!! Copyright (C) 2002-2006 M. Marques, A. Castro, A. Rubio, G. Bertsch
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!
#include "global.h"

module states_elec_oct_m
  use accel_oct_m
  use batch_oct_m
  use batch_ops_oct_m
  use blacs_proc_grid_oct_m
  use boundaries_oct_m
  use batch_oct_m
  use batch_ops_oct_m
  use calc_mode_par_oct_m
  use comm_oct_m
  use debug_oct_m
  use derivatives_oct_m
  use distributed_oct_m
  use electron_space_oct_m
  use global_oct_m
  use grid_oct_m
  use io_oct_m
  use, intrinsic :: iso_fortran_env
  use kpoints_oct_m
  use lalg_basic_oct_m
  use loct_oct_m
  use math_oct_m
  use mesh_oct_m
  use mesh_batch_oct_m
  use mesh_function_oct_m
  use messages_oct_m
  use modelmb_particles_oct_m
  use mpi_oct_m
  use multicomm_oct_m
  use namespace_oct_m
#ifdef HAVE_OPENMP
  use omp_lib
#endif
  use parser_oct_m
  use profiling_oct_m
  use quickrnd_oct_m
  use restart_oct_m
  use smear_oct_m
  use space_oct_m
  use states_abst_oct_m
  use states_elec_group_oct_m
  use states_elec_dim_oct_m
  use types_oct_m
  use unit_oct_m
  use unit_system_oct_m
  use varinfo_oct_m
  use wfs_elec_oct_m

  implicit none

  private

  public ::                           &
    states_elec_t,                         &
    states_elec_init,                      &
    states_elec_look,                      &
    states_elec_densities_init,            &
    states_elec_exec_init,                 &
    states_elec_allocate_wfns,             &
    states_elec_allocate_current,          &
    states_elec_deallocate_wfns,           &
    states_elec_null,                      &
    states_elec_end,                       &
    states_elec_copy,                      &
    states_elec_generate_random,           &
    states_elec_fermi,                     &
    states_elec_eigenvalues_sum,           &
    states_elec_calc_quantities,           &
    state_is_local,                        &
    state_kpt_is_local,                    &
    states_elec_choose_kpoints,            &
    states_elec_distribute_nodes,          &
    states_elec_wfns_memory,               &
    states_elec_get_state,                 &
    states_elec_set_state,                 &
    states_elec_get_points,                &
    states_elec_block_min,                 &
    states_elec_block_max,                 &
    states_elec_block_size,                &
    states_elec_count_pairs,               &
    occupied_states,                       &
    states_elec_set_zero,                  &
    states_elec_generate_random_vector,    &
    stress_t,                              &
    kpoints_distribute


  ! this type must be moved to stress module but due to circular dependency it is not possible now
  type stress_t
    real(real64)  :: total(3,3)           = M_ZERO
    real(real64)  :: kinetic(3,3)         = M_ZERO
    real(real64)  :: Hartree(3,3)         = M_ZERO
    real(real64)  :: xc(3,3)              = M_ZERO
    real(real64)  :: ps_local(3,3)        = M_ZERO
    real(real64)  :: ps_nl(3,3)           = M_ZERO
    real(real64)  :: ion_ion(3,3)         = M_ZERO
    real(real64)  :: vdw(3,3)             = M_ZERO
    real(real64)  :: hubbard(3,3)         = M_ZERO

    real(real64)  :: kinetic_sumrule !< Violation of the kinetic pressure sumrule
    real(real64)  :: hartree_sumrule !< Violation of the Hartree pressure sumrule
  end type stress_t

  ! TODO(Alex) Issue #672 Decouple k-point info from `states_elec_dim_t`

  !> @brief The states_elec_t class contains all electronic wave functions
  !!
  !! and also data, which is derived from the wave functions, e.g. the density
  !! currents, etc., but also eigenenergies.
  !!
  !! The wave functions, themselves, are stored in groups in the type
  !! states_elec_group_oct_m::states_elec_group_t.
  !
  type, extends(states_abst_t) :: states_elec_t
    ! Components are public by default
    type(states_elec_dim_t)  :: d                     !< Spin dimensions; also contains k-points and weights
    integer                  :: nst_conv              !< Number of states to be converged for unocc calc.

    logical                  :: only_userdef_istates  !< only use user-defined states as initial states in propagation

    type(states_elec_group_t) :: group                !< Wave function plus blocking data
    integer :: block_size                !< @brief number of states per batch, as set by StatesBlockSize
    !!                                      input variable
    !!
    !!                                      The value is limited by the actual number of states.
    logical :: pack_states               !< @brief packing status as requested by StatesPack input variable.
    !!
    !!                                      set in states_elec_oct_m::states_elec_init.
    !!                                      This does _not_ represent the actual status, but whether states _should_ be packed.


    character(len=1024), allocatable :: user_def_states(:,:,:) !< used for the user-defined wavefunctions
    !!                                                            (they are stored as formula strings)
    !!                                                            dimensions (st\%d\%dim, st\%nst, st\%d\%nik)

    ! TODO(Alex) Issue #821.  Collate current quantities into an object.
    ! the densities and currents (after all we are doing DFT :)
    real(real64), allocatable :: rho(:,:)               !< rho,                   dimension (gr\%np_part, st\%d\%nspin)
    real(real64), allocatable :: rho_core(:)            !< core charge for nl core corrections

    real(real64), allocatable :: current(:, :, :)       !< total current,         dimension (gr\%np_part, space\%dim, st\%d\%nspin)
    real(real64), allocatable :: current_para(:, :, :)  !< paramagnetic current,  dimension (gr\%np_part, space\%dim, st\%d\%nspin)
    real(real64), allocatable :: current_dia(:, :, :)   !< diamagnetic current,   dimension (gr\%np_part, space\%dim, st\%d\%nspin)
    real(real64), allocatable :: current_mag(:, :, :)   !< magnetization current, dimension (gr\%np_part, space\%dim, st\%d\%nspin)
    real(real64), allocatable :: current_kpt(:,:,:)     !< k-point resolved current, dimension (gr%np space%dim, kpt_start:kpt_end)

    ! TODO(Alex) Issue #673. Create frozen density class and replace in states_elec_t
    ! It may be required to "freeze" the deepest orbitals during the evolution; the density
    ! of these orbitals is kept in frozen_rho. It is different from rho_core.
    real(real64), allocatable :: frozen_rho(:, :)       !< frozen density
    real(real64), allocatable :: frozen_tau(:, :)       !< frozen kinetic energy density
    real(real64), allocatable :: frozen_gdens(:,:,:)    !< frozen gradient of density
    real(real64), allocatable :: frozen_ldens(:,:)      !< frozen lapacian of density

    logical            :: uniform_occ   !< .true. if occupations are equal for all states: no empty states, and no smearing

    real(real64), allocatable :: eigenval(:,:)        !< eigenvalues (st\%nst, st\%d\%nik)
    logical            :: fixed_occ            !< should the occupation numbers be fixed?
    logical            :: restart_fixed_occ    !< should the occupation numbers be fixed by restart?
    logical            :: restart_reorder_occs !< used for restart with altered occupation numbers
    real(real64), allocatable :: occ(:,:)             !< the occupation numbers
    real(real64), allocatable :: kweights(:)          !< weights for the k-point integrations
    integer            :: nik                  !< Number of irreducible subspaces

    logical, private   :: fixed_spins          !< In spinors mode, the spin direction is set
    !!                                            for the initial (random) orbitals.
    real(real64), allocatable :: spin(:, :, :)        !< spin orientations, dimension (1:3, 1:st\%nst, 1:st\%d\%nik)

    real(real64)   :: qtot          !< (-) The total charge in the system (used in Fermi)
    real(real64)   :: val_charge    !< valence charge

    type(stress_t) :: stress_tensors

    logical        :: fromScratch
    type(smear_t)  :: smear         !< smearing of the electronic occupations

    ! TODO(Alex) Issue #823 Move modelmbparticles out of states_elec_t
    type(modelmb_particle_t) :: modelmbparticles
    integer, allocatable :: mmb_nspindown(:,:) !< number of down spins in the selected Young diagram for each type and state
    integer, allocatable :: mmb_iyoung(:,:)    !< index of the selected Young diagram for each type and state
    real(real64),   allocatable :: mmb_proj(:)        !< projection of the state onto the chosen Young diagram

    logical                     :: parallel_in_states = .false. !< Am I parallel in states?

    ! TODO(Alex) Issue #824. Package the communicators in a single instance prior to removing
    ! or consider creating a distributed_t instance for each (as distributed_t contains the an instance of mpi_grp_t)
    type(mpi_grp_t)             :: mpi_grp              !< The MPI group related to the parallelization in states.
    type(mpi_grp_t)             :: dom_st_mpi_grp       !< The MPI group related to the domains-states "plane".
    type(mpi_grp_t)             :: st_kpt_mpi_grp       !< The MPI group related to the states-kpoints "plane".
    type(mpi_grp_t)             :: dom_st_kpt_mpi_grp   !< The MPI group related to the domains-states-kpoints "cube".
    type(blacs_proc_grid_t)     :: dom_st_proc_grid     !< The BLACS process grid for the domains-states plane

    type(distributed_t)         :: dist                 !< states distribution over processes
    logical                     :: scalapack_compatible !< Whether the states parallelization uses ScaLAPACK layout

    ! TODO(Alex) Issue #820. Remove lnst, st_start, st_end and node, as they are all contained within dist
    integer                     :: lnst                 !< Number of states on local process.
    integer                     :: st_start, st_end     !< Range of states processed by local process.
    integer, allocatable        :: node(:)              !< To which node belongs each state.

    ! TODO(Alex) Issue #824. Either change from data to a method, or package with `st_kpt_mpi_grp`
    integer, allocatable        :: st_kpt_task(:,:)     !< For a given task, what are kpt and st start/end

    type(multicomm_all_pairs_t), private :: ap          !< All-pairs schedule.
    logical                     :: symmetrize_density
    integer                     :: randomization      !< Method used to generate random states
    integer                     :: orth_method = 0    !< @brief orthogonalization as requested by StatesOrthogonalization

    real(real64)   :: cl_states_mem

  contains
    procedure :: nullify => states_elec_null            !< @copydoc states_elec_null
    procedure :: write_info => states_elec_write_info   !< @copydoc states_elec_write_info
    procedure :: pack => states_elec_pack               !< @copydoc states_elec_pack
    procedure :: unpack => states_elec_unpack           !< @copydoc states_elec_unpack
    procedure :: set_zero => states_elec_set_zero       !< @copydoc states_elec_set_zero
    procedure :: dipole => states_elec_calculate_dipole !< @copydoc states_elec_calculate_dipole
  end type states_elec_t

  !> Method used to generate random states
  integer, public, parameter :: &
    PAR_INDEPENDENT = 1,              &
    PAR_DEPENDENT   = 2


  interface states_elec_get_state
    module procedure dstates_elec_get_state1, zstates_elec_get_state1, dstates_elec_get_state2, zstates_elec_get_state2
    module procedure dstates_elec_get_state3, zstates_elec_get_state3, dstates_elec_get_state4, zstates_elec_get_state4
  end interface states_elec_get_state

  interface states_elec_set_state
    module procedure dstates_elec_set_state1, zstates_elec_set_state1, dstates_elec_set_state2, zstates_elec_set_state2
    module procedure dstates_elec_set_state3, zstates_elec_set_state3, dstates_elec_set_state4, zstates_elec_set_state4
  end interface states_elec_set_state

  interface states_elec_get_points
    module procedure dstates_elec_get_points1, zstates_elec_get_points1, dstates_elec_get_points2, zstates_elec_get_points2
  end interface states_elec_get_points

  interface states_elec_generate_random_vector
    module procedure dstates_elec_generate_random_vector, zstates_elec_generate_random_vector
  end interface

