"""Plotting utilities for the fields
"""

import numpy as np
import pandas as pd
import scipy.interpolate
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.mplot3d import Axes3D
import tikzplotlib as tkl
plt.ioff()

p_path = 'plots/'
s_path = '../linep/'

T = 1.0
ns_list = [2**i for i in np.arange(3)+3]
nt_list = [2**i for i in np.arange(3)+4]
l_stab_list = [i for i in np.arange(0.5, 1.75, 0.25)]
method_list = ['implicit', 'semiexplicit', 'fixedstress', 'undrained']
omega_list = [omega for omega in np.arange(0.5, 25, 0.5)]
t_list = [0.09375, 0.25000, 0.5000, 1.0000]

colors = [plt.cm.jet(x) for x in np.linspace(0.0, 1.0, 8)]
markers = ['^', '*', 'o', '+', 'd', 's', 'h', 'v']

columns_ni = ['t', 'u_max', 'p_max', 'u_error_e', 'p_error_e', 'run_time']


def plot_fields(t, u, p, div_u, f_name, save_flag=False, n_points=50):
    """
    Plots the field values at a given time t
    :param t: time point at which field values are to be plotted
    :param u: displacement field
    :param p: pressure field
    :param f_name: simulation parameter key
    :param save_flag: flag to save the plots
    :return:
    """

    # mesh to plot the fields on
    xx = np.linspace(0, 1, n_points)
    yy = np.linspace(0, 1, n_points)
    xx, yy = np.meshgrid(xx, yy)

    # plot pressure field
    Q = p.function_space()
    coordinates_dofs = Q.tabulate_dof_coordinates()
    x = coordinates_dofs[:, 0]
    y = coordinates_dofs[:, 1]
    p_vec = scipy.interpolate.griddata((x, y), p.vector().get_local(), (xx, yy), method='linear')

    fig = plt.figure(figsize=(10, 10))
    p_name = p_path + 'p_%.4f_' % t + f_name
    plot = plt.imshow(p_vec)
    plt.title('$p$ at $t=%.4f$' % t)
    divider = make_axes_locatable(plt.gca())
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(plot, cax=cax)
    plt.savefig(p_name + '.png', bbox_inches='tight')
    tkl.save(p_name + '.tex')

    # 3D plot
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_surface(xx, yy, p_vec)
    p_name = p_path + 'p3_%.4f_' % t + f_name
    plt.savefig(p_name + '.png', bbox_inches='tight')
    tkl.save(p_name + '.tex')

    # divergence of u $\nabla\cdot\,u$
    div_u_vec = scipy.interpolate.griddata((x, y), div_u.vector().get_local(), (xx, yy), method='linear')
    fig = plt.figure(figsize=(10, 10))
    plt.title('$\\nabla\\cdot u$ at $t=%.4f$' % t)
    plot = plt.imshow(div_u_vec)
    divider = make_axes_locatable(plt.gca())
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(plot, cax=cax)
    p_name = p_path + 'div_u_%.4f_' % t + f_name
    plt.savefig(p_name + '.png')
    tkl.save(p_name + '.tex')

    # 3D plot
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_surface(xx, yy, div_u_vec)
    p_name = p_path + 'div_u3_%.4f_' % t + f_name
    plt.savefig(p_name + '.png', bbox_inches='tight')
    tkl.save(p_name + '.tex')

    # displacement field
    V = u.function_space()
    V0 = V.sub(0).collapse()
    u_x, u_y = u.split(deepcopy=True)

    coordinates_dofs = V0.tabulate_dof_coordinates()
    x = coordinates_dofs[:, 0]
    y = coordinates_dofs[:, 1]
    ux_vec = scipy.interpolate.griddata((x, y), u_x.vector().get_local(), (xx, yy), method='linear')
    uy_vec = scipy.interpolate.griddata((x, y), u_y.vector().get_local(), (xx, yy), method='linear')

    # u_x
    fig = plt.figure(figsize=(10, 10))
    plot = plt.imshow(ux_vec)
    plt.title('$u_x$ at $t=%.4f$' % t)
    divider = make_axes_locatable(plt.gca())
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(plot, cax=cax)
    p_name = p_path + 'u_x_%.4f_' % t + f_name
    plt.savefig(p_name + '.png')
    tkl.save(p_name + '.tex')

    # 3D plot
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_surface(xx, yy, p_vec)
    p_name = p_path + 'ux3_%.4f_' % t + f_name
    plt.savefig(p_name + '.png', bbox_inches='tight')
    tkl.save(p_name + '.tex')

    # u_y
    fig = plt.figure(figsize=(10, 10))
    plot = plt.imshow(uy_vec)
    plt.title('$u_y$ at $t=%.4f$' % t)
    divider = make_axes_locatable(plt.gca())
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(plot, cax=cax)
    p_name = p_path + 'u_y_%.4f_' % t + f_name
    plt.savefig(p_name + '.png')
    tkl.save(p_name + '.tex')

    # 3D plot
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_surface(xx, yy, p_vec)
    p_name = p_path + 'uy3_%.4f_' % t + f_name
    plt.savefig(p_name + '.png', bbox_inches='tight')
    tkl.save(p_name + '.tex')

    # Quiver plot for displacement field
    skip = (slice(None, None, 6), slice(None, None, 6))
    fig, ax = plt.subplots()
    ax.set_title('                     $\|u\|$', loc='left')
    quiver = ax.quiver(xx[skip], yy[skip], ux_vec[skip], uy_vec[skip], color='b')

    ax.quiverkey(quiver, X=0.6, Y=1.01, U=5e-10,
                 label=r'$5.0\times 10^{=10} \;\; \mathrm{m s}^{-1}$', labelpos='E', coordinates='axes')
    ax.set_aspect('equal', )
    plt.tight_layout()
    p_name = p_path + 'u_%.4f_' % t + f_name
    plt.savefig(p_name + '.png', bbox_inches='tight')
    plt.savefig(p_name + '.svg', bbox_inches='tight')
    # tkl.save(p_name + '.tex') # quiver plots is not supported by tikzplotlib

    plt.close()


