GroupNorm
#include <popnn/GroupNorm.hpp>
Group normalization operations.
-
namespace popnn
Functions used in neural networks.
-
namespace gn
Functions
-
std::pair<poplar::Tensor, poplar::Tensor> groupNormStatistics(poplar::Graph &graph, const poplar::Tensor acts, float eps, poplar::program::Sequence &prog, unsigned numGroups, bool unbiasedVarEstimate, bool stableAlgo = false, const poplar::Type &partialsType = poplar::FLOAT, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &options = {})
Estimate mean and inverse of standard deviation of activations.
- Parameters
graph – The graph that the normalisation operation is added to.
acts – The activations for which the mean and variance are estimated.
eps – The epsilon value added to the variance to avoid division by zero.
prog – The program sequence to add the operation to.
numGroups – The number of groups to split the channel dimension into when calculating group norm statistics. The
groupNormStridedChannelGrouping
option defines how the split is made.unbiasedVarEstimate – If true, an unbiased variance estimate will be computed.
stableAlgo – If true, computes the mean first then subtracts the activations from it before computing the variance. The implementation with this flag set to true is slower than when set to false.
partialsType – Poplar type used for intermediate values. If the type specified is smaller than the input/ output type then
partialsType
is ignored and the input/output type is used instead.debugContext – Optional debug information.
options – Group normalisation options. See groupNormalise().
- Returns
A vector pair with mean and inverse standard deviation.
-
poplar::Tensor groupNormWhiten(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &options = {})
Whiten activations given the mean and standard deviation.
- Parameters
graph – The graph that the normalisation operation is added to.
acts – The input activations that will be whitened.
mean – The previously calculated mean to subtract from the activations. Typically calculated using groupNormStatistics().
invStdDev – The previously calculated inverse standard deviation to multiply the activations by. Typically calculated using groupNormStatistics().
prog – The program sequence to add the operation to.
debugContext – Optional debug information.
options – Group normalisation options. See groupNormalise().
- Returns
A new tensor with the whitened activations.
-
std::pair<poplar::Tensor, poplar::Tensor> groupNormalise(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gamma, const poplar::Tensor &beta, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &options = {})
Group normalise activations given the mean, standard deviation and group norm parameters.
Group normalisation options
groupNormStridedChannelGrouping
(true, false) [=true]Select groups of channels for group normalisation with a stride between channels. This makes the implementation more efficient but is unconventional. Among other things this will mean that using pre-trained weights would not be possible if not produced with this unconventional implementation.
If we have
numGroups
groups then the channels in the groupgroups[groupIdx]
are given by:Strided channel grouping: channelInGroupIdx * numGroups + groupIdx
Otherwise: channelInGroupIdx + channelsPerGroup * groupIdx
In the case of instanceNormalise() and layerNormalise() (which use group norm in their implementation) this option will have no effect.
- Parameters
graph – The graph that the normalisation operation is added to.
acts – The input activations to whiten and normalise, with shape
[B][C][..F..]
where:B
is the batch sizeC
is the number of channels..F..
are the dimensions of an N-dimensional field.
gamma – The gamma weights to multiply by when normalising the whitened activations.
beta – The beta weights to add when normalising the whitened activations.
mean – The mean to subtract when whitening the activations.
invStdDev – The inverse standard deviation to multiply by when whitening the activations.
prog – The program sequence to add the operation to.
debugContext – Optional debug information.
options – Group normalisation options.
- Returns
Two tensors containing:
normalised activations
whitened activations
-
std::pair<poplar::Tensor, poplar::Tensor> groupNormParamGradients(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gradsIn, const poplar::Tensor &mean, const poplar::Tensor &iStdDev, poplar::program::Sequence &prog, const poplar::Type &partialsType = poplar::FLOAT, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &options = {})
Compute gradients with respect to parameters for parameter update.
- Parameters
graph – The graph that the normalisation operation is added to.
acts – The forward-pass activation inputs to this layer.
gradsIn – The gradient with respect to the output of this layer.
mean – The mean of the
acts
tensor, typically calculated using groupNormStatistics().iStdDev – The inverse standard deviation of the
acts
tensor, typically calculated using groupNormStatistics().prog – The program sequence to add the operation to.
partialsType – Poplar type used for intermediate values. If the type specified is smaller than the input/output type then
partialsType
is ignored and the input/output type is used instead.debugContext – Optional debug information.
options – Group normalisation options. See groupNormalise().
- Returns
A pair of tensors,
gammaDelta
andbetaDelta
which are the gradients with respect togamma
andbeta
.
-
std::pair<poplar::Tensor, poplar::Tensor> groupNormParamGradients(poplar::Graph &graph, const poplar::Tensor &actsWhitened, const poplar::Tensor &gradsIn, poplar::program::Sequence &prog, const poplar::Type &partialsType = poplar::FLOAT, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &options = {})
Compute gradients with respect to parameters for parameter update.
- Parameters
graph – The graph that the normalisation operation is added to.
actsWhitened – The forward-pass whitened activation inputs to this layer.
gradsIn – The gradient with respect to the output of this layer.
prog – The program sequence to add the operation to.
partialsType – Poplar type used for intermediate values. If the type specified is smaller than the input/output type then
partialsType
is ignored and the input/output type is used instead.debugContext – Optional debug information.
options – Group normalisation options. See groupNormalise().
- Returns
A pair of tensors,
gammaDelta
andbetaDelta
which are the gradients with respect togamma
andbeta
.
-
poplar::Tensor groupNormGradients(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gradsIn, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, const poplar::Tensor &gamma, poplar::program::Sequence &prog, const poplar::Type &partialsType = poplar::FLOAT, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &options = {})
Compute gradients with respect to input activations for the group norm layer.
Gradients are propagated through the complete layer including statistics computation.
- Parameters
graph – The graph that the normalisation operation is added to.
acts – The forward-pass activation inputs to this layer.
gradsIn – The gradient with respect to the output of this layer.
mean – The mean of the
acts
tensor, typically calculated using groupNormStatistics().invStdDev – The inverse standard deviation of the
acts
tensor, typically calculated using groupNormStatistics().gamma – The gamma weights to multiply by when normalising the whitened activations.
prog – The program sequence to add the operation to.
partialsType – Poplar type used for intermediate values. If the type specified is smaller than the input/output type then
partialsType
is ignored and the input/output type is used instead.debugContext – Optional debug information.
options – Group normalisation options. See groupNormalise().
- Returns
A tensor containing the gradients with respect to the input activations for this layer.
-
poplar::Tensor groupNormGradients(poplar::Graph &graph, const poplar::Tensor &actsWhitened, const poplar::Tensor &gradsIn, const poplar::Tensor &invStdDev, const poplar::Tensor &gamma, poplar::program::Sequence &prog, const poplar::Type &partialsType = poplar::FLOAT, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &options = {})
Compute gradients with respect to input activations for the group norm layer.
Gradients are propagated through the complete layer including statistics computation.
- Parameters
graph – The graph that the normalisation operation is added to.
actsWhitened – The forward-pass whitened activation inputs to this layer.
gradsIn – The gradient with respect to the output of this layer.
invStdDev – The inverse standard deviation of the
acts
tensor, typically calculated using groupNormStatistics().gamma – The gamma weights to multiply by when normalising the whitened activations.
prog – The program sequence to add the operation to.
partialsType – Poplar type used for intermediate values. If the type specified is smaller than the input/output type then
partialsType
is ignored and the input/output type is used instead.debugContext – Optional debug information.
options – Group normalisation options. See groupNormalise().
- Returns
A tensor containing the gradients with respect to the input activations for this layer.
-
void groupNormParamUpdate(poplar::Graph &graph, const poplar::Tensor &gammaDelta, const poplar::Tensor &betaDelta, float scale, poplar::Tensor &gamma, poplar::Tensor &beta, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &options = {})
Update parameters for the group norm layer.
Gradients are propagated through the complete layer including statistics computation.
The
gamma
andbeta
parameters are updated as follows:gamma
+=gammaDelta
*scale
beta
+=betaDelta
*scale
scale
is a float and therefore constant.- Parameters
graph – The graph that the normalisation operation is added to.
gammaDelta – Value used to update
gamma
.betaDelta – Value used to update
beta
.scale – Scale factor for
gammaDelta
andbetaDelta
.gamma – The gamma weights to multiply by when normalising the activations.
beta – The beta weights to add when normalising the activations.
prog – The program sequence to add the operation to.
debugContext – Optional debug information.
options – Group normalisation options. See groupNormalise().
-
void groupNormParamUpdate(poplar::Graph &graph, const poplar::Tensor &gammaDelta, const poplar::Tensor &betaDelta, const poplar::Tensor &scale, poplar::Tensor &gamma, poplar::Tensor &beta, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &options = {})
Update parameters for the group norm layer.
Gradients are propagated through the complete layer including statistics computation.
The
gamma
andbeta
parameters are updated as follows:gamma += gammaDelta * scale
beta += betaDelta * scale
scale
is a tensor and therefore variable.- Parameters
graph – The graph that the normalisation operation is added to.
gammaDelta – Value used to update
gamma
.betaDelta – Value used to update
beta
.scale – Scale factor for
gammaDelta
andbetaDelta
.gamma – The gamma weights to multiply by when normalising the activations.
beta – The beta weights to add when normalising the activations.
prog – The program sequence to add the operation to.
debugContext – Optional debug information.
options – Group normalisation options. See groupNormalise().
-
std::pair<poplar::Tensor, poplar::Tensor> groupNormStatistics(poplar::Graph &graph, const poplar::Tensor acts, float eps, poplar::program::Sequence &prog, unsigned numGroups, bool unbiasedVarEstimate, bool stableAlgo = false, const poplar::Type &partialsType = poplar::FLOAT, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &options = {})
-
namespace gn