#!/usr/bin/env python
# coding: utf-8
import os, sys, time
start = time.time()
end = time.time() # initialise these for measuring time
import tensorflow as tf
import keras
import numpy as np
import pickle
import utils

## HYPERPARAMETERS
#inputs = ['dummy', '0', 'MNIST','permuted', '0', '0.1', '1', '100', '10', '200', '200']
inputs = sys.argv
visible_GPU = inputs[1]
save_outputs_to_log_dir = False

### HYPERPARAMETERS
HP = {\
'dataset'           : inputs[2],\
'task'              : inputs[3],\
'seed'              : int(inputs[4]),\
'optimizer'         : inputs[5],\
'lr'                : float(inputs[6]),\
'batch_size'        : int(inputs[7]),\
'n_epochs_per_task' : int(inputs[8]),\
'first_hidden'      : int(inputs[9]),\
'second_hidden'     : int(inputs[10]),\
}
HP_label = 'FT'
for item in HP.items():
    HP_label += '__'
    HP_label += item[0]
    HP_label += '='
    HP_label += str(item[1])
    

if HP['dataset'] == 'MNIST':
    from keras.datasets import mnist
    (X_train, Y_train), (X_test,Y_test) = mnist.load_data()
if HP['dataset'] == 'FASHION':
    from keras.datasets import fashion_mnist
    (X_train, Y_train), (X_test,Y_test) = fashion_mnist.load_data()

X_train = X_train.reshape([-1, 784])
X_test = X_test.reshape([-1, 784])
X_train = X_train/255
X_test = X_test/255

if HP['task'] == 'permuted':
    dataset_train, dataset_test = utils.make_permuted_dataset((X_train, Y_train),(X_test,Y_test))
    out_dim = 10
    
if HP['task'] == 'split':
    dataset_train, dataset_test = utils.make_split_dataset((X_train, Y_train),(X_test,Y_test))
    out_dim = 2
    
# VISIBLE GPU
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]=visible_GPU
if save_outputs_to_log_dir:
    orig_stdout = sys.stdout
    f = open('logs/log_'+HP_label+'.txt', 'w')
    sys.stdout = f





##############
#GET SERIOUS
##############
tf.reset_default_graph()
tf.random.set_random_seed(HP['seed'])
np.random.seed(HP['seed'])


#############
#PLACEHOLDERS
##############
X_ph = tf.placeholder(tf.float32,[None, 784])
Y_ph = tf.placeholder(tf.float32, [None,out_dim])


###########
#VARIABLES
###########
K = HP['first_hidden']
L = HP['second_hidden']
W0 = tf.Variable(tf.random.uniform([784,K], minval=-tf.sqrt(6/(784+K)),maxval=tf.sqrt(6/(784+K))))
b0 = tf.Variable(tf.ones([K])/10)
W1 = tf.Variable(tf.random.uniform([K,L], minval=-tf.sqrt(6/(L+K)),maxval=tf.sqrt(6/(L+K))))
b1 = tf.Variable(tf.ones([L])/10)
W2 = tf.Variable(tf.random.uniform([L,out_dim], minval=-tf.sqrt(6/(L+out_dim)),maxval=tf.sqrt(6/(out_dim+L))))
b2 = tf.Variable(tf.ones(out_dim)/10)


#############
### MODEL WITH 2 HIDDEN LAYERS, 784-K-L-10
############
H1 = tf.nn.relu(tf.matmul(X_ph,W0) + b0)
H2 = tf.nn.relu(tf.matmul(H1, W1) + b1)
Y_logits = tf.matmul(H2,W2) + b2
Y_pred = tf.nn.softmax(Y_logits)


########
#LOSS
########
#Cross-entropy
CEL = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y_ph, logits=Y_logits))
#accuracy
correct_a = tf.equal(tf.argmax(Y_ph,axis=1), tf.argmax(Y_pred,axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_a,tf.float32))


#######################
#TRAINING AND GRADIENTS
#######################
if HP['optimizer']=='adam':
    trainer = tf.train.AdamOptimizer(HP['lr'])
if HP['optimizer']=='sgd':
    trainer = tf.train.GradientDescentOptimizer(HP['lr'])
train = trainer.minimize(CEL)


#############
#THE SESS
#############
#initialize the model
init = tf.global_variables_initializer()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
print("Hyperparameters:", HP_label)
with tf.Session(config=config) as sess:
    sess.run(init)
    for task_id, data in enumerate(dataset_train):
        print("\nSTARTING TASK ",task_id)
        n_iterations = int(data[0].shape[0] /HP['batch_size'])
        print('n_iterations:  ', n_iterations) 
        
        for epoch in range(HP['n_epochs_per_task']):
            end = time.time()
            print('Time for previous epoch: ', end-start, '\n')
            start = time.time()
            print("Epoch ", epoch, '    Task ', task_id)
            
            perm = np.random.permutation(data[0].shape[0])
            my_data = [[],[]]
            my_data[0] = data[0][perm,:]
            my_data[1] = data[1][perm,:]
            
            for j in range(n_iterations):
                bs = HP['batch_size']
                X_batch = my_data[0][j*bs:(j+1)*bs,:]
                Y_batch = my_data[1][j*bs:(j+1)*bs,:]
                sess.run(train,feed_dict={X_ph:X_batch, Y_ph:Y_batch})
        
        #Evaluate performance on previous tasks
        print("\n")
        test_acc_old_tasks = np.zeros(task_id+1)
        for old_task in range(len(dataset_train)):
            X_batch = dataset_test[old_task][0]
            Y_batch = dataset_test[old_task][1]
            test_acc, test_loss = sess.run([accuracy, CEL], feed_dict={X_ph:X_batch, Y_ph:Y_batch})
            if old_task < task_id+1:
                test_acc_old_tasks[old_task]= test_acc
            print("Task",old_task)
            print("Test acc: ", test_acc)
            print("Test loss: ", test_loss)
            
        print("average so far: ", np.mean(test_acc_old_tasks))
        print(" ")
 
if save_outputs_to_log_dir:
    sys.stdout = orig_stdout
    f.close()
    
file = open('summary.txt', 'a+')
file.write(str(np.mean(test_acc_old_tasks))+' '+HP_label+'\n')
file.close()
