!! Copyright (C) 2024 S. Ohlmann
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

#include "global.h"

module chebyshev_coefficients_oct_m
  use debug_oct_m
  use fftw_oct_m
  use fftw_params_oct_m
  use global_oct_m
  use, intrinsic :: iso_c_binding
  use, intrinsic :: iso_fortran_env
  use loct_math_oct_m
  use math_oct_m
  use messages_oct_m
  use profiling_oct_m

  implicit none

  private
  public ::                            &
    chebyshev_function_t,              &
    chebyshev_exp_t,                   &
    chebyshev_exp_imagtime_t,          &
    chebyshev_numerical_t

  type, abstract :: chebyshev_function_t
    private
    real(real64) :: half_span
    real(real64) :: middle_point
    real(real64) :: deltat
  contains
    procedure(chebyshev_get_coefficients), deferred :: get_coefficients
    procedure(chebyshev_get_error), deferred :: get_error
    procedure :: set_parameters => chebyshev_set_parameters
  end type chebyshev_function_t

  abstract interface
    subroutine chebyshev_get_coefficients(this, order, coefficients)
      import chebyshev_function_t
      import real64
      implicit none
      class(chebyshev_function_t),  intent(in)  :: this
      integer,                      intent(in)  :: order
      complex(real64), allocatable, intent(out) :: coefficients(:)
    end subroutine chebyshev_get_coefficients

    real(real64) function chebyshev_get_error(this, order)
      import chebyshev_function_t
      import real64
      implicit none
      class(chebyshev_function_t),  intent(in)  :: this
      integer,      intent(in)  :: order
    end function chebyshev_get_error
  end interface

  !< type encapsulates information about chebyshev coefficients for exp(iHt)
  type, extends(chebyshev_function_t) :: chebyshev_exp_t
  contains
    procedure :: get_coefficients => chebyshev_exp_coefficients
    procedure :: get_error => chebyshev_exp_error
  end type chebyshev_exp_t

  interface chebyshev_exp_t
    procedure chebyshev_exp_constructor
  end interface chebyshev_exp_t

  !< type encapsulates information about chebyshev coefficients for exp(-Ht)
  type, extends(chebyshev_function_t) :: chebyshev_exp_imagtime_t
  contains
    procedure :: get_coefficients => chebyshev_exp_imagtime_coefficients
    procedure :: get_error => chebyshev_exp_imagtime_error
  end type chebyshev_exp_imagtime_t

  interface chebyshev_exp_imagtime_t
    procedure chebyshev_exp_imagtime_constructor
  end interface chebyshev_exp_imagtime_t

  ! interface needed to save function pointer in type
  abstract interface
    complex(real64) function complex_function_i(z)
      import real64
      implicit none
      complex(real64), intent(in) :: z
    end function complex_function_i
  end interface

  !< type encapsulates information about chebyshev coefficients that are computed numerically
  type, extends(chebyshev_function_t) :: chebyshev_numerical_t
    procedure(complex_function_i), pointer, nopass :: complex_function
  contains
    procedure :: get_coefficients => chebyshev_numerical_coefficients
    procedure :: get_error => chebyshev_numerical_error
  end type chebyshev_numerical_t

  interface chebyshev_numerical_t
    procedure chebyshev_numerical_constructor
  end interface chebyshev_numerical_t

