!! Copyright (C) 2009-2020 X. Andrade, N. Tancogne-Dejean, M. Lueders
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

!> @brief apply the local potential (stored in the hamiltonian) to the states
!
subroutine X(hamiltonian_elec_base_local)(this, mesh, std, ispin, psib, vpsib, async)
  class(hamiltonian_elec_base_t),  intent(in)    :: this
  class(mesh_t),                  intent(in)    :: mesh
  type(states_elec_dim_t),        intent(in)    :: std    !< dimensions of the states
  integer,                        intent(in)    :: ispin  !< spin channel
  type(wfs_elec_t),               intent(in)    :: psib   !< original states
  type(wfs_elec_t),               intent(inout) :: vpsib  !< states multiplied by local potential
  logical,              optional, intent(in)    :: async

  PUSH_SUB(X(hamiltonian_elec_base_local))

  if (psib%status() == BATCH_DEVICE_PACKED) then
    if (allocated(this%Impotential)) then
      call X(hamiltonian_elec_base_local_sub)(this%potential, mesh, std, ispin, &
        psib, vpsib, potential_accel=this%potential_accel, &
        impotential_accel=this%impotential_accel, async=async)
    else
      call X(hamiltonian_elec_base_local_sub)(this%potential, mesh, std, ispin, &
        psib, vpsib, potential_accel = this%potential_accel, async=async)
    end if
  else
    if (allocated(this%Impotential)) then
      call X(hamiltonian_elec_base_local_sub)(this%potential, mesh, std, ispin, &
        psib, vpsib, Impotential = this%Impotential)
    else
      call X(hamiltonian_elec_base_local_sub)(this%potential, mesh, std, ispin, psib, vpsib)
    end if
  end if

  POP_SUB(X(hamiltonian_elec_base_local))
end subroutine X(hamiltonian_elec_base_local)

! ---------------------------------------------------------------------------------------
!
!> @brief apply a local potential to a set of states
!!
!! This auxiliary routine multiples the wave functions in psib by a potential,
!! provided as argument, and returns the new set of states in vpsib.
!
subroutine X(hamiltonian_elec_base_local_sub)(potential, mesh, std, ispin, psib, vpsib, &
  Impotential, potential_accel, impotential_accel, async)
  real(real64), contiguous,                 intent(in)    :: potential(:,:) !< real potential v
  class(mesh_t),                     intent(in)    :: mesh   !< the mesh
  type(states_elec_dim_t),           intent(in)    :: std    !< dimensions of the states
  integer,                           intent(in)    :: ispin  !< spin channel
  class(batch_t), target,            intent(in)    :: psib   !< original wave functions
  class(batch_t), target,            intent(inout) :: vpsib  !< wave functions multiplied by v
  real(real64), optional, contiguous,       intent(in)    :: impotential(:,:) !< optional imaginary potential
  type(accel_mem_t),  optional, target, intent(in) :: potential_accel    !< device buffer for Re(v)
  type(accel_mem_t),  optional, target, intent(in) :: impotential_accel  !< devide buffer for Im(v)
  logical,            optional,      intent(in)    :: async

  integer :: ist, ip, is
#ifdef R_TCOMPLEX
  R_TYPE :: psi1, psi2
  real(real64)  :: Imvv
  complex(real64)  :: pot(1:4), tmp
  complex(real64), pointer :: psi(:, :), vpsi(:, :)
