Poplar and PopLibs
Gru.hpp
Go to the documentation of this file.
1// Copyright (c) 2019 Graphcore Ltd. All rights reserved.
8#ifndef popnn_Gru_hpp
9#define popnn_Gru_hpp
10
11#include <poplar/Tensor.hpp>
12#include <poplin/MatMul.hpp>
13#include <popnn/GruDef.hpp>
15#include <popnn/Rnn.hpp>
16
17namespace popnn {
18namespace gru {
19
25const std::vector<BasicGruCellUnit> getDefaultBasicGruCellOrder();
26
29struct GruParams {
31
32 // The datatype of the GRU.
37 std::size_t batchSize;
40 std::size_t timeSteps;
43 std::vector<std::size_t> layerSizes;
46 bool outputFullSequence = true;
49 bool calcInputGradients = true;
55 std::vector<BasicGruCellUnit> cellOrder = getDefaultBasicGruCellOrder();
58 bool resetAfter = false;
63
64 GruParams(poplar::Type dataType, std::size_t batchSize, std::size_t timeSteps,
65 std::vector<std::size_t> layerSizes,
68
70 std::size_t maxTimeSteps, const poplar::Tensor &timeSteps,
71 std::vector<std::size_t> layerSizes,
74
75 GruParams(const GruParams &other);
76};
77
78uint64_t getBasicGruCellFwdFlops(const GruParams &params);
79
80uint64_t getBasicGruCellBwdFlops(const GruParams &params);
81
82uint64_t getBasicGruCellWuFlops(const GruParams &params);
83
112poplar::Tensor createInput(poplar::Graph &graph, const GruParams &params,
113 const poplar::DebugContext &debugContext,
114 const poplar::OptionFlags &options = {},
115 poplin::PlanningCache *planningCache = nullptr);
116
117poplar::Tensor createInitialState(poplar::Graph &graph, const GruParams &params,
118 const poplar::DebugContext &debugContext,
119 const poplar::OptionFlags &options,
120 poplin::PlanningCache *cache);
126 poplar::Tensor inputWeights;
127 poplar::Tensor outputWeights;
128 poplar::Tensor biases;
129};
130
134std::pair<poplar::Tensor, poplar::Tensor>
136 const poplar::DebugContext &debugContext,
137 const poplar::OptionFlags &options = {},
138 poplin::PlanningCache *planningCache = nullptr);
139
144 const poplar::DebugContext &debugContext,
145 const poplar::OptionFlags &options = {},
146 poplin::PlanningCache *planningCache = nullptr);
147
151GruWeights createWeights(poplar::Graph &graph, const GruParams &params,
152 const poplar::DebugContext &debugContext,
153 const poplar::OptionFlags &options = {},
154 poplin::PlanningCache *planningCache = nullptr);
155
159 const poplar::DebugContext &debugContext,
160 const poplar::OptionFlags &options = {});
161
208 const poplar::Tensor &stateInit, const poplar::Tensor &in,
209 const GruWeights &weights, poplar::Tensor *intermediates,
211 const poplar::DebugContext &debugContext = {},
212 const poplar::OptionFlags &options = {},
213 poplin::PlanningCache *planningCache = nullptr);
214
265 const poplar::Tensor &stateInit, const poplar::Tensor &in,
266 const poplar::Tensor &realTimeSteps,
267 const GruWeights &weights, poplar::Tensor *intermediates,
269 const poplar::DebugContext &debugContext = {},
270 const poplar::OptionFlags &options = {},
271 poplin::PlanningCache *planningCache = nullptr);
272
320 const poplar::Tensor &stateInit,
321 const poplar::Tensor &in, const GruWeights &weights,
322 poplar::Tensor *intermediates,
323 const poplar::Tensor &attScores,
325 const poplar::DebugContext &debugContext = {},
326 const poplar::OptionFlags &options = {},
327 poplin::PlanningCache *planningCache = nullptr);
328
380auGruFwd(poplar::Graph &graph, const GruParams &params,
381 const poplar::Tensor &stateInit, const poplar::Tensor &in,
382 const poplar::Tensor &realTimeSteps, const GruWeights &weights,
383 poplar::Tensor *intermediates, const poplar::Tensor &attScores,
385 const poplar::DebugContext &debugContext = {},
386 const poplar::OptionFlags &options = {},
387 poplin::PlanningCache *planningCache = nullptr);
388
428 poplar::Graph &graph, const GruParams &params,
429 poplar::program::Sequence &prog, const poplar::Tensor &fwdOutputInit,
430 const poplar::Tensor &fwdIntermediatesSeq, const GruWeights &weights,
431 const poplar::Tensor &fwdInputSeq, const poplar::Tensor &fwdOutput,
432 const poplar::Tensor &gradLayerNext, poplar::Tensor *inputGrad,
433 poplar::Tensor *bwdIntermediates, const poplar::DebugContext &debugContext,
434 const poplar::OptionFlags &options_, poplin::PlanningCache *planningCache);
435
480gruBwd(poplar::Graph &graph, const GruParams &params,
481 poplar::program::Sequence &prog, const poplar::Tensor &fwdOutputInit,
482 const poplar::Tensor &fwdIntermediatesSeq, const GruWeights &weights,
483 const poplar::Tensor &fwdInputSeq, const poplar::Tensor &realTimeSteps,
484 const poplar::Tensor &fwdOutput, const poplar::Tensor &gradLayerNext,
485 poplar::Tensor *inputGrad, poplar::Tensor *bwdIntermediates,
486 const poplar::DebugContext &debugContext,
487 const poplar::OptionFlags &options_,
488 poplin::PlanningCache *planningCache);
489
531 poplar::Graph &graph, const GruParams &params,
532 poplar::program::Sequence &prog, const poplar::Tensor &fwdOutputInit,
533 const poplar::Tensor &fwdIntermediatesSeq, const GruWeights &weights,
534 const poplar::Tensor &fwdInputSeq, const poplar::Tensor &fwdOutput,
535 const poplar::Tensor &gradLayerNext, poplar::Tensor *inputGrad,
536 poplar::Tensor *bwdIntermediates, const poplar::Tensor &attentions,
537 poplar::Tensor *attentionsGrad, const poplar::DebugContext &debugContext,
538 const poplar::OptionFlags &options_, poplin::PlanningCache *planningCache);
539
586auGruBwd(poplar::Graph &graph, const GruParams &params,
587 poplar::program::Sequence &prog, const poplar::Tensor &fwdOutputInit,
588 const poplar::Tensor &fwdIntermediatesSeq, const GruWeights &weights,
589 const poplar::Tensor &fwdInputSeq, const poplar::Tensor &realTimeSteps,
590 const poplar::Tensor &fwdOutput, const poplar::Tensor &gradLayerNext,
591 poplar::Tensor *inputGrad, poplar::Tensor *bwdIntermediates,
592 const poplar::Tensor &attentions, poplar::Tensor *attentionsGrad,
593 const poplar::DebugContext &debugContext,
594 const poplar::OptionFlags &options_,
595 poplin::PlanningCache *planningCache);
596
627 const poplar::Tensor &fwdOutputInit,
628 const poplar::Tensor &fwdIntermediates,
629 const poplar::Tensor &bwdIntermediates,
630 const GruWeights &weights, const poplar::Tensor &input,
631 const poplar::Tensor &output,
632 const poplar::DebugContext &debugContext,
633 const poplar::OptionFlags &options_,
634 poplin::PlanningCache *planningCache);
635
666 const poplar::Tensor &fwdOutputInit,
667 const poplar::Tensor &fwdIntermediates,
668 const poplar::Tensor &bwdIntermediates,
669 const GruWeights &weights, const poplar::Tensor &input,
670 const poplar::Tensor &output,
671 const poplar::DebugContext &debugContext,
672 const poplar::OptionFlags &options_,
673 poplin::PlanningCache *planningCache);
674
712 poplar::Graph &graph, const GruParams &params,
713 poplar::program::Sequence &prog, const poplar::Tensor &fwdOutputInit,
714 const poplar::Tensor &fwdIntermediates, const GruWeights &weights,
715 const poplar::Tensor &input, const poplar::Tensor &output,
716 const poplar::Tensor &outputGrad, poplar::Tensor *inputGrad,
717 GruWeights &weightsGrad, const poplar::DebugContext &debugContext,
718 const poplar::OptionFlags &options_, poplin::PlanningCache *planningCache);
719
761 poplar::Graph &graph, const GruParams &params,
762 poplar::program::Sequence &prog, const poplar::Tensor &fwdOutputInit,
763 const poplar::Tensor &fwdIntermediates, const GruWeights &weights,
764 const poplar::Tensor &input, const poplar::Tensor &realTimeSteps,
765 const poplar::Tensor &output, const poplar::Tensor &outputGrad,
766 poplar::Tensor *inputGrad, GruWeights &weightsGrad,
767 const poplar::DebugContext &debugContext,
768 const poplar::OptionFlags &options_, poplin::PlanningCache *planningCache);
769
809 poplar::Graph &graph, const GruParams &params,
810 poplar::program::Sequence &prog, const poplar::Tensor &fwdOutputInit,
811 const poplar::Tensor &fwdIntermediates, const GruWeights &weights,
812 const poplar::Tensor &input, const poplar::Tensor &output,
813 const poplar::Tensor &outputGrad, poplar::Tensor *inputGrad,
814 GruWeights &weightsGrad, const poplar::Tensor &attentions,
815 poplar::Tensor *attentionsGrad, const poplar::DebugContext &debugContext,
816 const poplar::OptionFlags &options_, poplin::PlanningCache *planningCache);
817
861 poplar::Graph &graph, const GruParams &params,
862 poplar::program::Sequence &prog, const poplar::Tensor &fwdOutputInit,
863 const poplar::Tensor &fwdIntermediates, const GruWeights &weights,
864 const poplar::Tensor &input, const poplar::Tensor &realTimeSteps,
865 const poplar::Tensor &output, const poplar::Tensor &outputGrad,
866 poplar::Tensor *inputGrad, GruWeights &weightsGrad,
867 const poplar::Tensor &attentions, poplar::Tensor *attentionsGrad,
868 const poplar::DebugContext &debugContext,
869 const poplar::OptionFlags &options_, poplin::PlanningCache *planningCache);
870
871} // namespace gru
872} // namespace popnn
873
874#endif // popnn_Gru_hpp
Definitions for GRU cell operations.
poplar::Tensor auGruBwd(poplar::Graph &graph, const GruParams &params, poplar::program::Sequence &prog, const poplar::Tensor &fwdOutputInit, const poplar::Tensor &fwdIntermediatesSeq, const GruWeights &weights, const poplar::Tensor &fwdInputSeq, const poplar::Tensor &fwdOutput, const poplar::Tensor &gradLayerNext, poplar::Tensor *inputGrad, poplar::Tensor *bwdIntermediates, const poplar::Tensor &attentions, poplar::Tensor *attentionsGrad, const poplar::DebugContext &debugContext, const poplar::OptionFlags &options_, poplin::PlanningCache *planningCache)
Run AUGRU backward pass.
std::pair< poplar::Tensor, poplar::Tensor > createWeightsKernel(poplar::Graph &graph, const GruParams &params, const poplar::DebugContext &debugContext, const poplar::OptionFlags &options={}, poplin::PlanningCache *planningCache=nullptr)
Create the weights kernel used to weight the input and output of a GRU.
poplar::Tensor gruBwdWithWU(poplar::Graph &graph, const GruParams &params, poplar::program::Sequence &prog, const poplar::Tensor &fwdOutputInit, const poplar::Tensor &fwdIntermediates, const GruWeights &weights, const poplar::Tensor &input, const poplar::Tensor &output, const poplar::Tensor &outputGrad, poplar::Tensor *inputGrad, GruWeights &weightsGrad, const poplar::DebugContext &debugContext, const poplar::OptionFlags &options_, poplin::PlanningCache *planningCache)
Run a combined GRU backward and weight update pass.
const std::vector< BasicGruCellUnit > getDefaultBasicGruCellOrder()
Get the default order of the gates in a basic GRU cell.
GruWeights gruWU(poplar::Graph &graph, const GruParams &params, poplar::program::Sequence &prog, const poplar::Tensor &fwdOutputInit, const poplar::Tensor &fwdIntermediates, const poplar::Tensor &bwdIntermediates, const GruWeights &weights, const poplar::Tensor &input, const poplar::Tensor &output, const poplar::DebugContext &debugContext, const poplar::OptionFlags &options_, poplin::PlanningCache *planningCache)
Run a standalone weight update pass.
poplar::Tensor gruFwd(poplar::Graph &graph, const GruParams &params, const poplar::Tensor &stateInit, const poplar::Tensor &in, const GruWeights &weights, poplar::Tensor *intermediates, poplar::program::Sequence &fwdProg, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={}, poplin::PlanningCache *planningCache=nullptr)
Calculate the result of applying a GRU across a sequence.
poplar::Tensor gruBwd(poplar::Graph &graph, const GruParams &params, poplar::program::Sequence &prog, const poplar::Tensor &fwdOutputInit, const poplar::Tensor &fwdIntermediatesSeq, const GruWeights &weights, const poplar::Tensor &fwdInputSeq, const poplar::Tensor &fwdOutput, const poplar::Tensor &gradLayerNext, poplar::Tensor *inputGrad, poplar::Tensor *bwdIntermediates, const poplar::DebugContext &debugContext, const poplar::OptionFlags &options_, poplin::PlanningCache *planningCache)
Run GRU backward pass.
GruWeights auGruWU(poplar::Graph &graph, const GruParams &params, poplar::program::Sequence &prog, const poplar::Tensor &fwdOutputInit, const poplar::Tensor &fwdIntermediates, const poplar::Tensor &bwdIntermediates, const GruWeights &weights, const poplar::Tensor &input, const poplar::Tensor &output, const poplar::DebugContext &debugContext, const poplar::OptionFlags &options_, poplin::PlanningCache *planningCache)
Run a standalone weight update pass.
poplar::Tensor auGruFwd(poplar::Graph &graph, const GruParams &params, const poplar::Tensor &stateInit, const poplar::Tensor &in, const GruWeights &weights, poplar::Tensor *intermediates, const poplar::Tensor &attScores, poplar::program::Sequence &fwdProg, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={}, poplin::PlanningCache *planningCache=nullptr)
Calculate the result of applying an AUGRU across a sequence.
poplar::Tensor createWeightsBiases(poplar::Graph &graph, const GruParams &params, const poplar::DebugContext &debugContext, const poplar::OptionFlags &options={}, poplin::PlanningCache *planningCache=nullptr)
Create the weights biases.
poplar::Tensor auGruBwdWithWU(poplar::Graph &graph, const GruParams &params, poplar::program::Sequence &prog, const poplar::Tensor &fwdOutputInit, const poplar::Tensor &fwdIntermediates, const GruWeights &weights, const poplar::Tensor &input, const poplar::Tensor &output, const poplar::Tensor &outputGrad, poplar::Tensor *inputGrad, GruWeights &weightsGrad, const poplar::Tensor &attentions, poplar::Tensor *attentionsGrad, const poplar::DebugContext &debugContext, const poplar::OptionFlags &options_, poplin::PlanningCache *planningCache)
Run a combined AUGRU backward and weight update pass.
poplar::Tensor createAttention(poplar::Graph &graph, const GruParams &params, const poplar::DebugContext &debugContext, const poplar::OptionFlags &options={})
Create an attention tensor for AUGRU.
Definitions for non-linearity operations.
Functions for recurrent neural networks (RNN).
DebugContext gathers the common external parameters of the context of an operation.
Definition: DebugContext.hpp:221
This class represents a graph program to be executed on the IPU.
Definition: Graph.hpp:52
A set of option/value string flags to be used in various APIs.
Definition: OptionFlags.hpp:24
A reference to a subset of tensor elements.
Definition: Tensor.hpp:38
Class representing device data types.
Definition: Type.hpp:42
Program that executes a sequence of programs.
Definition: Program.hpp:77
Functions used in neural networks.
Definition: BatchNorm.hpp:14
NonLinearityType
Definition: NonLinearityDef.hpp:11
@ TANH
Hyperbolic tangent:
Functions and data types for performing matrix multiplies on the IPU.
Structure representing the parameters of the GRU.
Definition: Gru.hpp:29
std::vector< std::size_t > layerSizes
The number of neurons for the input and output layer.
Definition: Gru.hpp:43
std::vector< BasicGruCellUnit > cellOrder
The weights and biases for all of the layers being processed are concatenated in the outermost dimens...
Definition: Gru.hpp:55
std::size_t batchSize
The batch size.
Definition: Gru.hpp:37
bool outputFullSequence
If true the GRU function returns the entire sequence of outputs, otherwise it returns just the final ...
Definition: Gru.hpp:46
NonLinearityType activation
Activation function.
Definition: Gru.hpp:60
NonLinearityType recurrentActivation
Recurrent activation function.
Definition: Gru.hpp:62
poplar::Type dataType
Definition: Gru.hpp:34
std::size_t timeSteps
The number of time steps in the sequence of the GRU.
Definition: Gru.hpp:40
bool resetAfter
Controls whether the reset gate is applied before or after the candidate weights and biases.
Definition: Gru.hpp:58
bool calcInputGradients
If this parameter is set to false then the GRU will skip the calculation of the gradients of the inpu...
Definition: Gru.hpp:49
Structure holding all the parameters of a GRU cell, or the deltas for those parameters (depending on ...
Definition: Gru.hpp:125
Structure of Recurrent Neural Network (RNN) parameters which allows for any customized implementation...
Definition: Rnn.hpp:22