Tag Archives: Tensorflow static_rnn

Tensorflow: How to Use static_rnn

tf.nn.static_ rnn

Aliases:

tf.contrib.rnn.static_ rnn

tf.nn.static_ rnn

Use the specified RNN neurons to create a recurrent neural network

tf.nn.static_rnn(

    cell,

    inputs,

    initial_state=None,

    dtype=None,

    sequence_length=None,

    scope=None

)

Parameter Description:

  1. Cell: RNN neuron used in neural network, such as basic RNN cell, basic LSTM cell
  2. inputs: a list of length T, each element in the list is a tensor, which is in the form of: [batch]_ size,input_ size]
  3. initial_ state:RNN The initial state of, if cell.state_ If size is an integer, it must be of the appropriate type and shape, such as [batch]_ size, cell.state_ The tensor of [size]. as cell.state_ If size is a tuple, it should be a tensor tuple cell.state_ S in size should have the form of [batch]_ The tuple of the tensor of [size, S].
  4. Dtype: data type of initial state and expected output. Optional parameters.
  5. sequence_ Length: Specifies the length of each input sequence. The size is batch_ The vector of size.
  6. Scope: variable range

Return value:

A (outputs, state) pair

Outputs: a list of length T, in which each element is the corresponding output of each input. For example, one time step corresponds to one output.

State: the final state

Code example:

import tensorflow as tf



x=tf.Variable(tf.random_normal([2,4,3])) #[batch_size,timesteps,embedding_dim]

x=tf.unstack(x,axis=1) #Expand by time step

n_neurons = 5 # Number of output neurons



basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)

output_seqs, states = tf.contrib.rnn.static_rnn(basic_cell,x, dtype=tf.float32)



print(len(output_seqs)) #Four time steps

print(output_seqs[0]) # Output one tensor per time step

print(output_seqs[1]) # output one tensor per time step

print(states) #Hide states

The output is as follows:

4 Tensor(“rnn/basic_rnn_cell/Tanh:0”, shape=(2, 5), dtype=float32) Tensor(“rnn/basic_rnn_cell/Tanh_1:0”, shape=(2, 5), dtype=float32) Tensor(“rnn/basic_rnn_cell/Tanh_3:0”, shape=(2, 5), dtype=float32)