- Notifications
You must be signed in to change notification settings - Fork 362
/
Copy pathrun_qdq_debug.py
78 lines (60 loc) · 3.08 KB
/
run_qdq_debug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
importargparse
importonnx
fromonnxruntime.quantization.qdq_loss_debugimport (
collect_activations, compute_activation_error, compute_weight_error,
create_activation_matching, create_weight_matching,
modify_model_output_intermediate_tensors)
importresnet50_data_reader
defget_args():
parser=argparse.ArgumentParser()
parser.add_argument(
"--float_model", required=True, help="Path to original floating point model"
)
parser.add_argument("--qdq_model", required=True, help="Path to qdq model")
parser.add_argument(
"--calibrate_dataset", default="./test_images", help="calibration data set"
)
args=parser.parse_args()
returnargs
def_generate_aug_model_path(model_path: str) ->str:
aug_model_path= (
model_path[: -len(".onnx")] ifmodel_path.endswith(".onnx") elsemodel_path
)
returnaug_model_path+".save_tensors.onnx"
defmain():
# Process input parameters and setup model input data reader
args=get_args()
float_model_path=args.float_model
qdq_model_path=args.qdq_model
calibration_dataset_path=args.calibrate_dataset
print("------------------------------------------------\n")
print("Comparing weights of float model vs qdq model.....")
matched_weights=create_weight_matching(float_model_path, qdq_model_path)
weights_error=compute_weight_error(matched_weights)
forweight_name, errinweights_error.items():
print(f"Cross model error of '{weight_name}': {err}\n")
print("------------------------------------------------\n")
print("Augmenting models to save intermediate activations......")
aug_float_model_path=_generate_aug_model_path(float_model_path)
modify_model_output_intermediate_tensors(float_model_path, aug_float_model_path)
aug_qdq_model_path=_generate_aug_model_path(qdq_model_path)
modify_model_output_intermediate_tensors(qdq_model_path, aug_qdq_model_path)
print("------------------------------------------------\n")
print("Running the augmented floating point model to collect activations......")
input_data_reader=resnet50_data_reader.ResNet50DataReader(
calibration_dataset_path, float_model_path
)
float_activations=collect_activations(aug_float_model_path, input_data_reader)
print("------------------------------------------------\n")
print("Running the augmented qdq model to collect activations......")
input_data_reader.rewind()
qdq_activations=collect_activations(aug_qdq_model_path, input_data_reader)
print("------------------------------------------------\n")
print("Comparing activations of float model vs qdq model......")
act_matching=create_activation_matching(qdq_activations, float_activations)
act_error=compute_activation_error(act_matching)
foract_name, errinact_error.items():
print(f"Cross model error of '{act_name}': {err['xmodel_err']}\n")
print(f"QDQ error of '{act_name}': {err['qdq_err']}\n")
if__name__=="__main__":
main()