Poplar and PopLibs
GroupNorm.hpp
Go to the documentation of this file.
1// Copyright (c) 2019 Graphcore Ltd. All rights reserved.
6#ifndef popnn_GroupNorm_hpp
7#define popnn_GroupNorm_hpp
8#include "poplar/Program.hpp"
9#include "poplar/Tensor.hpp"
10#include <poplar/OptionFlags.hpp>
11#include <utility>
12
13namespace popnn {
14namespace gn {
15
44std::pair<poplar::Tensor, poplar::Tensor>
45groupNormStatistics(poplar::Graph &graph, const poplar::Tensor acts, float eps,
46 poplar::program::Sequence &prog, unsigned numGroups,
47 bool unbiasedVarEstimate, bool stableAlgo = false,
48 const poplar::Type &partialsType = poplar::FLOAT,
49 const poplar::DebugContext &debugContext = {},
50 const poplar::OptionFlags &options = {});
51
69 const poplar::Tensor &mean,
70 const poplar::Tensor &invStdDev,
72 const poplar::DebugContext &debugContext = {},
73 const poplar::OptionFlags &options = {});
74
120std::pair<poplar::Tensor, poplar::Tensor>
122 const poplar::Tensor &gamma, const poplar::Tensor &beta,
123 const poplar::Tensor &mean, const poplar::Tensor &invStdDev,
125 const poplar::DebugContext &debugContext = {},
126 const poplar::OptionFlags &options = {});
127
148std::pair<poplar::Tensor, poplar::Tensor> groupNormParamGradients(
149 poplar::Graph &graph, const poplar::Tensor &acts,
150 const poplar::Tensor &gradsIn, const poplar::Tensor &mean,
151 const poplar::Tensor &iStdDev, poplar::program::Sequence &prog,
152 const poplar::Type &partialsType = poplar::FLOAT,
153 const poplar::DebugContext &debugContext = {},
154 const poplar::OptionFlags &options = {});
155
173std::pair<poplar::Tensor, poplar::Tensor> groupNormParamGradients(
174 poplar::Graph &graph, const poplar::Tensor &actsWhitened,
175 const poplar::Tensor &gradsIn, poplar::program::Sequence &prog,
176 const poplar::Type &partialsType = poplar::FLOAT,
177 const poplar::DebugContext &debugContext = {},
178 const poplar::OptionFlags &options = {});
179
206 const poplar::Tensor &gradsIn, const poplar::Tensor &mean,
207 const poplar::Tensor &invStdDev, const poplar::Tensor &gamma,
209 const poplar::Type &partialsType = poplar::FLOAT,
210 const poplar::DebugContext &debugContext = {},
211 const poplar::OptionFlags &options = {});
212
237 const poplar::Tensor &gradsIn,
238 const poplar::Tensor &invStdDev, const poplar::Tensor &gamma,
240 const poplar::Type &partialsType = poplar::FLOAT,
241 const poplar::DebugContext &debugContext = {},
242 const poplar::OptionFlags &options = {});
243
267 const poplar::Tensor &gammaDelta,
268 const poplar::Tensor &betaDelta, float scale,
269 poplar::Tensor &gamma, poplar::Tensor &beta,
271 const poplar::DebugContext &debugContext = {},
272 const poplar::OptionFlags &options = {});
273
297 const poplar::Tensor &gammaDelta,
298 const poplar::Tensor &betaDelta,
299 const poplar::Tensor &scale, poplar::Tensor &gamma,
301 const poplar::DebugContext &debugContext = {},
302 const poplar::OptionFlags &options = {});
303} // namespace gn
304} // namespace popnn
305#endif // popnn_GroupNorm_hpp
std::pair< poplar::Tensor, poplar::Tensor > groupNormParamGradients(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gradsIn, const poplar::Tensor &mean, const poplar::Tensor &iStdDev, poplar::program::Sequence &prog, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Compute gradients with respect to parameters for parameter update.
std::pair< poplar::Tensor, poplar::Tensor > groupNormalise(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gamma, const poplar::Tensor &beta, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Group normalise activations given the mean, standard deviation and group norm parameters.
poplar::Tensor groupNormWhiten(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Whiten activations given the mean and standard deviation.
std::pair< poplar::Tensor, poplar::Tensor > groupNormStatistics(poplar::Graph &graph, const poplar::Tensor acts, float eps, poplar::program::Sequence &prog, unsigned numGroups, bool unbiasedVarEstimate, bool stableAlgo=false, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Estimate mean and inverse of standard deviation of activations.
poplar::Tensor groupNormGradients(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gradsIn, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, const poplar::Tensor &gamma, poplar::program::Sequence &prog, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Compute gradients with respect to input activations for the group norm layer.
void groupNormParamUpdate(poplar::Graph &graph, const poplar::Tensor &gammaDelta, const poplar::Tensor &betaDelta, float scale, poplar::Tensor &gamma, poplar::Tensor &beta, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Update parameters for the group norm layer.
DebugContext gathers the common external parameters of the context of an operation.
Definition: DebugContext.hpp:221
This class represents a graph program to be executed on the IPU.
Definition: Graph.hpp:52
A set of option/value string flags to be used in various APIs.
Definition: OptionFlags.hpp:24
A reference to a subset of tensor elements.
Definition: Tensor.hpp:38
Class representing device data types.
Definition: Type.hpp:42
Program that executes a sequence of programs.
Definition: Program.hpp:77
Type FLOAT
Device type: float
Functions used in neural networks.
Definition: BatchNorm.hpp:14