6#ifndef popnn_BatchNorm_hpp
7#define popnn_BatchNorm_hpp
8#include "poplar/DebugContext.hpp"
9#include "poplar/Program.hpp"
10#include "poplar/Tensor.hpp"
40std::pair<poplar::Tensor, poplar::Tensor>
43 bool stableAlgo =
false,
94 unsigned normBatchSize,
bool stableAlgo =
false,
149std::pair<poplar::Tensor, poplar::Tensor>
std::pair< poplar::Tensor, poplar::Tensor > batchNormParamGradients(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gradsIn, const poplar::Tensor &mean, const poplar::Tensor &iStdDev, poplar::program::Sequence &prog, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Compute gradients with respect to parameters required for parameter update.
std::pair< poplar::Tensor, poplar::Tensor > batchNormStatistics(poplar::Graph &graph, const poplar::Tensor acts, float eps, poplar::program::Sequence &prog, bool unbiasedVarEstimate, bool stableAlgo=false, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Estimate mean and inverse of standard deviation of batched activations.
poplar::Tensor batchNormGradients(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gradsIn, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, const poplar::Tensor &gamma, poplar::program::Sequence &prog, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Compute gradients with respect to input activations for the batch norm layer.
std::pair< poplar::Tensor, poplar::Tensor > distributedBatchNormStatistics(poplar::Graph &replicatedGraph, const poplar::Tensor acts, float eps, poplar::program::Sequence &prog, bool unbiasedVarEstimate, poplin::DistributedNormReduceCallback reduceCallback, unsigned normBatchSize, bool stableAlgo=false, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Compute the batch normalisation statistics for a part of the activations tensor.
poplar::Tensor distributedBatchNormGradients(poplar::Graph &replicatedGraph, const poplar::Tensor &actsWhitened, const poplar::Tensor &gradsIn, const poplar::Tensor &invStdDev, const poplar::Tensor &gamma, poplar::program::Sequence &prog, poplin::DistributedNormReduceCallback reduceCallback, unsigned normBatchSize, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Propagate the gradients through the batch norm layer where equal-sized batch elements are distributed...
std::pair< poplar::Tensor, poplar::Tensor > batchNormalise(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gamma, const poplar::Tensor &beta, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Batch normalise the activations using the given mean, standard deviation and batch norm parameters.
void batchNormParamUpdate(poplar::Graph &graph, const poplar::Tensor &gammaDelta, const poplar::Tensor &betaDelta, float scale, poplar::Tensor &gamma, poplar::Tensor &beta, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Update the parameters for the batch norm layer.
poplar::Tensor batchNormWhiten(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Whiten activations given the mean and standard deviation.
DebugContext gathers the common external parameters of the context of an operation.
Definition: DebugContext.hpp:221
This class represents a graph program to be executed on the IPU.
Definition: Graph.hpp:52
A set of option/value string flags to be used in various APIs.
Definition: OptionFlags.hpp:24
A reference to a subset of tensor elements.
Definition: Tensor.hpp:38
Class representing device data types.
Definition: Type.hpp:42
Program that executes a sequence of programs.
Definition: Program.hpp:77
Type FLOAT
Device type: float
std::function< std::vector< poplar::Tensor >(poplar::Graph &replicatedGraph, const std::vector< poplar::Tensor > &inputsToReduce, poplar::program::Sequence &prog, unsigned groupSize, const poplar::DebugContext &debugContext, const poplar::OptionFlags &options)> DistributedNormReduceCallback
Callback to reduce statistics and gradients.
Definition: Norms.hpp:150
Functions used in neural networks.
Definition: BatchNorm.hpp:14
Functions to support normalising values in a tensor.