- Notifications
You must be signed in to change notification settings - Fork 5.9k
/
Copy pathbenchmark_text_to_image.py
40 lines (32 loc) · 1.07 KB
/
benchmark_text_to_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
importargparse
importsys
sys.path.append(".")
frombase_classesimportTextToImageBenchmark, TurboTextToImageBenchmark# noqa: E402
ALL_T2I_CKPTS= [
"Lykon/DreamShaper",
"segmind/SSD-1B",
"stabilityai/stable-diffusion-xl-base-1.0",
"kandinsky-community/kandinsky-2-2-decoder",
"warp-ai/wuerstchen",
"stabilityai/sdxl-turbo",
]
if__name__=="__main__":
parser=argparse.ArgumentParser()
parser.add_argument(
"--ckpt",
type=str,
default="Lykon/DreamShaper",
choices=ALL_T2I_CKPTS,
)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--model_cpu_offload", action="store_true")
parser.add_argument("--run_compile", action="store_true")
args=parser.parse_args()
benchmark_cls=None
if"turbo"inargs.ckpt:
benchmark_cls=TurboTextToImageBenchmark
else:
benchmark_cls=TextToImageBenchmark
benchmark_pipe=benchmark_cls(args)
benchmark_pipe.benchmark(args)