Poplar and PopLibs
BatchNorm.hpp
Go to the documentation of this file.
1// Copyright (c) 2017 Graphcore Ltd. All rights reserved.
6#ifndef popnn_BatchNorm_hpp
7#define popnn_BatchNorm_hpp
8#include "poplar/DebugContext.hpp"
9#include "poplar/Program.hpp"
10#include "poplar/Tensor.hpp"
11#include "poplin/Norms.hpp"
12#include <utility>
13
14namespace popnn {
15namespace bn {
16
40std::pair<poplar::Tensor, poplar::Tensor>
41batchNormStatistics(poplar::Graph &graph, const poplar::Tensor acts, float eps,
42 poplar::program::Sequence &prog, bool unbiasedVarEstimate,
43 bool stableAlgo = false,
44 const poplar::Type &partialsType = poplar::FLOAT,
45 const poplar::DebugContext &debugContext = {},
46 const poplar::OptionFlags &options = {});
47
90std::pair<poplar::Tensor, poplar::Tensor> distributedBatchNormStatistics(
91 poplar::Graph &replicatedGraph, const poplar::Tensor acts, float eps,
92 poplar::program::Sequence &prog, bool unbiasedVarEstimate,
94 unsigned normBatchSize, bool stableAlgo = false,
95 const poplar::Type &partialsType = poplar::FLOAT,
96 const poplar::DebugContext &debugContext = {},
97 const poplar::OptionFlags &options = {});
98
116 const poplar::Tensor &mean,
117 const poplar::Tensor &invStdDev,
119 const poplar::DebugContext &debugContext = {},
120 const poplar::OptionFlags &options = {});
121
149std::pair<poplar::Tensor, poplar::Tensor>
151 const poplar::Tensor &gamma, const poplar::Tensor &beta,
152 const poplar::Tensor &mean, const poplar::Tensor &invStdDev,
154 const poplar::DebugContext &debugContext = {},
155 const poplar::OptionFlags &options = {});
156
172 const poplar::Tensor &combinedMultiplicand,
173 const poplar::Tensor &addend,
175 const poplar::DebugContext &debugContext = {},
176 const poplar::OptionFlags &options = {});
177
198std::pair<poplar::Tensor, poplar::Tensor> batchNormParamGradients(
199 poplar::Graph &graph, const poplar::Tensor &acts,
200 const poplar::Tensor &gradsIn, const poplar::Tensor &mean,
201 const poplar::Tensor &iStdDev, poplar::program::Sequence &prog,
202 const poplar::Type &partialsType = poplar::FLOAT,
203 const poplar::DebugContext &debugContext = {},
204 const poplar::OptionFlags &options = {});
205
223std::pair<poplar::Tensor, poplar::Tensor> batchNormParamGradients(
224 poplar::Graph &graph, const poplar::Tensor &actsWhitened,
225 const poplar::Tensor &gradsIn, poplar::program::Sequence &prog,
226 const poplar::Type &partialsType = poplar::FLOAT,
227 const poplar::DebugContext &debugContext = {},
228 const poplar::OptionFlags &options = {});
229
256 const poplar::Tensor &gradsIn, const poplar::Tensor &mean,
257 const poplar::Tensor &invStdDev, const poplar::Tensor &gamma,
259 const poplar::Type &partialsType = poplar::FLOAT,
260 const poplar::DebugContext &debugContext = {},
261 const poplar::OptionFlags &options = {});
262
288 const poplar::Tensor &gradsIn,
289 const poplar::Tensor &invStdDev, const poplar::Tensor &gamma,
291 const poplar::Type &partialsType = poplar::FLOAT,
292 const poplar::DebugContext &debugContext = {},
293 const poplar::OptionFlags &options = {});
294
336 poplar::Graph &replicatedGraph, const poplar::Tensor &actsWhitened,
337 const poplar::Tensor &gradsIn, const poplar::Tensor &invStdDev,
338 const poplar::Tensor &gamma, poplar::program::Sequence &prog,
340 unsigned normBatchSize, const poplar::Type &partialsType = poplar::FLOAT,
341 const poplar::DebugContext &debugContext = {},
342 const poplar::OptionFlags &options = {});
343
387 poplar::Graph &replicatedGraph, const poplar::Tensor &acts,
388 const poplar::Tensor &gradsIn, const poplar::Tensor &mean,
389 const poplar::Tensor &invStdDev, const poplar::Tensor &gamma,
392 unsigned normBatchSize, const poplar::Type &partialsType = poplar::FLOAT,
393 const poplar::DebugContext &debugContext = {},
394 const poplar::OptionFlags &options = {});
395
419 const poplar::Tensor &gammaDelta,
420 const poplar::Tensor &betaDelta, float scale,
421 poplar::Tensor &gamma, poplar::Tensor &beta,
423 const poplar::DebugContext &debugContext = {},
424 const poplar::OptionFlags &options = {});
425
449 const poplar::Tensor &gammaDelta,
450 const poplar::Tensor &betaDelta,
451 const poplar::Tensor &scale, poplar::Tensor &gamma,
453 const poplar::DebugContext &debugContext = {},
454 const poplar::OptionFlags &options = {});
455} // namespace bn
456} // namespace popnn
457#endif // popnn_BatchNorm_hpp
std::pair< poplar::Tensor, poplar::Tensor > batchNormParamGradients(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gradsIn, const poplar::Tensor &mean, const poplar::Tensor &iStdDev, poplar::program::Sequence &prog, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Compute gradients with respect to parameters required for parameter update.
std::pair< poplar::Tensor, poplar::Tensor > batchNormStatistics(poplar::Graph &graph, const poplar::Tensor acts, float eps, poplar::program::Sequence &prog, bool unbiasedVarEstimate, bool stableAlgo=false, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Estimate mean and inverse of standard deviation of batched activations.
poplar::Tensor batchNormGradients(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gradsIn, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, const poplar::Tensor &gamma, poplar::program::Sequence &prog, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Compute gradients with respect to input activations for the batch norm layer.
std::pair< poplar::Tensor, poplar::Tensor > distributedBatchNormStatistics(poplar::Graph &replicatedGraph, const poplar::Tensor acts, float eps, poplar::program::Sequence &prog, bool unbiasedVarEstimate, poplin::DistributedNormReduceCallback reduceCallback, unsigned normBatchSize, bool stableAlgo=false, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Compute the batch normalisation statistics for a part of the activations tensor.
poplar::Tensor distributedBatchNormGradients(poplar::Graph &replicatedGraph, const poplar::Tensor &actsWhitened, const poplar::Tensor &gradsIn, const poplar::Tensor &invStdDev, const poplar::Tensor &gamma, poplar::program::Sequence &prog, poplin::DistributedNormReduceCallback reduceCallback, unsigned normBatchSize, const poplar::Type &partialsType=poplar::FLOAT, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Propagate the gradients through the batch norm layer where equal-sized batch elements are distributed...
std::pair< poplar::Tensor, poplar::Tensor > batchNormalise(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &gamma, const poplar::Tensor &beta, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Batch normalise the activations using the given mean, standard deviation and batch norm parameters.
void batchNormParamUpdate(poplar::Graph &graph, const poplar::Tensor &gammaDelta, const poplar::Tensor &betaDelta, float scale, poplar::Tensor &gamma, poplar::Tensor &beta, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Update the parameters for the batch norm layer.
poplar::Tensor batchNormWhiten(poplar::Graph &graph, const poplar::Tensor &acts, const poplar::Tensor &mean, const poplar::Tensor &invStdDev, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Whiten activations given the mean and standard deviation.
DebugContext gathers the common external parameters of the context of an operation.
Definition: DebugContext.hpp:221
This class represents a graph program to be executed on the IPU.
Definition: Graph.hpp:52
A set of option/value string flags to be used in various APIs.
Definition: OptionFlags.hpp:24
A reference to a subset of tensor elements.
Definition: Tensor.hpp:38
Class representing device data types.
Definition: Type.hpp:42
Program that executes a sequence of programs.
Definition: Program.hpp:77
Type FLOAT
Device type: float
std::function< std::vector< poplar::Tensor >(poplar::Graph &replicatedGraph, const std::vector< poplar::Tensor > &inputsToReduce, poplar::program::Sequence &prog, unsigned groupSize, const poplar::DebugContext &debugContext, const poplar::OptionFlags &options)> DistributedNormReduceCallback
Callback to reduce statistics and gradients.
Definition: Norms.hpp:150
Functions used in neural networks.
Definition: BatchNorm.hpp:14
Functions to support normalising values in a tensor.