- Notifications
You must be signed in to change notification settings - Fork 28.8k
/
Copy pathcheck_model_tester.py
63 lines (55 loc) · 2.54 KB
/
check_model_tester.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
importglob
importos
fromget_test_infoimportget_tester_classes
if__name__=="__main__":
failures= []
pattern=os.path.join("tests", "models", "**", "test_modeling_*.py")
test_files=glob.glob(pattern)
# TODO: deal with TF/Flax too
test_files= [
xforxintest_filesifnot (x.startswith("test_modeling_tf_") orx.startswith("test_modeling_flax_"))
]
fortest_fileintest_files:
tester_classes=get_tester_classes(test_file)
fortester_classintester_classes:
# A few tester classes don't have `parent` parameter in `__init__`.
# TODO: deal this better
try:
tester=tester_class(parent=None)
exceptException:
continue
ifhasattr(tester, "get_config"):
config=tester.get_config()
fork, vinconfig.to_dict().items():
ifisinstance(v, int):
target=None
ifkin ["vocab_size"]:
target=100
elifkin ["max_position_embeddings"]:
target=128
elifkin ["hidden_size", "d_model"]:
target=40
elifk== ["num_layers", "num_hidden_layers", "num_encoder_layers", "num_decoder_layers"]:
target=5
iftargetisnotNoneandv>target:
failures.append(
f"{tester_class.__name__} will produce a `config` of type `{config.__class__.__name__}`"
f' with config["{k}"] = {v} which is too large for testing! Set its value to be smaller'
f" than {target}."
)
iflen(failures) >0:
raiseException(f"There were {len(failures)} failures:\n"+"\n".join(failures))