#endif
  real(real64)   :: vv
  logical :: pot_is_cmplx
  integer :: pnp
  integer(int64) :: localsize, dim2, dim3
  type(accel_mem_t), pointer :: potential_accel_, impotential_accel_

  call profiling_in(TOSTRING(X(VLPSI)))
  PUSH_SUB(X(hamiltonian_elec_base_local_sub))

  pot_is_cmplx = .false.
  if (present(Impotential)) pot_is_cmplx = .true.
  if (present(impotential_accel)) pot_is_cmplx = .true.

  call psib%check_compatibility_with(vpsib)

  select case (psib%status())
  case (BATCH_DEVICE_PACKED)
    pnp = accel_padded_size(mesh%np)

    if(.not. present(potential_accel)) then
      SAFE_ALLOCATE(potential_accel_)
      call accel_create_buffer(potential_accel_, ACCEL_MEM_READ_WRITE, TYPE_FLOAT, pnp*std%nspin)
      if(std%ispin /= SPINORS) then
        call accel_write_buffer(potential_accel_, mesh%np, potential(:,ispin), offset=pnp*(ispin - 1))
      else
        do is = 1, std%nspin
          call accel_write_buffer(potential_accel_, mesh%np, potential(:,is), offset=pnp*(is - 1))
        end do
      end if
    else
      potential_accel_ => potential_accel
    end if

    if(.not. present(impotential_accel) .and. present(impotential)) then
      SAFE_ALLOCATE(impotential_accel_)
      call accel_create_buffer(impotential_accel_, ACCEL_MEM_READ_WRITE, TYPE_FLOAT, pnp*std%nspin)
      if(std%ispin /= SPINORS) then
        call accel_write_buffer(impotential_accel_, mesh%np, impotential(:,ispin), offset=pnp*(ispin - 1))
      else
        do is = 1, std%nspin
          call accel_write_buffer(impotential_accel_, mesh%np, impotential(:,is), offset=pnp*(is - 1))
        end do
      end if
    else
      impotential_accel_ => impotential_accel
    end if

    if (.not. pot_is_cmplx) then
      select case (std%ispin)

      case (UNPOLARIZED, SPIN_POLARIZED)
        call accel_set_kernel_arg(kernel_vpsi, 0, pnp*(ispin - 1))
        call accel_set_kernel_arg(kernel_vpsi, 1, mesh%np)
        call accel_set_kernel_arg(kernel_vpsi, 2, potential_accel_)
        call accel_set_kernel_arg(kernel_vpsi, 3, psib%ff_device)
        call accel_set_kernel_arg(kernel_vpsi, 4, int(log2(psib%pack_size_real(1)), int32))
        call accel_set_kernel_arg(kernel_vpsi, 5, vpsib%ff_device)
        call accel_set_kernel_arg(kernel_vpsi, 6, int(log2(vpsib%pack_size_real(1)), int32))

        localsize = accel_kernel_workgroup_size(kernel_vpsi)/psib%pack_size_real(1)

        dim3 = mesh%np/(accel_max_size_per_dim(2)*localsize) + 1
        dim2 = min(accel_max_size_per_dim(2)*localsize, pad(mesh%np, localsize))

        call accel_kernel_run(kernel_vpsi, (/psib%pack_size_real(1), dim2, dim3/), (/psib%pack_size_real(1), localsize, 1_int64/))

      case (SPINORS)
        call accel_set_kernel_arg(kernel_vpsi_spinors, 0, mesh%np)
        call accel_set_kernel_arg(kernel_vpsi_spinors, 1, potential_accel_)
        call accel_set_kernel_arg(kernel_vpsi_spinors, 2, pnp)
        call accel_set_kernel_arg(kernel_vpsi_spinors, 3, psib%ff_device)
        call accel_set_kernel_arg(kernel_vpsi_spinors, 4, int(psib%pack_size(1), int32))
        call accel_set_kernel_arg(kernel_vpsi_spinors, 5, vpsib%ff_device)
        call accel_set_kernel_arg(kernel_vpsi_spinors, 6, int(vpsib%pack_size(1), int32))

        localsize = accel_kernel_workgroup_size(kernel_vpsi_spinors)/(psib%pack_size(1)/2)

        dim3 = pnp/(accel_max_size_per_dim(2)*localsize) + 1
        dim2 = min(accel_max_size_per_dim(2)*localsize, pad(pnp, localsize))

        call accel_kernel_run(kernel_vpsi_spinors, (/psib%pack_size(1)/2, dim2, dim3/), &
          (/psib%pack_size(1)/2, localsize, 1_int64/))

      end select
    else
      ! complex potentials
      select case (std%ispin)

      case (UNPOLARIZED, SPIN_POLARIZED)
        call accel_set_kernel_arg(kernel_vpsi_complex, 0, pnp*(ispin - 1))
        call accel_set_kernel_arg(kernel_vpsi_complex, 1, mesh%np)
        call accel_set_kernel_arg(kernel_vpsi_complex, 2, potential_accel_)
        call accel_set_kernel_arg(kernel_vpsi_complex, 3, impotential_accel)
        call accel_set_kernel_arg(kernel_vpsi_complex, 4, psib%ff_device)
        call accel_set_kernel_arg(kernel_vpsi_complex, 5, int(log2(psib%pack_size(1)), int32))
        call accel_set_kernel_arg(kernel_vpsi_complex, 6, vpsib%ff_device)
        call accel_set_kernel_arg(kernel_vpsi_complex, 7, int(log2(vpsib%pack_size(1)), int32))

        localsize = accel_kernel_workgroup_size(kernel_vpsi_complex)/psib%pack_size(1)

        dim3 = mesh%np/(accel_max_size_per_dim(2)*localsize) + 1
        dim2 = min(accel_max_size_per_dim(2)*localsize, pad(mesh%np, localsize))

        call accel_kernel_run(kernel_vpsi_complex, (/psib%pack_size(1), dim2, dim3/), (/psib%pack_size(1), localsize, 1_int64/))

      case (SPINORS)
        call accel_set_kernel_arg(kernel_vpsi_spinors_complex, 0, mesh%np)
        call accel_set_kernel_arg(kernel_vpsi_spinors_complex, 1, potential_accel_)
        call accel_set_kernel_arg(kernel_vpsi_spinors_complex, 2, pnp)
        call accel_set_kernel_arg(kernel_vpsi_spinors_complex, 3, impotential_accel)
        call accel_set_kernel_arg(kernel_vpsi_spinors_complex, 4, pnp)
        call accel_set_kernel_arg(kernel_vpsi_spinors_complex, 5, psib%ff_device)
        call accel_set_kernel_arg(kernel_vpsi_spinors_complex, 6, int(psib%pack_size(1), int32))
        call accel_set_kernel_arg(kernel_vpsi_spinors_complex, 7, vpsib%ff_device)
        call accel_set_kernel_arg(kernel_vpsi_spinors_complex, 8, int(vpsib%pack_size(1), int32))

        localsize = accel_kernel_workgroup_size(kernel_vpsi_spinors_complex)/(psib%pack_size(1)/2)

        dim3 = pnp/(accel_max_size_per_dim(2)*localsize) + 1
        dim2 = min(accel_max_size_per_dim(2)*localsize, pad(pnp, localsize))

        call accel_kernel_run(kernel_vpsi_spinors_complex, (/psib%pack_size(1)/2, dim2, dim3/), &
          (/psib%pack_size(1)/2, localsize, 1_int64/))

      end select
    end if

    if (.not. optional_default(async, .false.)) call accel_finish()

    call profiling_count_operations((R_MUL*psib%nst_linear)*mesh%np)
    call profiling_count_transfers(mesh%np, M_ONE)
    call profiling_count_transfers(mesh%np*psib%nst, R_TOTYPE(M_ONE))

    if(.not. present(potential_accel)) then
      call accel_release_buffer(potential_accel_)
      SAFE_DEALLOCATE_P(potential_accel_)
    end if
    if(.not. present(impotential_accel) .and. present(impotential)) then
      call accel_release_buffer(impotential_accel_)
      SAFE_DEALLOCATE_P(impotential_accel_)
    end if

  case (BATCH_PACKED)

    select case (std%ispin)
    case (UNPOLARIZED, SPIN_POLARIZED)
      if (pot_is_cmplx) then