contains

  ! TODO(Alex): Issue #826. Rename to something like "states_elec_default_wfs_type", or remove
  subroutine states_elec_null(st)
    class(states_elec_t), intent(inout) :: st

    PUSH_SUB(states_elec_null)

    st%wfs_type = TYPE_FLOAT ! By default, calculations use real wavefunctions

    st%packed = .false.

    POP_SUB(states_elec_null)
  end subroutine states_elec_null


  !> @brief Initialize a new states_elec_t object
  subroutine states_elec_init(st, namespace, space, valence_charge, kpoints)
    type(states_elec_t), target, intent(inout) :: st
    type(namespace_t),           intent(in)    :: namespace
    type(electron_space_t),      intent(in)    :: space
    real(real64),                intent(in)    :: valence_charge
    type(kpoints_t),             intent(in)    :: kpoints

    real(real64) :: excess_charge, nempty_percent
    integer :: nempty, ntot, default
    integer :: nempty_conv
    logical :: force
    real(real64), parameter :: tol = 1e-13_real64

    PUSH_SUB_WITH_PROFILE(states_elec_init)

    st%fromScratch = .true. ! this will be reset if restart_read is called
    call states_elec_null(st)

    ! We get the spin dimension from the electronic space
    ! TODO: Remove spin space information from states_elec_dim
    st%d%ispin = space%ispin

    ! Use of spinors requires complex wavefunctions.
    if (st%d%ispin == SPINORS) call states_set_complex(st)

    if (st%d%ispin /= UNPOLARIZED .and. kpoints%use_time_reversal) then
      message(1) = "Time reversal symmetry is only implemented for unpolarized spins."
      message(2) = "Use KPointsUseTimeReversal = no."
      call messages_fatal(2, namespace=namespace)
    end if


    !%Variable ExcessCharge
    !%Type float
    !%Default 0.0
    !%Section States
    !%Description
    !% The net charge of the system. A negative value means that we are adding
    !% electrons, while a positive value means we are taking electrons
    !% from the system.
    !%End
    call parse_variable(namespace, 'ExcessCharge', M_ZERO, excess_charge)

    !%Variable TotalStates
    !%Type integer
    !%Default 0
    !%Section States
    !%Description
    !% This variable sets the total number of states that Octopus will
    !% use. This is normally not necessary since by default Octopus
    !% sets the number of states to the minimum necessary to hold the
    !% electrons present in the system. (This default behavior is
    !% obtained by setting <tt>TotalStates</tt> to 0).
    !%
    !% If you want to add some unoccupied states, probably it is more convenient to use the variable
    !% <tt>ExtraStates</tt>.
    !%End
    call parse_variable(namespace, 'TotalStates', 0, ntot)
    if (ntot < 0) then
      write(message(1), '(a,i5,a)') "Input: '", ntot, "' is not a valid value for TotalStates."
      call messages_fatal(1, namespace=namespace)
    end if

    !%Variable ExtraStates
    !%Type integer
    !%Default 0
    !%Section States
    !%Description
    !% The number of states is in principle calculated considering the minimum
    !% numbers of states necessary to hold the electrons present in the system.
    !% The number of electrons is
    !% in turn calculated considering the nature of the species supplied in the
    !% <tt>Species</tt> block, and the value of the <tt>ExcessCharge</tt> variable.
    !% However, one may command <tt>Octopus</tt> to use more states, which is necessary if one wants to
    !% use fractional occupational numbers, either fixed from the beginning through
    !% the <tt>Occupations</tt> block or by prescribing
    !% an electronic temperature with <tt>Smearing</tt>, or in order to calculate
    !% excited states (including with <tt>CalculationMode = unocc</tt>).
    !%End
    call parse_variable(namespace, 'ExtraStates', 0, nempty)
    if (nempty < 0) then
      write(message(1), '(a,i5,a)') "Input: '", nempty, "' is not a valid value for ExtraStates."
      message(2) = '(0 <= ExtraStates)'
      call messages_fatal(2, namespace=namespace)
    end if

    if (ntot > 0 .and. nempty > 0) then
      message(1) = 'You cannot set TotalStates and ExtraStates at the same time.'
      call messages_fatal(1, namespace=namespace)
    end if

    !%Variable ExtraStatesInPercent
    !%Type float
    !%Default 0
    !%Section States
    !%Description
    !% This variable allows to set the number of extra/empty states as percentage of the
    !% used occupied states. For example, a value 35 for ExtraStatesInPercent would amount
    !% to ceiling(35/100 * nstates) extra states, where nstates denotes the amount of occupied
    !% states Octopus is using for the system at hand.
    !%End
    call parse_variable(namespace, 'ExtraStatesInPercent', M_ZERO, nempty_percent)
    if (nempty_percent < 0) then
      write(message(1), '(a,f8.6,a)') "Input: '", nempty_percent, &
        "' should be a percentage value x (where x is parts in hundred) larger or equal 0"
      call messages_fatal(1, namespace=namespace)
    end if

    if (nempty > 0 .and. nempty_percent > 0) then
      message(1) = 'You cannot set ExtraStates and ExtraStatesInPercent at the same time.'
      call messages_fatal(1, namespace=namespace)
    end if

    !%Variable ExtraStatesToConverge
    !%Type integer
    !%Default <tt>ExtraStates</tt> (Default 0)
    !%Section States
    !%Description
    !% For <tt>gs</tt> and <tt>unocc</tt> calculations. (For the <tt>gs</tt> calculation one needs to set <tt>ConvEigenError=yes</tt>)
    !% Specifies the number of extra states that will be considered for reaching the convergence.
    !% The calculation will consider the number off occupied states plus
    !% <tt>ExtraStatesToConverge</tt> for the convergence criteria.
    !% By default, all extra states need to be converged (For <tt>gs</tt> calculations only with <tt>ConvEigenError=yes</tt>).
    !% Thus, together with <tt>ExtraStates</tt>, one can have some more states which will not be
    !% considered for the convergence criteria, thus making the convergence of the
    !% unocc calculation faster.
    !%End
    call parse_variable(namespace, 'ExtraStatesToConverge', nempty, nempty_conv)
    if (nempty < 0) then
      write(message(1), '(a,i5,a)') "Input: '", nempty_conv, "' is not a valid value for ExtraStatesToConverge."
      message(2) = '(0 <= ExtraStatesToConverge)'
      call messages_fatal(2, namespace=namespace)
    end if

    if (nempty_conv > nempty) then
      message(1) = 'You cannot set ExtraStatesToConverge to a higher value than ExtraStates.'
      call messages_fatal(1, namespace=namespace)
    end if

    ! For non-periodic systems this should just return the Gamma point
    call states_elec_choose_kpoints(st, kpoints, namespace)

    st%val_charge = valence_charge

    st%qtot = -(st%val_charge + excess_charge)

    if (st%qtot < -M_EPSILON) then
      write(message(1),'(a,f12.6,a)') 'Total charge = ', st%qtot, ' < 0'
      message(2) = 'Check Species and ExcessCharge.'
      call messages_fatal(2, only_root_writes = .true., namespace=namespace)
    end if

    select case (st%d%ispin)
    case (UNPOLARIZED)
      st%d%dim = 1
      st%nst = nint(st%qtot/2)
      if (st%nst*2 - st%qtot < -tol) st%nst = st%nst + 1
      st%d%nspin = 1
      st%d%spin_channels = 1
    case (SPIN_POLARIZED)
      st%d%dim = 1
      st%nst = nint(st%qtot/2)
      if (st%nst*2 - st%qtot < -tol) st%nst = st%nst + 1
      st%d%nspin = 2
      st%d%spin_channels = 2
    case (SPINORS)
      st%d%dim = 2
      st%nst = nint(st%qtot)
      if (st%nst - st%qtot < -tol) st%nst = st%nst + 1
      st%d%nspin = 4
      st%d%spin_channels = 2
    end select

    if (ntot > 0) then
      if (ntot < st%nst) then
        message(1) = 'TotalStates is smaller than the number of states required by the system.'
        call messages_fatal(1, namespace=namespace)
      end if

      st%nst = ntot
    end if

    if (nempty_percent > 0) then
      nempty = ceiling(nempty_percent * st%nst / 100)
    end if

    st%nst_conv = st%nst + nempty_conv
    st%nst = st%nst + nempty
    if (st%nst == 0) then
      message(1) = "Cannot run with number of states = zero."
      call messages_fatal(1, namespace=namespace)
    end if

    !%Variable StatesBlockSize
    !%Type integer
    !%Section Execution::Optimization
    !%Description
    !% Some routines work over blocks of eigenfunctions, which
    !% generally improves performance at the expense of increased
    !% memory consumption. This variable selects the size of the
    !% blocks to be used. If GPUs are used, the default is the
    !% warp size (32 for NVIDIA, 32 or 64 for AMD);
    !% otherwise it is 4.
    !%End

    if (accel_is_enabled()) then
      ! Some AMD GPUs have a warp size of 64.  When OpenCL is used
      ! accel%warp_size = 1 which is why we use max(accel%warp_size, 32)
      ! here so that StatesBlockSize is at least 32.
      default = max(accel%warp_size, 32)
    else
      default = 4
    end if

    if (default > pad_pow2(st%nst)) default = pad_pow2(st%nst)

    ASSERT(default > 0)

    call parse_variable(namespace, 'StatesBlockSize', default, st%block_size)
    if (st%block_size < 1) then
      call messages_write("The variable 'StatesBlockSize' must be greater than 0.")
      call messages_fatal(namespace=namespace)
    end if

    st%block_size = min(st%block_size, st%nst)
    conf%target_states_block_size = st%block_size

    SAFE_ALLOCATE(st%eigenval(1:st%nst, 1:st%nik))
    st%eigenval = huge(st%eigenval)

    ! Periodic systems require complex wavefunctions
    ! but not if it is Gamma-point only
    if (.not. kpoints%gamma_only()) then
      call states_set_complex(st)
    end if

    !%Variable OnlyUserDefinedInitialStates
    !%Type logical
    !%Default no
    !%Section States
    !%Description
    !% If true, then only user-defined states from the block <tt>UserDefinedStates</tt>
    !% will be used as initial states for a time-propagation. No attempt is made
    !% to load ground-state orbitals from a previous ground-state run.
    !%End
    call parse_variable(namespace, 'OnlyUserDefinedInitialStates', .false., st%only_userdef_istates)

    ! we now allocate some arrays
    SAFE_ALLOCATE(st%occ     (1:st%nst, 1:st%nik))
    st%occ      = M_ZERO
    ! allocate space for formula strings that define user-defined states
    if (parse_is_defined(namespace, 'UserDefinedStates') .or. parse_is_defined(namespace, 'OCTInitialUserdefined') &
      .or. parse_is_defined(namespace, 'OCTTargetUserdefined')) then
      SAFE_ALLOCATE(st%user_def_states(1:st%d%dim, 1:st%nst, 1:st%nik))
      ! initially we mark all 'formulas' as undefined
      st%user_def_states(1:st%d%dim, 1:st%nst, 1:st%nik) = 'undefined'
    end if

    if (st%d%ispin == SPINORS) then
      SAFE_ALLOCATE(st%spin(1:3, 1:st%nst, 1:st%nik))
    end if

    !%Variable StatesRandomization
    !%Type integer
    !%Default par_independent
    !%Section States
    !%Description
    !% The randomization of states can be done in two ways:
    !% i) a parallelisation independent way (default), where the random states are identical,
    !% irrespectively of the number of tasks and
    !% ii) a parallelisation dependent way, which can prevent linear dependency
    !%  to occur for large systems.
    !%Option par_independent 1
    !% Parallelisation-independent randomization of states.
    !%Option par_dependent 2
    !% The randomization depends on the number of taks used in the calculation.
    !%End
    call parse_variable(namespace, 'StatesRandomization', PAR_INDEPENDENT, st%randomization)


    call states_elec_read_initial_occs(st, namespace, excess_charge, kpoints)
    call states_elec_read_initial_spins(st, namespace)

    ! This test can only be done here, as smear_init is called inside states_elec_read_initial_occs, and
    ! only there smear%photodop is set.

    if (st%smear%photodop) then
      if (nempty == 0) then
        write(message(1), '(a,i5,a)') "PhotoDoping requires to specify ExtraStates."
        message(2) = '(0 == ExtraStates)'
        call messages_fatal(2, namespace=namespace)
      end if
    end if

    st%st_start = 1
    st%st_end = st%nst
    st%lnst = st%nst
    SAFE_ALLOCATE(st%node(1:st%nst))
    st%node(1:st%nst) = 0

    call mpi_grp_init(st%mpi_grp, MPI_COMM_UNDEFINED)
    st%parallel_in_states = .false.

    call distributed_nullify(st%d%kpt, st%nik)

    call modelmb_particles_init(st%modelmbparticles, namespace, space)
    if (st%modelmbparticles%nparticle > 0) then
      ! FIXME: check why this is not initialized properly in the test, or why it is written out when not initialized
      SAFE_ALLOCATE(st%mmb_nspindown(1:st%modelmbparticles%ntype_of_particle, 1:st%nst))
      st%mmb_nspindown(:,:) = -1
      SAFE_ALLOCATE(st%mmb_iyoung(1:st%modelmbparticles%ntype_of_particle, 1:st%nst))
      st%mmb_iyoung(:,:) = -1
      SAFE_ALLOCATE(st%mmb_proj(1:st%nst))
      st%mmb_proj(:) = M_ZERO
    end if

    !%Variable SymmetrizeDensity
    !%Type logical
    !%Default no
    !%Section States
    !%Description
    !% When enabled the density is symmetrized. Currently, this can
    !% only be done for periodic systems. (Experimental.)
    !%End
    call parse_variable(namespace, 'SymmetrizeDensity', kpoints%use_symmetries, st%symmetrize_density)
    call messages_print_var_value('SymmetrizeDensity', st%symmetrize_density, namespace=namespace)

    !%Variable ForceComplex
    !%Type logical
    !%Default no
    !%Section Execution::Debug
    !%Description
    !% Normally <tt>Octopus</tt> determines automatically the type necessary
    !% for the wavefunctions. When set to yes this variable will
    !% force the use of complex wavefunctions.
    !%
    !% Warning: This variable is designed for testing and
    !% benchmarking and normal users need not use it.
    !%End
    call parse_variable(namespace, 'ForceComplex', .false., force)

    if (force) call states_set_complex(st)

    st%packed = .false.

    POP_SUB_WITH_PROFILE(states_elec_init)
  end subroutine states_elec_init

  ! ---------------------------------------------------------
  !> Reads the 'states' file in the restart directory, and finds out
  !! the nik, dim, and nst contained in it.
  !
  subroutine states_elec_look(restart, nik, dim, nst, ierr)
    type(restart_t), intent(in)  :: restart
    integer,         intent(out) :: nik
    integer,         intent(out) :: dim
    integer,         intent(out) :: nst
    integer,         intent(out) :: ierr

    character(len=256) :: lines(3)
    character(len=20)   :: char
    integer :: iunit

    PUSH_SUB(states_elec_look)

    ierr = 0

    iunit = restart_open(restart, 'states')
    call restart_read(restart, iunit, lines, 3, ierr)
    if (ierr == 0) then
      read(lines(1), *) char, nst
      read(lines(2), *) char, dim
      read(lines(3), *) char, nik
    end if
    call restart_close(restart, iunit)

    POP_SUB(states_elec_look)
  end subroutine states_elec_look

  ! ---------------------------------------------------------
  !> Reads from the input file the initial occupations, if the
  !! block "Occupations" is present. Otherwise, it makes an initial
  !! guess for the occupations, maybe using the "Smearing"
  !! variable.
  !!
  !! The resulting occupations are placed on the st\%occ variable. The
  !! boolean st\%fixed_occ is also set to .true., if the occupations are
  !! set by the user through the "Occupations" block; false otherwise.
  !
  subroutine states_elec_read_initial_occs(st, namespace, excess_charge, kpoints)
    type(states_elec_t),  intent(inout) :: st
    type(namespace_t),    intent(in)    :: namespace
    real(real64),         intent(in)    :: excess_charge
    type(kpoints_t),      intent(in)    :: kpoints

    integer :: ik, ist, ispin, nspin, ncols, nrows, el_per_state, icol, start_pos, spin_n
    type(block_t) :: blk
    real(real64) :: rr, charge
    logical :: integral_occs, unoccupied_states
    real(real64), allocatable :: read_occs(:, :)
    real(real64) :: charge_in_block

    PUSH_SUB(states_elec_read_initial_occs)

    !%Variable RestartFixedOccupations
    !%Type logical
    !%Default yes
    !%Section States
    !%Description
    !% Setting this variable will make the restart proceed as
    !% if the occupations from the previous calculation had been set via the <tt>Occupations</tt> block,
    !% <i>i.e.</i> fixed. Otherwise, occupations will be determined by smearing.
    !%End
    call parse_variable(namespace, 'RestartFixedOccupations', .true., st%restart_fixed_occ)
    ! we will turn on st%fixed_occ if restart_read is ever called

    !%Variable Occupations
    !%Type block
    !%Section States
    !%Description
    !% The occupation numbers of the orbitals can be fixed through the use of this
    !% variable. For example:
    !%
    !% <tt>%Occupations
    !% <br>&nbsp;&nbsp;2 | 2 | 2 | 2 | 2
    !% <br>%</tt>
    !%
    !% would fix the occupations of the five states to 2. There can be
    !% at most as many columns as states in the calculation. If there are fewer columns
    !% than states, then the code will assume that the user is indicating the occupations
    !% of the uppermost states where all lower states have full occupation (i.e. 2 for spin-unpolarized
    !% calculations, 1 otherwise) and all higher states have zero occupation. The first column
    !% will be taken to refer to the lowest state such that the occupations would be consistent
    !% with the correct total charge. For example, if there are 8 electrons and 10 states (from
    !% <tt>ExtraStates = 6</tt>), then an abbreviated specification
    !%
    !% <tt>%Occupations
    !% <br>&nbsp;&nbsp;1 | 0 | 1
    !% <br>%</tt>
    !%
    !% would be equivalent to a full specification
    !%
    !% <tt>%Occupations
    !% <br>&nbsp;&nbsp;2 | 2 | 2 | 1 | 0 | 1 | 0 | 0 | 0 | 0
    !% <br>%</tt>
    !%
    !% This is an example of use for constrained density-functional theory,
    !% crudely emulating a HOMO->LUMO+1 optical excitation.
    !% The number of rows should be equal
    !% to the number of k-points times the number of spins. For example, for a finite system
    !% with <tt>SpinComponents == spin_polarized</tt>,
    !% this block should contain two lines, one for each spin channel.
    !% All rows must have the same number of columns.
    !%
    !% The <tt>Occupations</tt> block is useful for the ground state of highly symmetric
    !% small systems (like an open-shell atom), to fix the occupation numbers
    !% of degenerate states in order to help <tt>octopus</tt> to converge. This is to
    !% be used in conjuction with <tt>ExtraStates</tt>. For example, to calculate the
    !% carbon atom, one would do:
    !%
    !% <tt>ExtraStates = 2
    !% <br>%Occupations
    !% <br>&nbsp;&nbsp;2 | 2/3 | 2/3 | 2/3
    !% <br>%</tt>
    !%
    !% If you want the calculation to be spin-polarized (which makes more sense), you could do:
    !%
    !% <tt>ExtraStates = 2
    !% <br>%Occupations
    !% <br>&nbsp;&nbsp; 2/3 | 2/3 | 2/3
    !% <br>&nbsp;&nbsp; 0   |   0 |   0
    !% <br>%</tt>
    !%
    !% Note that in this case the first state is absent, the code will calculate four states
    !% (two because there are four electrons, plus two because <tt>ExtraStates</tt> = 2), and since
    !% it finds only three columns, it will occupy the first state with one electron for each
    !% of the spin options.
    !%
    !% If the sum of occupations is not equal to the total charge set by <tt>ExcessCharge</tt>,
    !% an error message is printed.
    !% If <tt>FromScratch = no</tt> and <tt>RestartFixedOccupations = yes</tt>,
    !% this block will be ignored.
    !%End

    integral_occs = .true.

    occ_fix: if (parse_block(namespace, 'Occupations', blk) == 0) then
      ! read in occupations
      st%fixed_occ = .true.

      ncols = parse_block_cols(blk, 0)
      if (ncols > st%nst) then
        call messages_input_error(namespace, "Occupations", "Too many columns in block Occupations.")
      end if

      nrows = parse_block_n(blk)
      if (nrows /= st%nik) then
        call messages_input_error(namespace, "Occupations", "Wrong number of rows in block Occupations.")
      end if

      do ik = 1, st%nik - 1
        if (parse_block_cols(blk, ik) /= ncols) then
          call messages_input_error(namespace, "Occupations", &
            "All rows in block Occupations must have the same number of columns.")
        end if
      end do

      ! Now we fill all the "missing" states with the maximum occupation.
      if (st%d%ispin == UNPOLARIZED) then
        el_per_state = 2
      else
        el_per_state = 1
      end if

      SAFE_ALLOCATE(read_occs(1:ncols, 1:st%nik))

      charge_in_block = M_ZERO
      do ik = 1, st%nik
        do icol = 1, ncols
          call parse_block_float(blk, ik - 1, icol - 1, read_occs(icol, ik))
          charge_in_block = charge_in_block + read_occs(icol, ik) * st%kweights(ik)
        end do
      end do

      spin_n = 2
      select case (st%d%ispin)
      case (UNPOLARIZED)
        spin_n = 2
      case (SPIN_POLARIZED)
        spin_n = 2
      case (SPINORS)
        spin_n = 1
      end select

      start_pos = nint((st%qtot - charge_in_block)/spin_n)

      if (start_pos + ncols > st%nst) then
        message(1) = "To balance charge, the first column in block Occupations is taken to refer to state"
        write(message(2),'(a,i6,a)') "number ", start_pos, " but there are too many columns for the number of states."
        write(message(3),'(a,i6,a)') "Solution: set ExtraStates = ", start_pos + ncols - st%nst
        call messages_fatal(3, namespace=namespace)
      end if

      do ik = 1, st%nik
        do ist = 1, start_pos
          st%occ(ist, ik) = el_per_state
        end do
      end do

      do ik = 1, st%nik
        do ist = start_pos + 1, start_pos + ncols
          st%occ(ist, ik) = read_occs(ist - start_pos, ik)
          integral_occs = integral_occs .and. &
            abs((st%occ(ist, ik) - el_per_state) * st%occ(ist, ik))  <=  M_EPSILON
        end do
      end do

      do ik = 1, st%nik
        do ist = start_pos + ncols + 1, st%nst
          st%occ(ist, ik) = M_ZERO
        end do
      end do

      call parse_block_end(blk)

      SAFE_DEALLOCATE_A(read_occs)

    else
      st%fixed_occ = .false.
      integral_occs = .false.

      ! first guess for occupation...paramagnetic configuration
      rr = M_ONE
      if (st%d%ispin == UNPOLARIZED) rr = M_TWO

      st%occ  = M_ZERO
      st%qtot = -(st%val_charge + excess_charge)

      nspin = 1
      if (st%d%nspin == 2) nspin = 2

      do ik = 1, st%nik, nspin
        charge = M_ZERO
        do ist = 1, st%nst
          do ispin = ik, ik + nspin - 1
            st%occ(ist, ispin) = min(rr, -(st%val_charge + excess_charge) - charge)
            charge = charge + st%occ(ist, ispin)
          end do
        end do
      end do

    end if occ_fix

    !%Variable RestartReorderOccs
    !%Type logical
    !%Default no
    !%Section States
    !%Description
    !% Consider doing a ground-state calculation, and then restarting with new occupations set
    !% with the <tt>Occupations</tt> block, in an attempt to populate the orbitals of the original
    !% calculation. However, the eigenvalues may reorder as the density changes, in which case the
    !% occupations will now be referring to different orbitals. Setting this variable to yes will
    !% try to solve this issue when the restart data is being read, by reordering the occupations
    !% according to the order of the expectation values of the restart wavefunctions.
    !%End
    if (st%fixed_occ) then
      call parse_variable(namespace, 'RestartReorderOccs', .false., st%restart_reorder_occs)
    else
      st%restart_reorder_occs = .false.
    end if

    call smear_init(st%smear, namespace, st%d%ispin, st%fixed_occ, integral_occs, kpoints)

    unoccupied_states = (st%d%ispin /= SPINORS .and. st%nst*2 > st%qtot) .or. (st%d%ispin == SPINORS .and. st%nst > st%qtot)

    if (.not. smear_is_semiconducting(st%smear) .and. .not. st%smear%method == SMEAR_FIXED_OCC) then
      if (.not. unoccupied_states) then
        call messages_write('Smearing needs unoccupied states (via ExtraStates or TotalStates) to be useful.')
        call messages_warning(namespace=namespace)
      end if
    end if

    ! sanity check
    charge = M_ZERO
    do ist = 1, st%nst
      charge = charge + sum(st%occ(ist, 1:st%nik) * st%kweights(1:st%nik))
    end do
    if (abs(charge - st%qtot) > 1e-6_real64) then
      message(1) = "Initial occupations do not integrate to total charge."
      write(message(2), '(6x,f12.6,a,f12.6)') charge, ' != ', st%qtot
      call messages_fatal(2, only_root_writes = .true., namespace=namespace)
    end if

    st%uniform_occ = smear_is_semiconducting(st%smear) .and. .not. unoccupied_states

    POP_SUB(states_elec_read_initial_occs)
  end subroutine states_elec_read_initial_occs


  ! ---------------------------------------------------------
  !> Reads, if present, the "InitialSpins" block.
  !!
  !! This is only done in spinors mode; otherwise the routine does nothing.
  !! The resulting spins are placed onto the st\%spin pointer. The boolean
  !! st\%fixed_spins is set to true if (and only if) the InitialSpins
  !! block is present.
  !
  subroutine states_elec_read_initial_spins(st, namespace)
    type(states_elec_t), intent(inout) :: st
    type(namespace_t),   intent(in)    :: namespace

    integer :: i, j
    type(block_t) :: blk

    PUSH_SUB(states_elec_read_initial_spins)

    st%fixed_spins = .false.
    if (st%d%ispin /= SPINORS) then
      POP_SUB(states_elec_read_initial_spins)
      return
    end if

    !%Variable InitialSpins
    !%Type block
    !%Section States
    !%Description
    !% The spin character of the initial random guesses for the spinors can
    !% be fixed by making use of this block. Note that this will not "fix" the
    !% the spins during the calculation (this cannot be done in spinors mode, in
    !% being able to change the spins is why the spinors mode exists in the first
    !% place).
    !%
    !% This block is meaningless and ignored if the run is not in spinors mode
    !% (<tt>SpinComponents = spinors</tt>).
    !%
    !% The structure of the block is very simple: each column contains the desired
    !% <math>\left< S_x \right>, \left< S_y \right>, \left< S_z \right> </math> for each spinor.
    !% If the calculation is for a periodic system
    !% and there is more than one <i>k</i>-point, the spins of all the <i>k</i>-points are
    !% the same.
    !%
    !% For example, if we have two spinors, and we want one in the <math>S_x</math> "down" state,
    !% and another one in the <math>S_x</math> "up" state:
    !%
    !% <tt>%InitialSpins
    !% <br>&nbsp;&nbsp;&nbsp; 0.5 | 0.0 | 0.0
    !% <br>&nbsp;&nbsp; -0.5 | 0.0 | 0.0
    !% <br>%</tt>
    !%
    !% WARNING: if the calculation is for a system described by pseudopotentials (as
    !% opposed to user-defined potentials or model systems), this option is
    !% meaningless since the random spinors are overwritten by the atomic orbitals.
    !%
    !% This constraint must be fulfilled:
    !% <br><math> \left< S_x \right>^2 + \left< S_y \right>^2 + \left< S_z \right>^2 = \frac{1}{4} </math>
    !%End
    spin_fix: if (parse_block(namespace, 'InitialSpins', blk) == 0) then
      do i = 1, st%nst
        do j = 1, 3
          call parse_block_float(blk, i-1, j-1, st%spin(j, i, 1))
        end do
        if (abs(sum(st%spin(1:3, i, 1)**2) - M_FOURTH) > 1.0e-6_real64) call messages_input_error(namespace, 'InitialSpins')
      end do
      call parse_block_end(blk)
      st%fixed_spins = .true.
      do i = 2, st%nik
        st%spin(:, :, i) = st%spin(:, :, 1)
      end do
    end if spin_fix

    POP_SUB(states_elec_read_initial_spins)
  end subroutine states_elec_read_initial_spins


  ! ---------------------------------------------------------
  !> Allocates the KS wavefunctions defined within a states_elec_t structure.
  !
  subroutine states_elec_allocate_wfns(st, mesh, wfs_type, skip, packed)
    type(states_elec_t),    intent(inout)   :: st       !< the states
    class(mesh_t),          intent(in)      :: mesh     !< underlying mesh
    type(type_t), optional, intent(in)      :: wfs_type !< optional type; either TYPE\_FLOAT or TYPE\_FLOAT
    logical,      optional, intent(in)      :: skip(:)  !< optional array of states to skip
    logical,      optional, intent(in)      :: packed   !< optional flag whether to pack?

    PUSH_SUB(states_elec_allocate_wfns)

    if (present(wfs_type)) then
      ASSERT(wfs_type == TYPE_FLOAT .or. wfs_type == TYPE_CMPLX)
      st%wfs_type = wfs_type
    end if

    call states_elec_init_block(st, mesh, skip = skip, packed=packed)
    call states_elec_set_zero(st)

    POP_SUB(states_elec_allocate_wfns)
  end subroutine states_elec_allocate_wfns

  !---------------------------------------------------------------------
  !> Initializes the data components in st that describe how the states
  !! are distributed in blocks:
  !!
  !! st\%nblocks: this is the number of blocks in which the states are divided. Note that
  !!   this number is the total number of blocks, regardless of how many are actually stored
  !!   in each node.
  !! block_start: in each node, the index of the first block.
  !! block_end: in each node, the index of the last block.
  !!   If the states are not parallelized, then block_start is 1 and block_end is st\%nblocks.
  !! st\%iblock(1:st\%nst): it points, for each state, to the block that contains it.
  !! st\%block_is_local(): st\%block_is_local(ib) is .true. if block ib is stored in the running node.
  !! st\%block_range(1:st\%nblocks, 1:2): Block ib contains states fromn st\%block_range(ib, 1) to st\%block_range(ib, 2)
  !! st\%block_size(1:st\%nblocks): Block ib contains a number st\%block_size(ib) of states.
  !! st\%block_initialized: it should be .false. on entry, and .true. after exiting this routine.
  !!
  !! The set of batches st\%psib(1:st\%nblocks) contains the blocks themselves.
  subroutine states_elec_init_block(st, mesh, verbose, skip, packed)
    type(states_elec_t),           intent(inout) :: st
    type(mesh_t),                  intent(in)    :: mesh
    logical, optional,             intent(in)    :: verbose !< default: false
    logical, optional,             intent(in)    :: skip(:) !< list of states to skip
    logical, optional,             intent(in)    :: packed  !< default: false

    integer :: ib, iqn, ist, istmin, istmax
    logical :: same_node, verbose_, packed_
    integer, allocatable :: bstart(:), bend(:)

    PUSH_SUB(states_elec_init_block)

    SAFE_ALLOCATE(bstart(1:st%nst))
    SAFE_ALLOCATE(bend(1:st%nst))
    SAFE_ALLOCATE(st%group%iblock(1:st%nst))

    st%group%iblock = 0

    verbose_ = optional_default(verbose, .true.)
    packed_ = optional_default(packed, .false.)

    !In case we have a list of state to skip, we do not allocate them
    istmin = 1
    if (present(skip)) then
      do ist = 1, st%nst
        if (.not. skip(ist)) then
          istmin = ist
          exit
        end if
      end do
    end if

    istmax = st%nst
    if (present(skip)) then
      do ist = st%nst, istmin, -1
        if (.not. skip(ist)) then
          istmax = ist
          exit
        end if
      end do
    end if

    if (present(skip) .and. verbose_) then
      call messages_write('Info: Allocating states from ')
      call messages_write(istmin, fmt = 'i8')
      call messages_write(' to ')
      call messages_write(istmax, fmt = 'i8')
      call messages_info()
    end if

    ! count and assign blocks
    ib = 0
    st%group%nblocks = 0
    bstart(1) = istmin
    do ist = istmin, istmax
      ib = ib + 1

      st%group%iblock(ist) = st%group%nblocks + 1

      same_node = .true.
      if (st%parallel_in_states .and. ist /= istmax) then
        ! We have to avoid that states that are in different nodes end
        ! up in the same block
        same_node = (st%node(ist + 1) == st%node(ist))
      end if

      if (ib == st%block_size .or. ist == istmax .or. .not. same_node) then
        ib = 0
        st%group%nblocks = st%group%nblocks + 1
        bend(st%group%nblocks) = ist
        if (ist /= istmax) bstart(st%group%nblocks + 1) = ist + 1
      end if
    end do

    SAFE_ALLOCATE(st%group%psib(1:st%group%nblocks, st%d%kpt%start:st%d%kpt%end))

    SAFE_ALLOCATE(st%group%block_is_local(1:st%group%nblocks, 1:st%nik))
    st%group%block_is_local = .false.
    st%group%block_start  = -1
    st%group%block_end    = -2  ! this will make that loops block_start:block_end do not run if not initialized

    do ib = 1, st%group%nblocks
      if (bstart(ib) >= st%st_start .and. bend(ib) <= st%st_end) then
        if (st%group%block_start == -1) st%group%block_start = ib
        st%group%block_end = ib
        do iqn = st%d%kpt%start, st%d%kpt%end
          st%group%block_is_local(ib, iqn) = .true.

          if (states_are_real(st)) then
            call dwfs_elec_init(st%group%psib(ib, iqn), st%d%dim, bstart(ib), bend(ib), mesh%np_part, iqn, &
              special=.true., packed=packed_)
          else
            call zwfs_elec_init(st%group%psib(ib, iqn), st%d%dim, bstart(ib), bend(ib), mesh%np_part, iqn, &
              special=.true., packed=packed_)
          end if

        end do
      end if
    end do

    SAFE_ALLOCATE(st%group%block_range(1:st%group%nblocks, 1:2))
    SAFE_ALLOCATE(st%group%block_size(1:st%group%nblocks))

    st%group%block_range(1:st%group%nblocks, 1) = bstart(1:st%group%nblocks)
    st%group%block_range(1:st%group%nblocks, 2) = bend(1:st%group%nblocks)
    st%group%block_size(1:st%group%nblocks) = bend(1:st%group%nblocks) - bstart(1:st%group%nblocks) + 1

    st%group%block_initialized = .true.

    SAFE_ALLOCATE(st%group%block_node(1:st%group%nblocks, 1:st%nik))
    st%group%block_node = 0

    ASSERT(allocated(st%node))
    ASSERT(all(st%node >= 0) .and. all(st%node < st%mpi_grp%size))

    do iqn = st%d%kpt%start, st%d%kpt%end
      do ib = st%group%block_start, st%group%block_end
        st%group%block_node(ib, iqn) = st%st_kpt_mpi_grp%rank
      end do
    end do

    call comm_allreduce(st%st_kpt_mpi_grp, st%group%block_node)

    if (verbose_) then
      call messages_write('Info: Blocks of states')
      call messages_info()
      do ib = 1, st%group%nblocks
        call messages_write('      Block ')
        call messages_write(ib, fmt = 'i8')
        call messages_write(' contains ')
        call messages_write(st%group%block_size(ib), fmt = 'i8')
        call messages_write(' states')
        if (st%group%block_size(ib) > 0) then
          call messages_write(':')
          call messages_write(st%group%block_range(ib, 1), fmt = 'i8')
          call messages_write(' - ')
          call messages_write(st%group%block_range(ib, 2), fmt = 'i8')
        end if
        call messages_info()
      end do
    end if

