博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tensorflow源码分析——BasicLSTMCell
阅读量:5830 次
发布时间:2019-06-18

本文共 3761 字,大约阅读时间需要 12 分钟。

BasicLSTMCell 是最简单的LSTMCell,源码位于:/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py。 BasicLSTMCell 继承了RNNCell,源码位于:/tensorflow/python/ops/rnn_cell_impl.py
注意事项: 1. input_size 这个参数不能使用,使用的是num_units
2. state_is_tuple 官方建议设置为True。此时,输入和输出的states为c(cell状态)和h(输出)的二元组
3. 输入、输出、cell的维度相同,都是 batch_size * num_units,
cell = tf.contrib.rnn.BasicLSTMCell(num_units, forget_bias=0.0, state_is_tuple=True)  #指定num_units_initial_state = cell.zero_state(batch_size, tf.float32)                   #指定batch_size,将c和h全部初始化为0,shape全是batch_size * num_units,
4.
class BasicLSTMCell(RNNCell):  """Basic LSTM recurrent network cell.  The implementation is based on: http://arxiv.org/abs/1409.2329.  We add forget_bias (default: 1) to the biases of the forget gate in order to  reduce the scale of forgetting in the beginning of the training.  It does not allow cell clipping, a projection layer, and does not  use peep-hole connections: it is the basic baseline.  For advanced models, please use the full LSTMCell that follows.  """  def __init__(self, num_units, forget_bias=1.0, input_size=None,               state_is_tuple=True, activation=tanh):    """Initialize the basic LSTM cell.    Args:      num_units: int, The number of units in the LSTM cell.      forget_bias: float, The bias added to forget gates (see above).      input_size: Deprecated and unused.      state_is_tuple: If True, accepted and returned states are 2-tuples of        the `c_state` and `m_state`.  If False, they are concatenated        along the column axis.  The latter behavior will soon be deprecated.      activation: Activation function of the inner states.    """    if not state_is_tuple:      logging.warn("%s: Using a concatenated state is slower and will soon be "                   "deprecated.  Use state_is_tuple=True.", self)    if input_size is not None:      logging.warn("%s: The input_size parameter is deprecated.", self)    self._num_units = num_units    self._forget_bias = forget_bias    self._state_is_tuple = state_is_tuple    self._activation = activation  @property  def state_size(self):    return (LSTMStateTuple(self._num_units, self._num_units)            if self._state_is_tuple else 2 * self._num_units)  @property  def output_size(self):    return self._num_units  def __call__(self, inputs, state, scope=None):    """Long short-term memory cell (LSTM)."""    with vs.variable_scope(scope or "basic_lstm_cell"):      # Parameters of gates are concatenated into one multiply for efficiency.      if self._state_is_tuple:        c, h = state      else:        c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)     # 线性计算 concat = [inputs, h]W + b     # 线性计算,分配W和b,W的shape为(2*num_units, 4*num_units), b的shape为(4*num_units,),共包含有四套参数,       # concat shape(batch_size, 4*num_units)      # 注意:只有cell 的input和output的size相等时才可以这样计算,否则要定义两套W,b.每套再包含四套参数      concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope)      # i = input_gate, j = new_input, f = forget_gate, o = output_gate      i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)      new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *               self._activation(j))      new_h = self._activation(new_c) * sigmoid(o)      if self._state_is_tuple:        new_state = LSTMStateTuple(new_c, new_h)      else:        new_state = array_ops.concat([new_c, new_h], 1)      return new_h, new_state

 5. lstm层,每一batch的运算

with tf.variable_scope("RNN"):            for time_step in range(num_steps):                if time_step > 0: tf.get_variable_scope().reuse_variables()                (cell_output, state) = cell(inputs[:, time_step, :], state)                outputs.append(cell_output)

6. 每一epoch

7.全部运算

转载于:https://www.cnblogs.com/yuetz/p/6563377.html

你可能感兴趣的文章
Linux-网络连接-(VMware与CentOS)
查看>>
寻找链表相交节点
查看>>
AS3——禁止swf缩放
查看>>
linq 学习笔记之 Linq基本子句
查看>>
[Js]布局转换
查看>>
Hot Bath
查看>>
国内常用NTP服务器地址及
查看>>
Java annotation 自定义注释@interface的用法
查看>>
Apache Spark 章节1
查看>>
phpcms与discuz的ucenter整合
查看>>
Linux crontab定时执行任务
查看>>
mysql root密码重置
查看>>
33蛇形填数
查看>>
选择排序
查看>>
SQL Server 数据库的数据和日志空间信息
查看>>
前端基础之JavaScript
查看>>
自己动手做个智能小车(6)
查看>>
自己遇到的,曾未知道的知识点
查看>>
P1382 楼房 set用法小结
查看>>
分类器性能度量
查看>>