Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions spec/draft/API_specification/searching_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ Objects in API
count_nonzero
nonzero
searchsorted
top_k
where
56 changes: 54 additions & 2 deletions src/array_api_stubs/_draft/searching_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
__all__ = ["argmax", "argmin", "count_nonzero", "nonzero", "searchsorted", "where"]
__all__ = [
"argmax",
"argmin",
"count_nonzero",
"nonzero",
"searchsorted",
"top_k",
"where",
]


from ._types import Optional, Tuple, Literal, Union, array
from ._types import Optional, Literal, Tuple, Union, array


def argmax(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> array:
Expand Down Expand Up @@ -177,6 +185,50 @@ def searchsorted(
"""


def top_k(
Comment thread
rgommers marked this conversation as resolved.
x: array,
k: int,
/,
*,
axis: Optional[int] = None,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd second Olivier's comment (#722 (comment)):

I don't think users have an need for this: then can flatten the input by themselves if need.

As a user, I would definitely prefer to flatten myself, as opposed to getting a 1D array as an output for an nD input.

The "axis=None means ravel" default IMO makes sense for reductions which return a scalar: "give me the sum of all elements of this array which happens to be nD".

Returning a ravelled array for an nD input is not intuitive, unexpected, and I don't think it has much precedent even in NumPy? If anything, the default for np.sort, np.partition and np.argpartition is all axis=-1, so it seems to make sense to be consistent with them --- which is the what the NumPy PR default, too numpy/numpy#31659

And in the Array API spec, sort defaults to axis=-1, too.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the default should be -1, not None. This is what all array libraries do for their implementation, and sort/partition default is at least as relevant as min/max.

A code search shows that it's also somewhat regularly used in practice for torch.topk, with dim=1 or dim=-1 most often.

mode: Literal["largest", "smallest"] = "largest",
Comment thread
rgommers marked this conversation as resolved.
) -> Tuple[array, array]:
"""
Returns the values and indices of the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension.

Parameters
----------
x: array
input array. Should have a real-valued data type.
k: int
number of elements to find. Must be a positive integer value.
axis: Optional[int]
axis along which to search. If ``None``, the function must search the flattened array. Default: ``None``.
mode: Literal['largest', 'smallest']
search mode. Must be one of the following modes:

- ``'largest'``: return the ``k`` largest elements.
- ``'smallest'``: return the ``k`` smallest elements.

Default: ``'largest'``.

Returns
-------
out: Tuple[array, array]
a namedtuple ``(values, indices)`` whose

- first element must have the field name ``values`` and must be an array containing the ``k`` largest (or smallest) elements of ``x``. The array must have the same data type as ``x``. If ``axis`` is ``None``, the array must be a one-dimensional array having shape ``(k,)``; otherwise, if ``axis`` is an integer value, the array must have the same rank (number of dimensions) and shape as ``x``, except for the axis specified by ``axis`` which must have size ``k``.
- second element must have the field name ``indices`` and must be an array containing indices of ``x`` that result in ``values``. The array must have the same shape as ``values`` and must have the default array index data type. If ``axis`` is ``None``, ``indices`` must be the indices of a flattened ``x``.

Notes
-----

- If ``k`` exceeds the number of elements in ``x`` or along the axis specified by ``axis``, behavior is left unspecified and thus implementation-dependent. Conforming implementations may choose, e.g., to raise an exception or return all elements.
- The order of the returned values and indices is left unspecified and thus implementation-dependent. Conforming implementations may return sorted or unsorted values.
- Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`).
"""


def where(
condition: array,
x1: Union[array, int, float, complex, bool],
Expand Down
Loading