Skip to content

[WIP] ENH: dask+cupy, dask+sparse etc. namespaces#270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base:main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions array_api_compat/common/_helpers.py
Original file line numberDiff line numberDiff line change
Expand Up@@ -397,7 +397,9 @@ def is_dask_namespace(xp: Namespace) -> bool:
"""
Returns True if `xp` is a Dask namespace.

This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
This includes ``dask.array`` itself, the version wrapped by array-api-compat,
and the bespoke namespaces generated by
``array_api_compat.dask.array.wrap_namespace``.

See Also
--------
Expand All@@ -411,7 +413,13 @@ def is_dask_namespace(xp: Namespace) -> bool:
is_pydata_sparse_namespace
is_array_api_strict_namespace
"""
return xp.__name__ in {"dask.array", _compat_module_name() + ".dask.array"}
da_compat_name = _compat_module_name() + '.dask.array'
name = xp.__name__
return (
name in {'dask.array', da_compat_name}
or name.startswith(da_compat_name + '.')
and name[len(da_compat_name) + 1:] not in ("linalg", "fft")
)


def is_jax_namespace(xp: Namespace) -> bool:
Expand DownExpand Up@@ -597,9 +605,16 @@ def your_function(x, y):
elif is_dask_array(x):
if _use_compat:
_check_api_version(api_version)
from ..dask import array as dask_namespace

namespaces.add(dask_namespace)
from ..dask.array import wrap_namespace

# The meta-namespace is only used to generate the meta-array, so it
# would be useless to create a namespace such as e.g.
# array_api_compat.dask.array.array_api_compat.cupy.
# It would get worse once you vendor array-api-compat!
# So keep it clean with array_api_compat.dask.array.cupy.
mxp = array_namespace(x._meta, use_compat=False)
xp = wrap_namespace(mxp)
namespaces.add(xp)
else:
import dask.array as da

Expand Down
1 change: 1 addition & 0 deletions array_api_compat/dask/array/__init__.py
Original file line numberDiff line numberDiff line change
Expand Up@@ -4,6 +4,7 @@

# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403
from ._meta import wrap_namespace # noqa: F401

__array_api_version__: Final = "2024.12"

Expand Down
13 changes: 11 additions & 2 deletions array_api_compat/dask/array/_aliases.py
Original file line numberDiff line numberDiff line change
Expand Up@@ -152,6 +152,7 @@ def asarray(
dtype: DType | None = None,
device: Device | None = None,
copy: py_bool | None = None,
like: Array | None = None,
**kwargs: object,
) -> Array:
"""
Expand All@@ -168,7 +169,11 @@ def asarray(
if copy is False:
raise ValueError("Unable to avoid copy when changing dtype")
obj = obj.astype(dtype)
return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue]
if copy:
obj = obj.copy()
if like is not None:
obj = da.asarray(obj, like=like)
return obj

if copy is False:
raise NotImplementedError(
Expand All@@ -177,7 +182,11 @@ def asarray(

# copy=None to be uniform across dask < 2024.12 and >= 2024.12
# see https://github.com/dask/dask/pull/11524/
obj = np.array(obj, dtype=dtype, copy=True)
if like is not None:
mxp = array_namespace(like)
obj = mxp.asarray(obj, dtype=dtype, copy=True)
else:
obj = np.array(obj, dtype=dtype, copy=True)
return da.from_array(obj)


Expand Down
49 changes: 49 additions & 0 deletions array_api_compat/dask/array/_meta.py
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
import functools
import sys
import types

from ...common._helpers import is_numpy_namespace
from ...common._typing import Namespace

__all__ = ['wrap_namespace']


def wrap_namespace(xp: Namespace) -> Namespace:
"""Create a bespoke Dask namespace that wraps around another namespace.

Parameters
----------
xp : namespace
Namespace to be wrapped by Dask

Returns
-------
namespace :
A module object that duplicates array_api_compat.dask.array, with the
difference that all creation functions will create an array with the same
meta namespace as the input.
"""
from .. import array as da_compat

if is_numpy_namespace(xp):
return da_compat

mod_name = f'{da_compat.__name__}.{xp.__name__}'
try:
return sys.modules[mod_name]
except KeyError:
pass

mod = types.ModuleType(mod_name)
sys.modules[mod_name] = mod

meta = xp.empty(())
for name, v in da_compat.__dict__.items():
if name.startswith('_'):
continue
if name in {'arange', 'asarray', 'empty', 'eye', 'from_dlpack',
'full', 'linspace', 'ones', 'zeros'}:
v = functools.wraps(v)(functools.partial(v, like=meta))
setattr(mod, name, v)

return mod
Loading
close