!! Copyright (C) 2023 N. Tancogne-Dejean
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

!>@brief Perform the SCDM method and returns a states_elec_t object
!! with localized states
subroutine X(scdm_get_localized_states)(this, namespace, space, gr, kpoints, st, scdm_st)
  type(scdm_t),         intent(in)  :: this
  type(namespace_t),    intent(in)  :: namespace
  type(space_t),        intent(in)  :: space
  type(grid_t),         intent(in)  :: gr
  type(kpoints_t),      intent(in)  :: kpoints
  type(states_elec_t),  intent(in)  :: st
  type(states_elec_t),  intent(out) :: scdm_st

  integer, allocatable :: jpvt(:)
  R_TYPE, allocatable   :: umnk(:,:,:)       ! SCDM-Wannier gauge matrices U(k)
  R_TYPE, allocatable   :: chi(:,:), chi_diag(:,:),chi2(:,:), zwn(:,:)
  real(real64), allocatable   :: chi_eigenval(:), occ_temp(:)
  real(real64)         :: scdm_mu, smear,  kvec(3), kpoint(space%dim)
  integer :: ik_gamma, ipmax, iw, iw2, jst, ip, idim, ispin, nspin, idimmax
  integer :: ik, ist
  real(real64) :: wmod, wmodmax
  R_TYPE, allocatable :: u(:, :), vt(:, :), psi(:,:)
  real(real64), allocatable :: sg_values(:)

  PUSH_SUB(X(scdm_get_localized_states))

  !TODO: finish the code for k-points
  ASSERT(st%nik == 1)

  call states_elec_copy(scdm_st, st)

  scdm_mu = maxval(st%eigenval)

  nspin = 1
  if (st%d%ispin == SPIN_POLARIZED) nspin = 2

  ! Only the orbitals at the Gamma point are used to find the comlumn permutation
  do ispin = 1, nspin

    SAFE_ALLOCATE(jpvt(1:gr%np_global*st%d%dim))
    SAFE_ALLOCATE(occ_temp(1:st%nst))

    ik_gamma = ispin

    !TODO: Restore these lines
    !  ! smear the states at gamma
    !  do ist = 1, st%nst
    !    occ_temp(ist)= st%occ(ist, ik_gamma)
    !    st%occ(ist, ik_gamma)=M_HALF*loct_erfc((st%eigenval(ist, ik_gamma)-scdm_mu) / scdm_sigma)
    !  end do

    call zstates_elec_rrqr_decomposition(st, namespace, gr, st%nst, .true., ik_gamma, jpvt)

    !TODO: Restore these lines
    !  ! reset occupations at gamma
    !  do ist = 1, st%nst
    !    st%occ(ist, ik_gamma) = occ_temp(ist)
    !  end do

    SAFE_ALLOCATE(umnk(1:st%nst, 1:st%nst, 1:st%nik))

    ! auxiliary arrays for scdm procedure
    SAFE_ALLOCATE(chi(1:st%nst, 1:st%nst))
    SAFE_ALLOCATE(chi_diag(1:st%nst, 1:st%nst))
    SAFE_ALLOCATE(chi2(1:st%nst, 1:st%nst))
    SAFE_ALLOCATE(chi_eigenval(1:st%nst))
    SAFE_ALLOCATE(psi(1:gr%np, 1:st%d%dim))

    chi = M_ZERO
    do ik = 1, st%nik
      kvec(:) = kpoints%reduced%point(:, ik)

      if (st%d%get_spin_index(ik) /= ispin) cycle ! We treat each spin channel independently


      !TODO: Restore these lines
      !  if (st%d%ispin == SPIN_POLARIZED) then
      !    ik_real = (ik-1)*2 + st%d%spin_channels
      !  else
      !    ik_real = ik
      !  end if


      ! We now use the first J column of the matrix with permuted columns P\Pi
      do ist = 1, st%nst
        call states_elec_get_state(st, gr, ist, ik, psi)
        smear = st%occ(ist, ik) !M_HALF * loct_erfc((st%eigenval(ist, ik) - scdm_mu) / scdm_sigma)
        ! NOTE: here check for domain parallelization
        ASSERT(.not. gr%parallel_in_domains) ! The values goes up to np_global

        do jst = 1, st%nst
          ! Convert the spin-space index into separate space and spin indices
          ip = mod(jpvt(jst), int(gr%np_global, int32))
          idim = ceiling((jpvt(jst) - ip)/real(gr%np_global, real64) ) + 1

          !TODO: Restore the application of the phase. Check properly the sign
          chi(ist, jst) = smear * R_CONJ(psi(ip, idim))
