Poplar and PopLibs
Collectives.hpp
Go to the documentation of this file.
1// Copyright (c) 2019 Graphcore Ltd. All rights reserved.
2#ifndef INCLUDE_GCL_COLLECTIVES_HPP
3#define INCLUDE_GCL_COLLECTIVES_HPP
4
5#include <poplar/DebugContext.hpp>
6#include <poplar/Graph.hpp>
7#include <poplar/OptionFlags.hpp>
8#include <poplar/Program.hpp>
9#include <poplar/Tensor.hpp>
10#include <iosfwd>
11#include <vector>
12
13#if __cplusplus >= 201603L
17#define GCL_NO_DISCARD [[nodiscard]]
18#else
22#define GCL_NO_DISCARD
23#endif
24
28namespace gcl {
29
34enum class CommGroupType {
36 ALL,
49};
50
51// clang-format off
69// clang-format on
70struct CommGroup {
71 CommGroup() = default;
79 CommGroup(const CommGroupType groupType, unsigned groupSize,
80 unsigned replicaStride = 1);
81
82 virtual ~CommGroup() = default;
83
84protected:
86 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
89 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
90 unsigned mReplicaGroupSize = 0;
93 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
94 unsigned mReplicaGroupStride = 1;
95
100 friend std::ostream &operator<<(std::ostream &os, const CommGroup &group);
101};
102
107 ADD,
108 MEAN,
109 MUL,
110 MIN,
111 MAX,
113 LOGICAL_OR,
114 SQUARE_ADD,
115 LOCAL,
116};
117
125std::istream &operator>>(std::istream &is, CollectiveOperator &op);
126
134std::ostream &operator<<(std::ostream &os, const CollectiveOperator &op);
135
140
311 const CommGroup &group,
312 const poplar::DebugContext &debugContext = {},
313 const poplar::OptionFlags &options = {});
314
334GCL_NO_DISCARD std::vector<poplar::Tensor> allReduceCrossReplica(
335 poplar::Graph &graph, const std::vector<poplar::Tensor> &datas,
337 const CommGroup &group, const poplar::DebugContext &debugContext = {},
338 const poplar::OptionFlags &options = {});
339
353 const poplar::DebugContext &debugContext = {},
354 const poplar::OptionFlags &options = {});
355
369GCL_NO_DISCARD std::vector<poplar::Tensor>
371 const std::vector<poplar::Tensor> &datas,
373 const poplar::DebugContext &debugContext = {},
374 const poplar::OptionFlags &options = {});
375
389 poplar::Graph &graph, const poplar::Tensor &data,
390 poplar::Tensor &destination, CollectiveOperator op,
391 poplar::program::Sequence &prog, const CommGroup &group,
392 const poplar::DebugContext &debugContext = {},
393 const poplar::OptionFlags &options = {});
394
408 poplar::Graph &graph, const std::vector<poplar::Tensor> &datas,
409 const std::vector<poplar::Tensor> &destinations, CollectiveOperator op,
410 poplar::program::Sequence &prog, const CommGroup &group,
411 const poplar::DebugContext &debugContext = {},
412 const poplar::OptionFlags &options = {});
413
425 poplar::Graph &graph, const poplar::Tensor &data,
426 poplar::Tensor &destination, CollectiveOperator op,
428 const poplar::DebugContext &debugContext = {},
429 const poplar::OptionFlags &options = {});
430
445 const CommGroup &group,
446 const poplar::DebugContext &debugContext = {},
447 const poplar::OptionFlags &options = {});
448
467 std::vector<poplar::Tensor> &datas,
470 const CommGroup &group,
471 const poplar::DebugContext &debugContext = {},
472 const poplar::OptionFlags &options = {});
473
486 const poplar::DebugContext &debugContext = {},
487 const poplar::OptionFlags &options = {});
488
500 std::vector<poplar::Tensor> &datas,
503 const poplar::DebugContext &debugContext = {},
504 const poplar::OptionFlags &options = {});
505
563 poplar::Graph &graph, const poplar::Tensor &data, CollectiveOperator op,
564 poplar::program::Sequence &prog, const CommGroup &group,
565 const poplar::DebugContext &debugContext = {},
566 const poplar::OptionFlags &options = {});
567
580GCL_NO_DISCARD std::vector<poplar::Tensor> reduceScatterCrossReplica(
581 poplar::Graph &graph, const std::vector<poplar::Tensor> &datas,
583 const CommGroup &group, const poplar::DebugContext &debugContext = {},
584 const poplar::OptionFlags &options = {});
585
598 poplar::Graph &graph, const std::vector<poplar::Tensor> &datas,
599 const std::vector<poplar::Tensor> &destinations, CollectiveOperator op,
600 poplar::program::Sequence &prog, const CommGroup &group,
601 const poplar::DebugContext &debugContext = {},
602 const poplar::OptionFlags &options = {});
603
618 const poplar::DebugContext &debugContext = {},
619 const poplar::OptionFlags &options = {});
620
653 poplar::program::Sequence &prog, const CommGroup &group,
654 const poplar::DebugContext &debugContext = {},
655 const poplar::OptionFlags &options = {});
656
668GCL_NO_DISCARD std::vector<poplar::Tensor>
670 const std::vector<poplar::Tensor> &datas,
671 poplar::program::Sequence &prog, const CommGroup &group,
672 const poplar::DebugContext &debugContext = {},
673 const poplar::OptionFlags &options = {});
674
689 poplar::Graph &graph, const std::vector<poplar::Tensor> &datas,
690 const std::vector<poplar::Tensor> &destinations,
691 poplar::program::Sequence &prog, const CommGroup &group,
692 const poplar::DebugContext &debugContext = {},
693 const poplar::OptionFlags &options = {});
694
707 const poplar::DebugContext &debugContext = {},
708 const poplar::OptionFlags &options = {});
709
745 poplar::program::Sequence &prog, const CommGroup &group,
746 const poplar::DebugContext &debugContext = {},
747 const poplar::OptionFlags &options = {});
748
761 const poplar::DebugContext &debugContext = {},
762 const poplar::OptionFlags &options = {});
763
790 const CommGroup &group = {}, unsigned rootReplica = 0,
791 const poplar::DebugContext &debugContext = {},
792 const poplar::OptionFlags &options = {});
793
795
800
802struct Chunk {
804 GC_DEPRECATED_MSG("Use get/setTensor() instead")
805 // NOLINTNEXTLINE(readability-identifier-naming)
806 poplar::Tensor tensor;
808 GC_DEPRECATED_MSG("Use get/setIndex() instead")
809 // NOLINTNEXTLINE(readability-identifier-naming)
810 unsigned index = 0;
812 GC_DEPRECATED_MSG("Use get/setOffset() instead")
813 // NOLINTNEXTLINE(readability-identifier-naming)
814 unsigned offset = 0;
815
816#pragma GCC diagnostic push
817#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
818#pragma GCC diagnostic ignored "-Wshadow"
819 // In the deprecation period we need to ensure none of the
820 // default implementations warn about using the deprecated vars.
821 Chunk() = default;
822 ~Chunk() = default;
824 Chunk(const Chunk &) = default;
826 Chunk(Chunk &&) noexcept = default;
830 Chunk &operator=(const Chunk &) = default;
834 Chunk &operator=(Chunk &&) noexcept = default;
835
843 Chunk(poplar::Tensor tensor, unsigned index, unsigned offset);
844#pragma GCC diagnostic pop
845
850
854 GCL_NO_DISCARD unsigned getIndex() const;
855
859 GCL_NO_DISCARD unsigned getOffset() const;
860
865
869 void setIndex(unsigned index);
870
874 void setOffset(unsigned offset);
875};
876
878struct Chunks {
880 GC_DEPRECATED_MSG("Use get/setOriginalInput() instead")
881 // NOLINTNEXTLINE(readability-identifier-naming)
884 GC_DEPRECATED_MSG("Use get/setChunks() instead")
885 // NOLINTNEXTLINE(readability-identifier-naming)
886 std::vector<Chunk> chunks;
887
888#pragma GCC diagnostic push
889#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
890#pragma GCC diagnostic ignored "-Wshadow"
891 // In the deprecation period we need to ensure none of the
892 // default implementations warn about using the deprecated vars.
893 Chunks() = default;
894 ~Chunks() = default;
896 Chunks(const Chunks &) = default;
898 Chunks(Chunks &&) noexcept = default;
902 Chunks &operator=(const Chunks &) = default;
906 Chunks &operator=(Chunks &&) noexcept = default;
907
913 explicit Chunks(unsigned size) : chunks(std::vector<Chunk>(size)) {}
914#pragma GCC diagnostic pop
915
920
924 GCL_NO_DISCARD const std::vector<Chunk> &getChunks() const;
925
930
935 void setChunk(std::vector<Chunk>::size_type i, Chunk chunk);
936
940 void setChunks(std::vector<Chunk> chunks);
941
953};
954
966GC_DEPRECATED_MSG("Use Chunks::concat() instead")
967poplar::Tensor concatChunks(const Chunks &chunks);
968
1008 poplar::Graph &graph, const poplar::Tensor &toReduce, CollectiveOperator op,
1009 poplar::program::Sequence &prog,
1010 const poplar::DebugContext &debugContext = {},
1011 const poplar::OptionFlags &options = {});
1012
1048 const poplar::DebugContext &debugContext = {},
1049 const poplar::OptionFlags &options = {});
1050
1094 const poplar::DebugContext &debugContext = {},
1095 const poplar::OptionFlags &options = {});
1096
1098
1099} // namespace gcl
1100#endif // INCLUDE_GCL_COLLECTIVES_HPP
#define GCL_NO_DISCARD
Produce compile time warning for unused return values.
Definition: Collectives.hpp:22
DebugContext gathers the common external parameters of the context of an operation.
Definition: DebugContext.hpp:221
This class represents a graph program to be executed on the IPU.
Definition: Graph.hpp:52
A set of option/value string flags to be used in various APIs.
Definition: OptionFlags.hpp:24
A reference to a subset of tensor elements.
Definition: Tensor.hpp:38
Program that executes a sequence of programs.
Definition: Program.hpp:77
Graphcore Communications Library.
CommGroupType
Enum to define communication group specification type.
Definition: Collectives.hpp:34
@ ALL
All replicas viewed as one group.
@ CONSECUTIVE
Groups are consecutive in replica.
@ ORTHOGONAL
Groups are sliced orthogonal to the replica ordering.
std::istream & operator>>(std::istream &is, CollectiveOperator &op)
Parse token from input stream is to op.
CollectiveOperator
Supported collective operators.
Definition: Collectives.hpp:106
@ LOGICAL_OR
Only supports boolean operands.
@ LOCAL
Do nothing and keep the local value.
@ SQUARE_ADD
Squares each element before applying ADD reduction.
@ LOGICAL_AND
Only supports boolean operands.
poplar::Tensor allToAllCrossReplica(poplar::Graph &graph, const poplar::Tensor &data, poplar::program::Sequence &prog, const CommGroup &group, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Perform an all-to-all exchange of the elements of the input tensor based on replica ID.
void allReduceInPlaceCrossReplica(poplar::Graph &graph, poplar::Tensor &data, CollectiveOperator op, poplar::program::Sequence &prog, const CommGroup &group, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
As allReduceCrossReplica() but writes result back to the input data tensor.
poplar::Tensor allGatherWithinReplica(poplar::Graph &graph, const Chunks &toGather, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Broadcast data distributed over all IPUs.
void allGatherToDestinationCrossReplica(poplar::Graph &graph, const std::vector< poplar::Tensor > &datas, const std::vector< poplar::Tensor > &destinations, poplar::program::Sequence &prog, const CommGroup &group, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
As allGatherCrossReplica() but with vector input/output arguments.
poplar::Tensor allReduceCrossReplica(poplar::Graph &graph, const poplar::Tensor &data, CollectiveOperator op, poplar::program::Sequence &prog, const CommGroup &group, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Perform an all-reduce operation.
poplar::Tensor broadcastCrossReplica(poplar::Graph &graph, const poplar::Tensor &data, poplar::program::Sequence &prog, const CommGroup &group={}, unsigned rootReplica=0, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Perform a broadcast from one replica to all other replicas.
Chunks reduceScatterWithinReplica(poplar::Graph &graph, const poplar::Tensor &toReduce, CollectiveOperator op, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Reduce a rank 2 tensor.
std::ostream & operator<<(std::ostream &os, const CollectiveOperator &op)
Write op to output stream os.
poplar::Tensor allReduceWithinReplica(poplar::Graph &graph, const poplar::Tensor &toReduce, CollectiveOperator op, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Perform an all-reduce operation on the specified tensor.
void reduceScatterToDestinationCrossReplica(poplar::Graph &graph, const std::vector< poplar::Tensor > &datas, const std::vector< poplar::Tensor > &destinations, CollectiveOperator op, poplar::program::Sequence &prog, const CommGroup &group, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
As reduceScatterCrossReplica() but with vector input/output arguments.
void allReduceToDestinationCrossReplica(poplar::Graph &graph, const poplar::Tensor &data, poplar::Tensor &destination, CollectiveOperator op, poplar::program::Sequence &prog, const CommGroup &group, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
As allReduceCrossReplica() but writes the result to the destination tensor.
poplar::Tensor concatChunks(const Chunks &chunks)
Concatenates chunks.
poplar::Tensor reduceScatterCrossReplica(poplar::Graph &graph, const poplar::Tensor &data, CollectiveOperator op, poplar::program::Sequence &prog, const CommGroup &group, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Reduce the replicated rank-1 tensor data with the result scattered across the replicas.
poplar::Tensor allGatherCrossReplica(poplar::Graph &graph, const poplar::Tensor &data, poplar::program::Sequence &prog, const CommGroup &group, const poplar::DebugContext &debugContext={}, const poplar::OptionFlags &options={})
Gather the replicated tensor data.
Poplar classes and functions.
Definition: ArrayRef.hpp:14
Represents a section of a tensor mapped to an IPU.
Definition: Collectives.hpp:802
unsigned getOffset() const
Offset within rank (model parallel index.
unsigned offset
Offset within rank (model parallel index.
Definition: Collectives.hpp:814
void setOffset(unsigned offset)
Set offset.
Chunk(const Chunk &)=default
Defaulted to avoid warnings in deprecation period.
poplar::Tensor tensor
Mapped tensor.
Definition: Collectives.hpp:806
unsigned getIndex() const
Ring index (data parallel index)
void setTensor(poplar::Tensor tensor)
Set mapped tensor.
poplar::Tensor getTensor() const
Mapped tensor.
void setIndex(unsigned index)
Set ring index.
unsigned index
Ring index (data parallel index)
Definition: Collectives.hpp:810
Chunk(Chunk &&) noexcept=default
Defaulted to avoid warnings in deprecation period.
A vector of Chunk data.
Definition: Collectives.hpp:878
poplar::Tensor concat() const
Concatenates chunks.
poplar::Tensor getOriginalInput() const
Used to undo shuffles introduced by scatter.
poplar::Tensor originalInput
Used to undo shuffles introduced by scatter.
Definition: Collectives.hpp:882
void setOriginalInput(poplar::Tensor input)
Set original input.
Chunks(const Chunks &)=default
Defaulted to avoid warnings in deprecation period.
void setChunks(std::vector< Chunk > chunks)
Set chunks produced by scatter step.
const std::vector< Chunk > & getChunks() const
Chunks produced by the scatter step.
Chunks(Chunks &&) noexcept=default
Defaulted to avoid warnings in deprecation period.
std::vector< Chunk > chunks
Chunks produced by the scatter step.
Definition: Collectives.hpp:886
void setChunk(std::vector< Chunk >::size_type i, Chunk chunk)
Set chunk.
Struct to specify sub-groups of replicas.
Definition: Collectives.hpp:70
friend std::ostream & operator<<(std::ostream &os, const CommGroup &group)
String representation of the CommGroup.
unsigned mReplicaGroupSize
Replica group size.
Definition: Collectives.hpp:90
CommGroup(const CommGroupType groupType, unsigned groupSize, unsigned replicaStride=1)
Construct CommGroup.
unsigned mReplicaGroupStride
Replica group stride.
Definition: Collectives.hpp:94
CommGroupType mReplicaGroupType
Replica group type.
Definition: Collectives.hpp:87