"""Utility functions
"""

# ----------------------------------------------------------------------
#   Imports
# ----------------------------------------------------------------------

import os
import yaml
import numpy as np
import pandas as pd
import dolfin as df
from ufl import nabla_div
from time import perf_counter as timer


# ----------------------------------------------------------------------
#   Config file reader
# ----------------------------------------------------------------------
def get_config(config_file):
    if config_file:
        with open(config_file, 'r') as file_reader:
            params = yaml.load(file_reader, Loader=yaml.FullLoader)
    return params


# ----------------------------------------------------------------------
#   Define strain and stress
# ----------------------------------------------------------------------
def epsilon(w):
    return 0.5*(df.nabla_grad(w) + df.nabla_grad(w).T)


def sigma(w, dim, lamda, mu):
    return lamda*nabla_div(w)*df.Identity(dim) + 2*mu*epsilon(w)


# ------------------------------------------------------------------------------------
#   assemble matrices for constructing measurement operators in energy norms, H^1, L^2
# ------------------------------------------------------------------------------------
def construct_func_spaces(ns, params, params_disc):
    # mesh and spaces
    if params['DIM'] == 2:
        mesh = df.RectangleMesh(df.Point(params['POINT1']), df.Point(params['POINT2']), ns, ns)
    else:
        mesh = df.BoxMesh(df.Point(params['POINT1']), df.Point(params['POINT2']), ns, ns, ns)

    V = df.VectorFunctionSpace(mesh, params_disc['interp_func'], params_disc['degree_u'])
    Q = df.FunctionSpace(mesh, params_disc['interp_func'], params_disc['degree_p'])

    return V, Q


def construct_energy_norm_operators(V, Q, params):
    # trial and test functions
    u, v = df.TrialFunction(V), df.TestFunction(V)
    p, q = df.TrialFunction(Q), df.TestFunction(Q)
    # bilinear forms and associated matrices
    a = df.inner(sigma(u, params['DIM'], params['LAMBDA'], params['MU']), epsilon(v)) * df.dx
    b = params['KAPPA_NU'] * df.dot(df.grad(p), df.grad(q)) * df.dx
    c = (1/params['M']) * df.inner(p, q) * df.dx
    A = df.assemble(a)
    B = df.assemble(b)
    C = df.assemble(c)

    return A, B, C


def construct_norm_operators(V, Q):
    # trial and test functions
    u, v = df.TrialFunction(V), df.TestFunction(V)
    p, q = df.TrialFunction(Q), df.TestFunction(Q)
    # bilinear forms and associated matrices
    a_l2 = df.inner(u, v) * df.dx
    a_d_l2 = df.inner(df.nabla_grad(u), df.nabla_grad(v)) * df.dx
    c_l2 = df.inner(p, q) * df.dx
    c_d_l2 = df.inner(df.nabla_grad(p), df.nabla_grad(q)) * df.dx
    H_V = df.assemble(a_l2)
    dH_V = df.assemble(a_d_l2)
    H_Q = df.assemble(c_l2)
    dH_Q = df.assemble(c_d_l2)

    return H_V, dH_V, H_Q, dH_Q


# def compute_rel_error(u_r, u, Ar):
#     error_u = u_r.vector() - u.vector()
#
#     den = np.dot(u_r.vector(), Ar*u_r.vector())
#     if den != 0.0:
#         error_u = np.sqrt(np.dot(error_u, Ar*error_u)/den)
#     else:
#         error_u = np.sqrt(np.dot(error_u, Ar * error_u))
#         print('Relative error couldn\'t be computed, hence absolute error is computed')
#
#     return error_u


def compute_error(u_r, u, Ar, rel=False):
    error = u_r.vector() - u.vector()

    error = np.sqrt(np.dot(error, Ar * error))
    if rel is True:
        den = np.sqrt(np.dot(u_r.vector(), Ar * u_r.vector()))
        if den != 0.0:
            error_u = error/den
        else:
            print('Relative error couldn\'t be computed, hence absolute error is computed')

    return error


def compute_norm(u, H_V, dH_V=None, norm='L2'):
    vector = u.vector()
    if norm == 'L2':
        norm = np.sqrt(np.dot(vector, H_V * vector))
    else:
        norm = np.sqrt(np.dot(vector, H_V * vector) + np.dot(vector, dH_V * vector))

    return norm


