自定义数据集不接受 PyTorch 中的参数

srdg

我正在尝试使用此数据集在 PyTorch 中创建自定义数据它的形状为 (X, 785),X 是样本数,每行包含索引 0 和 784 像素值处的标签。这是我的代码:

from torch.utils.data import Dataset
def SignMNISTDataset(Dataset):

  def __init__(self, csv_file_path, mode='Train'):
    self.labels = []
    self.pixels = []
    self.mode = mode

    data = pd.read_csv(csv_file_path).values
    if self.mode == 'Train':
      self.labels = data[:,0].tolist()
      print("Training labels acquired")

    for idx in range(len(self.labels)):
      self.pixels.append(data[idx][1:].tolist())

  def __len__(self):
    return len(self.labels)

  def __getitem__(self, idx):
    pixels = self.pixels[idx]
    if self.mode == 'Train':
      labels = self.labels[idx]
      content = {"pixels":pixels, "label":labels}
    else:
      content = {"pixels":pixels}
    return content

training_data = SignMNISTDataset('sign_mnist_train/sign_mnist_train.csv', 'Train')

在运行时,我收到以下错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-46-0173199f8794> in <module>()
     27     return content
     28 
---> 29 training_data = SignMNISTDataset('sign_mnist_train/sign_mnist_train.csv', 'Train')
     30 from torch.utils.data import DataLoader
     31 

TypeError: SignMNISTDataset() takes 1 positional argument but 2 were given

这究竟是从哪里来的?在对象创建过程中,模式参数是否以某种方式不被读取?我的最终目标是按照本教程创建一个用于对符号字符进行分类的神经网络

我尝试mode在对象创建期间明确提及关键字这就是我得到的 -

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-48-fd796c48dc67> in <module>()
     27     return content
     28 
---> 29 training_data = SignMNISTDataset('sign_mnist_train/sign_mnist_train.csv', mode='Train')

TypeError: SignMNISTDataset() got an unexpected keyword argument 'mode'
吗哪

请使用

class SignMNISTDataset(Dataset):

而不是

def SignMNISTDataset(Dataset):

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章

在 Pytorch 中优化自定义参数

在pytorch中对自定义数据集进行数据预处理(transform.Normalize)

如何将基于自定义图像的数据集加载到Pytorch中以用于CNN?

如何在Pytorch中为图像及其遮罩创建自定义数据集?

从PyTorch自定义数据集的__getitem__中庞大的未压缩tar文件读取图像的最快方法

从<tr>中的ng-repeat不接受自定义过滤器中的参数

PyTorch中的自定义损失功能

pytorch中自定义LSTM的辍学

在Pytorch中为自定义的NN模块定义命名参数

使用PyTorch实施自定义数据集

在pytorch中加载自定义数据集

使用PyTorch加载图像的自定义数据集

在Pytorch中转换自定义数据集时出错

在pytorch中反向传播时自动更新自定义图层参数

如何在pytorch自定义模型的模块类中添加参数?

Pytorch自定义nn模块转发函数中的参数过多

Laravel:无法在自定义命令中传递可选参数出现错误“--default_value”选项不接受值

从目录 Pytorch 中的文件夹加载自定义数据

如何将自定义数据放入Pytorch DataLoader中?

datetime.strptime不接受自定义函数传递的参数

带有 dplyr 函数的自定义函数不接受参数值

使用React自定义钩子的API调用不接受更新的参数

在 pytorch 闪电中自定义优化器

在Pytorch中自定义距离损失功能?

关于 Tensorflow 和 PyTorch 中的自定义操作

在pytorch中为CNN设置自定义内核

在Pytorch自定义模块中添加模块

PyTorch中的自定义卷积核和环形卷积

在 pytorch 中创建自定义梯度下降