!!$!!!!DEBUG
!!$    ! some debug output that I will keep here for the moment
!!$    if (mpi_grp_is_root(mpi_world)) then
!!$      print*, "NST       ", st%nst
!!$      print*, "BLOCKSIZE ", st%block_size
!!$      print*, "NBLOCKS   ", st%group%nblocks
!!$
!!$      print*, "==============="
!!$      do ist = 1, st%nst
!!$        print*, st%node(ist), ist, st%group%iblock(ist)
!!$      end do
!!$      print*, "==============="
!!$
!!$      do ib = 1, st%group%nblocks
!!$        print*, ib, bstart(ib), bend(ib)
!!$      end do
!!$
!!$    end if
!!$!!!!ENDOFDEBUG

    SAFE_DEALLOCATE_A(bstart)
    SAFE_DEALLOCATE_A(bend)
    POP_SUB(states_elec_init_block)
  end subroutine states_elec_init_block


  ! ---------------------------------------------------------
  !> Deallocates the KS wavefunctions defined within a states_elec_t structure.
  subroutine states_elec_deallocate_wfns(st)
    type(states_elec_t), intent(inout) :: st

    PUSH_SUB(states_elec_deallocate_wfns)

    call states_elec_group_end(st%group, st%d)

    POP_SUB(states_elec_deallocate_wfns)
  end subroutine states_elec_deallocate_wfns


  ! ---------------------------------------------------------
  subroutine states_elec_densities_init(st, gr)
    type(states_elec_t), target, intent(inout) :: st
    type(grid_t),                intent(in)    :: gr

    real(real64) :: fsize

    PUSH_SUB(states_elec_densities_init)

    SAFE_ALLOCATE(st%rho(1:gr%np_part, 1:st%d%nspin))
    st%rho = M_ZERO

    fsize = gr%np_part*8.0_real64*st%block_size

    call messages_write('Info: states-block size = ')
    call messages_write(fsize, fmt = '(f10.1)', align_left = .true., units = unit_megabytes, print_units = .true.)
    call messages_info()

    POP_SUB(states_elec_densities_init)
  end subroutine states_elec_densities_init

  !---------------------------------------------------------------------
  subroutine states_elec_allocate_current(st, space, mesh)
    type(states_elec_t), intent(inout) :: st
    class(space_t),      intent(in)    :: space
    class(mesh_t),       intent(in)    :: mesh

    PUSH_SUB(states_elec_allocate_current)

    if (.not. allocated(st%current)) then
      SAFE_ALLOCATE(st%current(1:mesh%np_part, 1:space%dim, 1:st%d%nspin))
      st%current = M_ZERO
    end if

    if (.not. allocated(st%current_para)) then
      SAFE_ALLOCATE(st%current_para(1:mesh%np_part, 1:space%dim, 1:st%d%nspin))
      st%current_para = M_ZERO
    end if

    if (.not. allocated(st%current_dia)) then
      SAFE_ALLOCATE(st%current_dia(1:mesh%np_part, 1:space%dim, 1:st%d%nspin))
      st%current_dia= M_ZERO
    end if

    if (.not. allocated(st%current_mag)) then
      SAFE_ALLOCATE(st%current_mag(1:mesh%np_part, 1:space%dim, 1:st%d%nspin))
      st%current_mag= M_ZERO
    end if

    if (.not. allocated(st%current_kpt)) then
      SAFE_ALLOCATE(st%current_kpt(1:mesh%np,1:space%dim,st%d%kpt%start:st%d%kpt%end))
      st%current_kpt = M_ZERO
    end if

    POP_SUB(states_elec_allocate_current)
  end subroutine states_elec_allocate_current

  !---------------------------------------------------------------------
  !> @brief Further initializations
  !!
  !! This subroutine:
  !! - Fills in the block size (st\%d\%block_size);
  !! - Finds out whether or not to pack the states (st\%d\%pack_states);
  !! - Finds out the orthogonalization method (st\%d\%orth_method).
  !!
  subroutine states_elec_exec_init(st, namespace, mc)
    type(states_elec_t),  intent(inout) :: st
    type(namespace_t),    intent(in)    :: namespace
    type(multicomm_t),    intent(in)    :: mc

    integer :: default

    PUSH_SUB(states_elec_exec_init)

    !%Variable StatesPack
    !%Type logical
    !%Section Execution::Optimization
    !%Description
    !% When set to yes, states are stored in packed mode, which improves
    !% performance considerably. Not all parts of the code will profit from
    !% this, but should nevertheless work regardless of how the states are
    !% stored.
    !%
    !% If GPUs are used and this variable is set to yes, Octopus
    !% will store the wave-functions in device (GPU) memory. If
    !% there is not enough memory to store all the wave-functions,
    !% execution will stop with an error.
    !%
    !% See also the related <tt>HamiltonianApplyPacked</tt> variable.
    !%
    !% The default is yes.
    !%End

    call parse_variable(namespace, 'StatesPack', .true., st%pack_states)

    call messages_print_var_value('StatesPack', st%pack_states, namespace=namespace)

    call messages_obsolete_variable(namespace, 'StatesMirror')

    !%Variable StatesOrthogonalization
    !%Type integer
    !%Section SCF::Eigensolver
    !%Description
    !% The full orthogonalization method used by some
    !% eigensolvers. The default is <tt>cholesky_serial</tt>, except with state
    !% parallelization, the default is <tt>cholesky_parallel</tt>.
    !%Option cholesky_serial 1
    !% Cholesky decomposition implemented using
    !% BLAS/LAPACK. Can be used with domain parallelization but not
    !% state parallelization.
    !%Option cholesky_parallel 2
    !% Cholesky decomposition implemented using
    !% ScaLAPACK. Compatible with states parallelization.
    !%Option cgs 3
    !% Classical Gram-Schmidt (CGS) orthogonalization.
    !% Can be used with domain parallelization but not state parallelization.
    !% The algorithm is defined in Giraud et al., Computers and Mathematics with Applications 50, 1069 (2005).
    !%Option mgs 4
    !% Modified Gram-Schmidt (MGS) orthogonalization.
    !% Can be used with domain parallelization but not state parallelization.
    !% The algorithm is defined in Giraud et al., Computers and Mathematics with Applications 50, 1069 (2005).
    !%Option drcgs 5
    !% Classical Gram-Schmidt orthogonalization with double-step reorthogonalization.
    !% Can be used with domain parallelization but not state parallelization.
    !% The algorithm is taken from Giraud et al., Computers and Mathematics with Applications 50, 1069 (2005).
    !% According to this reference, this is much more precise than CGS or MGS algorithms. The MGS version seems not to improve much the stability and would require more communications over the domains.
    !%End

    default = OPTION__STATESORTHOGONALIZATION__CHOLESKY_SERIAL
