!! Copyright (C) 2002-2006 M. Marques, A. Castro, A. Rubio, G. Bertsch
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!


!>@brief Diagonalises the Hamiltonian in the subspace defined by the states.
!!
!! On exit, the states are always orthonormal.
subroutine X(subspace_diag)(this, namespace, mesh, st, hm, ik, eigenval, diff, nonortho)
  type(subspace_t),            intent(in)    :: this
  type(namespace_t),           intent(in)    :: namespace
  class(mesh_t),               intent(in)    :: mesh
  type(states_elec_t), target, intent(inout) :: st
  type(hamiltonian_elec_t),    intent(in)    :: hm
  integer,                     intent(in)    :: ik
  real(real64), contiguous,           intent(out)   :: eigenval(:)
  real(real64), contiguous,           intent(out)   :: diff(:)
  logical, optional,           intent(in)    :: nonortho

  integer :: ist
  R_TYPE, allocatable :: psi(:, :, :)

  PUSH_SUB(X(subspace_diag))
  call profiling_in(TOSTRING(X(SUBSPACE_DIAG)))

  select case (this%method)

  case (OPTION__SUBSPACEDIAGONALIZATION__SCALAPACK)

    if(optional_default(nonortho, .false.)) then
      call X(states_elec_orthogonalization_full)(st, namespace, mesh, ik)
    end if

    SAFE_ALLOCATE(psi(1:mesh%np_part, 1:st%d%dim, st%st_start:max(st%st_end,1)))

    do ist = st%st_start, st%st_end
      call states_elec_get_state(st, mesh, ist, ik, psi(:, :, ist))
    end do

    call X(subspace_diag_scalapack)(namespace, mesh, st, hm, ik, eigenval, psi, diff)

    do ist = st%st_start, st%st_end
      call states_elec_set_state(st, mesh, ist, ik, psi(:, :, ist))
    end do

    SAFE_DEALLOCATE_A(psi)

  case (OPTION__SUBSPACEDIAGONALIZATION__STANDARD)

    call X(subspace_diag_standard)(namespace, mesh, st, hm, ik, eigenval, diff, nonortho)

  case (OPTION__SUBSPACEDIAGONALIZATION__NONE)
    ! do nothing unless states are non-orthonormal
    if(optional_default(nonortho, .false.)) then
      call X(states_elec_orthogonalization_full)(st, namespace, mesh, ik)
    end if

  case default
    ASSERT(.false.)

  end select

  if (st%parallel_in_states) then
    call states_elec_parallel_gather(st, diff)
  end if

  call profiling_out(TOSTRING(X(SUBSPACE_DIAG)))
  POP_SUB(X(subspace_diag))
end subroutine X(subspace_diag)

