!! Copyright (C) 2002-2006 M. Marques, A. Castro, A. Rubio, G. Bertsch
!! Copyright (C) 2012-2013 M. Gruning, P. Melo, M. Oliveira
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!
!!
! ---------------------------------------------------------
!> This file handles the evaluation of the OEP potential, in the KLI or full OEP
!! as described in S. Kuemmel and J. Perdew, PRL 90, 043004 (2003)
!!
!! This file has to be outside the module xc, for it requires the Hpsi.
!! This is why it needs the xc_functional module. I prefer to put it here since
!! the rest of the Hamiltonian module does not know about the gory details
!! of how xc is defined and calculated.
subroutine X(xc_oep_calc)(oep, namespace, xcs, gr, hm, st, space, rcell_volume, ex, ec, vxc)
  type(xc_oep_t),              intent(inout) :: oep
  type(namespace_t),           intent(in)    :: namespace
  type(xc_t),                  intent(inout) :: xcs
  type(grid_t),                intent(in)    :: gr
  type(hamiltonian_elec_t),    intent(inout) :: hm
  type(states_elec_t), target, intent(inout) :: st
  class(space_t),              intent(in)    :: space
  real(real64),                intent(in)    :: rcell_volume
  real(real64),                intent(inout) :: ex, ec
  real(real64), contiguous, optional, intent(inout) :: vxc(:,:) !< vxc(mesh%np, st%d%nspin)

  real(real64) :: eig
  R_TYPE :: uxc_bar
  integer :: is, ist, ixc, nspin_, idm, ib, ik
  logical, save :: first = .true.
  R_TYPE, allocatable :: psi(:), xpsi(:)
  type(states_elec_t) :: xst
  type(states_elec_t), pointer :: oep_st
  logical :: exx

  if (oep%level == OEP_LEVEL_NONE) return

  call profiling_in(TOSTRING(X(XC_OEP)))
  PUSH_SUB(X(xc_oep_calc))

  if (oep%type == OEP_TYPE_SIC .and. oep%scdm_for_pzsic) then

    ! We new states will be stored in oep_st
    SAFE_ALLOCATE(oep_st)

    write(message(1), '(a,i4,a,f11.6)') 'Performing SCDM Wannierization for SIC'
    call messages_info(1, debug_only=.true.)

    ! See scdm_get_localized_states
    if (gr%parallel_in_domains) then
      call messages_not_implemented("PZ-SIC with spinors and domain parallelization")
    end if

    call X(scdm_get_localized_states)(oep%scdm, namespace, space, gr, hm%kpoints, st, oep_st)
  else
    oep_st => st
  end if


  ! initialize oep structure
  nspin_ = min(st%d%nspin, 2)
  SAFE_ALLOCATE(oep%eigen_type (1:st%nst))
  SAFE_ALLOCATE(oep%eigen_index(1:st%nst))
  SAFE_ALLOCATE(oep%X(lxc)(1:gr%np, 1:st%d%dim, st%st_start:st%st_end, st%d%kpt%start:st%d%kpt%end))
  oep%X(lxc) = M_ZERO
  SAFE_ALLOCATE(oep%uxc_bar(1:st%d%dim, 1:st%nst, st%d%kpt%start:st%d%kpt%end))

  !We first apply the exchange operator to all the states
  call xst%nullify()

  exx = .false.
  select case(oep%type)
  case(OEP_TYPE_EXX)
    functl_loop: do ixc = 1, 2
      if (xcs%functional(ixc, 1)%family /= XC_FAMILY_OEP) cycle
      select case (xcs%functional(ixc,1)%id)
      case (XC_OEP_X)
        call X(exchange_operator_compute_potentials)(hm%exxop, namespace, space, gr, st, xst, hm%kpoints, eig)
        ex = ex + eig
        exx = .true.
      end select
    end do functl_loop

    ! MGGA from OEP
  case(OEP_TYPE_MGGA)
    if (.not. xst%group%block_initialized) then
      call states_elec_copy(xst, st)
      call states_elec_set_zero(xst)
    end if

    do ik = st%d%kpt%start, st%d%kpt%end
      do ib = st%group%block_start, st%group%block_end
        ! Here we apply the term by calling hamiltonian_elec_apply_batch, which takes care of setting the
        ! phases properly, as well as upating boundary points/ghost points when needed
        call X(hamiltonian_elec_apply_batch)(hm, namespace, gr, st%group%psib(ib, ik), &
          xst%group%psib(ib, ik), terms = TERM_MGGA)
      end do
    end do
    exx = .true.

    ! this part handles the (pure) orbital functionals
    ! SIC a la PZ is handled here
  case(OEP_TYPE_SIC)
    ! this routine is only prepared for finite systems. (Why not?)
    if (st%nik > st%d%ispin) then
      call messages_not_implemented("OEP-SIC for periodic systems", namespace=namespace)
    end if
    ! The spinor case needs to be treated differently
    if (st%d%ispin == SPINORS) then
      call oep_sic_pauli(xcs, gr, hm%psolver, namespace, space, rcell_volume, oep_st, hm%kpoints, oep, ex, ec)
    else
      do is = 1, nspin_
        call X(oep_sic) (xcs, gr, hm%psolver, namespace, space, rcell_volume, oep_st, hm%kpoints, is, oep, ex, ec)
      end do
    end if

  case default
    ASSERT(.false.)
  end select

  ! calculate uxc_bar for the occupied states

  SAFE_ALLOCATE(psi(1:gr%np))
  SAFE_ALLOCATE(xpsi(1:gr%np))

  oep%uxc_bar(:, :, :) = M_ZERO
  do ik = st%d%kpt%start, st%d%kpt%end
    do ist = st%st_start, st%st_end
      if (abs(st%occ(ist, ik)) <= M_MIN_OCC) cycle

      do idm = 1, st%d%dim
        call states_elec_get_state(oep_st, gr, idm, ist, ik, psi)
        if (exx) then
          ! Here we copy the state from xst to X(lxc).
          ! This will be removed in the future, but it allows to keep both EXX and PZ-SIC in the code
          call states_elec_get_state(xst, gr, idm, ist, ik, xpsi)
          ! There is a complex conjugate here, as the lxc is defined as <\psi|X and
          ! exchange_operator_compute_potentials returns X|\psi>
