!! Copyright (C) 2017 N. Tancogne-Dejean
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

! ---------------------------------------------------------
!> This routine returns the atomic orbital basis -- provided
!! by the pseudopotential structure in geo.
! ---------------------------------------------------------
subroutine X(get_atomic_orbital) (namespace, space, latt, pos, species, mesh, sm, ii, ll, jj, os, &
  orbind, radius, d_dim, use_mesh, normalize, index_shift)
  type(namespace_t),        intent(in)    :: namespace
  class(space_t),           intent(in)    :: space
  type(lattice_vectors_t),  intent(in)    :: latt
  real(real64),             intent(in)    :: pos(:) !< space%dim
  class(species_t),         intent(in)    :: species
  class(mesh_t),            intent(in)    :: mesh
  type(submesh_t),          intent(inout) :: sm
  integer,                  intent(in)    :: ii, ll
  real(real64),             intent(in)    :: jj
  type(orbitalset_t),       intent(inout) :: os
  integer,                  intent(in)    :: orbind
  real(real64),             intent(in)    :: radius
  integer,                  intent(in)    :: d_dim
  logical,                  intent(in)    :: use_mesh
  logical,                  intent(in)    :: normalize
  integer, optional,        intent(in)    :: index_shift !< In order to accumulate the orbitals from two consecutive calls, we can
  !                                                         pass an index shift, to set the orbitals in os%X(orb) to the right place


  real(real64), allocatable :: tmp(:)
  R_TYPE, allocatable :: ztmp(:,:)
  integer :: mm, shift_
  real(real64) :: coeff, norm

  PUSH_SUB(X(get_atomic_orbital))

  shift_ = optional_default(index_shift, 0)

  ! Initializes the submesh if not already done
  call atomic_orbital_init_submesh(namespace, space, latt, pos, mesh, sm, radius)

  if (.not. allocated(os%X(orb))) then
    if (use_mesh) then
      SAFE_ALLOCATE(os%X(orb)(1:mesh%np,1:os%ndim,1:os%norbs))
    else
      SAFE_ALLOCATE(os%X(orb)(1:sm%np,1:os%ndim,1:os%norbs))
    end if
    os%X(orb)(:,:,:) = R_TOTYPE(M_ZERO)
  end if

  if (d_dim == 1) then

    mm = orbind-1-ll

    !We get the orbital from the pseudopotential
    !In this case we want to get a real orbital and to store it in complex array
    SAFE_ALLOCATE(tmp(1:sm%np))
    call datomic_orbital_get_submesh(species, sm, ii, ll, mm, 1, tmp)
    if (normalize) then
      norm = dsm_nrm2(os%sphere, tmp)
      call lalg_scal(os%sphere%np, M_ONE/norm, tmp)
    end if

    if (use_mesh) then
      call submesh_add_to_mesh(sm, tmp, os%X(orb)(1:mesh%np, 1, orbind+shift_))
    else
      os%X(orb)(1:sm%np, 1, orbind+shift_) = tmp(1:sm%np)
    end if
    SAFE_DEALLOCATE_A(tmp)

  else
    SAFE_ALLOCATE(ztmp(1:sm%np, 1:2))

    if (is_close(jj, ll+M_HALF)) then
      mm = orbind - 2 - ll
      if (mm >= -ll) then
        call X(atomic_orbital_get_submesh)(species, sm, ii, ll, mm, 1, ztmp(:, 1))
        coeff = sqrt((ll+mm+M_ONE)/(M_TWO*ll+M_ONE))
        call lalg_scal(sm%np, coeff, ztmp(:, 1))
      else
        ztmp(1:sm%np, 1) = M_ZERO
      end if
      if (mm < ll) then
        call X(atomic_orbital_get_submesh)(species, sm, ii, ll, mm+1, 1, ztmp(:,2))
        coeff = sqrt((ll-mm)/(M_TWO*ll+M_ONE))
        call lalg_scal(sm%np, coeff, ztmp(:, 2))
      else
        ztmp(1:sm%np, 2) = M_ZERO
      end if
    else
      mm = orbind - ll
      call X(atomic_orbital_get_submesh)(species, sm, ii, ll, mm, 1, ztmp(:,2))
      coeff = -sqrt((ll+mm)/(M_TWO*ll+M_ONE))
      call lalg_scal(sm%np, coeff, ztmp(:, 2))
      if (mm > -ll) then
        call X(atomic_orbital_get_submesh)(species, sm, ii, ll, mm-1, 1, ztmp(:,1))
        coeff = sqrt((ll-mm+M_ONE)/(M_TWO*ll+M_ONE))
        call lalg_scal(sm%np, coeff, ztmp(:, 1))
      else
        ztmp(1:sm%np, 1) = M_ZERO
      end if
    end if

    if (normalize) then
      norm = X(sm_nrm2)(os%sphere, ztmp(:,1))**2
      norm = norm + X(sm_nrm2)(os%sphere, ztmp(:,2))**2
      norm = sqrt(norm)
      call lalg_scal(os%sphere%np, M_ONE/norm, ztmp(:,1))
      call lalg_scal(os%sphere%np, M_ONE/norm, ztmp(:,2))
    end if

    if (use_mesh) then
      call submesh_add_to_mesh(sm, ztmp(:, 1), os%X(orb)(1:mesh%np, 1, orbind+shift_))
      call submesh_add_to_mesh(sm, ztmp(:, 2), os%X(orb)(1:mesh%np, 2, orbind+shift_))
    else
      os%X(orb)(1:sm%np, 1, orbind+shift_) = ztmp(1:sm%np, 1)
      os%X(orb)(1:sm%np, 2, orbind+shift_) = ztmp(1:sm%np, 2)
    end if
    SAFE_DEALLOCATE_A(tmp)

  end if

  POP_SUB(X(get_atomic_orbital))

