-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Today DataLoader works well and returns reproducible results when fixing torch.manual_seed.
However, the problem is that the batch sampler or sampler (if it uses torch random generator) returns differents results if it used alone or in data loader (even fixing the random state) when num_workers > 0
The problem is here, when base_seed is randomly set, random indices generated by the first call of next(sampler_iter) that happens after are not the same as they'd have been generated alone (with next) or by the data loader with num_workers=0.
Can we set base_seed differently without using torch random generator?
For example, we can use a fix base_seed in DataLoaderIter:
base_seed = 7056
self.workers = [
multiprocessing.Process(
target=_worker_loop,
args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
base_seed + i, self.worker_init_fn, i))
for i in range(self.num_workers)]and set a random seed in _worker_loop:
def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
# ...
base_seed = torch.LongTensor(1).random_(seed)[0]
torch.set_num_threads(1)
torch.manual_seed(base_seed)
# ...What do you think ?
Context: this can be helpful if I would like to resume DataLoader from a given iteration (if training has crashed for example).
Thanks