!! Copyright (C) 2002-2006 M. Marques, A. Castro, A. Rubio, G. Bertsch
!! Copyright (C) 2023-2024 N. Tancogne-Dejean
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

#include "global.h"

module xc_fbe_oct_m
  use batch_oct_m
  use batch_ops_oct_m
  use comm_oct_m
  use debug_oct_m
  use derivatives_oct_m
  use electron_space_oct_m
  use exchange_operator_oct_m
  use global_oct_m
  use grid_oct_m
  use kpoints_oct_m
  use lalg_basic_oct_m
  use math_oct_m
  use messages_oct_m
  use mesh_oct_m
  use mesh_function_oct_m
  use mpi_oct_m
  use namespace_oct_m
  use nl_operator_oct_m
  use parser_oct_m
  use profiling_oct_m
  use poisson_oct_m
  use ring_pattern_oct_m
  use solvers_oct_m
  use space_oct_m
  use states_abst_oct_m
  use states_elec_oct_m
  use states_elec_dim_oct_m
  use states_elec_all_to_all_communications_oct_m
  use varinfo_oct_m
  use wfs_elec_oct_m
  use xc_functional_oct_m

  implicit none

  private
  public ::               &
    x_fbe_calc,           &
    lda_c_fbe,            &
    fbe_c_lda_sl

  type(grid_t), pointer :: gr_aux  => null()
  real(real64), pointer :: rho_aux(:) => null()
  real(real64), allocatable  :: diag_lapl(:) !< diagonal of the laplacian


