!! Copyright (C) 2023. A Buccheri.
!!
!! This Source Code Form is subject to the terms of the Mozilla Public
!! License, v. 2.0. If a copy of the MPL was not distributed with this
!! file, You can obtain one at https://mozilla.org/MPL/2.0/.
!!

!> @brief Driver for Chebyshev filter-based solver.
!!
!! This routine is implemented according to Algorithm 4.1 in the paper:
!! ""Chebyshev-filtered subspace iteration method free of sparse diagonalization for
!!  solving the Kohn–Sham equation"". http://dx.doi.org/10.1016/j.jcp.2014.06.056
!!
!! The scaled Chebyshev algorithm is always utilised, as we get an
!! estimate of the lowest eigenvalue of H from the lowest Ritz value of
!! the prior step. The reason for the scaling is to prevent a potential overflow,
!! which may happen if the filter degree is large or if the smallest eigenvalue is
!! mapped far away from −1.
!!
!! For the first SCF step, a Chebyshev polynomial of filter_params%degree is applied
!! filter_params%n_iter times to a set of search vectors found from the initial guess
!! for the density.
subroutine X(chebyshev_filter_solver)(namespace, sdiag, mesh, st, hm, ik, subspace_tol, &
  filter_params, scf_iter, prior_residuals)
  type(namespace_t),             intent(in)    :: namespace        !< Calling namespace
  type(subspace_t),              intent(in)    :: sdiag            !< Subspace diagonalisation choice
  class(mesh_t),                 intent(in)    :: mesh             !< Real-space mesh
  type(states_elec_t),           intent(inout) :: st               !< Eigenstates
  type(hamiltonian_elec_t),      intent(in)    :: hm               !< Hamiltonian
  integer,                       intent(in)    :: ik               !< k-point index
  real(real64),                  intent(in)    :: subspace_tol     !< Subspace iterative solver tolerance
  type(eigen_chebyshev_t),       intent(in)    :: filter_params    !< Chebyshev filter parameters
  integer,                       intent(in)    :: scf_iter         !< SCF iteration
  real(real64),                  intent(in)    :: prior_residuals(:) !< Eigenvalue residuals from prior SCF step

  type(chebyshev_filter_bounds_t), pointer :: bounds               !< Filter bounds
  integer, allocatable :: max_degree_estimate(:)                   !< Max polynomial estimate for each block
  real(real64) :: a_l, lower_bound, upper_bound                           !< Limits

  PUSH_SUB(X(chebyshev_filter_solver))

  if (scf_iter == 1 .and. filter_params%n_iter > 0) then
    call X(firstscf_iterative_chebyshev_filter)(namespace, sdiag, mesh, st, hm, ik, subspace_tol, filter_params)

  else
    ! Update lower bound and min eigenvalue from largest and smallest prior Ritz values
    lower_bound = maxval(st%eigenval(:, ik)) + 1e-3_real64
    a_l = minval(st%eigenval(:, ik))

    upper_bound = X(upper_bound_estimator)(namespace, mesh, st, hm, ik, filter_params%n_lanczos)
    bounds => chebyshev_filter_bounds_t(lower_bound, upper_bound, a_l=a_l, safe_zero=.true.)

    ! Allocate all blocks, rather than those local to comm, to allow MPI-safe printing
    ! If allocation were local, 1:st%group%block_end - st%group%block_start + 1
    SAFE_ALLOCATE(max_degree_estimate(1:st%group%nblocks))
    max_degree_estimate = 0

    if (filter_params%optimize_degree) then
      call optimal_chebyshev_polynomial_degree(namespace, st, ik, prior_residuals, subspace_tol, bounds, &
        filter_params%degree, max_degree_estimate)
    else
      max_degree_estimate = filter_params%degree
    endif

    call X(chebyshev_filter)(namespace, mesh, st, hm, max_degree_estimate, bounds, ik)

    SAFE_DEALLOCATE_A(max_degree_estimate)
    SAFE_DEALLOCATE_P(bounds)
  endif

  POP_SUB(X(chebyshev_filter_solver))
end subroutine X(chebyshev_filter_solver)


