- Notifications
You must be signed in to change notification settings - Fork 5.9k
/
Copy pathtest_pipeline_flux_control_inpaint.py
175 lines (150 loc) · 6.27 KB
/
test_pipeline_flux_control_inpaint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
importunittest
importnumpyasnp
importtorch
fromPILimportImage
fromtransformersimportAutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
fromdiffusersimport (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
FluxControlInpaintPipeline,
FluxTransformer2DModel,
)
fromdiffusers.utils.testing_utilsimport (
torch_device,
)
from ..test_pipelines_commonimport (
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
)
classFluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class=FluxControlInpaintPipeline
params=frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params=frozenset(["prompt"])
# there is no xformers processor for Flux
test_xformers_attention=False
defget_dummy_components(self):
torch.manual_seed(0)
transformer=FluxTransformer2DModel(
patch_size=1,
in_channels=8,
out_channels=4,
num_layers=1,
num_single_layers=1,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=32,
pooled_projection_dim=32,
axes_dims_rope=[4, 4, 8],
)
clip_text_encoder_config=CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder=CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2=T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer=CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2=AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae=AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler=FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
}
defget_dummy_inputs(self, device, seed=0):
ifstr(device).startswith("mps"):
generator=torch.manual_seed(seed)
else:
generator=torch.Generator(device="cpu").manual_seed(seed)
image=Image.new("RGB", (8, 8), 0)
control_image=Image.new("RGB", (8, 8), 0)
mask_image=Image.new("RGB", (8, 8), 255)
inputs= {
"prompt": "A painting of a squirrel eating a burger",
"control_image": control_image,
"generator": generator,
"image": image,
"mask_image": mask_image,
"strength": 0.8,
"num_inference_steps": 2,
"guidance_scale": 30.0,
"height": 8,
"width": 8,
"max_sequence_length": 48,
"output_type": "np",
}
returninputs
deftest_fused_qkv_projections(self):
device="cpu"# ensure determinism for the device-dependent torch.Generator
components=self.get_dummy_components()
pipe=self.pipeline_class(**components)
pipe=pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs=self.get_dummy_inputs(device)
image=pipe(**inputs).images
original_image_slice=image[0, -3:, -3:, -1]
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
assertcheck_qkv_fusion_processors_exist(pipe.transformer), (
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
)
assertcheck_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs=self.get_dummy_inputs(device)
image=pipe(**inputs).images
image_slice_fused=image[0, -3:, -3:, -1]
pipe.transformer.unfuse_qkv_projections()
inputs=self.get_dummy_inputs(device)
image=pipe(**inputs).images
image_slice_disabled=image[0, -3:, -3:, -1]
assertnp.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
"Fusion of QKV projections shouldn't affect the outputs."
)
assertnp.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
"Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
)
assertnp.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
"Original outputs should match when fused QKV projections are disabled."
)
deftest_flux_image_output_shape(self):
pipe=self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs=self.get_dummy_inputs(torch_device)
height_width_pairs= [(32, 32), (72, 57)]
forheight, widthinheight_width_pairs:
expected_height=height-height% (pipe.vae_scale_factor*2)
expected_width=width-width% (pipe.vae_scale_factor*2)
inputs.update({"height": height, "width": width})
image=pipe(**inputs).images[0]
output_height, output_width, _=image.shape
assert (output_height, output_width) == (expected_height, expected_width)