8#ifndef popnn_InstanceNorm_hpp
9#define popnn_InstanceNorm_hpp
40inline std::pair<poplar::Tensor, poplar::Tensor>
43 bool unbiasedVarEstimate,
bool stableAlgo,
47 poputil::PoplibsOpDebugInfo di(debugContext,
48 DI_ARGS(acts, eps, unbiasedVarEstimate,
49 stableAlgo, partialsType, options));
51 auto outputs = popnn::gn::groupNormStatistics(
52 graph, acts, eps, prog, acts.
dim(1), unbiasedVarEstimate, stableAlgo,
53 partialsType, {di}, options);
55 di.addOutputs({{
"mean", poputil::toProfileValue(outputs.first)},
56 {
"iStd", poputil::toProfileValue(outputs.second)}});
82 poputil::PoplibsOpDebugInfo di(debugContext,
83 DI_ARGS(acts, mean, invStdDev, options));
85 auto output = popnn::gn::groupNormWhiten(graph, acts, mean, invStdDev, prog,
120inline std::pair<poplar::Tensor, poplar::Tensor>
127 poputil::PoplibsOpDebugInfo di(
128 debugContext, DI_ARGS(acts, gamma, beta, mean, invStdDev, options));
130 auto outputs = popnn::gn::groupNormalise(graph, acts, gamma, beta, mean,
131 invStdDev, prog, {di}, options);
133 di.addOutputs({{
"normActs", poputil::toProfileValue(outputs.first)},
134 {
"whitenedActs", poputil::toProfileValue(outputs.second)}});
166 poputil::PoplibsOpDebugInfo di(
168 DI_ARGS(acts, gradsIn, mean, iStdDev, partialsType, options));
170 auto outputs = popnn::gn::groupNormParamGradients(
171 graph, acts, gradsIn, mean, iStdDev, prog, partialsType, {di}, options);
173 di.addOutputs({{
"meanGrad", poputil::toProfileValue(outputs.first)},
174 {
"iStdDevGrad", poputil::toProfileValue(outputs.second)}});
202 poputil::PoplibsOpDebugInfo di(
203 debugContext, DI_ARGS(actsWhitened, gradsIn, partialsType, options));
205 auto outputs = popnn::gn::groupNormParamGradients(
206 graph, actsWhitened, gradsIn, prog, partialsType, {di}, options);
208 di.addOutputs({{
"meanGrad", poputil::toProfileValue(outputs.first)},
209 {
"iStdDevGrad", poputil::toProfileValue(outputs.second)}});
248 poputil::PoplibsOpDebugInfo di(
250 DI_ARGS(acts, gradsIn, mean, invStdDev, gamma, partialsType, options));
253 popnn::gn::groupNormGradients(graph, acts, gradsIn, mean, invStdDev,
254 gamma, prog, partialsType, {di}, options);
256 di.addOutput(output);
291 poputil::PoplibsOpDebugInfo di(
293 DI_ARGS(actsWhitened, gradsIn, invStdDev, gamma, partialsType, options));
296 popnn::gn::groupNormGradients(graph, actsWhitened, gradsIn, invStdDev,
297 gamma, prog, partialsType, {di}, options);
298 di.addOutput(output);
331 poputil::PoplibsOpDebugInfo di(
333 DI_ARGS(gammaDelta, betaDelta, scale, gamma, beta, options));
335 return popnn::gn::groupNormParamUpdate(graph, gammaDelta, betaDelta, scale,
336 gamma, beta, prog, {di}, options);
368 poputil::PoplibsOpDebugInfo di(
370 DI_ARGS(gammaDelta, betaDelta, scale, gamma, beta, options));
371 return popnn::gn::groupNormParamUpdate(graph, gammaDelta, betaDelta, scale,
372 gamma, beta, prog, {di}, options);
389uint64_t getFwdFlops(uint64_t numChannels, uint64_t actsPerChannel,
390 bool computeEstimates);
395uint64_t getBwdFlops(uint64_t numChannels, uint64_t actsPerChannel);
400uint64_t getWuFlops(uint64_t numChannels, uint64_t actsPerChannel);
Poplibs generic debug info structure.
Group normalization operations.
std::pair< poplar::Tensor, poplar::Tensor > instanceNormParamGradients(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gradsIn, const poplar::Tensor &mean, const poplar::Tensor &iStdDev, poplar::program::Sequence &prog, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Compute gradients with respect to parameters for parameter update.
Definition: InstanceNorm.hpp:158
poplar::Tensor instanceNormGradients(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gradsIn, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, const poplar::Tensor &gamma, poplar::program::Sequence &prog, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Compute gradients with respect to input activations for the instance norm layer.
Definition: InstanceNorm.hpp:239
void instanceNormParamUpdate(poplar::Graph &graph, const poplar::Tensor &gammaDelta, const poplar::Tensor &betaDelta, float scale, poplar::Tensor &gamma, poplar::Tensor &beta, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Update parameters for the instance norm layer.
Definition: InstanceNorm.hpp:325
std::pair< poplar::Tensor, poplar::Tensor > instanceNormStatistics(poplar::Graph &graph, const poplar::Tensor acts, float eps, poplar::program::Sequence &prog, bool unbiasedVarEstimate, bool stableAlgo, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Estimate mean and inverse of standard deviation of activations.
Definition: InstanceNorm.hpp:41
std::pair< poplar::Tensor, poplar::Tensor > instanceNormalise(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gamma, const poplar::Tensor &beta, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Instance normalise activations given the mean, standard deviation and norm parameters.
Definition: InstanceNorm.hpp:121
poplar::Tensor instanceNormWhiten(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Whiten activations given the mean and standard deviation.
Definition: InstanceNorm.hpp:77
DebugContext gathers the common external parameters of the context of an operation.
Definition: DebugContext.hpp:221
This class represents a graph program to be executed on the IPU.
Definition: Graph.hpp:52
A set of option/value string flags to be used in various APIs.
Definition: OptionFlags.hpp:24
A reference to a subset of tensor elements.
Definition: Tensor.hpp:38
std::size_t dim(unsigned i) const
Get a dimension of the tensor.
Class representing device data types.
Definition: Type.hpp:42
Program that executes a sequence of programs.
Definition: Program.hpp:77
Type FLOAT
Device type: float
Functions used in neural networks.
Definition: BatchNorm.hpp:14