#ifndef R_TREAL
          xpsi = R_CONJ(xpsi)
#endif
          call lalg_axpy(gr%np, M_ONE, xpsi, oep%X(lxc)(1:gr%np, idm, ist, ik))
        end if
        ! Temporary var assignment due to length of macro line
        uxc_bar = X(mf_dotp)(gr, psi, oep%X(lxc)(1:gr%np, idm, ist, ik), reduce = .false., dotu = .true.)
        oep%uxc_bar(idm, ist, ik) = R_REAL(uxc_bar)
      end do
    end do
  end do
  call gr%allreduce(oep%uxc_bar(:, :, st%d%kpt%start:st%d%kpt%end))

  SAFE_DEALLOCATE_A(psi)
  SAFE_DEALLOCATE_A(xpsi)

  call states_elec_end(xst)

  if (oep_st%parallel_in_states) then
    call oep_st%mpi_grp%barrier()
    do ik = oep_st%d%kpt%start, oep_st%d%kpt%end
      do ist = 1, oep_st%nst
        call oep_st%mpi_grp%bcast(oep%uxc_bar(1, ist, ik), st%d%dim, MPI_DOUBLE_PRECISION, oep_st%node(ist))
      end do
    end do
  end if

  if (st%d%ispin == SPINORS) then
    call xc_oep_AnalyzeEigen(oep, oep_st, 1)
    call xc_KLI_Pauli_solve(gr, oep_st, oep)
    if (present(vxc)) then
      call lalg_axpy(gr%np, 4, M_ONE, oep%vxc, vxc)
    end if
    ! full OEP not implemented!
  else
    call X(xc_KLI_solve) (space, gr, oep_st, oep, rcell_volume)
    if (oep%level == OEP_LEVEL_FULL .and. (.not. first)) then
      do is = 1, nspin_
        ! get the HOMO state
        call xc_oep_AnalyzeEigen(oep, oep_st, is)
        if (present(vxc)) then
          call X(xc_oep_solve)(namespace, gr, hm, oep_st, is, vxc(:,is), oep)
        end if
      end do
    else !KLI
      call lalg_axpy(gr%np, oep_st%d%nspin, M_ONE, oep%vxc, vxc)
    end if
    first = .false.
  end if
  SAFE_DEALLOCATE_A(oep%eigen_type)
  SAFE_DEALLOCATE_A(oep%eigen_index)
  SAFE_DEALLOCATE_A(oep%X(lxc))
  SAFE_DEALLOCATE_A(oep%uxc_bar)

  if (oep%type == OEP_TYPE_SIC .and. oep%scdm_for_pzsic) then
    call states_elec_end(oep_st)
    SAFE_DEALLOCATE_P(oep_st)
  end if

  POP_SUB(X(xc_oep_calc))
  call profiling_out(TOSTRING(X(XC_OEP)))