contains

  subroutine chebyshev_set_parameters(this, half_span, middle_point, deltat)
    class(chebyshev_function_t), intent(inout) :: this
    real(real64),                intent(in)    :: half_span
    real(real64),                intent(in)    :: middle_point
    real(real64),                intent(in)    :: deltat

    PUSH_SUB(chebyshev_set_parameters)

    this%half_span = half_span
    this%middle_point = middle_point
    this%deltat = deltat

    POP_SUB(chebyshev_set_parameters)
  end subroutine chebyshev_set_parameters

  function chebyshev_exp_constructor(half_span, middle_point, deltat) result(chebyshev_function)
    class(chebyshev_exp_t), pointer    :: chebyshev_function
    real(real64),           intent(in) :: half_span
    real(real64),           intent(in) :: middle_point
    real(real64),           intent(in) :: deltat

    PUSH_SUB(chebyshev_exp_constructor)

    allocate(chebyshev_function)
    call chebyshev_function%set_parameters(half_span, middle_point, deltat)

    POP_SUB(chebyshev_exp_constructor)
  end function chebyshev_exp_constructor

  subroutine chebyshev_exp_coefficients(this, order, coefficients)
    class(chebyshev_exp_t),       intent(in)  :: this
    integer,                      intent(in)  :: order
    complex(real64), allocatable, intent(out) :: coefficients(:)

    integer :: i

    PUSH_SUB(chebyshev_exp_coefficients)
    SAFE_ALLOCATE(coefficients(0:order))
    do i = 0, order
      coefficients(i) = exp(-M_zI*this%middle_point*this%deltat) * &
        (-M_zI)**i * M_TWO * loct_bessel(i, this%half_span*this%deltat)
    end do
    coefficients(0) = coefficients(0) / M_TWO
    POP_SUB(chebyshev_exp_coefficients)
  end subroutine chebyshev_exp_coefficients

  !> Use the error estimate from Lubich, C. From Quantum to Classical Molecular Dynamics:
  !! Reduced Models and Numerical Analysis. (EMS Press, 2008), doi:10.4171/067, Theorems 2.1 to 2.4
  real(real64) function chebyshev_exp_error(this, order)
    class(chebyshev_exp_t), intent(in)  :: this
    integer,                intent(in)  :: order

    real(real64) :: r, inv_r

    PUSH_SUB(chebyshev_exp_error)
    if (order >= this%half_span*this%deltat) then
      r = M_TWO*(order+1)/abs(this%half_span*this%deltat)
      inv_r = M_ONE/r
      chebyshev_exp_error = M_TWO*inv_r**(order+1)/(M_ONE - inv_r) * exp(abs(this%half_span*this%deltat)*(r-inv_r)/M_TWO)
    else
      chebyshev_exp_error = -M_ONE
    end if
    POP_SUB(chebyshev_exp_error)
  end function chebyshev_exp_error

  function chebyshev_exp_imagtime_constructor(half_span, middle_point, deltat) result(chebyshev_function)
    class(chebyshev_exp_imagtime_t), pointer    :: chebyshev_function
    real(real64),                    intent(in) :: half_span
    real(real64),                    intent(in) :: middle_point
    real(real64),                    intent(in) :: deltat

    PUSH_SUB(chebyshev_exp_imagtime_constructor)

    allocate(chebyshev_function)
    call chebyshev_function%set_parameters(half_span, middle_point, deltat)

    POP_SUB(chebyshev_exp_imagtime_constructor)
  end function chebyshev_exp_imagtime_constructor

  subroutine chebyshev_exp_imagtime_coefficients(this, order, coefficients)
    class(chebyshev_exp_imagtime_t), intent(in)  :: this
    integer,                         intent(in)  :: order
    complex(real64), allocatable,    intent(out) :: coefficients(:)

    integer :: i

    PUSH_SUB(chebyshev_exp_imagtime_coefficients)
    SAFE_ALLOCATE(coefficients(0:order))
    do i = 0, order
      coefficients(i) = exp(-this%middle_point*this%deltat) * &
        M_TWO * loct_bessel_in(i, -this%half_span*this%deltat)
    end do
    coefficients(0) = coefficients(0) / M_TWO
    POP_SUB(chebyshev_exp_imagtime_coefficients)
  end subroutine chebyshev_exp_imagtime_coefficients

  !> Use the error estimate from Hochbruck, M. & Ostermann, A. Exponential integrators.
  !! Acta Numerica 19, 209–286 (2010), Theorem 4.1 (page 265) and L. Bergamaschi and M. Vianello:
  !! Efficient computation of the exponential operator for large, sparse, symmetric matrices,
  !! Numer. Linear Algebra Appl. 7, 27–45 (2000), eq. 2.7
  real(real64) function chebyshev_exp_imagtime_error(this, order)
    class(chebyshev_exp_imagtime_t), intent(in)  :: this
    integer,                         intent(in)  :: order

    real(real64) :: upper_bound, lower_bound
    real(real64), parameter :: b = 0.618, d = 0.438

    PUSH_SUB(chebyshev_exp_imagtime_error)
    upper_bound = this%half_span+this%middle_point
    lower_bound = -this%half_span+this%middle_point
    if (order <= upper_bound*this%deltat) then
      chebyshev_exp_imagtime_error = M_TWO*exp(-lower_bound*this%deltat)*exp(-b*(order+1)**2/(upper_bound*this%deltat))*&
        (M_ONE+sqrt(upper_bound*this%deltat*M_PI/(M_FOUR*b))) + M_TWO*d**(upper_bound*this%deltat)/(M_ONE-d)
    else
      chebyshev_exp_imagtime_error = M_TWO*exp(-lower_bound*this%deltat)*d**order/(M_ONE-d)
    end if
    POP_SUB(chebyshev_exp_imagtime_error)
  end function chebyshev_exp_imagtime_error

  function chebyshev_numerical_constructor(half_span, middle_point, deltat, complex_function) result(chebyshev_function)
    class(chebyshev_numerical_t),  pointer    :: chebyshev_function
    real(real64),                  intent(in) :: half_span
    real(real64),                  intent(in) :: middle_point
    real(real64),                  intent(in) :: deltat
    procedure(complex_function_i)             :: complex_function

    PUSH_SUB(chebyshev_numerical_constructor)

    allocate(chebyshev_function)
    call chebyshev_function%set_parameters(half_span, middle_point, deltat)
    chebyshev_function%complex_function => complex_function

    POP_SUB(chebyshev_numerical_constructor)
  end function chebyshev_numerical_constructor

  !> use a discrete cosine transform to compute the coefficients because
  !! no analytical formula is available for the phi_k functions
  subroutine chebyshev_numerical_coefficients(this, order, coefficients)
    class(chebyshev_numerical_t), intent(in)  :: this
    integer,                      intent(in)  :: order
    complex(real64), allocatable, intent(out) :: coefficients(:)

    integer :: i
    complex(real64) :: values(0:order)
    real(real64) :: re_values(0:order), im_values(0:order)
    real(real64) :: re_coefficients(0:order), im_coefficients(0:order)
    real(real64) :: theta
    type(c_ptr) :: plan

    PUSH_SUB(chebyshev_numerical_coefficients)
    SAFE_ALLOCATE(coefficients(0:order))
    ! compute values for DCT
    do i = 0, order
      theta = M_PI*i/order
      values(i) = this%complex_function(-M_zI*this%deltat*(this%half_span*cos(theta)+this%middle_point))
      re_values(i) = real(values(i), real64)
      im_values(i) = aimag(values(i))
    end do
    plan = fftw_plan_r2r_1d(order+1, re_values, re_coefficients, FFTW_REDFT00, FFTW_ESTIMATE)
    call fftw_execute_r2r(plan, re_values, re_coefficients)
    call fftw_execute_r2r(plan, im_values, im_coefficients)
    call fftw_destroy_plan(plan)
    do i = 0, order
      ! normalize correctly, also need factor 2 for expansion coefficients
      coefficients(i) = cmplx(re_coefficients(i), im_coefficients(i), real64)/order
    end do
    ! need to divide first and last coefficient by 2
    coefficients(0) = coefficients(0) / M_TWO
    coefficients(order) = coefficients(order) / M_TWO

    POP_SUB(chebyshev_numerical_coefficients)
  end subroutine chebyshev_numerical_coefficients

  !> use the error estimate from Lubich, C. From Quantum to Classical Molecular Dynamics:
  !! Reduced Models and Numerical Analysis. (EMS Press, 2008), doi:10.4171/067, Theorems 2.1 to 2.4
  real(real64) function chebyshev_numerical_error(this, order)
    class(chebyshev_numerical_t), intent(in)  :: this
    integer,                      intent(in)  :: order

    real(real64) :: r
    real(real64) :: inv_r

    PUSH_SUB(chebyshev_numerical_error)
    if (order >= this%half_span*this%deltat) then
      r = M_TWO*order/abs(this%half_span*this%deltat)
      inv_r = M_ONE/r
      chebyshev_numerical_error = M_TWO*inv_R**order/(M_ONE - inv_r) * &
        real(this%complex_function(cmplx(abs(this%half_span*this%deltat)*(r-inv_r)/M_TWO, M_ZERO, real64)), real64)
    else
      chebyshev_numerical_error = -M_ONE
    end if
    POP_SUB(chebyshev_numerical_error)
  end function chebyshev_numerical_error
end module chebyshev_coefficients_oct_m

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
