我想保存PyTorch的torch.utils.data.dataloader.DataLoader
实例,以便我可以继续上次中断的地方进行训练(保留随机播放的种子,状态和所有内容)。
这很简单。一个人应该设计自己的方法Sampler
,该方法采用起始索引并自己对数据进行随机排序:
import random
from torch.utils.data.dataloader import Sampler
random.seed(224) # use a fixed number
class MySampler(Sampler):
def __init__(self, data, i=0):
random.shuffle(data)
self.seq = list(range(len(data)))[i * batch_size:]
def __iter__(self):
return iter(self.seq)
def __len__(self):
return len(self.seq)
现在将最后一个索引保存在i
某处,并在下一次实例化DataLoader
它时使用它:
train_dataset = MyDataset(train_data)
train_sampler = MySampler(train_dataset, last_i)
train_data_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
sampler=train_sampler,
shuffle=False) # don't forget to set DataLoader's shuffle to False
在Colab上进行培训时,这非常有用。
本文收集自互联网,转载请注明来源。
如有侵权,请联系 [email protected] 删除。
我来说两句