Poplar and PopLibs
StreamCallback.hpp
1// Copyright (c) 2019 Graphcore Ltd. All rights reserved.
2
3#ifndef poplar_StreamCallback_hpp
4#define poplar_StreamCallback_hpp
5
6#include <poplar/CallbackTraits.hpp>
7#include <poplar/Quarter.hpp>
8#include <poplar/SSOPointer.hpp>
9#include <poplar/exceptions.hpp>
10
11#include <memory>
12
13namespace poplar {
14
15// Stream callback interface base class for all data types.
16class StreamCallbackBase {
17public:
18 enum class Result { Success, NotAvailable };
19
20 virtual ~StreamCallbackBase() = default;
21
26 virtual void complete(){};
27
31 virtual void invalidatePrefetched(){};
32};
33
43class StreamCallback : public StreamCallbackBase {
44public:
45 virtual ~StreamCallback() = default;
46
78 virtual Result prefetch(void *p) = 0;
79
87 virtual void fetch(void *) = 0;
88};
89
94public:
95 virtual ~StreamCallbackWithMetadata() = default;
96
101 QuarterMetadata getMetadata() const { return metadata; };
102
107 void setMetadata(QuarterMetadata md) { metadata = md; };
108
109private:
110 QuarterMetadata metadata;
111};
112
118public:
121
125 void notify();
126
129 bool isAwaiting() const;
130
131protected:
135 void wait();
136
137private:
138 struct ConditionVar;
139 SSOPointer<ConditionVar> conditionVar;
140};
141
147public:
149 virtual Result prefetch(void *) final override {
150 return Result::NotAvailable;
151 }
153 virtual void invalidatePrefetched() final override {}
155 virtual void complete() final override {}
156};
157
162public:
169 template <class CallbackImpl,
170 typename = typename std::enable_if<
171 std::is_base_of<StreamCallback, CallbackImpl>::value ||
172 std::is_base_of<StreamCallbackWithMetadata,
173 CallbackImpl>::value>::type>
174 StreamCallbackHandle(std::unique_ptr<CallbackImpl> f)
175 : callback(std::move(f)) {
176 if (!callback) {
177 throw poplar_error("Invalid null stream callback");
178 }
179 }
180
187 template <class F, typename = typename std::enable_if<
188 traits::is_callback<F>::value>::type>
189 StreamCallbackHandle(F &&f) : callback(makeCallback(std::forward<F>(f))) {}
190
191 // Non copy constructible
193
194 // Move constructible
196
197 // This forces the user to provide a non-empty callback and allows the
198 // internal implementation to only make use of std::unique_ptr<StreamCallback>
199 // if necessary.
202 operator std::unique_ptr<StreamCallbackBase>() && {
203 return std::move(callback);
204 }
205
212 operator std::unique_ptr<StreamCallback>() &&;
213
220 operator std::unique_ptr<StreamCallbackWithMetadata>() &&;
221
222private:
223 template <class F>
224 static std::unique_ptr<StreamCallback> makeCallback(F &&f) {
225 struct NonPrefetchable final : LegacyStreamCallback {
226 using Function = typename traits::remove_cvref<F>::type;
227
228 NonPrefetchable(F &&f) : function(std::forward<F>(f)) {}
229 void fetch(void *p) override { function(p); }
230 Function function;
231 };
232 return std::unique_ptr<NonPrefetchable>(
233 new NonPrefetchable(std::forward<F>(f)));
234 }
235
236 std::unique_ptr<StreamCallbackBase> callback;
237};
238
239} // end namespace poplar
240
241#endif // poplar_StreamCallback_hpp
A reference to a function stored within a graph.
Definition: GraphElements.hpp:148
Convenience StreamCallback specialization for implementations that do not support prefetch/complete o...
Definition: StreamCallback.hpp:146
virtual void invalidatePrefetched() final override
Not available in legacy streams.
Definition: StreamCallback.hpp:153
virtual void complete() final override
Not available in legacy streams.
Definition: StreamCallback.hpp:155
virtual Result prefetch(void *) final override
Not available in legacy streams.
Definition: StreamCallback.hpp:149
Quarter metadata type.
Definition: Quarter.hpp:37
Expands StreamCallback API with functions that prevent further progress.
Definition: StreamCallback.hpp:117
bool isAwaiting() const
Returns whether the callback task is waiting to be notified.
void notify()
Schedules the callback task back to execution.
void wait()
Calling this function will suspend the execution of the current callback.
Wrapper for StreamCallback instances.
Definition: StreamCallback.hpp:161
StreamCallbackHandle(std::unique_ptr< CallbackImpl > f)
Constructs a handle from an instance of a stream callback implementation.
Definition: StreamCallback.hpp:174
StreamCallbackHandle(F &&f)
Constructs a handle from a callable instance.
Definition: StreamCallback.hpp:189
Interface used to add support for stream copies to produce/consume data if the data type requires met...
Definition: StreamCallback.hpp:93
QuarterMetadata getMetadata() const
Get the binary representation of metadata on the host.
Definition: StreamCallback.hpp:101
void setMetadata(QuarterMetadata md)
Set the binary representation of metadata on the host.
Definition: StreamCallback.hpp:107
Interface used during stream copies to produce/consume the data being exchanged between the host and ...
Definition: StreamCallback.hpp:43
virtual Result prefetch(void *p)=0
Callback function to fill the host buffer (host-to-device streams only).
virtual void fetch(void *)=0
Callback function to fill the host buffer.
Poplar classes and functions.
Definition: ArrayRef.hpp:14
Base class for Poplar exceptions.
Definition: exceptions.hpp:16