#!/usr/bin/env python
# coding: utf-8
import os
import numpy as np
import sys
import matplotlib as plt
import tensorflow as tf
import keras


def visualise_batch(X, y):
    print('NEW BATCH')
    for j in range(5):
        i = np.random.randint(len(X))
        label = y[i,:]#np.argmax(la[i,:])
        pixels = X[i,:]
        pixels = pixels.reshape((28, 28))
        plt.title('Label is {label}'.format(label=label))
        plt.imshow(pixels, cmap='gray')
        plt.show()


def mini_batch(data_set, batch_size):
    idx = np.random.randint(0,data_set[1].shape[0], size=batch_size)
    batch_images = data_set[0][idx]
    batch_labels = data_set[1][idx]
    return batch_images, batch_labels


def make_permuted_dataset(training_data, test_data, n_tasks=10):
    #make one hot
    training_labels_one_hot = tf.keras.utils.to_categorical(training_data[1], num_classes=10)
    test_labels_one_hot = tf.keras.utils.to_categorical(training_data[1], num_classes=10)
    #create permutations
    perm = []
    for i in range(n_tasks):
        perm.append(np.random.permutation(784))
    #create array with permuted data
    dataset_train = []
    dataset_test = []
    for i in range(len(perm)):
        dataset_train.append((training_data[0][:,perm[i]], training_labels_one_hot))
        dataset_test.append((training_data[0][:,perm[i]], test_labels_one_hot))
    return dataset_train, dataset_test
    
    
def make_split_dataset(training_data, test_data):
    split_by_class_train, split_by_class_test = split_by_class(training_data, test_data)
    dataset_train, dataset_test = mix_split_classes(split_by_class_train, split_by_class_test)
    return dataset_train, dataset_test
    
    
    
def split_by_class(training_data, test_data):
    labels = np.unique(training_data[1])
    split_classes = [[],[]]
    for j,data in enumerate([training_data, test_data]):
        for i in labels:
            idx = np.in1d(data[1], i)
            X = data[0][idx]
            y = tf.keras.utils.to_categorical(data[1][idx],num_classes=len(labels))
            split_classes[j].append((X,y))
    return split_classes[0], split_classes[1]


def mix_split_classes(split_training_data, split_test_data, group_array=None):
    if group_array == None:
        group_array = [[0,1], [2,3], [4,5], [6,7], [8,9]]
    mix_classes = []
    for j, data in enumerate([split_training_data, split_test_data]):
        mix_classes.append([])
        for i, group in enumerate(group_array):
            images = np.zeros([0, *data[0][0].shape[1:] ])
            labels = np.zeros([0, len(group)])
            for label in group:
                images = np.concatenate((images, data[label][0]))
                labels = np.concatenate((labels, data[label][1][:,group]) )
            mix_classes[j].append((images, labels))
    return mix_classes[0], mix_classes[1]