end subroutine X(get_atomic_orbital)



 ! ---------------------------------------------------------
subroutine X(atomic_orbital_get_submesh)(species, submesh, ii, ll, mm, ispin, phi, derivative)
  class(species_t), target, intent(in)  :: species    !< The species.
  type(submesh_t),          intent(in)  :: submesh    !< The submesh descriptor where the orbital will be calculated.
  integer,                  intent(in)  :: ii
  integer,                  intent(in)  :: ll
  integer,                  intent(in)  :: mm
  integer,                  intent(in)  :: ispin      !< The spin index.
  R_TYPE, contiguous,       intent(out) :: phi(:)     !< The function defined in the mesh where the orbitals is returned.
  logical,        optional, intent(in)  :: derivative !< If present and .true. returns the derivative of the orbital.

  integer :: ip, nn(3), idir
  real(real64) :: sqrtw, ww, prefac
  R_TYPE, allocatable :: ylm(:)
  type(ps_t), pointer :: ps
  type(spline_t) :: dur
  logical :: derivative_
  real(real64) :: tmp(1), norm

  PUSH_SUB(X(atomic_orbital_get_submesh))

  derivative_ = optional_default(derivative, .false.)

  ASSERT(ubound(phi, dim = 1) >= submesh%np)


  ! Only for 3D atomic systems
  if(species%represents_real_atom() .and. submesh%mesh%box%dim == 3) then

    !$omp parallel do
    do ip = 1, submesh%np
      phi(ip) = submesh%r(ip)
    end do

    select type(species)
    class is(pseudopotential_t)
      if (submesh%np > 0) then
        ps => species%ps
        if (.not. derivative_) then
          call spline_eval_vec(ps%ur(ii, ispin), submesh%np, phi)
        else
          call spline_init(dur)
          call spline_der(ps%ur(ii, ispin), dur, ps%projectors_sphere_threshold)
          call spline_eval_vec(dur, submesh%np, phi)
          call spline_end(dur)
        end if
        nullify(ps)
      end if

    type is(full_anc_t)
      ! For the ANC potential, we know the 1s orbital, so we use it as it is a better guess than
      ! the hydrogenic one
      if(ii == 1) then
        ASSERT(species%b() < 0) ! To be sure it was already computed

        ! See Eq. 16 in [Gygi J. Chem. Theory Comput. 2023, 19, 1300−1309]
        prefac = sqrt(species%get_z()**3/M_PI)
        !$omp parallel do private(ww)
        do ip = 1, submesh%np
          ww = species%get_z()*submesh%r(ip)*species%a()
          phi(ip) = -ww/species%a() * loct_erf(ww) + species%b()*exp(-ww**2)
          phi(ip) = prefac * exp(phi(ip))
        end do
      else
        prefac = sqrt((2*species%get_z()/ii)**3 * factorial(ii - ll - 1) / (2*ii*factorial(ii+ll)))
        ! FIXME: cache result somewhat. e.g. re-use result for each m. and use recursion relation.
        !$omp parallel do private(ww, tmp)
        do ip = 1, submesh%np
          ww = species%get_z()*submesh%r(ip) / ii
          if(-ww < M_MIN_EXP_ARG) then
            phi(ip) = M_ZERO
          else
            ! Replacement for loct_sf_laguerre_n(ii-ll-1, real(2*ll + 1, real64) , 2*ww) that produces overflow
            ! TODO: vectorize this call, as the new routine supports it
            call generalized_laguerre_polynomial(1, ii-ll-1, 2*ll + 1, (/M_TWO*ww/), tmp)
            phi(ip) = prefac * exp(-ww) * (M_TWO * ww)**ll * tmp(1)
          end if
        end do
      end if

    class is(allelectron_t)
      prefac = sqrt((2*species%get_z()/ii)**3 * factorial(ii - ll - 1) / (2*ii*factorial(ii+ll)))
      ! FIXME: cache result somewhat. e.g. re-use result for each m. and use recursion relation.
      !$omp parallel do private(ww, tmp)
      do ip = 1, submesh%np
        ww = species%get_z()*submesh%r(ip) / ii
        if(-ww < M_MIN_EXP_ARG) then
          phi(ip) = M_ZERO
        else
          ! Replacement for loct_sf_laguerre_n(ii-ll-1, real(2*ll + 1, real64) , 2*ww) that produces overflow
          ! TODO: vectorize this call, as the new routine supports it
          call generalized_laguerre_polynomial(1, ii-ll-1, 2*ll + 1, (/M_TWO*ww/), tmp)
          phi(ip) = prefac * exp(-ww) * (M_TWO * ww)**ll * tmp(1)
        end if
      end do
    end select

    SAFE_ALLOCATE(ylm(1:submesh%np))

