Poplar and PopLibs
CastToGfloat.hpp
1// Copyright (c) 2019 Graphcore Ltd. All rights reserved.
2#ifndef popfloat_CastToGfloat_hpp
3#define popfloat_CastToGfloat_hpp
5#include <popfloat/experimental/GfloatExpr.hpp>
6#include <popfloat/experimental/GfloatExprUtil.hpp>
7#include <poplar/DebugContext.hpp>
8#include <poplar/Engine.hpp>
9#include <poplar/Graph.hpp>
10#include <poplar/OptionFlags.hpp>
11#include <poplar/Program.hpp>
12#include <poplar/Type.hpp>
13
14#include <functional>
15
16/*
17 * README: THIS CODE IS NOT SUPPORTED AND MAY BE REMOVED WITHOUT NOTICE.
18 *
19 * Using precision lower than FP32 allows more efficient computation and
20 * reduces memory usage, permitting the deployment of larger networks, the
21 * use of large batch sizes for memory limited systems and leads to faster
22 * data transfers. In order to study the effect of limited precision data
23 * representation and computation on neural network training on the IPU,
24 * this library introduces a set of functions that will allow the definition
25 * of custom numerical formats.
26 *
27 * The popfloat library allows the emulation of custom numerical formats on
28 * the IPU referred to as Gfloat. The user will be able to specify the float
29 * format's mantissa size, exponent size, exponent bias, and also enable or
30 * disable denormals as well as Inf/Nan signalling. Once the format has
31 * been defined, the user can cast an input (float/half) tensor to the new
32 * format and can choose from a selection of rounding modes (deterministic and
33 * stochastic) for quantisation.
34 * Supported deterministic rounding modes are:
35 * - round-to-zero,
36 * - round to nearest with ties to nearest even,
37 * - round to nearest with ties away from zero,
38 * - round towards positive infinity (ceil)
39 * - round towards negative infinity (floor).
40 *
41 * In the case of Stochastic Rounding (SR), the density function of the noise
42 * samples that are used as well as the number of SR bits can be specified.
43 *
44 * The library also allows the quantised float format to be saved as 8-bit
45 * or 16-bit values, provided that float format size (1+mantissa+exponent)
46 * can be represented as an 8- or a 16-bit value.
47 * Casting a gfloat format from its packed representation to a native higher
48 * precision IEEE float format (IEEE float formats supported on the IPU for
49 * calculation) is done with the castGfloatToNative functions.
50 *
51 */
52
53namespace popfloat {
54namespace experimental {
55
56class GfloatCast {
57public:
58 struct GfloatFormatOptions {
59 unsigned numMantissaBits;
60 unsigned numExponentBits;
61 unsigned numExponentBias;
62 bool enableDenorms;
63 bool enableInfsAndNans;
64
65 void parseGfloatFormatOptions(const poplar::OptionFlags &options);
66
67 GfloatFormatOptions(const poplar::OptionFlags &options) {
68 parseGfloatFormatOptions(options);
69 }
70
71 GfloatFormatOptions() {
72 numMantissaBits = 10;
73 numExponentBits = 5;
74 numExponentBias = 15;
75 enableDenorms = true;
76 enableInfsAndNans = true;
77 }
78 };
79
80 struct FormatConfig {
81 FormatConfig() = default;
82
83 /*
84 * GfloatFormatConfig: This structure stores the configuration parameters
85 * of the generic float format, `gfloat`, and defines attributes of the
86 * format used by the cast operations. Format parameters are:
87 * - numMantissaBits: defines the format's precision
88 * - numExponentBits: defines the number of magnitudes, for radix-2 format
89 * - exponentBias: offset of the stored exponent from the actual value
90 * - enableDenorms: enable gradual underflow
91 * - enableInfsAndNans: allow Inf/Nan signalling
92 * - specCalculationType: The native IPU IEEE float type used to
93 * calculate the gfloat format. Possible values are:
94 * - FP32: can be used for all gfloat formats
95 * - FP16: can only be used for gfloat formats that can be
96 * represented as IEEE FP16. i.e., when numMantissaBits
97 * and numExponentBits less than IEEE FP16's mantissa
98 * and exponent sizes, respectively).
99 * - AUTO: Will let the FormatConfig constructor choose the smallest
100 * native IEEE float to calculate the format.
101 */
102 FormatConfig(unsigned numMantissaBits, unsigned numExponentBits,
103 int exponentBias, bool enableDenorms, bool enableInfsAndNans,
104 popfloat::experimental::SpecType specCalculationType =
105 popfloat::experimental::SpecType::AUTO);
106
107 FormatConfig(GfloatFormatOptions formatOptions,
108 poplar::Type calculationType);
109
110 FormatConfig(unsigned numMantissaBits, unsigned numExponentBits,
111 int exponentBias, bool enableDenorms, bool enableInfsAndNans,
112 poplar::Type calculationType);
113
114 /* Copy constructor */
115 FormatConfig(const FormatConfig &formatConfig);
116
117 poplar::Type getCalculationType() const { return calculationType; }
118 poplar::Type getNativeType() const { return nativeType; }
119 poplar::Type getStorageType() const { return storageType; }
120
121 popfloat::experimental::FormatType getFormatType() const {
122 return formatType;
123 }
124
125 unsigned getNumMantissaBits() const { return numMantissaBits; }
126 unsigned getNumExponentBits() const { return numExponentBits; }
127 int getExponentBias() const { return exponentBias; }
128 bool isDenormEnabled() const { return enableDenorms; }
129 bool infAndNansEnabled() const { return enableInfsAndNans; }
130 bool isPackedFloatFormat() const {
131 return (formatType !=
132 popfloat::experimental::FormatType::QUANTISED_FP16) &&
133 (formatType !=
134 popfloat::experimental::FormatType::QUANTISED_FP32) &&
135 (formatType != popfloat::experimental::FormatType::IEEE_FP16);
136 };
137
138 unsigned getPackedFloatBits() const { return packedFloatBits; };
139 bool isBlockFloat() const { return blockFloat; };
140
141 unsigned getPackedFloatParameters() const { return packedFloatParameters; };
142
143 bool operator==(FormatConfig &other) const {
144 const auto numMantissaBits_ = other.getNumMantissaBits();
145 const auto numExponentBits_ = other.getNumExponentBits();
146 const auto exponentBias_ = other.getExponentBias();
147 const auto enableDenorms_ = other.isDenormEnabled();
148 const auto enableInfsAndNans_ = other.infAndNansEnabled();
149 const auto calculationType_ = other.getCalculationType();
150
151 return std::tie(numMantissaBits, numExponentBits, exponentBias,
152 enableDenorms, enableInfsAndNans, calculationType) ==
153 std::tie(numMantissaBits_, numExponentBits_, exponentBias_,
154 enableDenorms_, enableInfsAndNans_, calculationType_);
155 }
156
157 private:
158 /*
159 * calculationType: IEEE float type used to calculate the gfloat format.
160 * To cast a native IEEE float type to a gfloat format, we can use
161 * - poplar::HALF only if the gfloat format can be represented as an IEEE
162 * FP16. i.e. when the number of gfloat mantissa bits is less than or
163 * equal the IEEE FP16 mantissa size (10) and the number of gfloat
164 * exponent bits is less than or equal to IEEE FP16 mantissa size (5).
165 * - poplar::FLOAT can be used for all gfloat formats.
166 */
167 poplar::Type calculationType;
168
169 /*
170 * numMantissaBits: The Gfloat format's mantissa field size, which
171 * determines the number of fraction bits of the significand.
172 */
173 unsigned numMantissaBits;
174
175 /*
176 * numExponentBits: The Gfloat format's exponent field size.
177 */
178 unsigned numExponentBits;
179
180 /*
181 * exponentBias: The Gfloat format's exponent bias.
182 */
183 int exponentBias;
184
185 /*
186 * enableDenorms: to enable the Gfloat format's denormals. If false,
187 * gradual underflow is disabled, and denormal values will not
188 * represented. This means that the all-zero exponent field will
189 * represent zero.
190 */
191 bool enableDenorms;
192
193 /*
194 * enableInfsAndNans: to enable the Gfloat format's Infs/Nans signalling.
195 * This is ignored if numExponentBits=0. If enabled, input Infs/Nans are
196 * always propagated.
197 */
198 bool enableInfsAndNans;
199
200 /*
201 * nativeType: the format config will choose the smallest IEEE float type
202 * to represent the gfloat format. The result of quantisation is
203 * - poplar::HALF if the gfloat format can be represented as an IEEE FP16.
204 * - poplar::FLOAT if the gfloat format cannot be represented as an IEEE
205 * FP16.
206 * NOTE:
207 * - If the calculationType is IEEE FP32 and the gfloat format can be
208 * represented as IEEE FP16, nativeType will be set to IEEE FP16. (For
209 * instance, when casting to a 1/3/4 format and using IEEE FP32 as a
210 * calculationType). Otherwise, nativeType and calculationType will be
211 * the same.
212 * - When creating a CastConfig the user can override the native type
213 * to use to represent a gfloat format. (See CastConfig).
214 */
215 poplar::Type nativeType;
216
217 /*
218 * storageType: the format config will choose the samllest type that can be
219 * used to store a gfloat format
220 * - poplar::CHAR for a custom FP8 format.
221 * - poplar::SHORT for a custom FP16 format.
222 * - poplar::HALF for formats that can be represented as IEEE FP16.
223 * - poplar::FLOAT for all other format.
224 */
225 poplar::Type storageType;
226
227 /*
228 * formatType: Gfloat format type. The different gfloat format types are:
229 * - IEEE_FP16: This format denotes a cast from IEEE FP32 to IEEE FP16
230 * using rounding schemes not supported by the IPU
231 * - QUANTISED_FP32: Any Gfloat format that is stored as IEEE FP32
232 * - QUANTISED_FP16: Any Gfloat format that is stored as IEEE FP16
233 * - MIN_NORM_ALIGN_GF8: Any custom FP8 format with less than 5 exponent
234 * bits
235 * - ONE_FIVE_TWO_GF8: A 1/5/2 format with Infs/Nans enabled
236 * - MAX_NORM_ALIGN_GF8: A 1/5/2 format with Infs/Nans disabled
237 * - BFLOAT16: Google's Bfloat format (1/8/7)
238 * - NO_DENORM_GF16: A custom FP16 format with denorms disabled
239 * - ENABLE_DENORM_GF16: A custom FP16 with denorms enabled
240 */
241 popfloat::experimental::FormatType formatType;
242
243 /*
244 * packedFloatParameters: This is a packed representation of the gfloat
245 * format's parameters using 4 bytes (stored as INT32). The parameter
246 * packing is done such that:
247 * - one byte is used to store the number of mantissa bits
248 * - one byte is used to store the number of exponent bits
249 * - one byte is used to store the exponent bias
250 * - one bit is used to store the enableDenorms flag
251 * - one bit is used to store the enableInfsAndNans flag
252 */
253 unsigned packedFloatParameters;
254
255 /*
256 * blockFloat: This indicate if the format is INT or block-float. Block
257 * floating-point values are gfloat formats with zero exponent bits or
258 * gfloat formats with one exponent bit (numExponentBits=1) and Infs/Nans
259 * disabled.
260 */
261 bool blockFloat;
262
263 /*
264 * packedFloatBits: the number of bits used to pack the gfloat format
265 */
266 unsigned packedFloatBits;
267 };
268
269 struct GfloatCastOptions {
270 popfloat::experimental::RoundType roundMode;
271 popfloat::experimental::SRDensityType srNoiseDensity;
272 unsigned numSRBits;
273 double srNoiseOffset;
274 double srNoiseScale;
275 double srNoiseMax;
276 double srNoiseMin;
277 double bernoulliProb;
278 bool enableNanooMode;
279
280 void parseGfloatCastOptions(const poplar::OptionFlags &options);
281
282 GfloatCastOptions(const poplar::OptionFlags &options) {
283 parseGfloatCastOptions(options);
284 }
285
286 GfloatCastOptions() {
287 roundMode = popfloat::experimental::RoundType::INV;
288 srNoiseDensity = popfloat::experimental::SRDensityType::INVALID;
289 numSRBits = 24;
290 srNoiseOffset = 0.0;
291 srNoiseScale = 0.0;
292 srNoiseMax = 0.0;
293 srNoiseMin = 0.0;
294 bernoulliProb = 0.0;
295 enableNanooMode = true;
296 }
297 };
298
299 struct RoundConfig {
300 /*
301 * RoundConfig: This structure stores the configuration parameters for
302 * the rounding mode used in a castNativeToGfloat operation:
303 * - roundMode: quantisation rounding mode
304 * - numSRBits: number of random bits used for stochastic rounding,
305 * - srNoiseDensity: Stochasting rounding noise density,
306 * - srNoiseOffset: Stochastic rounding noise offset,
307 * - srNoiseScale: Stochastic rounding noise scaling factor,
308 * - srNoiseMax: Stochastic rounding maximum noise value,
309 * - srNoiseMin: Stochastic rounding minimum noise value,
310 * - bernoulliProb: Probability of rounding down for stochastic
311 * rounding with Bernoulli density
312 */
313
314 RoundConfig() = default;
315
316 RoundConfig(popfloat::experimental::RoundType roundMode, unsigned numSRBits,
317 poplar::Type calculationType,
318 popfloat::experimental::SRDensityType srNoiseDensity =
319 popfloat::experimental::SRDensityType::INVALID,
320 float srNoiseOffset = 0.0, float srNoiseScale = 0.0,
321 float srNoiseMax = 0.0, float srNoiseMin = 0.0,
322 float bernoulliProb = 0.0);
323
324 RoundConfig(const GfloatCast::RoundConfig &roundCfg);
325
326 RoundConfig(GfloatCastOptions castOptions, poplar::Type calculationType);
327
328 popfloat::experimental::RoundType getRoundMode() const {
329 return roundModeType;
330 }
331
332 unsigned getNumSRBits() const { return numSRBits; }
333
334 popfloat::experimental::SRDensityType getSRNoiseDensity() const {
335 return srNoiseDensity;
336 }
337 std::vector<unsigned> getRoundingParams() const { return roundingParams; }
338 std::vector<unsigned> getNoiseParams() const { return noiseParams; }
339 unsigned getDensityParam() const { return densityParam; }
340
341 float getBernoulliProbability() const { return bernoulliProb; }
342
343 float getSRNoiseOffset() const { return srNoiseOffset; }
344 float getSRNoiseScale() const { return srNoiseScale; }
345 float getSRNoiseMax() const { return srNoiseMax; }
346 float getSRNoiseMin() const { return srNoiseMin; }
347
348 std::vector<unsigned> getSRBitMask() const { return srBitMask; }
349
350 private:
351 /*
352 * roundModeType: Quantisation rounding mode. Supported rounding modes are:
353 * - RZ: round-to-zero (truncate)
354 * - RA: round-to-nearest with ties rounding away from zero
355 * - RN: round-to-nearest with ties rounding to nearest even value
356 * - RU: round-towards positive infinity (ceil)
357 * - RD: round-towards negative infinity (floor)
358 * - SR: stochastic rounding using as many random bits as the truncated
359 * mantissa for rounding.
360 * - SX: Stochastic rounding eXtension to limit the maximum number of
361 * random bits and to use different noise distributions for
362 * stochastic rounding.
363 */
364 popfloat::experimental::RoundType roundModeType;
365
366 /*
367 * numSRBits: The number of random bits (N) used for stochastic rounding.
368 * If T mantissa bits of the higher precision input are to be truncated, a
369 * maximum of N or T random bits are used for stochastic rounding,
370 * whichever is smallest. i.e., min(N,T) bits below the Gfloat's mantissa
371 * LSB are used for stochastic rounding.
372 */
373 unsigned numSRBits;
374
375 /*
376 * srNoiseDensity: Stochastic rounding noise density.
377 * Supported densities are
378 * - Uniform: the noise samples are uniformly distributed between
379 * two user-defined values min and max
380 * - Normal: the noise samples are normally distributed with a user-
381 * defined mean and standard deviation (stdDev). The values are
382 * clipped to a defined [min,max] range.
383 * - Truncated-Normal: the noise samples have a truncated normal
384 * distribution with a user-defined mean and standard deviation.
385 * Unlike the normal distribution, for truncated normal we sample
386 * from the normal distribution until all samples are in the range
387 * [min,max].
388 * - Laplace: the noise samples have a Laplace distribution with a user-
389 * defined offset (mu) and scale (b). The values are clipped to a
390 * defined [min,max] range.
391 * - Logistic: the noise samples have a logistic distribution with a
392 * user defined mean and scale (s). The values are clipped to a
393 * defined [min,max] range.
394 * - Logit-Normal: the noise samples have a logit-normal distribution
395 * with defined mean and scale parameter (standard of the normal
396 * values us whose logit is used). The values are clipped to a
397 * [min,max] range.
398 * - Truncated Logit-Normal: the noise samples have a logit-normal
399 * distribution clipped to a [min,max] range. The values whose
400 * logit is used, have a truncated normal distribution.
401 * - Bernoulli: the probability of rounding down is set for all inputs.
402 */
403 popfloat::experimental::SRDensityType srNoiseDensity;
404
405 /*
406 * bernoulliProb: used by the Bernoulli distribution as the stochastic
407 * rounding probability of truncating the mantissa.
408 */
409 float bernoulliProb;
410
411 /*
412 * srNoiseOffset: Stochastic rounding noise samples offset. This is used
413 * by the following densities:
414 * - Normal: to set the distribution mean
415 * - Truncated Normal: to set the distribution mean
416 * - Laplace: to set the distribution offset parameter mu
417 * - Logistic: to set the distribution mean
418 * - Logit-normal: to set the mean of the normal distribution used to
419 * generate the samples
420 * - Truncated logit-normal: to set the mean of the normal distribution
421 * used to generate the samples
422 */
423 float srNoiseOffset;
424
425 /*
426 * srNoiseScale: Stochastic rounding noise samples scale factor. This is
427 * used by the following densities:
428 * - Normal: to set the distribution standard deviation
429 * - Truncated Normal: to set the distribution standard deviation
430 * - Laplace: to set the distribution scale parameter b
431 * - Logistic: to set the distribution scale parameter s
432 * - Logit-normal: to set the standard deviation of the normal distribution
433 * used to generate the samples
434 * - Truncated logit-normal: to set the standard deviation of the normal
435 * distribution used to generate the samples
436 */
437 float srNoiseScale;
438
439 /*
440 * srNoiseMax: Stochastic rounding noise samples maximum value. For the
441 * following densities SRNoiseMax must satisfy:
442 * - Uniform: must be a value in the range [0,1]
443 * - Normal: must be a value in the range [-0.5,0.5]
444 * - Truncated must Normal: be a value in the range [-0.5,0.5]
445 * - Laplace: must be a value in the range [-0.5,0.5]
446 * - Logistic: must be a value in the range [0,1]
447 * - Logit-normal: must be a value in the range [0,1]
448 * - Truncated logit-normal: must be a value in the range [0,1]
449 */
450 float srNoiseMax;
451
452 /*
453 * srNoiseMin: Stochastic rounding noise samples minimum value. For the
454 * different densities, SRNoiseMin must satisfy:
455 * - Uniform: must be a value in the range [0,1]
456 * - Normal: must be a value in the range [-0.5,0.5]
457 * - Truncated must Normal: be a value in the range [-0.5,0.5]
458 * - Laplace: must be a value in the range [-0.5,0.5]
459 * - Logistic: must be a value in the range [0,1]
460 * - Logit-normal: must be a value in the range [0,1]
461 * - Truncated logit-normal: must be a value in the range [0,1]
462 */
463 float srNoiseMin;
464
465 /*
466 * NOTE: Stochastic rounding noise density:
467 * For a given higher precision input, x, the cast output is either
468 * y1 or y2 such that y1<=x< y2. For a noise sample, n, with a given
469 * density, the probability of x rounding down is given by:
470 * p(y1|x,n) = p(x+(y2-y1)*n<y2)=p(n<(y2-x)/(y2-y1))
471 * Scaling by (y2-y1) allows the noise samples to align below the mantissa
472 * LSB. After adding noise, the bottom T bits of the mantissa are lost and
473 * the result is truncated (RZ) or rounded-to-nearest away (RA), depending
474 * on the density. The rounding modes for the different distribution are:
475 * - Uniform: truncate (RZ),
476 * - Normal: round to nearest (RA),
477 * - Truncated Normal: round to nearest (RA),
478 * - Laplace: round to nearest (RA),
479 * - Logistic: truncate (RZ),
480 * - Logit-Normal: truncate (RZ),
481 * - Truncated Logit-Normal: truncate (RZ)
482 */
483
484 /*
485 * noiseParams: the user defined stochastic rounding density parameters
486 * (offset, scale, min, and max) will stored in one vector.
487 */
488 std::vector<unsigned> noiseParams;
489
490 /*
491 * densityParam: Other density parameters:
492 * - For truncated normal and truncated logit-normal this is the
493 * maximum number of times to sample from the Normal distribution
494 * per iteration
495 * - For Bernoulli, this is the scaled probability used by the `rmask`
496 * instruction
497 */
498 unsigned densityParam;
499
500 /*
501 * srBitMask: Bit mask used for stochastic rounding
502 */
503 std::vector<unsigned> srBitMask;
504
505 /*
506 * roundingParams: a vector of all rounding parameters
507 */
508 std::vector<unsigned> roundingParams;
509 };
510
511 /*
512 * CastConfig: This structure stores the configuration parameters
513 * of the gfloat cast operations. The different cast operations are:
514 * - Cast to quantised FP32 with the possibility to save the output
515 * as INT16 for custom FP16 formats.
516 * - Cast to quantised FP16 with the possibility to save the output
517 * as INT8 for custom FP8 formats.
518 * - Cast a custom FP16 to IEEE FP32 from the INT16 representation of the
519 * format.
520 * - Cast a custom FP8 to IEEE FP16 from the INT8 representation of the
521 * format.
522 */
523 struct CastConfig {
524 CastConfig() = default;
525
535 static CastConfig
536 createCastNativeToGF(popfloat::experimental::FormatType formatType,
537 poplar::Type calculationType, poplar::Type storageType,
538 RoundConfig roundCfg, bool enableNanooMode);
539
550 static CastConfig
551 createCastGFToNative(popfloat::experimental::FormatType formatType,
552 poplar::Type calculationType,
553 poplar::Type storageType);
554
555 popfloat::experimental::RoundType getRoundMode() const {
556 return roundConfig.getRoundMode();
557 }
558
559 unsigned getNumSRBits() const { return roundConfig.getNumSRBits(); }
560
561 popfloat::experimental::SRDensityType getSRNoiseDensity() const {
562 return roundConfig.getSRNoiseDensity();
563 }
564
565 std::vector<unsigned> getNoiseParams() const {
566 return roundConfig.getNoiseParams();
567 }
568
569 unsigned getDensityParam() const { return roundConfig.getDensityParam(); }
570
571 float getBernoulliProbability() const {
572 return roundConfig.getBernoulliProbability();
573 }
574
575 float getSRNoiseOffset() const { return roundConfig.getSRNoiseOffset(); }
576
577 float getSRNoiseScale() const { return roundConfig.getSRNoiseScale(); }
578
579 float getSRNoiseMax() const { return roundConfig.getSRNoiseMax(); }
580
581 float getSRNoiseMin() const { return roundConfig.getSRNoiseMin(); }
582
583 std::vector<unsigned> getSRBitMask() const {
584 return roundConfig.getSRBitMask();
585 }
586
587 bool isNanooModeEnabled() const { return enableNanooMode; }
588
589 std::vector<unsigned> getCastParams() const { return castParams; }
590
591 poplar::Type getCalculationType() const { return calculationType; }
592 poplar::Type getStorageType() const { return storageType; }
593
594 bool inPlaceOp(poplar::Type inType) const {
595 return (inType == storageType);
596 }
597
598 popfloat::experimental::FormatType getFormatType() const {
599 return floatFormatType;
600 }
601
602 bool getStoreAsNative() const { return storeAsNative; }
603
604 std::vector<unsigned> getRoundingParams() const {
605 return roundConfig.getRoundingParams();
606 }
607
608 private:
609 CastConfig(popfloat::experimental::FormatType floatFormatType,
610 poplar::Type calculationType, poplar::Type storageType,
611 RoundConfig roundCfg, bool enableNanooMode);
612
613 /*
614 * calculationType: IEEE float type used to calculate the gfloat format
615 * - poplar::HALF iff the gfloat format can be represented as an
616 * IEEE FP16. i.e. when the number of gfloat mantissa bits is less
617 * than or equal theIEEE FP16 mantissa size (10) and the number of
618 * gfloat exponent bits is less than or equal to IEEE FP16 mantissa
619 * size (5).
620 * - poplar::FLOAT any gfloat format can be represented as an IEEE
621 * FP32.
622 * NOTE: This is copied from the calculationType used for FormatConfig.
623 */
624 poplar::Type calculationType;
625
626 /*
627 * storageType: type used to represent a custom float format
628 * - poplar::FLOAT for quantised FP32 formats.
629 * - poplar::HALF for quantised FP16 formats.
630 * - poplar::CHAR for a custom FP8 formats
631 * - poplar::SHORT for a custom FP16 formats
632 * NOTE: This can be copied from the storageType chosen by FormatConfig,
633 * or can set by the user.
634 */
635 poplar::Type storageType;
636
637 /*
638 * An instance of RoundConfig storing attributes of the rounding method
639 * used in this cast operation.
640 */
641 RoundConfig roundConfig;
642
643 /*
644 * enableNanooMode: this is similar to the IPU's NaNOO mode
645 * - If true, the cast will generate QNaNs on overflow or when input
646 * values have magnitudes greater than the format's maximum value.
647 * - If false, the cast will clip on overflow and when input values
648 * are outside the range.
649 * NOTE:
650 * - Regardless of whether the Nanoo mode is turned on or off, input
651 * Infs/Nans will always propagate.
652 * - This mode should be disabled if the format's Infs/Nans are not
653 * enabled or if the number of exponent bits is zero. Otherwise,
654 * when a quantised gfloat is packed (INT8 for custom FP8 and INT16
655 * for custom FP16 propagated Infs/Nans will be packed with the all
656 * one-exponent. When the packed value are unpacked, the values that
657 * used to be Infs/Nans, after quantisation, will become values with
658 * the format's maximum exponent. This is equivalent to disabling
659 * the propagation of Infs/Nans in quantisation.
660 */
661 bool enableNanooMode;
662
663 /*
664 * floatFormatType: the Gfloat format type. The different types are:
665 * - IEEE_FP16: When casting from IEEE FP32 to IEEE FP16 using
666 * rounding modes not supported by the IPU
667 * - QUANTISED_FP32: Any Gfloat format that can only be stored as
668 * IEEE FP32
669 * - QUANTISED_FP16: Any Gfloat format that is stored as IEEE FP16
670 * - MIN_NORM_ALIGN_GF8: Any custom FP8 format with less than 5
671 * exponent bits
672 * - ONE_FIVE_TWO_GF8: A 1/5/2 format with Infs/Nans enabled
673 * - MAX_NORM_ALIGN_GF8: A 1/5/2 format with Infs/Nans disabled
674 * - BFLOAT16: Google's Bfloat format (1/8/7) with denorms not enabled.
675 * - NO_DENORM_GF16: A custom FP16 format with denorms not enabled
676 * - ENABLE_DENORM_GF16: A custom FP16 with denorms enabled
677 */
678 popfloat::experimental::FormatType floatFormatType;
679
680 /*
681 * storeAsNative: Indicates if a gfloat format is stored as a Native IEEE
682 * float if true, or if the gfloat format is packed to the smallest
683 * bit representation.
684 */
685 bool storeAsNative;
686
687 /*
688 * castParams: A vector of all parameters used by cast vertex
689 */
690 std::vector<unsigned> castParams;
691 };
692
724 GfloatCast(const FormatConfig &formatCfg, const RoundConfig &roundCfg,
725 const bool enableNanooMode,
726 const popfloat::experimental::SpecType &GFType =
727 popfloat::experimental::SpecType::AUTO,
728 const popfloat::experimental::SpecType &NativeType =
729 popfloat::experimental::SpecType::AUTO);
730
731 GfloatCast(const GfloatFormatOptions &formatOtions,
732 const GfloatCastOptions &castOptions, poplar::Type calculationType,
733 const popfloat::experimental::SpecType &GFType =
734 popfloat::experimental::SpecType::AUTO,
735 const popfloat::experimental::SpecType &NativeType =
736 popfloat::experimental::SpecType::AUTO);
737
742 GfloatCast(const GfloatCast &gfloatCast);
743
755 static poplar::Tensor
756 createCastOpParamsTensor(poplar::Graph &graph, const poplar::ComputeSet &cs,
757 poplar::Type calculationType,
758 const unsigned gfPacked,
759 const poplar::DebugContext &debugContext = {});
760
772 static poplar::Tensor
773 createCastOpParamsTensor(poplar::Graph &graph, const poplar::ComputeSet &cs,
774 poplar::Type calculationType,
775 poplar::Tensor gfPacked,
776 const poplar::DebugContext &debugContext = {});
777
789 static poplar::Tensor createCastOpParamsTensor(
791 poplar::Type calculationType, const unsigned gfStruct,
792 const poplar::DebugContext &debugContext = {});
793
805 static poplar::Tensor createCastOpParamsTensor(
807 poplar::Type calculationType, poplar::Tensor gfStruct,
808 const poplar::DebugContext &debugContext = {});
809
819 void createCastOpParamsTensor(poplar::Graph &graph,
821 const poplar::DebugContext &debugContext = {});
822
830 void createCastOpParamsTensor(poplar::Graph &graph,
831 const poplar::ComputeSet &cs);
832
845 static poplar::Tensor castNativeToGfloat(
846 poplar::Graph &graph, poplar::Tensor input, const poplar::Tensor &param,
847 poplar::program::Sequence &prog, const CastConfig &gfCastCfg,
848 const poplar::DebugContext &debugContext = {});
849
866 static poplar::Tensor
867 castNativeToGfloat(poplar::Graph &graph, poplar::Tensor input,
868 const poplar::Tensor &param, const poplar::ComputeSet &cs,
869 const CastConfig &gfCastCfg,
870 const poplar::DebugContext &debugContext = {});
871
884 castNativeToGfloat(poplar::Graph &graph, poplar::Tensor input,
886 const poplar::DebugContext &debugContext = {});
887
898 static void castNativeToGfloatInPlace(
899 poplar::Graph &graph, poplar::Tensor input, const poplar::Tensor &param,
900 poplar::program::Sequence &prog, const CastConfig &gfCastCfg,
901 const poplar::DebugContext &debugContext = {});
902
913 void castNativeToGfloatInPlace(poplar::Graph &graph, poplar::Tensor input,
915 const poplar::DebugContext &debugContext = {});
916
927 static void castNativeToGfloatInPlace(
928 poplar::Graph &graph, poplar::Tensor input, const poplar::Tensor &param,
929 const poplar::ComputeSet &cs, const CastConfig &gfCastCfg,
930 const poplar::DebugContext &debugContext = {});
931
940 void castNativeToGfloatInPlace(poplar::Graph &graph, poplar::Tensor input,
941 const poplar::ComputeSet &cs);
942
955 static poplar::Tensor castGfloatToNative(
956 poplar::Graph &graph, poplar::Tensor input, const poplar::Tensor &param,
957 poplar::program::Sequence &prog, const CastConfig &gfCastCfg,
958 const poplar::DebugContext &debugContext = {});
959
972 castGfloatToNative(poplar::Graph &graph, poplar::Tensor input,
974 const poplar::DebugContext &debugContext = {});
975
989 static poplar::Tensor
990 castGfloatToNative(poplar::Graph &graph, poplar::Tensor input,
991 const poplar::Tensor &param, const poplar::ComputeSet &cs,
992 const CastConfig &gfCastCfg,
993 const poplar::DebugContext &debugContext = {});
994
1006 poplar::Tensor castGfloatToNative(poplar::Graph &graph, poplar::Tensor input,
1007 const poplar::ComputeSet &cs);
1008
1013 poplar::Type getGFStorageType() const {
1014 return nativeToGFCastCfg.getStorageType();
1015 }
1016
1021 poplar::Type getCalculationType() const {
1022 return nativeToGFCastCfg.getCalculationType();
1023 }
1024
1029 poplar::Type getNativeStorageType() const {
1030 return gfToNativeCastCfg.getStorageType();
1031 }
1032
1037 FormatConfig getFormatConfig() const { return formatCfg; }
1038
1043 CastConfig getNativeToGFConfig() const { return nativeToGFCastCfg; }
1044
1049 CastConfig getGFToNativeConfig() const { return gfToNativeCastCfg; }
1050
1055 bool getStoreAsNative() const { return nativeToGFCastCfg.getStoreAsNative(); }
1056
1061 poplar::Tensor getCastOpParams() const { return *gfParams; }
1062
1067 bool isCastOpParamSet() const { return castOpParamSet; }
1068
1072 void setGfloatCastParameters(poplar::Tensor *gfParams_) {
1073 gfParams.reset(gfParams_);
1074 }
1075
1076 bool isNanooModeEnabled() const {
1077 return nativeToGFCastCfg.isNanooModeEnabled();
1078 }
1079
1080 std::vector<unsigned> getSRBitMask() const {
1081 return nativeToGFCastCfg.getSRBitMask();
1082 }
1083
1084 popfloat::experimental::RoundType getRoundMode() const {
1085 return nativeToGFCastCfg.getRoundMode();
1086 }
1087
1088 bool inPlaceOp(poplar::Type outType) const {
1089 return nativeToGFCastCfg.inPlaceOp(outType);
1090 }
1091
1092 std::vector<unsigned> getRoundingParams() const {
1093 return nativeToGFCastCfg.getRoundingParams();
1094 }
1095
1096protected:
1097 CastConfig nativeToGFCastCfg;
1098 CastConfig gfToNativeCastCfg;
1099 FormatConfig formatCfg;
1100 std::unique_ptr<poplar::Tensor> gfParams;
1101 bool castOpParamSet;
1102};
1103
1104} // end namespace experimental
1105} // end namespace popfloat
1106#endif
A reference to a compute set within a graph.
Definition: GraphElements.hpp:131
DebugContext gathers the common external parameters of the context of an operation.
Definition: DebugContext.hpp:221
This class represents a graph program to be executed on the IPU.
Definition: Graph.hpp:52
A set of option/value string flags to be used in various APIs.
Definition: OptionFlags.hpp:24
A reference to a subset of tensor elements.
Definition: Tensor.hpp:38
Class representing device data types.
Definition: Type.hpp:42
Program that executes a sequence of programs.
Definition: Program.hpp:77
Define a PopLibs exception.