- Notifications
You must be signed in to change notification settings - Fork 5.9k
/
Copy pathtest_scheduler_unclip.py
141 lines (102 loc) · 4.9 KB
/
test_scheduler_unclip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
importunittest
importtorch
fromdiffusersimportUnCLIPScheduler
from .test_schedulersimportSchedulerCommonTest
# UnCLIPScheduler is a modified DDPMScheduler with a subset of the configuration.
classUnCLIPSchedulerTest(SchedulerCommonTest):
scheduler_classes= (UnCLIPScheduler,)
defget_scheduler_config(self, **kwargs):
config= {
"num_train_timesteps": 1000,
"variance_type": "fixed_small_log",
"clip_sample": True,
"clip_sample_range": 1.0,
"prediction_type": "epsilon",
}
config.update(**kwargs)
returnconfig
deftest_timesteps(self):
fortimestepsin [1, 5, 100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
deftest_variance_type(self):
forvariancein ["fixed_small_log", "learned_range"]:
self.check_over_configs(variance_type=variance)
deftest_clip_sample(self):
forclip_samplein [True, False]:
self.check_over_configs(clip_sample=clip_sample)
deftest_clip_sample_range(self):
forclip_sample_rangein [1, 5, 10, 20]:
self.check_over_configs(clip_sample_range=clip_sample_range)
deftest_prediction_type(self):
forprediction_typein ["epsilon", "sample"]:
self.check_over_configs(prediction_type=prediction_type)
deftest_time_indices(self):
fortime_stepin [0, 500, 999]:
forprev_timestepin [None, 5, 100, 250, 500, 750]:
ifprev_timestepisnotNoneandprev_timestep>=time_step:
continue
self.check_over_forward(time_step=time_step, prev_timestep=prev_timestep)
deftest_variance_fixed_small_log(self):
scheduler_class=self.scheduler_classes[0]
scheduler_config=self.get_scheduler_config(variance_type="fixed_small_log")
scheduler=scheduler_class(**scheduler_config)
asserttorch.sum(torch.abs(scheduler._get_variance(0) -1.0000e-10)) <1e-5
asserttorch.sum(torch.abs(scheduler._get_variance(487) -0.0549625)) <1e-5
asserttorch.sum(torch.abs(scheduler._get_variance(999) -0.9994987)) <1e-5
deftest_variance_learned_range(self):
scheduler_class=self.scheduler_classes[0]
scheduler_config=self.get_scheduler_config(variance_type="learned_range")
scheduler=scheduler_class(**scheduler_config)
predicted_variance=0.5
assertscheduler._get_variance(1, predicted_variance=predicted_variance) --10.1712790<1e-5
assertscheduler._get_variance(487, predicted_variance=predicted_variance) --5.7998052<1e-5
assertscheduler._get_variance(999, predicted_variance=predicted_variance) --0.0010011<1e-5
deftest_full_loop(self):
scheduler_class=self.scheduler_classes[0]
scheduler_config=self.get_scheduler_config()
scheduler=scheduler_class(**scheduler_config)
timesteps=scheduler.timesteps
model=self.dummy_model()
sample=self.dummy_sample_deter
generator=torch.manual_seed(0)
fori, tinenumerate(timesteps):
# 1. predict noise residual
residual=model(sample, t)
# 2. predict previous mean of sample x_t-1
pred_prev_sample=scheduler.step(residual, t, sample, generator=generator).prev_sample
sample=pred_prev_sample
result_sum=torch.sum(torch.abs(sample))
result_mean=torch.mean(torch.abs(sample))
assertabs(result_sum.item() -252.2682495) <1e-2
assertabs(result_mean.item() -0.3284743) <1e-3
deftest_full_loop_skip_timesteps(self):
scheduler_class=self.scheduler_classes[0]
scheduler_config=self.get_scheduler_config()
scheduler=scheduler_class(**scheduler_config)
scheduler.set_timesteps(25)
timesteps=scheduler.timesteps
model=self.dummy_model()
sample=self.dummy_sample_deter
generator=torch.manual_seed(0)
fori, tinenumerate(timesteps):
# 1. predict noise residual
residual=model(sample, t)
ifi+1==timesteps.shape[0]:
prev_timestep=None
else:
prev_timestep=timesteps[i+1]
# 2. predict previous mean of sample x_t-1
pred_prev_sample=scheduler.step(
residual, t, sample, prev_timestep=prev_timestep, generator=generator
).prev_sample
sample=pred_prev_sample
result_sum=torch.sum(torch.abs(sample))
result_mean=torch.mean(torch.abs(sample))
assertabs(result_sum.item() -258.2044983) <1e-2
assertabs(result_mean.item() -0.3362038) <1e-3
@unittest.skip("Test not supported.")
deftest_trained_betas(self):
pass
@unittest.skip("Test not supported.")
deftest_add_noise_device(self):
pass