! ---------------------------------------------------------
!>@brief Diagonalises the Hamiltonian in the subspace defined by the states.
!!
!! On exit, the states are always orthonormal. On entry, they can be non-orthonomal.
!! In this case, we solve a generalized eigenvalue problem using the overlap matrix.
!! This is usefull to avoid doing before a Cholesky decomposition of the states
subroutine X(subspace_diag_standard)(namespace, mesh, st, hm, ik, eigenval, diff, nonortho)
  type(namespace_t),           intent(in)    :: namespace
  class(mesh_t),               intent(in)    :: mesh
  type(states_elec_t), target, intent(inout) :: st          !< Orthogonalised eigenstates
  type(hamiltonian_elec_t),    intent(in)    :: hm
  integer,                     intent(in)    :: ik
  real(real64), contiguous,           intent(out)   :: eigenval(:) !< Eigenvalues
  real(real64),                intent(out)   :: diff(:)     !< Residue
  logical, optional,           intent(in)    :: nonortho    !< if yes, eigenvectors are not orthonormal on entry

  R_TYPE, allocatable :: hmss(:, :) !< Hamiltonian matrix in the subspace form by the Kohn-Sham states
  R_TYPE, allocatable :: rdiff(:), overlap(:,:)
  integer             :: ib, minst, maxst
  type(wfs_elec_t)       :: hpsib
  R_TYPE :: phase
  logical :: nonorthonormal

  PUSH_SUB(X(subspace_diag_standard))

  SAFE_ALLOCATE(hmss(1:st%nst, 1:st%nst))

  nonorthonormal = optional_default(nonortho, .false.)

  ! Diagonalize the Hamiltonian in the subspace.
  if (nonorthonormal) then
    ! For non-orthonormal states, we combine the orthogonalization with the subspace diagonalization
    ! We just need to compute the overlap matrix and solve a generalized eigenvalue problem
    SAFE_ALLOCATE(overlap(1:st%nst, 1:st%nst))
    call X(subspace_diag_hamiltonian)(namespace, mesh, st, hm, ik, hmss, overlap)
    call lalg_geneigensolve(st%nst, hmss, overlap, eigenval, preserve_mat=.false.)
    SAFE_DEALLOCATE_A(overlap)
  else
    call X(subspace_diag_hamiltonian)(namespace, mesh, st, hm, ik, hmss)
    call lalg_eigensolve(st%nst, hmss, eigenval)
  end if

  do ib = 1, st%nst
    phase = hmss(ib,ib)
    if(abs(phase) > 1e-16_real64) then
      phase = phase / abs(phase)
      call lalg_scal(st%nst, M_ONE/phase, hmss(:, ib))
    end if
  end do

  ! Calculate the new eigenfunctions as a linear combination of the
  ! old ones.
  call states_elec_rotate(st, namespace, mesh, hmss, ik)

  ! Recalculate the residues.
  call profiling_in(TOSTRING(X(SUBSPACE_DIFF)))

  SAFE_ALLOCATE(rdiff(1:st%nst))
  rdiff(1:st%nst) = R_TOTYPE(M_ZERO)

  do ib = st%group%block_start, st%group%block_end

    minst = states_elec_block_min(st, ib)
    maxst = states_elec_block_max(st, ib)

    if (hm%apply_packed()) call st%group%psib(ib, ik)%do_pack()

    call st%group%psib(ib, ik)%copy_to(hpsib)

    call X(hamiltonian_elec_apply_batch)(hm, namespace, mesh, st%group%psib(ib, ik), hpsib)
    call batch_axpy(mesh%np, -eigenval, st%group%psib(ib, ik), hpsib)
    call X(mesh_batch_dotp_vector)(mesh, hpsib, hpsib, rdiff(minst:maxst), reduce = .false.)

    call hpsib%end()

    if (hm%apply_packed()) call st%group%psib(ib, ik)%do_unpack(copy = .false.)

  end do

  call mesh%allreduce(rdiff)
  diff(1:st%nst) = sqrt(abs(rdiff(1:st%nst)))

  SAFE_DEALLOCATE_A(rdiff)

  call profiling_out(TOSTRING(X(SUBSPACE_DIFF)))

  SAFE_DEALLOCATE_A(hmss)

  POP_SUB(X(subspace_diag_standard))

end subroutine X(subspace_diag_standard)

! ---------------------------------------------------------
!> This routine diagonalises the Hamiltonian in the subspace defined by
!! the states; this version is aware of parallelization in states but
!! consumes more memory.
subroutine X(subspace_diag_scalapack)(namespace, mesh, st, hm, ik, eigenval, psi, diff)
  type(namespace_t),        intent(in)    :: namespace
  class(mesh_t),            intent(in)    :: mesh
  type(states_elec_t),      intent(inout) :: st
  type(hamiltonian_elec_t), intent(in)    :: hm
  integer,                  intent(in)    :: ik
  real(real64),             intent(out)   :: eigenval(:)
  R_TYPE, contiguous,       intent(inout) :: psi(:, :, st%st_start:)
  real(real64),             intent(out)   :: diff(:)