!> @brief Iterative application of Chebyshev filter, for use with the first SCF step.
!!
!! The initial search vectors are found using the initial for the density (sum of atomic
!! densities LCAO, etc), and returned in st: st%group%psib.
subroutine X(firstscf_iterative_chebyshev_filter)(namespace, sdiag, mesh, st, hm, ik, tolerance, filter_params)
  type(namespace_t),        intent(in)    :: namespace      !< Calling namespace
  type(subspace_t),         intent(in)    :: sdiag          !< Subspace diagonalisation choice
  class(mesh_t),            intent(in)    :: mesh           !< Real-space mesh
  type(states_elec_t),      intent(inout) :: st             !< Initial guess at search vectors, and much more
  type(hamiltonian_elec_t), intent(in)    :: hm             !< Hamiltonian
  integer,                  intent(in)    :: ik             !< k-point index
  real(real64),             intent(in)    :: tolerance      !< Subspace iterative solver tolerance
  type(eigen_chebyshev_t),  intent(in)    :: filter_params  !< Chebyshev filter parameters

  real(real64),   allocatable :: e_diff(:)                         !< Differences in the eigenvalues
  logical, allocatable :: converged(:)                      !< Has each eigenvalue converged
  R_TYPE,  allocatable :: search_vector(:, :)               !< Random search vector
  integer, allocatable :: degree(:)                         !< Filter degree
  type(chebyshev_filter_bounds_t), pointer :: bounds        !< Filter bounds
  real(real64)   :: upper_bound                                    !< Upper bound
  real(real64)   :: lower_bound                                    !< Lower bound
  integer :: ifilter, ist                                   !< Loop indices

  PUSH_SUB(X(firstscf_iterative_chebyshev_filter))

  SAFE_ALLOCATE(search_vector(1:mesh%np_part, 1:hm%d%dim))
  ! We reset the seed, such that the result is the same for each k-point
  ! This makes the result independent of the k-point parallelization
  call states_elec_generate_random_vector(mesh, st, search_vector, normalized=.true., reset_seed=.true.)

  ! Initial bounds estimate
  call X(filter_bounds_estimator)(namespace, mesh, hm, ik, filter_params%n_lanczos, &
    filter_params%bound_mixing, search_vector, bounds)
  SAFE_DEALLOCATE_A(search_vector)
  upper_bound = bounds%upper

  SAFE_ALLOCATE(e_diff(1:st%nst))
  ! If allocation were local, degree would have size (1:st%group%block_end - st%group%block_start + 1)
  SAFE_ALLOCATE(degree(1:st%group%nblocks))
  degree = 0

  do ifilter = 1, filter_params%n_iter
    write(message(1), '(a, I3, a, I3)') &
      "Debug: Chebyshev 1st iterative step. Filter loop  ", ifilter, '/', filter_params%n_iter
    call messages_info(1, namespace=namespace, debug_only=.true.)

    ! Diagonalise subspace to get Ritz pairs
    call X(subspace_diag)(sdiag, namespace, mesh, st, hm, ik, st%eigenval(:, ik), e_diff, nonortho = ifilter > 1)

    converged = e_diff < tolerance
    if (all(converged)) then
      SAFE_DEALLOCATE_A(degree)
      SAFE_DEALLOCATE_A(e_diff)
      SAFE_DEALLOCATE_P(bounds)
      POP_SUB(X(firstscf_iterative_chebyshev_filter))
      return
    endif

    SAFE_DEALLOCATE_P(bounds)
    lower_bound = maxval(st%eigenval(:, ik))

    ! If the largest Ritz value becomes larger than max eigenvalue of the spectrum
    ! this implies that the estimate for the upper bound needs refining
    if (lower_bound >= upper_bound) then
      upper_bound = X(upper_bound_estimator)(namespace, mesh, st, hm, ik, filter_params%n_lanczos)
    endif

    bounds => chebyshev_filter_bounds_t(lower_bound, &    ! Largest Ritz value
      upper_bound, &                       ! Unchanged
      a_l=minval(st%eigenval(:, ik)), & ! Smallest Ritz value
      safe_zero=.true.)

    if (filter_params%optimize_degree) then
      call optimal_chebyshev_polynomial_degree(namespace, st, ik, e_diff, tolerance, bounds, &
        filter_params%degree, degree)
    else
      degree = filter_params%degree
    endif


    call X(chebyshev_filter)(namespace, mesh, st, hm, degree, bounds, ik)
  enddo

  write(message(1), '(a,1x,i7)') 'Debug: Chebyshev 1st iterative step state convergence for ik = ', ik
  call messages_info(1, namespace=namespace, debug_only=.true.)

  do ist = 1, size(converged)
    write(message(1), '(i8,1x,2(f16.12,1x),l)') ist, st%eigenval(ist, ik), e_diff(ist), converged(ist)
    call messages_info(1, namespace=namespace, debug_only=.true.)
  enddo

  SAFE_DEALLOCATE_A(degree)
  SAFE_DEALLOCATE_A(e_diff)
  SAFE_DEALLOCATE_P(bounds)
  POP_SUB(X(firstscf_iterative_chebyshev_filter))
end subroutine X(firstscf_iterative_chebyshev_filter)