end subroutine X(xc_OEP_calc)


! ---------------------------------------------------------
!> This routine follows closely the one of PRB 68, 035103 (2003)
!> Below we refer to the equation number of this paper
subroutine X(xc_oep_solve) (namespace, mesh, hm, st, is, vxc, oep)
  type(namespace_t),        intent(in)    :: namespace
  class(mesh_t),            intent(in)    :: mesh
  type(hamiltonian_elec_t), intent(inout) :: hm
  type(states_elec_t),      intent(in)    :: st
  integer,                  intent(in)    :: is
  real(real64),             intent(inout) :: vxc(:) !< (mesh%np, given for the spin is)
  type(xc_oep_t),           intent(inout) :: oep

  integer :: iter, ist, iter_used
  real(real64) :: vxc_bar, ff, residue
  real(real64), allocatable :: ss(:), vxc_old(:)
  real(real64), allocatable :: psi2(:,:)
  R_TYPE, allocatable :: bb(:,:), psi(:, :)
  logical, allocatable :: orthogonal(:)
  real(real64) :: shift

  PUSH_SUB(X(xc_oep_solve))
  call profiling_in(TOSTRING(X(OEP_LEVEL_FULL)))

  if (st%parallel_in_states) then
    call messages_not_implemented("Full OEP parallel in states", namespace=namespace)
  end if

  SAFE_ALLOCATE(     bb(1:mesh%np, 1))
  SAFE_ALLOCATE(     ss(1:mesh%np))
  SAFE_ALLOCATE(vxc_old(1:mesh%np))
  SAFE_ALLOCATE(psi(1:mesh%np, 1:st%d%dim))
  SAFE_ALLOCATE(psi2(1:mesh%np, 1:st%d%dim))
  SAFE_ALLOCATE(orthogonal(1:st%nst))

  if (.not. lr_is_allocated(oep%lr)) then
    call lr_allocate(oep%lr, st, mesh)
    oep%lr%X(dl_psi)(:,:, :, :) = M_ZERO
  end if

  shift = M_ZERO

  do iter = 1, oep%scftol%max_iter
    ! We now update the potential
    ! It is made of the other xc potential (vxc) and the iterative oep (oep%vxc) and the oep shift
    hm%ks_pot%vxc(1:mesh%np,is) = vxc(1:mesh%np) + oep%vxc(1:mesh%np, is) + shift
    hm%ks_pot%vhxc(1:mesh%np,is) = hm%ks_pot%vhartree(1:mesh%np) + hm%ks_pot%vxc(1:mesh%np,is)
    call hamiltonian_elec_update_pot(hm, mesh)

    ! iteration over all states
    ss = M_ZERO
    do ist = 1, st%nst

      if (abs(st%occ(ist,is)) <= M_EPSILON) cycle  !only over occupied states
      call states_elec_get_state(st, mesh, ist, is, psi)
      psi2(:, 1) = real(R_CONJ(psi(:, 1))*psi(:,1), real64)

      ! evaluate right-hand side
      vxc_bar = dmf_integrate(mesh, psi2(:, 1)*hm%ks_pot%vxc(1:mesh%np, is))

      ! This the right-hand side of Eq. 21
      bb(1:mesh%np, 1) = -(hm%ks_pot%vxc(1:mesh%np, is) - (vxc_bar - oep%uxc_bar(1, ist, is)))* &
        R_CONJ(psi(:, 1)) + oep%X(lxc)(1:mesh%np, 1, ist, is)

      call X(lr_orth_vector) (mesh, st, bb, ist, is, R_TOTYPE(M_ZERO))

      ! Sternheimer equation [H-E_i]psi_i = bb_i, where psi_i the orbital shift, see Eq. 21
      call X(linear_solver_solve_HXeY)(oep%solver, namespace, hm, mesh, st, ist, is, &
        oep%lr%X(dl_psi)(:,:, ist, is), bb, &
        R_TOTYPE(-st%eigenval(ist, is)), oep%scftol%final_tol, residue, iter_used)

      write(message(1),'(a,i3,a,es14.6,a,es14.6,a,i4)') "Debug: OEP - iter ", iter, &
        " linear solver residual ", residue, " tol ", &
        oep%scftol%final_tol, " iter_used ", iter_used
      call messages_info(1, namespace=namespace, debug_only=.true.)

      !We project out the occupied bands
      call X(lr_orth_vector) (mesh, st, oep%lr%X(dl_psi)(:,:, ist, is), ist, is, R_TOTYPE(M_ZERO))

      ! calculate this funny function ss
      ! ss = ss + 2*dl_psi*psi
      ! This is Eq. 25
      call lalg_axpy(mesh%np, M_TWO, R_REAL(oep%lr%X(dl_psi)(1:mesh%np, 1, ist, is)*psi(:, 1)), ss(:))
    end do

    ff = dmf_nrm2(mesh, ss)
    write(message(1),'(a,i3,a,es14.6,a,i4)') "Debug: OEP - iter ", iter, " residual ", ff, " max ", oep%scftol%max_iter
    call messages_info(1, namespace=namespace, debug_only=.true.)

    !Here we enforce Eq. (24), see the discussion below Eq. 26
    shift = get_shift()

    if (ff < oep%scftol%conv_abs_dens) exit

    !Here we mix the xc potential
    call X(xc_oep_mix)(oep, mesh, ss, st%rho(:,is), is)

  end do

  ! As the last state of the loop might not converge, we have done a last mixing
  ! so we need to recompute the shift
  shift = get_shift()

  vxc(1:mesh%np) = vxc(1:mesh%np) + oep%vxc(1:mesh%np, is) + shift

  if (is == 1) then
    oep%norm2ss = ff
  else
    oep%norm2ss = oep%norm2ss + ff !adding up spin up and spin down component
  end if

  if (ff > oep%scftol%conv_abs_dens) then
    write(message(1), '(a)') "OEP did not converge."
    call messages_warning(1, namespace=namespace)

    ! otherwise the number below will be one too high
    iter = iter - 1
  end if

  write(message(1), '(a,i4,a,es14.6)') "Info: After ", iter, " iterations, the OEP residual = ", ff
  message(2) = ''
  call messages_info(2, namespace=namespace)

  SAFE_DEALLOCATE_A(bb)
  SAFE_DEALLOCATE_A(ss)
  SAFE_DEALLOCATE_A(vxc_old)
  SAFE_DEALLOCATE_A(psi)
  SAFE_DEALLOCATE_A(psi2)
  SAFE_DEALLOCATE_A(orthogonal)

  call profiling_out(TOSTRING(X(OEP_LEVEL_FULL)))
  POP_SUB(X(xc_oep_solve))
