41 lines
1.1 KiB
Python
41 lines
1.1 KiB
Python
|
import os
|
||
|
import pickle
|
||
|
import random
|
||
|
from sklearn import datasets
|
||
|
import numpy as np
|
||
|
|
||
|
PICKLE_DATA_FILENAME = "mnist.pickle"
|
||
|
|
||
|
train_x = None
|
||
|
train_y = None
|
||
|
dev_x = None
|
||
|
dev_y = None
|
||
|
test_x = None
|
||
|
test_y = None
|
||
|
|
||
|
def load_mnistdata():
|
||
|
global train_x, train_y, dev_x, dev_y, test_x, test_y
|
||
|
if not os.path.exists(PICKLE_DATA_FILENAME):
|
||
|
X, y = datasets.fetch_openml('mnist_784', return_X_y=True, cache=True, as_frame= False)
|
||
|
with open(PICKLE_DATA_FILENAME,"wb") as file:
|
||
|
pickle.dump(X,file)
|
||
|
pickle.dump(y,file)
|
||
|
else:
|
||
|
with open(PICKLE_DATA_FILENAME,"rb") as file:
|
||
|
X = pickle.load(file)
|
||
|
y = pickle.load(file)
|
||
|
|
||
|
#i = random.randint(0,len(X) - 1)
|
||
|
#plt.imshow(X[0].reshape(28,28),cmap='gray',interpolation='none')
|
||
|
#plt.show()
|
||
|
|
||
|
#simple normalize
|
||
|
X = X / 255
|
||
|
|
||
|
y = np.array([int(i) for i in y])
|
||
|
Y = np.eye(10)[y]
|
||
|
|
||
|
train_x,train_y = X[0:3500*17], Y[0:3500*17]
|
||
|
dev_x,dev_y = X[3500*17:3500*18], Y[3500*17:3500*18]
|
||
|
test_x,test_y = X[3500*18:3500*20], Y[3500*18:3500*20]
|
||
|
return ((train_x, train_y),(dev_x,dev_y),(test_x,test_y))
|