#ifdef HAVE_SCALAPACK
  R_TYPE, allocatable :: hs(:, :), hpsi(:, :, :), evectors(:, :)
  integer             :: ist, size
  integer :: psi_block(1:2), total_np, psi_desc(BLACS_DLEN), hs_desc(BLACS_DLEN), info
  integer :: nbl, nrow, ncol
  type(wfs_elec_t) :: psib, hpsib
#ifdef HAVE_ELPA
  class(elpa_t), pointer :: elpa
#else
  integer :: lwork
  R_TYPE :: rttmp
  R_TYPE, allocatable :: work(:)
#ifdef R_TCOMPLEX
  integer :: lrwork
  complex(real64), allocatable :: rwork(:)
  complex(real64) :: ftmp
#endif
#endif

  PUSH_SUB(X(subspace_diag_scalapack))

  SAFE_ALLOCATE(hpsi(1:mesh%np_part, 1:st%d%dim, st%st_start:max(st%st_end,1)))

  call states_elec_parallel_blacs_blocksize(st, namespace, mesh, psi_block, total_np)

  call descinit(psi_desc(1), total_np, st%nst, psi_block(1), psi_block(2), 0, 0,  st%dom_st_proc_grid%context, &
    st%d%dim*mesh%np_part, info)

  if (info /= 0) then
    write(message(1), '(a,i6)') "subspace diagonalization descinit for psi failed with error code ", info
    call messages_fatal(1, namespace=namespace)
  end if

  ! select the blocksize, we use the division used for state
  ! parallelization but with a maximum of 64
  nbl = min(64, psi_block(2))

  ! calculate the size of the matrix in each node
  nrow = max(1, numroc(st%nst, nbl, st%dom_st_proc_grid%myrow, 0, st%dom_st_proc_grid%nprow))
  ncol = max(1, numroc(st%nst, nbl, st%dom_st_proc_grid%mycol, 0, st%dom_st_proc_grid%npcol))

  SAFE_ALLOCATE(hs(1:nrow, 1:ncol))

  call descinit(hs_desc(1), st%nst, st%nst, nbl, nbl, 0, 0, st%dom_st_proc_grid%context, nrow, info)

  if (info /= 0) then
    write(message(1), '(a,i6)') "subspace diagonalization descinit for Hamiltonian failed with error code ", info
    call messages_fatal(1, namespace=namespace)
  end if

  ! calculate |hpsi> = H |psi>
  do ist = st%st_start, st%st_end, st%block_size
    size = min(st%block_size, st%st_end - ist + 1)

    call wfs_elec_init(psib, hm%d%dim, ist, ist + size - 1, psi(:, :, ist:), ik)
    call wfs_elec_init(hpsib, hm%d%dim, ist, ist + size - 1, hpsi(: , :, ist:), ik)

    call X(hamiltonian_elec_apply_batch)(hm, namespace, mesh, psib, hpsib)

    call psib%end()
    call hpsib%end()
  end do

  ! We need to set to zero some extra parts of the array
  if (st%d%dim == 1) then
    psi(mesh%np + 1:psi_block(1), 1:st%d%dim, st%st_start:max(st%st_end,1)) = M_ZERO
    hpsi(mesh%np + 1:psi_block(1), 1:st%d%dim, st%st_start:max(st%st_end,1)) = M_ZERO
  else
    psi(mesh%np + 1:mesh%np_part, 1:st%d%dim, st%st_start:max(st%st_end,1)) = M_ZERO
    hpsi(mesh%np + 1:mesh%np_part, 1:st%d%dim, st%st_start:max(st%st_end,1)) = M_ZERO
  end if

  call profiling_in(TOSTRING(X(SCALAPACK_GEMM1)))

  ! get the matrix <psi|H|psi> = <psi|hpsi>
  call pblas_gemm('c', 'n', st%nst, st%nst, total_np, &
    R_TOTYPE(mesh%vol_pp(1)), psi(1, 1, st%st_start), 1, 1, psi_desc(1), &
    hpsi(1, 1, st%st_start), 1, 1, psi_desc(1), &
    R_TOTYPE(M_ZERO), hs(1, 1), 1, 1, hs_desc(1))

  SAFE_ALLOCATE(evectors(1:nrow, 1:ncol))
  call profiling_out(TOSTRING(X(SCALAPACK_GEMM1)))

  call profiling_in(TOSTRING(X(SCALAPACK_DIAG)))

  ! now diagonalize
