11#include <poplar/Tensor.hpp>
101std::vector<std::pair<poplin::MatMulParams, poplar::OptionFlags>>
104uint64_t getBasicLstmCellFwdFlops(
const LstmParams ¶ms);
106uint64_t getBasicLstmCellBwdFlops(
const LstmParams ¶ms);
108uint64_t getBasicLstmCellWuFlops(
const LstmParams ¶ms);
163 poplin::PlanningCache *planningCache =
nullptr);
183 poplin::PlanningCache *planningCache =
nullptr);
203 poplin::PlanningCache *planningCache =
nullptr);
222 poplin::PlanningCache *planningCache =
nullptr);
249std::pair<poplar::Tensor, poplar::Tensor>
253 poplin::PlanningCache *planningCache =
nullptr);
261 poplin::PlanningCache *planningCache =
nullptr);
269 poplin::PlanningCache *planningCache =
nullptr);
301std::pair<poplar::Tensor, poplar::Tensor>
308 poplin::PlanningCache *planningCache =
nullptr);
363 poplin::PlanningCache *planningCache =
nullptr);
423 poplin::PlanningCache *planningCache =
nullptr);
462 poplin::PlanningCache *planningCache =
nullptr);
516 poplin::PlanningCache *planningCache =
nullptr);
573 poplin::PlanningCache *planningCache =
nullptr);
Definitions for LSTM cell operations.
LstmState lstmBwdWithWU(poplar::Graph &graph, const LstmParams ¶ms, poplar::program::Sequence &prog, const LstmState &fwdStateInit, const poplar::Tensor &fwdIntermediates, const LstmWeights &weights, const poplar::Tensor &input, const poplar::Tensor &output, const poplar::Tensor &outputGrad, const LstmState *lastStepStateGrad, poplar::Tensor *inputGrad, LstmWeights &weightsGrad, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={}, poplin::PlanningCache *planningCache=nullptr)
Run a combined LSTM backward and weight update pass.
poplar::Tensor createInitialOutput(poplar::Graph &graph, const LstmParams ¶ms, const poplar::DebugContext &debugContext, const poplar::OptionFlags &options={}, poplin::PlanningCache *planningCache=nullptr)
Create the initial output that can be combined with the initial cell state using an LstmState.
void zeroInitialState(poplar::Graph &graph, const LstmState &initialState, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={})
Initialise the forward state of an LSTM with zeros.
LstmState lstmBwd(poplar::Graph &graph, const LstmParams ¶ms, poplar::program::Sequence &prog, const LstmState &fwdStateInit, const poplar::Tensor &fwdIntermediates, const LstmWeights &weights, const poplar::Tensor &input, const poplar::Tensor &output, const poplar::Tensor &outputGrad, const LstmState *lastStepStateGrad, poplar::Tensor *inputGrad, poplar::Tensor *bwdIntermediates, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={}, poplin::PlanningCache *planningCache=nullptr)
Run LSTM backward pass.
poplar::Tensor createInitialCellState(poplar::Graph &graph, const LstmParams ¶ms, const poplar::DebugContext &debugContext, const poplar::OptionFlags &options={}, poplin::PlanningCache *planningCache=nullptr)
Create the initial cell state that can be combined with the initial output using an LstmState.
std::pair< poplar::Tensor, poplar::Tensor > lstmFwd(poplar::Graph &graph, const LstmParams ¶ms, const LstmState &stateInit, const poplar::Tensor &in, const LstmWeights &weights, poplar::Tensor *intermediates, poplar::program::Sequence &fwdProg, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={}, poplin::PlanningCache *planningCache=nullptr)
Calculate the result of applying an LSTM across a sequence.
LstmWeights lstmWU(poplar::Graph &graph, const LstmParams ¶ms, poplar::program::Sequence &prog, const LstmState &fwdStateInit, const poplar::Tensor &fwdIntermediates, const poplar::Tensor &bwdIntermediates, const LstmWeights &weights, const poplar::Tensor &input, const poplar::Tensor &output, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={}, poplin::PlanningCache *planningCache=nullptr)
Run a standalone weight update pass.
const std::vector< BasicLstmCellUnit > getDefaultBasicLstmCellOrder()
Get the default order of the gates in a basic LSTM cell.
Definitions for non-linearity operations.
Functions for recurrent neural networks (RNN).
DebugContext gathers the common external parameters of the context of an operation.
Definition: DebugContext.hpp:221
This class represents a graph program to be executed on the IPU.
Definition: Graph.hpp:52
A set of option/value string flags to be used in various APIs.
Definition: OptionFlags.hpp:24
A reference to a subset of tensor elements.
Definition: Tensor.hpp:38
Class representing device data types.
Definition: Type.hpp:42
Program that executes a sequence of programs.
Definition: Program.hpp:77
Functions used in neural networks.
Definition: BatchNorm.hpp:14
NonLinearityType
Definition: NonLinearityDef.hpp:11
@ TANH
Hyperbolic tangent:
Functions and data types for performing matrix multiplies on the IPU.
Structure representing the parameters of the LSTM.
Definition: Lstm.hpp:29
std::vector< BasicLstmCellUnit > cellOrder
The weights and biases for all of the layers being processed are concatenated in the outermost dimens...
Definition: Lstm.hpp:68
bool preserveFinalState
If this parameter is set to true then the LSTM will preserve the internal state at the last valid tim...
Definition: Lstm.hpp:62
poplar::Type dataType
The datatype of the LSTM.
Definition: Lstm.hpp:34
bool outputFullSequence
If true the Lstm function returns the entire sequence of outputs, otherwise it returns just the final...
Definition: Lstm.hpp:49
NonLinearityType recurrentActivation
Recurrent activation function.
Definition: Lstm.hpp:72
std::vector< std::size_t > layerSizes
The number of neurons before and after each layer of the LSTM.
Definition: Lstm.hpp:46
NonLinearityType activation
Activation function.
Definition: Lstm.hpp:70
std::size_t timeSteps
The number of time steps in the sequence of the LSTM.
Definition: Lstm.hpp:40
bool calcInputGradients
If this parameter is set to false then the LSTM will skip the calculation of the gradients of the inp...
Definition: Lstm.hpp:55
std::size_t batchSize
The batch size.
Definition: Lstm.hpp:37
bool doInputWeightCalc
If this parameter is set to false then the LSTM will skip the calculation of weighted inputs (only us...
Definition: Lstm.hpp:52
Structure holding the state of a LSTM cell, or the gradients for the state (depending on the context)...
Definition: Lstm.hpp:90
Structure holding all the parameters of an LSTM cell, or the deltas for those parameters (depending o...
Definition: Lstm.hpp:240
Structure of Recurrent Neural Network (RNN) parameters which allows for any customized implementation...
Definition: Rnn.hpp:22