Poplar and PopLibs
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
CTCLoss.hpp
Go to the documentation of this file.
1// Copyright (c) 2021 Graphcore Ltd. All rights reserved.
8#ifndef popnn_CTCLoss_hpp
9#define popnn_CTCLoss_hpp
10
11#include "CTCPlan.hpp"
12
13#include <poplar/Graph.hpp>
14#include <poplar/OptionFlags.hpp>
15#include <poplar/Program.hpp>
16
17namespace popnn {
18namespace ctc {
19
51Plan plan(const poplar::Graph &graph, const poplar::Type &inType,
52 const poplar::Type &outType, unsigned batchSize, unsigned maxTime,
53 unsigned maxLabelLength, unsigned numClasses,
54 const poplar::OptionFlags &options = {});
55
70poplar::Tensor createDataInput(poplar::Graph &graph, const poplar::Type &type,
71 const std::size_t batchSize,
72 const std::size_t maxTime,
73 const std::size_t numClasses, const Plan &plan,
74 const poplar::DebugContext &debugContext = {});
75
89 const std::size_t batchSize,
90 const std::size_t maxLabelLength,
91 const Plan &plan,
92 const poplar::DebugContext &debugContext = {});
93
112/*[INTERNAL]
113 * * `zeroInfinityRelTolerance` Positive decimal [=0.01]
114 *
115 * When zeroInfinity is enabled, the value of the loss for each batch
116 * is compared to the value of the loss that we would get if the input
117 * batches were too short to be alined to the labels. This option
118 * determines the relative tolerance within which the two values are
119 * considered equal.
120 */
140std::pair<poplar::Tensor, poplar::Tensor> calcLossAndGradientLogProbabilities(
141 poplar::Graph &graph, const poplar::Type &outType,
142 const poplar::Tensor &logProbs, const poplar::Tensor &labels,
143 const poplar::Tensor &dataLengths, const poplar::Tensor &labelLengths,
144 poplar::program::Sequence &prog, const unsigned blankClass,
145 const Plan &plan, const poplar::DebugContext &debugContext = {},
146 const poplar::OptionFlags &options = {});
147
170std::pair<poplar::Tensor, poplar::Tensor> calcLossAndGradientLogits(
171 poplar::Graph &graph, const poplar::Type &outType,
172 const poplar::Tensor &logits, const poplar::Tensor &labels,
173 const poplar::Tensor &dataLengths, const poplar::Tensor &labelLengths,
174 poplar::program::Sequence &prog, const unsigned blankClass,
175 const Plan &plan, const poplar::DebugContext &debugContext = {},
176 const poplar::OptionFlags &options = {});
177
199 poplar::Graph &graph, const poplar::Type &outType,
200 const poplar::Tensor &logProbs, const poplar::Tensor &labels,
201 const poplar::Tensor &dataLengths, const poplar::Tensor &labelLengths,
202 poplar::program::Sequence &prog, const unsigned blankClass,
203 const Plan &plan, const poplar::DebugContext &debugContext = {},
204 const poplar::OptionFlags &options = {});
205
227 poplar::Graph &graph, const poplar::Type &outType,
228 const poplar::Tensor &logits, const poplar::Tensor &labels,
229 const poplar::Tensor &dataLengths, const poplar::Tensor &labelLengths,
230 poplar::program::Sequence &prog, const unsigned blankClass,
231 const Plan &plan, const poplar::DebugContext &debugContext = {},
232 const poplar::OptionFlags &options = {});
233
234} // namespace ctc
235} // namespace popnn
236
237#endif // popnn_CTCLoss_hpp
std::pair< poplar::Tensor, poplar::Tensor > calcLossAndGradientLogits(poplar::Graph &graph, const poplar::Type &outType, const poplar::Tensor &logits, const poplar::Tensor &labels, const poplar::Tensor &dataLengths, const poplar::Tensor &labelLengths, poplar::program::Sequence &prog, const unsigned blankClass, const Plan &plan, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Calculate the CTC loss & gradient, creating and mapping the result tensor according to the plan provi...
poplar::Tensor calcCTCLossLogProbabilities(poplar::Graph &graph, const poplar::Type &outType, const poplar::Tensor &logProbs, const poplar::Tensor &labels, const poplar::Tensor &dataLengths, const poplar::Tensor &labelLengths, poplar::program::Sequence &prog, const unsigned blankClass, const Plan &plan, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Calculate the CTC loss, creating and mapping the result tensor according to the plan provided.
poplar::Tensor calcCTCLossLogits(poplar::Graph &graph, const poplar::Type &outType, const poplar::Tensor &logits, const poplar::Tensor &labels, const poplar::Tensor &dataLengths, const poplar::Tensor &labelLengths, poplar::program::Sequence &prog, const unsigned blankClass, const Plan &plan, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Calculate the CTC loss, creating and mapping the result tensor according to the plan provided.
std::pair< poplar::Tensor, poplar::Tensor > calcLossAndGradientLogProbabilities(poplar::Graph &graph, const poplar::Type &outType, const poplar::Tensor &logProbs, const poplar::Tensor &labels, const poplar::Tensor &dataLengths, const poplar::Tensor &labelLengths, poplar::program::Sequence &prog, const unsigned blankClass, const Plan &plan, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Calculate the CTC loss & gradient, creating and mapping the result tensor according to the plan provi...
poplar::Tensor createLabelsInput(poplar::Graph &graph, const poplar::Type &type, const std::size_t batchSize, const std::size_t maxLabelLength, const Plan &plan, const poplar::DebugContext &debugContext={})
Create and map a labels input [batchSize, maxLabelLength] tensor which the gradient function will use...
Support for planning Connectionist Temporal Classification (CTC) Operations.
DebugContext gathers the common external parameters of the context of an operation.
Definition: DebugContext.hpp:221
This class represents a graph program to be executed on the IPU.
Definition: Graph.hpp:52
A set of option/value string flags to be used in various APIs.
Definition: OptionFlags.hpp:24
A reference to a subset of tensor elements.
Definition: Tensor.hpp:38
Class representing device data types.
Definition: Type.hpp:42
Program that executes a sequence of programs.
Definition: Program.hpp:77
An object representing a plan that describes how to map tensors and implement the CTC Loss or CTC Inf...
Definition: CTCPlan.hpp:19
Functions used in neural networks.
Definition: BatchNorm.hpp:14