#ifdef R_TCOMPLEX
    ! complex spherical harmonics. FIXME: vectorize
    !$omp parallel do
    do ip = 1, submesh%np
      call ylmr_cmplx(submesh%rel_x(1:3, ip), ll, mm, ylm(ip))
    end do
#else
    ! real spherical harmonics
    if(submesh%np > 0) then
      call loct_ylm(submesh%np, submesh%rel_x(1,1), submesh%r(1), ll, mm, ylm(1))
    end if
#endif

    !$omp parallel do
    do ip = 1, submesh%np
      phi(ip) = phi(ip)*ylm(ip)
    end do

    SAFE_DEALLOCATE_A(ylm)

  else

    ASSERT(.not. derivative_)
    ! Question: why not implemented derivatives here?
    ! Answer: because they are linearly dependent with lower-order Hermite polynomials.

    select type(species)
    class is(jellium_t)
      ww = species%get_omega()
    class is(allelectron_t)
      ww = species%get_omega()
    class default
      ASSERT(.false.)
    end select
    sqrtw = sqrt(ww)

    ! FIXME: this is a pretty dubious way to handle l and m quantum numbers. Why not use ylm?
    nn = (/ii, ll, mm/)

    !$omp parallel do private(idir)
    do ip = 1, submesh%np
      phi(ip) = exp(-ww*submesh%r(ip)**2/M_TWO)
      do idir = 1, submesh%mesh%box%dim
        phi(ip) = phi(ip) * hermite(nn(idir) - 1, submesh%rel_x(idir, ip)*sqrtw)
      end do
    end do

  end if

  if (species%is_full()) then
    ! When doing the all-electron calculations with a large spacing, the 1s orbitals
    ! will look like a delta function. However, the integral, i.e., |\phi_{1s}(0)|^2 dV
    ! will be very large, e.g. 644 for Ag with a spacing of 0.3.
    ! Without renormlazing this, the 1s orbitals leads to unnormalized guess density way too large
    ! and hence a wrong LCAO
    norm = X(sm_nrm2)(submesh, phi)
    call lalg_scal(submesh%np, M_ONE/norm, phi)
  end if

  POP_SUB(X(atomic_orbital_get_submesh))