#ifdef R_TCOMPLEX
        !$omp parallel do private(vv, Imvv, ist, tmp)
        do ip = 1, mesh%np
          vv = potential(ip, ispin)
          Imvv = Impotential(ip, ispin)
          tmp = cmplx(vv, Imvv, real64)
          !$omp simd
          do ist = 1, psib%nst_linear
            vpsib%zff_pack(ist, ip) = vpsib%zff_pack(ist, ip) + tmp*psib%zff_pack(ist, ip)
          end do
        end do
        call profiling_count_operations(2*((R_ADD+R_MUL)*psib%nst_linear)*mesh%np)
#else
        ! Complex potential can only be applied to complex batches
        ASSERT(.false.)
#endif
      else
        !$omp parallel do private(vv, ist)
        do ip = 1, mesh%np
          vv = potential(ip, ispin)
          !$omp simd
          do ist = 1, psib%nst_linear
            vpsib%X(ff_pack)(ist, ip) = vpsib%X(ff_pack)(ist, ip) + vv*psib%X(ff_pack)(ist, ip)
          end do
        end do
        call profiling_count_operations(((R_ADD+R_MUL)*psib%nst_linear)*mesh%np)
      end if
      call profiling_count_transfers(mesh%np, M_ONE)
      call profiling_count_transfers(mesh%np*psib%nst_linear, R_TOTYPE(M_ONE))

    case (SPINORS)
