"""Solve linear coupled elliptic-parabolic PDE system
using FEniCS
"""

import argparse
import numpy as np
import pandas as pd
import dolfin as df

import utils
from solvers import implicit, semiexplicit, parellip, ellippar


def convergence():
    """Convergence study for numerical methods
    for solving the coupled linear Elliptic-Parabolic PDE
    system using FEniCS, with the parameters specified
    in the config files.
    """

    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('problem_file', type=str,
                        default=None, help='Biot poroelasticity parameters (YAML file)')
    parser.add_argument('discretization_file', type=str,
                        default=None, help='Discretization and FEniCS parameters (YAML file)')

    args = parser.parse_args()

    problem_params = utils.get_config(args.problem_file)
    discretization_params = utils.get_config(args.discretization_file)

    # -----------------------------------------------------
    # set up the problem
    # -----------------------------------------------------
    # data and boundary conditions
    if problem_params['DIM'] == 2:
        u_D = df.Expression((problem_params['U_X_CPP'], problem_params['U_Y_CPP']), degree=1)
        p_D = df.Expression(problem_params['P_CPP'], degree=1)
        f = df.Expression((problem_params['F_X_CPP'], problem_params['F_Y_CPP']), degree=1, t=0.0)
        g = df.Expression(problem_params['G_CPP'], degree=1, t=0.0, cG=problem_params['CONSTANT_G'][0])

        p_0 = df.Expression(problem_params['P_0_CPP'], degree=1, cP=problem_params['CONSTANT_P_0'][0])

        boundary = problem_params['BOUNDARY']
    else:
        u_D = df.Expression((problem_params['U_X_CPP'], problem_params['U_Y_CPP'], problem_params['U_Z_CPP']), degree=1)
        p_D = df.Expression(problem_params['P_CPP'], degree=1)
        f = df.Expression((problem_params['F_X_CPP'], problem_params['F_Y_CPP'], problem_params['F_Z_CPP']),
                          degree=1, t=0.0)
        g = df.Expression(problem_params['G_CPP'], degree=1, t=0.0, cG=problem_params['CONSTANT_G'][0])

        p_0 = df.Expression(problem_params['P_0_CPP'], degree=1, cP=problem_params['CONSTANT_P_0'][0])

        boundary = problem_params['BOUNDARY']

    # discretization
    nt_list = [2 ** i for i in np.arange(discretization_params['NT_MIN'], discretization_params['NT_MAX'] + 1)]
    ns_list = [2 ** i for i in np.arange(discretization_params['NS_MIN'], discretization_params['NS_MAX'] + 1)]

    # -------------------------------------------------------------------------------------
    #   reference call
    # -------------------------------------------------------------------------------------
    nt_ref = 2 ** discretization_params['NT_REF']
    ns_ref = 2 ** discretization_params['NS_REF']

    # compute reference solution
    print('\n\n-------------------')
    print('compute reference solution')
    solver = implicit(ns_ref, nt_ref, f, g, u_D, p_D, boundary, p_0,
                      problem_params, discretization_params)
    u_ref, p_ref, diagnostics = solver.u_n, solver.p_n, solver.diagnostics

    # ------------------------------------------------------------------------------------
    # assemble matrices for energy norms
    # ------------------------------------------------------------------------------------
    V, Q = utils.construct_func_spaces(ns_ref, problem_params, discretization_params)
    A, B, C = utils.construct_energy_norm_operators(V, Q, problem_params)
    H_V, dH_V, H_Q, dH_Q = utils.construct_norm_operators(V, Q)

    # Save graphics
    tau = problem_params['T'] / nt_ref
    domain_size = max([x2-x1 for x1, x2 in zip(problem_params['POINT1'], problem_params['POINT2'])])
    h = domain_size/ns_ref

    ns_cs = 2 ** discretization_params['NS_CS']
    # compute coupling strength
    omega_true, C_d1, C_d2, c_a, c_c = utils.compute_coupling_strength(ns_cs, problem_params['M'],
                                                                       problem_params, discretization_params)
    print('\n\n-------------------')
    print('The coupling strength is:\n')
    print(omega_true)

    n_rows = len(nt_list) * len(ns_list)
    columns = ['h', 'tau',
               'implicit_rt_avg', 'implicit_err_u_A', 'implicit_err_u_L2', 'implicit_err_u_H1',
               'implicit_err_p_B', 'implicit_err_p_C', 'implicit_err_p_L2', 'implicit_err_p_H1',
               'semiexplicit_rt_avg', 'semiexplicit_err_u_A', 'semiexplicit_err_u_L2', 'semiexplicit_err_u_H1',
               'semiexplicit_err_p_B', 'semiexplicit_err_p_C', 'semiexplicit_err_p_L2', 'semiexplicit_err_p_H1',
               'fixedstress_rt_avg', 'fixedstress_err_u_A', 'fixedstress_err_u_L2', 'fixedstress_err_u_H1',
               'fixedstress_err_p_B', 'fixedstress_err_p_C', 'fixedstress_err_p_L2', 'fixedstress_err_p_H1',
               'undrained_rt_avg', 'undrained_err_u_A', 'undrained_err_u_L2', 'undrained_err_u_H1',
               'undrained_err_p_B', 'undrained_err_p_C', 'undrained_err_p_L2', 'undrained_err_p_H1',
               'fixedstress_n_avg', 'undrained_n_avg']
    df_summary = pd.DataFrame([], index=np.arange(n_rows), columns=columns)

    tau_list = [problem_params['T'] / nt for nt in nt_list]
    h_list = [domain_size / ns for ns in ns_list]

    # ------------------------------------------------------------------------------------
    # convergence implicit
    # ------------------------------------------------------------------------------------
    print('\n\n-------------------')
    print('convergence implicit euler')

    cnt = 0
    for (j, ns) in enumerate(ns_list):
        for (i, nt) in enumerate(nt_list):
            solver = implicit(ns, nt,
                              f, g, u_D, p_D, boundary, p_0,
                              problem_params, discretization_params)
            u_n, p_n, diagnostics = solver.u_n, solver.p_n, solver.diagnostics

            # interpolate (u_n, p_n) in reference spaces!!!
            errors = utils.compile_errors(u_ref, p_ref, df.interpolate(u_n, V), df.interpolate(p_n, Q),
                                          A, B, C, H_V, dH_V, H_Q, dH_Q)
            print('Error: %.2e' % (errors['err_u_A'] + errors['err_p_C']))
            print('Runtime: %.2e' % diagnostics['timings'][-1])

            df_summary.loc[cnt, 'h'] = h_list[j]
            df_summary.loc[cnt, 'tau'] = tau_list[i]
            df_summary.loc[cnt, 'implicit_rt_avg'] = diagnostics['timings'][-1]
            df_summary.loc[cnt, 'implicit_err_u_A'] = errors['err_u_A']
            df_summary.loc[cnt, 'implicit_err_u_L2'] = errors['err_u_L2']
            df_summary.loc[cnt, 'implicit_err_u_H1'] = errors['err_u_H1']
            df_summary.loc[cnt, 'implicit_err_p_B'] = errors['err_p_B']
            df_summary.loc[cnt, 'implicit_err_p_C'] = errors['err_p_C']
            df_summary.loc[cnt, 'implicit_err_p_L2'] = errors['err_p_L2']
            df_summary.loc[cnt, 'implicit_err_p_H1'] = errors['err_p_H1']
            cnt += 1

    # ------------------------------------------------------------------------------------
    # convergence semiexplicit
    # ------------------------------------------------------------------------------------

    print('\n\n-------------------')
    print('convergence semiexplicit')

    cnt = 0
    for (j, ns) in enumerate(ns_list):
        for (i, nt) in enumerate(nt_list):
            solver = semiexplicit(ns, nt,
                                  f, g, u_D, p_D, boundary, p_0,
                                  problem_params, discretization_params)
            u_n, p_n, diagnostics = solver.u_n, solver.p_n, solver.diagnostics

            # interpolate (u_n, p_n) in reference spaces!!!
            errors = utils.compile_errors(u_ref, p_ref, df.interpolate(u_n, V), df.interpolate(p_n, Q),
                                          A, B, C, H_V, dH_V, H_Q, dH_Q)
            print('Error: %.2e' % (errors['err_u_A'] + errors['err_p_C']))
            print('Runtime: %.2e' % diagnostics['timings'][-1])

            df_summary.loc[cnt, 'h'] = h_list[j]
            df_summary.loc[cnt, 'tau'] = tau_list[i]
            df_summary.loc[cnt, 'semiexplicit_rt_avg'] = diagnostics['timings'][-1]
            df_summary.loc[cnt, 'semiexplicit_err_u_A'] = errors['err_u_A']
            df_summary.loc[cnt, 'semiexplicit_err_u_L2'] = errors['err_u_L2']
            df_summary.loc[cnt, 'semiexplicit_err_u_H1'] = errors['err_u_H1']
            df_summary.loc[cnt, 'semiexplicit_err_p_B'] = errors['err_p_B']
            df_summary.loc[cnt, 'semiexplicit_err_p_C'] = errors['err_p_C']
            df_summary.loc[cnt, 'semiexplicit_err_p_L2'] = errors['err_p_L2']
            df_summary.loc[cnt, 'semiexplicit_err_p_H1'] = errors['err_p_H1']
            cnt += 1

    # ------------------------------------------------------------------------------------
    #   convergence fixedstress
    # ------------------------------------------------------------------------------------
    print('\n\n-------------------')
    print('convergence fixedstress')

    cnt = 0
    for (j, ns) in enumerate(ns_list):
        for (i, nt) in enumerate(nt_list):
            solver = parellip(ns, nt,
                              f, g, u_D, p_D, boundary, p_0,
                              problem_params, discretization_params,
                              V, Q, A, C,
                              l_stab=1.0, tol=1e-5, max_iters=20)
            u_n, p_n, diagnostics = solver.u_n, solver.p_n, solver.diagnostics

            # interpolate (u_n, p_n) in reference spaces!!!
            errors = utils.compile_errors(u_ref, p_ref, df.interpolate(u_n, V), df.interpolate(p_n, Q),
                                          A, B, C, H_V, dH_V, H_Q, dH_Q)
            print('Error: %.2e' % (errors['err_u_A'] + errors['err_p_C']))
            print('Runtime: %.2e' % diagnostics['timings'][-1])

            df_summary.loc[cnt, 'h'] = h_list[j]
            df_summary.loc[cnt, 'tau'] = tau_list[i]
            df_summary.loc[cnt, 'fixedstress_rt_avg'] = diagnostics['timings'][-1]
            df_summary.loc[cnt, 'fixedstress_err_u_A'] = errors['err_u_A']
            df_summary.loc[cnt, 'fixedstress_err_u_L2'] = errors['err_u_L2']
            df_summary.loc[cnt, 'fixedstress_err_u_H1'] = errors['err_u_H1']
            df_summary.loc[cnt, 'fixedstress_err_p_B'] = errors['err_p_B']
            df_summary.loc[cnt, 'fixedstress_err_p_C'] = errors['err_p_C']
            df_summary.loc[cnt, 'fixedstress_err_p_L2'] = errors['err_p_L2']
            df_summary.loc[cnt, 'fixedstress_err_p_H1'] = errors['err_p_H1']
            df_summary.loc[cnt, 'fixedstress_n_avg'] = diagnostics['n_iterations'].mean()
            cnt += 1

    # ------------------------------------------------------------------------------------
    # convergence undrained
    # ------------------------------------------------------------------------------------

    print('\n\n-------------------')
    print('convergence undrained')

    cnt = 0
    for (j, ns) in enumerate(ns_list):
        for (i, nt) in enumerate(nt_list):
            solver = ellippar(ns, nt,
                              f, g, u_D, p_D, boundary, p_0,
                              problem_params, discretization_params,
                              V, Q, A, C,
                              l_stab=1.0, tol=1e-5, max_iters=20)
            u_n, p_n, diagnostics = solver.u_n, solver.p_n, solver.diagnostics

            # interpolate (u_n, p_n) in reference spaces!!!
            errors = utils.compile_errors(u_ref, p_ref, df.interpolate(u_n, V), df.interpolate(p_n, Q),
                                          A, B, C, H_V, dH_V, H_Q, dH_Q)
            print('Error: %.2e' % (errors['err_u_A'] + errors['err_p_C']))
            print('Runtime: %.2e' % diagnostics['timings'][-1])

            df_summary.loc[cnt, 'h'] = h_list[j]
            df_summary.loc[cnt, 'tau'] = tau_list[i]
            df_summary.loc[cnt, 'undrained_rt_avg'] = diagnostics['timings'][-1]
            df_summary.loc[cnt, 'undrained_err_u_A'] = errors['err_u_A']
            df_summary.loc[cnt, 'undrained_err_u_L2'] = errors['err_u_L2']
            df_summary.loc[cnt, 'undrained_err_u_H1'] = errors['err_u_H1']
            df_summary.loc[cnt, 'undrained_err_p_B'] = errors['err_p_B']
            df_summary.loc[cnt, 'undrained_err_p_C'] = errors['err_p_C']
            df_summary.loc[cnt, 'undrained_err_p_L2'] = errors['err_p_L2']
            df_summary.loc[cnt, 'undrained_err_p_H1'] = errors['err_p_H1']
            df_summary.loc[cnt, 'undrained_n_avg'] = diagnostics['n_iterations'].mean()
            cnt += 1

    # -------------------------------------------------------------------------------
    # Error comparison plots
    # -------------------------------------------------------------------------------

    df_summary.to_csv('summary.csv')