#ifdef HAVE_ELPA
  if (elpa_init(20170403) /= elpa_ok) then
    write(message(1),'(a)') "ELPA API version not supported"
    call messages_fatal(1, namespace=namespace)
  end if
  elpa => elpa_allocate()

  ! set parameters describing the matrix
  call elpa%set("na", st%nst, info)
  call elpa%set("nev", st%nst, info)
  call elpa%set("local_nrows", nrow, info)
  call elpa%set("local_ncols", ncol, info)
  call elpa%set("nblk", nbl, info)
  call elpa%set("mpi_comm_parent", st%dom_st_mpi_grp%comm%MPI_VAL, info)
  call elpa%set("process_row", st%dom_st_proc_grid%myrow, info)
  call elpa%set("process_col", st%dom_st_proc_grid%mycol, info)

  info = elpa%setup()

  ! one stage solver usually shows worse performance than two stage solver
  call elpa%set("solver", elpa_solver_2stage, info)

  ! call eigensolver
  call elpa%eigenvectors(hs, eigenval, evectors, info)

  ! error handling
  if (info /= elpa_ok) then
    write(message(1),'(a,i6,a,a)') "Error in ELPA, code: ", info, ", message: ", &
      elpa_strerr(info)
    call messages_fatal(1, namespace=namespace)
  end if

  call elpa_deallocate(elpa)
  call elpa_uninit()

#else
! Use ScaLAPACK function if ELPA not available

#ifdef R_TCOMPLEX

  call pzheev(jobz = 'V', uplo = 'U', n = st%nst, a = hs(1, 1) , ia = 1, ja = 1, desca = hs_desc(1), &
    w = eigenval(1), z = evectors(1, 1), iz = 1, jz = 1, descz = hs_desc(1), &
    work = rttmp, lwork = -1, rwork = ftmp, lrwork = -1, info = info)

  if (info /= 0) then
    write(message(1),'(a,i6)') "ScaLAPACK pzheev workspace query failure, error code = ", info
    call messages_fatal(1, namespace=namespace)
  end if

  lwork = nint(abs(rttmp))
  lrwork = nint(real(ftmp, real64) )

  SAFE_ALLOCATE(work(1:lwork))
  SAFE_ALLOCATE(rwork(1:lrwork))

  if (st%nst == 1) then
    ! pzheev from scalapack seems to return wrong eigenvectors for one state,
    ! so we do not call it in this case.
    eigenval(1) = real(hs(1, 1), real64)
    evectors(1, 1) = R_TOTYPE(M_ONE)
  else
    call pzheev(jobz = 'V', uplo = 'U', n = st%nst, a = hs(1, 1) , ia = 1, ja = 1, desca = hs_desc(1), &
      w = eigenval(1), z = evectors(1, 1), iz = 1, jz = 1, descz = hs_desc(1), &
      work = work(1), lwork = lwork, rwork = rwork(1), lrwork = lrwork, info = info)
  end if

  if (info /= 0) then
    write(message(1),'(a,i6)') "ScaLAPACK pzheev call failure, error code = ", info
    call messages_fatal(1, namespace=namespace)
  end if

  SAFE_DEALLOCATE_A(work)
  SAFE_DEALLOCATE_A(rwork)