def plot_avg_iters(method, h, df_avg_iters, save_flg=False):
    tau_df = df_avg_iters.groupby(['h']).get_group(h)
    fig = plt.figure(figsize=(10, 10))
    cnt = 0
    legend = ['$\\tau=%.2e$' % (T/nt) for nt in nt_list]
    for tau, df in tau_df.groupby(['tau']):
        plt.plot(l_stab_list, df['n_iter'],
                 color=colors[cnt], marker=markers[cnt], linewidth=2)
        cnt += 1
    plt.grid(True)
    plt.xlabel('$L$')
    plt.ylabel('Average number of iterations')
    legend = (*legend,)
    plt.legend(legend, bbox_to_anchor=(1.04, 0.5), loc='center left')
    plt.title('Average number of iterations for $h=%.2e$' % h)
    if save_flg:
        f_name = p_path + method[:2] + '_avg_iters' + '_h=%.2e' % h
        plt.savefig(f_name+'.png', bbox_inches='tight')
        tkl.save(f_name+'.tex')


def plot_error_h(method, df, t, field='u', err_type='L2', save_flg=False, log_log=False):
    error = method + '_err_' + field + '_' + err_type
    if field == 'u':
        y_label = '$||u(T) - u^{N}_{h}||_{%s} / ||u(T)||_{%s}$' % (err_type, err_type)
    else:
        y_label = '$||p(T) - p^{N}_{h}||_{%s} / ||p(T)||_{%s}$' % (err_type, err_type)

    fig = plt.figure(figsize=(10, 10))
    cnt = 0
    legend = []
    h_list = []
    for h, group in df.groupby('h'):
        if log_log:
            plt.loglog(group['tau'], group[error],
                       color=colors[cnt], marker=markers[cnt], linewidth=2)
        else:
            plt.plot(group['tau'], group[error],
                     color=colors[cnt], marker=markers[cnt], linewidth=2)

        legend.append('$h=%.2e$' % h)
        h_list.append(h)
        cnt += 1

    if log_log:
        plt.loglog(h_list, h_list, linestyle='dashed', color='black')
        legend.append('linear')
    else:
        plt.plot(h_list, h_list, linestyle='dashed', color='black')
        legend.append('linear')

    plt.grid(True)
    plt.xlabel('$\\tau$')
    plt.ylabel(y_label)
    #     plt.ylabel('Error for $%s$' % field)
    legend = (*legend,)
    plt.legend(legend, bbox_to_anchor=(1.04, 0.5), loc='center left')
    plt.title('Error for $%s$ at $t=%.4f$' % (field, t))
    if save_flg:
        f_name = p_path + method[:2] + '_tau_err_' + field + '_%s_t=%.4f' % (err_type, t)
        plt.savefig(f_name + '.png', bbox_inches='tight')
        tkl.save(f_name + '.tex')


