Saving and Loading State | Custom State: Map-Style | Custom State: Iterable-Style | Install guide | Beta Usage and Feedback | License
StatefulDataLoader
is a drop-in replacement for torch.utils.data.DataLoader
which offers state_dict/load_state_dict
methods for handling mid-epoch checkpointing which operate on the previous/next iterator requested from the dataloader (resp.).
By default, the state includes the number of batches yielded and uses this to naively fast-forward the sampler (map-style) or the dataset (iterable-style). However if the sampler and/or dataset include state_dict/load_state_dict
methods, then it will call them during its own state_dict/load_state_dict
calls. Under the hood, StatefulDataLoader handles aggregation and distribution of state across multiprocess workers (but not across ranks).
torchdata.stateful_dataloader
is currently available in torchdata>=0.8.0
.
Using pip:
pip install torchdata>=0.8.0
Using conda:
conda install torchdata -c pytorch-nightly
fromtorchdata.stateful_dataloaderimportStatefulDataLoader ... dataloader=StatefulDataLoader(dataset, num_workers=2, ...) fori, batchinenumerate(dataloader): ... ifi==10: state_dict=dataloader.state_dict() break ... # Training run resumes with the previous checkpointdataloader=StatefulDataLoader(dataset, num_workers=2, ...) # Resume state with DataLoaderdataloader.load_state_dict(state_dict) fori, batchinenumerate(dataloader): ...
For efficient resuming, you can resume iteration by defining state_dict/load_state_dict
methods in your sampler. If your dataset has worker-specific state (eg RNG transform state) you can add state_dict/load_state_dict
methods to your dataset.
fromtypingimport*importtorchimporttorch.utils.datafromtorchdata.stateful_dataloaderimportStatefulDataLoader# If you are using the default RandomSampler and BatchSampler in torch.utils.data# they are patched when you import torchdata.stateful_dataloader so that defining# a custom sampler here is unnecessaryclassMySampler(torch.utils.data.Sampler[int]): def__init__(self, high: int, seed: int, limit: int): self.seed, self.high, self.limit=seed, high, limitself.g=torch.Generator() self.g.manual_seed(self.seed) self.i=0def__iter__(self): whileself.i<self.limit: val=int(torch.randint(high=self.high, size=(1,), generator=self.g)) self.i+=1yieldvaldefload_state_dict(self, state_dict: Dict[str, Any]): self.i=state_dict["i"] self.g.set_state(state_dict["rng"]) defstate_dict(self) ->Dict[str, Any]: return {"i": self.i, "rng": self.g.get_state()} # Optional: save dataset random transform stateclassNoisyRange(torch.utils.data.Dataset): def__init__(self, high: int, mean: float, std: float): self.high, self.mean, self.std=high, torch.tensor([float(mean)]), float(std) def__len__(self): returnself.highdef__getitem__(self, idx: int) ->float: ifnot (0<=idx<self.high): raiseIndexError() x=torch.normal(self.mean, self.std) noise=x.item() returnidx+noisedefload_state_dict(self, state_dict): torch.set_rng_state(state_dict["rng"]) defstate_dict(self): return {"rng": torch.get_rng_state()} # Test both single/multiprocess dataloadingfornum_workersin [0, 2]: print(f"{num_workers=}") dl=StatefulDataLoader(NoisyRange(5, 1, 1), sampler=MySampler(5, 1, 10), batch_size=2, drop_last=False, num_workers=num_workers) batches= [] fori, batchinenumerate(dl): batches.append(batch) ifi==2: sd=dl.state_dict() dl.load_state_dict(sd) batches2=list(dl) print(batches[3:]) print(batches2) """Output:num_workers=0[tensor([-0.4526, 3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)][tensor([-0.4526, 3.7948], dtype=torch.float64), tensor([6.5494, 3.0470], dtype=torch.float64)]num_workers=2[tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)][tensor([3.7412, 1.2438], dtype=torch.float64), tensor([4.4807, 4.0036], dtype=torch.float64)]"""
Tracking iteration order with Iterable-style datasets requires state from each worker-level instance of the dataset to be captured. You can define state_dict/load_state_dict
methods on your dataset which capture worker-level state. StatefulDataLoader
will handle aggregation across workers and distribution back to the workers. Calling load_state_dict
requires StatefulDataLoader
to have same num_workers
as those of the provided state_dict
.
fromtypingimport*importtorchimporttorch.utils.datafromtorchdata.stateful_dataloaderimportStatefulDataLoaderclassMyIterableDataset(torch.utils.data.IterableDataset): def__init__(self, high: int, seed: int): self.high, self.seed=high, seedself.g=torch.Generator() self.i=0def__iter__(self): worker_info=torch.utils.data.get_worker_info() ifworker_infoisnotNone: worker_id=worker_info.idnum_workers=worker_info.num_workerselse: worker_id=0num_workers=1self.g.manual_seed(self.seed) arr=torch.randperm(self.high, generator=self.g) arr=arr[worker_id:self.high:num_workers] foridxinrange(self.i, len(arr)): self.i+=1yieldarr[idx] self.i=0defstate_dict(self): return {"i": self.i} defload_state_dict(self, state_dict): self.i=state_dict["i"] # Test both single/multiprocess dataloadingfornum_workersin [0, 2]: print(f"{num_workers=}") dl=StatefulDataLoader( MyIterableDataset(12, 0), batch_size=2, drop_last=False, num_workers=num_workers) batches= [] fori, batchinenumerate(dl): batches.append(batch) ifi==2: sd=dl.state_dict() dl.load_state_dict(sd) batches2=list(dl) print(batches[3:]) print(batches2) """Output:num_workers=0[tensor([ 2, 10]), tensor([3, 1]), tensor([11, 6])][tensor([ 2, 10]), tensor([3, 1]), tensor([11, 6])]num_workers=2[tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])][tensor([ 4, 10]), tensor([ 3, 11]), tensor([1, 6])]"""
We'd love to hear from and work with early adopters to shape our designs. Please reach out by raising an issue if you're interested in using this tooling for your project.
TorchData is BSD licensed, as found in the LICENSE file.