在Jupyter Notebook Python中加载数据集

曼尼斯·库玛(Manish Kumar)

我在jupyter notebook python中运行以下代码:

# Run some setup code for this notebook.

import random
import numpy as np
from cs231n.data_utils import load_CIFAR10
import matplotlib.pyplot as plt

# This is a bit of magic to make matplotlib figures appear inline in the notebook
# rather than in a new window.
%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# Some more magic so that the notebook will reload external python modules;
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

然后以下说明:

# Load the raw CIFAR-10 data.
cifar10_dir = 'cs231n/datasets/cifar-10-batches-py'
X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)

# As a sanity check, we print out the size of the training and test data.
print ('Training data shape: ', X_train.shape)
print ('Training labels shape: ', y_train.shape)
print ('Test data shape: ', X_test.shape)
print ('Test labels shape: ', y_test.shape)

通过运行第二部分,我可能会遇到以下错误:

---------------------------------------------------------------------------
UnicodeDecodeError                        Traceback (most recent call last)
<ipython-input-5-9506c06e646a> in <module>()
      1 # Load the raw CIFAR-10 data.
      2 cifar10_dir = 'cs231n/datasets/cifar-10-batches-py'
----> 3 X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
      4 
      5 # As a sanity check, we print out the size of the training and test data.

C:\Users\lenovo\assignment1\cs231n\data_utils.py in load_CIFAR10(ROOT)
     20   for b in range(1,6):
     21     f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
---> 22     X, Y = load_CIFAR_batch(f)
     23     xs.append(X)
     24     ys.append(Y)

C:\Users\lenovo\assignment1\cs231n\data_utils.py in load_CIFAR_batch(filename)
      7   """ load single batch of cifar """
      8   with open(filename, 'rb') as f:
----> 9     datadict = pickle.load(f)
     10     X = datadict['data']
     11     Y = datadict['labels']

UnicodeDecodeError: 'ascii' codec can't decode byte 0x8b in position 6: ordinal not in range(128)

如何解决此错误?我正在使用Annaconda3运行此代码。似乎以上代码已在Annaonda2版本中写入。是否有解决这些错误的建议?

只是为了更多细节:

我正在尝试通过链接解决分配:http : //cs231n.github.io/assignments2016/assignment1/

编辑:

添加包含load_CIFAR定义的data_utils.py

import _pickle as pickle
import numpy as np
import os
from scipy.misc import imread

def load_CIFAR_batch(filename):
  """ load single batch of cifar """
  with open(filename, 'rb') as f:
    datadict = pickle.load(f)
    X = datadict['data']
    Y = datadict['labels']
    X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
    Y = np.array(Y)
    return X, Y

def load_CIFAR10(ROOT):
  """ load all of cifar """
  xs = []
  ys = []
  for b in range(1,6):
    f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
    X, Y = load_CIFAR_batch(f)
    xs.append(X)
    ys.append(Y)    
  Xtr = np.concatenate(xs)
  Ytr = np.concatenate(ys)
  del X, Y
  Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
  return Xtr, Ytr, Xte, Yte
马克斯·阿彭坎普

您正在加载的pickle文件很可能是使用python 2生成的。

由于pickle在Python2和Python3中的工作方式存在根本差异,因此您可以尝试使用latin-1编码加载文件,hich假定将0-255直接映射为char。

此方法需要进行完整性检查,因为不能保证产生一致的数据。

本文收集自互联网,转载请注明来源。

如有侵权,请联系 [email protected] 删除。

编辑于
0

我来说两句

0 条评论
登录 后参与评论

相关文章