!! Copyright (C) 2002-2006 M. Marques, A. Castro, A. Rubio, G. Bertsch
!! Copyright (C) 2019 M. Oliveira
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

#include "global.h"

module exponential_oct_m
  use accel_oct_m
  use batch_oct_m
  use batch_ops_oct_m
  use blas_oct_m
  use chebyshev_coefficients_oct_m
  use debug_oct_m
  use global_oct_m
  use hamiltonian_abst_oct_m
  use hamiltonian_elec_oct_m
  use hamiltonian_elec_base_oct_m
  use, intrinsic :: iso_fortran_env
  use lalg_adv_oct_m
  use lalg_basic_oct_m
  use loct_math_oct_m
  use math_oct_m
  use mesh_oct_m
  use mesh_function_oct_m
  use mesh_batch_oct_m
  use messages_oct_m
  use namespace_oct_m
  use parser_oct_m
  use profiling_oct_m
  use states_elec_oct_m
  use states_elec_calc_oct_m
  use types_oct_m
  use wfs_elec_oct_m
  use xc_oct_m

  implicit none

  private
  public ::                            &
    exponential_t,                     &
    exponential_init,                  &
    exponential_copy,                  &
    exponential_apply_all,             &
    exponential_lanczos_function_batch

  integer, public, parameter ::  &
    EXP_LANCZOS            = 2,  &
    EXP_TAYLOR             = 3,  &
    EXP_CHEBYSHEV          = 4

  type exponential_t
    private
    integer, public :: exp_method  !< which method is used to apply the exponential
    real(real64)    :: lanczos_tol !< tolerance for the Lanczos method
    real(real64)    :: chebyshev_tol !< tolerance for the Chebyshev method
    integer, public :: exp_order   !< order to which the propagator is expanded
    integer         :: arnoldi_gs  !< Orthogonalization scheme used for Arnoldi
    logical, public :: full_batch = .false. !< apply exponential to full batch instead of to each state in a batch
  contains
    procedure :: apply_batch => exponential_apply_batch
    procedure :: apply_single => exponential_apply_single
    procedure :: apply_phi_batch => exponential_apply_phi_batch
  end type exponential_t

