"""Time integration methods for solving linear
coupled elliptic-parabolic PDE system
"""

import numpy as np
import dolfin as df
from ufl import nabla_div
from time import perf_counter as timer

from utils import sigma, epsilon, compute_error

# TODO: implement start of simulation from a previously computed solution
# possible to continue simulation from a previously computed solution
# For a given solver instance which is already setup:
# set the number of time-steps and call the preferred solver to continue the simulation
# hence possible to use a mix of solvers.

# ---------------------------------------------------------------------------
#   Solver
# ---------------------------------------------------------------------------


class LinEPSolver:
    """A solver class that holds the 4 numerical methods to solve
    linear coupled elliptic-parabolic PDE system"""
    verbose = False

    def __init__(self, ns, nt,
                 f, g, u_D, p_D, boundary, p_0,
                 params, params_disc, t=0.0, steps=None, test=False):
        self.ns = ns
        self.nt = nt
        self.f = f
        self.g = g
        self.u_D = u_D
        self.p_D = p_D
        self.boundary = boundary
        # initial data is not needed, can be stored in the current solution,
        # but at the moment it is stored as well
        self.p_0 = p_0
        self.u_0 = None
        self.up_n = None    # initial data is stored in current solution
        self.params = params
        self.params_disc = params_disc
        self.t = t
        self.steps = steps
        self.test = test

        self.tau = None
        self.h = None
        self.mesh = None
        self.V = None
        self.Q = None
        self.VQ = None
        self.bc = None
        self.bc_u = None
        self.bc_p = None
        self.u_n = None
        self.p_n = None

        self.diagnostics = None
        self.set = False    # switch to True when setup is complete

    def _setup_mesh_size(self):
        tau = self.params['T'] / self.nt
        domain_size = max([x2 - x1 for x1, x2 in zip(self.params['POINT1'], self.params['POINT2'])])
        h = domain_size / self.ns

        self.tau = tau
        self.h = h
        if self.steps is None:
            self.steps = self.nt

    def _setup_mesh(self):
        if self.params['DIM'] == 2:
            mesh = df.RectangleMesh(df.Point(self.params['POINT1']), df.Point(self.params['POINT2']), self.ns, self.ns)
        else:
            mesh = df.BoxMesh(df.Point(self.params['POINT1']), df.Point(self.params['POINT2']), self.ns, self.ns, self.ns)

        self.mesh = mesh

    def _setup_func_spaces(self):
        V = df.VectorFunctionSpace(self.mesh, self.params_disc['interp_func'], self.params_disc['degree_u'])
        Q = df.FunctionSpace(self.mesh, self.params_disc['interp_func'], self.params_disc['degree_p'])

        P2 = df.VectorElement(self.params_disc['interp_func'], self.mesh.ufl_cell(), self.params_disc['degree_u'])
        P1 = df.FiniteElement(self.params_disc['interp_func'], self.mesh.ufl_cell(), self.params_disc['degree_p'])
        element = df.MixedElement([P2, P1])
        VQ = df.FunctionSpace(self.mesh, element)

        bc = [df.DirichletBC(VQ.sub(0), self.u_D, self.boundary), df.DirichletBC(VQ.sub(1), self.p_D, self.boundary)]
        bc_p = df.DirichletBC(Q, self.p_D, self.boundary)
        bc_u = df.DirichletBC(V, self.u_D, self.boundary)

        self.V = V
        self.Q = Q
        self.VQ = VQ
        self.bc = bc
        self.bc_u = bc_u
        self.bc_p = bc_p
        self.u_n = df.Function(V)
        self.p_n = df.Function(Q)

    def _setup_init_cond(self):
        if self.test is False:  # when analytical solution is not given
            # trial and test functions
            u, p = df.TrialFunctions(self.VQ)
            v, q = df.TestFunctions(self.VQ)

            # set t=0.0 in f and g
            self.f.t = 0.0
            self.g.t = 0.0

            # problem formulation
            a0 = df.inner(sigma(u, self.params['DIM'], self.params['LAMBDA'], self.params['MU']), epsilon(v)) * df.dx
            d0_trans = df.inner(self.params['ALPHA'] * nabla_div(v), p) * df.dx
            d0 = df.inner(self.params['ALPHA'] * nabla_div(u), q) * df.dx
            c0 = (1 / self.params['M']) * df.inner(p, q) * df.dx
            b0 = df.inner(self.params['KAPPA_NU'] * df.grad(p), df.grad(q)) * df.dx
            rhs_f = df.dot(self.f, v) * df.dx
            rhs_g = df.dot(self.g, q) * df.dx

            # solve for consistent initial condition
            start = timer()
            up_n = df.Function(self.VQ)
            LHS0, RHS0 = df.assemble_system(a0 + df.dot(p, q) * df.dx,
                                            rhs_f + df.inner(self.params['ALPHA'] * nabla_div(v), self.p_0) * df.dx
                                            + df.dot(self.p_0, q) * df.dx, self.bc)
            df.solve(LHS0, up_n.vector(), RHS0)
            u_0, p_0 = df.split(up_n)
            u_0, p_0 = df.project(u_0, self.V), df.project(p_0, self.Q)
        else:   # when analytical solution is given, interpolate the exact solution onto the mesh
            self.u_D.t = 0.0
            self.p_D.t = 0.0
            u_0 = df.interpolate(self.u_D, self.V)
            p_0 = df.interpolate(self.p_D, self.Q)
            # up_n is not necessary
            fa = df.FunctionAssigner(self.VQ, [self.V, self.Q])
            up_n = df.Function(self.VQ)
            fa.assign(up_n, [u_0, p_0])

        self.p_0 = p_0  # this changes the data type of p_0!! Check for good practice
        self.u_0 = u_0
        # set current solution also to be the initial solution
        self.u_n.assign(u_0)
        self.p_n.assign(p_0)
        self.up_n = up_n

    def setup(self):
        self._setup_mesh_size()
        self._setup_mesh()
        self._setup_func_spaces()
        self._setup_init_cond()
        self.set = True

    # =================================================================
    #   Solvers
    # =================================================================
    # -----------------------------------------------------------------
    #   Implicit solver
    # -----------------------------------------------------------------
    def ImplicitSolver(self):
        print('implicit with tau=%.2e and h=%.2e' % (self.tau, self.h))
        start = timer()

        # trial and test functions
        u, p = df.TrialFunctions(self.VQ)
        v, q = df.TestFunctions(self.VQ)

        up_sol = df.Function(self.VQ)
        fa = df.FunctionAssigner(self.VQ, [self.V, self.Q])
        up_n = df.Function(self.VQ)
        fa.assign(up_n, [self.u_n, self.p_n])
        u_n, p_n = df.split(up_n)

        # implicit
        a = df.inner(sigma(u, self.params['DIM'], self.params['LAMBDA'], self.params['MU']), epsilon(v)) * df.dx
        d_trans = self.params['ALPHA'] * df.div(v) * p * df.dx

        c = (1/self.params['M']) * p * q * df.dx
        c_R = (1/self.params['M']) * p_n * q * df.dx
        d = self.params['ALPHA'] * df.div(u) * q * df.dx
        d_R = self.params['ALPHA'] * df.div(u_n) * q * df.dx
        b = self.tau * self.params['KAPPA_NU'] * df.dot(df.grad(p), df.grad(q)) * df.dx

        rhs_f = df.inner(self.f, v) * df.dx
        rhs_g = self.tau * self.g * q * df.dx

        A = a - d_trans + d + c + b
        L = rhs_f + rhs_g + d_R + c_R

        # discretization
        tic = timer()
        A_mat = df.assemble(A)
        self.bc[0].apply(A_mat)
        self.bc[1].apply(A_mat)

        if self.verbose is True:
            print(' - elapsed time to assemble linear system: %f' % (timer()-tic))

        # simulate
        timings = np.zeros(self.steps)

        for cnt in np.arange(1, self.steps+1):
            self.t += self.tau
            if self.verbose is True:
                if cnt % 50 == 0:
                    print('Solve at time t = %.3f' % self.t)
                    print('Average time per time-step: %f seconds' % (sum(timings[:cnt])/(1.0+cnt)))
            self.g.t = self.t
            self.f.t = self.t

            tic = timer()
            # Solve for displacement-pressure fields
            b_vec = df.assemble(L)
            self.bc[0].apply(b_vec)
            self.bc[1].apply(b_vec)

            df.solve(A_mat, up_sol.vector(), b_vec)
            timings[cnt-1] = timer()-tic

            up_n.assign(up_sol)

        # save the state of simulation
        diagnostics = dict()
        diagnostics['timings'] = timings
        self.u_n.assign(df.project(u_n, self.V))
        self.p_n.assign(df.project(p_n, self.Q))
        self.diagnostics = diagnostics

        run_time_total = timer() - start
        print('Time : %.2e' % run_time_total)

    # ---------------------------------------------------------------------------
    #   semi explicit
    # ---------------------------------------------------------------------------
    def SemiExplicitSolver(self):
        print('semiexplicit with tau=%.2e and h=%.2e' % (self.tau, self.h))
        start = timer()

        # Assembly of stiffness matrices
        u = df.TrialFunction(self.V)
        p = df.TrialFunction(self.Q)
        v = df.TestFunction(self.V)
        q = df.TestFunction(self.Q)
        u_sol = df.Function(self.V)  # Most recent solution
        p_sol = df.Function(self.Q)  # Most recent solution
        u_n = df.Function(self.V)
        p_n = df.Function(self.Q)
        u_n.assign(self.u_n)
        p_n.assign(self.p_n)

        # set up bilinear forms
        a = df.inner(sigma(u, self.params['DIM'], self.params['LAMBDA'], self.params['MU']), epsilon(v)) * df.dx
        d_trans = self.params['ALPHA'] * df.div(v) * p_n * df.dx  # Semi-explicit due to p_n

        c = (1/self.params['M']) * p * q * df.dx
        c_R = (1/self.params['M']) * p_n * q * df.dx
        d = self.params['ALPHA'] * df.div(u_sol) * q * df.dx  # Updated solution for u is used
        d_R = self.params['ALPHA'] * df.div(u_n) * q * df.dx
        b = self.tau * self.params['KAPPA_NU'] * df.dot(df.grad(p), df.grad(q)) * df.dx

        rhs_f = df.inner(self.f, v) * df.dx
        rhs_g = self.tau * self.g * q * df.dx


        A1 = a
        L1 = rhs_f + d_trans
        A2 = c + b
        L2 = rhs_g - d + d_R + c_R
        A1_mat = df.assemble(A1)
        A2_mat = df.assemble(A2)
        self.bc_u.apply(A1_mat)
        self.bc_p.apply(A2_mat)

        # simulate
        timings = np.zeros(self.steps)

        for cnt in np.arange(1, self.steps+1):
            self.t += self.tau
            if self.verbose is True:
                if cnt % 50 == 0:
                    print('Solve at time t = %.3f' % self.t)
                    print('Average time per time-step: %f seconds' % (sum(timings[:cnt])/(1.0+cnt)))
            self.g.t = self.t
            self.f.t = self.t

            tic = timer()
            # Solve for displacement field
            b1_vec = df.assemble(L1)
            self.bc_u.apply(b1_vec)
            df.solve(A1_mat, u_sol.vector(), b1_vec)

            # Solve for pressure field
            b2_vec = df.assemble(L2)
            self.bc_p.apply(b2_vec)
            df.solve(A2_mat, p_sol.vector(), b2_vec)

            timings[cnt-1] = timer() - tic

            u_n.assign(u_sol)
            p_n.assign(p_sol)

        # save the state of simulation
        diagnostics = dict()
        diagnostics['timings'] = timings
        self.u_n.assign(u_n)
        self.p_n.assign(p_n)
        self.diagnostics = diagnostics

        run_time_total = timer() - start
        print('Time : %.2e' % run_time_total)

    # ---------------------------------------------------------------------------
    #   parabolic-elliptic iterative explicit
    # ---------------------------------------------------------------------------
    def ParEllipIterSolver(self,
                           Vr, Qr, Ar, Cr,
                           l_stab=1.0, tol=1e-5, max_iters=20):

        print('fixedstress with tau=%.2e and h=%.2e and L=%.2e' % (self.tau, self.h, l_stab))
        start = timer()

        kdr = self.params['LAMBDA'] + 2.0*self.params['MU']/3.0
        L_STAB = self.params['ALPHA'] ** 2 / 2.0 / kdr

        # Assembly of stiffness matrices
        u = df.TrialFunction(self.V)
        p = df.TrialFunction(self.Q)
        v = df.TestFunction(self.V)
        q = df.TestFunction(self.Q)
        u_sol = df.Function(self.V)  # Most recent solution
        p_sol = df.Function(self.Q)  # Most recent solution
        u_i = df.Function(self.V, name='previous iterate')
        p_i = df.Function(self.Q, name='previous iterate')
        u_n = df.Function(self.V)
        p_n = df.Function(self.Q)
        u_n.assign(self.u_n)
        p_n.assign(self.p_n)

        """Variational form
        $
        \begin{align}
            a(u^{n+1,i+1}_h,v_h) - d(v_h, p^{n+1,i+1}_h) &= \langle f^{n+1},v_h\rangle,  \\
            d(D_\tau u^{n+1,i}_h, q_h) + c(D_\tau p^{n+1,i+1}_h,q_h) + b(p^{n+1,i+1}_h,q_h) +
             L\,\frac{M\alpha^2}{\tau K_\text{dr}} c(p^{n+1,i+1}_h - p^{n+1,i}_h,q_h) &= \langle g^{n+1},q_h\rangle  
        \end{align}
        $
        """
        a = df.inner(sigma(u, self.params['DIM'], self.params['LAMBDA'], self.params['MU']), epsilon(v)) * df.dx
        d_trans = self.params['ALPHA'] * df.div(v) * p_sol * df.dx

        c = (1/self.params['M']) * (p - p_n) * q * df.dx
        d = self.params['ALPHA'] * df.div(u_i - u_n) * q * df.dx
        b = self.tau * self.params['KAPPA_NU'] * df.dot(df.grad(p), df.grad(q)) * df.dx
        s = L_STAB * l_stab * (p - p_i) * q * df.dx

        rhs_f = df.inner(self.f, v) * df.dx
        rhs_g = self.tau * self.g * q * df.dx

        functional_u = a - d_trans - rhs_f
        A1, L1 = df.lhs(functional_u), df.rhs(functional_u)
        functional_p = d + c + b + s - rhs_g
        A2, L2 = df.lhs(functional_p), df.rhs(functional_p)

        A1_mat = df.assemble(A1)
        A2_mat = df.assemble(A2)
        self.bc_u.apply(A1_mat)
        self.bc_p.apply(A2_mat)

        # simulate
        timings = np.zeros(self.steps)
        n_iterations = np.zeros(self.steps)
        for cnt in np.arange(1, self.steps+1):
            self.t += self.tau
            if self.verbose is True:
                if cnt % 50 == 0:
                    print('Solve at time t = %.3f' % self.t)
                    print('Average time per time-step: %f seconds' % (sum(timings[:cnt])/(1.0+cnt)))
            self.g.t = self.t
            self.f.t = self.t

            tic = timer()
            error = 2.0 * tol
            n_iter = 0

            # Initialize with previous values
            u_i.assign(u_n)
            p_i.assign(p_n)

            while (error > tol) and (n_iter < max_iters):
                tic_i = timer()

                # Solve for pressure field
                b2_vec = df.assemble(L2)
                self.bc_p.apply(b2_vec)
                df.solve(A2_mat, p_sol.vector(), b2_vec)

                # Solve for displacement field
                b1_vec = df.assemble(L1)
                self.bc_u.apply(b1_vec)
                df.solve(A1_mat, u_sol.vector(), b1_vec)
                toc_i = timer()

                error_u = compute_error(df.interpolate(u_sol, Vr), df.interpolate(u_i, Vr), Ar)
                error_p = compute_error(df.interpolate(p_sol, Qr), df.interpolate(p_i, Qr), Cr)
                error = error_u + error_p

                u_i.assign(u_sol)
                p_i.assign(p_sol)

                n_iter += 1
                timings[cnt-1] += toc_i - tic_i

            n_iterations[cnt-1] = n_iter
            u_n.assign(u_sol)
            p_n.assign(p_sol)

        # save the state of simulation
        diagnostics = dict()
        diagnostics['timings'] = timings
        diagnostics['n_iterations'] = n_iterations
        self.u_n.assign(u_n)
        self.p_n.assign(p_n)
        self.diagnostics = diagnostics

        run_time_total = timer() - start
        print('Time : %.2e' % run_time_total)

    # ---------------------------------------------------------------------------
    #   elliptic-parabolic iterative explicit
    # ---------------------------------------------------------------------------
    def EllipParIterSolver(self,
                           Vr, Qr, Ar, Cr,
                           l_stab=1.0, tol=1e-5, max_iters=20):

        print('undrained with tau=%.2e and h=%.2e and L=%.2e' % (self.tau, self.h, l_stab))
        start = timer()

        kdr = self.params['LAMBDA'] + 2.0*self.params['MU']/3.0
        L_STAB = self.params['M'] * self.params['ALPHA'] ** 2 * \
                 (self.params['M'] * self.params['ALPHA'] ** 2 / 2.0 / kdr /
                 (1 + self.params['M'] * self.params['KAPPA_NU'] / (5 / 9)) ** 2)  # * self.tau

        # Assembly of stiffness matrices
        u = df.TrialFunction(self.V)
        p = df.TrialFunction(self.Q)
        v = df.TestFunction(self.V)
        q = df.TestFunction(self.Q)
        u_sol = df.Function(self.V)  # Most recent solution
        p_sol = df.Function(self.Q)  # Most recent solution
        u_i = df.Function(self.V, name='previous iterate')
        p_i = df.Function(self.Q, name='previous iterate')
        u_n = df.Function(self.V)
        p_n = df.Function(self.Q)
        u_n.assign(self.u_n)
        p_n.assign(self.p_n)

        """Variational form
        $
        \begin{align}
            a(u^{n+1,i+1}_h,v_h) - d(v_h, p^{n+1,i}_h)
            + L\,d(v_h, M\,\alpha \nabla\cdot(u^{n+1,i+1}_h - u^{n+1,i}_h) &= \langle f^{n+1},v_h\rangle, \\
            d(D_\tau u^{n+1,i+1}_h, q_h) + c(D_\tau p^{n+1,i+1}_h,q_h) + b(p^{n+1,i+1}_h,q_h) &= \langle g^{n+1},q_h\rangle 
        \end{align}
        $
        """
        a = df.inner(sigma(u, self.params['DIM'], self.params['LAMBDA'], self.params['MU']), epsilon(v)) * df.dx
        d_trans = self.params['ALPHA'] * df.div(v) * p_i * df.dx
        s = L_STAB * l_stab * df.div(v) * (df.div(u - u_i)) * df.dx

        c = (1/self.params['M']) * (p - p_n) * q * df.dx
        d = self.params['ALPHA'] * df.div(u_sol - u_n) * q * df.dx
        b = self.tau * self.params['KAPPA_NU'] * df.dot(df.grad(p), df.grad(q)) * df.dx

        rhs_f = df.inner(self.f, v) * df.dx
        rhs_g = self.tau * self.g * q * df.dx

        functional_u = a - d_trans + s - rhs_f
        A1, L1 = df.lhs(functional_u), df.rhs(functional_u)
        functional_p = d + c + b - rhs_g
        A2, L2 = df.lhs(functional_p), df.rhs(functional_p)

        A1_mat = df.assemble(A1)
        A2_mat = df.assemble(A2)
        self.bc_u.apply(A1_mat)
        self.bc_p.apply(A2_mat)

        # simulate
        timings = np.zeros(self.steps)
        n_iterations = np.zeros(self.steps)
        for cnt in np.arange(1, self.steps+1):
            self.t += self.tau
            if self.verbose is True:
                if cnt % 50 == 0:
                    print('Solve at time t = %.3f' % self.t)
                    print('Average time per time-step: %f seconds' % (sum(timings[:cnt])/(1.0+cnt)))
            self.g.t = self.t
            self.f.t = self.t

            tic = timer()
            error = 2.0 * tol
            n_iter = 0

            # Initialize with previous values
            u_i.assign(u_n)
            p_i.assign(p_n)

            while (error > tol) and (n_iter < max_iters):
                tic_i = timer()

                # Solve for displacement field
                b1_vec = df.assemble(L1)
                self.bc_u.apply(b1_vec)
                df.solve(A1_mat, u_sol.vector(), b1_vec)

                # Solve for pressure field
                b2_vec = df.assemble(L2)
                self.bc_p.apply(b2_vec)
                df.solve(A2_mat, p_sol.vector(), b2_vec)
                toc_i = timer()

                error_u = compute_error(df.interpolate(u_sol, Vr), df.interpolate(u_i, Vr), Ar)
                error_p = compute_error(df.interpolate(p_sol, Qr), df.interpolate(p_i, Qr), Cr)
                error = error_u + error_p

                u_i.assign(u_sol)
                p_i.assign(p_sol)

                n_iter += 1
                timings[cnt-1] += toc_i - tic_i

            n_iterations[cnt-1] = n_iter
            u_n.assign(u_sol)
            p_n.assign(p_sol)

        # save the state of simulation
        diagnostics = dict()
        diagnostics['timings'] = timings
        diagnostics['n_iterations'] = n_iterations
        self.u_n.assign(u_n)
        self.p_n.assign(p_n)
        self.diagnostics = diagnostics

        run_time_total = timer() - start
        print('Time : %.2e' % run_time_total)