def plot_error_tau(method, df, t, field='u', err_type='L2', save_flg=False, log_log=False):
    error = method + '_err_' + field + '_' + err_type
    if field == 'u':
        y_label = '$||u(T) - u^{N}_{h}||_{%s} / ||u(T)||_{%s}$' % (err_type, err_type)
    else:
        y_label = '$||p(T) - p^{N}_{h}||_{%s} / ||p(T)||_{%s}$' % (err_type, err_type)

    fig = plt.figure(figsize=(10, 10))
    cnt = 0
    legend = []
    tau_list = []
    for tau, group in df.groupby('tau'):
        if log_log:
            plt.loglog(group['h'], group[error],
                       color=colors[cnt], marker=markers[cnt], linewidth=2)
        else:
            plt.plot(group['h'], group[error],
                     color=colors[cnt], marker=markers[cnt], linewidth=2)

        legend.append('$\\tau=%.2e$' % tau)
        tau_list.append(tau)
        cnt += 1

    if log_log:
        plt.loglog(tau_list, tau_list, linestyle='dashed', color='black')
        legend.append('linear')
    else:
        plt.plot(tau_list, tau_list, linestyle='dashed', color='black')
        legend.append('linear')

    plt.grid(True)
    plt.xlabel('$h$')
    plt.ylabel(y_label)
#     plt.ylabel('Error for $%s$' % field)
    legend = (*legend,)
    plt.legend(legend, bbox_to_anchor=(1.04, 0.5), loc='center left')
    plt.title('Error for $%s$ at $t=%.4f$' % (field, t))
    if save_flg:
        f_name = p_path + method[:2] + '_h_err_' + field + '_%s_t=%.4f' % (err_type, t)
        plt.savefig(f_name+'.png', bbox_inches='tight')
        tkl.save(f_name+'.tex')


# ---------------------------------------------------------------------
#   Of all methods
# ---------------------------------------------------------------------
def plot_error_h_all(tau, t, df, field='u', err_type='L2', save_flg=False, log_log=False):
    if field == 'u':
        y_label = '$||u(T) - u^{N}_{h}||_{%s} / ||u(T)||_{%s}$' % (err_type, err_type)
    else:
        y_label = '$||p(T) - p^{N}_{h}||_{%s} / ||p(T)||_{%s}$' % (err_type, err_type)

    df = df.groupby(['tau']).get_group(tau)
    fig = plt.figure(figsize=(10, 10))
    cnt = 0
    legend = ['%s' % method for method in method_list]  # for $\\tau=%.2e$
    for method in method_list:
        error = method + '_err_' + field + '_' + err_type
        if log_log:
            plt.loglog(df['h'], df[error],
                       color=colors[cnt], marker=markers[cnt], linewidth=2)  # markers[cnt]
        else:
            plt.plot(df['h'], df[error],
                     color=colors[cnt], marker=markers[cnt], linewidth=2)  # markers[cnt]
        cnt += 1

    if log_log:
        plt.loglog(df['h'], df['h'], linestyle='dashed', color='black')
        legend.append('linear')
    else:
        plt.plot(df['h'], df['h'], linestyle='dashed', color='black')
        legend.append('linear')

    plt.grid(True)
    plt.xlabel('$h$')
    plt.ylabel(y_label)
    #     plt.ylabel('Error for $%s$' % field)
    legend = (*legend,)
    plt.legend(legend, bbox_to_anchor=(1.04, 0.5), loc='center left')
    plt.title('Error for $%s$ with $\\tau=%.2e$  at $t=%.4f$' % (field, tau, t))
    if save_flg:
        f_name = p_path + 'err_' + field + '_%s_tau=%.2e_t=%.4f' % (err_type, tau, t)
        plt.savefig(f_name + '.png', bbox_inches='tight')
        tkl.save(f_name + '.tex')