#ifdef R_TCOMPLEX
      ASSERT(mod(psib%nst_linear, 2) == 0)
      !the spinor case is more complicated since it mixes the two components.
      if (pot_is_cmplx) then
        !$omp parallel do private(psi1, psi2, ist, pot)
        do ip = 1, mesh%np
          pot(1:2) = cmplx(potential(ip, 1:2), Impotential(ip, 1:2), real64)
          pot(3) = cmplx(potential(ip, 3) - Impotential(ip, 4), potential(ip, 4) + Impotential(ip, 3), real64)
          pot(4) = cmplx(potential(ip, 3) + Impotential(ip, 4),-potential(ip, 4) + Impotential(ip, 3), real64)
          do ist = 1, psib%nst_linear, 2
            psi1 = psib%zff_pack(ist    , ip)
            psi2 = psib%zff_pack(ist + 1, ip)
            vpsib%zff_pack(ist    , ip) = vpsib%zff_pack(ist    , ip) + pot(1)*psi1 + pot(3)*psi2
            vpsib%zff_pack(ist + 1, ip) = vpsib%zff_pack(ist + 1, ip) + pot(2)*psi2 + pot(4)*psi1
          end do
        end do
        !$omp end parallel do
        call profiling_count_operations(((4*R_ADD + 4*R_MUL)*psib%nst + R_ADD+R_MUL)*mesh%np)
      else
        !$omp parallel do private(psi1, psi2, ist, tmp)
        do ip = 1, mesh%np
          tmp = cmplx(potential(ip, 3), potential(ip, 4), real64)
          do ist = 1, psib%nst_linear, 2
            psi1 = psib%zff_pack(ist    , ip)
            psi2 = psib%zff_pack(ist + 1, ip)
            vpsib%zff_pack(ist    , ip) = vpsib%zff_pack(ist    , ip) + &
              potential(ip, 1)*psi1 + tmp*psi2
            vpsib%zff_pack(ist + 1, ip) = vpsib%zff_pack(ist + 1, ip) + &
              potential(ip, 2)*psi2 + conjg(tmp)*psi1
          end do
        end do
        !$omp end parallel do
        call profiling_count_operations((4*R_ADD + 4*R_MUL)*mesh%np*psib%nst)
      end if
#else
      ! Spinors always imply complex batches
      ASSERT(.false.)
#endif
    end select

  case (BATCH_NOT_PACKED)

    select case (std%ispin)
    case (UNPOLARIZED, SPIN_POLARIZED)
      if (pot_is_cmplx) then
#ifdef R_TCOMPLEX
        !$omp parallel private(ip, ist)
        do ist = 1, psib%nst
          !$omp do simd
          do ip = 1, mesh%np
            vpsib%X(ff)(ip, 1, ist) = vpsib%X(ff)(ip, 1, ist) + &
              cmplx(potential(ip, ispin), Impotential(ip, ispin), real64) * psib%X(ff)(ip, 1, ist)
          end do
          !$omp end do simd nowait
        end do
        !$omp end parallel
        call profiling_count_operations(2*((R_ADD+R_MUL)*psib%nst)*mesh%np)
#else
        ! Complex potential can only be applied to complex batches
        ASSERT(.false.)
#endif
      else
        !$omp parallel private(ip, ist)
        do ist = 1, psib%nst
          !$omp do simd
          do ip = 1, mesh%np
            vpsib%X(ff)(ip, 1, ist) = vpsib%X(ff)(ip, 1, ist) + &
              potential(ip, ispin) * psib%X(ff)(ip, 1, ist)
          end do
          !$omp end do simd nowait
        end do
        !$omp end parallel
        call profiling_count_operations(((R_ADD+R_MUL)*psib%nst)*mesh%np)
      end if

      call profiling_count_transfers(mesh%np, M_ONE)
      call profiling_count_transfers(mesh%np*psib%nst, R_TOTYPE(M_ONE))

    case (SPINORS)
#ifdef R_TCOMPLEX
      !the spinor case is more complicated since it mixes the two components.
      if (pot_is_cmplx) then
        do ist = 1, psib%nst
          psi  => psib%zff(:, :, ist)
          vpsi => vpsib%zff(:, :, ist)

          do ip = 1, mesh%np
            pot(1:2) = cmplx(potential(ip, 1:2), Impotential(ip, 1:2), real64)
            pot(3) = cmplx(potential(ip, 3) - Impotential(ip, 4), potential(ip, 4) + Impotential(ip, 3), real64)
            pot(4) = cmplx(potential(ip, 3) + Impotential(ip, 4),-potential(ip, 4) + Impotential(ip, 3), real64)
            vpsi(ip, 1) = vpsi(ip, 1) + pot(1)*psi(ip, 1) + pot(3)*psi(ip, 2)
            vpsi(ip, 2) = vpsi(ip, 2) + pot(2)*psi(ip, 2) + pot(4)*psi(ip, 1)
          end do
        end do
        call profiling_count_operations((7*R_ADD + 7*R_MUL)*mesh%np*psib%nst)

      else
        do ist = 1, psib%nst
          psi  => psib%zff(:, :, ist)
          vpsi => vpsib%zff(:, :, ist)

          do ip = 1, mesh%np
            vpsi(ip, 1) = vpsi(ip, 1) + potential(ip, 1)*psi(ip, 1) + &
              cmplx(potential(ip, 3), potential(ip, 4), real64)*psi(ip, 2)
            vpsi(ip, 2) = vpsi(ip, 2) + potential(ip, 2)*psi(ip, 2) + &
              cmplx(potential(ip, 3), -potential(ip, 4), real64)*psi(ip, 1)
          end do
        end do
        call profiling_count_operations((6*R_ADD + 6*R_MUL)*mesh%np*psib%nst)
      end if
