- Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathautoencoder.py
22 lines (19 loc) · 762 Bytes
/
autoencoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
importtorch
importtorch.nnasnn
fromdiffusersimportAutoencoderKL
classSDVAE_EMA(nn.Module):
def__init__(self):
super(SDVAE_EMA, self).__init__()
self.model=AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema", torch_dtype=torch.bfloat16)
self.model.to(torch.device('cuda'iftorch.cuda.is_available() else'cpu'))
defencode(self, x):
returnself.model.encode(x).latents
defdecode(self, x):
withtorch.no_grad():
withtorch.cuda.amp.autocast():
x=self.model.decode(x/self.model.config.scaling_factor).sample
x= (x+1) /2
x=x.clamp(0, 1)
x= (x*255).to(torch.uint8)
x=x.permute(0, 2, 3, 1)
returnx