!! Copyright (C) 2002-2006 M. Marques, A. Castro, A. Rubio, G. Bertsch
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

#include "global.h"

module propagator_etrs_oct_m
  use accel_oct_m
  use batch_oct_m
  use debug_oct_m
  use density_oct_m
  use electron_space_oct_m
  use exponential_oct_m
  use ext_partner_list_oct_m
  use grid_oct_m
  use gauge_field_oct_m
  use global_oct_m
  use hamiltonian_elec_oct_m
  use hamiltonian_elec_base_oct_m
  use interaction_partner_oct_m
  use ion_dynamics_oct_m
  use ions_oct_m
  use ks_potential_oct_m
  use lalg_basic_oct_m
  use lda_u_oct_m
  use lda_u_io_oct_m
  use math_oct_m
  use messages_oct_m
  use mesh_function_oct_m
  use multicomm_oct_m
  use namespace_oct_m
  use parser_oct_m
  use profiling_oct_m
  use propagator_base_oct_m
  use space_oct_m
  use states_elec_dim_oct_m
  use states_elec_oct_m
  use types_oct_m
  use v_ks_oct_m
  use wfs_elec_oct_m
  use propagation_ops_elec_oct_m
  use xc_oct_m

  implicit none

  private

  public ::                       &
    td_etrs,                      &
    td_etrs_sc,                   &
    td_aetrs,                     &
    td_caetrs

