Poplar and PopLibs
Expr.hpp
Go to the documentation of this file.
1// Copyright (c) 2018 Graphcore Ltd. All rights reserved.
8#ifndef __popops_Expr_hpp__
9#define __popops_Expr_hpp__
10#include <cassert>
11#include <memory>
12#include <poplar/Tensor.hpp>
13#include <poplar/Type.hpp>
14#include <poplar/TypeTraits.hpp>
15#include <popops/ExprOp.hpp>
16#include <string>
17#include <type_traits>
18#include <vector>
19
20namespace popops {
21namespace expr {
22
36class Expr {
37protected:
38 using ExprClassID = void (*)(void);
39 ExprClassID classId;
40 Expr(ExprClassID classId) : classId(classId) {}
41
42public:
43 virtual ~Expr();
44
45 template <class T> bool isA() const { return classId == T::getClassId(); }
46
47 template <class T> T *getAs() {
48 if (!isA<T>())
49 return 0;
50 return static_cast<T *>(this);
51 }
52
53 template <class T> const T *getAs() const {
54 if (!isA<T>())
55 return 0;
56 return static_cast<const T *>(this);
57 }
58
59 virtual std::unique_ptr<Expr> clone() const = 0;
60
61 virtual std::string name(const std::vector<poplar::Tensor> &) const = 0;
62
63 virtual bool deepEquals(const Expr &other) const = 0;
64
65 virtual void print(std::ostream &os, unsigned indent = 0,
66 bool prettyPrint = true) const = 0;
67};
68
69std::ostream &operator<<(std::ostream &os, const Expr &expr);
70bool deepEquals(const Expr &a, const Expr &b);
71
72template <class T> class ExprType : public Expr {
73 static void loc();
74 static ExprClassID getClassId() { return &loc; }
75
76public:
77 ExprType() : Expr(getClassId()) {}
78 friend class Expr;
79};
80
83class Any {
84 std::unique_ptr<Expr> expr;
85
86public:
87 Any(const Expr &expr) : expr(expr.clone()) {}
88
89 operator Expr &() { return *expr; }
90 operator const Expr &() const { return *expr; }
91 std::string name(const std::vector<poplar::Tensor> &inputs) const {
92 return expr->name(inputs);
93 }
94};
95
97class Const : public ExprType<Const> {
98 poplar::TypeTraits typeTraits;
99 poplar::Type type;
100 std::unique_ptr<char[]> data;
101
102protected:
103 template <typename T> Const(T x, bool isHalfType) {
104 static_assert(std::is_integral<T>::value ||
105 std::is_floating_point<T>::value,
106 "Constant expression values should be integrals or floats");
107 typeTraits = poplar::TypeTraits::make<T>();
108 if (isHalfType) {
109 type = poplar::HALF;
110 } else {
112 }
113 data.reset(new char[typeTraits.size]);
114 const char *p = reinterpret_cast<const char *>(&x);
115 std::copy(p, p + typeTraits.size, data.get());
116 }
117
118public:
119 template <typename T, typename = typename std::enable_if<
120 poplar::TypeTraits::isSimpleType<T>(), T>::type>
121 Const(T x) : Const(x, false) {}
122
123 Const(poplar::TypeTraits typeTraits_, poplar::Type type_, const char *data_)
124 : typeTraits(std::move(typeTraits_)), type(type_) {
125 data.reset(new char[typeTraits.size]);
126 std::copy(data_, data_ + typeTraits.size, data.get());
127 }
128 Const(Const &&) = default;
129 Const &operator=(Const &&) = default;
130 Const(const Const &other)
131 : Const(other.typeTraits, other.type, other.data.get()) {}
132 Const &operator=(const Const &other) {
133 Const tmp{other};
134 std::swap(*this, tmp);
135 return *this;
136 }
137
138 char *getData() const { return data.get(); }
139
140 const poplar::TypeTraits &getTypeTraits() const { return typeTraits; }
141
142 const poplar::Type &getType() const { return type; }
143
144 std::string printValue() const;
145
146 double getDataAsDouble() const;
147
148 std::uint64_t getDataForUnsignedIntegral() const;
149
150 std::unique_ptr<Expr> clone() const override {
151 return std::unique_ptr<Expr>(new Const(typeTraits, type, data.get()));
152 }
153 std::string name(const std::vector<poplar::Tensor> &) const override;
154
155 bool deepEquals(const Expr &other) const override;
156
157 void print(std::ostream &os, unsigned indent = 0,
158 bool prettyPrint = true) const override;
159};
160
162class ConstHalf : public Const {
163public:
164 ConstHalf(float x) : Const(x, true) {}
165 ConstHalf(ConstHalf &&) = default;
166 ConstHalf &operator=(ConstHalf &&) = default;
167 ConstHalf(const ConstHalf &other) : Const(other) {}
168 ConstHalf &operator=(const ConstHalf &other) {
169 ConstHalf tmp{other};
170 std::swap(*this, tmp);
171 return *this;
172 }
173};
174
175inline ConstHalf operator"" _half(long double x) {
176 assert(x <= std::numeric_limits<float>::max());
177 return ConstHalf(static_cast<float>(x));
178}
179
181class Cast : public ExprType<Cast> {
182 std::unique_ptr<Expr> a;
183 poplar::Type bType;
184
185public:
186 Cast(const Expr &a_, const poplar::Type bType_)
187 : a(a_.clone()), bType(bType_) {}
188 Cast(Cast &&other) = default;
189 Cast &operator=(Cast &&other) = default;
190 Cast(const Cast &other) : Cast(*other.a, other.bType) {}
191 Cast &operator==(const Cast &other) {
192 Cast tmp{other};
193 std::swap(*this, tmp);
194 return *this;
195 }
196
197 const Expr &getLHS() const { return *a; }
198 const poplar::Type &getRHSType() const { return bType; }
199
200 std::unique_ptr<Expr> clone() const override {
201 return std::unique_ptr<Expr>(new Cast(*a, bType));
202 }
203 std::string name(const std::vector<poplar::Tensor> &inputs) const override;
204
205 bool deepEquals(const Expr &other) const override;
206
207 void print(std::ostream &os, unsigned indent = 0,
208 bool prettyPrint = true) const override;
209};
210
211class PlaceHolder : public ExprType<PlaceHolder> {
212 unsigned index;
213
214public:
215 PlaceHolder(unsigned index) : index(index) {}
216 PlaceHolder(PlaceHolder &&) = default;
217 PlaceHolder &operator=(PlaceHolder &&) = default;
218 PlaceHolder(const PlaceHolder &other) : PlaceHolder(other.index) {}
219 PlaceHolder &operator=(const PlaceHolder &other) {
220 PlaceHolder tmp{other};
221 std::swap(*this, tmp);
222 return *this;
223 }
224
225 unsigned getIndex() const { return index; }
226
227 std::unique_ptr<Expr> clone() const override {
228 return std::unique_ptr<Expr>(new PlaceHolder(index));
229 }
230 std::string name(const std::vector<poplar::Tensor> &inputs) const override;
231
232 bool deepEquals(const Expr &other) const override;
233
234 void print(std::ostream &os, unsigned indent = 0,
235 bool prettyPrint = true) const override;
236};
237
238const PlaceHolder _1(1);
239const PlaceHolder _2(2);
240const PlaceHolder _3(3);
241const PlaceHolder _4(4);
242const PlaceHolder _5(5);
243const PlaceHolder _6(6);
244const PlaceHolder _7(7);
245const PlaceHolder _8(8);
246const PlaceHolder _9(9);
247const PlaceHolder _10(10);
248const PlaceHolder _11(11);
249const PlaceHolder _12(12);
250const PlaceHolder _13(13);
251const PlaceHolder _14(14);
252const PlaceHolder _15(15);
253const PlaceHolder _16(16);
254const PlaceHolder _17(17);
255const PlaceHolder _18(18);
256const PlaceHolder _19(19);
257const PlaceHolder _20(20);
258
260class UnaryOp : public ExprType<UnaryOp> {
261 UnaryOpType type;
262 std::unique_ptr<Expr> a;
263
264public:
265 UnaryOp(UnaryOpType type, const Expr &a) : type(type), a(a.clone()) {}
266 UnaryOp(UnaryOp &&) = default;
267 UnaryOp &operator=(UnaryOp &&) = default;
268 UnaryOp(const UnaryOp &other) : UnaryOp(other.type, *other.a) {}
269 UnaryOp &operator=(const UnaryOp &other) {
270 UnaryOp tmp(other);
271 std::swap(*this, tmp);
272 return *this;
273 }
274
275 UnaryOpType getOpType() const { return type; }
276
277 const Expr &getArg() const { return *a; }
278
279 std::unique_ptr<Expr> clone() const override {
280 return std::unique_ptr<Expr>(new UnaryOp(type, *a));
281 }
282 std::string name(const std::vector<poplar::Tensor> &inputs) const override;
283 std::string exprName(const std::vector<poplar::Tensor> &inputs) const {
284 return a->name(inputs);
285 };
286
287 bool deepEquals(const Expr &other) const override;
288
289 void print(std::ostream &os, unsigned indent = 0,
290 bool prettyPrint = true) const override;
291};
292
293#define POPLIBS_DEFINE_EXPR_UNARY_OP(Name, Op) \
294 class Name : public UnaryOp { \
295 public: \
296 Name(const Expr &a) : UnaryOp(UnaryOpType::Op, a) {} \
297 };
298
299#define POPLIBS_DEFINE_EXPR_UNARY_OP_AND_SYMBOL(Name, Op, Sym) \
300 POPLIBS_DEFINE_EXPR_UNARY_OP(Name, Op) \
301 inline Name operator Sym(const Expr &a) { return Name(a); }
302
303POPLIBS_DEFINE_EXPR_UNARY_OP(Abs, ABSOLUTE)
304POPLIBS_DEFINE_EXPR_UNARY_OP(Asin, ASIN)
305POPLIBS_DEFINE_EXPR_UNARY_OP_AND_SYMBOL(BitwiseNot, BITWISE_NOT, ~)
306POPLIBS_DEFINE_EXPR_UNARY_OP(Cbrt, CBRT)
307POPLIBS_DEFINE_EXPR_UNARY_OP(Erf, ERF)
308POPLIBS_DEFINE_EXPR_UNARY_OP(Ceil, CEIL)
309POPLIBS_DEFINE_EXPR_UNARY_OP(Cos, COS)
310POPLIBS_DEFINE_EXPR_UNARY_OP(Exp, EXPONENT)
311POPLIBS_DEFINE_EXPR_UNARY_OP(Expm1, EXPONENT_MINUS_ONE)
312POPLIBS_DEFINE_EXPR_UNARY_OP(Floor, FLOOR)
313POPLIBS_DEFINE_EXPR_UNARY_OP(GeluErf, GELU_ERF)
314POPLIBS_DEFINE_EXPR_UNARY_OP(Inv, INVERSE)
315POPLIBS_DEFINE_EXPR_UNARY_OP(IsFinite, IS_FINITE)
316POPLIBS_DEFINE_EXPR_UNARY_OP(IsInf, IS_INF)
317POPLIBS_DEFINE_EXPR_UNARY_OP(IsNaN, IS_NAN)
318POPLIBS_DEFINE_EXPR_UNARY_OP(Log, LOGARITHM)
319POPLIBS_DEFINE_EXPR_UNARY_OP(Log1p, LOGARITHM_ONE_PLUS)
320POPLIBS_DEFINE_EXPR_UNARY_OP_AND_SYMBOL(Not, LOGICAL_NOT, !)
321POPLIBS_DEFINE_EXPR_UNARY_OP_AND_SYMBOL(Neg, NEGATE, -)
322POPLIBS_DEFINE_EXPR_UNARY_OP(Signum, SIGNUM)
323POPLIBS_DEFINE_EXPR_UNARY_OP(Sin, SIN)
324POPLIBS_DEFINE_EXPR_UNARY_OP(Tan, TAN)
325POPLIBS_DEFINE_EXPR_UNARY_OP(Tanh, TANH)
326POPLIBS_DEFINE_EXPR_UNARY_OP(Round, ROUND)
327POPLIBS_DEFINE_EXPR_UNARY_OP(Trunc, TRUNC)
328POPLIBS_DEFINE_EXPR_UNARY_OP(Sqrt, SQRT)
329POPLIBS_DEFINE_EXPR_UNARY_OP(Square, SQUARE)
330POPLIBS_DEFINE_EXPR_UNARY_OP(Sigmoid, SIGMOID)
331POPLIBS_DEFINE_EXPR_UNARY_OP(Rsqrt, RSQRT)
332
334class BinaryOp : public ExprType<BinaryOp> {
335 BinaryOpType type;
336 std::unique_ptr<Expr> a, b;
337
338public:
339 BinaryOp(BinaryOpType type, const Expr &a, const Expr &b)
340 : type(type), a(a.clone()), b(b.clone()) {}
341 BinaryOp(BinaryOp &&) = default;
342 BinaryOp &operator=(BinaryOp &&) = default;
343 BinaryOp(const BinaryOp &other) : BinaryOp(other.type, *other.a, *other.b) {}
344 BinaryOp &operator=(const BinaryOp &other) {
345 BinaryOp tmp{other};
346 std::swap(*this, tmp);
347 return *this;
348 }
349
350 BinaryOpType getOpType() const { return type; }
351
352 const Expr &getLHS() const { return *a; }
353 const Expr &getRHS() const { return *b; }
354
355 std::unique_ptr<Expr> clone() const override {
356 return std::unique_ptr<Expr>(new BinaryOp(type, *a, *b));
357 }
358 std::string name(const std::vector<poplar::Tensor> &inputs) const override;
359 std::string exprName(const std::vector<poplar::Tensor> &inputs) const {
360 return a->name(inputs) + "_" + b->name(inputs);
361 }
362
363 bool deepEquals(const Expr &other) const override;
364
365 void print(std::ostream &os, unsigned indent = 0,
366 bool prettyPrint = true) const override;
367};
368
369#define POPLIBS_DEFINE_EXPR_BINARY_OP(Name, Op) \
370 class Name : public BinaryOp { \
371 public: \
372 Name(const Expr &a, const Expr &b) : BinaryOp(BinaryOpType::Op, a, b) {} \
373 };
374
375#define POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Name, Op, Sym) \
376 POPLIBS_DEFINE_EXPR_BINARY_OP(Name, Op) \
377 template <typename T> \
378 inline typename std::enable_if<!std::is_base_of<Expr, T>::value && \
379 poplar::TypeTraits::isSimpleType<T>(), \
380 Name>::type \
381 operator Sym(const T &a, const Expr &b) { \
382 return Name(Const(a), b); \
383 } \
384 template <typename T> \
385 inline typename std::enable_if<!std::is_base_of<Expr, T>::value && \
386 poplar::TypeTraits::isSimpleType<T>(), \
387 Name>::type \
388 operator Sym(const Expr &a, const T &b) { \
389 return Name(a, Const(b)); \
390 } \
391 inline Name operator Sym(const Expr &a, const Expr &b) { return Name(a, b); }
392
393POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Add, ADD, +)
394POPLIBS_DEFINE_EXPR_BINARY_OP(Atan2, ATAN2)
395POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(BitwiseAnd, BITWISE_AND, &)
396POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(BitwiseOr, BITWISE_OR, |)
397POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(BitwiseXor, BITWISE_XOR, ^)
398POPLIBS_DEFINE_EXPR_BINARY_OP(BitwiseXnor, BITWISE_XNOR)
399POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Divide, DIVIDE, /)
400POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Equal, EQUAL, ==)
401POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Gte, GREATER_THAN_EQUAL, >=)
402POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Gt, GREATER_THAN, >)
403POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Lte, LESS_THAN_EQUAL, <=)
404POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(And, LOGICAL_AND, &&)
405POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Or, LOGICAL_OR, ||)
406POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Lt, LESS_THAN, <)
407POPLIBS_DEFINE_EXPR_BINARY_OP(InvStdDevToVariance, INV_STD_DEV_TO_VARIANCE)
408POPLIBS_DEFINE_EXPR_BINARY_OP(Max, MAXIMUM)
409POPLIBS_DEFINE_EXPR_BINARY_OP(Min, MINIMUM)
410POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Mul, MULTIPLY, *)
411POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(NotEqual, NOT_EQUAL, !=)
412POPLIBS_DEFINE_EXPR_BINARY_OP(Pow, POWER)
413POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Rem, REMAINDER, %)
414POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Shl, SHIFT_LEFT, <<)
415POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Shr, SHIFT_RIGHT, >>)
416POPLIBS_DEFINE_EXPR_BINARY_OP(ShrSE, SHIFT_RIGHT_SIGN_EXTEND)
417POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Sub, SUBTRACT, -)
418POPLIBS_DEFINE_EXPR_BINARY_OP(VarianceToInvStdDev, VARIANCE_TO_INV_STD_DEV)
419
421class TernaryOp : public ExprType<TernaryOp> {
422 TernaryOpType type;
423 std::unique_ptr<Expr> a, b, c;
424
425public:
426 TernaryOp(TernaryOpType type, const Expr &a, const Expr &b, const Expr &c)
427 : type(type), a(a.clone()), b(b.clone()), c(c.clone()) {}
428 TernaryOp(TernaryOp &&) = default;
429 TernaryOp &operator=(TernaryOp &&) = default;
430 TernaryOp(const TernaryOp &other)
431 : TernaryOp(other.type, *other.a, *other.b, *other.c) {}
432 TernaryOp &operator=(const TernaryOp &other) {
433 TernaryOp tmp{other};
434 std::swap(*this, tmp);
435 return *this;
436 }
437
438 TernaryOpType getOpType() const { return type; }
439
440 const Expr &getArg0() const { return *a; }
441 const Expr &getArg1() const { return *b; }
442 const Expr &getArg2() const { return *c; }
443
444 std::unique_ptr<Expr> clone() const override {
445 return std::unique_ptr<Expr>(new TernaryOp(type, *a, *b, *c));
446 }
447 std::string name(const std::vector<poplar::Tensor> &inputs) const override;
448 std::string exprName(const std::vector<poplar::Tensor> &inputs) const {
449 return a->name(inputs) + "_" + b->name(inputs) + "_" + c->name(inputs);
450 }
451
452 bool deepEquals(const Expr &other) const override;
453
454 void print(std::ostream &os, unsigned indent = 0,
455 bool prettyPrint = true) const override;
456};
457
458#define POPLIBS_DEFINE_EXPR_TERNARY_OP(Name, Op) \
459 class Name : public TernaryOp { \
460 public: \
461 Name(const Expr &a, const Expr &b, const Expr &c) \
462 : TernaryOp(TernaryOpType::Op, a, b, c) {} \
463 };
464
471POPLIBS_DEFINE_EXPR_TERNARY_OP(Select, SELECT)
472POPLIBS_DEFINE_EXPR_TERNARY_OP(Clamp, CLAMP)
473
474} // namespace expr
475} // namespace popops
476
477#endif // __popops_Expr_hpp__
Operators used in expressions with elements of tensors.
TernaryOpType
Enumeration defining operators used by Expr for building expressions.
Definition: ExprOp.hpp:16
Class representing device data types.
Definition: Type.hpp:42
A class that can contain any expression, useful for building up expression trees dynamically where th...
Definition: Expr.hpp:83
A class to represent expressions with binary operators.
Definition: Expr.hpp:334
A class to represent cast expressions.
Definition: Expr.hpp:181
A class to represent constant expressions of type half.
Definition: Expr.hpp:162
A class to represent constant expressions.
Definition: Expr.hpp:97
Type to represent element expressions.
Definition: Expr.hpp:36
Computes the conditional ternary operation.
Definition: Expr.hpp:471
A class to represent expressions with ternary operators.
Definition: Expr.hpp:421
A class to represent expressions with unary operators.
Definition: Expr.hpp:260
half2 max(half2 src0, half2 src1)
Targets the f16v2max instruction.
Definition: ipu_intrinsics:333
Type HALF
Device type: half
Common functions, such as elementwise and reductions.
Definition: AllTrue.hpp:15
std::ostream & operator<<(std::ostream &os, const Operation &op)
Write op to output stream os.
A structure to provide information about arithmetic (integer and floating point) types.
Definition: TypeTraits.hpp:22
Template structure to relate a host type to a device type.
Definition: Type.hpp:192