#else

  call pdsyev(jobz = 'V', uplo = 'U', n = st%nst, a = hs(1, 1) , ia = 1, ja = 1, desca = hs_desc(1), &
    w = eigenval(1), z = evectors(1, 1), iz = 1, jz = 1, descz = hs_desc(1), work = rttmp, lwork = -1, info = info)

  if (info /= 0) then
    write(message(1),'(a,i6)') "ScaLAPACK pdsyev workspace query failure, error code = ", info
    call messages_fatal(1, namespace=namespace)
  end if

  lwork = nint(abs(rttmp))
  SAFE_ALLOCATE(work(1:lwork))

  call pdsyev(jobz = 'V', uplo = 'U', n = st%nst, a = hs(1, 1) , ia = 1, ja = 1, desca = hs_desc(1), &
    w = eigenval(1), z = evectors(1, 1), iz = 1, jz = 1, descz = hs_desc(1), work = work(1), lwork = lwork, info = info)

  if (info /= 0) then
    write(message(1),'(a,i6)') "ScaLAPACK pdsyev call failure, error code = ", info
    call messages_fatal(1, namespace=namespace)
  end if

  SAFE_DEALLOCATE_A(work)
#endif

#endif
!(HAVE_ELPA)

  call profiling_out(TOSTRING(X(SCALAPACK_DIAG)))

  SAFE_DEALLOCATE_A(hs)

  do ist = st%st_start, st%st_end
    call lalg_copy(mesh%np, st%d%dim, psi(:,:,ist), hpsi(:,:,ist))
  end do

  call profiling_in(TOSTRING(X(SCALAPACK_GEMM2)))
  call pblas_gemm('n', 'n', total_np, st%nst, st%nst, &
    R_TOTYPE(M_ONE), hpsi(1, 1, st%st_start), 1, 1, psi_desc(1), &
    evectors(1, 1), 1, 1, hs_desc(1), &
    R_TOTYPE(M_ZERO), psi(1, 1, st%st_start), 1, 1, psi_desc(1))
  call profiling_out(TOSTRING(X(SCALAPACK_GEMM2)))

  ! Recalculate the residues.
  do ist = st%st_start, st%st_end
    call X(hamiltonian_elec_apply_single)(hm, namespace, mesh, psi(:, :, ist) , hpsi(:, :, st%st_start), ist, ik)
    diff(ist) = X(states_elec_residue)(mesh, st%d%dim, hpsi(:, :, st%st_start), eigenval(ist), psi(:, :, ist))
  end do

  SAFE_DEALLOCATE_A(hpsi)
  SAFE_DEALLOCATE_A(hs)

  POP_SUB(X(subspace_diag_scalapack))

#endif /* SCALAPACK */
end subroutine X(subspace_diag_scalapack)

