私は1つの小さなデータセットと1つの大きなデータセットを持っており、それらは2つの別々のクラスを意味します。私がトレーニングしているネットワークはスタイル転送であるため、トレーニングを継続するには、クラスごとに1つの画像が必要です。ただし、小さいデータセットがなくなるとすぐにトレーニングは停止します。小さなデータセットからそのサイズを超えてランダムにサンプリングを続けるにはどうすればよいですか?
試しましたRandomSampler()
が、うまくいきませんでした。小さなデータセットのコードは次のとおりです。
sampler = RandomSampler(self)
dataloader = DataLoader(self, batch_size=26, shuffle=False, sampler=sampler)
while True:
for data in dataloader:
yield data
私も試しましたiterator.cycle
が、それも役に立ちませんでした。
loader = iter(cycle(self.dataset.gen(attribute_id, True)))
A, y_A = next(loader)
B, y_B = next(self.dataset.gen(attribute_id, False))
でのあなたのアイデアRandomSampler
はそう遠くはありませんでした。と呼ばれるサンプラーがありSubsetRandomSampler
ます。通常、サブセットはセット全体よりも小さいですが、そうである必要はありません。
小さいデータセットにA
エントリがあり、2番目のデータセットにがあるとしますB
。インデックスを定義できます。
indices = np.random.randint(0, A, B)
sampler = torch.utils.data.sampler.SubsetRandomSampler(indices)
これによりB
、小さいデータセットに有効な範囲のインデックスが生成されます。
テスト:
loader = torch.utils.data.DataLoader(set_A, batch_size=1, sampler=sampler)
print(len(loader)) # B
この記事はインターネットから収集されたものであり、転載の際にはソースを示してください。
侵害の場合は、連絡してください[email protected]
コメントを追加