Poplar and PopLibs
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ConvParams.hpp
Go to the documentation of this file.
1// Copyright (c) 2019 Graphcore Ltd. All rights reserved.
8#ifndef poplin_ConvParams_hpp
9#define poplin_ConvParams_hpp
10#include "poplar/Type.hpp"
11#include <tuple>
12#include <vector>
13
14namespace poplin {
15
16struct ConvParams {
17 struct InputTransform {
20 std::vector<unsigned> truncationLower;
23 std::vector<unsigned> truncationUpper;
27 std::vector<unsigned> dilation;
30 std::vector<unsigned> paddingLower;
33 std::vector<unsigned> paddingUpper;
35 std::vector<bool> flip;
36
37 InputTransform() = default;
38 InputTransform(const std::size_t size);
50 InputTransform(std::vector<unsigned> truncationLower,
51 std::vector<unsigned> truncationUpper,
52 std::vector<unsigned> dilation,
53 std::vector<unsigned> paddingLower,
54 std::vector<unsigned> paddingUpper, std::vector<bool> flip);
55
56 friend bool operator<(const InputTransform &a, const InputTransform &b);
57 friend bool operator==(const InputTransform &a, const InputTransform &b);
58 friend bool operator!=(const InputTransform &a, const InputTransform &b);
59 };
60
61 struct OutputTransform {
64 std::vector<unsigned> truncationLower;
67 std::vector<unsigned> truncationUpper;
70 std::vector<unsigned> stride;
72 std::vector<unsigned> paddingLower;
74 std::vector<unsigned> paddingUpper;
75
76 OutputTransform() = default;
77 OutputTransform(const std::size_t size);
85 OutputTransform(std::vector<unsigned> truncationLower,
86 std::vector<unsigned> truncationUpper,
87 std::vector<unsigned> striding,
88 std::vector<unsigned> paddingLower,
89 std::vector<unsigned> paddingUpper);
90
91 friend bool operator<(const OutputTransform &a, const OutputTransform &b);
92 friend bool operator==(const OutputTransform &a, const OutputTransform &b);
93 friend bool operator!=(const OutputTransform &a, const OutputTransform &b);
94 };
95
96 poplar::Type inputType;
97 poplar::Type outputType;
99 std::size_t batchSize;
101 std::vector<std::size_t> inputFieldShape;
103 std::vector<std::size_t> kernelShape;
105 std::size_t inputChannelsPerConvGroup;
107 std::size_t outputChannelsPerConvGroup;
112 std::size_t numConvGroups;
113
115 InputTransform inputTransform;
117 InputTransform kernelTransform;
119 OutputTransform outputTransform;
120
121 ConvParams() = default;
122 /*
123 * \param dataType
124 * \param batchSize
125 * \param inputFieldShape
126 * \param kernelShape
127 * \param inputChannels
128 * \param outputChannels
129 * \param numConvGroups
130 */
131 ConvParams(poplar::Type dataType, std::size_t batchSize,
132 std::vector<std::size_t> inputFieldShape,
133 std::vector<std::size_t> kernelShape, std::size_t inputChannels,
134 std::size_t outputChannels, std::size_t numConvGroups);
135 /*
136 * \param inputType
137 * \param outputType
138 * \param batchSize
139 * \param inputFieldShape
140 * \param kernelShape
141 * \param inputChannels
142 * \param outputChannels
143 * \param numConvGroups
144 */
145 ConvParams(poplar::Type inputType, poplar::Type outputType,
146 std::size_t batchSize, std::vector<std::size_t> inputFieldShape,
147 std::vector<std::size_t> kernelShape, std::size_t inputChannels,
148 std::size_t outputChannels, std::size_t numConvGroups);
149 /*
150 * \param inputType
151 * \param outputType
152 * \param batchSize
153 * \param inputFieldShape
154 * \param kernelShape
155 * \param inputChannels
156 * \param outputChannels
157 * \param numConvGroups
158 * \param inputTransform
159 * \param kernelTransform
160 * \param outputTransform
161 */
162 ConvParams(poplar::Type inputType, poplar::Type outputType,
163 std::size_t batchSize, std::vector<std::size_t> inputFieldShape,
164 std::vector<std::size_t> kernelShape, std::size_t inputChannels,
165 std::size_t outputChannels, std::size_t numConvGroups,
166 InputTransform inputTransform, InputTransform kernelTransform,
167 OutputTransform outputTransform);
168
171 std::size_t getUntransformedOutputSize(unsigned dim) const;
173 std::size_t getOutputSize(unsigned dim) const;
175 std::size_t getNumOutputChansPerConvGroup() const {
176 return outputChannelsPerConvGroup;
177 }
179 std::size_t getNumOutputChans() const {
180 return outputChannelsPerConvGroup * numConvGroups;
181 }
183 std::size_t getInputSize(unsigned dim) const { return inputFieldShape[dim]; }
185 std::size_t getNumInputChansPerConvGroup() const {
186 return inputChannelsPerConvGroup;
187 }
189 std::size_t getNumInputChans() const {
190 return inputChannelsPerConvGroup * numConvGroups;
191 }
193 std::size_t getNumConvGroups() const { return numConvGroups; }
195 std::size_t getNumFieldDims() const { return inputFieldShape.size(); }
197 std::vector<std::size_t> getInputFieldShape() const {
198 return inputFieldShape;
199 }
201 std::vector<std::size_t> getKernelShape() const { return kernelShape; }
203 std::size_t getBatchSize() const { return batchSize; }
204
206 unsigned getTruncatedInputSize(unsigned dim) const;
208 unsigned getTruncatedKernelSize(unsigned dim) const;
211 unsigned getTransformedInputSize(unsigned dim) const;
214 unsigned getTransformedKernelSize(unsigned dim) const;
216 std::vector<size_t> getOutputFieldShape() const;
217
218 void validate() const;
219 ConvParams canonicalize() const;
220
221 friend bool operator<(const ConvParams &a, const ConvParams &b);
222 friend bool operator==(const ConvParams &a, const ConvParams &b);
223 friend bool operator!=(const ConvParams &a, const ConvParams &b);
224};
225
226std::ostream &operator<<(std::ostream &os, const ConvParams &p);
227std::istream &operator>>(std::istream &is, ConvParams &p);
228
229std::size_t hash_value(const ConvParams::InputTransform &it);
230std::size_t hash_value(const ConvParams::OutputTransform &ot);
231
232} // namespace poplin
233
234namespace std {
235
236template <> struct hash<poplin::ConvParams::InputTransform> {
237 std::size_t operator()(const poplin::ConvParams::InputTransform &it) const;
238};
239
240template <> struct hash<poplin::ConvParams::OutputTransform> {
241 std::size_t operator()(const poplin::ConvParams::OutputTransform &ot) const;
242};
243
244template <> struct hash<poplin::ConvParams> {
245 std::size_t operator()(const poplin::ConvParams &params) const;
246};
247
248} // namespace std
249
250#endif // poplin_ConvParams_hpp
Class representing device data types.
Definition: Type.hpp:42
std::istream & operator>>(std::istream &is, CollectiveOperator &op)
Parse token from input stream is to op.
std::ostream & operator<<(std::ostream &os, const CollectiveOperator &op)
Write op to output stream os.
std::size_t hash_value(const EngineOptions &options)
Obtain a hash value for an EngineOptions object.
Linear algebra functions.
Definition: Cholesky.hpp:14