Poplar and PopLibs
InstanceNorm.hpp
Go to the documentation of this file.
1// Copyright (c) 2019 Graphcore Ltd. All rights reserved.
8#ifndef popnn_InstanceNorm_hpp
9#define popnn_InstanceNorm_hpp
10#include "popnn/GroupNorm.hpp"
11#include "poputil/DebugInfo.hpp"
12
13namespace popnn {
14namespace in {
15
40inline std::pair<poplar::Tensor, poplar::Tensor>
42 float eps, poplar::program::Sequence &prog,
43 bool unbiasedVarEstimate, bool stableAlgo,
44 const poplar::Type &partialsType = poplar::FLOAT,
45 const poplar::DebugContext &debugContext = {},
46 const poplar::OptionFlags &options = {}) {
47 poputil::PoplibsOpDebugInfo di(debugContext,
48 DI_ARGS(acts, eps, unbiasedVarEstimate,
49 stableAlgo, partialsType, options));
50
51 auto outputs = popnn::gn::groupNormStatistics(
52 graph, acts, eps, prog, acts.dim(1), unbiasedVarEstimate, stableAlgo,
53 partialsType, {di}, options);
54
55 di.addOutputs({{"mean", poputil::toProfileValue(outputs.first)},
56 {"iStd", poputil::toProfileValue(outputs.second)}});
57 return outputs;
58}
59
76inline poplar::Tensor
78 const poplar::Tensor &mean, const poplar::Tensor &invStdDev,
80 const poplar::DebugContext &debugContext = {},
81 const poplar::OptionFlags &options = {}) {
82 poputil::PoplibsOpDebugInfo di(debugContext,
83 DI_ARGS(acts, mean, invStdDev, options));
84
85 auto output = popnn::gn::groupNormWhiten(graph, acts, mean, invStdDev, prog,
86 {di}, options);
87 di.addOutput(output);
88 return output;
89}
90
120inline std::pair<poplar::Tensor, poplar::Tensor>
122 const poplar::Tensor &gamma, const poplar::Tensor &beta,
123 const poplar::Tensor &mean, const poplar::Tensor &invStdDev,
125 const poplar::DebugContext &debugContext = {},
126 const poplar::OptionFlags &options = {}) {
127 poputil::PoplibsOpDebugInfo di(
128 debugContext, DI_ARGS(acts, gamma, beta, mean, invStdDev, options));
129
130 auto outputs = popnn::gn::groupNormalise(graph, acts, gamma, beta, mean,
131 invStdDev, prog, {di}, options);
132
133 di.addOutputs({{"normActs", poputil::toProfileValue(outputs.first)},
134 {"whitenedActs", poputil::toProfileValue(outputs.second)}});
135 return outputs;
136}
137
158inline std::pair<poplar::Tensor, poplar::Tensor> instanceNormParamGradients(
159 poplar::Graph &graph, const poplar::Tensor &acts,
160 const poplar::Tensor &gradsIn, const poplar::Tensor &mean,
161 const poplar::Tensor &iStdDev, poplar::program::Sequence &prog,
162 const poplar::Type &partialsType = poplar::FLOAT,
163 const poplar::DebugContext &debugContext = {},
164 const poplar::OptionFlags &options = {}) {
165
166 poputil::PoplibsOpDebugInfo di(
167 debugContext,
168 DI_ARGS(acts, gradsIn, mean, iStdDev, partialsType, options));
169
170 auto outputs = popnn::gn::groupNormParamGradients(
171 graph, acts, gradsIn, mean, iStdDev, prog, partialsType, {di}, options);
172
173 di.addOutputs({{"meanGrad", poputil::toProfileValue(outputs.first)},
174 {"iStdDevGrad", poputil::toProfileValue(outputs.second)}});
175 return outputs;
176}
177
195inline std::pair<poplar::Tensor, poplar::Tensor> instanceNormParamGradients(
196 poplar::Graph &graph, const poplar::Tensor &actsWhitened,
197 const poplar::Tensor &gradsIn, poplar::program::Sequence &prog,
198 const poplar::Type &partialsType = poplar::FLOAT,
199 const poplar::DebugContext &debugContext = {},
200 const poplar::OptionFlags &options = {}) {
201
202 poputil::PoplibsOpDebugInfo di(
203 debugContext, DI_ARGS(actsWhitened, gradsIn, partialsType, options));
204
205 auto outputs = popnn::gn::groupNormParamGradients(
206 graph, actsWhitened, gradsIn, prog, partialsType, {di}, options);
207
208 di.addOutputs({{"meanGrad", poputil::toProfileValue(outputs.first)},
209 {"iStdDevGrad", poputil::toProfileValue(outputs.second)}});
210 return outputs;
211}
212
238inline poplar::Tensor
240 const poplar::Tensor &gradsIn, const poplar::Tensor &mean,
241 const poplar::Tensor &invStdDev,
242 const poplar::Tensor &gamma,
244 const poplar::Type &partialsType = poplar::FLOAT,
245 const poplar::DebugContext &debugContext = {},
246 const poplar::OptionFlags &options = {}) {
247
248 poputil::PoplibsOpDebugInfo di(
249 debugContext,
250 DI_ARGS(acts, gradsIn, mean, invStdDev, gamma, partialsType, options));
251
252 auto output =
253 popnn::gn::groupNormGradients(graph, acts, gradsIn, mean, invStdDev,
254 gamma, prog, partialsType, {di}, options);
255
256 di.addOutput(output);
257 return output;
258}
259
284 poplar::Graph &graph, const poplar::Tensor &actsWhitened,
285 const poplar::Tensor &gradsIn, const poplar::Tensor &invStdDev,
286 const poplar::Tensor &gamma, poplar::program::Sequence &prog,
287 const poplar::Type &partialsType = poplar::FLOAT,
288 const poplar::DebugContext &debugContext = {},
289 const poplar::OptionFlags &options = {}) {
290
291 poputil::PoplibsOpDebugInfo di(
292 debugContext,
293 DI_ARGS(actsWhitened, gradsIn, invStdDev, gamma, partialsType, options));
294
295 auto output =
296 popnn::gn::groupNormGradients(graph, actsWhitened, gradsIn, invStdDev,
297 gamma, prog, partialsType, {di}, options);
298 di.addOutput(output);
299 return output;
300}
301
324inline void
326 const poplar::Tensor &betaDelta, float scale,
327 poplar::Tensor &gamma, poplar::Tensor &beta,
329 const poplar::DebugContext &debugContext = {},
330 const poplar::OptionFlags &options = {}) {
331 poputil::PoplibsOpDebugInfo di(
332 debugContext,
333 DI_ARGS(gammaDelta, betaDelta, scale, gamma, beta, options));
334
335 return popnn::gn::groupNormParamUpdate(graph, gammaDelta, betaDelta, scale,
336 gamma, beta, prog, {di}, options);
337}
338
361inline void
363 const poplar::Tensor &betaDelta,
364 const poplar::Tensor &scale, poplar::Tensor &gamma,
366 const poplar::DebugContext &debugContext = {},
367 const poplar::OptionFlags &options = {}) {
368 poputil::PoplibsOpDebugInfo di(
369 debugContext,
370 DI_ARGS(gammaDelta, betaDelta, scale, gamma, beta, options));
371 return popnn::gn::groupNormParamUpdate(graph, gammaDelta, betaDelta, scale,
372 gamma, beta, prog, {di}, options);
373}
374
389uint64_t getFwdFlops(uint64_t numChannels, uint64_t actsPerChannel,
390 bool computeEstimates);
395uint64_t getBwdFlops(uint64_t numChannels, uint64_t actsPerChannel);
400uint64_t getWuFlops(uint64_t numChannels, uint64_t actsPerChannel);
401
402} // namespace in
403} // namespace popnn
404#endif // popnn_InstanceNorm_hpp
Poplibs generic debug info structure.
Group normalization operations.
std::pair< poplar::Tensor, poplar::Tensor > instanceNormParamGradients(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gradsIn, const poplar::Tensor &mean, const poplar::Tensor &iStdDev, poplar::program::Sequence &prog, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Compute gradients with respect to parameters for parameter update.
Definition: InstanceNorm.hpp:158
poplar::Tensor instanceNormGradients(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gradsIn, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, const poplar::Tensor &gamma, poplar::program::Sequence &prog, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Compute gradients with respect to input activations for the instance norm layer.
Definition: InstanceNorm.hpp:239
void instanceNormParamUpdate(poplar::Graph &graph, const poplar::Tensor &gammaDelta, const poplar::Tensor &betaDelta, float scale, poplar::Tensor &gamma, poplar::Tensor &beta, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Update parameters for the instance norm layer.
Definition: InstanceNorm.hpp:325
std::pair< poplar::Tensor, poplar::Tensor > instanceNormStatistics(poplar::Graph &graph, const poplar::Tensor acts, float eps, poplar::program::Sequence &prog, bool unbiasedVarEstimate, bool stableAlgo, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Estimate mean and inverse of standard deviation of activations.
Definition: InstanceNorm.hpp:41
std::pair< poplar::Tensor, poplar::Tensor > instanceNormalise(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gamma, const poplar::Tensor &beta, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Instance normalise activations given the mean, standard deviation and norm parameters.
Definition: InstanceNorm.hpp:121
poplar::Tensor instanceNormWhiten(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Whiten activations given the mean and standard deviation.
Definition: InstanceNorm.hpp:77
DebugContext gathers the common external parameters of the context of an operation.
Definition: DebugContext.hpp:221
This class represents a graph program to be executed on the IPU.
Definition: Graph.hpp:52
A set of option/value string flags to be used in various APIs.
Definition: OptionFlags.hpp:24
A reference to a subset of tensor elements.
Definition: Tensor.hpp:38
std::size_t dim(unsigned i) const
Get a dimension of the tensor.
Class representing device data types.
Definition: Type.hpp:42
Program that executes a sequence of programs.
Definition: Program.hpp:77
Type FLOAT
Device type: float
Functions used in neural networks.
Definition: BatchNorm.hpp:14