contains

  real(real64) function get_shift() result(shift)
    integer :: ist

    PUSH_SUB(get_shift)

    shift = M_ZERO
    do ist = 1, st%nst
      if (oep%eigen_type(ist) == 2) then
        call states_elec_get_state(st, mesh, ist, is, psi)
        psi2(:, 1) = real(R_CONJ(psi(:, 1))*psi(:,1), real64)
        vxc_bar = dmf_integrate(mesh, psi2(:, 1)*oep%vxc(1:mesh%np, is))*oep%socc*st%occ(ist, is)
        shift = shift - (vxc_bar - oep%uxc_bar(1, ist,is))
      end if
    end do

    POP_SUB(get_shift)
  end function get_shift

end subroutine X(xc_oep_solve)


!----------------------------------------------------------------------
!> A routine that takes care of mixing the potential
subroutine X(xc_oep_mix)(oep, mesh, ss, rho, is)
  class(xc_oep_t),          intent(inout) :: oep
  type(mesh_t),             intent(in)    :: mesh
  real(real64), contiguous, intent(in)    :: ss(:)
  real(real64),             intent(in)    :: rho(:)
  integer,                  intent(in)    :: is

  integer :: ip
  real(real64), allocatable :: mix(:)

  PUSH_SUB(X(xc_oep_mix))

  !Here we mix the xc potential
  select case (oep%mixing_scheme)
  case (OEP_MIXING_SCHEME_CONST)
    !This is Eq. 26
    call lalg_axpy(mesh%np, oep%mixing, ss(:), oep%vxc(:, is))

  case (OEP_MIXING_SCHEME_DENS)

    ! See Eq. 28 of the Kuemmel paper
    SAFE_ALLOCATE(mix(1:mesh%np))
    do ip = 1, mesh%np
      mix(ip) = - M_HALF * oep%vxc(ip, is) / (rho(ip) + M_TINY)
      ! To avoid nonsense local mixing, we put a maximum value here
      ! This tipically occurs when the wavefunctions are not well converged
      if (abs(mix(ip)) > 1e3_real64) mix(ip) = sign(1e3_real64,mix(ip))
      mix(ip) = mix(ip) * ss(ip)
    end do

    call lalg_axpy(mesh%np, M_ONE, mix, oep%vxc(:, is))
    SAFE_DEALLOCATE_A(mix)

  case (OEP_MIXING_SCHEME_BB)
    !This is the Barzilai-Borwein scheme, as explained in
    !Hollins, et al. PRB 85, 235126 (2012)
    if (dmf_nrm2(mesh, oep%vxc_old(1:mesh%np,is)) > M_EPSILON) then ! do not do it for the first run
      oep%mixing = -dmf_dotp(mesh, oep%vxc(1:mesh%np,is) - oep%vxc_old(1:mesh%np,is), ss - oep%ss_old(:, is)) &
        / dmf_dotp(mesh, ss - oep%ss_old(:, is), ss - oep%ss_old(:, is))
    end if

    write(message(1), '(a,es14.6,a,es15.8)') "Info: oep%mixing:", oep%mixing, " norm2ss: ", dmf_nrm2(mesh, ss)
    call messages_info(1, debug_only=.true.)

    call lalg_copy(mesh%np, oep%vxc(:,is), oep%vxc_old(:,is))
    call lalg_copy(mesh%np, ss, oep%ss_old(:, is))
    call lalg_axpy(mesh%np, oep%mixing, ss(:), oep%vxc(:, is))

  end select

  POP_SUB(X(xc_oep_mix))
end subroutine X(xc_oep_mix)

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