def test():
    """Solve the coupled linear Elliptic-Parabolic PDE
    system using FEniCS, with the parameters specified
    in the config files.
    """

    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('problem_file', type=str,
                        default=None, help='Biot poroelasticity parameters (YAML file)')
    parser.add_argument('discretization_file', type=str,
                        default=None, help='Discretization and FEniCS paramteres (YAML file)')

    args = parser.parse_args()

    problem_params = utils.get_config(args.problem_file)
    discretization_params = utils.get_config(args.discretization_file)

    # -----------------------------------------------------
    # set up the problem
    # -----------------------------------------------------
    ns_ref = 2 ** 3
    nt_ref = 2 ** 3
    # mesh, discrete spaces, and boundary conditions
    if problem_params['DIM'] == 2:
        u_D = df.Expression((problem_params['U_X_CPP'], problem_params['U_Y_CPP']), degree=1)
        p_D = df.Expression(problem_params['P_CPP'], degree=1)
        f = df.Expression((problem_params['F_X_CPP'], problem_params['F_Y_CPP']), degree=1, t=0.0)
        g = df.Expression(problem_params['G_CPP'], degree=1, t=0.0, cG=problem_params['CONSTANT_G'][0])

        p_0 = df.Expression(problem_params['P_0_CPP'], degree=1, cP=problem_params['CONSTANT_P_0'][0])

        boundary = problem_params['BOUNDARY']
    else:
        u_D = df.Expression((problem_params['U_X_CPP'], problem_params['U_Y_CPP'], problem_params['U_Z_CPP']), degree=1)
        p_D = df.Expression(problem_params['P_CPP'], degree=1)
        f = df.Expression((problem_params['F_X_CPP'], problem_params['F_Y_CPP'], problem_params['F_Z_CPP']),
                          degree=1, t=0.0)
        g = df.Expression(problem_params['G_CPP'], degree=1, t=0.0, cG=problem_params['CONSTANT_G'][0])

        p_0 = df.Expression(problem_params['P_0_CPP'], degree=1, cP=problem_params['CONSTANT_P_0'][0])

        boundary = problem_params['BOUNDARY']

    V, Q = utils.construct_func_spaces(ns_ref, problem_params, discretization_params)
    A, B, C = utils.construct_energy_norm_operators(V, Q, problem_params)
    H_V, dH_V, H_Q, dH_Q = utils.construct_norm_operators(V, Q)

    solver = implicit(ns_ref, nt_ref * 2,
                      f, g, u_D, p_D, boundary, p_0,
                      problem_params, discretization_params)
    u_ref, p_ref, diagnostics = solver.u_n, solver.p_n, solver.diagnostics

    tau = problem_params['T'] / nt_ref
    domain_size = max([x2-x1 for x1, x2 in zip(problem_params['POINT1'], problem_params['POINT2'])])
    h = domain_size/ns_ref

    n_rows = 1
    columns = ['h', 'tau',
               'implicit_rt_avg', 'implicit_err_u_A', 'implicit_err_u_L2', 'implicit_err_u_H1',
               'implicit_err_p_B', 'implicit_err_p_C', 'implicit_err_p_L2', 'implicit_err_p_H1',
               'semiexplicit_rt_avg', 'semiexplicit_err_u_A', 'semiexplicit_err_u_L2', 'semiexplicit_err_u_H1',
               'semiexplicit_err_p_B', 'semiexplicit_err_p_C', 'semiexplicit_err_p_L2', 'semiexplicit_err_p_H1',
               'fixedstress_rt_avg', 'fixedstress_err_u_A', 'fixedstress_err_u_L2', 'fixedstress_err_u_H1',
               'fixedstress_err_p_B', 'fixedstress_err_p_C', 'fixedstress_err_p_L2', 'fixedstress_err_p_H1',
               'undrained_rt_avg', 'undrained_err_u_A', 'undrained_err_u_L2', 'undrained_err_u_H1',
               'undrained_err_p_B', 'undrained_err_p_C', 'undrained_err_p_L2', 'undrained_err_p_H1',
               'fixedstress_n_avg', 'undrained_n_avg']
    df_summary = pd.DataFrame([], index=np.arange(n_rows), columns=columns)

    cnt = 0
    # implicit method
    print('\n\n-------------------')
    print('test implicit method')
    solver = implicit(ns_ref, nt_ref,
                      f, g, u_D, p_D, boundary, p_0,
                      problem_params, discretization_params)
    u_n, p_n, diagnostics = solver.u_n, solver.p_n, solver.diagnostics
    errors = utils.compile_errors(u_ref, p_ref, df.interpolate(u_n, V), df.interpolate(p_n, Q),
                                  A, B, C, H_V, dH_V, H_Q, dH_Q)
    print('Error: %.2e' % (errors['err_u_A'] + errors['err_p_C']))
    print('Runtime: %.2e' % diagnostics['timings'][-1])

    df_summary.loc[cnt, 'h'] = h
    df_summary.loc[cnt, 'tau'] = tau
    df_summary.loc[cnt, 'implicit_rt_avg'] = diagnostics['timings'][-1]
    df_summary.loc[cnt, 'implicit_err_u_A'] = errors['err_u_A']
    df_summary.loc[cnt, 'implicit_err_u_L2'] = errors['err_u_L2']
    df_summary.loc[cnt, 'implicit_err_u_H1'] = errors['err_u_H1']
    df_summary.loc[cnt, 'implicit_err_p_B'] = errors['err_p_B']
    df_summary.loc[cnt, 'implicit_err_p_C'] = errors['err_p_C']
    df_summary.loc[cnt, 'implicit_err_p_L2'] = errors['err_p_L2']
    df_summary.loc[cnt, 'implicit_err_p_H1'] = errors['err_p_H1']

    # semiexplicit method
    print('\n\n-------------------')
    print('test semiexplicit method')
    solver = semiexplicit(ns_ref, nt_ref,
                          f, g, u_D, p_D, boundary, p_0,
                          problem_params, discretization_params)
    u_n, p_n, diagnostics = solver.u_n, solver.p_n, solver.diagnostics
    errors = utils.compile_errors(u_ref, p_ref, df.interpolate(u_n, V), df.interpolate(p_n, Q),
                                  A, B, C, H_V, dH_V, H_Q, dH_Q)
    print('Error: %.2e' % (errors['err_u_A'] + errors['err_p_C']))
    print('Runtime: %.2e' % diagnostics['timings'][-1])

    df_summary.loc[cnt, 'h'] = h
    df_summary.loc[cnt, 'tau'] = tau
    df_summary.loc[cnt, 'semiexplicit_rt_avg'] = diagnostics['timings'][-1]
    df_summary.loc[cnt, 'semiexplicit_err_u_A'] = errors['err_u_A']
    df_summary.loc[cnt, 'semiexplicit_err_u_L2'] = errors['err_u_L2']
    df_summary.loc[cnt, 'semiexplicit_err_u_H1'] = errors['err_u_H1']
    df_summary.loc[cnt, 'semiexplicit_err_p_B'] = errors['err_p_B']
    df_summary.loc[cnt, 'semiexplicit_err_p_C'] = errors['err_p_C']
    df_summary.loc[cnt, 'semiexplicit_err_p_L2'] = errors['err_p_L2']
    df_summary.loc[cnt, 'semiexplicit_err_p_H1'] = errors['err_p_H1']

    # fixedstress method
    print('\n\n-------------------')
    print('test fixedstress method')
    solver = parellip(ns_ref, nt_ref,
                      f, g, u_D, p_D, boundary, p_0,
                      problem_params, discretization_params,
                      V, Q, A, C,
                      l_stab=1.0, tol=1e-5, max_iters=20)
    u_n, p_n, diagnostics = solver.u_n, solver.p_n, solver.diagnostics
    errors = utils.compile_errors(u_ref, p_ref, df.interpolate(u_n, V), df.interpolate(p_n, Q),
                                  A, B, C, H_V, dH_V, H_Q, dH_Q)
    print('Error: %.2e' % (errors['err_u_A'] + errors['err_p_C']))
    print('Runtime: %.2e' % diagnostics['timings'][-1])

    df_summary.loc[cnt, 'h'] = h
    df_summary.loc[cnt, 'tau'] = tau
    df_summary.loc[cnt, 'fixedstress_rt_avg'] = diagnostics['timings'][-1]
    df_summary.loc[cnt, 'fixedstress_err_u_A'] = errors['err_u_A']
    df_summary.loc[cnt, 'fixedstress_err_u_L2'] = errors['err_u_L2']
    df_summary.loc[cnt, 'fixedstress_err_u_H1'] = errors['err_u_H1']
    df_summary.loc[cnt, 'fixedstress_err_p_B'] = errors['err_p_B']
    df_summary.loc[cnt, 'fixedstress_err_p_C'] = errors['err_p_C']
    df_summary.loc[cnt, 'fixedstress_err_p_L2'] = errors['err_p_L2']
    df_summary.loc[cnt, 'fixedstress_err_p_H1'] = errors['err_p_H1']
    df_summary.loc[cnt, 'fixedstress_n_avg'] = diagnostics['n_iterations'].mean()

    # undrained method
    print('\n\n-------------------')
    print('test undrained method')
    solver = ellippar(ns_ref, nt_ref,
                      f, g, u_D, p_D, boundary, p_0,
                      problem_params, discretization_params,
                      V, Q, A, C,
                      l_stab=1.0, tol=1e-5, max_iters=20)
    u_n, p_n, diagnostics = solver.u_n, solver.p_n, solver.diagnostics
    errors = utils.compile_errors(u_ref, p_ref, df.interpolate(u_n, V), df.interpolate(p_n, Q),
                                  A, B, C, H_V, dH_V, H_Q, dH_Q)
    print('Error: %.2e' % (errors['err_u_A'] + errors['err_p_C']))
    print('Runtime: %.2e' % diagnostics['timings'][-1])

    df_summary.loc[cnt, 'h'] = h
    df_summary.loc[cnt, 'tau'] = tau
    df_summary.loc[cnt, 'undrained_rt_avg'] = diagnostics['timings'][-1]
    df_summary.loc[cnt, 'undrained_err_u_A'] = errors['err_u_A']
    df_summary.loc[cnt, 'undrained_err_u_L2'] = errors['err_u_L2']
    df_summary.loc[cnt, 'undrained_err_u_H1'] = errors['err_u_H1']
    df_summary.loc[cnt, 'undrained_err_p_B'] = errors['err_p_B']
    df_summary.loc[cnt, 'undrained_err_p_C'] = errors['err_p_C']
    df_summary.loc[cnt, 'undrained_err_p_L2'] = errors['err_p_L2']
    df_summary.loc[cnt, 'undrained_err_p_H1'] = errors['err_p_H1']
    df_summary.loc[cnt, 'undrained_n_avg'] = diagnostics['n_iterations'].mean()

    df_summary.to_csv('summary.csv')


if __name__ == '__main__':
    # test()
    convergence()
