- Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathstubs.py
98 lines (83 loc) · 3.45 KB
/
stubs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
importinspect
importsys
fromimportlibimportimport_module
fromimportlib.utilimportfind_spec
frompathlibimportPath
fromtypesimportFunctionType, ModuleType
fromtypingimportDict, List
from . importapi_version
__all__= [
"name_to_func",
"array_methods",
"array_attributes",
"category_to_funcs",
"EXTENSIONS",
"extension_to_funcs",
]
spec_module="_"+api_version.replace('.', '_')
spec_dir=Path(__file__).parent.parent/"array-api"/"spec"/api_version/"API_specification"
assertspec_dir.exists(), f"{spec_dir} not found - try `git submodule update --init`"
sigs_dir=Path(__file__).parent.parent/"array-api"/"src"/"array_api_stubs"/spec_module
assertsigs_dir.exists()
sigs_abs_path: str=str(sigs_dir.parent.parent.resolve())
sys.path.append(sigs_abs_path)
assertfind_spec(f"array_api_stubs.{spec_module}") isnotNone
name_to_mod: Dict[str, ModuleType] = {}
forpathinsigs_dir.glob("*.py"):
name=path.name.replace(".py", "")
name_to_mod[name] =import_module(f"array_api_stubs.{spec_module}.{name}")
array=name_to_mod["array_object"].array
array_methods= [
fforn, fininspect.getmembers(array, predicate=inspect.isfunction)
ifn!="__init__"# probably exists for Sphinx
]
array_attributes= [
nforn, fininspect.getmembers(array, predicate=lambdax: isinstance(x, property))
]
category_to_funcs: Dict[str, List[FunctionType]] = {}
forname, modinname_to_mod.items():
ifname.endswith("_functions"):
category=name.replace("_functions", "")
objects= [getattr(mod, name) fornameinmod.__all__]
assertall(isinstance(o, FunctionType) foroinobjects) # sanity check
category_to_funcs[category] =objects
all_funcs= []
forfuncsin [array_methods, *category_to_funcs.values()]:
all_funcs.extend(funcs)
name_to_func: Dict[str, FunctionType] = {f.__name__: fforfinall_funcs}
info_funcs= []
ifapi_version>="2023.12":
# The info functions in the stubs are in info.py, but this is not a name
# in the standard.
info_mod=name_to_mod["info"]
# Note that __array_namespace_info__ is in info.__all__ but it is in the
# top-level namespace, not the info namespace.
info_funcs= [getattr(info_mod, name) fornameininfo_mod.__all__
ifname!='__array_namespace_info__']
assertall(isinstance(f, FunctionType) forfininfo_funcs)
name_to_func.update({f.__name__: fforfininfo_funcs})
all_funcs.append(info_mod.__array_namespace_info__)
name_to_func['__array_namespace_info__'] =info_mod.__array_namespace_info__
category_to_funcs['info'] = [info_mod.__array_namespace_info__]
EXTENSIONS: List[str] = ["linalg"]
ifapi_version>="2022.12":
EXTENSIONS.append("fft")
extension_to_funcs: Dict[str, List[FunctionType]] = {}
forextinEXTENSIONS:
mod=name_to_mod[ext]
objects= [getattr(mod, name) fornameinmod.__all__]
assertall(isinstance(o, FunctionType) foroinobjects) # sanity check
funcs= []
forfuncinobjects:
if"Alias"infunc.__doc__:
funcs.append(name_to_func[func.__name__])
else:
funcs.append(func)
extension_to_funcs[ext] =funcs
forfuncsinextension_to_funcs.values():
forfuncinfuncs:
iffunc.__name__notinname_to_func.keys():
name_to_func[func.__name__] =func
# sanity check public attributes are not empty
forattrin__all__:
assertlen(locals()[attr]) !=0, f"{attr} is empty"