InstanceNorm

#include <popnn/InstanceNorm.hpp>

Instance normalization operations.

Instance norm uses group norm with number of groups = number of channels.

namespace popnn

Functions used in neural networks.

namespace in

Functions

inline std::pair<poplar::Tensor, poplar::Tensor> instanceNormStatistics(poplar::Graph &graph, const poplar::Tensor acts, float eps, poplar::program::Sequence &prog, bool unbiasedVarEstimate, bool stableAlgo, const poplar::Type &partialsType = poplar::FLOAT, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &options = {})

Estimate mean and inverse of standard deviation of activations.

inline poplar::Tensor instanceNormWhiten(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &options = {})

Whiten activations given mean and standard deviation.

inline std::pair<poplar::Tensor, poplar::Tensor> instanceNormalise(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gamma, const poplar::Tensor &beta, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &options = {})

Instance normalise activations given mean, standard deviation and norm parameters.

The result is two tensors

  1. normalised activations

  2. whitened activations

inline std::pair<poplar::Tensor, poplar::Tensor> instanceNormParamGradients(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gradsIn, const poplar::Tensor &mean, const poplar::Tensor &iStdDev, poplar::program::Sequence &prog, const poplar::Type &partialsType = poplar::FLOAT, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &options = {})

Compute gradients w.r.t parameters for parameter update.

inline std::pair<poplar::Tensor, poplar::Tensor> instanceNormParamGradients(poplar::Graph &graph, const poplar::Tensor &actsWhitened, const poplar::Tensor &gradsIn, poplar::program::Sequence &prog, const poplar::Type &partialsType = poplar::FLOAT, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &options = {})

Compute gradients w.r.t parameters for parameter update.

inline poplar::Tensor instanceNormGradients(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gradsIn, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, const poplar::Tensor &gamma, poplar::program::Sequence &prog, const poplar::Type &partialsType = poplar::FLOAT, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &options = {})

Compute gradients w.r.t input activations for the instance norm layer.

Gradients are propagated through the complete layer including statistics computation.

inline poplar::Tensor instanceNormGradients(poplar::Graph &graph, const poplar::Tensor &actsWhitened, const poplar::Tensor &gradsIn, const poplar::Tensor &invStdDev, const poplar::Tensor &gamma, poplar::program::Sequence &prog, const poplar::Type &partialsType = poplar::FLOAT, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &options = {})

Compute gradients w.r.t input activations for the instance norm layer.

Gradients are propagated through the complete layer including statistics computation.

inline void instanceNormParamUpdate(poplar::Graph &graph, const poplar::Tensor &gammaDelta, const poplar::Tensor &betaDelta, float scale, poplar::Tensor &gamma, poplar::Tensor &beta, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &options = {})

Update parameters given gradients w.r.t. parameters.

inline void instanceNormParamUpdate(poplar::Graph &graph, const poplar::Tensor &gammaDelta, const poplar::Tensor &betaDelta, const poplar::Tensor &scale, poplar::Tensor &gamma, poplar::Tensor &beta, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &options = {})
uint64_t getFwdFlops(uint64_t numChannels, uint64_t actsPerChannel, bool computeEstimates)

In flop computation, the following applies:

  • Acts per channel:

    • for fc layers: the total number of batches.

    • for conv layers: the field size per channel * batch size.

  • Number of channels:

    • for fc layers: the total number of activations in a batch.

    • for conv layers: the total number of channels.

uint64_t getBwdFlops(uint64_t numChannels, uint64_t actsPerChannel)
uint64_t getWuFlops(uint64_t numChannels, uint64_t actsPerChannel)