Scatter

#include <popops/Scatter.hpp>

Scatter operations.

namespace popops

Common functions, such as elementwise and reductions.

Typedefs

using UpdateComputationFunc = std::function<poplar::Tensor(poplar::Graph&, poplar::Tensor&, poplar::Tensor&, poplar::program::Sequence&)>

Functions

void scatter(poplar::Graph &graph, const poplar::Tensor &operand, const poplar::Tensor &indices, const poplar::Tensor &updates, std::size_t indexVectorDim, std::vector<unsigned> updateWindowDims, std::vector<std::size_t> insertWindowDims, std::vector<unsigned> scatterDimsToOperandDims, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &optionFlags = {})

The scatter operation generates a result which is the value of the input array operand, with several slices (at indices specified by indices) updated with the values in updates.

  • ** scatter options **

    • remapOutOfBoundIndices (true, false) [=false] Out of bounds indices are mapped to index 0.

    • paddingIndexUsed (true, false) [=false] Padding index equal to the size of the slice dimension of tensor operand may be used in the indices. The actual padding values returned for a padding index are zeros.

Note

This is a near direct port of https://www.tensorflow.org/xla/operation_semantics#scatter from tensorflow/compiler/xla/service/scatter_expander.cc

Parameters
  • graph – The Poplar graph.

  • operand – Array to be scattered into.

  • indices – Array containing the starting indices of the slices that must be scattered to.

  • updates – Array containing the values that must be used for scattering.

  • indexVectorDim – The dimension in indices that contains the starting indices.

  • updateWindowDims – The set of dimensions in updates shape that are window dimensions.

  • insertWindowDims – The set of window dimensions that must be inserted into updates shape.

  • scatterDimsToOperandDims – A dimensions map from the scatter indices to the operand index space. This array is interpreted as mapping i to scatterDimsToOperandDims[i]. It has to be one-to-one and total.

  • prog – The program to be extended.

  • debugContext – Optional debug information.

  • optionFlags – Scatter options

void scatter(poplar::Graph &graph, const poplar::Tensor &operand, const poplar::Tensor &indices, const poplar::Tensor &updates, std::size_t indexVectorDim, std::vector<unsigned> updateWindowDims, std::vector<std::size_t> insertWindowDims, std::vector<unsigned> scatterDimsToOperandDims, UpdateComputationFunc &updateComputation, poplar::program::Sequence &prog, const poplar::DebugContext &debugContext = {}, const poplar::OptionFlags &optionFlags = {})

Similar to the above scatter(), but allows for a user defined update computation.

This computation is used to combine the existing values in the input tensor and the updates during the scatter.

See overload for more information on optionFlags.

Note

The first tensor parameter that is passed into the updateComputation will always be the current value from the operand tensor and the second parameter will always be the value from the updates tensor. This is important specifically for cases when the updateComputation is not commutative.

Parameters
  • graph – The Poplar graph.

  • operand – Array to be scattered into.

  • indices – Array containing the starting indices of the slices that must be scattered to.

  • updates – Array containing the values that must be used for scattering.

  • indexVectorDim – The dimension in indices that contains the starting indices.

  • updateWindowDims – The set of dimensions in updates shape that are window dimensions.

  • insertWindowDims – The set of window dimensions that must be inserted into updates shape.

  • scatterDimsToOperandDims – A map of dimensions from the scatter indices to the operand index space. This array is interpreted as mapping i to scatterDimsToOperandDims[i]. It has to be one-to-one and total.

  • updateComputation – Computation to be used for combining the existing values in the input tensor and the updates during scatter.

  • prog – The program to be extended.

  • debugContext – Optional debug information.

  • optionFlags – Option flags