- Notifications
You must be signed in to change notification settings - Fork 5.9k
/
Copy pathconvert_flux_xlabs_ipadapter_to_diffusers.py
97 lines (72 loc) · 3.59 KB
/
convert_flux_xlabs_ipadapter_to_diffusers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
importargparse
fromcontextlibimportnullcontext
importsafetensors.torch
fromaccelerateimportinit_empty_weights
fromhuggingface_hubimporthf_hub_download
fromdiffusers.utils.import_utilsimportis_accelerate_available, is_transformers_available
ifis_transformers_available():
fromtransformersimportCLIPVisionModelWithProjection
vision=True
else:
vision=False
"""
python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
--filename "flux-ip-adapter.safetensors"
--output_path "flux-ip-adapter-hf/"
"""
CTX=init_empty_weightsifis_accelerate_availableelsenullcontext
parser=argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--filename", default="flux.safetensors", type=str)
parser.add_argument("--checkpoint_path", default=None, type=str)
parser.add_argument("--output_path", type=str)
parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)
args=parser.parse_args()
defload_original_checkpoint(args):
ifargs.original_state_dict_repo_idisnotNone:
ckpt_path=hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
elifargs.checkpoint_pathisnotNone:
ckpt_path=args.checkpoint_path
else:
raiseValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
original_state_dict=safetensors.torch.load_file(ckpt_path)
returnoriginal_state_dict
defconvert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers):
converted_state_dict= {}
# image_proj
## norm
converted_state_dict["image_proj.norm.weight"] =original_state_dict.pop("ip_adapter_proj_model.norm.weight")
converted_state_dict["image_proj.norm.bias"] =original_state_dict.pop("ip_adapter_proj_model.norm.bias")
## proj
converted_state_dict["image_proj.proj.weight"] =original_state_dict.pop("ip_adapter_proj_model.norm.weight")
converted_state_dict["image_proj.proj.bias"] =original_state_dict.pop("ip_adapter_proj_model.norm.bias")
# double transformer blocks
foriinrange(num_layers):
block_prefix=f"ip_adapter.{i}."
# to_k_ip
converted_state_dict[f"{block_prefix}to_k_ip.bias"] =original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
)
converted_state_dict[f"{block_prefix}to_k_ip.weight"] =original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
)
# to_v_ip
converted_state_dict[f"{block_prefix}to_v_ip.bias"] =original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
)
converted_state_dict[f"{block_prefix}to_k_ip.weight"] =original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
)
returnconverted_state_dict
defmain(args):
original_ckpt=load_original_checkpoint(args)
num_layers=19
converted_ip_adapter_state_dict=convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers)
print("Saving Flux IP-Adapter in Diffusers format.")
safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")
ifvision:
model=CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
model.save_pretrained(f"{args.output_path}/image_encoder")
if__name__=="__main__":
main(args)