- Notifications
You must be signed in to change notification settings - Fork 5.9k
/
Copy pathserver.py
133 lines (105 loc) · 4.06 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
importasyncio
importlogging
importos
importrandom
importtempfile
importtraceback
importuuid
importaiohttp
importtorch
fromfastapiimportFastAPI, HTTPException
fromfastapi.middleware.corsimportCORSMiddleware
fromfastapi.staticfilesimportStaticFiles
frompydanticimportBaseModel
fromdiffusers.pipelines.stable_diffusion_3importStableDiffusion3Pipeline
logger=logging.getLogger(__name__)
classTextToImageInput(BaseModel):
model: str
prompt: str
size: str|None=None
n: int|None=None
classHttpClient:
session: aiohttp.ClientSession=None
defstart(self):
self.session=aiohttp.ClientSession()
asyncdefstop(self):
awaitself.session.close()
self.session=None
def__call__(self) ->aiohttp.ClientSession:
assertself.sessionisnotNone
returnself.session
classTextToImagePipeline:
pipeline: StableDiffusion3Pipeline=None
device: str=None
defstart(self):
iftorch.cuda.is_available():
model_path=os.getenv("MODEL_PATH", "stabilityai/stable-diffusion-3.5-large")
logger.info("Loading CUDA")
self.device="cuda"
self.pipeline=StableDiffusion3Pipeline.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
).to(device=self.device)
eliftorch.backends.mps.is_available():
model_path=os.getenv("MODEL_PATH", "stabilityai/stable-diffusion-3.5-medium")
logger.info("Loading MPS for Mac M Series")
self.device="mps"
self.pipeline=StableDiffusion3Pipeline.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
).to(device=self.device)
else:
raiseException("No CUDA or MPS device available")
app=FastAPI()
service_url=os.getenv("SERVICE_URL", "http://localhost:8000")
image_dir=os.path.join(tempfile.gettempdir(), "images")
ifnotos.path.exists(image_dir):
os.makedirs(image_dir)
app.mount("/images", StaticFiles(directory=image_dir), name="images")
http_client=HttpClient()
shared_pipeline=TextToImagePipeline()
# Configure CORS settings
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods, e.g., GET, POST, OPTIONS, etc.
allow_headers=["*"], # Allows all headers
)
@app.on_event("startup")
defstartup():
http_client.start()
shared_pipeline.start()
defsave_image(image):
filename="draw"+str(uuid.uuid4()).split("-")[0] +".png"
image_path=os.path.join(image_dir, filename)
# write image to disk at image_path
logger.info(f"Saving image to {image_path}")
image.save(image_path)
returnos.path.join(service_url, "images", filename)
@app.get("/")
@app.post("/")
@app.options("/")
asyncdefbase():
return"Welcome to Diffusers! Where you can use diffusion models to generate images"
@app.post("/v1/images/generations")
asyncdefgenerate_image(image_input: TextToImageInput):
try:
loop=asyncio.get_event_loop()
scheduler=shared_pipeline.pipeline.scheduler.from_config(shared_pipeline.pipeline.scheduler.config)
pipeline=StableDiffusion3Pipeline.from_pipe(shared_pipeline.pipeline, scheduler=scheduler)
generator=torch.Generator(device=shared_pipeline.device)
generator.manual_seed(random.randint(0, 10000000))
output=awaitloop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator=generator))
logger.info(f"output: {output}")
image_url=save_image(output.images[0])
return {"data": [{"url": image_url}]}
exceptExceptionase:
ifisinstance(e, HTTPException):
raisee
elifhasattr(e, "message"):
raiseHTTPException(status_code=500, detail=e.message+traceback.format_exc())
raiseHTTPException(status_code=500, detail=str(e) +traceback.format_exc())
if__name__=="__main__":
importuvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)