#ifdef HAVE_SCALAPACK
    if (multicomm_strategy_is_parallel(mc, P_STRATEGY_STATES)) then
      default = OPTION__STATESORTHOGONALIZATION__CHOLESKY_PARALLEL
    end if
#endif

    call parse_variable(namespace, 'StatesOrthogonalization', default, st%orth_method)

    if (.not. varinfo_valid_option('StatesOrthogonalization', st%orth_method)) then
      call messages_input_error(namespace, 'StatesOrthogonalization')
    end if
    call messages_print_var_option('StatesOrthogonalization', st%orth_method, namespace=namespace)

    !%Variable StatesCLDeviceMemory
    !%Type float
    !%Section Execution::Optimization
    !%Default -512
    !%Description
    !% This variable selects the amount of OpenCL device memory that
    !% will be used by Octopus to store the states.
    !%
    !% A positive number smaller than 1 indicates a fraction of the total
    !% device memory. A number larger than one indicates an absolute
    !% amount of memory in megabytes. A negative number indicates an
    !% amount of memory in megabytes that would be subtracted from
    !% the total device memory.
    !%End
    call parse_variable(namespace, 'StatesCLDeviceMemory', -512.0_real64, st%cl_states_mem)

    POP_SUB(states_elec_exec_init)
  end subroutine states_elec_exec_init


  ! ---------------------------------------------------------
  !> @brief make a (selective) copy of a states_elec_t object
  !
  subroutine states_elec_copy(stout, stin, exclude_wfns, exclude_eigenval, special)
    type(states_elec_t), target, intent(inout) :: stout            !< source
    type(states_elec_t),         intent(in)    :: stin             !< destination
    logical, optional,           intent(in)    :: exclude_wfns     !< do not copy wavefunctions, densities, node
    logical, optional,           intent(in)    :: exclude_eigenval !< do not copy eigenvalues, occ, spin
    logical, optional,           intent(in)    :: special          !< allocate on GPU

    logical :: exclude_wfns_

    PUSH_SUB(states_elec_copy)

    exclude_wfns_ = optional_default(exclude_wfns, .false.)

    call states_elec_null(stout)

    call states_elec_dim_copy(stout%d, stin%d)
    SAFE_ALLOCATE_SOURCE_A(stout%kweights, stin%kweights)
    stout%nik = stin%nik

    call modelmb_particles_copy(stout%modelmbparticles, stin%modelmbparticles)
    if (stin%modelmbparticles%nparticle > 0) then
      SAFE_ALLOCATE_SOURCE_A(stout%mmb_nspindown, stin%mmb_nspindown)
      SAFE_ALLOCATE_SOURCE_A(stout%mmb_iyoung, stin%mmb_iyoung)
      SAFE_ALLOCATE_SOURCE_A(stout%mmb_proj, stin%mmb_proj)
    end if

    stout%wfs_type    = stin%wfs_type
    stout%nst         = stin%nst
    stout%block_size  = stin%block_size
    stout%orth_method = stin%orth_method

    stout%cl_states_mem  = stin%cl_states_mem
    stout%pack_states    = stin%pack_states


    stout%only_userdef_istates = stin%only_userdef_istates

    if (.not. exclude_wfns_) then
      SAFE_ALLOCATE_SOURCE_A(stout%rho, stin%rho)
    end if

    stout%uniform_occ = stin%uniform_occ

    if (.not. optional_default(exclude_eigenval, .false.)) then
      SAFE_ALLOCATE_SOURCE_A(stout%eigenval, stin%eigenval)
      SAFE_ALLOCATE_SOURCE_A(stout%occ, stin%occ)
      SAFE_ALLOCATE_SOURCE_A(stout%spin, stin%spin)
    end if

    ! the call to init_block is done at the end of this subroutine
    ! it allocates iblock, psib, block_is_local
    stout%group%nblocks = stin%group%nblocks

    SAFE_ALLOCATE_SOURCE_A(stout%user_def_states, stin%user_def_states)

    SAFE_ALLOCATE_SOURCE_A(stout%current, stin%current)
    SAFE_ALLOCATE_SOURCE_A(stout%current_kpt, stin%current_kpt)
    SAFE_ALLOCATE_SOURCE_A(stout%rho_core, stin%rho_core)
    SAFE_ALLOCATE_SOURCE_A(stout%frozen_rho, stin%frozen_rho)
    SAFE_ALLOCATE_SOURCE_A(stout%frozen_tau, stin%frozen_tau)
    SAFE_ALLOCATE_SOURCE_A(stout%frozen_gdens, stin%frozen_gdens)
    SAFE_ALLOCATE_SOURCE_A(stout%frozen_ldens, stin%frozen_ldens)

    stout%fixed_occ = stin%fixed_occ
    stout%restart_fixed_occ = stin%restart_fixed_occ

    stout%fixed_spins = stin%fixed_spins

    stout%qtot       = stin%qtot
    stout%val_charge = stin%val_charge

    call smear_copy(stout%smear, stin%smear)

    stout%parallel_in_states = stin%parallel_in_states
    call mpi_grp_copy(stout%mpi_grp, stin%mpi_grp)
    call mpi_grp_copy(stout%dom_st_kpt_mpi_grp, stin%dom_st_kpt_mpi_grp)
    call mpi_grp_copy(stout%st_kpt_mpi_grp, stin%st_kpt_mpi_grp)
    call mpi_grp_copy(stout%dom_st_mpi_grp, stin%dom_st_mpi_grp)
    SAFE_ALLOCATE_SOURCE_A(stout%node, stin%node)
    SAFE_ALLOCATE_SOURCE_A(stout%st_kpt_task, stin%st_kpt_task)