end subroutine X(atomic_orbital_get_submesh)


 ! ---------------------------------------------------------
 ! Does the same job as atomic_orbital_get_submesh, but with an extra check that
 ! all points can be evaluated first.
 ! In case it cannot, it creates a temporary submesh on which the points can be evaluated,
 ! call atomic_orbital_get_submesh for this one, and copies back the points on the original one
subroutine X(atomic_orbital_get_submesh_safe)(species, submesh, ii, ll, mm, ispin, phi)
  class(species_t), target, intent(in)  :: species    !< The species.
  type(submesh_t),          intent(in)  :: submesh    !< The submesh descriptor where the orbital will be calculated.
  integer,                  intent(in)  :: ii
  integer,                  intent(in)  :: ll
  integer,                  intent(in)  :: mm
  integer,                  intent(in)  :: ispin      !< The spin index.
  R_TYPE, contiguous,       intent(out) :: phi(:)     !< The function defined in the mesh where the orbitals is returned.

  integer :: ip, is
  logical :: safe
  integer, allocatable :: map(:)
  type(submesh_t) :: tmp_sm
  R_TYPE, allocatable :: phi_tmp(:)
  real(real64) :: threshold
  type(ps_t), pointer :: ps

  if (submesh%np == 0) return

  PUSH_SUB(X(atomic_orbital_get_submesh_safe))

  safe = .true.
  select type(species)
  class is(pseudopotential_t)
    threshold = spline_range_max(species%ps%ur(ii, ispin))
    if (any(submesh%r(1:submesh%np) > threshold)) safe = .false.
  end select

  if (safe) then

    call X(atomic_orbital_get_submesh)(species, submesh, ii, ll, mm, ispin, phi)

  else
    select type(species)
    class is(pseudopotential_t)
      ps => species%ps
      threshold = spline_range_max(ps%ur(ii, ispin))

      is = 0
      do ip = 1, submesh%np
        if (submesh%r(ip) <= threshold) then
          is = is + 1
        end if
      end do

      SAFE_ALLOCATE(map(1:is))
      tmp_sm%mesh => submesh%mesh
      tmp_sm%np = is
      SAFE_ALLOCATE(tmp_sm%rel_x(1:submesh%mesh%box%dim, 1:tmp_sm%np))
      SAFE_ALLOCATE(tmp_sm%r(1:tmp_sm%np))
      SAFE_ALLOCATE(phi_tmp(1:tmp_sm%np))
      is = 0
      do ip = 1, submesh%np
        if (submesh%r(ip) <= threshold) then
          is = is + 1
          map(is) = ip
          tmp_sm%rel_x(:, is) = submesh%rel_x(:, ip)
          tmp_sm%r(is) = submesh%r(ip)
        end if
      end do

      call X(atomic_orbital_get_submesh)(species, tmp_sm, ii, ll, mm, ispin, phi_tmp)

      phi = R_TOTYPE(M_ZERO)
      do ip = 1, tmp_sm%np
        phi(map(ip)) = phi_tmp(ip)
      end do

      call submesh_end(tmp_sm)
      SAFE_DEALLOCATE_A(map)

    class default
      ASSERT(.false.)
    end select
  end if

  POP_SUB(X(atomic_orbital_get_submesh_safe))
end subroutine X(atomic_orbital_get_submesh_safe)
