8#ifndef popnn_LayerNorm_hpp
9#define popnn_LayerNorm_hpp
44inline std::pair<poplar::Tensor, poplar::Tensor>
47 bool stableAlgo =
false,
51 poputil::PoplibsOpDebugInfo di(debugContext,
52 DI_ARGS(acts, eps, unbiasedVarEstimate,
53 stableAlgo, partialsType, options));
55 auto outputs = popnn::gn::groupNormStatistics(graph, acts, eps, prog, 1,
56 unbiasedVarEstimate, stableAlgo,
57 partialsType, {di}, options);
59 di.addOutputs({{
"mean", poputil::toProfileValue(outputs.first)},
60 {
"iStdDev", poputil::toProfileValue(outputs.second)}});
86 poputil::PoplibsOpDebugInfo di(debugContext,
87 DI_ARGS(acts, mean, invStdDev, options));
89 auto output = popnn::gn::groupNormWhiten(graph, acts, mean, invStdDev, prog,
124inline std::pair<poplar::Tensor, poplar::Tensor>
131 poputil::PoplibsOpDebugInfo di(
132 debugContext, DI_ARGS(acts, gamma, beta, mean, invStdDev, options));
134 auto outputs = popnn::gn::groupNormalise(graph, acts, gamma, beta, mean,
135 invStdDev, prog, {di}, options);
137 di.addOutputs({{
"normActs", poputil::toProfileValue(outputs.first)},
138 {
"whitenedActs", poputil::toProfileValue(outputs.second)}});
170 poputil::PoplibsOpDebugInfo di(
172 DI_ARGS(acts, gradsIn, mean, iStdDev, partialsType, options));
174 auto outputs = popnn::gn::groupNormParamGradients(
175 graph, acts, gradsIn, mean, iStdDev, prog, partialsType, {di}, options);
177 di.addOutputs({{
"meanGrad", poputil::toProfileValue(outputs.first)},
178 {
"iStdDevGrad", poputil::toProfileValue(outputs.second)}});
205 poputil::PoplibsOpDebugInfo di(
206 debugContext, DI_ARGS(actsWhitened, gradsIn, partialsType, options));
208 auto outputs = popnn::gn::groupNormParamGradients(
209 graph, actsWhitened, gradsIn, prog, partialsType, {di}, options);
211 di.addOutputs({{
"meanGrad", poputil::toProfileValue(outputs.first)},
212 {
"iStdDevGrad", poputil::toProfileValue(outputs.second)}});
248 poputil::PoplibsOpDebugInfo di(
250 DI_ARGS(acts, gradsIn, mean, invStdDev, gamma, partialsType, options));
253 popnn::gn::groupNormGradients(graph, acts, gradsIn, mean, invStdDev,
254 gamma, prog, partialsType, {di}, options);
255 di.addOutput(output);
289 poputil::PoplibsOpDebugInfo di(
291 DI_ARGS(actsWhitened, gradsIn, invStdDev, gamma, partialsType, options));
294 popnn::gn::groupNormGradients(graph, actsWhitened, gradsIn, invStdDev,
295 gamma, prog, partialsType, {di}, options);
296 di.addOutput(output);
329 poputil::PoplibsOpDebugInfo di(
331 DI_ARGS(gammaDelta, betaDelta, scale, gamma, beta, options));
333 return popnn::gn::groupNormParamUpdate(graph, gammaDelta, betaDelta, scale,
334 gamma, beta, prog, {di}, options);
367 poputil::PoplibsOpDebugInfo di(
369 DI_ARGS(gammaDelta, betaDelta, scale, gamma, beta, options));
371 return popnn::gn::groupNormParamUpdate(graph, gammaDelta, betaDelta, scale,
372 gamma, beta, prog, {di}, options);
Poplibs generic debug info structure.
Group normalization operations.
poplar::Tensor layerNormGradients(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gradsIn, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, const poplar::Tensor &gamma, poplar::program::Sequence &prog, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Compute gradients with respect to input activations for the layer norm layer.
Definition: LayerNorm.hpp:241
poplar::Tensor layerNormWhiten(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Whiten activations given the mean and standard deviation.
Definition: LayerNorm.hpp:81
std::pair< poplar::Tensor, poplar::Tensor > layerNormalise(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gamma, const poplar::Tensor &beta, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Layer normalise activations given the mean, standard deviation and batch norm parameters.
Definition: LayerNorm.hpp:125
std::pair< poplar::Tensor, poplar::Tensor > layerNormStatistics(poplar::Graph &graph, const poplar::Tensor acts, float eps, poplar::program::Sequence &prog, bool unbiasedVarEstimate, bool stableAlgo=false, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Estimate mean and inverse of standard deviation of activations.
Definition: LayerNorm.hpp:45
void layerNormParamUpdate(poplar::Graph &graph, const poplar::Tensor &gammaDelta, const poplar::Tensor &betaDelta, float scale, poplar::Tensor &gamma, poplar::Tensor &beta, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Update layer norm parameters given the gradients with respect to the parameters.
Definition: LayerNorm.hpp:322
std::pair< poplar::Tensor, poplar::Tensor > layerNormParamGradients(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gradsIn, const poplar::Tensor &mean, const poplar::Tensor &iStdDev, poplar::program::Sequence &prog, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Compute gradients with respect to parameters for parameter update.
Definition: LayerNorm.hpp:162
DebugContext gathers the common external parameters of the context of an operation.
Definition: DebugContext.hpp:221
This class represents a graph program to be executed on the IPU.
Definition: Graph.hpp:52
A set of option/value string flags to be used in various APIs.
Definition: OptionFlags.hpp:24
A reference to a subset of tensor elements.
Definition: Tensor.hpp:38
Class representing device data types.
Definition: Type.hpp:42
Program that executes a sequence of programs.
Definition: Program.hpp:77
half2 ln(half2 src)
Targets the f16v2ln instruction.
Definition: ipu_intrinsics:567
Type FLOAT
Device type: float
Functions used in neural networks.
Definition: BatchNorm.hpp:14