8#ifndef __popops_Expr_hpp__
9#define __popops_Expr_hpp__
12#include <poplar/Tensor.hpp>
13#include <poplar/Type.hpp>
14#include <poplar/TypeTraits.hpp>
38 using ExprClassID = void (*)(void);
40 Expr(ExprClassID classId) : classId(classId) {}
45 template <
class T>
bool isA()
const {
return classId == T::getClassId(); }
47 template <
class T> T *getAs() {
50 return static_cast<T *
>(
this);
53 template <
class T>
const T *getAs()
const {
56 return static_cast<const T *
>(
this);
59 virtual std::unique_ptr<Expr> clone()
const = 0;
61 virtual std::string name(
const std::vector<poplar::Tensor> &)
const = 0;
63 virtual bool deepEquals(
const Expr &other)
const = 0;
65 virtual void print(std::ostream &os,
unsigned indent = 0,
66 bool prettyPrint =
true)
const = 0;
70bool deepEquals(
const Expr &a,
const Expr &b);
72template <
class T>
class ExprType :
public Expr {
74 static ExprClassID getClassId() {
return &loc; }
77 ExprType() : Expr(getClassId()) {}
84 std::unique_ptr<Expr> expr;
87 Any(
const Expr &expr) : expr(expr.clone()) {}
89 operator Expr &() {
return *expr; }
90 operator const Expr &()
const {
return *expr; }
91 std::string name(
const std::vector<poplar::Tensor> &inputs)
const {
92 return expr->name(inputs);
97class Const :
public ExprType<Const> {
100 std::unique_ptr<char[]> data;
103 template <
typename T>
Const(T x,
bool isHalfType) {
104 static_assert(std::is_integral<T>::value ||
105 std::is_floating_point<T>::value,
106 "Constant expression values should be integrals or floats");
107 typeTraits = poplar::TypeTraits::make<T>();
113 data.reset(
new char[typeTraits.size]);
114 const char *p =
reinterpret_cast<const char *
>(&x);
115 std::copy(p, p + typeTraits.size, data.get());
119 template <
typename T,
typename =
typename std::enable_if<
120 poplar::TypeTraits::isSimpleType<T>(), T>::type>
124 : typeTraits(std::move(typeTraits_)), type(type_) {
125 data.reset(
new char[typeTraits.size]);
126 std::copy(data_, data_ + typeTraits.size, data.get());
131 :
Const(other.typeTraits, other.type, other.data.get()) {}
134 std::swap(*
this, tmp);
138 char *getData()
const {
return data.get(); }
144 std::string printValue()
const;
146 double getDataAsDouble()
const;
148 std::uint64_t getDataForUnsignedIntegral()
const;
150 std::unique_ptr<Expr> clone()
const override {
151 return std::unique_ptr<Expr>(
new Const(typeTraits, type, data.get()));
153 std::string name(
const std::vector<poplar::Tensor> &)
const override;
155 bool deepEquals(
const Expr &other)
const override;
157 void print(std::ostream &os,
unsigned indent = 0,
158 bool prettyPrint =
true)
const override;
170 std::swap(*
this, tmp);
175inline ConstHalf operator"" _half(
long double x) {
181class Cast :
public ExprType<Cast> {
182 std::unique_ptr<Expr> a;
187 : a(a_.clone()), bType(bType_) {}
189 Cast &operator=(
Cast &&other) =
default;
190 Cast(
const Cast &other) :
Cast(*other.a, other.bType) {}
191 Cast &operator==(
const Cast &other) {
193 std::swap(*
this, tmp);
197 const Expr &getLHS()
const {
return *a; }
198 const poplar::Type &getRHSType()
const {
return bType; }
200 std::unique_ptr<Expr> clone()
const override {
201 return std::unique_ptr<Expr>(
new Cast(*a, bType));
203 std::string name(
const std::vector<poplar::Tensor> &inputs)
const override;
205 bool deepEquals(
const Expr &other)
const override;
207 void print(std::ostream &os,
unsigned indent = 0,
208 bool prettyPrint =
true)
const override;
211class PlaceHolder :
public ExprType<PlaceHolder> {
215 PlaceHolder(
unsigned index) : index(index) {}
216 PlaceHolder(PlaceHolder &&) =
default;
217 PlaceHolder &operator=(PlaceHolder &&) =
default;
218 PlaceHolder(
const PlaceHolder &other) : PlaceHolder(other.index) {}
219 PlaceHolder &operator=(
const PlaceHolder &other) {
220 PlaceHolder tmp{other};
221 std::swap(*
this, tmp);
225 unsigned getIndex()
const {
return index; }
227 std::unique_ptr<Expr> clone()
const override {
228 return std::unique_ptr<Expr>(
new PlaceHolder(index));
230 std::string name(
const std::vector<poplar::Tensor> &inputs)
const override;
232 bool deepEquals(
const Expr &other)
const override;
234 void print(std::ostream &os,
unsigned indent = 0,
235 bool prettyPrint =
true)
const override;
238const PlaceHolder _1(1);
239const PlaceHolder _2(2);
240const PlaceHolder _3(3);
241const PlaceHolder _4(4);
242const PlaceHolder _5(5);
243const PlaceHolder _6(6);
244const PlaceHolder _7(7);
245const PlaceHolder _8(8);
246const PlaceHolder _9(9);
247const PlaceHolder _10(10);
248const PlaceHolder _11(11);
249const PlaceHolder _12(12);
250const PlaceHolder _13(13);
251const PlaceHolder _14(14);
252const PlaceHolder _15(15);
253const PlaceHolder _16(16);
254const PlaceHolder _17(17);
255const PlaceHolder _18(18);
256const PlaceHolder _19(19);
257const PlaceHolder _20(20);
262 std::unique_ptr<Expr> a;
265 UnaryOp(UnaryOpType type,
const Expr &a) : type(type), a(a.clone()) {}
271 std::swap(*
this, tmp);
275 UnaryOpType getOpType()
const {
return type; }
277 const Expr &getArg()
const {
return *a; }
279 std::unique_ptr<Expr> clone()
const override {
280 return std::unique_ptr<Expr>(
new UnaryOp(type, *a));
282 std::string name(
const std::vector<poplar::Tensor> &inputs)
const override;
283 std::string exprName(
const std::vector<poplar::Tensor> &inputs)
const {
284 return a->name(inputs);
287 bool deepEquals(
const Expr &other)
const override;
289 void print(std::ostream &os,
unsigned indent = 0,
290 bool prettyPrint =
true)
const override;
293#define POPLIBS_DEFINE_EXPR_UNARY_OP(Name, Op) \
294 class Name : public UnaryOp { \
296 Name(const Expr &a) : UnaryOp(UnaryOpType::Op, a) {} \
299#define POPLIBS_DEFINE_EXPR_UNARY_OP_AND_SYMBOL(Name, Op, Sym) \
300 POPLIBS_DEFINE_EXPR_UNARY_OP(Name, Op) \
301 inline Name operator Sym(const Expr &a) { return Name(a); }
303POPLIBS_DEFINE_EXPR_UNARY_OP(Abs, ABSOLUTE)
304POPLIBS_DEFINE_EXPR_UNARY_OP(Asin, ASIN)
305POPLIBS_DEFINE_EXPR_UNARY_OP_AND_SYMBOL(BitwiseNot, BITWISE_NOT, ~)
306POPLIBS_DEFINE_EXPR_UNARY_OP(Cbrt, CBRT)
307POPLIBS_DEFINE_EXPR_UNARY_OP(Erf, ERF)
308POPLIBS_DEFINE_EXPR_UNARY_OP(Ceil, CEIL)
309POPLIBS_DEFINE_EXPR_UNARY_OP(Cos, COS)
310POPLIBS_DEFINE_EXPR_UNARY_OP(Exp, EXPONENT)
311POPLIBS_DEFINE_EXPR_UNARY_OP(Expm1, EXPONENT_MINUS_ONE)
312POPLIBS_DEFINE_EXPR_UNARY_OP(Floor, FLOOR)
313POPLIBS_DEFINE_EXPR_UNARY_OP(GeluErf, GELU_ERF)
314POPLIBS_DEFINE_EXPR_UNARY_OP(Inv, INVERSE)
315POPLIBS_DEFINE_EXPR_UNARY_OP(IsFinite, IS_FINITE)
316POPLIBS_DEFINE_EXPR_UNARY_OP(IsInf, IS_INF)
317POPLIBS_DEFINE_EXPR_UNARY_OP(IsNaN, IS_NAN)
318POPLIBS_DEFINE_EXPR_UNARY_OP(Log, LOGARITHM)
319POPLIBS_DEFINE_EXPR_UNARY_OP(Log1p, LOGARITHM_ONE_PLUS)
320POPLIBS_DEFINE_EXPR_UNARY_OP_AND_SYMBOL(Not, LOGICAL_NOT, !)
321POPLIBS_DEFINE_EXPR_UNARY_OP_AND_SYMBOL(Neg, NEGATE, -)
322POPLIBS_DEFINE_EXPR_UNARY_OP(Signum, SIGNUM)
323POPLIBS_DEFINE_EXPR_UNARY_OP(Sin, SIN)
324POPLIBS_DEFINE_EXPR_UNARY_OP(Tan, TAN)
325POPLIBS_DEFINE_EXPR_UNARY_OP(Tanh, TANH)
326POPLIBS_DEFINE_EXPR_UNARY_OP(Round, ROUND)
327POPLIBS_DEFINE_EXPR_UNARY_OP(Trunc, TRUNC)
328POPLIBS_DEFINE_EXPR_UNARY_OP(Sqrt, SQRT)
329POPLIBS_DEFINE_EXPR_UNARY_OP(Square, SQUARE)
330POPLIBS_DEFINE_EXPR_UNARY_OP(Sigmoid, SIGMOID)
331POPLIBS_DEFINE_EXPR_UNARY_OP(Rsqrt, RSQRT)
336 std::unique_ptr<Expr> a, b;
340 : type(type), a(a.clone()), b(b.clone()) {}
346 std::swap(*
this, tmp);
350 BinaryOpType getOpType()
const {
return type; }
352 const Expr &getLHS()
const {
return *a; }
353 const Expr &getRHS()
const {
return *b; }
355 std::unique_ptr<Expr> clone()
const override {
356 return std::unique_ptr<Expr>(
new BinaryOp(type, *a, *b));
358 std::string name(
const std::vector<poplar::Tensor> &inputs)
const override;
359 std::string exprName(
const std::vector<poplar::Tensor> &inputs)
const {
360 return a->name(inputs) +
"_" + b->name(inputs);
363 bool deepEquals(
const Expr &other)
const override;
365 void print(std::ostream &os,
unsigned indent = 0,
366 bool prettyPrint =
true)
const override;
369#define POPLIBS_DEFINE_EXPR_BINARY_OP(Name, Op) \
370 class Name : public BinaryOp { \
372 Name(const Expr &a, const Expr &b) : BinaryOp(BinaryOpType::Op, a, b) {} \
375#define POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Name, Op, Sym) \
376 POPLIBS_DEFINE_EXPR_BINARY_OP(Name, Op) \
377 template <typename T> \
378 inline typename std::enable_if<!std::is_base_of<Expr, T>::value && \
379 poplar::TypeTraits::isSimpleType<T>(), \
381 operator Sym(const T &a, const Expr &b) { \
382 return Name(Const(a), b); \
384 template <typename T> \
385 inline typename std::enable_if<!std::is_base_of<Expr, T>::value && \
386 poplar::TypeTraits::isSimpleType<T>(), \
388 operator Sym(const Expr &a, const T &b) { \
389 return Name(a, Const(b)); \
391 inline Name operator Sym(const Expr &a, const Expr &b) { return Name(a, b); }
393POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Add, ADD, +)
394POPLIBS_DEFINE_EXPR_BINARY_OP(Atan2, ATAN2)
395POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(BitwiseAnd, BITWISE_AND, &)
396POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(BitwiseOr, BITWISE_OR, |)
397POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(BitwiseXor, BITWISE_XOR, ^)
398POPLIBS_DEFINE_EXPR_BINARY_OP(BitwiseXnor, BITWISE_XNOR)
399POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Divide, DIVIDE, /)
400POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Equal, EQUAL, ==)
401POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Gte, GREATER_THAN_EQUAL, >=)
402POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Gt, GREATER_THAN, >)
403POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Lte, LESS_THAN_EQUAL, <=)
404POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(And, LOGICAL_AND, &&)
405POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Or, LOGICAL_OR, ||)
406POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Lt, LESS_THAN, <)
407POPLIBS_DEFINE_EXPR_BINARY_OP(InvStdDevToVariance, INV_STD_DEV_TO_VARIANCE)
408POPLIBS_DEFINE_EXPR_BINARY_OP(Max, MAXIMUM)
409POPLIBS_DEFINE_EXPR_BINARY_OP(Min, MINIMUM)
410POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Mul, MULTIPLY, *)
411POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(NotEqual, NOT_EQUAL, !=)
412POPLIBS_DEFINE_EXPR_BINARY_OP(Pow, POWER)
413POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Rem, REMAINDER, %)
414POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Shl, SHIFT_LEFT, <<)
415POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Shr, SHIFT_RIGHT, >>)
416POPLIBS_DEFINE_EXPR_BINARY_OP(ShrSE, SHIFT_RIGHT_SIGN_EXTEND)
417POPLIBS_DEFINE_EXPR_BINARY_OP_AND_SYMBOL(Sub, SUBTRACT, -)
418POPLIBS_DEFINE_EXPR_BINARY_OP(VarianceToInvStdDev, VARIANCE_TO_INV_STD_DEV)
423 std::unique_ptr<Expr> a, b, c;
427 : type(type), a(a.clone()), b(b.clone()), c(c.clone()) {}
431 :
TernaryOp(other.type, *other.a, *other.b, *other.c) {}
434 std::swap(*
this, tmp);
440 const Expr &getArg0()
const {
return *a; }
441 const Expr &getArg1()
const {
return *b; }
442 const Expr &getArg2()
const {
return *c; }
444 std::unique_ptr<Expr> clone()
const override {
445 return std::unique_ptr<Expr>(
new TernaryOp(type, *a, *b, *c));
447 std::string name(
const std::vector<poplar::Tensor> &inputs)
const override;
448 std::string exprName(
const std::vector<poplar::Tensor> &inputs)
const {
449 return a->name(inputs) +
"_" + b->name(inputs) +
"_" + c->name(inputs);
452 bool deepEquals(
const Expr &other)
const override;
454 void print(std::ostream &os,
unsigned indent = 0,
455 bool prettyPrint =
true)
const override;
458#define POPLIBS_DEFINE_EXPR_TERNARY_OP(Name, Op) \
459 class Name : public TernaryOp { \
461 Name(const Expr &a, const Expr &b, const Expr &c) \
462 : TernaryOp(TernaryOpType::Op, a, b, c) {} \
471POPLIBS_DEFINE_EXPR_TERNARY_OP(
Select, SELECT)
472POPLIBS_DEFINE_EXPR_TERNARY_OP(Clamp, CLAMP)
Operators used in expressions with elements of tensors.
TernaryOpType
Enumeration defining operators used by Expr for building expressions.
Definition: ExprOp.hpp:16
Class representing device data types.
Definition: Type.hpp:42
A class that can contain any expression, useful for building up expression trees dynamically where th...
Definition: Expr.hpp:83
A class to represent expressions with binary operators.
Definition: Expr.hpp:334
A class to represent cast expressions.
Definition: Expr.hpp:181
A class to represent constant expressions of type half.
Definition: Expr.hpp:162
A class to represent constant expressions.
Definition: Expr.hpp:97
Type to represent element expressions.
Definition: Expr.hpp:36
Computes the conditional ternary operation.
Definition: Expr.hpp:471
A class to represent expressions with ternary operators.
Definition: Expr.hpp:421
A class to represent expressions with unary operators.
Definition: Expr.hpp:260
half2 max(half2 src0, half2 src1)
Targets the f16v2max instruction.
Definition: ipu_intrinsics:333
Type HALF
Device type: half
Common functions, such as elementwise and reductions.
Definition: AllTrue.hpp:15
std::ostream & operator<<(std::ostream &os, const Operation &op)
Write op to output stream os.
A structure to provide information about arithmetic (integer and floating point) types.
Definition: TypeTraits.hpp:22
Template structure to relate a host type to a device type.
Definition: Type.hpp:192