def compute_error_L2(u_r, u, H_V, rel=False):
    error = u_r.vector() - u.vector()

    error = np.sqrt(np.dot(error, H_V * error))
    if rel is True:
        den = np.sqrt(np.dot(u_r.vector(), H_V * u_r.vector()))
        if den != 0.0:
            error = error/den
        else:
            print('Relative error couldn\'t be computed, hence absolute error is computed')

    return error


def compute_error_H1(u_r, u, H_V, dH_V, rel=False):
    error = u_r.vector() - u.vector()
    error = np.sqrt(np.dot(error, H_V * error) + np.dot(error, dH_V * error))
    if rel is True:
        den = np.sqrt(np.dot(u_r.vector(), H_V * u_r.vector()) + np.dot(u_r.vector(), dH_V * u_r.vector()))
        if den != 0.0:
            error = error/den
        else:
            print('Relative error couldn\'t be computed, hence absolute error is computed')

    return error


def compile_errors(u_r, p_r, u, p, Ar, Br, Cr, H_V, dH_V, H_Q, dH_Q, rel=True):
    errors = dict()
    errors['err_u_A'] = compute_error(u_r, u, Ar, rel=rel)
    errors['err_p_B'] = compute_error(p_r, p, Br, rel=rel)
    errors['err_p_C'] = compute_error(p_r, p, Cr, rel=rel)
    errors['err_u_L2'] = compute_error_L2(u_r, u, H_V, rel=rel)
    errors['err_u_H1'] = compute_error_H1(u_r, u, H_V, dH_V, rel=rel)
    errors['err_p_L2'] = compute_error_L2(p_r, p, H_Q, rel=rel)
    errors['err_p_H1'] = compute_error_H1(p_r, p, H_Q, dH_Q, rel=rel)

    return errors


# -------------------------------------------------------------------------------------
#   compute coupling strength
# -------------------------------------------------------------------------------------
def compute_coupling_strength(ns, M_omega, params, params_disc):
    if params['DIM'] == 2:
        mesh = df.RectangleMesh(df.Point(params['POINT1']), df.Point(params['POINT2']), ns, ns)
    else:
        mesh = df.BoxMesh(df.Point(params['POINT1']), df.Point(params['POINT2']), ns, ns, ns)

    # Function space
    P2 = df.VectorElement(params_disc['interp_func'], mesh.ufl_cell(), params_disc['degree_u'])
    P1 = df.FiniteElement(params_disc['interp_func'], mesh.ufl_cell(), params_disc['degree_p'])
    element = df.MixedElement([P2, P1])
    VQ = df.FunctionSpace(mesh, element)

    # Assembly of stiffness matrices
    (v, q) = df.TestFunctions(VQ)
    (u, p) = df.TrialFunctions(VQ)

    a = df.inner(sigma(u, params['DIM'], params['LAMBDA'], params['MU']), epsilon(v)) * df.dx
    d_trans = params['ALPHA'] * df.div(v) * p * df.dx

    c = (1 / M_omega) * p * q * df.dx
    d = params['ALPHA'] * df.div(u) * q * df.dx

    A = df.assemble(a)
    C = df.assemble(c)
    D_trans = df.assemble(d_trans)
    D = df.assemble(d)

    c_a = np.linalg.svd(A.array())[1].max()
    c_c = np.linalg.svd(C.array())[1].max()
    C_d1 = np.linalg.svd(D_trans.array())[1].max()
    C_d2 = np.linalg.svd(D.array())[1].max()
    omega_true = np.sqrt(C_d1 * C_d2 / (c_a * c_c))

    return omega_true, C_d1, C_d2, c_a, c_c


def compile_summary(t, error_u, error_p, u_sol, p_sol, run_time):
    summ = dict()
    summ['t'] = t
    summ['u_max'] = np.abs(u_sol.vector().get_local()).max()
    summ['p_max'] = np.abs(p_sol.vector().get_local()).max()
    summ['u_error_e'], summ['p_error_e'] = error_u, error_p
    summ['run_time'] = run_time

    return summ