def plot_error_tau_all(h, t, df, field='u', err_type='L2', save_flg=False, log_log=False):
    if field == 'u':
        y_label = '$||u(T) - u^{N}_{h}||_{%s} / ||u(T)||_{%s}$' % (err_type, err_type)
    else:
        y_label = '$||p(T) - p^{N}_{h}||_{%s} / ||p(T)||_{%s}$' % (err_type, err_type)

    df = df.groupby(['h']).get_group(h)
    fig = plt.figure(figsize=(10, 10))
    cnt = 0
    legend = ['%s' % method for method in method_list]  # for $\\tau=%.2e$
    for method in method_list:
        error = method + '_err_' + field + '_' + err_type
        if log_log:
            plt.loglog(df['tau'], df[error],
                       color=colors[cnt], marker=markers[cnt], linewidth=2)  # markers[cnt]
        else:
            plt.plot(df['tau'], df[error],
                     color=colors[cnt], marker=markers[cnt], linewidth=2)  # markers[cnt]
        cnt += 1

    if log_log:
        plt.loglog(df['tau'], df['tau'], linestyle='dashed', color='black')
        legend.append('linear')
    else:
        plt.plot(df['tau'], df['tau'], linestyle='dashed', color='black')
        legend.append('linear')

    plt.grid(True)
    plt.xlabel('$\\tau$')
    plt.ylabel(y_label)
    #     plt.ylabel('Error for $%s$' % field)
    legend = (*legend,)
    plt.legend(legend, bbox_to_anchor=(1.04, 0.5), loc='center left')
    plt.title('Error for $%s$ with $h=%.2e$  at $t=%.4f$' % (field, h, t))
    if save_flg:
        f_name = p_path + 'err_' + field + '_%s_h=%.2e_t=%.4f' % (err_type, h, t)
        plt.savefig(f_name + '.png', bbox_inches='tight')
        tkl.save(f_name + '.tex')


def compute_weak_coupling(h, tau, t, omega_list, omega_true_list, save_flag=False):
    n_rows = len(nt_list)*len(ns_list)
    df_error = pd.DataFrame([], index=np.arange(n_rows),
                            columns=['h', 'tau', 'omega', 'u_error', 'p_error', 'run_time'])
    cnt = 0
    for (j, omega) in enumerate(omega_list):
        f_name = 'h=%.2e_tau=%.2e_omega=%.2f.csv' % (h, tau, omega)
        f_path = 'test2/summary/' + f_name
        df = pd.read_csv(f_path)
        df = df[columns_ni]

        df_error.loc[cnt, 'h'] = h
        df_error.loc[cnt, 'tau'] = tau
        df_error.loc[cnt, 'omega'] = omega_true_list[j]
        cond = df['t'] == t
        df_error.loc[cnt, 'u_error'] = df.loc[cond, 'u_error_e'].values[0]
        df_error.loc[cnt, 'p_error'] = df.loc[cond, 'p_error_e'].values[0]
        df_error.loc[cnt, 'run_time'] = df.loc[cond, 'run_time'].values[0]

        cnt += 1

    return df_error


def plot_error_weak_coupling(h, t, df, field='u', save_flg=False, log_scale=False):
    error = field+'_error'
    fig = plt.figure(figsize=(10, 10))
    legend = []
    cnt = 0
    for tau, omega_df in df.groupby(['tau']):
        if log_scale:
            plt.semilogy(omega_list, omega_df[error], color=colors[cnt], marker='^', linewidth=2)
        else:
            plt.plot(omega_list, omega_df[error], color=colors[cnt], marker='^', linewidth=2)
        legend.append('$\\tau=%.2e$' % tau)
        cnt += 1
    plt.grid(True)
    plt.xlabel('$\\omega$')
    plt.ylabel('Error for %s'%(field))
    legend = (*legend,)
    plt.legend(legend, bbox_to_anchor=(1.04, 0.5), loc='center left')
    plt.title('Error for $h=%.2e$  at $t=%.4f$' % (h, t))
    if save_flg:
        f_name = 'plots/'+'coupling_' + field +'_h=%.2e_t=%.4f' % (h, t)
        plt.savefig(f_name+'.png', bbox_inches='tight')
        tkl.save(f_name+'.tex')

