- Notifications
You must be signed in to change notification settings - Fork 45
/
Copy path__init__.py
93 lines (72 loc) · 2.65 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
importos
fromfunctoolsimportwraps
fromimportlibimportimport_module
fromhypothesisimportstrategiesasst
fromhypothesis.extraimportarray_api
from . import_version
__all__= ["xp", "api_version", "xps"]
# You can comment the following out and instead import the specific array module
# you want to test, e.g. `import array_api_strict as xp`.
if"ARRAY_API_TESTS_MODULE"inos.environ:
env_var=os.environ["ARRAY_API_TESTS_MODULE"]
ifenv_var.startswith("exec('") andenv_var.endswith("')"):
script=env_var[6:][:-2]
namespace= {}
exec(script, namespace)
xp=namespace["xp"]
xp_name=xp.__name__
else:
xp_name=env_var
_module, _sub=xp_name, None
if"."inxp_name:
_module, _sub=xp_name.split(".", 1)
xp=import_module(_module)
if_sub:
try:
xp=getattr(xp, _sub)
exceptAttributeError:
# _sub may be a submodule that needs to be imported. We can't
# do this in every case because some array modules are not
# submodules that can be imported (like mxnet.nd).
xp=import_module(xp_name)
else:
raiseRuntimeError(
"No array module specified - either edit __init__.py or set the "
"ARRAY_API_TESTS_MODULE environment variable."
)
# If xp.bool is not available, like in some versions of NumPy and CuPy, try
# patching in xp.bool_.
try:
xp.bool
exceptAttributeErrorase:
ifhasattr(xp, "bool_"):
xp.bool=xp.bool_
else:
raisee
# We monkey patch floats() to always disable subnormals as they are out-of-scope
_floats=st.floats
@wraps(_floats)
deffloats(*a, **kw):
kw["allow_subnormal"] =False
return_floats(*a, **kw)
st.floats=floats
# We do the same with xps.from_dtype() - this is not strictly necessary, as
# the underlying floats() will never generate subnormals. We only do this
# because internal logic in xps.from_dtype() assumes xp.finfo() has its
# attributes as scalar floats, which is expected behaviour but disrupts many
# unrelated tests.
try:
__from_dtype=array_api._from_dtype
@wraps(__from_dtype)
def_from_dtype(*a, **kw):
kw["allow_subnormal"] =False
return__from_dtype(*a, **kw)
array_api._from_dtype=_from_dtype
exceptAttributeError:
# Ignore monkey patching if Hypothesis changes the private API
pass
api_version=os.getenv(
"ARRAY_API_TESTS_VERSION", getattr(xp, "__array_api_version__", "2024.12")
)
xps=array_api.make_strategies_namespace(xp, api_version=api_version)
__version__=_version.get_versions()["version"]