def implicit(ns, nt,
             f, g, u_D, p_D, boundary, p_0,
             params, params_disc, test=False):

    solver = LinEPSolver(ns, nt,
                         f, g, u_D, p_D, boundary, p_0,
                         params, params_disc, test=test)

    solver.setup()

    solver.ImplicitSolver()

    return solver


def semiexplicit(ns, nt,
                 f, g, u_D, p_D, boundary, p_0,
                 params, params_disc, test=False):

    solver = LinEPSolver(ns, nt,
                         f, g, u_D, p_D, boundary, p_0,
                         params, params_disc, test=test)

    solver.setup()

    solver.SemiExplicitSolver()

    return solver


def parellip(ns, nt,
             f, g, u_D, p_D, boundary, p_0,
             params, params_disc,
             V, Q, A, C,
             l_stab=1.0, tol=1e-5, max_iters=20, test=False):

    solver = LinEPSolver(ns, nt,
                         f, g, u_D, p_D, boundary, p_0,
                         params, params_disc, test=test)

    solver.setup()

    solver.ParEllipIterSolver(V, Q, A, C,
                              l_stab=l_stab, tol=tol, max_iters=max_iters)

    return solver


def ellippar(ns, nt,
             f, g, u_D, p_D, boundary, p_0,
             params, params_disc,
             V, Q, A, C,
             l_stab=1.0, tol=1e-5, max_iters=20, test=False):

    solver = LinEPSolver(ns, nt,
                         f, g, u_D, p_D, boundary, p_0,
                         params, params_disc, test=test)

    solver.setup()

    solver.EllipParIterSolver(V, Q, A, C,
                              l_stab=l_stab, tol=tol, max_iters=max_iters)

    return solver

