Poplar and PopLibs
Lstm.hpp
Go to the documentation of this file.
1// Copyright (c) 2017 Graphcore Ltd. All rights reserved.
8#ifndef popnn_Lstm_hpp
9#define popnn_Lstm_hpp
10
11#include <poplar/Tensor.hpp>
12#include <poplin/MatMul.hpp>
13#include <popnn/LstmDef.hpp>
15#include <popnn/Rnn.hpp>
16
17namespace popnn {
18namespace lstm {
19
25const std::vector<BasicLstmCellUnit> getDefaultBasicLstmCellOrder();
26
29struct LstmParams {
31
37 std::size_t batchSize;
40 std::size_t timeSteps;
46 std::vector<std::size_t> layerSizes;
49 bool outputFullSequence = true;
52 bool doInputWeightCalc = true;
55 bool calcInputGradients = true;
62 bool preserveFinalState = false;
68 std::vector<BasicLstmCellUnit> cellOrder = getDefaultBasicLstmCellOrder();
73
75 std::size_t timeSteps, std::vector<std::size_t> layerSizes,
78
80 std::size_t maxTimeSteps, const poplar::Tensor &timeSteps,
81 std::vector<std::size_t> layerSizes,
84};
85
90struct LstmState {
91 poplar::Tensor output;
92 poplar::Tensor cellState;
93
94 poplar::Tensor getAsTensor() const;
95};
96
101std::vector<std::pair<poplin::MatMulParams, poplar::OptionFlags>>
102getMatMulPrePlanParameters(LstmParams params, poplar::OptionFlags opts);
103
104uint64_t getBasicLstmCellFwdFlops(const LstmParams &params);
105
106uint64_t getBasicLstmCellBwdFlops(const LstmParams &params);
107
108uint64_t getBasicLstmCellWuFlops(const LstmParams &params);
109
160poplar::Tensor createInput(poplar::Graph &graph, const LstmParams &params,
161 const poplar::DebugContext &debugContext,
162 const poplar::OptionFlags &options = {},
163 poplin::PlanningCache *planningCache = nullptr);
164
181 const poplar::DebugContext &debugContext,
182 const poplar::OptionFlags &options = {},
183 poplin::PlanningCache *planningCache = nullptr);
184
201 const poplar::DebugContext &debugContext,
202 const poplar::OptionFlags &options = {},
203 poplin::PlanningCache *planningCache = nullptr);
204
219LstmState createInitialState(poplar::Graph &graph, const LstmParams &params,
220 const poplar::DebugContext &debugContext,
221 const poplar::OptionFlags &options = {},
222 poplin::PlanningCache *planningCache = nullptr);
223
232void zeroInitialState(poplar::Graph &graph, const LstmState &initialState,
234 const poplar::DebugContext &debugContext = {});
235
241 poplar::Tensor inputWeights;
242 poplar::Tensor outputWeights;
243 poplar::Tensor biases;
244};
245
249std::pair<poplar::Tensor, poplar::Tensor>
250createWeightsKernel(poplar::Graph &graph, const LstmParams &params,
251 const poplar::DebugContext &debugContext,
252 const poplar::OptionFlags &options = {},
253 poplin::PlanningCache *planningCache = nullptr);
254
258createWeightsBiases(poplar::Graph &graph, const LstmParams &params,
259 const poplar::DebugContext &debugContext,
260 const poplar::OptionFlags &options = {},
261 poplin::PlanningCache *planningCache = nullptr);
262
266LstmWeights createWeights(poplar::Graph &graph, const LstmParams &params,
267 const poplar::DebugContext &debugContext,
268 const poplar::OptionFlags &options = {},
269 poplin::PlanningCache *planningCache = nullptr);
270
301std::pair<poplar::Tensor, poplar::Tensor>
302lstmFwd(poplar::Graph &graph, const LstmParams &params,
303 const LstmState &stateInit, const poplar::Tensor &in,
304 const LstmWeights &weights, poplar::Tensor *intermediates,
306 const poplar::DebugContext &debugContext = {},
307 const poplar::OptionFlags &options = {},
308 poplin::PlanningCache *planningCache = nullptr);
309
354LstmState
355lstmBwd(poplar::Graph &graph, const LstmParams &params,
356 poplar::program::Sequence &prog, const LstmState &fwdStateInit,
357 const poplar::Tensor &fwdIntermediates, const LstmWeights &weights,
358 const poplar::Tensor &input, const poplar::Tensor &output,
359 const poplar::Tensor &outputGrad, const LstmState *lastStepStateGrad,
360 poplar::Tensor *inputGrad, poplar::Tensor *bwdIntermediates,
361 const poplar::DebugContext &debugContext = {},
362 const poplar::OptionFlags &options = {},
363 poplin::PlanningCache *planningCache = nullptr);
364
414 const LstmState &fwdStateInit,
415 const poplar::Tensor &fwdIntermediates,
416 const LstmWeights &weights, const poplar::Tensor &input,
417 const poplar::Tensor &output,
418 const poplar::Tensor &outputGrad,
419 const poplar::Tensor *lastCellStateGrad,
420 poplar::Tensor *inputGrad, poplar::Tensor *bwdIntermediates,
421 const poplar::DebugContext &debugContext = {},
422 const poplar::OptionFlags &options = {},
423 poplin::PlanningCache *planningCache = nullptr);
424
455 const LstmState &fwdStateInit,
456 const poplar::Tensor &fwdIntermediates,
457 const poplar::Tensor &bwdIntermediates,
458 const LstmWeights &weights, const poplar::Tensor &input,
459 const poplar::Tensor &output,
460 const poplar::DebugContext &debugContext = {},
461 const poplar::OptionFlags &options = {},
462 poplin::PlanningCache *planningCache = nullptr);
463
507 const LstmState &fwdStateInit,
508 const poplar::Tensor &fwdIntermediates,
509 const LstmWeights &weights, const poplar::Tensor &input,
510 const poplar::Tensor &output,
511 const poplar::Tensor &outputGrad,
512 const LstmState *lastStepStateGrad,
513 poplar::Tensor *inputGrad, LstmWeights &weightsGrad,
514 const poplar::DebugContext &debugContext = {},
515 const poplar::OptionFlags &options = {},
516 poplin::PlanningCache *planningCache = nullptr);
517
564 const LstmState &fwdStateInit,
565 const poplar::Tensor &fwdIntermediates,
566 const LstmWeights &weights, const poplar::Tensor &input,
567 const poplar::Tensor &output,
568 const poplar::Tensor &outputGrad,
569 const poplar::Tensor *lastCellStateGrad,
570 poplar::Tensor *inputGrad, LstmWeights &weightsGrad,
571 const poplar::DebugContext &debugContext = {},
572 const poplar::OptionFlags &options = {},
573 poplin::PlanningCache *planningCache = nullptr);
574
575} // namespace lstm
576} // namespace popnn
577
578#endif // popnn_Lstm_hpp
Definitions for LSTM cell operations.
LstmState lstmBwdWithWU(poplar::Graph &graph, const LstmParams &params, poplar::program::Sequence &prog, const LstmState &fwdStateInit, const poplar::Tensor &fwdIntermediates, const LstmWeights &weights, const poplar::Tensor &input, const poplar::Tensor &output, const poplar::Tensor &outputGrad, const LstmState *lastStepStateGrad, poplar::Tensor *inputGrad, LstmWeights &weightsGrad, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={}, poplin::PlanningCache *planningCache=nullptr)
Run a combined LSTM backward and weight update pass.
poplar::Tensor createInitialOutput(poplar::Graph &graph, const LstmParams &params, const poplar::DebugContext &debugContext, const poplar::OptionFlags &options={}, poplin::PlanningCache *planningCache=nullptr)
Create the initial output that can be combined with the initial cell state using an LstmState.
void zeroInitialState(poplar::Graph &graph, const LstmState &initialState, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={})
Initialise the forward state of an LSTM with zeros.
LstmState lstmBwd(poplar::Graph &graph, const LstmParams &params, poplar::program::Sequence &prog, const LstmState &fwdStateInit, const poplar::Tensor &fwdIntermediates, const LstmWeights &weights, const poplar::Tensor &input, const poplar::Tensor &output, const poplar::Tensor &outputGrad, const LstmState *lastStepStateGrad, poplar::Tensor *inputGrad, poplar::Tensor *bwdIntermediates, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={}, poplin::PlanningCache *planningCache=nullptr)
Run LSTM backward pass.
poplar::Tensor createInitialCellState(poplar::Graph &graph, const LstmParams &params, const poplar::DebugContext &debugContext, const poplar::OptionFlags &options={}, poplin::PlanningCache *planningCache=nullptr)
Create the initial cell state that can be combined with the initial output using an LstmState.
std::pair< poplar::Tensor, poplar::Tensor > lstmFwd(poplar::Graph &graph, const LstmParams &params, const LstmState &stateInit, const poplar::Tensor &in, const LstmWeights &weights, poplar::Tensor *intermediates, poplar::program::Sequence &fwdProg, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={}, poplin::PlanningCache *planningCache=nullptr)
Calculate the result of applying an LSTM across a sequence.
LstmWeights lstmWU(poplar::Graph &graph, const LstmParams &params, poplar::program::Sequence &prog, const LstmState &fwdStateInit, const poplar::Tensor &fwdIntermediates, const poplar::Tensor &bwdIntermediates, const LstmWeights &weights, const poplar::Tensor &input, const poplar::Tensor &output, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={}, poplin::PlanningCache *planningCache=nullptr)
Run a standalone weight update pass.
const std::vector< BasicLstmCellUnit > getDefaultBasicLstmCellOrder()
Get the default order of the gates in a basic LSTM cell.
Definitions for non-linearity operations.
Functions for recurrent neural networks (RNN).
DebugContext gathers the common external parameters of the context of an operation.
Definition: DebugContext.hpp:221
This class represents a graph program to be executed on the IPU.
Definition: Graph.hpp:52
A set of option/value string flags to be used in various APIs.
Definition: OptionFlags.hpp:24
A reference to a subset of tensor elements.
Definition: Tensor.hpp:38
Class representing device data types.
Definition: Type.hpp:42
Program that executes a sequence of programs.
Definition: Program.hpp:77
Functions used in neural networks.
Definition: BatchNorm.hpp:14
NonLinearityType
Definition: NonLinearityDef.hpp:11
@ TANH
Hyperbolic tangent:
Functions and data types for performing matrix multiplies on the IPU.
Structure representing the parameters of the LSTM.
Definition: Lstm.hpp:29
std::vector< BasicLstmCellUnit > cellOrder
The weights and biases for all of the layers being processed are concatenated in the outermost dimens...
Definition: Lstm.hpp:68
bool preserveFinalState
If this parameter is set to true then the LSTM will preserve the internal state at the last valid tim...
Definition: Lstm.hpp:62
poplar::Type dataType
The datatype of the LSTM.
Definition: Lstm.hpp:34
bool outputFullSequence
If true the Lstm function returns the entire sequence of outputs, otherwise it returns just the final...
Definition: Lstm.hpp:49
NonLinearityType recurrentActivation
Recurrent activation function.
Definition: Lstm.hpp:72
std::vector< std::size_t > layerSizes
The number of neurons before and after each layer of the LSTM.
Definition: Lstm.hpp:46
NonLinearityType activation
Activation function.
Definition: Lstm.hpp:70
std::size_t timeSteps
The number of time steps in the sequence of the LSTM.
Definition: Lstm.hpp:40
bool calcInputGradients
If this parameter is set to false then the LSTM will skip the calculation of the gradients of the inp...
Definition: Lstm.hpp:55
std::size_t batchSize
The batch size.
Definition: Lstm.hpp:37
bool doInputWeightCalc
If this parameter is set to false then the LSTM will skip the calculation of weighted inputs (only us...
Definition: Lstm.hpp:52
Structure holding the state of a LSTM cell, or the gradients for the state (depending on the context)...
Definition: Lstm.hpp:90
Structure holding all the parameters of an LSTM cell, or the deltas for those parameters (depending o...
Definition: Lstm.hpp:240
Structure of Recurrent Neural Network (RNN) parameters which allows for any customized implementation...
Definition: Rnn.hpp:22