我是pytorch和tensorflow用户。为了使用AWS sagemaker的弹性推断,我遇到了Mxnet。
Mxnet gluon数据集api似乎与pytorch的数据集非常相似。
class CustomDataset(mxnet.gluon.data.Dataset):
def __init__(self):
self.train_df = pd.read_csv('/shared/KTUTOR/test_summary_data.csv')
def __getitem__(self, idx):
return mxnet.nd.array(self.train_df.loc[idx, ['TT', 'TF', 'FT', 'FF']], dtype='float64'), mxnet.nd.array(self.train_df.loc[idx, ['p1']], dtype='float64')
def __len__(self):
return len(self.train_df)
我如上所述定义了customdataset,并将数据类型设置为float64。
test_data = mxnet.gluon.data.DataLoader(CustomDataset(), batch_size=8, shuffle=True, num_workers=2)
我用DataLoader包装了数据集,到目前为止没有任何错误。当我将数据传递到网络时,错误会增加。
for epoch in range(1):
for data, label in test_data:
print(data.dtype)
print(label.dtype)
with autograd.record():
output = net(data)
loss = softmax_cross_entropy(output, label)
loss.backward()
trainer.step(batch_size)
net(data)中的错误上升,并且错误消息如下所示。
MXNetError: [07:53:55] src/operator/contrib/../elemwise_op_common.h:135: Check failed: assign(&dattr, vec.at(i)): Incompatible attr in node at 1-th input: expected float64, got float32
Stack trace:
[bt] (0) /root/anaconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x4b09db)
[0x7f00f96519db] ...
当我打印数据和标签的类型时,它们都是float64,但是MXNet告诉我数据的数据类型是float32。有人可以解释为什么会这样吗?在此先感谢。
本文收集自互联网,转载请注明来源。
如有侵权,请联系 [email protected] 删除。
我来说两句