! ------------------------------------------------------
!>@brief Diagonalises the Hamiltonian in the subspace defined by the states.
subroutine X(subspace_diag_hamiltonian)(namespace, mesh, st, hm, ik, hmss, overlap)
  type(namespace_t),            intent(in)    :: namespace
  class(mesh_t),                intent(in)    :: mesh
  type(states_elec_t), target,  intent(inout) :: st
  type(hamiltonian_elec_t),     intent(in)    :: hm
  integer,                      intent(in)    :: ik
  R_TYPE, contiguous,           intent(out)   :: hmss(:, :)
  R_TYPE, contiguous, optional, intent(out)   :: overlap(:, :)

  integer       :: ib, ip
  R_TYPE, allocatable :: psi(:, :, :), hpsi(:, :, :)
  class(wfs_elec_t), allocatable :: hpsib(:)
  integer :: sp, size, block_size
  type(accel_mem_t) :: psi_buffer, hpsi_buffer, hmss_buffer, overlap_buffer

  PUSH_SUB(X(subspace_diag_hamiltonian))
  call profiling_in(TOSTRING(X(SUBSPACE_HAMILTONIAN)))

  SAFE_ALLOCATE_TYPE_ARRAY(wfs_elec_t, hpsib, (st%group%block_start:st%group%block_end))

  if (present(overlap)) then
    overlap = M_ZERO
  end if

  do ib = st%group%block_start, st%group%block_end
    call st%group%psib(ib, ik)%copy_to(hpsib(ib))
    call X(hamiltonian_elec_apply_batch)(hm, namespace, mesh, st%group%psib(ib, ik), hpsib(ib))
  end do

  if (st%are_packed() .and. accel_is_enabled()) then

    ASSERT(ubound(hmss, dim = 1) == st%nst)

    call accel_create_buffer(hmss_buffer, ACCEL_MEM_READ_WRITE, R_TYPE_VAL, st%nst*st%nst)
    call accel_set_buffer_to_zero(hmss_buffer, R_TYPE_VAL, st%nst*st%nst)

    ! we have to copy the blocks to a temporary array
    block_size = batch_points_block_size()

    ! need conversion to i8 to avoid possible overflow
    call accel_create_buffer(psi_buffer, ACCEL_MEM_READ_WRITE, R_TYPE_VAL, int(st%nst, int64)*st%d%dim*block_size)
    call accel_create_buffer(hpsi_buffer, ACCEL_MEM_READ_WRITE, R_TYPE_VAL, int(st%nst, int64)*st%d%dim*block_size)

    if (present(overlap)) then
      call accel_create_buffer(overlap_buffer, ACCEL_MEM_READ_WRITE, R_TYPE_VAL, st%nst*st%nst)
      call accel_set_buffer_to_zero(overlap_buffer, R_TYPE_VAL, st%nst*st%nst)
    end if

    if (st%parallel_in_states) then
      SAFE_ALLOCATE(psi(1:st%nst, 1:st%d%dim, 1:block_size))
      SAFE_ALLOCATE(hpsi(1:st%nst, 1:st%d%dim, 1:block_size))
    end if

    do sp = 1, mesh%np, block_size
      size = min(block_size, mesh%np - sp + 1)

      do ib = st%group%block_start, st%group%block_end
        ASSERT(R_TYPE_VAL == st%group%psib(ib, ik)%type())
        call batch_get_points(st%group%psib(ib, ik), sp, sp + size - 1, psi_buffer, st%nst, st%d%dim)
        call batch_get_points(hpsib(ib), sp, sp + size - 1, hpsi_buffer, st%nst, st%d%dim)
      end do

      if (st%parallel_in_states) then
        call accel_read_buffer(psi_buffer, int(st%nst, int64)*st%d%dim*block_size, psi)
        call states_elec_parallel_gather(st, (/st%d%dim, size/), psi)
        call accel_write_buffer(psi_buffer, int(st%nst, int64)*st%d%dim*block_size, psi)
        call accel_read_buffer(hpsi_buffer, int(st%nst, int64)*st%d%dim*block_size, hpsi)
        call states_elec_parallel_gather(st, (/st%d%dim, size/), hpsi)
        call accel_write_buffer(hpsi_buffer, int(st%nst, int64)*st%d%dim*block_size, hpsi)
      end if

      call X(accel_gemm)(transA = ACCEL_BLAS_N, &
        transB = ACCEL_BLAS_C,                  &
        M = int(st%nst, int64),                    &
        N = int(st%nst, int64),                    &
        K = int(size*st%d%dim, int64),             &
        alpha = R_TOTYPE(mesh%volume_element),  &
        A = hpsi_buffer,                        &
        offA = 0_int64,                            &
        lda = int(st%nst, int64),                  &
        B = psi_buffer,                         &
        offB = 0_int64,                            &
        ldb = int(st%nst, int64),                  &
        beta = R_TOTYPE(M_ONE),                 &
        C = hmss_buffer,                        &
        offC = 0_int64,                            &
        ldc = int(st%nst, int64))

      if (present(overlap)) then
        call X(accel_herk)(uplo = ACCEL_BLAS_UPPER, &
          trans = ACCEL_BLAS_N,                     &
          n = int(st%nst, int64),                      &
          k = int(size*st%d%dim, int64),               &
          alpha = mesh%volume_element,              &
          A = psi_buffer,                           &
          offa = 0_int64,                              &
          lda = int(st%nst, int64),                    &
          beta = M_ONE,                             &
          C = overlap_buffer,                       &
          offc = 0_int64,                              &
          ldc = int(st%nst, int64))
      end if

      call accel_finish()

    end do

    if (st%parallel_in_states) then
      SAFE_DEALLOCATE_A(psi)
      SAFE_DEALLOCATE_A(hpsi)
    end if

    call accel_release_buffer(psi_buffer)
    call accel_release_buffer(hpsi_buffer)

    call accel_read_buffer(hmss_buffer, st%nst*st%nst, hmss)
    call accel_release_buffer(hmss_buffer)
    if (present(overlap)) then
      call accel_read_buffer(overlap_buffer, st%nst*st%nst, overlap)
      call accel_release_buffer(overlap_buffer)
    end if

  else

