- Notifications
You must be signed in to change notification settings - Fork 5.9k
/
Copy pathtest_attention_processor.py
131 lines (100 loc) · 4.47 KB
/
test_attention_processor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
importtempfile
importunittest
importnumpyasnp
importpytest
importtorch
fromdiffusersimportDiffusionPipeline
fromdiffusers.models.attention_processorimportAttention, AttnAddedKVProcessor
fromdiffusers.utils.testing_utilsimporttorch_device
classAttnAddedKVProcessorTests(unittest.TestCase):
defget_constructor_arguments(self, only_cross_attention: bool=False):
query_dim=10
ifonly_cross_attention:
cross_attention_dim=12
else:
# when only cross attention is not set, the cross attention dim must be the same as the query dim
cross_attention_dim=query_dim
return {
"query_dim": query_dim,
"cross_attention_dim": cross_attention_dim,
"heads": 2,
"dim_head": 4,
"added_kv_proj_dim": 6,
"norm_num_groups": 1,
"only_cross_attention": only_cross_attention,
"processor": AttnAddedKVProcessor(),
}
defget_forward_arguments(self, query_dim, added_kv_proj_dim):
batch_size=2
hidden_states=torch.rand(batch_size, query_dim, 3, 2)
encoder_hidden_states=torch.rand(batch_size, 4, added_kv_proj_dim)
attention_mask=None
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"attention_mask": attention_mask,
}
deftest_only_cross_attention(self):
# self and cross attention
torch.manual_seed(0)
constructor_args=self.get_constructor_arguments(only_cross_attention=False)
attn=Attention(**constructor_args)
self.assertTrue(attn.to_kisnotNone)
self.assertTrue(attn.to_visnotNone)
forward_args=self.get_forward_arguments(
query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"]
)
self_and_cross_attn_out=attn(**forward_args)
# only self attention
torch.manual_seed(0)
constructor_args=self.get_constructor_arguments(only_cross_attention=True)
attn=Attention(**constructor_args)
self.assertTrue(attn.to_kisNone)
self.assertTrue(attn.to_visNone)
forward_args=self.get_forward_arguments(
query_dim=constructor_args["query_dim"], added_kv_proj_dim=constructor_args["added_kv_proj_dim"]
)
only_cross_attn_out=attn(**forward_args)
self.assertTrue((only_cross_attn_out!=self_and_cross_attn_out).all())
classDeprecatedAttentionBlockTests(unittest.TestCase):
@pytest.fixture(scope="session")
defis_dist_enabled(pytestconfig):
returnpytestconfig.getoption("dist") =="loadfile"
@pytest.mark.xfail(
condition=torch.device(torch_device).type=="cuda"andis_dist_enabled,
reason="Test currently fails on our GPU CI because of `loadfile`. Note that it only fails when the tests are distributed from `pytest ... tests/models`. If the tests are run individually, even with `loadfile` it won't fail.",
strict=True,
)
deftest_conversion_when_using_device_map(self):
pipe=DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
pre_conversion=pipe(
"foo",
num_inference_steps=2,
generator=torch.Generator("cpu").manual_seed(0),
output_type="np",
).images
# the initial conversion succeeds
pipe=DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", device_map="balanced", safety_checker=None
)
conversion=pipe(
"foo",
num_inference_steps=2,
generator=torch.Generator("cpu").manual_seed(0),
output_type="np",
).images
withtempfile.TemporaryDirectory() astmpdir:
# save the converted model
pipe.save_pretrained(tmpdir)
# can also load the converted weights
pipe=DiffusionPipeline.from_pretrained(tmpdir, device_map="balanced", safety_checker=None)
after_conversion=pipe(
"foo",
num_inference_steps=2,
generator=torch.Generator("cpu").manual_seed(0),
output_type="np",
).images
self.assertTrue(np.allclose(pre_conversion, conversion, atol=1e-3))
self.assertTrue(np.allclose(conversion, after_conversion, atol=1e-3))