!! Copyright (C) 2002-2006 M. Marques, A. Castro, A. Rubio, G. Bertsch
!! Copyright (C) 2012-2013 M. Gruning, P. Melo, M. Oliveira
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!
!!
! ---------------------------------------------------------
!> This file handles the evaluation of the OEP potential, in the KLI or full OEP
!! as described in S. Kuemmel and J. Perdew, PRL 90, 043004 (2003)
!!
!! This file has to be outside the module xc, for it requires the Hpsi.
!! This is why it needs the xc_functl module. I prefer to put it here since
!! the rest of the Hamiltonian module does not know about the gory details
!! of how xc is defined and calculated.
subroutine X(xc_oep_photon_calc)(oep, namespace, xcs, gr, hm, st, space, ex, ec, vxc)
  type(xc_oep_photon_t),    intent(inout) :: oep
  type(namespace_t),        intent(in)    :: namespace
  type(xc_t),               intent(inout) :: xcs
  type(grid_t),             intent(in)    :: gr
  type(hamiltonian_elec_t), intent(inout) :: hm
  type(states_elec_t),      intent(inout) :: st
  class(space_t),           intent(in)    :: space
  real(real64),             intent(inout) :: ex, ec
  real(real64), contiguous, optional, intent(inout) :: vxc(:,:) !< vxc(mesh%np, st%d%nspin)

  real(real64) :: eig
  integer :: is, ist, ixc, nspin_, isp, idm
  logical, save :: first = .true.
  R_TYPE, allocatable :: psi(:), xpsi(:)
  type(states_elec_t) :: xst
  logical :: exx

  if(oep%level == OEP_LEVEL_NONE) return

  call profiling_in(TOSTRING(X(XC_OEP_PHOTON)))
  PUSH_SUB(X(xc_oep_photon_calc))

  ! initialize oep structure
  nspin_ = min(st%d%nspin, 2)
  SAFE_ALLOCATE(oep%eigen_type (1:st%nst))
  SAFE_ALLOCATE(oep%eigen_index(1:st%nst))
  SAFE_ALLOCATE(oep%X(lxc)(1:gr%np, 1:nspin_, st%st_start:st%st_end, 1:nspin_))
  oep%X(lxc) = M_ZERO
  SAFE_ALLOCATE(oep%uxc_bar(1:1, 1:st%nst, 1:nspin_))

  !We first apply the exchange operator to all the states
  call xst%nullify()

  exx = .false.
  functl_loop: do ixc = 1, 2
    if(xcs%functional(ixc, 1)%family /= XC_FAMILY_OEP) cycle
    select case(xcs%functional(ixc,1)%id)
    case(XC_OEP_X)
      call X(exchange_operator_compute_potentials)(hm%exxop, namespace, space, gr, st, xst, hm%kpoints, eig)
      ex = ex + eig
      exx = .true.
    end select
  end do functl_loop

  ! calculate uxc_bar for the occupied states

  SAFE_ALLOCATE(psi(1:gr%np))
  SAFE_ALLOCATE(xpsi(1:gr%np))

  oep%uxc_bar(:, :, :) = M_ZERO
  do is = 1, nspin_
    ! distinguish between 'is' being the spin_channel index (collinear)
    ! and being the spinor (noncollinear)
    if (st%d%ispin==SPINORS) then
      isp = 1
      idm = is
    else
      isp = is
      idm = 1
    end if

    do ist = st%st_start, st%st_end
      call states_elec_get_state(st, gr, idm, ist, isp, psi)
      if(exx) then
        ! Here we copy the state from xst to X(lxc).
        ! This will be removed in the future, but it allows to keep both EXX and PZ-SIC in the code
        call states_elec_get_state(xst, gr, idm, ist, isp, xpsi)
        ! There is a complex conjugate here, as the lxc is defined as <\psi|X and
        ! exchange_operator_compute_potentials returns X|\psi>
#ifndef R_TREAL
        xpsi = R_CONJ(xpsi)
