import os import pickle import random from sklearn import datasets import numpy as np PICKLE_DATA_FILENAME = "mnist.pickle" train_x = None train_y = None dev_x = None dev_y = None test_x = None test_y = None def load_mnistdata(): global train_x, train_y, dev_x, dev_y, test_x, test_y if not os.path.exists(PICKLE_DATA_FILENAME): X, y = datasets.fetch_openml('mnist_784', return_X_y=True, cache=True, as_frame= False) with open(PICKLE_DATA_FILENAME,"wb") as file: pickle.dump(X,file) pickle.dump(y,file) else: with open(PICKLE_DATA_FILENAME,"rb") as file: X = pickle.load(file) y = pickle.load(file) #i = random.randint(0,len(X) - 1) #plt.imshow(X[0].reshape(28,28),cmap='gray',interpolation='none') #plt.show() #simple normalize X = X / 255 y = np.array([int(i) for i in y]) Y = np.eye(10)[y] train_x,train_y = X[0:3500*17], Y[0:3500*17] dev_x,dev_y = X[3500*17:3500*18], Y[3500*17:3500*18] test_x,test_y = X[3500*18:3500*20], Y[3500*18:3500*20] return ((train_x, train_y),(dev_x,dev_y),(test_x,test_y))