#ifdef HAVE_SCALAPACK
    call blacs_proc_grid_copy(stin%dom_st_proc_grid, stout%dom_st_proc_grid)
#endif

    call distributed_copy(stin%dist, stout%dist)

    stout%scalapack_compatible = stin%scalapack_compatible

    stout%lnst       = stin%lnst
    stout%st_start   = stin%st_start
    stout%st_end     = stin%st_end

    if (stin%parallel_in_states) call multicomm_all_pairs_copy(stout%ap, stin%ap)

    stout%symmetrize_density = stin%symmetrize_density

    if (.not. exclude_wfns_) call states_elec_group_copy(stin%d, stin%group, stout%group, special=special)

    stout%packed = stin%packed

    stout%randomization = stin%randomization

    POP_SUB(states_elec_copy)
  end subroutine states_elec_copy


  ! ---------------------------------------------------------
  !> @brief finalize the states_elec_t object
  !
  subroutine states_elec_end(st)
    type(states_elec_t), intent(inout) :: st

    PUSH_SUB(states_elec_end)

    call states_elec_dim_end(st%d)

    if (st%modelmbparticles%nparticle > 0) then
      SAFE_DEALLOCATE_A(st%mmb_nspindown)
      SAFE_DEALLOCATE_A(st%mmb_iyoung)
      SAFE_DEALLOCATE_A(st%mmb_proj)
    end if
    call modelmb_particles_end(st%modelmbparticles)

    ! this deallocates dpsi, zpsi, psib, iblock, iblock
    call states_elec_deallocate_wfns(st)

    SAFE_DEALLOCATE_A(st%user_def_states)

    SAFE_DEALLOCATE_A(st%rho)
    SAFE_DEALLOCATE_A(st%eigenval)

    SAFE_DEALLOCATE_A(st%current)
    SAFE_DEALLOCATE_A(st%current_para)
    SAFE_DEALLOCATE_A(st%current_dia)
    SAFE_DEALLOCATE_A(st%current_mag)
    SAFE_DEALLOCATE_A(st%current_kpt)
    SAFE_DEALLOCATE_A(st%rho_core)
    SAFE_DEALLOCATE_A(st%frozen_rho)
    SAFE_DEALLOCATE_A(st%frozen_tau)
    SAFE_DEALLOCATE_A(st%frozen_gdens)
    SAFE_DEALLOCATE_A(st%frozen_ldens)
    SAFE_DEALLOCATE_A(st%occ)
    SAFE_DEALLOCATE_A(st%spin)
    SAFE_DEALLOCATE_A(st%kweights)


#ifdef HAVE_SCALAPACK
    call blacs_proc_grid_end(st%dom_st_proc_grid)
#endif

    call distributed_end(st%dist)

    SAFE_DEALLOCATE_A(st%node)
    SAFE_DEALLOCATE_A(st%st_kpt_task)

    if (st%parallel_in_states) then
      SAFE_DEALLOCATE_A(st%ap%schedule)
    end if

    POP_SUB(states_elec_end)
  end subroutine states_elec_end


  ! TODO(Alex) Issue #684. Abstract duplicate code in states_elec_generate_random, to get it to
  ! a point where it can be refactored.
  !> @brief randomize states
  subroutine states_elec_generate_random(st, mesh, kpoints, ist_start_, ist_end_, ikpt_start_, ikpt_end_, normalized)
    type(states_elec_t),    intent(inout) :: st          !< object to randomize
    class(mesh_t),          intent(in)    :: mesh        !< underlying mess
    type(kpoints_t),        intent(in)    :: kpoints     !< kpoint list
    integer, optional,      intent(in)    :: ist_start_  !< optional start state index
    integer, optional,      intent(in)    :: ist_end_    !< optional end state index
    integer, optional,      intent(in)    :: ikpt_start_ !< optional start kpoint index
    integer, optional,      intent(in)    :: ikpt_end_   !< optional end kpoint index
    logical, optional,      intent(in)    :: normalized  !< optional flag whether to generate states should have norm 1,
    !!                                                      default = .true.

    integer :: ist, ik, id, ist_start, ist_end, jst, ikpt_start, ikpt_end
    complex(real64)   :: alpha, beta
    real(real64), allocatable :: dpsi(:,  :)
    complex(real64), allocatable :: zpsi(:,  :), zpsi2(:)
    integer :: ikpoint, ip
    type(batch_t) :: ffb

    logical :: normalized_

    normalized_ = optional_default(normalized, .true.)

    PUSH_SUB(states_elec_generate_random)

    ist_start = optional_default(ist_start_, 1)
    ist_end = optional_default(ist_end_, st%nst)
    ikpt_start = optional_default(ikpt_start_, 1)
    ikpt_end = optional_default(ikpt_end_, st%nik)

    SAFE_ALLOCATE(dpsi(1:mesh%np, 1:st%d%dim))
    if (states_are_complex(st)) then
      SAFE_ALLOCATE(zpsi(1:mesh%np, 1:st%d%dim))
    end if

    select case (st%d%ispin)
    case (UNPOLARIZED, SPIN_POLARIZED)

      do ik = ikpt_start, ikpt_end
        ikpoint = st%d%get_kpoint_index(ik)
        do ist = ist_start, ist_end
          if (states_are_real(st).or.kpoints_point_is_gamma(kpoints, ikpoint)) then
            if (st%randomization == PAR_INDEPENDENT) then
              call dmf_random(mesh, dpsi(:, 1), &
                pre_shift = mesh%pv%xlocal-1, &
                post_shift = mesh%pv%np_global - mesh%pv%xlocal - mesh%np + 1, &
                normalized = normalized)
              ! Ensures that the grid points are properly distributed in the domain parallel case
              if(mesh%parallel_in_domains) then
                call batch_init(ffb, dpsi(:,1))
                call dmesh_batch_exchange_points(mesh, ffb, backward_map = .true.)
                call ffb%end()
              end if
            else
              call dmf_random(mesh, dpsi(:, 1), normalized = normalized)
            end if
            if (.not. state_kpt_is_local(st, ist, ik)) cycle
            if (states_are_complex(st)) then !Gamma point
              do ip = 1, mesh%np
                zpsi(ip,1) = cmplx(dpsi(ip,1), M_ZERO, real64)
              end do
              call states_elec_set_state(st, mesh, ist,  ik, zpsi)
            else
              call states_elec_set_state(st, mesh, ist,  ik, dpsi)
            end if
          else
            if (st%randomization == PAR_INDEPENDENT) then
              call zmf_random(mesh, zpsi(:, 1), &
                pre_shift = mesh%pv%xlocal-1, &
                post_shift = mesh%pv%np_global - mesh%pv%xlocal - mesh%np + 1, &
                normalized = normalized)
              ! Ensures that the grid points are properly distributed in the domain parallel case
              if(mesh%parallel_in_domains) then
                call batch_init(ffb, zpsi(:,1))
                call zmesh_batch_exchange_points(mesh, ffb, backward_map = .true.)
                call ffb%end()
              end if
            else
              call zmf_random(mesh, zpsi(:, 1), normalized = normalized)
            end if
            if (.not. state_kpt_is_local(st, ist, ik)) cycle
            call states_elec_set_state(st, mesh, ist,  ik, zpsi)
          end if
        end do
      end do

    case (SPINORS)

      ASSERT(states_are_complex(st))

      if (st%fixed_spins) then

        do ik = ikpt_start, ikpt_end
          ikpoint = st%d%get_kpoint_index(ik)
          do ist = ist_start, ist_end
            if (kpoints_point_is_gamma(kpoints, ikpoint)) then
              if (st%randomization == PAR_INDEPENDENT) then
                call dmf_random(mesh, dpsi(:, 1), &
                  pre_shift = mesh%pv%xlocal-1, &
                  post_shift = mesh%pv%np_global - mesh%pv%xlocal - mesh%np + 1, &
                  normalized = normalized)
                ! Ensures that the grid points are properly distributed in the domain parallel case
                if(mesh%parallel_in_domains) then
                  call batch_init(ffb, dpsi(:,1))
                  call dmesh_batch_exchange_points(mesh, ffb, backward_map = .true.)
                  call ffb%end()
                end if
              else
                call dmf_random(mesh, dpsi(:, 1), normalized = normalized)
                if (.not. state_kpt_is_local(st, ist, ik)) cycle
              end if
              do ip = 1, mesh%np
                zpsi(ip,1) = cmplx(dpsi(ip,1), M_ZERO, real64)
              end do
              call states_elec_set_state(st, mesh, ist,  ik, zpsi)
            else
              if (st%randomization == PAR_INDEPENDENT) then
                call zmf_random(mesh, zpsi(:, 1), &
                  pre_shift = mesh%pv%xlocal-1, &
                  post_shift = mesh%pv%np_global - mesh%pv%xlocal - mesh%np + 1, &
                  normalized = normalized)
                ! Ensures that the grid points are properly distributed in the domain parallel case
                if(mesh%parallel_in_domains) then
                  call batch_init(ffb, zpsi(:,1))
                  call zmesh_batch_exchange_points(mesh, ffb, backward_map = .true.)
                  call ffb%end()
                end if
              else
                call zmf_random(mesh, zpsi(:, 1), normalized = normalized)
                if (.not. state_kpt_is_local(st, ist, ik)) cycle
              end if
            end if
            if (.not. state_kpt_is_local(st, ist, ik)) cycle
            ! In this case, the spinors are made of a spatial part times a vector [alpha beta]^T in
            ! spin space (i.e., same spatial part for each spin component). So (alpha, beta)
            ! determines the spin values. The values of (alpha, beta) can be be obtained
            ! with simple formulae from <Sx>, <Sy>, <Sz>.
            !
            ! Note that here we orthonormalize the orbital part. This ensures that the spinors
            ! are untouched later in the general orthonormalization, and therefore the spin values
            ! of each spinor remain the same.
            SAFE_ALLOCATE(zpsi2(1:mesh%np))
            do jst = ist_start, ist - 1
              call states_elec_get_state(st, mesh, 1, jst, ik, zpsi2)
              zpsi(1:mesh%np, 1) = zpsi(1:mesh%np, 1) - zmf_dotp(mesh, zpsi(:, 1), zpsi2)*zpsi2(1:mesh%np)
            end do
            SAFE_DEALLOCATE_A(zpsi2)

            call zmf_normalize(mesh, 1, zpsi)
            zpsi(1:mesh%np, 2) = zpsi(1:mesh%np, 1)

            alpha = cmplx(sqrt(M_HALF + st%spin(3, ist, ik)), M_ZERO, real64)
            beta  = cmplx(sqrt(M_ONE - abs(alpha)**2), M_ZERO, real64)
            if (abs(alpha) > M_ZERO) then
              beta = cmplx(st%spin(1, ist, ik) / abs(alpha), st%spin(2, ist, ik) / abs(alpha), real64)
            end if
            zpsi(1:mesh%np, 1) = alpha*zpsi(1:mesh%np, 1)
            zpsi(1:mesh%np, 2) = beta*zpsi(1:mesh%np, 2)
            call states_elec_set_state(st, mesh, ist,  ik, zpsi)
          end do
        end do
      else
        do ik = ikpt_start, ikpt_end
          do ist = ist_start, ist_end
            do id = 1, st%d%dim
              if (st%randomization == PAR_INDEPENDENT) then
                call zmf_random(mesh, zpsi(:, id), &
                  pre_shift = mesh%pv%xlocal-1, &
                  post_shift = mesh%pv%np_global - mesh%pv%xlocal - mesh%np + 1, &
                  normalized = .false.)
                ! Ensures that the grid points are properly distributed in the domain parallel case
                if(mesh%parallel_in_domains) then
                  call batch_init(ffb, zpsi(:, id))
                  call zmesh_batch_exchange_points(mesh, ffb, backward_map = .true.)
                  call ffb%end()
                end if
              else
                call zmf_random(mesh, zpsi(:, id), normalized = .false.)
              end if
            end do
            ! We need to generate the wave functions even if not using them in order to be consistent with the random seed in parallel runs.
            if (.not. state_kpt_is_local(st, ist, ik)) cycle
            ! Note that mf_random normalizes each spin channel independently to 1.
            ! Therefore we need to renormalize the spinor:
            if (normalized_) call zmf_normalize(mesh, st%d%dim, zpsi)
            call states_elec_set_state(st, mesh, ist,  ik, zpsi)
          end do
        end do
      end if

    end select

    SAFE_DEALLOCATE_A(dpsi)
    SAFE_DEALLOCATE_A(zpsi)

    POP_SUB(states_elec_generate_random)
  end subroutine states_elec_generate_random

  ! ---------------------------------------------------------
  !> @brief calculate the Fermi level for the states in this object
  !
  subroutine states_elec_fermi(st, namespace, mesh, compute_spin)
    type(states_elec_t), intent(inout) :: st
    type(namespace_t),   intent(in)    :: namespace
    class(mesh_t),       intent(in)    :: mesh
    logical, optional,   intent(in)    :: compute_spin !> optional flag: compute spins for SPINOR case?

    !> Local variables.
    integer            :: ist, ik
    real(real64)       :: charge
    complex(real64), allocatable :: zpsi(:, :)

    PUSH_SUB(states_elec_fermi)

    call smear_find_fermi_energy(st%smear, namespace, st%eigenval, st%occ, st%qtot, &
      st%nik, st%nst, st%kweights)

    call smear_fill_occupations(st%smear, st%eigenval, st%occ, st%nik, st%nst)

    ! check if everything is OK
    charge = M_ZERO
    do ist = 1, st%nst
      charge = charge + sum(st%occ(ist, 1:st%nik) * st%kweights(1:st%nik))
    end do
    if (abs(charge-st%qtot) > 1e-6_real64) then
      message(1) = 'Occupations do not integrate to total charge.'
      write(message(2), '(6x,f12.8,a,f12.8)') charge, ' != ', st%qtot
      call messages_warning(2, namespace=namespace)
      if (charge < M_EPSILON) then
        message(1) = "There don't seem to be any electrons at all!"
        call messages_fatal(1, namespace=namespace)
      end if
    end if

    if (st%d%ispin == SPINORS .and. optional_default(compute_spin,.true.)) then
      ASSERT(states_are_complex(st))

      st%spin(:,:,:) = M_ZERO

      SAFE_ALLOCATE(zpsi(1:mesh%np, st%d%dim))
      do ik = st%d%kpt%start, st%d%kpt%end
        do ist = st%st_start, st%st_end
          call states_elec_get_state(st, mesh, ist, ik, zpsi)
          st%spin(1:3, ist, ik) = state_spin(mesh, zpsi)
        end do
      end do
      SAFE_DEALLOCATE_A(zpsi)

      if (st%parallel_in_states .or. st%d%kpt%parallel) then
        call comm_allreduce(st%st_kpt_mpi_grp, st%spin)
      end if

    end if

    POP_SUB(states_elec_fermi)
  end subroutine states_elec_fermi


  ! ---------------------------------------------------------
  !> function to calculate the eigenvalues sum using occupations as weights
  !
  function states_elec_eigenvalues_sum(st, alt_eig) result(tot)
    type(states_elec_t),  intent(in) :: st                                     !< the states object
    real(real64),      optional, intent(in) :: alt_eig(st%st_start:, st%d%kpt%start:) !< alternative eigenvalues;
    !<                                                      dimension (st\%st_start:st\%st_end, st\%d\%kpt\%start:st\%d\%kpt\%end)
    real(real64)                     :: tot                                    !< eigenvalue sum

    integer :: ik

    PUSH_SUB(states_elec_eigenvalues_sum)

    tot = M_ZERO
    do ik = st%d%kpt%start, st%d%kpt%end
      if (present(alt_eig)) then
        tot = tot + st%kweights(ik) * sum(st%occ(st%st_start:st%st_end, ik) * &
          alt_eig(st%st_start:st%st_end, ik))
      else
        tot = tot + st%kweights(ik) * sum(st%occ(st%st_start:st%st_end, ik) * &
          st%eigenval(st%st_start:st%st_end, ik))
      end if
    end do

    if (st%parallel_in_states .or. st%d%kpt%parallel) call comm_allreduce(st%st_kpt_mpi_grp, tot)

    POP_SUB(states_elec_eigenvalues_sum)
  end function states_elec_eigenvalues_sum


  !> @Brief. Distribute states over the processes for states parallelization
  subroutine states_elec_distribute_nodes(st, namespace, mc)
    type(states_elec_t),    intent(inout) :: st
    type(namespace_t),      intent(in)    :: namespace
    type(multicomm_t),      intent(in)    :: mc  !< MPI communicators

    logical :: default_scalapack_compatible

    PUSH_SUB(states_elec_distribute_nodes)

    ! TODO(Alex) Issue #820. This is superflous. These defaults are set in initialisation of
    ! states, and in the state distribution instance
    ! Defaults.
    st%node(:)            = 0
    st%st_start           = 1
    st%st_end             = st%nst
    st%lnst               = st%nst
    st%parallel_in_states = .false.

    call mpi_grp_init(st%mpi_grp, mc%group_comm(P_STRATEGY_STATES))
    call mpi_grp_init(st%dom_st_kpt_mpi_grp, mc%dom_st_kpt_comm)
    call mpi_grp_init(st%dom_st_mpi_grp, mc%dom_st_comm)
    call mpi_grp_init(st%st_kpt_mpi_grp, mc%st_kpt_comm)

    default_scalapack_compatible = calc_mode_par%scalapack_compat() .and. .not. st%d%kpt%parallel

    !%Variable ScaLAPACKCompatible
    !%Type logical
    !%Section Execution::Parallelization
    !%Description
    !% Whether to use a layout for states parallelization which is compatible with ScaLAPACK.
    !% The default is yes for <tt>CalculationMode = gs, unocc, go</tt> without k-point parallelization,
    !% and no otherwise. (Setting to other than default is experimental.)
    !% The value must be yes if any ScaLAPACK routines are called in the course of the run;
    !% it must be set by hand for <tt>td</tt> with <tt>TDDynamics = bo</tt>.
    !% This variable has no effect unless you are using states parallelization and have linked ScaLAPACK.
    !% Note: currently, use of ScaLAPACK is not compatible with task parallelization (<i>i.e.</i> slaves).
    !%End
    call parse_variable(namespace, 'ScaLAPACKCompatible', default_scalapack_compatible, st%scalapack_compatible)