#else
      ! Spinors always imply complex batches
      ASSERT(.false.)
#endif
    end select

  end select

  call profiling_out(TOSTRING(X(VLPSI)))
  POP_SUB(X(hamiltonian_elec_base_local_sub))

end subroutine X(hamiltonian_elec_base_local_sub)

! -----------------------------------------------------------------------------
!> @brief apply magnetic terms form the Hamiltonian to the wave functions
! Adds A.p + A^2|\psi> to the Hamiltonian.
!
subroutine X(hamiltonian_elec_base_magnetic)(this, mesh, der, std, ep, ispin, psib, vpsib)
  class(hamiltonian_elec_base_t), intent(in)    :: this
  class(mesh_t),                  intent(in)    :: mesh
  type(derivatives_t),            intent(in)    :: der
  type(states_elec_dim_t),        intent(in)    :: std
  type(epot_t),                   intent(in)    :: ep
  integer,                        intent(in)    :: ispin
  type(wfs_elec_t), target,       intent(inout) :: psib
  type(wfs_elec_t), target,       intent(inout) :: vpsib

  integer :: idir, ip
  type(wfs_elec_t) :: adotpb, gradb(mesh%box%dim)
  real(real64), allocatable :: avec(:,:,:), a2(:,:)

  if (.not. this%has_magnetic()) return

  if (.not. allocated(this%vector_potential)) return

#ifndef R_TCOMPLEX
  ! Vector potential not allowed with real wavefunctions
  ASSERT(.false.)
#endif

  call profiling_in(TOSTRING(X(MAGNETIC)))
  PUSH_SUB(X(hamiltonian_elec_base_magnetic))

  call psib%copy_to(adotpb, copy_data = .false.)

  call X(derivatives_batch_grad)(der, psib, gradb, set_bc=.false.)

  SAFE_ALLOCATE(a2(1:mesh%np,1:std%nspin))
  SAFE_ALLOCATE(avec(1:mesh%np, 1:std%nspin, mesh%box%dim))
  !$omp parallel private(idir, ip)
  do idir = 1, mesh%box%dim
    !$omp do
    do ip = 1, mesh%np
      avec(ip, 1:std%spin_channels, idir) = this%vector_potential(idir, ip)
      avec(ip, std%spin_channels+1:std%nspin, idir) = M_ZERO
    end do
    !$omp end do nowait
  end do
  !$omp end parallel

  !$omp parallel do
  do ip = 1, mesh%np
    a2(ip, 1:std%spin_channels) = (M_HALF / this%mass) * sum(this%vector_potential(1:mesh%box%dim, ip)**2)
    a2(ip, std%spin_channels+1:std%nspin) = M_ZERO
  end do

  ! Adding the A^2 term
  call zhamiltonian_elec_base_local_sub(a2, mesh, std, ispin, psib, vpsib)

  ! Adding the A.p term
  ! This reads -i\hbar A.[\nabla\psi(r)]
  ! TODO: the batch_axpy call is inefficient
  call batch_set_zero(adotpb)
  do idir = 1, mesh%box%dim
    call zhamiltonian_elec_base_local_sub(avec(:,:,idir), mesh, std, ispin, gradb(idir), adotpb)
  end do
  call batch_axpy(mesh%np, -M_zI / this%mass, adotpb, vpsib)

  SAFE_DEALLOCATE_A(avec)
  SAFE_DEALLOCATE_A(a2)

  do idir = 1, mesh%box%dim
    call gradb(idir)%end()
  end do
  call adotpb%end()

  POP_SUB(X(hamiltonian_elec_base_magnetic))
  call profiling_out(TOSTRING(X(MAGNETIC)))
end subroutine X(hamiltonian_elec_base_magnetic)

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