#ifdef R_TREAL
    ! Assume 2 to be the memory_address_size
    block_size = max(40, cpu_hardware%l2%size/(2 * sizeof_real64 * st%nst))
#else
    block_size = max(20, cpu_hardware%l2%size/(2 * sizeof_complex64 * st%nst))
#endif

    hmss(1:st%nst, 1:st%nst) = M_ZERO

    SAFE_ALLOCATE(psi(1:st%nst, 1:st%d%dim, 1:block_size))
    SAFE_ALLOCATE(hpsi(1:st%nst, 1:st%d%dim, 1:block_size))

    do sp = 1, mesh%np, block_size
      size = min(block_size, mesh%np - sp + 1)

      do ib = st%group%block_start, st%group%block_end
        call batch_get_points(st%group%psib(ib, ik), sp, sp + size - 1, psi)
        call batch_get_points(hpsib(ib), sp, sp + size - 1, hpsi)
      end do

      if (st%parallel_in_states) then
        call states_elec_parallel_gather(st, (/st%d%dim, size/), psi)
        call states_elec_parallel_gather(st, (/st%d%dim, size/), hpsi)
      end if

      if (mesh%use_curvilinear) then
        do ip = 1, size
          psi(1:st%nst, 1:st%d%dim, ip) = psi(1:st%nst, 1:st%d%dim, ip)*mesh%vol_pp(sp + ip - 1)
        end do
      end if

      call blas_gemm(transa = 'n',             &
        transb = 'c',                          &
        m = st%nst,                            &
        n = st%nst,                            &
        k = size*st%d%dim,                     &
        alpha = R_TOTYPE(mesh%volume_element), &
        a = hpsi(1, 1, 1),                     &
        lda = ubound(hpsi, dim = 1),           &
        b = psi(1, 1, 1),                      &
        ldb = ubound(psi, dim = 1),            &
        beta = R_TOTYPE(M_ONE),                &
        c = hmss(1, 1),                        &
        ldc = ubound(hmss, dim = 1))

      if (present(overlap)) then
        call blas_herk(uplo = 'u',     &
          trans = 'n',                 &
          n = st%nst,                  &
          k = size*st%d%dim,           &
          alpha = mesh%volume_element, &
          a = psi(1, 1, 1),            &
          lda = ubound(psi, dim = 1),  &
          beta = M_ONE,                &
          c = overlap(1, 1),           &
          ldc = ubound(overlap, dim = 1))
      end if

    end do

    SAFE_DEALLOCATE_A(psi)
    SAFE_DEALLOCATE_A(hpsi)

  end if

  call profiling_count_operations((R_ADD + R_MUL)*st%nst*(st%nst - M_ONE)*mesh%np)

  do ib = st%group%block_start, st%group%block_end
    call hpsib(ib)%end()
  end do

  SAFE_DEALLOCATE_A(hpsib)

  call mesh%allreduce(hmss, dim = (/st%nst, st%nst/))

  if (present(overlap)) then

    call profiling_count_operations((R_ADD + R_MUL)*M_HALF*st%nst*st%d%dim*(st%nst - M_ONE)*mesh%np)

    call mesh%allreduce(overlap, dim = (/st%nst, st%nst/))

    call upper_triangular_to_hermitian(st%nst, overlap)
  end if

  call profiling_out(TOSTRING(X(SUBSPACE_HAMILTONIAN)))
  POP_SUB(X(subspace_diag_hamiltonian))

end subroutine X(subspace_diag_hamiltonian)


!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
