8#ifndef poplin_ConvParams_hpp
9#define poplin_ConvParams_hpp
10#include "poplar/Type.hpp"
17 struct InputTransform {
20 std::vector<unsigned> truncationLower;
23 std::vector<unsigned> truncationUpper;
27 std::vector<unsigned> dilation;
30 std::vector<unsigned> paddingLower;
33 std::vector<unsigned> paddingUpper;
35 std::vector<bool> flip;
37 InputTransform() =
default;
38 InputTransform(
const std::size_t size);
50 InputTransform(std::vector<unsigned> truncationLower,
51 std::vector<unsigned> truncationUpper,
52 std::vector<unsigned> dilation,
53 std::vector<unsigned> paddingLower,
54 std::vector<unsigned> paddingUpper, std::vector<bool> flip);
56 friend bool operator<(
const InputTransform &a,
const InputTransform &b);
57 friend bool operator==(
const InputTransform &a,
const InputTransform &b);
58 friend bool operator!=(
const InputTransform &a,
const InputTransform &b);
61 struct OutputTransform {
64 std::vector<unsigned> truncationLower;
67 std::vector<unsigned> truncationUpper;
70 std::vector<unsigned> stride;
72 std::vector<unsigned> paddingLower;
74 std::vector<unsigned> paddingUpper;
76 OutputTransform() =
default;
77 OutputTransform(
const std::size_t size);
85 OutputTransform(std::vector<unsigned> truncationLower,
86 std::vector<unsigned> truncationUpper,
87 std::vector<unsigned> striding,
88 std::vector<unsigned> paddingLower,
89 std::vector<unsigned> paddingUpper);
91 friend bool operator<(
const OutputTransform &a,
const OutputTransform &b);
92 friend bool operator==(
const OutputTransform &a,
const OutputTransform &b);
93 friend bool operator!=(
const OutputTransform &a,
const OutputTransform &b);
99 std::size_t batchSize;
101 std::vector<std::size_t> inputFieldShape;
103 std::vector<std::size_t> kernelShape;
105 std::size_t inputChannelsPerConvGroup;
107 std::size_t outputChannelsPerConvGroup;
112 std::size_t numConvGroups;
115 InputTransform inputTransform;
117 InputTransform kernelTransform;
119 OutputTransform outputTransform;
121 ConvParams() =
default;
131 ConvParams(
poplar::Type dataType, std::size_t batchSize,
132 std::vector<std::size_t> inputFieldShape,
133 std::vector<std::size_t> kernelShape, std::size_t inputChannels,
134 std::size_t outputChannels, std::size_t numConvGroups);
146 std::size_t batchSize, std::vector<std::size_t> inputFieldShape,
147 std::vector<std::size_t> kernelShape, std::size_t inputChannels,
148 std::size_t outputChannels, std::size_t numConvGroups);
163 std::size_t batchSize, std::vector<std::size_t> inputFieldShape,
164 std::vector<std::size_t> kernelShape, std::size_t inputChannels,
165 std::size_t outputChannels, std::size_t numConvGroups,
166 InputTransform inputTransform, InputTransform kernelTransform,
167 OutputTransform outputTransform);
171 std::size_t getUntransformedOutputSize(
unsigned dim)
const;
173 std::size_t getOutputSize(
unsigned dim)
const;
175 std::size_t getNumOutputChansPerConvGroup()
const {
176 return outputChannelsPerConvGroup;
179 std::size_t getNumOutputChans()
const {
180 return outputChannelsPerConvGroup * numConvGroups;
183 std::size_t getInputSize(
unsigned dim)
const {
return inputFieldShape[dim]; }
185 std::size_t getNumInputChansPerConvGroup()
const {
186 return inputChannelsPerConvGroup;
189 std::size_t getNumInputChans()
const {
190 return inputChannelsPerConvGroup * numConvGroups;
193 std::size_t getNumConvGroups()
const {
return numConvGroups; }
195 std::size_t getNumFieldDims()
const {
return inputFieldShape.size(); }
197 std::vector<std::size_t> getInputFieldShape()
const {
198 return inputFieldShape;
201 std::vector<std::size_t> getKernelShape()
const {
return kernelShape; }
203 std::size_t getBatchSize()
const {
return batchSize; }
206 unsigned getTruncatedInputSize(
unsigned dim)
const;
208 unsigned getTruncatedKernelSize(
unsigned dim)
const;
211 unsigned getTransformedInputSize(
unsigned dim)
const;
214 unsigned getTransformedKernelSize(
unsigned dim)
const;
216 std::vector<size_t> getOutputFieldShape()
const;
218 void validate()
const;
219 ConvParams canonicalize()
const;
221 friend bool operator<(
const ConvParams &a,
const ConvParams &b);
222 friend bool operator==(
const ConvParams &a,
const ConvParams &b);
223 friend bool operator!=(
const ConvParams &a,
const ConvParams &b);
226std::ostream &
operator<<(std::ostream &os,
const ConvParams &p);
227std::istream &
operator>>(std::istream &is, ConvParams &p);
229std::size_t
hash_value(
const ConvParams::InputTransform &it);
230std::size_t
hash_value(
const ConvParams::OutputTransform &ot);
236template <>
struct hash<
poplin::ConvParams::InputTransform> {
237 std::size_t operator()(
const poplin::ConvParams::InputTransform &it)
const;
240template <>
struct hash<
poplin::ConvParams::OutputTransform> {
241 std::size_t operator()(
const poplin::ConvParams::OutputTransform &ot)
const;
244template <>
struct hash<
poplin::ConvParams> {
245 std::size_t operator()(
const poplin::ConvParams ¶ms)
const;
Class representing device data types.
Definition: Type.hpp:42
std::istream & operator>>(std::istream &is, CollectiveOperator &op)
Parse token from input stream is to op.
std::ostream & operator<<(std::ostream &os, const CollectiveOperator &op)
Write op to output stream os.
std::size_t hash_value(const EngineOptions &options)
Obtain a hash value for an EngineOptions object.
Linear algebra functions.
Definition: Cholesky.hpp:14