#ifdef HAVE_SCALAPACK
    if (default_scalapack_compatible .neqv. st%scalapack_compatible) then
      call messages_experimental('Setting ScaLAPACKCompatible to other than default', namespace=namespace)
    end if

    if (st%scalapack_compatible) then
      if (multicomm_have_slaves(mc)) then
        call messages_not_implemented("ScaLAPACK usage with task parallelization (slaves)", namespace=namespace)
      end if
      call blacs_proc_grid_init(st%dom_st_proc_grid, st%dom_st_mpi_grp)
    end if
#else
    st%scalapack_compatible = .false.
#endif

    if (multicomm_strategy_is_parallel(mc, P_STRATEGY_STATES)) then

#ifdef HAVE_MPI
      call multicomm_create_all_pairs(st%mpi_grp, st%ap)
#endif

      if (st%nst < st%mpi_grp%size) then
        message(1) = "Have more processors than necessary"
        write(message(2),'(i4,a,i4,a)') st%mpi_grp%size, " processors and ", st%nst, " states."
        call messages_fatal(2, namespace=namespace)
      end if

      call distributed_init(st%dist, st%nst, st%mpi_grp%comm, "states", scalapack_compat = st%scalapack_compatible)

      st%parallel_in_states = st%dist%parallel

      ! TODO(Alex) Issue #820. Remove lnst, st_start, st_end and node, as they are all contained within dist
      st%st_start = st%dist%start
      st%st_end   = st%dist%end
      st%lnst     = st%dist%nlocal
      st%node(1:st%nst) = st%dist%node(1:st%nst)

    end if

    call states_elec_kpoints_distribution(st)

    POP_SUB(states_elec_distribute_nodes)
  end subroutine states_elec_distribute_nodes


  !> @brief calculated selected quantities
  !!
  !! This function can calculate several quantities that depend on
  !! derivatives of the orbitals from the states and the density.
  !! The quantities to be calculated depend on the arguments passed.
  !
  subroutine states_elec_calc_quantities(gr, st, kpoints, nlcc, &
    kinetic_energy_density, paramagnetic_current, density_gradient, density_laplacian, &
    gi_kinetic_energy_density, st_end)
    type(grid_t),            intent(in)    :: gr         !< the underlying grid
    type(states_elec_t),     intent(in)    :: st         !< the states object
    type(kpoints_t),         intent(in)    :: kpoints    !< the kpoint list
    logical,                 intent(in)    :: nlcc       !< flag whether to use non-local core corrections
    real(real64), contiguous, optional, target, intent(out) :: kinetic_energy_density(:,:)       !< The kinetic energy density.
    real(real64), contiguous, optional, target, intent(out) :: paramagnetic_current(:,:,:)       !< The paramagnetic current.
    real(real64), contiguous, optional,         intent(out) :: density_gradient(:,:,:)           !< The gradient of the density.
    real(real64), contiguous, optional,         intent(out) :: density_laplacian(:,:)            !< The Laplacian of the density.
    real(real64), contiguous, optional,         intent(out) :: gi_kinetic_energy_density(:,:)    !< The gauge-invariant kinetic energy density.
    integer, optional,       intent(in)    :: st_end                            !< Maximum state used to compute the quantities

    real(real64), pointer, contiguous :: jp(:, :, :)
    real(real64), pointer, contiguous :: tau(:, :)
    complex(real64), allocatable :: wf_psi(:,:), gwf_psi(:,:,:), wf_psi_conj(:,:), lwf_psi(:,:)
    real(real64), allocatable :: abs_wf_psi(:,:), abs_gwf_psi(:,:)
    complex(real64), allocatable :: psi_gpsi(:,:)
    complex(real64)   :: c_tmp
    integer :: is, ik, ist, i_dim, st_dim, ii, st_end_
    real(real64)   :: ww, kpoint(gr%der%dim)
    logical :: something_to_do

    call profiling_in("STATES_CALC_QUANTITIES")

    PUSH_SUB(states_elec_calc_quantities)

    st_end_ = min(st%st_end, optional_default(st_end, st%st_end))

    something_to_do = present(kinetic_energy_density) .or. present(gi_kinetic_energy_density) .or. &
      present(paramagnetic_current) .or. present(density_gradient) .or. present(density_laplacian)
    ASSERT(something_to_do)

    SAFE_ALLOCATE( wf_psi(1:gr%np_part, 1:st%d%dim))
    SAFE_ALLOCATE( wf_psi_conj(1:gr%np_part, 1:st%d%dim))
    SAFE_ALLOCATE(gwf_psi(1:gr%np, 1:gr%der%dim, 1:st%d%dim))
    SAFE_ALLOCATE(abs_wf_psi(1:gr%np, 1:st%d%dim))
    SAFE_ALLOCATE(abs_gwf_psi(1:gr%np, 1:st%d%dim))
    SAFE_ALLOCATE(psi_gpsi(1:gr%np, 1:st%d%dim))
    if (present(density_laplacian)) then
      SAFE_ALLOCATE(lwf_psi(1:gr%np, 1:st%d%dim))
    end if

    nullify(tau)
    if (present(kinetic_energy_density)) tau => kinetic_energy_density

    nullify(jp)
    if (present(paramagnetic_current)) jp => paramagnetic_current

    ! for the gauge-invariant kinetic energy density we need the
    ! current and the kinetic energy density
    if (present(gi_kinetic_energy_density)) then
      if (.not. present(paramagnetic_current) .and. states_are_complex(st)) then
        SAFE_ALLOCATE(jp(1:gr%np, 1:gr%der%dim, 1:st%d%nspin))
      end if
      if (.not. present(kinetic_energy_density)) then
        SAFE_ALLOCATE(tau(1:gr%np, 1:st%d%nspin))
      end if
    end if

    if (associated(tau)) tau = M_ZERO
    if (associated(jp)) jp = M_ZERO
    if (present(density_gradient)) density_gradient(:,:,:) = M_ZERO
    if (present(density_laplacian)) density_laplacian(:,:) = M_ZERO
    if (present(gi_kinetic_energy_density)) gi_kinetic_energy_density = M_ZERO

    do ik = st%d%kpt%start, st%d%kpt%end

      kpoint(1:gr%der%dim) = kpoints%get_point(st%d%get_kpoint_index(ik))
      is = st%d%get_spin_index(ik)

      do ist = st%st_start, st_end_
        ww = st%kweights(ik)*st%occ(ist, ik)
        if (abs(ww) <= M_EPSILON) cycle

        ! all calculations will be done with complex wavefunctions
        call states_elec_get_state(st, gr, ist, ik, wf_psi)

        do st_dim = 1, st%d%dim
          call boundaries_set(gr%der%boundaries, gr, wf_psi(:, st_dim))
        end do

        ! calculate gradient of the wavefunction
        do st_dim = 1, st%d%dim
          call zderivatives_grad(gr%der, wf_psi(:,st_dim), gwf_psi(:,:,st_dim), set_bc = .false.)
        end do

        ! calculate the Laplacian of the wavefunction
        if (present(density_laplacian)) then
          do st_dim = 1, st%d%dim
            call zderivatives_lapl(gr%der, wf_psi(:,st_dim), lwf_psi(:,st_dim), ghost_update = .false., set_bc = .false.)
          end do
        end if

        ! We precompute some quantites, to avoid to compute it many times
        wf_psi_conj(1:gr%np, 1:st%d%dim) = conjg(wf_psi(1:gr%np,1:st%d%dim))
        abs_wf_psi(1:gr%np, 1:st%d%dim) = real(wf_psi_conj(1:gr%np, 1:st%d%dim) * wf_psi(1:gr%np, 1:st%d%dim), real64)

        if (present(density_laplacian)) then
          density_laplacian(1:gr%np, is) = density_laplacian(1:gr%np, is) + &
            ww * M_TWO*real(wf_psi_conj(1:gr%np, 1) * lwf_psi(1:gr%np, 1), real64)
          if (st%d%ispin == SPINORS) then
            density_laplacian(1:gr%np, 2) = density_laplacian(1:gr%np, 2) + &
              ww * M_TWO*real(wf_psi_conj(1:gr%np, 2) * lwf_psi(1:gr%np, 2), real64)
            !$omp parallel do private(c_tmp)
            do ii = 1, gr%np
              c_tmp = ww*(lwf_psi(ii, 1) * wf_psi_conj(ii, 2) + wf_psi(ii, 1) * conjg(lwf_psi(ii, 2)))
              density_laplacian(ii, 3) = density_laplacian(ii, 3) + real(c_tmp, real64)
              density_laplacian(ii, 4) = density_laplacian(ii, 4) + aimag(c_tmp)
            end do
          end if
        end if

        if (associated(tau)) then
          tau(1:gr%np, is) = tau(1:gr%np, is) &
            + ww * sum(kpoint(1:gr%der%dim)**2) * abs_wf_psi(1:gr%np, 1)
          if (st%d%ispin == SPINORS) then
            tau(1:gr%np, 2) = tau(1:gr%np, 2) &
              + ww * sum(kpoint(1:gr%der%dim)**2) * abs_wf_psi(1:gr%np, 2)

            !$omp parallel do private(c_tmp)
            do ii = 1, gr%np
              c_tmp = ww * sum(kpoint(1:gr%der%dim)**2) * wf_psi(ii, 1) * wf_psi_conj(ii, 2)
              tau(ii, 3) = tau(ii, 3) + real(c_tmp, real64)
              tau(ii, 4) = tau(ii, 4) + aimag(c_tmp)
            end do
          end if
        end if

        do i_dim = 1, gr%der%dim

          ! We precompute some quantites, to avoid to compute them many times
          psi_gpsi(1:gr%np, 1:st%d%dim) = wf_psi_conj(1:gr%np, 1:st%d%dim) &
            * gwf_psi(1:gr%np,i_dim,1:st%d%dim)
          abs_gwf_psi(1:gr%np, 1:st%d%dim) = real(conjg(gwf_psi(1:gr%np, i_dim, 1:st%d%dim)) &
            * gwf_psi(1:gr%np, i_dim, 1:st%d%dim), real64)

          if (present(density_gradient)) then
            density_gradient(1:gr%np, i_dim, is) = density_gradient(1:gr%np, i_dim, is) &
              + ww * M_TWO * real(psi_gpsi(1:gr%np, 1), real64)
            if (st%d%ispin == SPINORS) then
              density_gradient(1:gr%np, i_dim, 2) = density_gradient(1:gr%np, i_dim, 2)  &
                + ww * M_TWO*real(psi_gpsi(1:gr%np, 2), real64)
              !$omp parallel do private(c_tmp)
              do ii = 1, gr%np
                c_tmp = ww * (gwf_psi(ii, i_dim, 1) * wf_psi_conj(ii, 2) + wf_psi(ii, 1) * conjg(gwf_psi(ii, i_dim, 2)))
                density_gradient(ii, i_dim, 3) = density_gradient(ii, i_dim, 3) + real(c_tmp, real64)
                density_gradient(ii, i_dim, 4) = density_gradient(ii, i_dim, 4) + aimag(c_tmp)
              end do
            end if
          end if

          if (present(density_laplacian)) then
            call lalg_axpy(gr%np, ww*M_TWO, abs_gwf_psi(:,1), density_laplacian(:,is))
            if (st%d%ispin == SPINORS) then
              call lalg_axpy(gr%np, ww*M_TWO, abs_gwf_psi(:,2), density_laplacian(:,2))
              !$omp parallel do private(c_tmp)
              do ii = 1, gr%np
                c_tmp = M_TWO * ww * gwf_psi(ii, i_dim, 1) * conjg(gwf_psi(ii, i_dim, 2))
                density_laplacian(ii, 3) = density_laplacian(ii, 3) + real(c_tmp, real64)
                density_laplacian(ii, 4) = density_laplacian(ii, 4) + aimag(c_tmp)
              end do
            end if
          end if

          ! the expression for the paramagnetic current with spinors is
          !     j = ( jp(1)             jp(3) + i jp(4) )
          !         (-jp(3) + i jp(4)   jp(2)           )
          if (associated(jp)) then
            if (.not.(states_are_real(st))) then
              jp(1:gr%np, i_dim, is) = jp(1:gr%np, i_dim, is) + &
                ww*(aimag(psi_gpsi(1:gr%np, 1)) - abs_wf_psi(1:gr%np, 1)*kpoint(i_dim))
              if (st%d%ispin == SPINORS) then
                jp(1:gr%np, i_dim, 2) = jp(1:gr%np, i_dim, 2) + &
                  ww*( aimag(psi_gpsi(1:gr%np, 2)) - abs_wf_psi(1:gr%np, 2)*kpoint(i_dim))
                !$omp parallel do private(c_tmp)
                do ii = 1, gr%np
                  c_tmp = -ww*M_HALF*M_zI*(gwf_psi(ii, i_dim, 1)*wf_psi_conj(ii, 2) - wf_psi(ii, 1)*conjg(gwf_psi(ii, i_dim, 2)) &
                    - M_TWO * M_zI*wf_psi(ii, 1)*wf_psi_conj(ii, 2)*kpoint(i_dim))
                  jp(ii, i_dim, 3) = jp(ii, i_dim, 3) + real(c_tmp, real64)
                  jp(ii, i_dim, 4) = jp(ii, i_dim, 4) + aimag(c_tmp)
                end do
              end if
            end if
          end if

          ! the expression for the paramagnetic current with spinors is
          !     t = ( tau(1)              tau(3) + i tau(4) )
          !         ( tau(3) - i tau(4)   tau(2)            )
          if (associated(tau)) then
            tau(1:gr%np, is) = tau(1:gr%np, is) + ww*(abs_gwf_psi(1:gr%np,1) &
              - M_TWO*aimag(psi_gpsi(1:gr%np, 1))*kpoint(i_dim))
            if (st%d%ispin == SPINORS) then
              tau(1:gr%np, 2) = tau(1:gr%np, 2) + ww*(abs_gwf_psi(1:gr%np, 2) &
                - M_TWO*aimag(psi_gpsi(1:gr%np, 2))*kpoint(i_dim))
              !$omp parallel do private(c_tmp)
              do ii = 1, gr%np
                c_tmp = ww * ( gwf_psi(ii, i_dim, 1)*conjg(gwf_psi(ii, i_dim, 2))  &
                  + M_zI * (gwf_psi(ii,i_dim,1)*wf_psi_conj(ii, 2)    &
                  - wf_psi(ii, 1)*conjg(gwf_psi(ii,i_dim,2)))*kpoint(i_dim))
                tau(ii, 3) = tau(ii, 3) + real(c_tmp, real64)
                tau(ii, 4) = tau(ii, 4) + aimag(c_tmp)
              end do
            end if
          end if

        end do

      end do
    end do

    SAFE_DEALLOCATE_A(wf_psi_conj)
    SAFE_DEALLOCATE_A(abs_wf_psi)
    SAFE_DEALLOCATE_A(abs_gwf_psi)
    SAFE_DEALLOCATE_A(psi_gpsi)

    if (.not. present(gi_kinetic_energy_density)) then
      if (.not. present(paramagnetic_current)) then
        SAFE_DEALLOCATE_P(jp)
      end if
      if (.not. present(kinetic_energy_density)) then
        SAFE_DEALLOCATE_P(tau)
      end if
    end if

    if (st%parallel_in_states .or. st%d%kpt%parallel) call reduce_all(st%st_kpt_mpi_grp)

    ! We have to symmetrize everything as they are calculated from the
    ! wavefunctions.
    ! This must be done before compute the gauge-invariant kinetic energy density
    if (st%symmetrize_density) then
      do is = 1, st%d%nspin
        if (associated(tau)) then
          call dgrid_symmetrize_scalar_field(gr, tau(:, is), suppress_warning = .true.)
        end if

        if (present(density_laplacian)) then
          call dgrid_symmetrize_scalar_field(gr, density_laplacian(:, is), suppress_warning = .true.)
        end if

        if (associated(jp)) then
          call dgrid_symmetrize_vector_field(gr, jp(:, :, is), suppress_warning = .true.)
        end if

        if (present(density_gradient)) then
          call dgrid_symmetrize_vector_field(gr, density_gradient(:, :, is), suppress_warning = .true.)
        end if
      end do
    end if


    if (allocated(st%rho_core) .and. nlcc .and. (present(density_laplacian) .or. present(density_gradient))) then
      do ii = 1, gr%np
        wf_psi(ii, 1) = st%rho_core(ii)/st%d%spin_channels
      end do

      call boundaries_set(gr%der%boundaries, gr, wf_psi(:, 1))

      if (present(density_gradient)) then
        ! calculate gradient of the NLCC
        call zderivatives_grad(gr%der, wf_psi(:,1), gwf_psi(:,:,1), set_bc = .false.)
        do is = 1, st%d%spin_channels
          density_gradient(1:gr%np, 1:gr%der%dim, is) = density_gradient(1:gr%np, 1:gr%der%dim, is) + &
            real(gwf_psi(1:gr%np, 1:gr%der%dim,1))
        end do
      end if

      ! calculate the Laplacian of the wavefunction
      if (present(density_laplacian)) then
        call zderivatives_lapl(gr%der, wf_psi(:,1), lwf_psi(:,1), set_bc = .false.)

        do is = 1, st%d%spin_channels
          density_laplacian(1:gr%np, is) = density_laplacian(1:gr%np, is) + real(lwf_psi(1:gr%np, 1))
        end do
      end if
    end if

    !If we freeze some of the orbitals, we need to had the contributions here
    !Only in the case we are not computing it
    if (allocated(st%frozen_tau) .and. .not. present(st_end)) then
      call lalg_axpy(gr%np, st%d%nspin, M_ONE, st%frozen_tau, tau)
    end if
    if (allocated(st%frozen_gdens) .and. .not. present(st_end)) then
      do is = 1, st%d%nspin
        call lalg_axpy(gr%np, gr%der%dim, M_ONE, st%frozen_gdens(:,:,is), density_gradient(:,:,is))
      end do
    end if
    if (allocated(st%frozen_tau) .and. .not. present(st_end)) then
      call lalg_axpy(gr%np, st%d%nspin, M_ONE, st%frozen_ldens, density_laplacian)
    end if

    SAFE_DEALLOCATE_A(wf_psi)
    SAFE_DEALLOCATE_A(gwf_psi)
    SAFE_DEALLOCATE_A(lwf_psi)


    !We compute the gauge-invariant kinetic energy density
    if (present(gi_kinetic_energy_density)) then
      do is = 1, st%d%nspin
        ASSERT(associated(tau))
        gi_kinetic_energy_density(1:gr%np, is) = tau(1:gr%np, is)
      end do
      if (states_are_complex(st)) then
        ASSERT(associated(jp))
        if (st%d%ispin /= SPINORS) then
          do is = 1, st%d%nspin
            !$omp parallel do
            do ii = 1, gr%np
              if (st%rho(ii, is) < 1.0e-7_real64) cycle
              gi_kinetic_energy_density(ii, is) = &
                gi_kinetic_energy_density(ii, is) - sum(jp(ii,1:gr%der%dim, is)**2)/st%rho(ii, is)
            end do
          end do
        else ! Note that this is only the U(1) part of the gauge-invariant KED
          !$omp parallel do
          do ii = 1, gr%np
            gi_kinetic_energy_density(ii, 1) = gi_kinetic_energy_density(ii, 1) &
              - sum(jp(ii,1:gr%der%dim, 1)**2)/(SAFE_TOL(st%rho(ii, 1),M_EPSILON))
            gi_kinetic_energy_density(ii, 2) = gi_kinetic_energy_density(ii, 2) &
              - sum(jp(ii,1:gr%der%dim, 2)**2)/(SAFE_TOL(st%rho(ii, 2),M_EPSILON))
            gi_kinetic_energy_density(ii, 3) = &
              gi_kinetic_energy_density(ii, 3) - sum(jp(ii,1:gr%der%dim, 3)**2 + jp(ii,1:gr%der%dim, 4)**2) &
              /(SAFE_TOL((st%rho(ii, 3)**2 + st%rho(ii, 4)**2), M_EPSILON))*st%rho(ii, 3)
            gi_kinetic_energy_density(ii, 4) = &
              gi_kinetic_energy_density(ii, 4) + sum(jp(ii,1:gr%der%dim, 3)**2 + jp(ii,1:gr%der%dim, 4)**2) &
              /(SAFE_TOL((st%rho(ii, 3)**2 + st%rho(ii, 4)**2), M_EPSILON))*st%rho(ii, 4)
          end do
        end if
      end if
    end if

    if (.not. present(kinetic_energy_density)) then
      SAFE_DEALLOCATE_P(tau)
    end if
    if (.not. present(paramagnetic_current)) then
      SAFE_DEALLOCATE_P(jp)
    end if


    POP_SUB(states_elec_calc_quantities)

    call profiling_out("STATES_CALC_QUANTITIES")

  contains

    subroutine reduce_all(grp)
      type(mpi_grp_t), intent(in)  :: grp

      PUSH_SUB(states_elec_calc_quantities.reduce_all)

      if (associated(tau)) call comm_allreduce(grp, tau, dim = (/gr%np, st%d%nspin/))

      if (present(density_laplacian)) call comm_allreduce(grp, density_laplacian, dim = (/gr%np, st%d%nspin/))

      do is = 1, st%d%nspin
        if (associated(jp)) call comm_allreduce(grp, jp(:, :, is), dim = (/gr%np, gr%der%dim/))

        if (present(density_gradient)) then
          call comm_allreduce(grp, density_gradient(:, :, is), dim = (/gr%np, gr%der%dim/))
        end if
      end do

      POP_SUB(states_elec_calc_quantities.reduce_all)
    end subroutine reduce_all

  end subroutine states_elec_calc_quantities


  ! ---------------------------------------------------------
  !> calculate the spin vector for a spinor wave function f1
  !
  function state_spin(mesh, f1) result(spin)
    type(mesh_t), intent(in) :: mesh      !< underlying mesh
    complex(real64),        intent(in) :: f1(:, :)  !< spinor wave functionl dimensions (1:mesh\%np, 1:2)
    real(real64)             :: spin(1:3) !< spin vector

    complex(real64) :: z

    PUSH_SUB(state_spin)

    z = zmf_dotp(mesh, f1(:, 1) , f1(:, 2))

    spin(1) = M_TWO*real(z, real64)
    spin(2) = M_TWO*aimag(z)
    spin(3) = zmf_nrm2(mesh, f1(:, 1))**2 - zmf_nrm2(mesh, f1(:, 2))**2
    spin = M_HALF*spin ! spin is half the sigma matrix.

    POP_SUB(state_spin)
  end function state_spin

  ! ---------------------------------------------------------
  !> check whether a given state (ist) is on the local node
  !
  logical function state_is_local(st, ist)
    type(states_elec_t), intent(in) :: st
    integer,             intent(in) :: ist

    PUSH_SUB(state_is_local)

    state_is_local = ist >= st%st_start.and.ist <= st%st_end

    POP_SUB(state_is_local)
  end function state_is_local

  ! ---------------------------------------------------------
  !> check whether a given state (ist, ik) is on the local node
  !
  logical function state_kpt_is_local(st, ist, ik)
    type(states_elec_t), intent(in) :: st
    integer,             intent(in) :: ist
    integer,             intent(in) :: ik

    PUSH_SUB(state_kpt_is_local)

    state_kpt_is_local = ist >= st%st_start .and. ist <= st%st_end .and. &
      ik >= st%d%kpt%start .and. ik <= st%d%kpt%end

    POP_SUB(state_kpt_is_local)
  end function state_kpt_is_local


  ! ---------------------------------------------------------
  !> return the memory usage of a states_elec_t object
  real(real64) function states_elec_wfns_memory(st, mesh) result(memory)
    type(states_elec_t), intent(in) :: st
    class(mesh_t),       intent(in) :: mesh

    PUSH_SUB(states_elec_wfns_memory)
    memory = M_ZERO

    ! orbitals
    memory = memory + real64*real(mesh%np_part_global, real64) *st%d%dim*real(st%nst, real64) *st%d%kpt%nglobal

    POP_SUB(states_elec_wfns_memory)
  end function states_elec_wfns_memory

  ! ---------------------------------------------------------
  !> pack the batches in this states object
  !
  subroutine states_elec_pack(st, copy)
    class(states_elec_t),    intent(inout) :: st
    logical,      optional, intent(in)    :: copy

    integer :: iqn, ib
    integer(int64) :: max_mem, mem

    PUSH_SUB(states_elec_pack)

    ! nothing to do, already packed
    if (st%packed) then
      POP_SUB(states_elec_pack)
      return
    end if

    st%packed = .true.

    if (accel_is_enabled()) then
      max_mem = accel_global_memory_size()

      if (st%cl_states_mem > M_ONE) then
        max_mem = int(st%cl_states_mem, int64)*(1024_8)**2
      else if (st%cl_states_mem < 0.0_real64) then
        max_mem = max_mem + int(st%cl_states_mem, int64)*(1024_8)**2
      else
        max_mem = int(st%cl_states_mem*real(max_mem, real64) , int64)
      end if
    else
      max_mem = HUGE(max_mem)
    end if

    mem = 0
    qnloop: do iqn = st%d%kpt%start, st%d%kpt%end
      do ib = st%group%block_start, st%group%block_end

        mem = mem + st%group%psib(ib, iqn)%pack_total_size()

        if (mem > max_mem) then
          call messages_write('Not enough CL device memory to store all states simultaneously.', new_line = .true.)
          call messages_write('Only ')
          call messages_write(ib - st%group%block_start)
          call messages_write(' of ')
          call messages_write(st%group%block_end - st%group%block_start + 1)
          call messages_write(' blocks will be stored in device memory.', new_line = .true.)
          call messages_warning()
          exit qnloop
        end if

        call st%group%psib(ib, iqn)%do_pack(copy)
      end do
    end do qnloop

    POP_SUB(states_elec_pack)
  end subroutine states_elec_pack

  ! ------------------------------------------------------------
  !> unpack the batches in this states object
  !
  subroutine states_elec_unpack(st, copy)
    class(states_elec_t),    intent(inout) :: st
    logical,      optional, intent(in)    :: copy

    integer :: iqn, ib

    PUSH_SUB(states_elec_unpack)

    if (st%packed) then
      st%packed = .false.

      do iqn = st%d%kpt%start, st%d%kpt%end
        do ib = st%group%block_start, st%group%block_end
          if (st%group%psib(ib, iqn)%is_packed()) call st%group%psib(ib, iqn)%do_unpack(copy)
        end do
      end do
    end if

    POP_SUB(states_elec_unpack)
  end subroutine states_elec_unpack

  ! -----------------------------------------------------------
  !> write information about the states object
  !
  subroutine states_elec_write_info(st, namespace)
    class(states_elec_t),    intent(in) :: st
    type(namespace_t),       intent(in) :: namespace

    PUSH_SUB(states_elec_write_info)

    call messages_print_with_emphasis(msg="States", namespace=namespace)

    write(message(1), '(a,f12.3)') 'Total electronic charge  = ', st%qtot
    write(message(2), '(a,i8)')    'Number of states         = ', st%nst
    write(message(3), '(a,i8)')    'States block-size        = ', st%block_size
    call messages_info(3, namespace=namespace)

    call messages_print_with_emphasis(namespace=namespace)

    POP_SUB(states_elec_write_info)
  end subroutine states_elec_write_info

  ! ------------------------------------------------------------
  !> expclicitely set all wave functions in the states to zero
  !
  subroutine states_elec_set_zero(st)
    class(states_elec_t),    intent(inout) :: st

    integer :: iqn, ib

    PUSH_SUB(states_elec_set_zero)

    do iqn = st%d%kpt%start, st%d%kpt%end
      do ib = st%group%block_start, st%group%block_end
        call batch_set_zero(st%group%psib(ib, iqn))
      end do
    end do

    POP_SUB(states_elec_set_zero)
  end subroutine states_elec_set_zero

  ! ------------------------------------------------------------
  !> return index of first state in block ib
  !
  integer pure function states_elec_block_min(st, ib) result(range)
    type(states_elec_t),    intent(in) :: st
    integer,                intent(in) :: ib

    range = st%group%block_range(ib, 1)
  end function states_elec_block_min

  ! ------------------------------------------------------------
  !> return index of last state in block ib
  !
  integer pure function states_elec_block_max(st, ib) result(range)
    type(states_elec_t),    intent(in) :: st
    integer,                intent(in) :: ib

    range = st%group%block_range(ib, 2)
  end function states_elec_block_max

  ! ------------------------------------------------------------
  !> return number of states in block ib
  !
  integer pure function states_elec_block_size(st, ib) result(size)
    type(states_elec_t),    intent(in) :: st
    integer,           intent(in) :: ib

    size = st%group%block_size(ib)
  end function states_elec_block_size

  ! ---------------------------------------------------------
  !> number of occupied-unoccupied pairs for Casida
  !
  subroutine states_elec_count_pairs(st, namespace, n_pairs, n_occ, n_unocc, is_included, is_frac_occ)
    type(states_elec_t),  intent(in)  :: st
    type(namespace_t),    intent(in)  :: namespace
    integer,              intent(out) :: n_pairs    !< result: number of pairs
    integer,              intent(out) :: n_occ(:)   !< result: number of occ. states per k-point; dimension (1:nik)
    integer,              intent(out) :: n_unocc(:) !< result: number of unocc. states per k-point;dimension (1:nik)
    logical, allocatable, intent(out) :: is_included(:,:,:) !< result: mask whether a state is included;
    !!                                                         dimension (max(n_occ), max(n_unocc), st\%d\%nik)
    logical,              intent(out) :: is_frac_occ !< result: are there fractional occupations?

    integer :: ik, ist, ast, n_filled, n_partially_filled, n_half_filled
    character(len=80) :: nst_string, default, wfn_list
    real(real64) :: energy_window

    PUSH_SUB(states_elec_count_pairs)

    is_frac_occ = .false.
    do ik = 1, st%nik
      call occupied_states(st, namespace, ik, n_filled, n_partially_filled, n_half_filled)
      if (n_partially_filled > 0 .or. n_half_filled > 0) is_frac_occ = .true.
      n_occ(ik) = n_filled + n_partially_filled + n_half_filled
      n_unocc(ik) = st%nst - n_filled
      ! when we implement occupations, partially occupied levels need to be counted as both occ and unocc.
    end do

    !%Variable CasidaKSEnergyWindow
    !%Type float
    !%Section Linear Response::Casida
    !%Description
    !% An alternative to <tt>CasidaKohnShamStates</tt> for specifying which occupied-unoccupied
    !% transitions will be used: all those whose eigenvalue differences are less than this
    !% number will be included. If a value less than 0 is supplied, this criterion will not be used.
    !%End

    call parse_variable(namespace, 'CasidaKSEnergyWindow', -M_ONE, energy_window, units_inp%energy)

    !%Variable CasidaKohnShamStates
    !%Type string
    !%Section Linear Response::Casida
    !%Default all states
    !%Description
    !% The calculation of the excitation spectrum of a system in the Casida frequency-domain
    !% formulation of linear-response time-dependent density functional theory (TDDFT)
    !% implies the use of a basis set of occupied/unoccupied Kohn-Sham orbitals. This
    !% basis set should, in principle, include all pairs formed by all occupied states,
    !% and an infinite number of unoccupied states. In practice, one has to truncate this
    !% basis set, selecting a number of occupied and unoccupied states that will form the
    !% pairs. These states are specified with this variable. If there are, say, 15 occupied
    !% states, and one sets this variable to the value "10-18", this means that occupied
    !% states from 10 to 15, and unoccupied states from 16 to 18 will be considered.
    !%
    !% This variable is a string in list form, <i>i.e.</i> expressions such as "1,2-5,8-15" are
    !% valid. You should include a non-zero number of unoccupied states and a non-zero number
    !% of occupied states.
    !%End

    n_pairs = 0
    SAFE_ALLOCATE(is_included(maxval(n_occ), minval(n_occ) + 1:st%nst , st%nik))
    is_included(:,:,:) = .false.

    if (energy_window < M_ZERO) then
      write(nst_string,'(i6)') st%nst
      write(default,'(a,a)') "1-", trim(adjustl(nst_string))
      call parse_variable(namespace, 'CasidaKohnShamStates', default, wfn_list)

      write(message(1),'(a,a)') "Info: States that form the basis: ", trim(wfn_list)
      call messages_info(1, namespace=namespace)

      ! count pairs
      n_pairs = 0
      do ik = 1, st%nik
        do ast = n_occ(ik) + 1, st%nst
          if (loct_isinstringlist(ast, wfn_list)) then
            do ist = 1, n_occ(ik)
              if (loct_isinstringlist(ist, wfn_list)) then
                n_pairs = n_pairs + 1
                is_included(ist, ast, ik) = .true.
              end if
            end do
          end if
        end do
      end do

    else ! using CasidaKSEnergyWindow

      write(message(1),'(a,f12.6,a)') "Info: including transitions with energy < ", &
        units_from_atomic(units_out%energy, energy_window), trim(units_abbrev(units_out%energy))
      call messages_info(1, namespace=namespace)

      ! count pairs
      n_pairs = 0
      do ik = 1, st%nik
        do ast = n_occ(ik) + 1, st%nst
          do ist = 1, n_occ(ik)
            if (st%eigenval(ast, ik) - st%eigenval(ist, ik) < energy_window) then
              n_pairs = n_pairs + 1
              is_included(ist, ast, ik) = .true.
            end if
          end do
        end do
      end do

    end if

    POP_SUB(states_elec_count_pairs)
  end subroutine states_elec_count_pairs


  !> @brief return information about occupied orbitals in many-body state
  !!
  !! Returns information about which single-particle orbitals are
  !! occupied or not in a _many-particle_ state st:
  !!   - n_filled are the number of orbitals that are totally filled
  !!            (the occupation number is two, if ispin = UNPOLARIZED,
  !!            or it is one in the other cases).
  !!   - n_half_filled is only meaningful if ispin = UNPOLARIZED. It
  !!            is the number of orbitals where there is only one
  !!            electron in the orbital.
  !!   - n_partially_filled is the number of orbitals that are neither filled,
  !!            half-filled, nor empty.
  !! The integer arrays filled, partially_filled and half_filled point
  !!   to the indices where the filled, partially filled and half_filled
  !!   orbitals are, respectively.
  !
  subroutine occupied_states(st, namespace, ik, n_filled, n_partially_filled, n_half_filled, &
    filled, partially_filled, half_filled)
    type(states_elec_t),    intent(in)  :: st
    type(namespace_t),      intent(in)  :: namespace
    integer,                intent(in)  :: ik
    integer,                intent(out) :: n_filled, n_partially_filled, n_half_filled
    integer,      optional, intent(out) :: filled(:), partially_filled(:), half_filled(:)

    integer :: ist

    PUSH_SUB(occupied_states)

    if (present(filled))           filled(:) = 0
    if (present(partially_filled)) partially_filled(:) = 0
    if (present(half_filled))      half_filled(:) = 0
    n_filled = 0
    n_partially_filled = 0
    n_half_filled = 0

    select case (st%d%ispin)
    case (UNPOLARIZED)
      do ist = 1, st%nst
        if (abs(st%occ(ist, ik) - M_TWO) < M_MIN_OCC) then
          n_filled = n_filled + 1
          if (present(filled)) filled(n_filled) = ist
        else if (abs(st%occ(ist, ik) - M_ONE) < M_MIN_OCC) then
          n_half_filled = n_half_filled + 1
          if (present(half_filled)) half_filled(n_half_filled) = ist
        else if (st%occ(ist, ik) > M_MIN_OCC) then
          n_partially_filled = n_partially_filled + 1
          if (present(partially_filled)) partially_filled(n_partially_filled) = ist
        else if (abs(st%occ(ist, ik)) > M_MIN_OCC) then
          write(message(1),*) 'Internal error in occupied_states: Illegal occupation value ', st%occ(ist, ik)
          call messages_fatal(1, namespace=namespace)
        end if
      end do
    case (SPIN_POLARIZED, SPINORS)
      do ist = 1, st%nst
        if (abs(st%occ(ist, ik)-M_ONE) < M_MIN_OCC) then
          n_filled = n_filled + 1
          if (present(filled)) filled(n_filled) = ist
        else if (st%occ(ist, ik) > M_MIN_OCC) then
          n_partially_filled = n_partially_filled + 1
          if (present(partially_filled)) partially_filled(n_partially_filled) = ist
        else if (abs(st%occ(ist, ik)) > M_MIN_OCC) then
          write(message(1),*) 'Internal error in occupied_states: Illegal occupation value ', st%occ(ist, ik)
          call messages_fatal(1, namespace=namespace)
        end if
      end do
    end select

    POP_SUB(occupied_states)
  end subroutine occupied_states


  !> @brief distribute k-points over the nodes in the corresponding communicator
  subroutine kpoints_distribute(this, mc)
    type(states_elec_t), intent(inout) :: this
    type(multicomm_t),       intent(in)    :: mc

    PUSH_SUB(kpoints_distribute)
    call distributed_init(this%d%kpt, this%nik, mc%group_comm(P_STRATEGY_KPOINTS), "k-points")

    POP_SUB(kpoints_distribute)
  end subroutine kpoints_distribute


  ! TODO(Alex) Issue #824. Consider converting this to a function to returns `st_kpt_task`
  ! as this is only called in a couple of places, or package with the `st_kpt_mpi_grp` when split
  ! from st instance
  !> @brief Assign the start and end indices for states and kpoints, for "st_kpt_mpi_grp" communicator.
  subroutine states_elec_kpoints_distribution(st)
    type(states_elec_t),    intent(inout) :: st

    PUSH_SUB(states_elec_kpoints_distribution)

    if (.not. allocated(st%st_kpt_task)) then
      SAFE_ALLOCATE(st%st_kpt_task(0:st%st_kpt_mpi_grp%size-1, 1:4))
    end if

    st%st_kpt_task(0:st%st_kpt_mpi_grp%size-1, :) = 0
    st%st_kpt_task(st%st_kpt_mpi_grp%rank, :) = [st%st_start, st%st_end, st%d%kpt%start, st%d%kpt%end]

    if (st%parallel_in_states .or. st%d%kpt%parallel) then
      call comm_allreduce(st%st_kpt_mpi_grp, st%st_kpt_task)
    end if

    POP_SUB(states_elec_kpoints_distribution)
  end subroutine states_elec_kpoints_distribution

  ! ---------------------------------------------------------
  !> @brief double up k-points for SPIN_POLARIZED calculations
  !
  subroutine states_elec_choose_kpoints(st, kpoints, namespace)
    type(states_elec_t), target, intent(inout) :: st
    type(kpoints_t),             intent(in)    :: kpoints
    type(namespace_t),           intent(in)    :: namespace

    integer :: ik, iq
    type(states_elec_dim_t), pointer :: dd


    PUSH_SUB(states_elec_choose_kpoints)

    dd => st%d

    st%nik = kpoints_number(kpoints)

    if (dd%ispin == SPIN_POLARIZED) st%nik = 2*st%nik

    SAFE_ALLOCATE(st%kweights(1:st%nik))

    do iq = 1, st%nik
      ik = dd%get_kpoint_index(iq)
      st%kweights(iq) = kpoints%get_weight(ik)
    end do

    if (debug%info) call print_kpoints_debug
    POP_SUB(states_elec_choose_kpoints)

  contains
    subroutine print_kpoints_debug
      integer :: iunit

      PUSH_SUB(states_elec_choose_kpoints.print_kpoints_debug)

      call io_mkdir('debug/', namespace)
      iunit = io_open('debug/kpoints', namespace, action = 'write')
      call kpoints%write_info(iunit=iunit, absolute_coordinates = .true.)
      call io_close(iunit)

      POP_SUB(states_elec_choose_kpoints.print_kpoints_debug)
    end subroutine print_kpoints_debug

  end subroutine states_elec_choose_kpoints

  !> @brief calculate the expectation value of the dipole moment of electrons
  !!
  !! @note this routine is only meaningful for isolated systems.
  !! For periodic systems, we need to consider the Berry phase.
  !
  function states_elec_calculate_dipole(this, gr) result(dipole)
    class(states_elec_t), intent(in) :: this
    class(mesh_t),        intent(in) :: gr
    real(real64)                     :: dipole(1:gr%box%dim)

    integer :: ispin
    real(real64) :: e_dip(1:gr%box%dim, this%d%spin_channels)

    PUSH_SUB(states_elec_calculate_dipole)

    do ispin = 1, this%d%spin_channels
      call dmf_dipole(gr, this%rho(:, ispin), e_dip(:, ispin))
    end do

    dipole(:) = sum(e_dip(:, 1:this%d%spin_channels), 2)  ! dipole moment <mu_el> = \sum_i -e <x_i>

    POP_SUB(states_elec_calculate_dipole)
  end function states_elec_calculate_dipole


#include "undef.F90"
#include "real.F90"
#include "states_elec_inc.F90"

#include "undef.F90"
#include "complex.F90"
#include "states_elec_inc.F90"
#include "undef.F90"

end module states_elec_oct_m


!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