contains

  ! ---------------------------------------------------------
  subroutine exponential_init(te, namespace, full_batch)
    type(exponential_t), intent(out) :: te
    type(namespace_t),   intent(in)  :: namespace
    logical, optional,   intent(in)  :: full_batch

    PUSH_SUB(exponential_init)

    !%Variable TDExponentialMethod
    !%Type integer
    !%Default taylor
    !%Section Time-Dependent::Propagation
    !%Description
    !% Method used to numerically calculate the exponential of the Hamiltonian,
    !% a core part of the full algorithm used to approximate the evolution
    !% operator, specified through the variable <tt>TDPropagator</tt>.
    !% In the case of using the Magnus method, described below, the action of the exponential
    !% of the Magnus operator is also calculated through the algorithm specified
    !% by this variable.
    !%Option lanczos 2
    !% Allows for larger time-steps.
    !% However, the larger the time-step, the longer the computational time per time-step.
    !% In certain cases, if the time-step is too large, the code will emit a warning
    !% whenever it considers that the evolution may not be properly proceeding --
    !% the Lanczos process did not converge. The method consists in a Krylov
    !% subspace approximation of the action of the exponential
    !% (see M. Hochbruck and C. Lubich, <i>SIAM J. Numer. Anal.</i> <b>34</b>, 1911 (1997) for details).
    !% The performance of the method is controlled by the tolerance (controlled by <tt>TDLanczosTol</tt>).
    !% The smaller the tolerance, the more precisely the exponential
    !% is calculated, but also the larger the dimension of the Arnoldi
    !% subspace. If the maximum dimension (currently 200) is not enough to meet the criterion,
    !% the above-mentioned warning is emitted.
    !% Be aware that the larger the required dimension of the Krylov subspace, the larger
    !% the memory required for this method. So if you run out of memory, try to reduce
    !% the time step.
    !%Option taylor 3
    !% This method amounts to a straightforward application of the definition of
    !% the exponential of an operator, in terms of its Taylor expansion.
    !%
    !% <math>\exp_{\rm STD} (-i\delta t H) = \sum_{i=0}^{k} {(-i\delta t)^i\over{i!}} H^i.</math>
    !%
    !% The order <i>k</i> is determined by variable <tt>TDExpOrder</tt>.
    !% Some numerical considerations from <a href=http://www.phys.washington.edu/~bertsch/num3.ps>
    !% Jeff Giansiracusa and George F. Bertsch</a>
    !% suggest the 4th order as especially suitable and stable.
    !%Option chebyshev 4
    !% In principle, the Chebyshev expansion
    !% of the exponential represents it more accurately than the canonical or standard expansion.
    !% <tt>TDChebyshevTol</tt> determines the tolerance to which the expansion is computed.
    !%
    !% There exists a closed analytic form for the coefficients of the exponential in terms
    !% of Chebyshev polynomials:
    !%
    !% <math>\exp_{\rm CHEB} \left( -i\delta t H \right) = \sum_{k=0}^{\infty} (2-\delta_{k0})(-i)^{k}J_k(\delta t) T_k(H),</math>
    !%
    !% where <math>J_k</math> are the Bessel functions of the first kind, and H has to be previously
    !% scaled to <math>[-1,1]</math>.
    !% See H. Tal-Ezer and R. Kosloff, <i>J. Chem. Phys.</i> <b>81</b>,
    !% 3967 (1984); R. Kosloff, <i>Annu. Rev. Phys. Chem.</i> <b>45</b>, 145 (1994);
    !% C. W. Clenshaw, <i>MTAC</i> <b>9</b>, 118 (1955).
    !%End
    call parse_variable(namespace, 'TDExponentialMethod', EXP_TAYLOR, te%exp_method)

    select case (te%exp_method)
    case (EXP_TAYLOR)
    case (EXP_CHEBYSHEV)
      !%Variable TDChebyshevTol
      !%Type float
      !%Default 1e-10
      !%Section Time-Dependent::Propagation
      !%Description
      !% An internal tolerance variable for the Chebyshev method. The smaller, the more
      !% precisely the exponential is calculated and the more iterations are needed, i.e.,
      !% it becomes more expensive. The expansion is terminated once the error estimate
      !% is below this tolerance.
      !%End
      call parse_variable(namespace, 'TDChebyshevTol', 1e-10_real64, te%chebyshev_tol)
      if (te%chebyshev_tol <= M_ZERO) call messages_input_error(namespace, 'TDChebyshevTol')
    case (EXP_LANCZOS)
      !%Variable TDLanczosTol
      !%Type float
      !%Default 1e-5
      !%Section Time-Dependent::Propagation
      !%Description
      !% An internal tolerance variable for the Lanczos method. The smaller, the more
      !% precisely the exponential is calculated, and also the bigger the dimension
      !% of the Krylov subspace needed to perform the algorithm. One should carefully
      !% make sure that this value is not too big, or else the evolution will be
      !% wrong.
      !%End
      call parse_variable(namespace, 'TDLanczosTol', 1e-5_real64, te%lanczos_tol)
      if (te%lanczos_tol <= M_ZERO) call messages_input_error(namespace, 'TDLanczosTol')

    case default
      call messages_input_error(namespace, 'TDExponentialMethod')
    end select
    call messages_print_var_option('TDExponentialMethod', te%exp_method, namespace=namespace)

    if (te%exp_method == EXP_TAYLOR) then
      !%Variable TDExpOrder
      !%Type integer
      !%Default 4
      !%Section Time-Dependent::Propagation
      !%Description
      !% For <tt>TDExponentialMethod</tt> = <tt>taylor</tt>,
      !% the order to which the exponential is expanded.
      !%End
      call parse_variable(namespace, 'TDExpOrder', DEFAULT__TDEXPORDER, te%exp_order)
      if (te%exp_order < 2) call messages_input_error(namespace, 'TDExpOrder')

    end if

    te%arnoldi_gs = OPTION__ARNOLDIORTHOGONALIZATION__CGS
    if (te%exp_method == EXP_LANCZOS) then
      !%Variable ArnoldiOrthogonalization
      !%Type integer
      !%Section Time-Dependent::Propagation
      !%Description
      !% The orthogonalization method used for the Arnoldi procedure.
      !% Only for TDExponentialMethod = lanczos.
      !%Option cgs 3
      !% Classical Gram-Schmidt (CGS) orthogonalization.
      !% The algorithm is defined in Giraud et al., Computers and Mathematics with Applications 50, 1069 (2005).
      !%Option drcgs 5
      !% Classical Gram-Schmidt orthogonalization with double-step reorthogonalization.
      !% The algorithm is taken from Giraud et al., Computers and Mathematics with Applications 50, 1069 (2005).
      !% According to this reference, this is much more precise than CGS or MGS algorithms.
      !%End
      call parse_variable(namespace, 'ArnoldiOrthogonalization', OPTION__ARNOLDIORTHOGONALIZATION__CGS, &
        te%arnoldi_gs)
    end if

    ! do lanczos expansion for full batch?
    te%full_batch = optional_default(full_batch, te%full_batch)

    POP_SUB(exponential_init)
  end subroutine exponential_init

  ! ---------------------------------------------------------
  subroutine exponential_copy(teo, tei)
    type(exponential_t), intent(inout) :: teo
    type(exponential_t), intent(in)    :: tei

    PUSH_SUB(exponential_copy)

    teo%exp_method  = tei%exp_method
    teo%lanczos_tol = tei%lanczos_tol
    teo%exp_order   = tei%exp_order
    teo%arnoldi_gs  = tei%arnoldi_gs

    POP_SUB(exponential_copy)
  end subroutine exponential_copy

  ! ---------------------------------------------------------
  !> Wrapper to batchified routine for applying exponential to an array
  subroutine exponential_apply_single(te, namespace, mesh, hm, zpsi, ist, ik, deltat, vmagnus, imag_time)
    class(exponential_t),        intent(inout) :: te
    type(namespace_t),           intent(in)    :: namespace
    class(mesh_t),               intent(in)    :: mesh
    type(hamiltonian_elec_t),    intent(inout) :: hm
    integer,                     intent(in)    :: ist
    integer,                     intent(in)    :: ik
    complex(real64), contiguous, intent(inout) :: zpsi(:, :)
    real(real64),                intent(in)    :: deltat
    real(real64),   optional,    intent(in)    :: vmagnus(mesh%np, hm%d%nspin, 2)
    logical, optional,           intent(in)    :: imag_time

    type(wfs_elec_t) :: psib, inh_psib
    complex(real64), allocatable :: zpsi_inh(:, :)

    PUSH_SUB(exponential_apply_single)

    !We apply the phase only to np points, and the phase for the np+1 to np_part points
    !will be treated as a phase correction in the Hamiltonian
    if (hm%phase%is_allocated()) then
      call hm%phase%apply_to_single(zpsi, mesh%np, hm%d%dim, ik, .false.)
    end if

    call wfs_elec_init(psib, hm%d%dim, ist, ist, zpsi, ik)

    if (hamiltonian_elec_inh_term(hm)) then
      SAFE_ALLOCATE(zpsi_inh(1:mesh%np_part, 1:hm%d%dim))
      call states_elec_get_state(hm%inh_st, mesh, ist, ik, zpsi_inh(:, :))
      call wfs_elec_init(inh_psib, hm%d%dim, ist, ist, zpsi_inh, ik)
      call te%apply_batch(namespace, mesh, hm, psib, deltat, &
        vmagnus=vmagnus, imag_time=imag_time, inh_psib=inh_psib)
      call inh_psib%end()
      SAFE_DEALLOCATE_A(zpsi_inh)
    else
      call te%apply_batch(namespace, mesh, hm, psib, deltat, &
        vmagnus=vmagnus, imag_time=imag_time)
    end if

    call psib%end()

    if (hm%phase%is_allocated()) then
      call hm%phase%apply_to_single(zpsi, mesh%np, hm%d%dim, ik, .true.)
    end if

    POP_SUB(exponential_apply_single)
  end subroutine exponential_apply_single

  ! ---------------------------------------------------------
  !> This routine performs the operation:
  !! \f[
  !! \exp{-i*\Delta t*hm(t)}|\psi>  <-- |\psi>
  !! \f]
  !! If imag_time is present and is set to true, it performs instead:
  !! \f[
  !! \exp{ \Delta t*hm(t)}|\psi>  <-- |\psi>
  !! \f]
  !! If an inhomogeneous term is present, the operation is:
  !! \f[
  !! \exp{-i*\Delta t*hm(t)}|\psi> + \Delta t*\phi_1{-i*\Delta t*hm(t)}|inh\psi>  <-- |\psi>
  !! \f]
  !! where:
  !! \f[
  !! \phi_1(x) = (e^x - 1)/x
  !! \f]
  ! ---------------------------------------------------------
  subroutine exponential_apply_batch(te, namespace, mesh, hm, psib, deltat, psib2, deltat2, vmagnus, imag_time, inh_psib)
    class(exponential_t),               intent(inout) :: te
    type(namespace_t),                  intent(in)    :: namespace
    class(mesh_t),                      intent(in)    :: mesh
    class(hamiltonian_abst_t),          intent(inout) :: hm
    class(batch_t),                     intent(inout) :: psib
    real(real64),                       intent(in)    :: deltat
    class(batch_t),           optional, intent(inout) :: psib2
    real(real64),             optional, intent(in)    :: deltat2
    real(real64),             optional, intent(in)    :: vmagnus(:,:,:) !(mesh%np, hm%d%nspin, 2)
    logical,                  optional, intent(in)    :: imag_time
    class(batch_t),           optional, intent(inout) :: inh_psib   !< inhomogeneous term

    complex(real64) :: deltat_, deltat2_
    class(chebyshev_function_t), pointer :: chebyshev_function
    logical :: imag_time_

    PUSH_SUB(exponential_apply_batch)
    call profiling_in("EXPONENTIAL_BATCH")

    ASSERT(psib%type() == TYPE_CMPLX)

    ASSERT(present(psib2) .eqv. present(deltat2))
    if (present(inh_psib)) then
      ASSERT(inh_psib%nst == psib%nst)
    end if

    if (present(vmagnus)) then
      select type(hm)
      class is (hamiltonian_elec_t)
        ASSERT(size(vmagnus, 1) >= mesh%np)
        ASSERT(size(vmagnus, 2) == hm%d%nspin)
        ASSERT(size(vmagnus, 3) == 2)
      class default
        write(message(1), '(a)') 'Magnus operators only implemented for electrons at the moment'
        call messages_fatal(1, namespace=namespace)
      end select
    end if

    deltat2_ = cmplx(optional_default(deltat2, M_ZERO), M_ZERO, real64)

    imag_time_ = optional_default(imag_time, .false.)
    if (imag_time_) then
      deltat_ = -M_zI*deltat
      if (present(deltat2)) deltat2_ = M_zI*deltat2
    else
      deltat_ = cmplx(deltat, M_ZERO, real64)
      if (present(deltat2)) deltat2_ = cmplx(deltat2, M_ZERO, real64)
    end if

    if (.not. hm%is_hermitian() .and. te%exp_method == EXP_CHEBYSHEV) then
      write(message(1), '(a)') 'The Chebyshev expansion cannot be used for non-Hermitian operators.'
      write(message(2), '(a)') 'Please use the Lanczos exponentiation scheme ("TDExponentialMethod = lanczos")'
      write(message(3), '(a)') 'or the Taylor expansion ("TDExponentialMethod = taylor") method.'
      call messages_fatal(3, namespace=namespace)
    end if

    select case (te%exp_method)
    case (EXP_TAYLOR)
      ! Note that delttat2_ is only initialized if deltat2 is present.
      if (present(deltat2)) then
        call exponential_taylor_series_batch(te, namespace, mesh, hm, psib, deltat_, &
          psib2, deltat2_, vmagnus)
      else
        call exponential_taylor_series_batch(te, namespace, mesh, hm, psib, deltat_, &
          vmagnus=vmagnus)
      end if
      if (present(inh_psib)) then
        if (present(deltat2)) then
          call exponential_taylor_series_batch(te, namespace, mesh, hm, psib, deltat_, &
            psib2, deltat2_, vmagnus, inh_psib)
        else
          call exponential_taylor_series_batch(te, namespace, mesh, hm, psib, deltat_, &
            vmagnus=vmagnus, inh_psib=inh_psib)
        end if
      end if

    case (EXP_LANCZOS)
      if (present(psib2)) call psib%copy_data_to(mesh%np, psib2)
      call exponential_lanczos_batch(te, namespace, mesh, hm, psib, deltat_, vmagnus)
      if (present(inh_psib)) then
        call exponential_lanczos_batch(te, namespace, mesh, hm, psib, deltat_, vmagnus, inh_psib)
      end if
      if (present(psib2)) then
        call exponential_lanczos_batch(te, namespace, mesh, hm, psib2, deltat2_, vmagnus)
        if (present(inh_psib)) then
          call exponential_lanczos_batch(te, namespace, mesh, hm, psib2, deltat2_, vmagnus, inh_psib)
        end if
      end if

    case (EXP_CHEBYSHEV)
      if (present(inh_psib)) then
        write(message(1), '(a)') 'Chebyshev exponential ("TDExponentialMethod = chebyshev")'
        write(message(2), '(a)') 'with inhomogeneous term is not implemented'
        call messages_fatal(2, namespace=namespace)
      end if
      ! initialize classes for computing coefficients
      if (imag_time_) then
        chebyshev_function => chebyshev_exp_imagtime_t(hm%spectral_half_span, hm%spectral_middle_point, deltat)
      else
        chebyshev_function => chebyshev_exp_t(hm%spectral_half_span, hm%spectral_middle_point, deltat)
      end if
      if (present(psib2)) call psib%copy_data_to(mesh%np, psib2)
      call exponential_cheby_batch(te, namespace, mesh, hm, psib, deltat, chebyshev_function, vmagnus)
      if (present(psib2)) then
        call exponential_cheby_batch(te, namespace, mesh, hm, psib2, deltat2, chebyshev_function, vmagnus)
      end if
      deallocate(chebyshev_function)

    end select

    call profiling_out("EXPONENTIAL_BATCH")
    POP_SUB(exponential_apply_batch)
  end subroutine exponential_apply_batch

  ! ---------------------------------------------------------
  subroutine exponential_taylor_series_batch(te, namespace, mesh, hm, psib, deltat, psib2, deltat2, vmagnus, inh_psib, phik_shift)
    type(exponential_t),                intent(inout) :: te
    type(namespace_t),                  intent(in)    :: namespace
    class(mesh_t),                      intent(in)    :: mesh
    class(hamiltonian_abst_t),          intent(inout) :: hm
    class(batch_t),                     intent(inout) :: psib
    complex(real64),                    intent(in)    :: deltat
    class(batch_t),           optional, intent(inout) :: psib2
    complex(real64),          optional, intent(in)    :: deltat2
    real(real64),             optional, intent(in)    :: vmagnus(:,:,:) !(mesh%np, hm%d%nspin, 2)
    class(batch_t),           optional, intent(inout) :: inh_psib       !< inhomogeneous term
    integer,                  optional, intent(in)    :: phik_shift     !< shift in the Taylor expansion coefficients for phi_k

    complex(real64) :: zfact, zfact2
    integer :: iter, denom, phik_shift_
    logical :: zfact_is_real
    class(batch_t), allocatable :: psi1b, hpsi1b

    PUSH_SUB(exponential_taylor_series_batch)
    call profiling_in("EXP_TAYLOR_BATCH")

    call psib%clone_to(psi1b)
    call psib%clone_to(hpsi1b)

    zfact = M_z1
    zfact2 = M_z1
    zfact_is_real = abs(deltat-real(deltat)) < M_EPSILON

    if (present(psib2)) call psib%copy_data_to(mesh%np, psib2)

    if (present(inh_psib)) then
      zfact = zfact*deltat
      call batch_axpy(mesh%np, real(zfact), inh_psib, psib)

      if (present(psib2)) then
        zfact2 = zfact2*deltat2
        call batch_axpy(mesh%np, real(zfact2), inh_psib, psib2)
      end if
    end if

    ! shift the denominator by this shift for the phi_k functions
    phik_shift_ = optional_default(phik_shift, 0)

    do iter = 1, te%exp_order
      denom = iter+phik_shift_
      if (present(inh_psib)) denom = denom + 1
      zfact = zfact*(-M_zI*deltat)/denom
      if (present(deltat2)) zfact2 = zfact2*(-M_zI*deltat2)/denom
      zfact_is_real = .not. zfact_is_real
      ! FIXME: need a test here for runaway exponential, e.g. for too large dt.
      !  in runaway case the problem is really hard to trace back: the positions
      !  go haywire on the first step of dynamics (often NaN) and with debugging options
      !  the code stops in ZAXPY below without saying why.

      if (iter /= 1) then
        call operate_batch(hm, namespace, mesh, psi1b, hpsi1b, vmagnus)
      else
        if (present(inh_psib)) then
          call operate_batch(hm, namespace, mesh, inh_psib, hpsi1b, vmagnus)
        else
          call operate_batch(hm, namespace, mesh, psib, hpsi1b, vmagnus)
        end if
      end if

      if (zfact_is_real) then
        call batch_axpy(mesh%np, real(zfact), hpsi1b, psib)
        if (present(psib2)) call batch_axpy(mesh%np, real(zfact2), hpsi1b, psib2)
      else
        call batch_axpy(mesh%np, zfact, hpsi1b, psib)
        if (present(psib2)) call batch_axpy(mesh%np, zfact2, hpsi1b, psib2)
      end if

      if (iter /= te%exp_order) call hpsi1b%copy_data_to(mesh%np, psi1b)

    end do

    call psi1b%end()
    call hpsi1b%end()
    SAFE_DEALLOCATE_A(psi1b)
    SAFE_DEALLOCATE_A(hpsi1b)

    call profiling_out("EXP_TAYLOR_BATCH")
    POP_SUB(exponential_taylor_series_batch)
  end subroutine exponential_taylor_series_batch

  subroutine exponential_lanczos_batch(te, namespace, mesh, hm, psib, deltat, vmagnus, inh_psib)
    type(exponential_t),                intent(inout) :: te
    type(namespace_t),                  intent(in)    :: namespace
    class(mesh_t),                      intent(in)    :: mesh
    class(hamiltonian_abst_t),          intent(inout) :: hm
    class(batch_t),                     intent(inout) :: psib
    complex(real64),                    intent(in)    :: deltat
    real(real64),             optional, intent(in)    :: vmagnus(:,:,:) !(mesh%np, hm%d%nspin, 2)
    class(batch_t),           optional, intent(in)    :: inh_psib       !< inhomogeneous term

    class(batch_t), allocatable :: tmpb

    PUSH_SUB(exponential_lanczos_batch)

    if (present(inh_psib)) then
      call inh_psib%clone_to(tmpb, copy_data=.true.)
      ! psib = psib + deltat * phi1(-i*deltat*H) inh_psib
      call exponential_lanczos_function_batch(te, namespace, mesh, hm, tmpb, deltat, phi1, vmagnus)
      call batch_axpy(mesh%np, real(deltat, real64), tmpb, psib)
      call tmpb%end()
    else
      call exponential_lanczos_function_batch(te, namespace, mesh, hm, psib, deltat, exponential, vmagnus)
    end if

    POP_SUB(exponential_lanczos_batch)
  end subroutine exponential_lanczos_batch

  ! ---------------------------------------------------------
  !> @brief Compute fun(H) psib, i.e. the application of a function of the Hamiltonian to a batch
  !!
  !! Usually, the function is an exponential, or a related function.
  !!
  !! Some details of the implementation can be understood from
  !! Saad, Y. (1992). Analysis of some Krylov subspace approximations to the matrix exponential operator.
  !! SIAM Journal on Numerical Analysis, 29(1), 209-228.
  !!
  !! A pdf can be accessed [here](https://www-users.cse.umn.edu/~saad/PDF/RIACS-90-ExpTh.pdf)
  !! or [via the DOI](https://doi.org/10.1137/0729014).
  !! Equation numbers below refer to this paper.
  subroutine exponential_lanczos_function_batch(te, namespace, mesh, hm, psib, deltat, fun, vmagnus)
    type(exponential_t),                intent(inout) :: te
    type(namespace_t),                  intent(in)    :: namespace
    class(mesh_t),                      intent(in)    :: mesh
    class(hamiltonian_abst_t),          intent(inout) :: hm
    class(batch_t),                     intent(inout) :: psib
    complex(real64),                    intent(in)    :: deltat
    interface
      complex(real64) function fun(z)
        import real64
        complex(real64), intent(in) :: z
      end
    end interface
    real(real64),             optional, intent(in)    :: vmagnus(:,:,:) !(mesh%np, hm%d%nspin, 2)

    integer ::  iter, l, ii, ist, max_initialized
    complex(real64), allocatable :: hamilt(:,:,:), expo(:,:,:)
    real(real64), allocatable :: beta(:), res(:), norm(:)
    integer, parameter :: max_order = 200
    type(batch_p_t), allocatable :: vb(:) ! Krylov subspace vectors

    PUSH_SUB(exponential_lanczos_function_batch)
    call profiling_in("EXP_LANCZOS_FUN_BATCH")

    if (te%exp_method /= EXP_LANCZOS) then
      message(1) = "The exponential method needs to be set to Lanzcos (TDExponentialMethod=lanczos)."
      call messages_fatal(1)
    end if

    SAFE_ALLOCATE(beta(1:psib%nst))
    SAFE_ALLOCATE(res(1:psib%nst))
    SAFE_ALLOCATE(norm(1:psib%nst))
    SAFE_ALLOCATE(vb(1:max_order))
    call psib%clone_to(vb(1)%p)
    max_initialized = 1

    call psib%copy_data_to(mesh%np, vb(1)%p, async=.true.)
    call mesh_batch_nrm2(mesh, vb(1)%p, beta)

    if (te%full_batch) beta = norm2(beta)

    ! If we have a null vector, no need to compute the action of the exponential.
    if (all(abs(beta) <= 1.0e-12_real64)) then
      SAFE_DEALLOCATE_A(beta)
      SAFE_DEALLOCATE_A(res)
      SAFE_DEALLOCATE_A(norm)
      call vb(1)%p%end()
      SAFE_DEALLOCATE_A(vb)
      call profiling_out("EXP_LANCZOS_FUN_BATCH")
      POP_SUB(exponential_lanczos_function_batch)
      return
    end if

    call batch_scal(mesh%np, M_ONE/beta, vb(1)%p, a_full = .false.)

    SAFE_ALLOCATE(hamilt(1:max_order+1, 1:max_order+1, 1:psib%nst))
    SAFE_ALLOCATE(  expo(1:max_order+1, 1:max_order+1, 1:psib%nst))

    ! This is the Lanczos loop...
    do iter = 1, max_order-1
      call psib%clone_to(vb(iter + 1)%p)
      max_initialized = iter + 1

      ! to apply the Hamiltonian
      call operate_batch(hm, namespace, mesh, vb(iter)%p, vb(iter+1)%p, vmagnus)

      ! We use either the Lanczos method (Hermitian case) or the Arnoldi method
      if (hm%is_hermitian()) then
        l = max(1, iter - 1)
        hamilt(1:max(l-1, 1), iter, 1:psib%nst) = M_ZERO
      else
        l = 1
        if (iter > 2) then
          hamilt(iter, 1:iter-2, 1:psib%nst) = M_ZERO
        end if
      end if

      ! Orthogonalize against previous vectors
      call zmesh_batch_orthogonalization(mesh, iter - l + 1, vb(l:iter), vb(iter+1)%p, &
        normalize = .false., overlap = hamilt(l:iter, iter, 1:psib%nst), norm = hamilt(iter + 1, iter, 1:psib%nst), &
        gs_scheme = te%arnoldi_gs, full_batch = te%full_batch)

      ! We now need to compute exp(Hm), where Hm is the projection of the linear transformation
      ! of the Hamiltonian onto the Krylov subspace Km
      ! See Eq. 4
      do ii = 1, psib%nst
        call zlalg_matrix_function(iter, -M_zI*deltat, hamilt(:,:,ii), expo(:,:,ii), fun, hm%is_hermitian())
        res(ii) = abs(hamilt(iter + 1, iter, ii) * abs(expo(iter, 1, ii)))
      end do !ii

      ! We now estimate the error we made. This is given by the formula denoted Er2 in Sec. 5.2
      if (all(abs(hamilt(iter + 1, iter, :)) < 1.0e4_real64*M_EPSILON)) exit ! "Happy breakdown"
      ! We normalize only if the norm is non-zero
      ! see http://www.netlib.org/utk/people/JackDongarra/etemplates/node216.html#alg:arn0
      norm = M_ONE
      do ist = 1, psib%nst
        if (abs(hamilt(iter + 1, iter, ist)) >= 1.0e4_real64 * M_EPSILON) then
          norm(ist) = M_ONE / abs(hamilt(iter + 1, iter, ist))
        end if
      end do

      call batch_scal(mesh%np, norm, vb(iter+1)%p, a_full = .false.)

      if (iter > 3 .and. all(res < te%lanczos_tol)) exit

    end do !iter

    if (iter == max_order) then ! Here one should consider the possibility of the happy breakdown.
      write(message(1),'(a,i5,a,es9.2)') 'Lanczos exponential expansion did not converge after ', iter, &
        ' iterations. Residual: ', maxval(res)
      call messages_warning(1, namespace=namespace)
    else
      write(message(1),'(a,i5)') 'Debug: Lanczos exponential iterations: ', iter
      call messages_info(1, namespace=namespace, debug_only=.true.)
    end if

    ! See Eq. 4 for the expression here
    ! zpsi = nrm * V * expo(1:iter, 1) = nrm * V * expo * V^(T) * zpsi
    call batch_scal(mesh%np, expo(1,1,1:psib%nst), psib, a_full = .false.)
    ! TODO: We should have a routine batch_gemv for improved performance (see #1070 on gitlab)
    do ii = 2, iter
      call batch_axpy(mesh%np, beta(1:psib%nst)*expo(ii,1,1:psib%nst), vb(ii)%p, psib, a_full = .false.)
      ! In order to apply the two exponentials, we must store the eigenvalues and eigenvectors given by zlalg_exp
      ! And to recontruct here the exp(i*dt*H) for deltat2
    end do

    do ii = 1, max_initialized
      call vb(ii)%p%end()
    end do

    SAFE_DEALLOCATE_A(vb)
    SAFE_DEALLOCATE_A(hamilt)
    SAFE_DEALLOCATE_A(expo)
    SAFE_DEALLOCATE_A(beta)
    SAFE_DEALLOCATE_A(res)
    SAFE_DEALLOCATE_A(norm)

    call accel_finish()

    call profiling_out("EXP_LANCZOS_FUN_BATCH")

    POP_SUB(exponential_lanczos_function_batch)
  end subroutine exponential_lanczos_function_batch


  ! ---------------------------------------------------------
  !> Calculates the exponential of the Hamiltonian through an expansion in
  !! Chebyshev polynomials.
  !!
  !! It uses the algorithm as outlined in
  !! [Hochbruck, M. & Ostermann, A.: Exponential Runge–Kutta methods for parabolic problems.
  !! Applied Numerical Mathematics 53, 323–339 (2005)](https://doi.org/10.1017/S0962492910000048),
  !! Section 4.1.
  !! See also [Kosloff, J. Phys. Chem. (1988), 92, 2087-2100](http://doi.org/10.1021/j100319a003),
  !! Section III.1 and
  !! [Kosloff, Annu. Rev. Phys. Chem. (1994), 45, 145–178](https://doi.org/10.1146/annurev.pc.45.100194.001045),
  !! equations 7.1ff and especially figure 2 for the convergence properties of the coefficients.
  subroutine exponential_cheby_batch(te, namespace, mesh, hm, psib, deltat, chebyshev_function, vmagnus)
    type(exponential_t),                intent(inout) :: te
    type(namespace_t),                  intent(in)    :: namespace
    class(mesh_t),                      intent(in)    :: mesh
    class(hamiltonian_abst_t),          intent(inout) :: hm
    class(batch_t),                     intent(inout) :: psib
    real(real64),                       intent(in)    :: deltat
    class(chebyshev_function_t),        intent(in)    :: chebyshev_function
    real(real64),             optional, intent(in)    :: vmagnus(:,:,:) !(mesh%np, hm%d%nspin, 2)

    integer :: j, order_needed
    complex(real64) :: coefficient
    complex(real64), allocatable :: coefficients(:)
    real(real64) :: error
    class(batch_t), allocatable, target :: psi0, psi1, psi2
    class(batch_t), pointer :: psi_n, psi_n1, psi_n2
    integer, parameter :: max_order = 200

    PUSH_SUB(exponential_cheby_batch)
    call profiling_in("EXP_CHEBY_BATCH")

    call psib%clone_to(psi0)
    call psib%clone_to(psi1)
    call psib%clone_to(psi2)
    call psib%copy_data_to(mesh%np, psi0)

    order_needed = max_order
    do j = 1, max_order
      error = chebyshev_function%get_error(j)
      if (error > M_ZERO .and. error < te%chebyshev_tol) then
        order_needed = j
        exit
      end if
    end do

    call chebyshev_function%get_coefficients(j, coefficients)

    ! zero-order term
    call batch_scal(mesh%np, coefficients(0), psib)
    ! first-order term
    ! shifted Hamiltonian
    call operate_batch(hm, namespace, mesh, psi0, psi1, vmagnus)
    call batch_axpy(mesh%np, -hm%spectral_middle_point, psi0, psi1)
    call batch_scal(mesh%np, M_ONE/hm%spectral_half_span, psi1)
    ! accumulate result
    call batch_axpy(mesh%np, coefficients(1), psi1, psib)

    ! use pointers to avoid copies
    psi_n => psi2
    psi_n1 => psi1
    psi_n2 => psi0

    do j = 2, order_needed
      ! compute shifted Hamiltonian and Chebyshev recurrence formula
      call operate_batch(hm, namespace, mesh, psi_n1, psi_n, vmagnus)
      call batch_axpy(mesh%np, -hm%spectral_middle_point, psi_n1, psi_n)
      call batch_xpay(mesh%np, psi_n2, -M_TWO/hm%spectral_half_span, psi_n)
      call batch_scal(mesh%np, -M_ONE, psi_n)

      ! accumulate result
      call batch_axpy(mesh%np, coefficients(j), psi_n, psib)

      ! shift pointers for the three-term recurrence, this avoids copies
      if (mod(j, 3) == 2) then
        psi_n => psi0
        psi_n1 => psi2
        psi_n2 => psi1
      else if (mod(j, 3) == 1) then
        psi_n => psi2
        psi_n1 => psi1
        psi_n2 => psi0
      else
        psi_n => psi1
        psi_n1 => psi0
        psi_n2 => psi2
      end if
    end do

    if (order_needed == max_order) then
      write(message(1),'(a,i5,a,es9.2)') 'Chebyshev exponential expansion did not converge after ', j, &
        ' iterations. Coefficient: ', coefficient
      call messages_warning(1, namespace=namespace)
    else
      write(message(1),'(a,i5)') 'Debug: Chebyshev exponential iterations: ', j
      call messages_info(1, namespace=namespace, debug_only=.true.)
    end if

    call psi0%end()
    call psi1%end()
    call psi2%end()
    SAFE_DEALLOCATE_A(psi0)
    SAFE_DEALLOCATE_A(psi1)
    SAFE_DEALLOCATE_A(psi2)

    SAFE_DEALLOCATE_A(coefficients)

    call profiling_out("EXP_CHEBY_BATCH")

    POP_SUB(exponential_cheby_batch)
  end subroutine exponential_cheby_batch


  subroutine operate_batch(hm, namespace, mesh, psib, hpsib, vmagnus)
    class(hamiltonian_abst_t),          intent(inout) :: hm
    type(namespace_t),                  intent(in)    :: namespace
    class(mesh_t),                      intent(in)    :: mesh
    class(batch_t),                     intent(inout) :: psib
    class(batch_t),                     intent(inout) :: hpsib
    real(real64),             optional, intent(in)    :: vmagnus(:, :, :)

    PUSH_SUB(operate_batch)

    if (present(vmagnus)) then
      call hm%zmagnus_apply(namespace, mesh, psib, hpsib, vmagnus)
    else
      call hm%zapply(namespace, mesh, psib, hpsib)
    end if

    POP_SUB(operate_batch)
  end subroutine operate_batch

  ! ---------------------------------------------------------
  !> Note that this routine not only computes the exponential, but
  !! also an extra term if there is a inhomogeneous term in the
  !! Hamiltonian hm.
  subroutine exponential_apply_all(te, namespace, mesh, hm, st, deltat, order)
    type(exponential_t),      intent(inout) :: te
    type(namespace_t),        intent(in)    :: namespace
    class(mesh_t),            intent(inout) :: mesh
    type(hamiltonian_elec_t), intent(inout) :: hm
    type(states_elec_t),      intent(inout) :: st
    real(real64),             intent(in)    :: deltat
    integer, optional,        intent(inout) :: order

    integer :: ik, ib, i
    real(real64) :: zfact

    type(states_elec_t) :: st1, hst1

    PUSH_SUB(exponential_apply_all)

    ASSERT(te%exp_method == EXP_TAYLOR)

    call states_elec_copy(st1, st)
    call states_elec_copy(hst1, st)

    zfact = M_ONE
    do i = 1, te%exp_order
      zfact = zfact * deltat / i

      if (i == 1) then
        call zhamiltonian_elec_apply_all(hm, namespace, mesh, st, hst1)
      else
        call zhamiltonian_elec_apply_all(hm, namespace, mesh, st1, hst1)
      end if

      do ik = st%d%kpt%start, st%d%kpt%end
        do ib = st%group%block_start, st%group%block_end
          call batch_set_zero(st1%group%psib(ib, ik))
          call batch_axpy(mesh%np, -M_zI, hst1%group%psib(ib, ik), st1%group%psib(ib, ik))
          call batch_axpy(mesh%np, zfact, st1%group%psib(ib, ik), st%group%psib(ib, ik))
        end do
      end do

    end do
    ! End of Taylor expansion loop.

    call states_elec_end(st1)
    call states_elec_end(hst1)

    ! We now add the inhomogeneous part, if present.
    if (hamiltonian_elec_inh_term(hm)) then
      !write(*, *) 'Now we apply the inhomogeneous term...'

      call states_elec_copy(st1, hm%inh_st)
      call states_elec_copy(hst1, hm%inh_st)


      do ik = st%d%kpt%start, st%d%kpt%end
        do ib = st%group%block_start, st%group%block_end
          call batch_axpy(mesh%np, deltat, st1%group%psib(ib, ik), st%group%psib(ib, ik))
        end do
      end do

      zfact = M_ONE
      do i = 1, te%exp_order
        zfact = zfact * deltat / (i+1)

        if (i == 1) then
          call zhamiltonian_elec_apply_all(hm, namespace, mesh, hm%inh_st, hst1)
        else
          call zhamiltonian_elec_apply_all(hm, namespace, mesh, st1, hst1)
        end if

        do ik = st%d%kpt%start, st%d%kpt%end
          do ib = st%group%block_start, st%group%block_end
            call batch_set_zero(st1%group%psib(ib, ik))
            call batch_axpy(mesh%np, -M_zI, hst1%group%psib(ib, ik), st1%group%psib(ib, ik))
            call batch_axpy(mesh%np, deltat * zfact, st1%group%psib(ib, ik), st%group%psib(ib, ik))
          end do
        end do

      end do

      call states_elec_end(st1)
      call states_elec_end(hst1)

    end if

    if (present(order)) order = te%exp_order*st%nik*st%nst ! This should be the correct number

    POP_SUB(exponential_apply_all)
  end subroutine exponential_apply_all

  subroutine exponential_apply_phi_batch(te, namespace, mesh, hm, psib, deltat, k, vmagnus)
    class(exponential_t),               intent(inout) :: te
    type(namespace_t),                  intent(in)    :: namespace
    class(mesh_t),                      intent(in)    :: mesh
    class(hamiltonian_abst_t),          intent(inout) :: hm
    class(batch_t),                     intent(inout) :: psib
    real(real64),                       intent(in)    :: deltat
    integer,                            intent(in)    :: k
    real(real64),             optional, intent(in)    :: vmagnus(:,:,:) !(mesh%np, hm%d%nspin, 2)

    class(chebyshev_function_t), pointer :: chebyshev_function
    complex(real64) :: deltat_

    PUSH_SUB_WITH_PROFILE(exponential_apply_phi_batch)

    ASSERT(psib%type() == TYPE_CMPLX)
    if (present(vmagnus)) then
      select type(hm)
      class is (hamiltonian_elec_t)
        ASSERT(size(vmagnus, 1) >= mesh%np)
        ASSERT(size(vmagnus, 2) == hm%d%nspin)
        ASSERT(size(vmagnus, 3) == 2)
      class default
        write(message(1), '(a)') 'Magnus operators only implemented for electrons at the moment'
        call messages_fatal(1, namespace=namespace)
      end select
    end if

    if (.not. hm%is_hermitian() .and. te%exp_method == EXP_CHEBYSHEV) then
      write(message(1), '(a)') 'The Chebyshev expansion for the exponential will only converge if the imaginary'
      write(message(2), '(a)') 'eigenvalues are small enough compared to the span of the real eigenvalues,'
      write(message(3), '(a)') 'i.e., for ratios smaller than about 1e-3.'
      write(message(4), '(a)') 'The Lanczos method ("TDExponentialMethod = lanczos") is guaranteed to'
      write(message(5), '(a)') 'always converge in this case.'
      call messages_warning(5, namespace=namespace)
    end if

    deltat_ = cmplx(deltat, M_ZERO, real64)

    select case (te%exp_method)
    case (EXP_TAYLOR)
      call exponential_taylor_series_batch(te, namespace, mesh, hm, psib, deltat_, vmagnus=vmagnus, phik_shift=k)

    case (EXP_LANCZOS)
      if (k == 1) then
        call exponential_lanczos_function_batch(te, namespace, mesh, hm, psib, deltat_, phi1, vmagnus)
      else if (k == 2) then
        call exponential_lanczos_function_batch(te, namespace, mesh, hm, psib, deltat_, phi2, vmagnus)
      else
        write(message(1), '(a)') 'Lanczos expansion not implemented for phi_k, k > 2'
        call messages_fatal(1, namespace=namespace)
      end if

    case (EXP_CHEBYSHEV)
      if (k == 1) then
        chebyshev_function => chebyshev_numerical_t(hm%spectral_half_span, hm%spectral_middle_point, deltat, phi1)
      else if (k == 2) then
        chebyshev_function => chebyshev_numerical_t(hm%spectral_half_span, hm%spectral_middle_point, deltat, phi2)
      else
        write(message(1), '(a)') 'Chebyshev expansion not implemented for phi_k, k > 2'
        call messages_fatal(1, namespace=namespace)
      end if
      call exponential_cheby_batch(te, namespace, mesh, hm, psib, deltat, chebyshev_function, vmagnus)
      deallocate(chebyshev_function)

    end select
    POP_SUB_WITH_PROFILE(exponential_apply_phi_batch)
  end subroutine exponential_apply_phi_batch

end module exponential_oct_m

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
