"""Create plots from the parameteric study
"""

import os
import argparse
import numpy as np
import pandas as pd
import scipy.interpolate
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.mplot3d import Axes3D
import tikzplotlib as tkl

import utils
from plot_utils import *

p_path = 'plots/'
s_path = './'

save_dir = os.getcwd() + '/' +p_path
if not os.path.isdir(save_dir):
    os.mkdir(save_dir)

colors = [plt.cm.jet(x) for x in np.linspace(0.0, 1.0, 5)]
markers = ['^', '*', 'o', '+', 'd']

columns_ni = ['t', 'u_max', 'p_max', 'u_error_e', 'p_error_e', 'run_time']

plt.rcParams['font.size'] = 14
plt.rcParams['xtick.color'] = 'k'
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.color'] = 'k'
plt.rcParams['ytick.labelsize'] = 12
plt.rcParams['axes.labelcolor'] = 'darkred'
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['axes.titlecolor'] = 'darkred'
plt.rcParams['axes.titlesize'] = 16
# plt.rcParams.update(plt.rcParamsDefault)


def plot_simulation_summary():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('--summary_file', type=str,
                        default=None, help='Summary of the parametric study (csv file)')
    parser.add_argument('--problem_file', type=str,
                        default=None, help='Biot poroelasticity parameters (YAML file)')
    parser.add_argument('--discretization_params', type=str,
                        default=None, help='Discretization params (YAML file)')
    parser.add_argument('--log_log', type=bool,
                        default=False, help='Toggle to True to plot in log-log scale')
    parser.add_argument('--save_flg', type=bool,
                        default=False, help='Toggle to True to save the plots')

    args = parser.parse_args()

    df = pd.read_csv(args.summary_file)
    problem_params = utils.get_config(args.problem_file)
    discretization_params = utils.get_config(args.discretization_params)

    method_list = ['implicit', 'semiexplicit', 'fixedstress', 'undrained']
    nt_list = [2 ** i for i in np.arange(discretization_params['NT_MIN'], discretization_params['NT_MAX'] + 1)]
    ns_list = [2 ** i for i in np.arange(discretization_params['NS_MIN'], discretization_params['NS_MAX'] + 1)]
    l_stab_list = [i for i in np.arange(discretization_params['L_STAB_MIN'],
                                        discretization_params['L_STAB_MAX'] + discretization_params['L_STAB_STEP'],
                                        discretization_params['L_STAB_STEP'])]
    omega_list = [omega for omega in np.arange(discretization_params['OMEGA_MIN'],
                                               discretization_params['OMEGA_MAX'] + discretization_params['OMEGA_STEP'],
                                               discretization_params['OMEGA_STEP'])]

    norm_types_u = ['A', 'L2', 'H1']
    norm_types_p = ['B', 'L2', 'H1']
    err_type = 'L2'
    t_list = discretization_params['T_LIST']
    t = t_list[-1]

    tau = problem_params['T'] / nt_list[-1]
    domain_size = max([x2 - x1 for x1, x2 in zip(problem_params['POINT1'], problem_params['POINT2'])])
    h = domain_size / ns_list[0]

    err_types = ['L2', 'H1']
    for err_type in err_types:
        for method in method_list:
            plot_error_h(method, df, t, field='u', err_type=err_type, save_flg=args.save_flg, log_log=args.log_log)
            plot_error_h(method, df, t, field='p', err_type=err_type, save_flg=args.save_flg, log_log=args.log_log)
            plot_error_tau(method, df, t, field='u', err_type=err_type, save_flg=args.save_flg, log_log=args.log_log)
            plot_error_tau(method, df, t, field='p', err_type=err_type, save_flg=args.save_flg, log_log=args.log_log)

        plot_error_h_all(tau, t, df, field='u', err_type=err_type, save_flg=args.save_flg, log_log=args.log_log)
        plot_error_h_all(tau, t, df, field='p', err_type=err_type, save_flg=args.save_flg, log_log=args.log_log)
        plot_error_tau_all(h, t, df, field='u', err_type=err_type, save_flg=args.save_flg, log_log=args.log_log)
        plot_error_tau_all(h, t, df, field='p', err_type=err_type, save_flg=args.save_flg, log_log=args.log_log)

        # h = domain_size / ns_list[0]
        # plot_error_tau_all(h, t, df, field='u', err_type=err_type, save_flg=args.save_flg, log_log=args.log_log)
        # plot_error_tau_all(h, t, df, field='p', err_type=err_type, save_flg=args.save_flg, log_log=args.log_log)

    plot_error_h_all(tau, t, df, field='u', err_type='A', save_flg=args.save_flg, log_log=args.log_log)
    plot_error_h_all(tau, t, df, field='p', err_type='C', save_flg=args.save_flg, log_log=args.log_log)
    plot_error_tau_all(h, t, df, field='u', err_type='A', save_flg=args.save_flg, log_log=args.log_log)
    plot_error_tau_all(h, t, df, field='p', err_type='C', save_flg=args.save_flg, log_log=args.log_log)


if __name__ == '__main__':
    plot_simulation_summary()
