'''Script for simulating the motion of a spring.'''
from integrate import euler

import numpy as np
import matplotlib.pyplot as plt
import time


def force(t):
    '''Force acting on the spring.'''
    return np.sin(t)


def spring_ode(t, x, m, d, k):
    '''Right-hand side in the spring ODE.

    The model is a second-order differential equation

        m p''(t) + d p'(t) + k p(t) = force(t).

    The transformation to a system of first-order differential equations

        x'(t) = f(t, x(t))

    is done such that x(t) = [p(t), p'(t)].

    Parameters
    ----------
    t : float
        Time point.
    x : NumPy array of shape (2,)
        State at time t.
    m : float
        Mass of the body attached to the spring.
    d : float
        Damping of the spring.
    k : float
        Stiffness of the spring.
    '''
    F = force(t)
    dxdt = np.zeros((2,))
    dxdt[0] = x[1]
    dxdt[1] = (F - d * x[1] - k * x[0]) / m
    return dxdt


# parameter values
m = 1
d = 0.1
k = 1

# initial condition
p0 = 1
v0 = 1
x0 = np.array([p0, v0])

# time mesh
tstart = 0
tend = 100
tN = 3000
t = np.linspace(tstart, tend, tN)

# simulation
print('Starting the simulation.')
t_initial = time.time()

x = euler(lambda t, x: spring_ode(t, x, m, d, k), t, x0)

t_final = time.time()
t_elapsed = t_final - t_initial
print('Elapsed time is {:.5f} seconds.'.format(t_elapsed))

# plot
fig, ax = plt.subplots()
ax.plot(t, x[0, :])
ax.set_xlabel('$t$')
ax.set_ylabel('$p$')
plt.show()

# save data
data = np.zeros((tN, 2))
data[:, 0] = t
data[:, 1] = x[0, :]
np.savetxt('data.csv', data, header='t p', comments='')