contains

  ! ---------------------------------------------------------
  !> Propagator with enforced time-reversal symmetry
  subroutine td_etrs(ks, namespace, space, hm, ext_partners, gr, st, tr, time, dt, &
    ions_dyn, ions, mc)
    type(v_ks_t),                     intent(inout) :: ks
    type(namespace_t),                intent(in)    :: namespace
    type(electron_space_t),           intent(in)    :: space
    type(hamiltonian_elec_t),         intent(inout) :: hm
    type(partner_list_t),             intent(in)    :: ext_partners
    type(grid_t),                     intent(inout) :: gr
    type(states_elec_t),              intent(inout) :: st
    type(propagator_base_t),          intent(inout) :: tr
    real(real64),                     intent(in)    :: time
    real(real64),                     intent(in)    :: dt
    type(ion_dynamics_t),             intent(inout) :: ions_dyn
    type(ions_t),                     intent(inout) :: ions
    type(multicomm_t),                intent(inout) :: mc    !< index and domain communicators

    type(xc_copied_potentials_t) :: vhxc_t1, vhxc_t2
    type(gauge_field_t), pointer :: gfield

    PUSH_SUB(td_etrs)

    if (hm%theory_level /= INDEPENDENT_PARTICLES) then

      call hm%ks_pot%store_potentials(vhxc_t1)

      call propagation_ops_elec_fuse_density_exp_apply(tr%te, namespace, st, gr, hm, M_HALF*dt, dt)

      call v_ks_calc(ks, namespace, space, hm, st, ions, ext_partners, &
        calc_current = .false., calc_energy = .false., calc_eigenval = .false.)

      call hm%ks_pot%store_potentials(vhxc_t2)
      call hm%ks_pot%restore_potentials(vhxc_t1)
      call hm%update(gr, namespace, space, ext_partners, time = time - dt)

    else

      call propagation_ops_elec_exp_apply(tr%te, namespace, st, gr, hm, M_HALF*dt)

    end if

    ! propagate dt/2 with H(t)

    ! first move the ions to time t
    call propagation_ops_elec_move_ions(tr%propagation_ops_elec, gr, hm, st, namespace, space, ions_dyn, ions, &
      ext_partners, mc, time, dt)

    gfield => list_get_gauge_field(ext_partners)
    if(associated(gfield)) then
      call propagation_ops_elec_propagate_gauge_field(tr%propagation_ops_elec, gfield, dt, time)
    end if

    call hm%ks_pot%restore_potentials(vhxc_t2)

    call propagation_ops_elec_update_hamiltonian(namespace, space, st, gr, hm, ext_partners, time)

    ! propagate dt/2 with H(time - dt)
    call propagation_ops_elec_fuse_density_exp_apply(tr%te, namespace, st, gr, hm, M_HALF*dt)

    POP_SUB(td_etrs)
  end subroutine td_etrs

  ! ---------------------------------------------------------
  !> Propagator with enforced time-reversal symmetry and self-consistency
  subroutine td_etrs_sc(ks, namespace, space, hm, ext_partners, gr, st, tr, time, dt, &
    ions_dyn, ions, mc, sctol, scsteps)
    type(v_ks_t),                     intent(inout) :: ks
    type(namespace_t),                intent(in)    :: namespace
    type(electron_space_t),           intent(in)    :: space
    type(hamiltonian_elec_t),         intent(inout) :: hm
    type(partner_list_t),             intent(in)    :: ext_partners
    type(grid_t),                     intent(inout) :: gr
    type(states_elec_t),              intent(inout) :: st
    type(propagator_base_t),          intent(inout) :: tr
    real(real64),                     intent(in)    :: time
    real(real64),                     intent(in)    :: dt
    type(ion_dynamics_t),             intent(inout) :: ions_dyn
    type(ions_t),                     intent(inout) :: ions
    type(multicomm_t),                intent(inout) :: mc    !< index and domain communicators
    real(real64),                     intent(in)    :: sctol
    integer,                optional, intent(out)   :: scsteps

    real(real64) :: diff
    integer :: ik, ib, iter
    class(wfs_elec_t), allocatable :: psi2(:, :)
    ! these are hardcoded for the moment
    integer, parameter :: niter = 10
    type(gauge_field_t), pointer :: gfield
    type(xc_copied_potentials_t) :: vhxc_t1, vhxc_t2

    PUSH_SUB(td_etrs_sc)

    ASSERT(hm%theory_level /= INDEPENDENT_PARTICLES)

    call hm%ks_pot%store_potentials(vhxc_t1)

    call messages_new_line()
    call messages_write('        Self-consistency iteration:')
    call messages_info(namespace=namespace)

    !Propagate the states to t+dt/2 and compute the density at t+dt
    call propagation_ops_elec_fuse_density_exp_apply(tr%te, namespace, st, gr, hm, M_HALF*dt, dt)

    call v_ks_calc(ks, namespace, space, hm, st, ions, ext_partners, &
      calc_current = .false., calc_energy = .false., calc_eigenval = .false.)


    call hm%ks_pot%store_potentials(vhxc_t2)
    call hm%ks_pot%restore_potentials(vhxc_t1)

    call propagation_ops_elec_update_hamiltonian(namespace, space, st, gr, hm, ext_partners, time - dt)

    ! propagate dt/2 with H(t)

    ! first move the ions to time t
    call propagation_ops_elec_move_ions(tr%propagation_ops_elec, gr, hm, st, namespace, space, ions_dyn, ions, &
      ext_partners, mc, time, dt)

    gfield => list_get_gauge_field(ext_partners)
    if(associated(gfield)) then
      call propagation_ops_elec_propagate_gauge_field(tr%propagation_ops_elec, gfield, dt, time)
    end if

    call hm%ks_pot%restore_potentials(vhxc_t2)

    call propagation_ops_elec_update_hamiltonian(namespace, space, st, gr, hm, ext_partners, time)

    SAFE_ALLOCATE_TYPE_ARRAY(wfs_elec_t, psi2, (st%group%block_start:st%group%block_end, st%d%kpt%start:st%d%kpt%end))

    ! store the state at half iteration
    do ik = st%d%kpt%start, st%d%kpt%end
      do ib = st%group%block_start, st%group%block_end
        call st%group%psib(ib, ik)%copy_to(psi2(ib, ik), copy_data=.true.)
      end do
    end do

    do iter = 1, niter

      call hm%ks_pot%store_potentials(vhxc_t2)

      call propagation_ops_elec_fuse_density_exp_apply(tr%te, namespace, st, gr, hm, M_HALF * dt)

      call v_ks_calc(ks, namespace, space, hm, st, ions, ext_partners, &
        time = time, calc_current = .false., calc_energy = .false., calc_eigenval = .false.)
      call lda_u_update_occ_matrices(hm%lda_u, namespace, gr, st, hm%hm_base, hm%phase, hm%energy)

      ! now check how much the potential changed
      diff = hm%ks_pot%check_convergence(vhxc_t2, gr, st%rho, st%qtot)

      call messages_write('          step ')
      call messages_write(iter)
      call messages_write(', residue = ')
      call messages_write(abs(diff), fmt = '(1x,es9.2)')
      call messages_info(namespace=namespace)

      if (diff <= sctol) exit

      if (iter /= niter) then
        ! we are not converged, restore the states
        do ik = st%d%kpt%start, st%d%kpt%end
          do ib = st%group%block_start, st%group%block_end
            call psi2(ib, ik)%copy_data_to(gr%np, st%group%psib(ib, ik))
          end do
        end do
      end if

    end do

    if (hm%lda_u_level /= DFT_U_NONE) then
      call lda_u_write_U(hm%lda_u, namespace=namespace)
      call lda_u_write_V(hm%lda_u, namespace=namespace)
    end if

    ! print an empty line
    call messages_info(namespace=namespace)

    if (present(scsteps)) scsteps = iter

    do ik = st%d%kpt%start, st%d%kpt%end
      do ib = st%group%block_start, st%group%block_end
        call psi2(ib, ik)%end()
      end do
    end do

    SAFE_DEALLOCATE_A(psi2)

    POP_SUB(td_etrs_sc)
  end subroutine td_etrs_sc

  ! ---------------------------------------------------------
  !> Propagator with approximate enforced time-reversal symmetry
  subroutine td_aetrs(namespace, space, hm, gr, st, tr, time, dt, ions_dyn, ions, ext_partners, mc)
    type(namespace_t),                intent(in)    :: namespace
    type(electron_space_t),           intent(in)    :: space
    type(hamiltonian_elec_t),         intent(inout) :: hm
    type(grid_t),                     intent(inout) :: gr
    type(states_elec_t),              intent(inout) :: st
    type(propagator_base_t),          intent(inout) :: tr
    real(real64),                     intent(in)    :: time
    real(real64),                     intent(in)    :: dt
    type(ion_dynamics_t),             intent(inout) :: ions_dyn
    type(ions_t),                     intent(inout) :: ions
    type(partner_list_t),             intent(in)    :: ext_partners
    type(multicomm_t),                intent(inout) :: mc    !< index and domain communicators

    type(gauge_field_t), pointer :: gfield

    PUSH_SUB(td_aetrs)

    ! propagate half of the time step with H(time - dt)
    call propagation_ops_elec_exp_apply(tr%te, namespace, st, gr, hm, M_HALF*dt)

    !Get the potentials from the interpolation
    call propagation_ops_elec_interpolate_get(hm, gr, tr%vks_old)

    ! move the ions to time t
    call propagation_ops_elec_move_ions(tr%propagation_ops_elec, gr, hm, st, namespace, space, ions_dyn, &
      ions, ext_partners, mc, time, dt)

    !Propagate gauge field
    gfield => list_get_gauge_field(ext_partners)
    if(associated(gfield)) then
      call propagation_ops_elec_propagate_gauge_field(tr%propagation_ops_elec, gfield, dt, time)
    end if

    !Update Hamiltonian
    call propagation_ops_elec_update_hamiltonian(namespace, space, st, gr, hm, ext_partners, time)

    !Do the time propagation for the second half of the time step
    call propagation_ops_elec_fuse_density_exp_apply(tr%te, namespace, st, gr, hm, M_HALF*dt)

    POP_SUB(td_aetrs)
  end subroutine td_aetrs

  ! ---------------------------------------------------------
  !> Propagator with approximate enforced time-reversal symmetry
  subroutine td_caetrs(ks, namespace, space, hm, ext_partners, gr, st, tr, time, dt, &
    ions_dyn, ions, mc)
    type(v_ks_t),                     intent(inout) :: ks
    type(namespace_t),                intent(in)    :: namespace
    type(electron_space_t),           intent(in)    :: space
    type(hamiltonian_elec_t),         intent(inout) :: hm
    type(partner_list_t),             intent(in)    :: ext_partners
    type(grid_t),                     intent(inout) :: gr
    type(states_elec_t),              intent(inout) :: st
    type(propagator_base_t),          intent(inout) :: tr
    real(real64),                     intent(in)    :: time
    real(real64),                     intent(in)    :: dt
    type(ion_dynamics_t),             intent(inout) :: ions_dyn
    type(ions_t),                     intent(inout) :: ions
    type(multicomm_t),                intent(inout) :: mc    !< index and domain communicators

    integer :: ik, ispin, ip, ist, ib
    real(real64) :: vv
    complex(real64) :: phase
    type(density_calc_t)  :: dens_calc
    integer(int64)           :: pnp, wgsize, dim2, dim3
    type(accel_mem_t)    :: phase_buff
    type(gauge_field_t), pointer :: gfield
    type(xc_copied_potentials_t) :: vold

    PUSH_SUB(td_caetrs)

    ! Get the interpolated KS potentials into vold
    call hm%ks_pot%get_interpolated_potentials(tr%vks_old, 2, storage=vold)
    ! And set it to the Hamiltonian
    call hm%ks_pot%restore_potentials(vold)

    call propagation_ops_elec_update_hamiltonian(namespace, space, st, gr, hm, ext_partners, time - dt)

    call v_ks_calc_start(ks, namespace, space, hm, st, ions, ions%latt, ext_partners, &
      time = time - dt, calc_energy = .false.)

    ! propagate half of the time step with H(time - dt)
    call propagation_ops_elec_exp_apply(tr%te, namespace, st, gr, hm, M_HALF*dt)

    call v_ks_calc_finish(ks, hm, namespace, space, ions%latt, st, ext_partners)

    call hm%ks_pot%set_interpolated_potentials(tr%vks_old, 1)

    call hm%ks_pot%perform_interpolation(tr%vks_old, (/time - dt, time - M_TWO*dt, time - M_THREE*dt/), time)

    ! Replace vold by 0.5(vhxc+vold)
    call hm%ks_pot%mix_potentials(vold, dt)

    ! copy vold to a cl buffer
    if (accel_is_enabled() .and. hm%apply_packed()) then
      if (family_is_mgga_with_exc(hm%xc)) then
        call messages_not_implemented('CAETRS propagator with accel and MGGA with energy functionals', namespace=namespace)
      end if
      pnp = accel_padded_size(gr%np)
      call accel_create_buffer(phase_buff, ACCEL_MEM_READ_ONLY, TYPE_FLOAT, pnp*st%d%nspin)
      call vold%copy_vhxc_to_buffer(int(gr%np, int64), st%d%nspin, pnp, phase_buff)
    end if

    !Get the potentials from the interpolator
    call propagation_ops_elec_interpolate_get(hm, gr, tr%vks_old)

    ! move the ions to time t
    call propagation_ops_elec_move_ions(tr%propagation_ops_elec, gr, hm, st, namespace, space, ions_dyn, &
      ions, ext_partners, mc, time, dt)

    gfield => list_get_gauge_field(ext_partners)
    if(associated(gfield)) then
      call propagation_ops_elec_propagate_gauge_field(tr%propagation_ops_elec, gfield, dt, time)
    end if

    call propagation_ops_elec_update_hamiltonian(namespace, space, st, gr, hm, ext_partners, time)

    call density_calc_init(dens_calc, st, gr, st%rho)

    ! propagate the other half with H(t)
    do ik = st%d%kpt%start, st%d%kpt%end
      ispin = st%d%get_spin_index(ik)

      do ib = st%group%block_start, st%group%block_end
        if (hm%apply_packed()) then
          call st%group%psib(ib, ik)%do_pack()
          if (hamiltonian_elec_inh_term(hm)) call hm%inh_st%group%psib(ib, ik)%do_pack()
        end if

        call profiling_in("CAETRS_PHASE")
        select case (st%group%psib(ib, ik)%status())
        case (BATCH_NOT_PACKED)
          do ip = 1, gr%np
            vv = vold%vhxc(ip, ispin)
            phase = cmplx(cos(vv), -sin(vv), real64)
            do ist = 1, st%group%psib(ib, ik)%nst_linear
              st%group%psib(ib, ik)%zff_linear(ip, ist) = st%group%psib(ib, ik)%zff_linear(ip, ist)*phase
            end do
          end do
        case (BATCH_PACKED)
          do ip = 1, gr%np
            vv = vold%vhxc(ip, ispin)
            phase = cmplx(cos(vv), -sin(vv), real64)
            do ist = 1, st%group%psib(ib, ik)%nst_linear
              st%group%psib(ib, ik)%zff_pack(ist, ip) = st%group%psib(ib, ik)%zff_pack(ist, ip)*phase
            end do
          end do
        case (BATCH_DEVICE_PACKED)
          call accel_set_kernel_arg(kernel_phase, 0, pnp*(ispin - 1))
          call accel_set_kernel_arg(kernel_phase, 1, phase_buff)
          call accel_set_kernel_arg(kernel_phase, 2, st%group%psib(ib, ik)%ff_device)
          call accel_set_kernel_arg(kernel_phase, 3, log2(st%group%psib(ib, ik)%pack_size(1)))

          wgsize = accel_max_workgroup_size()/st%group%psib(ib, ik)%pack_size(1)
          dim3 = pnp/(accel_max_size_per_dim(2)*wgsize) + 1
          dim2 = min(accel_max_size_per_dim(2)*wgsize, pad(pnp, wgsize))

          call accel_kernel_run(kernel_phase, (/st%group%psib(ib, ik)%pack_size(1), dim2, dim3/), &
            (/st%group%psib(ib, ik)%pack_size(1), wgsize, 1_int64/))

        end select
        call profiling_out("CAETRS_PHASE")

        call hm%phase%set_phase_corr(gr, st%group%psib(ib, ik))
        if (hamiltonian_elec_inh_term(hm)) then
          call tr%te%apply_batch(namespace, gr, hm, st%group%psib(ib, ik), M_HALF*dt, &
            inh_psib = hm%inh_st%group%psib(ib, ik))
        else
          call tr%te%apply_batch(namespace, gr, hm, st%group%psib(ib, ik), M_HALF*dt)
        end if
        call hm%phase%unset_phase_corr(gr, st%group%psib(ib, ik))

        call density_calc_accumulate(dens_calc, st%group%psib(ib, ik))

        if (hm%apply_packed()) then
          call st%group%psib(ib, ik)%do_unpack()
          if (hamiltonian_elec_inh_term(hm)) call hm%inh_st%group%psib(ib, ik)%do_unpack()
        end if
      end do
    end do

    call density_calc_end(dens_calc)

    if (accel_is_enabled() .and. hm%apply_packed()) then
      call accel_release_buffer(phase_buff)
    end if

    POP_SUB(td_caetrs)
  end subroutine td_caetrs

end module propagator_etrs_oct_m

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
