如何在dataloader pytorch中分别加载根目录下的数据

Prithvi Raj Kanaujia

所以,我试图在 pytorch 中加载这个数据集,我在加载它时遇到了问题。

正如你可以看到我检查数据集目录看起来像这样:

    • monet_jpg

    • monet_tfrec

    • 照片_jpg

    • photo_tfrec

所以,我想在单独的数据加载器变量中加载照片和莫奈图像。但是这个方法好像行不通。

编辑:我的意思是 monet_ds 和 photo_ds 只返回 monet 图像(而 photo_ds 应该从 photo_jpg 返回图像)

我正在尝试通过此代码加载数据:

import torchvision.datasets as dset
import torchvision.utils as vutils
from torch.utils.data import Subset
​
def load_data(dataroot , image_size, batch_size, workers,ngpu,shuffle=True):
    #DataLoading
    # Create the dataset
    dataset = dset.ImageFolder(root=dataroot,
                            transform=transforms.Compose([
                                transforms.Resize(image_size),
                                transforms.CenterCrop(image_size),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            ]))
    print(dataset.class_to_idx)
    #print(dataset.imgs)
    monet_ds = Subset(dataset, range(0,299))
    photo_ds = Subset(dataset, range(300,))
    
    # Create the dataloader
    monet_ds = torch.utils.data.DataLoader(monet_ds, batch_size=batch_size,
                                             num_workers=workers)
    photo_ds = torch.utils.data.DataLoader(photo_ds, batch_size=batch_size,
                                             num_workers=workers)
    # Decide which device we want to run on
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
​
    print("Data loaded...")

root = "../input/gan-getting-started"
monet_ds, photo_ds, device = load_data(root, image_size, batch_size, workers, ngpu)

任何在 pytorch 中完美加载这些数据的帮助都会有很大帮助。谢谢你。

贝瑞尔

似乎它们是完全独立的,因此以下应该可以正常工作:

import os

from torchvision.datasets.folder import default_loader
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


class MonetPhotoDataset(Dataset):
    def __init__(self, root, transform=None):
        self.transform = transform
        self.img_paths = sorted(os.path.join(root, x) for x in os.listdir(root) if x.endswith('.jpg'))

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        sample = default_loader(img_path)
        if self.transform is not None:
            sample = self.transform(sample)
        return sample


def load_data(dataroot, image_size, batch_size, workers, ngpu, shuffle=True):
    # set up transform
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    
    # create datasets
    monet_ds = MonetPhotoDataset(root=os.path.join(dataroot, 'monet_jpg'), transform=transform)
    photo_ds = MonetPhotoDataset(root=os.path.join(dataroot, 'photo_jpg'), transform=transform)

    # create dataloaders
    monet_dl = DataLoader(monet_ds, batch_size=batch_size, num_workers=workers)
    photo_dl = DataLoader(photo_ds, batch_size=batch_size, num_workers=workers)

    # decide which device we want to run on
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
​
    print("Data loaded...")
    return monet_dl, photo_dl, device

root = "../input/gan-getting-started"
monet_dl, photo_dl, device = load_data(root, image_size, batch_size, workers, ngpu)

PS:我保留了 ,load_data因为我假设你依赖它在你的代码中的签名,但否则我不会使用它。另外,我没有测试上面的代码,所以预计会有一些错字,但逻辑是正确的。

请注意,此数据集仅返回图像。

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章

如何保存PyTorch的DataLoader实例?

如何将图像数据加载到 pytorch dataLoader?

如何在Pytorch中简化用于自动编码器的DataLoader

如何在TF Object Detection 2.0中分别加载已保存的Faster R-CNN的两个阶段?

Angular 2如何分别加载2 App根组件

如何使用模态框分别加载文本?

如何将图像加载到Pytorch DataLoader中?

如何将MNIST图像加载到Pytorch DataLoader中?

如何在目录下的目录中加载 .jpg 图像?

Pytorch Dataloader如何处理可变大小的数据?

如何在json中分别返回这些数据?

如何在R中分别绑定多个数据帧?

如何在SD卡(Android 8.0+)的根目录下保存文件?

如何在解决方案根目录下添加“ src”文件夹

Terraform:如何在根目录下创建 api 网关 POST 方法?

如何在chromebook上使用Linux在根目录下显示图形?

如何在ammap上使用Dataloader

如何在webpack中根目录加载CSS

__getitem__的idx在PyTorch的DataLoader中如何工作?

如何使用PyTorch Dataloader从Mosaic增强中获取类标签?

如何将不适合内存的巨大数据集拆分和加载到pytorch Dataloader中?

amCharts:如何正常管理dataLoader中的数据不足

在 Apollo GraphQL 中如何访问 Dataloader 中的数据源?

如何将 XML 元素的内容分别加载到 Python 列表?

PyTorch DataLoader如何与PyTorch数据集交互以转换批次?

使用pytorch DataLoader如何获取两个ndarray(数据和标签)?

PyTorch:如何将DataLoader用于自定义数据集

如何将Pytorch Dataloader转换为numpy数组以使用matplotlib显示图像数据?

如何将自定义数据放入Pytorch DataLoader中?