2#ifndef INCLUDE_GCL_COLLECTIVES_HPP
3#define INCLUDE_GCL_COLLECTIVES_HPP
5#include <poplar/DebugContext.hpp>
6#include <poplar/Graph.hpp>
7#include <poplar/OptionFlags.hpp>
8#include <poplar/Program.hpp>
9#include <poplar/Tensor.hpp>
13#if __cplusplus >= 201603L
17#define GCL_NO_DISCARD [[nodiscard]]
80 unsigned replicaStride = 1);
335 poplar::Graph &graph,
const std::vector<poplar::Tensor> &datas,
371 const std::vector<poplar::Tensor> &datas,
408 poplar::Graph &graph,
const std::vector<poplar::Tensor> &datas,
467 std::vector<poplar::Tensor> &datas,
500 std::vector<poplar::Tensor> &datas,
581 poplar::Graph &graph,
const std::vector<poplar::Tensor> &datas,
598 poplar::Graph &graph,
const std::vector<poplar::Tensor> &datas,
670 const std::vector<poplar::Tensor> &datas,
689 poplar::Graph &graph,
const std::vector<poplar::Tensor> &datas,
690 const std::vector<poplar::Tensor> &destinations,
790 const CommGroup &group = {},
unsigned rootReplica = 0,
804 GC_DEPRECATED_MSG(
"Use get/setTensor() instead")
808 GC_DEPRECATED_MSG("Use get/
setIndex() instead")
812 GC_DEPRECATED_MSG("Use get/
setOffset() instead")
816#pragma GCC diagnostic push
817#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
818#pragma GCC diagnostic ignored "-Wshadow"
844#pragma GCC diagnostic pop
880 GC_DEPRECATED_MSG(
"Use get/setOriginalInput() instead")
884 GC_DEPRECATED_MSG("Use get/
setChunks() instead")
888#pragma GCC diagnostic push
889#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
890#pragma GCC diagnostic ignored "-Wshadow"
914#pragma GCC diagnostic pop
966GC_DEPRECATED_MSG(
"Use Chunks::concat() instead")
1009 poplar::program::Sequence &prog,
1010 const
poplar::DebugContext &debugContext = {},
#define GCL_NO_DISCARD
Produce compile time warning for unused return values.
Definition: Collectives.hpp:22
DebugContext gathers the common external parameters of the context of an operation.
Definition: DebugContext.hpp:221
This class represents a graph program to be executed on the IPU.
Definition: Graph.hpp:52
A set of option/value string flags to be used in various APIs.
Definition: OptionFlags.hpp:24
A reference to a subset of tensor elements.
Definition: Tensor.hpp:38
Program that executes a sequence of programs.
Definition: Program.hpp:77
Graphcore Communications Library.
CommGroupType
Enum to define communication group specification type.
Definition: Collectives.hpp:34
@ ALL
All replicas viewed as one group.
@ CONSECUTIVE
Groups are consecutive in replica.
@ ORTHOGONAL
Groups are sliced orthogonal to the replica ordering.
std::istream & operator>>(std::istream &is, CollectiveOperator &op)
Parse token from input stream is to op.
CollectiveOperator
Supported collective operators.
Definition: Collectives.hpp:106
@ LOGICAL_OR
Only supports boolean operands.
@ LOCAL
Do nothing and keep the local value.
@ SQUARE_ADD
Squares each element before applying ADD reduction.
@ LOGICAL_AND
Only supports boolean operands.
poplar::Tensor allToAllCrossReplica(poplar::Graph &graph, const poplar::Tensor &data, poplar::program::Sequence &prog, const CommGroup &group, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Perform an all-to-all exchange of the elements of the input tensor based on replica ID.
void allReduceInPlaceCrossReplica(poplar::Graph &graph, poplar::Tensor &data, CollectiveOperator op, poplar::program::Sequence &prog, const CommGroup &group, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
As allReduceCrossReplica() but writes result back to the input data tensor.
poplar::Tensor allGatherWithinReplica(poplar::Graph &graph, const Chunks &toGather, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Broadcast data distributed over all IPUs.
void allGatherToDestinationCrossReplica(poplar::Graph &graph, const std::vector< poplar::Tensor > &datas, const std::vector< poplar::Tensor > &destinations, poplar::program::Sequence &prog, const CommGroup &group, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
As allGatherCrossReplica() but with vector input/output arguments.
poplar::Tensor allReduceCrossReplica(poplar::Graph &graph, const poplar::Tensor &data, CollectiveOperator op, poplar::program::Sequence &prog, const CommGroup &group, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Perform an all-reduce operation.
poplar::Tensor broadcastCrossReplica(poplar::Graph &graph, const poplar::Tensor &data, poplar::program::Sequence &prog, const CommGroup &group={}, unsigned rootReplica=0, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Perform a broadcast from one replica to all other replicas.
Chunks reduceScatterWithinReplica(poplar::Graph &graph, const poplar::Tensor &toReduce, CollectiveOperator op, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Reduce a rank 2 tensor.
std::ostream & operator<<(std::ostream &os, const CollectiveOperator &op)
Write op to output stream os.
poplar::Tensor allReduceWithinReplica(poplar::Graph &graph, const poplar::Tensor &toReduce, CollectiveOperator op, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Perform an all-reduce operation on the specified tensor.
void reduceScatterToDestinationCrossReplica(poplar::Graph &graph, const std::vector< poplar::Tensor > &datas, const std::vector< poplar::Tensor > &destinations, CollectiveOperator op, poplar::program::Sequence &prog, const CommGroup &group, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
As reduceScatterCrossReplica() but with vector input/output arguments.
void allReduceToDestinationCrossReplica(poplar::Graph &graph, const poplar::Tensor &data, poplar::Tensor &destination, CollectiveOperator op, poplar::program::Sequence &prog, const CommGroup &group, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
As allReduceCrossReplica() but writes the result to the destination tensor.
poplar::Tensor concatChunks(const Chunks &chunks)
Concatenates chunks.
poplar::Tensor reduceScatterCrossReplica(poplar::Graph &graph, const poplar::Tensor &data, CollectiveOperator op, poplar::program::Sequence &prog, const CommGroup &group, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Reduce the replicated rank-1 tensor data with the result scattered across the replicas.
poplar::Tensor allGatherCrossReplica(poplar::Graph &graph, const poplar::Tensor &data, poplar::program::Sequence &prog, const CommGroup &group, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Gather the replicated tensor data.
Poplar classes and functions.
Definition: ArrayRef.hpp:14
Represents a section of a tensor mapped to an IPU.
Definition: Collectives.hpp:802
unsigned getOffset() const
Offset within rank (model parallel index.
unsigned offset
Offset within rank (model parallel index.
Definition: Collectives.hpp:814
void setOffset(unsigned offset)
Set offset.
Chunk(const Chunk &)=default
Defaulted to avoid warnings in deprecation period.
poplar::Tensor tensor
Mapped tensor.
Definition: Collectives.hpp:806
unsigned getIndex() const
Ring index (data parallel index)
void setTensor(poplar::Tensor tensor)
Set mapped tensor.
poplar::Tensor getTensor() const
Mapped tensor.
void setIndex(unsigned index)
Set ring index.
unsigned index
Ring index (data parallel index)
Definition: Collectives.hpp:810
Chunk(Chunk &&) noexcept=default
Defaulted to avoid warnings in deprecation period.
A vector of Chunk data.
Definition: Collectives.hpp:878
poplar::Tensor concat() const
Concatenates chunks.
poplar::Tensor getOriginalInput() const
Used to undo shuffles introduced by scatter.
poplar::Tensor originalInput
Used to undo shuffles introduced by scatter.
Definition: Collectives.hpp:882
void setOriginalInput(poplar::Tensor input)
Set original input.
Chunks(const Chunks &)=default
Defaulted to avoid warnings in deprecation period.
void setChunks(std::vector< Chunk > chunks)
Set chunks produced by scatter step.
const std::vector< Chunk > & getChunks() const
Chunks produced by the scatter step.
Chunks(Chunks &&) noexcept=default
Defaulted to avoid warnings in deprecation period.
std::vector< Chunk > chunks
Chunks produced by the scatter step.
Definition: Collectives.hpp:886
void setChunk(std::vector< Chunk >::size_type i, Chunk chunk)
Set chunk.
Struct to specify sub-groups of replicas.
Definition: Collectives.hpp:70
friend std::ostream & operator<<(std::ostream &os, const CommGroup &group)
String representation of the CommGroup.
unsigned mReplicaGroupSize
Replica group size.
Definition: Collectives.hpp:90
CommGroup(const CommGroupType groupType, unsigned groupSize, unsigned replicaStride=1)
Construct CommGroup.
unsigned mReplicaGroupStride
Replica group stride.
Definition: Collectives.hpp:94
CommGroupType mReplicaGroupType
Replica group type.
Definition: Collectives.hpp:87