- Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathfid_evaluation.py
109 lines (98 loc) · 3.94 KB
/
fid_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
importmath
importos
importnumpyasnp
importtorch
fromeinopsimportrearrange, repeat
frompytorch_fid.fid_scoreimportcalculate_frechet_distance
frompytorch_fid.inceptionimportInceptionV3
fromtorch.nn.functionalimportadaptive_avg_pool2d
fromtqdm.autoimporttqdm
defnum_to_groups(num, divisor):
groups=num//divisor
remainder=num%divisor
arr= [divisor] *groups
ifremainder>0:
arr.append(remainder)
returnarr
classFIDEvaluation:
def__init__(
self,
batch_size,
dl,
sampler,
channels=3,
accelerator=None,
stats_dir="./results",
device="cuda",
num_fid_samples=50000,
inception_block_idx=2048,
):
self.batch_size=batch_size
self.n_samples=num_fid_samples
self.device=device
self.channels=channels
self.dl=dl
self.sampler=sampler
self.stats_dir=stats_dir
self.print_fn=printifacceleratorisNoneelseaccelerator.print
assertinception_block_idxinInceptionV3.BLOCK_INDEX_BY_DIM
block_idx=InceptionV3.BLOCK_INDEX_BY_DIM[inception_block_idx]
self.inception_v3=InceptionV3([block_idx]).to(device)
self.dataset_stats_loaded=False
defcalculate_inception_features(self, samples):
ifself.channels==1:
samples=repeat(samples, "b 1 ... -> b c ...", c=3)
self.inception_v3.eval()
features=self.inception_v3(samples)[0]
iffeatures.size(2) !=1orfeatures.size(3) !=1:
features=adaptive_avg_pool2d(features, output_size=(1, 1))
features=rearrange(features, "... 1 1 -> ...")
returnfeatures
defload_or_precalc_dataset_stats(self):
path=os.path.join(self.stats_dir, "dataset_stats")
try:
ckpt=np.load(path+".npz")
self.m2, self.s2=ckpt["m2"], ckpt["s2"]
self.print_fn("Dataset stats loaded from disk.")
ckpt.close()
exceptOSError:
num_batches=int(math.ceil(self.n_samples/self.batch_size))
stacked_real_features= []
self.print_fn(
f"Stacking Inception features for {self.n_samples} samples from the real dataset."
)
for_intqdm(range(num_batches)):
try:
real_samples=next(self.dl)
exceptStopIteration:
break
real_samples=real_samples.to(self.device)
real_features=self.calculate_inception_features(real_samples)
stacked_real_features.append(real_features)
stacked_real_features= (
torch.cat(stacked_real_features, dim=0).cpu().numpy()
)
m2=np.mean(stacked_real_features, axis=0)
s2=np.cov(stacked_real_features, rowvar=False)
np.savez_compressed(path, m2=m2, s2=s2)
self.print_fn(f"Dataset stats cached to {path}.npz for future use.")
self.m2, self.s2=m2, s2
self.dataset_stats_loaded=True
@torch.inference_mode()
deffid_score(self):
ifnotself.dataset_stats_loaded:
self.load_or_precalc_dataset_stats()
self.sampler.eval()
batches=num_to_groups(self.n_samples, self.batch_size)
stacked_fake_features= []
self.print_fn(
f"Stacking Inception features for {self.n_samples} generated samples."
)
forbatchintqdm(batches):
fake_samples=self.sampler.sample(batch_size=batch)
fake_features=self.calculate_inception_features(fake_samples)
stacked_fake_features.append(fake_features)
stacked_fake_features=torch.cat(stacked_fake_features, dim=0).cpu().numpy()
m1=np.mean(stacked_fake_features, axis=0)
s1=np.cov(stacked_fake_features, rowvar=False)
returncalculate_frechet_distance(m1, s1, self.m2, self.s2)