使用Mongo DB的PyTorch DataLoader

帕斯卡

我想知道使用连接到MongoDB的DataLoader是否明智,以及如何实现。

背景

我在一个(本地)MongoDB中有大约2000万个文档。文件数量超出内存容量。我想在数据上训练一个深层的神经网络。到目前为止,我一直先将数据导出到文件系统,并将子文件夹命名为文档的类。但是我发现这种方法毫无意义。如果数据已经很好地保存在数据库中,为什么要先导出(然后删除)。

问题1:

我对吗?直接连接到MongoDB是否有意义?还是有理由不这样做(例如,数据库通常太慢等)?如果数据库太慢(为什么?),可以以某种方式预取数据吗?

问题2:

如何实现PyTorch DataLoader我在网上仅发现很少的代码片段([1][2]),这使我怀疑自己的方法。

程式码片段

我访问MongoDB的一般方法如下。我想这没什么特别的。

import pymongo
from pymongo import MongoClient

myclient = pymongo.MongoClient("mongodb://localhost:27017/")
mydb = myclient["xyz"]
mycol = mydb["xyz_documents"]

query = {
    # some filters
}

results = mycol.find(query)

# results is now a cursor that can run through all docs
# Assume, for the sake of this example, that each doc contains a class name and some image that I want to train a classifier on
希蒙·马斯凯(Szymon Maszke)

介绍

这个有点开放,但是我们尝试一下,如果我在某个地方错了,也请纠正我。

到目前为止,我一直先将数据导出到文件系统,并将子文件夹命名为文档的类。

海事组织这是不明智的,因为:

  • 您实际上是在复制数据
  • 任何时候只要您只想训练新的代码和数据库,就必须重复此操作
  • 您可以一次访问多个数据点,并将它们缓存在RAM中,以供以后重用,而无需多次读取硬盘驱动器(这很繁重)

我对吗?直接连接到MongoDB是否有意义?

上面给出了,可能是(尤其是在清晰,可移植的实现方面)

还是有理由不这样做(例如,数据库通常会变慢等)?

在这种情况下AFAIK DB应该不会变慢,因为它将缓存对其的访问,但是不幸的是我不是数据库专家。现成的数据库实现了许多快速访问的技巧。

可以以某种方式预取数据吗?

是的,如果您只想获取数据,则可以一次加载较大部分的数据(例如1024记录),然后从中返回一批数据(例如batch_size=128

实作

一个人将如何实现PyTorch DataLoader?我在网上发现的代码片段很少([1]和[2]),这使我对自己的方法产生了疑问。

我不确定您为什么要这么做。torch.utils.data.Dataset如您列出的示例所示,您应该追求的目标

我将从与此处类似的简单非优化方法开始,所以:

  • 打开与db的连接,__init__并一直保持使用状态(我会从中创建一个上下文管理器,torch.utils.data.Dataset因此在完成时关闭连接)
  • 我不会将结果转换为list(特别是因为明显的原因您无法将其放入RAM中),因为它错过了生成器的意义
  • 我将在此数据集中执行批处理(batch_size 此处有一个参数)。
  • 我不确定__getitem__函数,但似乎它可以一次返回多个数据点,因此我会用它,它应该允许我们使用num_workers>0(假设mycol.find(query)每次都以相同的顺序返回数据)

鉴于此,我要做的就是遵循这些方针:

class DatabaseDataset(torch.utils.data.Dataset):
    def __init__(self, query, batch_size, path: str, database: str):
        self.batch_size = batch_size

        client = pymongo.MongoClient(path)
        self.db = client[database]
        self.query = query
        # Or non-approximate method, if the approximate method
        # returns smaller number of items you should be fine
        self.length = self.db.estimated_document_count()

        self.cursor = None

    def __enter__(self):
        # Ensure that this find returns the same order of query every time
        # If not, you might get duplicated data
        # It is rather unlikely (depending on batch size), shouldn't be a problem
        # for 20 million samples anyway
        self.cursor = self.db.find(self.query)
        return self

    def shuffle(self):
        # Find a way to shuffle data so it is returned in different order
        # If that happens out of the box you might be fine without it actually
        pass

    def __exit__(self, *_, **__):
        # Or anything else how to close the connection
        self.cursor.close()

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, index):
        # Read takes long, hence if you can load a batch of documents it should speed things up
        examples = self.cursor[index * batch_size : (index + 1) * batch_size]
        # Do something with this data
        ...
        # Return the whole batch
        return data, labels

现在分批处理已完成DatabaseDataset,因此torch.utils.data.DataLoader可以进行batch_size=1您可能需要挤压其他尺寸。

由于MongoDB使用锁(这并不奇怪,请参阅此处num_workers>0应该不是问题。

可能的用法(示意性):

with DatabaseDataset(...) as e:
    dataloader = torch.utils.data.DataLoader(e, batch_size=1)
    for epoch in epochs:
        for batch in dataloader:
            # And all the stuff
            ...
        dataset.shuffle() # after each epoch

记住在这种情况下改组实现!(也可以在上下文管理器中进行改组,您可能想要手动关闭连接或沿这些方式关闭连接)。

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章

創建和使用 PyTorch DataLoader

DataLoader 使用 pytorch 创建数据集

PyTorch DataLoader的VSCode错误?

Pytorch DataLoader shuffle = False?

分割DataLoader PyTorch

PyTorch:使用torchvision.datasets.ImageFolder和DataLoader进行测试

在pytorch中使用dataloader進行採樣替換

使用 pytorch Dataloader 加载并正确显示图像数据集

如何使用PyTorch Dataloader从Mosaic增强中获取类标签?

Pytorch DataLoader内存未释放

如何保存PyTorch的DataLoader实例?

PyTorch:随机播放DataLoader

将 PyTorch 与 cuda 11.1 一起使用时,PyTorch 无法正常工作:Dataloader

在Dataloader中使用带有腌制数据的生成器进行PyTorch

使用pytorch DataLoader如何获取两个ndarray(数据和标签)?

在PyTorch中使用Dataloader迭代数据集时出现IndexError

即使没有使用图像,PyTorch 也要求 DataLoader 中的图像维度

如何将Pytorch Dataloader转换为numpy数组以使用matplotlib显示图像数据?

PyTorch DataLoader 对并行运行的批次使用相同的随机种子

从PyTorch DataLoader获取单个随机示例

Pytorch DataLoader多个数据源

无法通过PyTorch DataLoader进行迭代

Pytorch Dataloader shuffle 与多个数据集

PyTorch Dataloader:数据集在 RAM 中完成

使用Facebook的DataLoader传递参数

Pytorch:在dataloader.dataset上使用torch.utils.random_split()后,数据中缺少批处理大小

使用PyTorch Dataloader将3维和1维特征传递给神经网络

使用Dataloader处理GraphQL字段参数?

如何使用Dataloader创建适当形状的输入?