contains

  ! -------------------------------------------------------------------------------------
  !>@brief Interface to X(x_fbe_calc)
  !! Two possible run modes possible: adiabatic and Sturm-Liouville.
  !! In the first one, we assume no current and solve the local force-balance equation
  !! In the second case, we solve the Sturm-Liouville equation
  !! The energy is given by the virial relation
  subroutine x_fbe_calc (id, namespace, psolver, gr, st, space, ex, vxc)
    integer,                     intent(in)    :: id
    type(namespace_t),           intent(in)    :: namespace
    type(poisson_t),             intent(in)    :: psolver
    type(grid_t),                intent(in)    :: gr
    type(states_elec_t),         intent(inout) :: st
    type(space_t),               intent(in)    :: space
    real(real64),                intent(inout) :: ex
    real(real64), contiguous, optional, intent(inout) :: vxc(:,:) !< vxc(gr%mesh%np, st%d%nspin)

    real(real64), allocatable :: fxc(:,:,:), internal_vxc(:,:)

    PUSH_SUB(x_fbe_calc)

    select case(id)
    case(XC_OEP_X_FBE)
      if (states_are_real(st)) then
        call dx_fbe_calc(namespace, psolver, gr, gr%der, st, ex, vxc=vxc)
      else
        call zx_fbe_calc(namespace, psolver, gr, gr%der, st, ex, vxc=vxc)
      end if
    case(XC_OEP_X_FBE_SL)
      SAFE_ALLOCATE(fxc(1:gr%np_part, 1:gr%box%dim, 1:st%d%spin_channels))
      SAFE_ALLOCATE(internal_vxc(1:gr%np, 1:st%d%spin_channels))
      internal_vxc = M_ZERO
      ! We first compute the force density
      if (states_are_real(st)) then
        call dx_fbe_calc(namespace, psolver, gr, gr%der, st, ex, vxc=internal_vxc, fxc=fxc)
      else
        call zx_fbe_calc(namespace, psolver, gr, gr%der, st, ex, vxc=internal_vxc, fxc=fxc)
      end if

      ! We solve the Sturm-Liouville equation
      if (present(vxc)) then
        call solve_sturm_liouville(namespace, gr, st, space, fxc, internal_vxc)
      end if

      ! Get the energy from the virial relation
      ex = get_virial_energy(gr, st%d%spin_channels, fxc)

      ! Adds the calculated potential
      call lalg_axpy(gr%np, st%d%spin_channels, M_ONE, internal_vxc, vxc)

      SAFE_DEALLOCATE_A(fxc)
      SAFE_DEALLOCATE_A(internal_vxc)
    case default
      ASSERT(.false.)
    end select

    POP_SUB(x_fbe_calc)
  end subroutine x_fbe_calc

  ! -------------------------------------------------------------------------------------
  !>@brief Solve the Sturm-Liouville equation
  !! On entry, vxc is the adiabatic one, on exit, it is the solution of the Sturm-Liouville equation
  subroutine solve_sturm_liouville(namespace, gr, st, space, fxc, vxc)
    type(namespace_t),            intent(in)     :: namespace
    type(grid_t), target,         intent(in)     :: gr
    type(states_elec_t), target,  intent(in)     :: st
    type(space_t),                intent(in)     :: space
    real(real64),  contiguous,    intent(inout)  :: fxc(:,:,:)
    real(real64),  contiguous,    intent(inout)  :: vxc(:,:)

    real(real64), allocatable :: rhs(:)
    integer :: iter, ispin
    real(real64) :: res
    real(real64), parameter :: threshold = 1e-7_real64
    character(len=32) :: name

    type(nl_operator_t) :: op(1)  !< this array is necessary for derivatives_get_lapl() to work

    PUSH_SUB(solve_sturm_liouville)

    ASSERT(ubound(fxc, dim=1) >= gr%np_part)

    gr_aux => gr
    call mesh_init_mesh_aux(gr)

    ! the smoothing is performed uing the same stencil as the Laplacian
    name = 'FBE preconditioner'
    call derivatives_get_lapl(gr%der, namespace, op, space, name, 1)
    SAFE_ALLOCATE(diag_lapl(1:op(1)%np))
    call dnl_operator_operate_diag(op(1), diag_lapl)
    call nl_operator_end(op(1))

    SAFE_ALLOCATE(rhs(1:gr%np))

    do ispin = 1, st%d%spin_channels
      call dderivatives_div(gr%der, fxc(:, :, ispin), rhs)
      rhs=-rhs
      rho_aux => st%rho(:, ispin)

      iter = 500
      call dqmr_sym_gen_dotu(gr%np, vxc(:, ispin), rhs, &
        sl_operator, dmf_dotu_aux, dmf_nrm2_aux, preconditioner, &
        iter, residue = res, threshold = threshold, showprogress = .false.)

      write(message(1), '(a, i6, a)') "Info: Sturm-Liouville solver converged in  ", iter, " iterations."
      write(message(2), '(a, es14.6)') "Info: The residue is ", res
      call messages_info(2, namespace=namespace)
    end do

    SAFE_DEALLOCATE_A(rhs)

    SAFE_DEALLOCATE_A(diag_lapl)

    nullify(rho_aux)
    nullify(gr_aux)

    POP_SUB(solve_sturm_liouville)
  contains
    !----------------------------------------------------------------
    !> Computes Ax = \nabla\cdot(\rho\nabla x)
    subroutine sl_operator(x, hx)
      real(real64), contiguous, intent(in)  :: x(:)
      real(real64), contiguous, intent(out) :: hx(:)

      integer :: ip, idir
      real(real64), allocatable :: vxc(:)
      real(real64), allocatable :: grad_vxc(:,:)

      SAFE_ALLOCATE(vxc(1:gr_aux%np_part))
      SAFE_ALLOCATE(grad_vxc(1:gr_aux%np_part, 1:gr_aux%box%dim))
      call lalg_copy(gr_aux%np, x, vxc)

      call dderivatives_grad(gr_aux%der, vxc, grad_vxc)

      !$omp parallel
      do idir = 1, gr_aux%box%dim
        !$omp do
        do ip = 1, gr_aux%np
          grad_vxc(ip, idir) = grad_vxc(ip, idir)*rho_aux(ip)
        end do
        !$omp end do nowait
      end do
      !$omp end parallel

      call dderivatives_div(gr_aux%der, grad_vxc, hx)

      SAFE_DEALLOCATE_A(vxc)
      SAFE_DEALLOCATE_A(grad_vxc)
    end subroutine sl_operator

    !----------------------------------------------------------------
    !> Simple preconditioner
    !! Here we need to approximate P^-1
    !! We use the Jacobi approximation and that \nabla\cdot[ \rho \nabla v] \approx \rho \nabla^2 v
    subroutine preconditioner(x, hx)
      real(real64), contiguous, intent(in)  :: x(:)
      real(real64), contiguous, intent(out) :: hx(:)

      integer :: ip

      !$omp parallel do
      do ip = 1, gr_aux%np
        hx(ip) = x(ip) / (max(rho_aux(ip), 1d-12) * diag_lapl(ip))
      end do

    end subroutine preconditioner

  end subroutine solve_sturm_liouville

  ! -------------------------------------------------------------------------------------
  !>@brief Computes the energy from the force virial relation
  real(real64) function get_virial_energy(gr, nspin, fxc) result(exc)
    type(grid_t),  intent(in) :: gr
    integer,       intent(in) :: nspin
    real(real64),  intent(in) :: fxc(:,:,:)

    integer :: isp, idir, ip
    real(real64), allocatable :: rfxc(:)
    real(real64) :: xx(gr%box%dim), rr

    PUSH_SUB(get_virial_energy)

    exc = M_ZERO
    do isp = 1, nspin
      SAFE_ALLOCATE(rfxc(1:gr%np))
      do ip = 1, gr%np
        rfxc(ip) = M_ZERO
        call mesh_r(gr, ip, rr, coords=xx)
        do idir = 1, gr%box%dim
          rfxc(ip) = rfxc(ip) + fxc(ip, idir, isp) * xx(idir)
        end do
      end do
      exc = exc + dmf_integrate(gr, rfxc)
      SAFE_DEALLOCATE_A(rfxc)
    end do

    POP_SUB(get_virial_energy)
  end function get_virial_energy


  ! -------------------------------------------------------------------------------------
  !>@brief Computes the local density correlation potential and energy obtained from the Colle-Salvetti approximation to the reduced density matrix, with a gradient expansion on the correlation force density.
  !!
  !! The spin-polarized case does not have an energy and uses the approximated potential that
  !! neglects \nabla \zeta.
  !! The energy is therefore wrong for the spin-polarized case
  subroutine lda_c_fbe (st, n_blocks, l_dens, l_dedd, l_zk)
    type(states_elec_t),         intent(in)    :: st
    integer,                     intent(in)    :: n_blocks
    real(real64),                intent(in)    :: l_dens(:,:)
    real(real64),                intent(inout) :: l_dedd(:,:)
    real(real64), optional,      intent(inout) :: l_zk(:)

    integer :: ip, ispin
    real(real64) :: rho, beta, beta2, e_c
    real(real64) :: q

    PUSH_SUB(lda_c_fbe)

    ! Set q such that we get the leading order of the r_s->0 limit for the HEG
    q = ((5.0_real64*sqrt(M_PI)**5)/(M_THREE*(M_ONE-log(M_TWO))))**(M_THIRD)
    if (present(l_zk)) l_zk = M_ZERO

    do ip = 1, n_blocks
      rho = sum(l_dens(1:st%d%spin_channels, ip))
      if (rho < 1e-20_real64) then
        l_dedd(1:st%d%spin_channels, ip) = M_ZERO
        cycle
      end if
      rho = max(rho, 1e-12_real64)
      beta = q*rho**M_THIRD
      beta2 = beta**2

      ! Potential
      ! First part of the potential
      l_dedd(1:st%d%spin_channels, ip) = (M_PI/(q**3))*((sqrt(M_PI)*beta/(M_ONE+sqrt(M_PI)*beta))**2 -M_ONE) * beta
      ! Second part of the potential
      l_dedd(1:st%d%spin_channels, ip) = l_dedd(1:st%d%spin_channels, ip) &
        - (5.0_real64*sqrt(M_PI))/(M_THREE*q**3)*(log(M_ONE+sqrt(M_PI)*beta) &
        -M_HALF/(M_ONE+sqrt(M_PI)*beta)**2 + M_TWO/(M_ONE+sqrt(M_PI)*beta)) + (5.0_real64*sqrt(M_PI))/(M_TWO*q**3)

      if (st%d%nspin == 1 .and. present(l_zk)) then
        ! Energy density
        ! First part of the energy density
        e_c = (9.0_real64*q**3)/M_TWO/beta &
          - M_TWO*q**3*sqrt(M_PI) &
          - 12.0_real64/beta2*(q**3/sqrt(M_PI)) &
          + M_THREE/(M_PI*rho)*(M_ONE/(M_ONE+sqrt(M_PI)*beta) - M_ONE &
          + 5.0_real64*log(M_ONE+sqrt(M_PI)*beta))

        ! Second part of the energy density
        e_c = e_c - 5.0_real64/6.0_real64*( &
          7.0_real64*q**3/beta &
          + M_THREE/(M_PI*rho*(M_ONE+sqrt(M_PI)*beta)) &
          - 17.0_real64*q**3/sqrt(M_PI)/beta2 &
          - 11.0_real64*q**3*sqrt(M_PI)/(M_THREE) &
          + (20.0_real64/(M_PI*rho) + M_TWO*sqrt(M_PI)*q**3)*log(M_ONE+sqrt(M_PI)*beta) &
          - M_THREE/(M_PI*rho))
        e_c = e_c/(q**6)
        l_zk(ip) = e_c
      else if(st%d%nspin == 2) then
        ! Here we have no energy density, so leave the potential unchanged
        ! This is the approximate potential that we implement here
        do ispin = 1, st%d%spin_channels
          l_dedd(ispin, ip) = l_dedd(ispin, ip) * M_TWO * l_dens(-ispin+3, ip) / rho
        end do
      end if
    end do

    POP_SUB(lda_c_fbe)
  end subroutine lda_c_fbe

  ! -------------------------------------------------------------------------------------
  !>@brief Sturm-Liouville version of the FBE local-density correlation functional
  subroutine fbe_c_lda_sl (namespace, psolver, gr, st, space, ec, vxc)
    type(namespace_t),           intent(in)    :: namespace
    type(poisson_t),             intent(in)    :: psolver
    type(grid_t),                intent(in)    :: gr
    type(states_elec_t),         intent(inout) :: st
    type(space_t),               intent(in)    :: space
    real(real64),                intent(inout) :: ec
    real(real64), contiguous, optional, intent(inout) :: vxc(:,:) !< vxc(gr%mesh%np, st%d%nspin)

    integer :: idir, ip, ispin
    real(real64), allocatable :: fxc(:,:,:), internal_vxc(:,:), grad_rho(:,:,:), tmp1(:,:), tmp2(:,:)
    real(real64) :: q, beta, rho, l_gdens

    PUSH_SUB(fbe_c_lda_sl)

    SAFE_ALLOCATE(internal_vxc(1:gr%np, 1:st%d%spin_channels))

    ! Needed to get the initial guess for the iterative solution of the Sturm-Liouville equation
    SAFE_ALLOCATE(tmp1(1:st%d%spin_channels, 1:gr%np))
    SAFE_ALLOCATE(tmp2(1:st%d%spin_channels, 1:gr%np))
    tmp1 = transpose(st%rho(1:gr%np, 1:st%d%spin_channels))
    call lda_c_fbe(st, gr%np, tmp1, tmp2)
    internal_vxc = transpose(tmp2)
    SAFE_DEALLOCATE_A(tmp1)
    SAFE_DEALLOCATE_A(tmp2)

    ! Set q such that we get the leading order of the r_s->0 limit for the HEG
    q = ((5.0_real64*sqrt(M_PI)**5)/(M_THREE*(M_ONE-log(M_TWO))))**(M_THIRD)

    SAFE_ALLOCATE(fxc(1:gr%np_part, 1:gr%box%dim, 1:st%d%spin_channels))
    SAFE_ALLOCATE(grad_rho(1:gr%np, 1:gr%box%dim, 1:st%d%spin_channels))
    do ispin = 1, st%d%spin_channels
      call dderivatives_grad(gr%der, st%rho(:, ispin), grad_rho(:, :, ispin))
    end do

    do ispin = 1, st%d%spin_channels
      do idir = 1, gr%box%dim
        do ip = 1, gr%np
          rho = sum(st%rho(ip, 1:st%d%spin_channels))
          if (st%rho(ip, ispin) < 1e-20_real64) then
            fxc(ip, idir, ispin) = M_ZERO
            cycle
          end if
          rho = max(rho, 1e-12_real64)
          beta = rho**M_THIRD * q

          l_gdens = sum(grad_rho(ip, idir, 1:st%d%spin_channels))

          if (st%d%spin_channels == 1) then
            fxc(ip, idir, ispin) = l_gdens * &
              ( M_PI * beta**2/((M_ONE + sqrt(M_PI)*beta)**2) - M_ONE &
              + M_THIRD * M_PI * beta**2 / ((M_ONE + sqrt(M_PI)*beta)**3) )
          else
            fxc(ip, idir, ispin) = M_TWO * (grad_rho(ip, idir, 3-ispin) * &
              (M_PI * beta**2/((M_ONE + sqrt(M_PI)*beta)**2) - M_ONE ) &
              + l_gdens * (M_THIRD * M_PI * beta**2 / ((M_ONE + sqrt(M_PI)*beta)**3) ) &
              * st%rho(ip, 3-ispin) / rho)
          end if

          fxc(ip, idir, ispin) = fxc(ip, idir, ispin) * M_PI/(M_THREE*beta**2)  * st%rho(ip, ispin)
        end do
      end do
    end do

    ! We solve the Sturm-Liouville equation
    if (present(vxc)) then
      call solve_sturm_liouville(namespace, gr, st, space, fxc, internal_vxc)
    end if

    ! Get the energy from the virial relation
    ec = get_virial_energy(gr, st%d%spin_channels, fxc)

    ! Adds the calculated potential
    call lalg_axpy(gr%np, st%d%spin_channels, M_ONE, internal_vxc, vxc)

    SAFE_DEALLOCATE_A(fxc)

    POP_SUB(fbe_c_lda_sl)
  end subroutine fbe_c_lda_sl


#include "undef.F90"
#include "real.F90"
#include "xc_fbe_inc.F90"

#include "undef.F90"
#include "complex.F90"
#include "xc_fbe_inc.F90"

end module xc_fbe_oct_m

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
