Skip to content

Commit

Permalink
ENH: special: use from_dlpack for array conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
mdhaber committed Nov 14, 2024
1 parent ae5f5bb commit baad43a
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 6 deletions.
111 changes: 108 additions & 3 deletions scipy/_lib/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,102 @@ def _asarray(
return array


def _resolve_device(array, xp_out, device=None):
if device is not None:
return device

xp_in = array_namespace(array)

# Here is what the logic is intended to be
# if is_numpy(xp_out) or is_array_api_strict(xp_out) or is_cupy(xp_out):
# # These libraries only support one type of device
# return xp_out.__array_namespace_info__().default_device()
# else:
# return xp_in.device(array)

# but there are many issues with device support right now. For example:
# - NumPy `from_dlpack` only understands device='cpu' (not a device object)
# - CuPy `from_dlpack` doesn't work
# - the format of the return of `device` seems different in different libraries

# So in the meantime:
if is_numpy(xp_out):
return "cpu"
elif is_array_api_strict(xp_out):
return xp_out.__array_namespace_info__().default_device()
else:
return None


def xp_asarray(
array: ArrayLike,
dtype: Any = None,
order: Literal['K', 'A', 'C', 'F'] | None = None,
copy: bool | None = None,
*,
xp: ModuleType | None = None,
device: str | None = None,
check_finite: bool = False,
subok: bool = False,
) -> Array:
"""SciPy-specific replacement for `xp.asarray` with `order`, `check_finite`,
`subok`, and automatic array type conversion and device transfer.
Memory layout parameter `order` is not exposed in the Array API standard.
`order` is only enforced if the input array implementation
is NumPy based, otherwise `order` is just silently ignored.
`check_finite` is also not a keyword in the array API standard; included
here for convenience rather than that having to be a separate function
call inside SciPy functions.
`subok` is included to allow this function to preserve the behaviour of
`np.asanyarray` for NumPy based inputs.
"""
xp_in = array_namespace(array)
if xp is None:
xp = xp_in

if is_numpy(xp_in):
# If object is array-like (but not array), make it an ndarray; otherwise, no-op.
array = np.asanyarray(array)

if is_numpy(xp_in) and not array.__class__.__name__.endswith('.ndarray') and subok:
# If it's a (strict) subclass of ndarray and we must preserve the subclass...
if not is_numpy(xp):
message = f"Array library {xp} cannot respect `subok=True`."
raise TypeError(message)
# This will do the right thing for NumPy 2.0+; for earlier versions of NumPy,
# it will not respect copy=False.
array = np.array(array, order=order, dtype=dtype, copy=copy, subok=True)
elif is_numpy(xp):
# Convert to NumPy array. Raise if copy is necessary and copy=False;
# otherwise, don't copy yet.
array = np_compat.from_dlpack(array, copy=None if copy is True else copy,
device=_resolve_device(array, np_compat, device))
# Now apply all the options
array = np_compat.asarray(array, order=order, dtype=dtype, copy=copy)
elif is_cupy(xp):
# This shouldn't be needed, but CuPy from_dlpack doesn't work. It doesn't
# accept a device argument, and when calling without one, it recommends:
# Use `cupy.array(numpy.from_dlpack(input))` instead.
array = xp.asarray(np_compat.from_dlpack(array), dtype=dtype, copy=copy)
# I don't know if other libraries support `from_dlpack`, either
# We might need other special cases like this.
else:
if order is not None:
pass # should we raise?
array = xp.from_dlpack(array, copy=None if copy is True else copy,
device=_resolve_device(array, xp, device))
array = xp.asarray(array, dtype=dtype, copy=copy)

if check_finite:
_check_finite(array, xp)

return array


def xp_atleast_nd(x: Array, *, ndim: int, xp: ModuleType | None = None) -> Array:
"""Recursively expand the dimension to have at least `ndim`."""
if xp is None:
Expand Down Expand Up @@ -316,7 +412,10 @@ def xp_assert_equal(actual, desired, *, check_namespace=True, check_dtype=True,
err_msg = None if err_msg == '' else err_msg
return xp.testing.assert_close(actual, desired, rtol=0, atol=0, equal_nan=True,
check_dtype=False, msg=err_msg)
# JAX uses `np.testing`

# JAX uses `np.testing` natively; array-api-strict must be converted
actual = xp_asarray(actual, xp=np) if is_array_api_strict(xp) else actual
desired = xp_asarray(desired, xp=np) if is_array_api_strict(xp) else desired
return np.testing.assert_array_equal(actual, desired, err_msg=err_msg)


Expand Down Expand Up @@ -349,7 +448,10 @@ def xp_assert_close(actual, desired, *, rtol=None, atol=0, check_namespace=True,
err_msg = None if err_msg == '' else err_msg
return xp.testing.assert_close(actual, desired, rtol=rtol, atol=atol,
equal_nan=True, check_dtype=False, msg=err_msg)
# JAX uses `np.testing`

# JAX uses `np.testing` natively; array-api-strict must be converted
actual = xp_asarray(actual, xp=np) if is_array_api_strict(xp) else actual
desired = xp_asarray(desired, xp=np) if is_array_api_strict(xp) else desired
return np.testing.assert_allclose(actual, desired, rtol=rtol,
atol=atol, err_msg=err_msg)

Expand All @@ -374,7 +476,10 @@ def xp_assert_less(actual, desired, *, check_namespace=True, check_dtype=True,
actual = actual.cpu()
if desired.device.type != 'cpu':
desired = desired.cpu()
# JAX uses `np.testing`

# JAX uses `np.testing` natively; array-api-strict must be converted
actual = xp_asarray(actual, xp=np) if is_array_api_strict(xp) else actual
desired = xp_asarray(desired, xp=np) if is_array_api_strict(xp) else desired
return np.testing.assert_array_less(actual, desired,
err_msg=err_msg, verbose=verbose)

Expand Down
6 changes: 3 additions & 3 deletions scipy/special/_support_alternative_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
from scipy._lib._array_api import (
array_namespace, scipy_namespace_for, is_numpy
array_namespace, scipy_namespace_for, is_numpy, xp_asarray
)
from . import _ufuncs
# These don't really need to be imported, but otherwise IDEs might not realize
Expand Down Expand Up @@ -41,9 +41,9 @@ def get_array_special_func(f_name, xp, n_array_args):
def __f(*args, _f=_f, _xp=xp, **kwargs):
array_args = args[:n_array_args]
other_args = args[n_array_args:]
array_args = [np.asarray(arg) for arg in array_args]
array_args = [xp_asarray(arg, xp=np) for arg in array_args]
out = _f(*array_args, *other_args, **kwargs)
return _xp.asarray(out)
return xp_asarray(out, xp=xp)

return __f

Expand Down

0 comments on commit baad43a

Please sign in to comment.