- Notifications
You must be signed in to change notification settings - Fork 163
/
Copy pathsampler.py
216 lines (182 loc) · 7.66 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
importitertools
fromtypingimportAny, Dict, Iterator, List, Optional, Sized
importtorch.utils.data.sampler
fromtorch.utils.dataimportDataset
fromtorch.utils.data.dataloaderimport_InfiniteConstantSampler
fromtorch.utils.data.samplerimportSampler
from .statefulimportStateful
class_StatefulRandomSamplerIterator(Iterator[int], Stateful):
_GENERATOR="generator"
_YIELDED="yielded"
def__init__(self, sampler):
self.sampler=sampler
self.generator_state=self.sampler.generator.get_state()
self.yielded=0
self.next_yielded=None
self.n=len(sampler.data_source)
self.replacement=sampler.replacement
self.num_samples=sampler.num_samples
self.chunk_size=32
self.perm: List[int] =self._get_perm()
self.perm_index=0
self.chunk_index=0
def__iter__(self):
returnself
def_get_perm(self) ->List[int]:
ifself.replacement:
returntorch.randint(
high=self.n,
size=(self.chunk_size,),
dtype=torch.int64,
generator=self.sampler.generator,
).tolist()
else:
returntorch.randperm(self.n, generator=self.sampler.generator).tolist()
def__next__(self):
ifself.yielded==self.num_samples:
raiseStopIteration()
ifself.perm_index==len(self.perm):
self.perm=self._get_perm()
self.perm_index=0
val=self.perm[self.perm_index]
self.perm_index+=1
self.yielded+=1
returnval
defstate_dict(self) ->dict:
return {
self._YIELDED: self.yielded,
self._GENERATOR: self.generator_state,
}
defload_state_dict(self, state_dict: dict) ->None:
self.next_yielded=state_dict[self._YIELDED]
self.generator_state=state_dict[self._GENERATOR]
self.sampler.generator.set_state(self.generator_state)
ifself.next_yieldedisnotNone:
self.perm=self._get_perm() # We want permutations from the latest generator state that's loaded
for_inrange(self.next_yielded):
next(self)
self.yielded=self.next_yielded
self.next_yielded=None
classRandomSampler(Sampler[int]):
def__init__(
self,
data_source: Sized,
replacement: bool=False,
num_samples: Optional[int] =None,
generator=None,
) ->None:
self.data_source=data_source
self.replacement=replacement
self._num_samples=num_samples
ifgeneratorisNone:
# Prevoiusly the random seed was fixed as 1. We then changed it to system generated seed to ensure deterministic randomness.
seed=int(torch.empty((), dtype=torch.int64).random_().item())
generator=torch.Generator()
generator.manual_seed(seed)
self.generator=generator
ifnotisinstance(self.replacement, bool):
raiseTypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
ifnotisinstance(self.num_samples, int) orself.num_samples<=0:
raiseValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
@property
defnum_samples(self) ->int:
# dataset size might change at runtime
ifself._num_samplesisNone:
returnlen(self.data_source)
returnself._num_samples
def__iter__(self) ->Iterator[int]:
return_StatefulRandomSamplerIterator(self)
def__len__(self) ->int:
returnself.num_samples
class_BatchSamplerIterator(Iterator[list[int]], Stateful):
_SAMPLES_YIELDED="samples_yielded"
_SAMPLER_STATE="sampler_state"
_SAMPLER_ITER_STATE="sampler_iter_state"
def__init__(self, sampler, batch_size: int, drop_last: bool):
self.sampler=sampler
self.sampler_iter=iter(self.sampler)
self.batch_size=batch_size
self.drop_last=drop_last
self.samples_yielded=0
def__next__(self) ->list[int]:
batch= []
try:
for_inrange(self.batch_size):
batch.append(next(self.sampler_iter))
self.samples_yielded+=1
returnbatch
exceptStopIteration:
ifself.drop_lastorlen(batch) ==0:
raiseStopIteration
else:
returnbatch
defstate_dict(self) ->Dict[str, Any]:
sd: Dict[str, Any] = {self._SAMPLES_YIELDED: self.samples_yielded}
ifisinstance(self.sampler, Stateful):
sd[self._SAMPLER_STATE] =self.sampler.state_dict()
ifisinstance(self.sampler_iter, Stateful):
sd[self._SAMPLER_ITER_STATE] =self.sampler_iter.state_dict()
returnsd
defload_state_dict(self, state_dict: Dict[str, Any]) ->None:
self.samples_yielded=state_dict[self._SAMPLES_YIELDED]
ifself._SAMPLER_STATEinstate_dict:
assertisinstance(self.sampler, Stateful)
self.sampler.load_state_dict(state_dict[self._SAMPLER_STATE])
self.sampler_iter=iter(self.sampler)
ifself._SAMPLER_ITER_STATEinstate_dict:
assertisinstance(self.sampler_iter, Stateful)
self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE])
ifnot (isinstance(self.sampler, Stateful) orisinstance(self.sampler_iter, Stateful)) andnotisinstance(
self.sampler, _InfiniteConstantSampler
):
# We skip x samples if underlying sampler is not stateful
for_inrange(self.samples_yielded):
next(self.sampler_iter)
defupdate_state_dict(self) ->None:
ifisinstance(self.sampler_iter, Stateful) andhasattr(self.sampler_iter, "update_state_dict"):
self.sampler_iter.update_state_dict()
classBatchSampler(torch.utils.data.sampler.BatchSampler):
def__init__(self, sampler, batch_size, drop_last):
super().__init__(sampler, batch_size, drop_last)
def__iter__(self):
return_BatchSamplerIterator(
sampler=self.sampler,
batch_size=self.batch_size,
drop_last=self.drop_last,
)
classStatefulDistributedSampler(torch.utils.data.distributed.DistributedSampler):
_YIELDED="yielded"
def__init__(
self,
dataset: Dataset,
num_replicas: Optional[int] =None,
rank: Optional[int] =None,
shuffle: bool=True,
seed: int=0,
drop_last: bool=False,
) ->None:
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
self.yielded=0
self.next_yielded=None
def__iter__(self):
self.yielded=0
ifself.next_yieldedisnotNone:
self.yielded=self.next_yielded
self.next_yielded=None
it=super().__iter__()
foridxinitertools.islice(it, self.yielded, None):
self.yielded+=1
yieldidx
defstate_dict(self) ->Dict[str, Any]:
return {self._YIELDED: self.yielded}
defload_state_dict(self, state_dict: Dict[str, Any]) ->None:
ifself._YIELDEDnotinstate_dict:
raiseValueError("Invalid state_dict")
ifstate_dict[self._YIELDED] <0:
raiseValueError("Cannot load state_dict with negative yielded value")
self.next_yielded=state_dict[self._YIELDED]