#endif
        call lalg_axpy(gr%np, M_ONE, xpsi, oep%X(lxc)(1:gr%np, idm, ist, is))
      end if

      oep%uxc_bar(idm, ist, is) = R_REAL(X(mf_dotp)(gr, psi, oep%X(lxc)(1:gr%np, idm, ist, is), reduce = .false., dotu = .true.))
    end do
  end do
  call gr%allreduce(oep%uxc_bar)

  SAFE_DEALLOCATE_A(psi)
  SAFE_DEALLOCATE_A(xpsi)

  call states_elec_end(xst)

  if(st%parallel_in_states) then
    if(st%d%ispin == SPIN_POLARIZED) then
      do isp = 1, 2
        do ist = 1, st%nst
          call  st%mpi_grp%bcast(oep%uxc_bar(1, ist, isp), st%d%dim, MPI_DOUBLE_PRECISION, st%node(ist))
        end do
      end do
    else
      do ist = 1, st%nst
        call  st%mpi_grp%bcast(oep%uxc_bar(1, ist, 1), st%d%dim, MPI_DOUBLE_PRECISION, st%node(ist))
      end do
    end if
  end if

  ! remove electron-electron interaction only consider electron-photon interaction
  if (oep%rm_ee_interaction) then
    vxc = M_ZERO
    oep%uxc_bar = M_ZERO
    oep%X(lxc) = M_ZERO
  end if


  do is = 1, nspin_
    ! get the HOMO state
    call xc_oep_AnalyzeEigen(oep, st, is)
    !
    call X(xc_KLI_solve_photon) (namespace, gr, hm, st, is, oep, first)
    !
    ! if asked, solve the full OEP equation
    if(oep%level == OEP_LEVEL_FULL .and. (.not. first)) then
      if(present(vxc)) then
        call X(xc_oep_solve_photon)(namespace, gr, hm, st, is, vxc(:,is), oep)
      end if
    else  ! solve the KLI equation
      if(present(vxc)) then
        call lalg_axpy(gr%np, M_ONE, oep%vxc(1:gr%np, is), vxc(1:gr%np, is))
      end if
    end if
  end do

  first = .false.

  SAFE_DEALLOCATE_A(oep%eigen_type)
  SAFE_DEALLOCATE_A(oep%eigen_index)
  SAFE_DEALLOCATE_A(oep%X(lxc))
  SAFE_DEALLOCATE_A(oep%uxc_bar)

  call profiling_out(TOSTRING(X(XC_OEP_PHOTON)))

  POP_SUB(X(xc_oep_photon_calc))
end subroutine X(xc_OEP_photon_calc)


! ---------------------------------------------------------
!> This is the photon version of the xc_oep_solve routine
subroutine X(xc_oep_solve_photon) (namespace, mesh, hm, st, is, vxc, oep)
  type(namespace_t),        intent(in)    :: namespace
  class(mesh_t),            intent(in)    :: mesh
  type(hamiltonian_elec_t), intent(inout) :: hm
  type(states_elec_t),      intent(in)    :: st
  integer,                  intent(in)    :: is
  real(real64), contiguous, intent(inout) :: vxc(:) !< (mesh%np, given for the spin is)
  type(xc_oep_photon_t),    intent(inout) :: oep

  integer :: iter, ist, iter_used
  real(real64) :: vxc_bar, ff, residue
  real(real64), allocatable :: ss(:), vxc_old(:)
  real(real64), allocatable :: psi2(:,:)
  R_TYPE, allocatable :: bb(:,:), psi(:, :), phi1(:,:,:)
  logical, allocatable :: orthogonal(:)
  real(real64) :: shift

  call profiling_in(TOSTRING(X(OEP_LEVEL_FULL_PHOTON)))
  PUSH_SUB(X(xc_oep_solve_photon))

  if(st%parallel_in_states) &
    call messages_not_implemented("Full OEP parallel in states", namespace=namespace)

#ifndef R_TREAL
  ! Photons with OEP are only implemented for real states
  ASSERT(.false.)
#endif

  SAFE_ALLOCATE(     bb(1:mesh%np, 1:1))
  SAFE_ALLOCATE(     ss(1:mesh%np))
  SAFE_ALLOCATE(vxc_old(1:mesh%np))
  SAFE_ALLOCATE(psi(1:mesh%np, 1:st%d%dim))
  SAFE_ALLOCATE(psi2(1:mesh%np, 1:st%d%dim))
  SAFE_ALLOCATE(orthogonal(1:st%nst))
  SAFE_ALLOCATE(phi1(1:mesh%np, 1:st%d%dim, 1:oep%noccst))

  if(.not. lr_is_allocated(oep%lr)) then
    call lr_allocate(oep%lr, st, mesh)
    oep%lr%X(dl_psi)(:,:, :, :) = M_ZERO
  end if

  if(.not. lr_is_allocated(oep%photon_lr)) then
    call lr_allocate(oep%photon_lr, st, mesh)
    oep%photon_lr%ddl_psi(:, :, :, :) = M_ZERO
  end if

  shift = M_ZERO

#ifdef R_TREAL
  call xc_oep_pt_phi(namespace, mesh, hm, st, is, oep, phi1)
#endif

  do iter = 1, oep%scftol%max_iter
    ! We now update the potential
    ! It is made of the other xc potential (vxc) and the iterative oep (oep%vxc) and the oep shift
    hm%ks_pot%vxc(1:mesh%np, is) = vxc(1:mesh%np) + oep%vxc(1:mesh%np, is) + shift
    hm%ks_pot%vhxc(1:mesh%np ,is) = hm%ks_pot%vhartree(1:mesh%np) + hm%ks_pot%vxc(1:mesh%np, is)
    call hamiltonian_elec_update_pot(hm, mesh)

    ! iteration over all states
    ss = M_ZERO
    do ist = 1, st%nst

      if(abs(st%occ(ist, is))<= M_EPSILON) cycle  !only over occupied states
      call states_elec_get_state(st, mesh, ist, is, psi)
      psi2(:, 1) = real(R_CONJ(psi(:, 1))*psi(:,1), real64)
      ! evaluate right-hand side
      vxc_bar = dmf_integrate(mesh, psi2(:, 1)*hm%ks_pot%vxc(1:mesh%np, is))

      ! This the right-hand side of Eq. 21
      bb(1:mesh%np, 1) = -(hm%ks_pot%vxc(1:mesh%np, is) - (vxc_bar - oep%uxc_bar(1, ist, is)))* &
        R_CONJ(psi(:, 1)) + oep%X(lxc)(1:mesh%np, 1, ist, is)