!> @brief Chebyshev Filter.
!!
!! Filter an eigenspectrum by an m-degree Chebyshev polynomial that dampens on the interval [a,b].
!! Based on algorithm 3.2 of ""Chebyshev-filtered subspace iteration method free of sparse diagonalization
!! for solving the Kohn–Sham equation"". http://dx.doi.org/10.1016/j.jcp.2014.06.056
!!
!! Application of the simple or scaled filter depends upon the choice of bounds.
!! In the simple case, sigma will reduce to 1.
subroutine X(chebyshev_filter) (namespace, mesh, st, hm, degree, bounds, ik, normalize)
  type(namespace_t),               intent(in)    :: namespace  !< Calling namespace
  class(mesh_t),                   intent(in)    :: mesh       !< Real-space mesh
  class(hamiltonian_elec_t),       intent(in)    :: hm         !< Hamiltonian
  integer,                         intent(in)    :: degree(:)  !< Chebyshev polynomial degree, per block
  !                                                            !< (filter applied using recursive definition)
  type(chebyshev_filter_bounds_t), intent(in)    :: bounds     !< Polynomial filter bounds of the subspace to dampen
  integer,                         intent(in)    :: ik         !< k-point index
  type(states_elec_t), target,  intent(inout)    :: st         !< KS containing input eigenvectors and
  !                                                            !< returned with updated eigenvectors
  logical, optional,               intent(in)    :: normalize  !< By default, the outcome is normalized

  real(real64) :: c                                                   !< Centre of the filter limits
  real(real64) :: hw                                                  !< Half-width of the filter limits
  real(real64) :: sigma, sigma_new, tau                               !< Filter scaling
  type(wfs_elec_t), pointer :: Y_old, Y, Y_new                 !< Pointers, used to avoid memory transfer
  integer :: idegree, iblock                                   !< Loop indices
  type(batch_pointer_t) :: physical_memory(3)                  !< Physical memory
  integer :: iperm(3)                                          !< Indices for the circular permutation of the pointers

  PUSH_SUB(X(chebyshev_filter))

  call profiling_in(TOSTRING(X(CHEBY)))

  hw = bounds%half_width()
  c = bounds%center()

  SAFE_ALLOCATE(physical_memory(2)%batch)
  SAFE_ALLOCATE(physical_memory(3)%batch)

  ! Block index
  do iblock = st%group%block_start, st%group%block_end

    sigma = bounds%sigma()
    ! Tau is not updated w.r.t. application  of the filter
    tau = M_TWO / sigma

    physical_memory(1)%batch => st%group%psib(iblock, ik)
    call hm%phase%set_phase_corr(mesh, st%group%psib(iblock, ik))
    call physical_memory(1)%batch%copy_to(physical_memory(2)%batch)
    call physical_memory(1)%batch%copy_to(physical_memory(3)%batch)

    ! Define Y = (H X - c X) * (sigma / hw)  in stages:
    ! Y = H X
    call X(hamiltonian_elec_apply_batch)(hm, namespace, mesh, physical_memory(1)%batch, physical_memory(2)%batch)
    ! Y -> (Y - cX)
    call batch_axpy(mesh%np, -c, physical_memory(1)%batch, physical_memory(2)%batch)
    ! Y -> Y * sigma / hw
    call batch_scal(mesh%np, sigma / hw, physical_memory(2)%batch)

    ! If the filter is only applied once, update Y_new accordingly
    if (degree(iblock) == 1) then
      Y_new => physical_memory(2)%batch
    endif

    do idegree = 2, degree(iblock)
      sigma_new = M_ONE / (tau - sigma)

      ! Having st%group%psib(iblock, ik), Yb, and HY been allocated, we have enough memory to
      ! avoid memory transfer. To do this, we employ pointers that we swap during the execution
      ! In order to avoid memory transfer, we do a circular permutation of the pointers
      iperm = [mod(idegree - 2, 3) + 1, mod(idegree - 1, 3) + 1, mod(idegree, 3) + 1]
      Y_old => physical_memory(iperm(1))%batch
      Y     => physical_memory(iperm(2))%batch
      Y_new => physical_memory(iperm(3))%batch

      ! Construct Y_new = (HY - c * Y) * (2 * sigma_new / hw) - (sigma * sigma_new) * X
      ! in stages:
      ! HY = H Y
      call X(hamiltonian_elec_apply_batch)(hm, namespace, mesh, Y, Y_new)
      ! HY -> (HY - cY)
      call batch_axpy(mesh%np, -c, Y, Y_new)
      ! HY -> (2 * sigma_new / hw) * HY
      call batch_scal(mesh%np, M_TWO * sigma_new / hw, Y_new)
      ! HY -> HY - (sigma * sigma_new) * Xb
      call batch_axpy(mesh%np, -sigma * sigma_new, Y_old, Y_new)

      sigma = sigma_new
    end do

    ! Copy the final data
    call Y_new%copy_data_to(mesh%np, st%group%psib(iblock, ik))
    nullify(Y_old, Y, Y_new)
    nullify(physical_memory(1)%batch)
    call physical_memory(2)%batch%end()
    call physical_memory(3)%batch%end()

    ! Remove the phase correction
    call hm%phase%unset_phase_corr(mesh, st%group%psib(iblock, ik))

    ! Normalize each state to avoid numerical problems in the subspace diagonalization
    ! For the composition unit test, this is deactivated
    if (optional_default(normalize, .true.)) then
      call X(mesh_batch_normalize)(mesh, st%group%psib(iblock, ik))
    end if
  end do

  SAFE_DEALLOCATE_P(physical_memory(2)%batch)
  SAFE_DEALLOCATE_P(physical_memory(3)%batch)

  call profiling_out(TOSTRING(X(CHEBY)))

  POP_SUB(X(chebyshev_filter))
end subroutine X(chebyshev_filter)

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
