2021-02-13 13:20:59 +09:00
|
|
|
from sklearn import datasets
|
|
|
|
import numpy as np
|
2021-02-13 23:26:40 +09:00
|
|
|
from layer import *
|
|
|
|
import os
|
|
|
|
import pickle
|
|
|
|
import matplotlib
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import random
|
2021-02-20 15:25:54 +09:00
|
|
|
import itertools
|
|
|
|
import math
|
2021-02-25 21:34:10 +09:00
|
|
|
import mnist_load
|
|
|
|
from p4_model import *
|
2021-02-20 15:25:54 +09:00
|
|
|
#matplotlib.use("TkAgg")
|
2021-02-13 23:26:40 +09:00
|
|
|
|
2021-02-25 21:34:10 +09:00
|
|
|
train_set, dev_set, test_set = mnist_load.load_mnistdata()
|
2021-02-13 23:26:40 +09:00
|
|
|
|
2021-02-25 21:34:10 +09:00
|
|
|
train_x,train_y = train_set
|
|
|
|
dev_x,dev_y = dev_set
|
|
|
|
test_x,test_y = test_set
|
2021-02-20 15:25:54 +09:00
|
|
|
|
2021-02-13 13:20:59 +09:00
|
|
|
gen:np.random.Generator = np.random.default_rng()
|
2021-02-25 21:34:10 +09:00
|
|
|
eta = 0.00001
|
2021-02-13 23:26:40 +09:00
|
|
|
|
2021-02-20 15:25:54 +09:00
|
|
|
MiniBatchN = 32
|
2021-02-13 23:26:40 +09:00
|
|
|
|
2021-02-25 21:34:10 +09:00
|
|
|
model = load_or_create_model([300,10])
|
2021-02-20 15:25:54 +09:00
|
|
|
|
|
|
|
end_n = math.floor(3500*17 /MiniBatchN)
|
|
|
|
|
|
|
|
for epoch in range(1):
|
|
|
|
#one epoch
|
|
|
|
for iteration in range(0,end_n):
|
|
|
|
choiced_index = gen.choice(range(0,len(train_x)),MiniBatchN)
|
|
|
|
batch_x = train_x[choiced_index]
|
|
|
|
batch_y = train_y[choiced_index]
|
|
|
|
#batch_x = train_x[MiniBatchN*iteration:MiniBatchN*(iteration+1)]
|
|
|
|
#batch_y = train_y[MiniBatchN*iteration:MiniBatchN*(iteration+1)]
|
|
|
|
model.train_one_iterate(batch_x,batch_y,eta)
|
|
|
|
if (model.iteration-1) % 200 == 0:
|
|
|
|
model.set_checkpoint(dev_x,dev_y)
|
|
|
|
if (model.iteration) % 10 == 0:
|
|
|
|
print(f"iteration {model.iteration+1}")
|
|
|
|
|
2021-02-25 21:34:10 +09:00
|
|
|
J = model.caculate(dev_x,dev_y)
|
2021-02-20 15:25:54 +09:00
|
|
|
loss = np.average(J.numpy())
|
|
|
|
print('testset : avg loss : ',loss)
|
|
|
|
|
|
|
|
confusion = get_confusion(J)
|
|
|
|
accuracy = get_accuracy_from_confusion(confusion)
|
|
|
|
print('accuracy : {:.2f}%'.format(accuracy * 100))
|
|
|
|
if True:
|
|
|
|
save_model(model)
|
|
|
|
|
|
|
|
plt.subplot(1,2,1)
|
2021-02-13 23:26:40 +09:00
|
|
|
plt.title("accuracy")
|
2021-02-20 15:25:54 +09:00
|
|
|
plt.plot([*map(lambda x: x.iteration,model.checkpoints)],
|
|
|
|
[*map(lambda x: x.accuracy,model.checkpoints)]
|
|
|
|
)
|
|
|
|
plt.subplot(1,2,2)
|
|
|
|
plt.title("loss")
|
|
|
|
plt.plot([*map(lambda x: x.iteration,model.checkpoints)],
|
|
|
|
[*map(lambda x: x.loss,model.checkpoints)])
|
2021-02-13 23:26:40 +09:00
|
|
|
plt.show()
|
2021-02-20 15:25:54 +09:00
|
|
|
|
2021-02-13 23:26:40 +09:00
|
|
|
plt.title("confusion matrix")
|
2021-02-20 15:25:54 +09:00
|
|
|
plt.imshow(confusion,cmap='Blues')
|
|
|
|
plt.colorbar()
|
|
|
|
for i,j in itertools.product(range(confusion.shape[0]),range(confusion.shape[1])):
|
|
|
|
plt.text(j,i,"{:}".format(confusion[i,j]),horizontalalignment="center",color="white" if i == j else "black")
|
2021-02-13 23:26:40 +09:00
|
|
|
plt.show()
|