#ifdef R_TREAL
      call xc_oep_pt_rhs(mesh, st, is, oep, phi1, ist, bb)
#endif

      call X(lr_orth_vector) (mesh, st, bb, ist, is, R_TOTYPE(M_ZERO))

      ! Sternheimer equation [H-E_i]psi_i = bb_i, where psi_i the orbital shift, see Eq. 21
      call X(linear_solver_solve_HXeY)(oep%solver, namespace, hm, mesh, st, ist, is, &
        oep%lr%X(dl_psi)(:,:, ist, is), bb, &
        R_TOTYPE(-st%eigenval(ist, is)), oep%scftol%final_tol, residue, iter_used)

      write(message(1),'(a,i3,a,es14.6,a,es14.6,a,i4)') "Debug: OEP - iter ", iter, &
        " linear solver residual ", residue, " tol ", &
        oep%scftol%final_tol, " iter_used ", iter_used
      call messages_info(1, debug_only=.true.)

      !We project out the occupied bands
      call X(lr_orth_vector) (mesh, st, oep%lr%X(dl_psi)(:,:, ist, is), ist, is, R_TOTYPE(M_ZERO))


      ! calculate this funny function ss
      ! ss = ss + 2*dl_psi*psi
      ! This is Eq. 25
      call lalg_axpy(mesh%np, M_TWO, R_REAL(oep%lr%X(dl_psi)(1:mesh%np, 1, ist, is)*psi(:, 1)), ss(:))

#ifdef R_TREAL
      call xc_oep_pt_inhomog(mesh, st, is, phi1, ist, ss)
#endif
    end do

    ff = dmf_nrm2(mesh, ss)
    write(message(1),'(a,i3,a,es14.6,a,i4)') "Debug: OEP - iter ", iter, " residual ", ff, " max ", oep%scftol%max_iter
    call messages_info(1, debug_only=.true.)

    !Here we enforce Eq. (24), see the discussion below Eq. 26
    shift = get_shift()

    if (ff < oep%scftol%conv_abs_dens) exit


    !Here we mix the xc potential
    call X(xc_oep_mix)(oep, mesh, ss, st%rho(:,is), is)

  end do

  ! As the last state of the loop might not converge, we have done a last mixing
  ! so we need to recompute the shift
  shift = get_shift()

  vxc(1:mesh%np) = vxc(1:mesh%np) + oep%vxc(1:mesh%np, is) + shift

  if (is == 1) then
    oep%norm2ss = ff
  else
    oep%norm2ss = oep%norm2ss + ff !adding up spin up and spin down component
  end if

  if(ff > oep%scftol%conv_abs_dens) then
    write(message(1), '(a)') "OEP did not converge."
    call messages_warning(1, namespace=namespace)

    ! otherwise the number below will be one too high
    iter = iter - 1
  end if

  write(message(1), '(a,i4,a,es14.6)') "Info: After ", iter, " iterations, the OEP residual = ", ff
  message(2) = ''
  call messages_info(2)

  SAFE_DEALLOCATE_A(bb)
  SAFE_DEALLOCATE_A(ss)
  SAFE_DEALLOCATE_A(vxc_old)
  SAFE_DEALLOCATE_A(psi)
  SAFE_DEALLOCATE_A(psi2)
  SAFE_DEALLOCATE_A(orthogonal)
  SAFE_DEALLOCATE_A(phi1)

  call profiling_out(TOSTRING(X(OEP_LEVEL_FULL_PHOTON)))

  POP_SUB(X(xc_oep_solve_photon))

contains
  real(real64) function get_shift() result(shift)
    PUSH_SUB(get_shift)

    shift = M_ZERO
    do ist = 1, st%nst
      if(oep%eigen_type(ist) == 2) then
        call states_elec_get_state(st, mesh, ist, is, psi)
        psi2(:, 1) = real(R_CONJ(psi(:, 1))*psi(:,1), real64)
        vxc_bar = dmf_integrate(mesh, psi2(:, 1)*oep%vxc(1:mesh%np, is))
#ifdef R_TREAL
        call xc_oep_pt_uxcbar(mesh, st, is, oep, phi1, ist, vxc_bar)
#endif
        shift = shift - (vxc_bar - oep%uxc_bar(1, ist, is))
      end if
    end do

    POP_SUB(get_shift)
  end function get_shift
end subroutine X(xc_oep_solve_photon)

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
