TopK

#include <popops/TopK.hpp>

Functions for finding the top k elements.

namespace popops

Common functions, such as elementwise and reductions.

Functions

std::ostream &operator<<(std::ostream &os, const TopKParams &p)
poplar::Tensor createTopKInput(poplar::Graph &graph, const poplar::Type &type, const std::vector<std::size_t> &shape, const TopKParams &params, const poplar::DebugContext &debugContext = {})

Create an return a new tensor laid out optimally to be used as an input to a topK operation with the given parameters.

Parameters
  • graph – The Poplar graph to add the tensor to.

  • type – The Poplar type of elements in the returned tensor.

  • shape – The shape of the returned tensor.

  • params – The parameters of the top k that the returned tensor will be used as input to.

  • debugContext – Optional debug information.

Returns

A newly created tensor with shape shape and full tile mapping.

poplar::Tensor topK(poplar::Graph &graph, poplar::program::Sequence &prog, const poplar::Tensor &t, const TopKParams &params, const poplar::DebugContext &debugContext = {})

Return the top k values in the innermost dimension of a tensor.

Parameters
  • graph – The Poplar graph to add the operation to.

  • prog – The Poplar sequence to add the operation to.

  • t – The tensor in which to find the top-k values in the innermost dimension.

  • params – The parameters of the top k.

  • debugContext – Optional debug information.

Returns

A tensor with the top k values found in the innermost dimension of t.

std::pair<poplar::Tensor, poplar::Tensor> topKKeyValue(poplar::Graph &graph, poplar::program::Sequence &prog, const poplar::Tensor &keys, const poplar::Tensor &values, const TopKParams &params, const poplar::DebugContext &debugContext = {})

Return the top k values in the innermost dimension of a tensor along with the permutation of another tensor with respect to the values.

Parameters
  • graph – The Poplar graph to add the operation to.

  • prog – The Poplar sequence to add the operation to.

  • key – The tensor in which to find the top-k values in the innermost dimension.

  • value – A tensor with the same shape as key for which to get the permutation with respect to key.

  • params – The parameters of the top k.

  • debugContext – Optional debug information.

Returns

A pair of tensors. The first contains the top k values found in the innermost dimension of key. The second contains the permutation of the tensor value with respect to the tensor key.

std::pair<poplar::Tensor, poplar::Tensor> topKWithPermutation(poplar::Graph &graph, poplar::program::Sequence &prog, const poplar::Tensor &t, const TopKParams &params, const poplar::DebugContext &debugContext = {})

Return the top k values in the innermost dimension of a tensor along with the indices of those values in the input tensor in the innermost dimension.

Parameters
  • graph – The Poplar graph to add the operation to.

  • prog – The Poplar sequence to add the operation to.

  • t – The tensor in which to find the top-k values in the innermost dimension.

  • params – The parameters of the top k.

  • debugContext – Optional debug information.

Returns

A pair of tensors. The first contains the top k values found in the innermost dimension of t. The second contains the indices of those values in the innermost dimension of t in the original input.

struct TopKParams
#include <TopK.hpp>

Parameters for topK* APIs.

Public Functions

TopKParams(unsigned k, bool largest, SortOrder sortOrder, bool stableSort = false) noexcept

Public Members

unsigned k

The number of outputs from the top k operation.

This must be less or equal the number of elements in the innermost dimension of the tensor used as input to the operation.

bool largest

If true, return the top k largest elements.

Otherwise return the top k smallest elements.

SortOrder sortOrder

The required ordering of elements in the resulting tensor.

bool stableSort

When sortOrder != SortOrder::NONE and stableSort is true, the relative order of values that compare equal are guaranteed not to change in the output.