!           * exp(M_zI * dot_product(gr%x(jpvt(jst), 1:3), kvec(1:3)))
        end do !jst
      end do !ist

      SAFE_ALLOCATE( u(1:st%nst, 1:st%nst))
      SAFE_ALLOCATE(vt(1:st%nst, 1:st%nst))
      SAFE_ALLOCATE(sg_values(1:st%nst))

      call lalg_singular_value_decomp(st%nst, st%nst, chi, u, vt, sg_values)
      umnk(:,:,ik) = matmul(u, vt)

      SAFE_DEALLOCATE_A(u)
      SAFE_DEALLOCATE_A(vt)
      SAFE_DEALLOCATE_A(sg_values)

    end do ! ik

    SAFE_DEALLOCATE_A(chi)
    SAFE_DEALLOCATE_A(chi_diag)
    SAFE_DEALLOCATE_A(chi2)
    SAFE_DEALLOCATE_A(chi_eigenval)
    SAFE_DEALLOCATE_A(jpvt)
    SAFE_DEALLOCATE_A(occ_temp)

    !Computing the Wannier states in the primitive cell, from the U matrices
    SAFE_ALLOCATE(zwn(1:gr%np, 1:st%d%dim))
    !TODO: Uncomment. Do not forget to deallocate this array
    !  SAFE _ALLOCATE(phase(1:gr%np))

    do iw = 1, st%nst
      zwn(:,:) = M_Z0

      do ik = 1, st%nik

        if (st%d%get_spin_index(ik) /= ispin) cycle ! We treat each spin channel independently

        kpoint(1:space%dim) = kpoints%get_point(ik, absolute_coordinates=.true.)

        !TODO: Reactivate these lines
        !    ! We compute the Wannier orbital on a grid centered around the Wannier function
        ! The minus sign is here is for the wrong convention of Octopus
        !    do ip = 1, mesh%np
        !      xx = mesh%x(ip, 1:space%dim)-centers(1:space%dim, iw)
        !      xx = ions%latt%fold_into_cell(xx)
        !      phase(ip) = exp(-M_zI* sum( xx * kpoint(1:space%dim)))
        !    end do

        do iw2 = 1, st%nst
          call states_elec_get_state(st, gr, iw2, ik, psi)

          ! Construction of the orbitals from the U matrix
          do idim = 1, st%d%dim
            !$omp parallel do
            do ip = 1, gr%np
              !TODO: restore the phase application
              zwn(ip, idim) = zwn(ip, idim) + Umnk(iw2, iw, ik) * psi(ip, idim) !* phase(ip)
            end do
          end do
        end do!iw2

        ! Following what Wannier90 is doing, we fix the global phase by setting the max to be real
        ! We also normalize to the number of k-point at this step
        ASSERT(.not. gr%parallel_in_domains) ! The max finding is not correct in this case

        ipmax = 0
        wmodmax = M_ZERO
        do idim = 1, st%d%dim
          do ip = 1, gr%np
            wmod = real(zwn(ip, idim)*R_CONJ(zwn(ip, idim)), real64)
            if (wmod > wmodmax) then
              ipmax = ip
              idimmax = idim
              wmodmax = wmod
            end if
          end do
        end do
        ! We normalize by the number of k-point per spin channel
        call lalg_scal(gr%np, st%d%dim, sqrt(wmodmax)/zwn(ipmax, idimmax)/st%nik*nspin, zwn)


        call states_elec_set_state(scdm_st, gr, iw, ik, zwn)

#ifdef R_TCOMPLEX
        ! Checking the ratio imag/real
        wmodmax = M_ZERO
        do idim = 1, st%d%dim
          do ip = 1, gr%np
            if(abs(real(zwn(ip, idim), real64)) >= 1e-2_real64) then
              wmodmax = max(wmodmax, abs(aimag(zwn(ip, idim)))/abs(real(zwn(ip, idim), real64)))
            end if
          end do
        end do

        write(message(1), '(a,i4,a,f11.6)') 'Wannier function ', iw, ' Max. Im/Re Ratio = ', wmodmax
        call messages_info(1, debug_only=.true.)
#endif
      end do !ik
    end do !iw

    SAFE_DEALLOCATE_A(psi)
    SAFE_DEALLOCATE_A(umnk)
    SAFE_DEALLOCATE_A(zwn)

  end do ! ispin

  POP_SUB(X(scdm_get_localized_states))
end subroutine X(scdm_get_localized_states)

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
