!! Copyright (C) 2016-2020 N. Tancogne-Dejean
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

#include "global.h"

module lda_u_oct_m
  use accel_oct_m
  use atomic_orbital_oct_m
  use boundaries_oct_m
  use batch_oct_m
  use batch_ops_oct_m
  use comm_oct_m
  use debug_oct_m
  use derivatives_oct_m
  use distributed_oct_m
  use electron_space_oct_m
  use energy_oct_m
  use global_oct_m
  use grid_oct_m
  use hamiltonian_elec_base_oct_m
  use ions_oct_m
  use kpoints_oct_m
  use lalg_basic_oct_m
  use lattice_vectors_oct_m
  use loct_oct_m
  use loct_math_oct_m
  use loewdin_oct_m
  use math_oct_m
  use mesh_oct_m
  use mesh_function_oct_m
  use messages_oct_m
  use mpi_oct_m
  use multicomm_oct_m
  use namespace_oct_m
  use orbitalbasis_oct_m
  use orbitalset_oct_m
  use orbitalset_utils_oct_m
  use parser_oct_m
  use poisson_oct_m
  use phase_oct_m
  use profiling_oct_m
  use restart_oct_m
  use space_oct_m
  use species_oct_m
  use states_abst_oct_m
  use states_elec_oct_m
  use states_elec_dim_oct_m
  use submesh_oct_m
  use symmetries_oct_m
  use types_oct_m
  use unit_system_oct_m
  use wfs_elec_oct_m

  implicit none

  private

  public ::                          &
    lda_u_t,                         &
    lda_u_init,                      &
    dlda_u_apply,                    &
    zlda_u_apply,                    &
    lda_u_update_basis,              &
    lda_u_update_occ_matrices,       &
    lda_u_end,                       &
    lda_u_build_phase_correction,    &
    lda_u_freeze_occ,                &
    lda_u_freeze_u,                  &
    dlda_u_set_occupations,          &
    zlda_u_set_occupations,          &
    dlda_u_get_occupations,          &
    zlda_u_get_occupations,          &
    dlda_u_update_potential,         &
    zlda_u_update_potential,         &
    lda_u_get_effectiveU,            &
    lda_u_set_effectiveU,            &
    lda_u_get_effectiveV,            &
    lda_u_set_effectiveV,            &
    dlda_u_commute_r,                &
    zlda_u_commute_r,                &
    dlda_u_commute_r_single,         &
    zlda_u_commute_r_single,         &
    dlda_u_force,                    &
    zlda_u_force,                    &
    dlda_u_rvu,                      &
    zlda_u_rvu,                      &
    lda_u_write_info,                &
    compute_ACBNO_U_kanamori,        &
    dcompute_dftu_energy,            &
    zcompute_dftu_energy


  integer, public, parameter ::        &
    DFT_U_NONE                    = 0, &
    DFT_U_EMPIRICAL               = 1, &
    DFT_U_ACBN0                   = 2

  integer, public, parameter ::        &
    DFT_U_FLL                     = 0, &
    DFT_U_AMF                     = 1, &
    DFT_U_MIX                     = 2


  !> @class lda_u_t
  !! \brief Class to describe DFT+U parameters
  type lda_u_t
    private
    integer,            public   :: level = DFT_U_NONE !< DFT+U level.

    ! DFT+U basic variables
    real(real64), allocatable, public   :: dn(:,:,:,:) !< Occupation matrices for the standard scheme
    real(real64), allocatable           :: dV(:,:,:,:) !< Potentials for the standard scheme

    ! ACBN0 variables
    complex(real64), allocatable, public   :: zn(:,:,:,:)
    complex(real64), allocatable           :: zV(:,:,:,:)
    real(real64),    allocatable, public   :: dn_alt(:,:,:,:) !< Stores the renomalized occ. matrices
    complex(real64), allocatable, public   :: zn_alt(:,:,:,:) !< if the ACBN0 functional is used

    real(real64), allocatable           :: renorm_occ(:,:,:,:,:) !< On-site occupations (for the ACBN0 functional)

    ! Coulomb integrales
    real(real64), allocatable, public   :: coulomb(:,:,:,:,:)         !< Coulomb integrals for all the system
    !                                                                 !<(for the ACBN0 functional)
    complex(real64), allocatable, public   :: zcoulomb(:,:,:,:,:,:,:) !< Coulomb integrals for all the system
    !                                                                 !< (for the ACBN0 functional with spinors)

    type(orbitalbasis_t),        public :: basis                !< The full basis of localized orbitals
    type(orbitalset_t), pointer, public :: orbsets(:) => NULL() !< All the orbital setss of the system
    integer,                     public :: norbsets = 0

    integer,              public :: nspins = 0
    integer,              public :: spin_channels = 0
    integer                      :: nspecies = 0
    integer,              public :: maxnorbs = 0           !< Maximal number of orbitals for all the atoms
    integer                      :: max_np = 0             !< Maximum number of points in all orbitals submesh spheres

    logical                      :: useAllOrbitals = .false.       !< Do we use all atomic orbitals possible
    logical                      :: skipSOrbitals = .true.         !< Not using s orbitals
    logical                      :: freeze_occ = .false.           !< Occupation matrices are not recomputed during TD evolution
    logical                      :: freeze_u = .false.             !< U is not recomputed during TD evolution
    logical,              public :: intersite = .false.            !< intersite V are computed or not
    real(real64)                 :: intersite_radius = M_ZERO      !< Maximal distance for considering neighboring atoms
    logical,              public :: basisfromstates = .false.      !< We can construct the localized basis from user-defined states
    real(real64)                 :: acbn0_screening = M_ONE        !< We use or not the screening in the ACBN0 functional
    integer, allocatable         :: basisstates(:)                 !< The indices of states used to construct a localized basis
    integer, allocatable         :: basisstates_os(:)              !< The index of the orbital set to which belongs the state specified in basisstate(:)
    logical                      :: rot_inv = .false.              !< Use a rotationally invariant formula for U and J (ACBN0 case)
    integer                      :: double_couting = DFT_U_FLL     !< Double-couting term
    integer                      :: sm_poisson = SM_POISSON_DIRECT !< Poisson solver used for computing Coulomb integrals
    real(real64), allocatable    :: dc_alpha(:)                    !< Mixing of the double-couting term

    type(lattice_vectors_t), pointer :: latt

    type(distributed_t) :: orbs_dist

    ! Intersite interaction variables
    integer, public     :: maxneighbors = 0
    real(real64), allocatable, public  :: dn_ij(:,:,:,:,:), dn_alt_ij(:,:,:,:,:), dn_alt_ii(:,:,:,:,:)
    complex(real64), allocatable, public  :: zn_ij(:,:,:,:,:), zn_alt_ij(:,:,:,:,:), zn_alt_ii(:,:,:,:,:)

    ! Symmetrization-related variables
    logical                   :: symmetrize_occ_matrices !< Do we symmetrize the occupation matrices or not
    integer, allocatable      :: inv_map_symm(:,:) !< Mapping between orbital sets
    integer                   :: nsym              !< Number of symmetries
    real(real64), allocatable :: symm_weight(:,:,:,:) !< Weight of the rotated orbital in the basis of the orbitals
  end type lda_u_t

contains

  ! ---------------------------------------------------------
  subroutine lda_u_init(this, namespace, space, level, gr, ions, st, mc, kpoints, has_phase)
    type(lda_u_t),     target, intent(inout) :: this
    type(namespace_t),         intent(in)    :: namespace
    class(space_t),            intent(in)    :: space
    integer,                   intent(in)    :: level
    type(grid_t),              intent(in)    :: gr
    type(ions_t),      target, intent(in)    :: ions
    type(states_elec_t),       intent(in)    :: st
    type(multicomm_t),         intent(in)    :: mc
    type(kpoints_t),           intent(in)    :: kpoints
    logical,                   intent(in)    :: has_phase

    integer :: is, ierr
    type(block_t) :: blk

    PUSH_SUB(lda_u_init)

    ASSERT(.not. (level == DFT_U_NONE))

    call messages_print_with_emphasis(msg="DFT+U", namespace=namespace)
    if (gr%parallel_in_domains) call messages_experimental("dft+u parallel in domains", namespace=namespace)
    this%level = level

    this%latt => ions%latt

    !%Variable DFTUBasisFromStates
    !%Type logical
    !%Default no
    !%Section Hamiltonian::DFT+U
    !%Description
    !% If set to yes, Octopus will construct the localized basis from
    !% user-defined states. The states are taken at the Gamma point (or the first k-point of the
    !% states in the restart_proj folder.
    !% The states are defined via the block DFTUBasisStates
    !%End
    call parse_variable(namespace, 'DFTUBasisFromStates', .false., this%basisfromstates)
    if (this%basisfromstates) call messages_experimental("DFTUBasisFromStates", namespace=namespace)

    !%Variable DFTUDoubleCounting
    !%Type integer
    !%Default dft_u_fll
    !%Section Hamiltonian::DFT+U
    !%Description
    !% This variable selects which DFT+U
    !% double counting term is used.
    !%Option dft_u_fll 0
    !% (Default) The Fully Localized Limit (FLL)
    !%Option dft_u_amf 1
    !% (Experimental) Around mean field double counting, as defined in PRB 44, 943 (1991) and PRB 49, 14211 (1994).
    !%Option dft_u_mix 2
    !% (Experimental) Mixed double countind term as introduced by Petukhov et al., PRB 67, 153106 (2003).
    !% This recovers the FLL and AMF as limiting cases.
    !%End
    call parse_variable(namespace, 'DFTUDoubleCounting', DFT_U_FLL, this%double_couting)
    call messages_print_var_option('DFTUDoubleCounting', this%double_couting, namespace=namespace)
    if (this%double_couting /= DFT_U_FLL) call messages_experimental("DFTUDoubleCounting /= dft_u_ffl", namespace=namespace)
    if (st%d%ispin == SPINORS .and. this%double_couting /= DFT_U_FLL) then
      call messages_not_implemented("AMF and MIX double counting with spinors", namespace=namespace)
    end if

    !%Variable DFTUPoissonSolver
    !%Type integer
    !%Section Hamiltonian::DFT+U
    !%Description
    !% This variable selects which Poisson solver
    !% is used to compute the Coulomb integrals over a submesh.
    !% These are non-periodic Poisson solvers.
    !% The FFT Poisson solver with spherical cutoff is used by default.
    !%
    !%Option dft_u_poisson_direct 0
    !% Direct Poisson solver. Slow but working in all cases.
    !%Option dft_u_poisson_isf 1
    !% (Experimental) ISF Poisson solver on a submesh.
    !% This does not work for non-orthogonal cells nor domain parallelization.
    !%Option dft_u_poisson_psolver 2
    !% (Experimental) PSolver Poisson solver on a submesh.
    !% This does not work for non-orthogonal cells nor domain parallelization.
    !% Requires the PSolver external library.
    !%Option dft_u_poisson_fft 3
    !% (Default) FFT Poisson solver on a submesh.
    !% This uses the 0D periodic version of the FFT kernels.
    !%End
    call parse_variable(namespace, 'DFTUPoissonSolver', SM_POISSON_FFT, this%sm_poisson)
    call messages_print_var_option('DFTUPoissonSolver', this%sm_poisson, namespace=namespace)
    if (this%sm_poisson /= SM_POISSON_DIRECT .and. this%sm_poisson /= SM_POISSON_FFT) then
      call messages_experimental("DFTUPoissonSolver different from dft_u_poisson_direct", namespace=namespace)
      call messages_experimental("and dft_u_poisson_fft", namespace=namespace)
    end if
    if (this%sm_poisson == SM_POISSON_ISF) then
      if (gr%parallel_in_domains) then
        call messages_not_implemented("DFTUPoissonSolver=dft_u_poisson_isf with domain parallelization", namespace=namespace)
      end if
      if (ions%latt%nonorthogonal) then
        call messages_not_implemented("DFTUPoissonSolver=dft_u_poisson_isf with non-orthogonal cells", namespace=namespace)
      end if
    end if
    if (this%sm_poisson == SM_POISSON_PSOLVER) then
#if !(defined HAVE_PSOLVER)
      message(1) = "The PSolver Poisson solver cannot be used since the code was not compiled with the PSolver library."
      call messages_fatal(1, namespace=namespace)
#endif
      if (gr%parallel_in_domains) then
        call messages_not_implemented("DFTUPoissonSolver=dft_u_poisson_psolver with domain parallelization", namespace=namespace)
      end if
      if (ions%latt%nonorthogonal) then
        call messages_not_implemented("DFTUPoissonSolver=dft_u_poisson_psolver with non-orthogonal cells", namespace=namespace)
      end if
    end if

    if (this%level == DFT_U_ACBN0) then
      !%Variable UseAllAtomicOrbitals
      !%Type logical
      !%Default no
      !%Section Hamiltonian::DFT+U
      !%Description
      !% If set to yes, Octopus will determine the effective U for all atomic orbitals
      !% from the peusopotential. Only available with ACBN0 functional.
      !% It is strongly recommended to set AOLoewdin=yes when using the option.
      !%End
      call parse_variable(namespace, 'UseAllAtomicOrbitals', .false., this%useAllOrbitals)
      if (this%useAllOrbitals) call messages_experimental("UseAllAtomicOrbitals", namespace=namespace)

      !%Variable SkipSOrbitals
      !%Type logical
      !%Default no
      !%Section Hamiltonian::DFT+U
      !%Description
      !% If set to yes, Octopus will determine the effective U for all atomic orbitals
      !% from the peusopotential but s orbitals. Only available with ACBN0 functional.
      !%End
      call parse_variable(namespace, 'SkipSOrbitals', .true., this%skipSOrbitals)
      if (.not. this%SkipSOrbitals) call messages_experimental("SkipSOrbitals", namespace=namespace)

      !%Variable ACBN0Screening
      !%Type float
      !%Default 1.0
      !%Section Hamiltonian::DFT+U
      !%Description
      !% If set to 0, no screening will be included in the ACBN0 functional, and the U
      !% will be estimated from bare Hartree-Fock. If set to 1 (default), the full screening
      !% of the U, as defined in the ACBN0 functional, is used.
      !%End
      call parse_variable(namespace, 'ACBN0Screening', M_ONE, this%acbn0_screening)
      call messages_print_var_value('ACBN0Screening', this%acbn0_screening, namespace=namespace)

      !%Variable ACBN0RotationallyInvariant
      !%Type logical
      !%Section Hamiltonian::DFT+U
      !%Description
      !% If set to yes, Octopus will use for U and J a formula which is rotationally invariant.
      !% This is different from the original formula for U and J.
      !% This is activated by default, except in the case of spinors, as this is not yet implemented in this case.
      !%End
      call parse_variable(namespace, 'ACBN0RotationallyInvariant', st%d%ispin /= SPINORS, this%rot_inv)
      call messages_print_var_value('ACBN0RotationallyInvariant', this%rot_inv, namespace=namespace)
      if (this%rot_inv .and. st%d%ispin == SPINORS) then
        call messages_not_implemented("Rotationally invariant ACBN0 with spinors", namespace=namespace)
      end if

      !%Variable ACBN0IntersiteInteraction
      !%Type logical
      !%Default no
      !%Section Hamiltonian::DFT+U
      !%Description
      !% If set to yes, Octopus will determine the effective intersite interaction V
      !% Only available with ACBN0 functional.
      !% It is strongly recommended to set AOLoewdin=yes when using the option.
      !%End
      call parse_variable(namespace, 'ACBN0IntersiteInteraction', .false., this%intersite)
      call messages_print_var_value('ACBN0IntersiteInteraction', this%intersite, namespace=namespace)
      if (this%intersite) call messages_experimental("ACBN0IntersiteInteraction", namespace=namespace)

      if (this%intersite) then

        !This is a non local operator. To make this working, one probably needs to apply the
        ! symmetries to the generalized occupation matrices
        if (kpoints%use_symmetries) then
          call messages_not_implemented("Intersite interaction with kpoint symmetries", namespace=namespace)
        end if

        !%Variable ACBN0IntersiteCutoff
        !%Type float
        !%Section Hamiltonian::DFT+U
        !%Description
        !% The cutoff radius defining the maximal intersite distance considered.
        !% Only available with ACBN0 functional with intersite interaction.
        !%End
        call parse_variable(namespace, 'ACBN0IntersiteCutoff', M_ZERO, this%intersite_radius, unit = units_inp%length)
        if (abs(this%intersite_radius) < M_EPSILON) then
          call messages_write("ACBN0IntersiteCutoff must be greater than 0")
          call messages_fatal(1, namespace=namespace)
        end if

      end if

    end if

    call lda_u_write_info(this, namespace=namespace)

    if (.not. this%basisfromstates) then

      call orbitalbasis_init(this%basis, namespace, space%periodic_dim)

      if (states_are_real(st)) then
        call dorbitalbasis_build(this%basis, namespace, ions, gr, st%d%kpt, st%d%dim, &
          this%skipSOrbitals, this%useAllOrbitals)
      else
        call zorbitalbasis_build(this%basis, namespace, ions, gr, st%d%kpt, st%d%dim, &
          this%skipSOrbitals, this%useAllOrbitals)
      end if
      this%orbsets => this%basis%orbsets
      this%norbsets = this%basis%norbsets
      this%maxnorbs = this%basis%maxnorbs
      this%max_np = this%basis%max_np
      this%nspins = st%d%nspin
      this%spin_channels = st%d%spin_channels
      this%nspecies = ions%nspecies

      !We allocate the necessary ressources
      if (states_are_real(st)) then
        call dlda_u_allocate(this, st)
      else
        call zlda_u_allocate(this, st)
      end if

      call distributed_nullify(this%orbs_dist, this%norbsets)
#ifdef HAVE_MPI
      if (.not. gr%parallel_in_domains) then
        call distributed_init(this%orbs_dist, this%norbsets, MPI_COMM_WORLD, "orbsets")
      end if
#endif

    else

      !%Variable DFTUBasisStates
      !%Type block
      !%Default none
      !%Section Hamiltonian::DFT+U
      !%Description
      !% This block starts by a line containing a single integer describing the number of
      !% orbital sets. One orbital set is a group of orbitals on which one adds a Hubbard U.
      !% Each following line of this block contains the index of a state to be used to construct the
      !% localized basis, followed by the index of the corresponding orbital set.
      !% See DFTUBasisFromStates for details.
      !%End
      if (parse_block(namespace, 'DFTUBasisStates', blk) == 0) then
        call parse_block_integer(blk, 0, 0, this%norbsets)
        this%maxnorbs = parse_block_n(blk)-1
        if (this%maxnorbs <1) then
          write(message(1),'(a,i3,a,i3)') 'DFTUBasisStates must contains at least one state.'
          call messages_fatal(1, namespace=namespace)
        end if
        SAFE_ALLOCATE(this%basisstates(1:this%maxnorbs))
        SAFE_ALLOCATE(this%basisstates_os(1:this%maxnorbs))
        do is = 1, this%maxnorbs
          call parse_block_integer(blk, is, 0, this%basisstates(is))
          call parse_block_integer(blk, is, 1, this%basisstates_os(is))
        end do
        call parse_block_end(blk)
      else
        write(message(1),'(a,i3,a,i3)') 'DFTUBasisStates must be specified if DFTUBasisFromStates=yes'
        call messages_fatal(1, namespace=namespace)
      end if

      if (states_are_real(st)) then
        call dorbitalbasis_build_empty(this%basis, gr, st%d%kpt, st%d%dim, this%norbsets, this%basisstates_os)
      else
        call zorbitalbasis_build_empty(this%basis, gr, st%d%kpt, st%d%dim, this%norbsets, this%basisstates_os)
      end if

      this%max_np = gr%np
      this%nspins = st%d%nspin
      this%spin_channels = st%d%spin_channels
      this%nspecies = 1

      this%orbsets => this%basis%orbsets

      call distributed_nullify(this%orbs_dist, this%norbsets)

      !We allocate the necessary ressources
      if (states_are_real(st)) then
        call dlda_u_allocate(this, st)
      else
        call zlda_u_allocate(this, st)
      end if

      call lda_u_loadbasis(this, namespace, space, st, gr, mc, ierr)
      if (ierr /= 0) then
        message(1) = "Unable to load DFT+U basis from selected states."
        call messages_fatal(1)
      end if

    end if

    ! Symmetrization of the occupation matrices
    this%symmetrize_occ_matrices = st%symmetrize_density .or. kpoints%use_symmetries
    if (this%basisfromstates) this%symmetrize_occ_matrices = .false.
    if (this%symmetrize_occ_matrices) then
      call build_symmetrization_map(this, ions, gr, st)
    end if

    SAFE_ALLOCATE(this%dc_alpha(this%norbsets))
    this%dc_alpha = M_ONE

    call messages_print_with_emphasis(namespace=namespace)

    POP_SUB(lda_u_init)
  end subroutine lda_u_init


  ! ---------------------------------------------------------
  subroutine lda_u_init_coulomb_integrals(this, namespace, space, gr, st, psolver, has_phase)
    type(lda_u_t),     target, intent(inout) :: this
    type(namespace_t),         intent(in)    :: namespace
    class(space_t),            intent(in)    :: space
    type(grid_t),              intent(in)    :: gr
    type(states_elec_t),       intent(in)    :: st
    type(poisson_t),           intent(in)    :: psolver
    logical,                   intent(in)    :: has_phase

    logical :: complex_coulomb_integrals
    integer :: ios
    integer :: norbs

    PUSH_SUB(lda_u_init_coulomb_integrals)

    norbs = this%maxnorbs

    if (.not. this%basisfromstates) then

      if (this%level == DFT_U_ACBN0) then

        complex_coulomb_integrals = .false.
        do ios = 1, this%norbsets
          if (this%orbsets(ios)%ndim  > 1) complex_coulomb_integrals = .true.
        end do

        if (.not. complex_coulomb_integrals) then
          write(message(1),'(a)')    'Computing the Coulomb integrals of the localized basis.'
          call messages_info(1, namespace=namespace)
          SAFE_ALLOCATE(this%coulomb(1:norbs,1:norbs,1:norbs,1:norbs, 1:this%norbsets))
          if (states_are_real(st)) then
            call dcompute_coulomb_integrals(this, namespace, space, gr, psolver)
          else
            call zcompute_coulomb_integrals(this, namespace, space, gr, psolver)
          end if
        else
          ASSERT(.not. states_are_real(st))
          write(message(1),'(a)')    'Computing complex Coulomb integrals of the localized basis.'
          call messages_info(1, namespace=namespace)
          SAFE_ALLOCATE(this%zcoulomb(1:norbs, 1:norbs, 1:norbs, 1:norbs, 1:st%d%dim, 1:st%d%dim, 1:this%norbsets))
          call compute_complex_coulomb_integrals(this, gr, st, psolver, namespace, space)
        end if
      end if

    else
      write(message(1),'(a)')    'Computing the Coulomb integrals of the localized basis.'
      call messages_info(1, namespace=namespace)
      SAFE_ALLOCATE(this%coulomb(1:norbs, 1:norbs, 1:norbs, 1:norbs, 1:this%norbsets))
      if (states_are_real(st)) then
        call dcompute_coulomb_integrals(this, namespace, space, gr, psolver)
      else
        call zcompute_coulomb_integrals(this, namespace, space, gr, psolver)
      end if

    end if

    POP_SUB(lda_u_init_coulomb_integrals)

  end subroutine lda_u_init_coulomb_integrals


  ! ---------------------------------------------------------
  subroutine lda_u_end(this)
    implicit none
    type(lda_u_t), intent(inout) :: this

    PUSH_SUB(lda_u_end)

    this%level = DFT_U_NONE

    SAFE_DEALLOCATE_A(this%dn)
    SAFE_DEALLOCATE_A(this%zn)
    SAFE_DEALLOCATE_A(this%dn_alt)
    SAFE_DEALLOCATE_A(this%zn_alt)
    SAFE_DEALLOCATE_A(this%dV)
    SAFE_DEALLOCATE_A(this%zV)
    SAFE_DEALLOCATE_A(this%coulomb)
    SAFE_DEALLOCATE_A(this%zcoulomb)
    SAFE_DEALLOCATE_A(this%renorm_occ)
    SAFE_DEALLOCATE_A(this%dn_ij)
    SAFE_DEALLOCATE_A(this%zn_ij)
    SAFE_DEALLOCATE_A(this%dn_alt_ij)
    SAFE_DEALLOCATE_A(this%zn_alt_ij)
    SAFE_DEALLOCATE_A(this%dn_alt_ii)
    SAFE_DEALLOCATE_A(this%zn_alt_ii)
    SAFE_DEALLOCATE_A(this%basisstates)
    SAFE_DEALLOCATE_A(this%basisstates_os)
    SAFE_DEALLOCATE_A(this%dc_alpha)
    SAFE_DEALLOCATE_A(this%inv_map_symm)
    SAFE_DEALLOCATE_A(this%symm_weight)

    nullify(this%orbsets)
    call orbitalbasis_end(this%basis)

    this%max_np = 0

    if (.not. this%basisfromstates) then
      call distributed_end(this%orbs_dist)
    end if

    POP_SUB(lda_u_end)
  end subroutine lda_u_end

  ! When moving the ions, the basis must be reconstructed
  subroutine lda_u_update_basis(this, space, gr, ions, st, psolver, namespace, kpoints, has_phase)
    type(lda_u_t),     target, intent(inout) :: this
    class(space_t),            intent(in)    :: space
    type(grid_t),              intent(in)    :: gr
    type(ions_t),      target, intent(in)    :: ions
    type(states_elec_t),       intent(in)    :: st
    type(poisson_t),           intent(in)    :: psolver
    type(namespace_t),         intent(in)    :: namespace
    type(kpoints_t),           intent(in)    :: kpoints
    logical,                   intent(in)    :: has_phase

    integer :: ios, maxorbs, nspin

    if(this%level == DFT_U_NONE) return

    PUSH_SUB(lda_u_update_basis)

    if(.not. this%basisfromstates) then
      !We clean the orbital basis, to be able to reconstruct it
      call orbitalbasis_end(this%basis)
      nullify(this%orbsets)

      !We now reconstruct the basis
      if (states_are_real(st)) then
        call dorbitalbasis_build(this%basis, namespace, ions, gr, st%d%kpt, st%d%dim, &
          this%skipSOrbitals, this%useAllOrbitals, verbose = .false.)
      else
        call zorbitalbasis_build(this%basis, namespace, ions, gr, st%d%kpt, st%d%dim, &
          this%skipSOrbitals, this%useAllOrbitals, verbose = .false.)
      end if
      this%orbsets => this%basis%orbsets
      this%max_np = this%basis%max_np
    end if

    !In case of intersite interaction we need to reconstruct the basis
    if (this%intersite) then
      this%maxneighbors = 0
      do ios = 1, this%norbsets
        call orbitalset_init_intersite(this%orbsets(ios), namespace, space, ios, ions, gr%der, psolver, &
          this%orbsets, this%norbsets, this%maxnorbs, this%intersite_radius, st%d%kpt, has_phase, &
          this%sm_poisson, this%basisfromstates, this%basis%combine_j_orbitals)
        this%maxneighbors = max(this%maxneighbors, this%orbsets(ios)%nneighbors)
      end do

      maxorbs = this%maxnorbs
      nspin = this%nspins

      if (states_are_real(st)) then
        SAFE_DEALLOCATE_A(this%dn_ij)
        SAFE_ALLOCATE(this%dn_ij(1:maxorbs,1:maxorbs,1:nspin,1:this%norbsets,1:this%maxneighbors))
        this%dn_ij(1:maxorbs,1:maxorbs,1:nspin,1:this%norbsets,1:this%maxneighbors) = M_ZERO
        SAFE_DEALLOCATE_A(this%dn_alt_ij)
        SAFE_ALLOCATE(this%dn_alt_ij(1:maxorbs,1:maxorbs,1:nspin,1:this%norbsets,1:this%maxneighbors))
        this%dn_alt_ij(1:maxorbs,1:maxorbs,1:nspin,1:this%norbsets,1:this%maxneighbors) = M_ZERO
        SAFE_DEALLOCATE_A(this%dn_alt_ii)
        SAFE_ALLOCATE(this%dn_alt_ii(1:2,1:maxorbs,1:nspin,1:this%norbsets,1:this%maxneighbors))
        this%dn_alt_ii(1:2,1:maxorbs,1:nspin,1:this%norbsets,1:this%maxneighbors) = M_ZERO
      else
        SAFE_DEALLOCATE_A(this%zn_ij)
        SAFE_ALLOCATE(this%zn_ij(1:maxorbs,1:maxorbs,1:nspin,1:this%norbsets,1:this%maxneighbors))
        this%zn_ij(1:maxorbs,1:maxorbs,1:nspin,1:this%norbsets,1:this%maxneighbors) = M_Z0
        SAFE_DEALLOCATE_A(this%zn_alt_ij)
        SAFE_ALLOCATE(this%zn_alt_ij(1:maxorbs,1:maxorbs,1:nspin,1:this%norbsets,1:this%maxneighbors))
        this%zn_alt_ij(1:maxorbs,1:maxorbs,1:nspin,1:this%norbsets,1:this%maxneighbors) = M_Z0
        SAFE_DEALLOCATE_A(this%zn_alt_ii)
        SAFE_ALLOCATE(this%zn_alt_ii(1:2,1:maxorbs,1:nspin,1:this%norbsets,1:this%maxneighbors))
        this%zn_alt_ii(1:2,1:maxorbs,1:nspin,1:this%norbsets,1:this%maxneighbors) = M_Z0
      end if
    end if

    ! We rebuild the phase for the orbital projection, similarly to the one of the pseudopotentials
    ! In case of a laser field, the phase is recomputed in hamiltonian_elec_update
    if (has_phase) then
      call lda_u_build_phase_correction(this, space, st%d, gr%der%boundaries, namespace, kpoints)
    else
      if(.not. this%basisfromstates) then
        !In case there is no phase, we perform the orthogonalization here
        if(this%basis%orthogonalization) then
          call dloewdin_orthogonalize(this%basis, st%d%kpt, namespace)
        else
          if(debug%info .and. space%is_periodic()) then
            call dloewdin_info(this%basis, st%d%kpt, namespace)
          end if
        end if
      end if
    end if

    ! Rebuild the Coulomb integrals
    if (allocated(this%coulomb)) then
      SAFE_DEALLOCATE_A(this%coulomb)
    end if
    if (allocated(this%zcoulomb)) then
      SAFE_DEALLOCATE_A(this%zcoulomb)
    end if
    call lda_u_init_coulomb_integrals(this, namespace, space, gr, st, psolver, has_phase)

    POP_SUB(lda_u_update_basis)

  end subroutine lda_u_update_basis

  ! Interface for the X(update_occ_matrices) routines
  subroutine lda_u_update_occ_matrices(this, namespace, mesh, st, hm_base, phase, energy)
    type(lda_u_t),                 intent(inout) :: this
    type(namespace_t),             intent(in)    :: namespace
    class(mesh_t),                 intent(in)    :: mesh
    type(states_elec_t),           intent(inout) :: st
    type(hamiltonian_elec_base_t), intent(in)    :: hm_base
    type(phase_t),                 intent(in)    :: phase
    type(energy_t),                intent(inout) :: energy

    if (this%level == DFT_U_NONE .or. this%freeze_occ) return
    PUSH_SUB(lda_u_update_occ_matrices)

    if (states_are_real(st)) then
      call dupdate_occ_matrices(this, namespace, mesh, st, energy%dft_u)
    else
      if (phase%is_allocated()) then
        call zupdate_occ_matrices(this, namespace, mesh, st, energy%dft_u, phase)
      else
        call zupdate_occ_matrices(this, namespace, mesh, st, energy%dft_u)
      end if
    end if

    POP_SUB(lda_u_update_occ_matrices)
  end subroutine lda_u_update_occ_matrices


  !> Build the phase correction to the global phase for all orbitals
  subroutine lda_u_build_phase_correction(this, space, std, boundaries, namespace, kpoints, vec_pot, vec_pot_var)
    type(lda_u_t),                 intent(inout) :: this
    class(space_t),                intent(in)    :: space
    type(states_elec_dim_t),       intent(in)    :: std
    type(boundaries_t),            intent(in)    :: boundaries
    type(namespace_t),             intent(in)    :: namespace
    type(kpoints_t),               intent(in)    :: kpoints
    real(real64), optional,  allocatable, intent(in)    :: vec_pot(:) !< (space%dim)
    real(real64), optional,  allocatable, intent(in)    :: vec_pot_var(:, :) !< (1:space%dim, 1:ns)

    integer :: ios

    if (boundaries%spiralBC) call messages_not_implemented("DFT+U with spiral boundary conditions", &
      namespace=namespace)

    PUSH_SUB(lda_u_build_phase_correction)

    write(message(1), '(a)') 'Debug: Building the phase correction for DFT+U orbitals.'
    call messages_info(1, namespace=namespace, debug_only=.true.)

    do ios = 1, this%norbsets
      call orbitalset_update_phase(this%orbsets(ios), space%dim, std%kpt, kpoints, &
        (std%ispin==SPIN_POLARIZED), vec_pot, vec_pot_var)
      call orbitalset_update_phase_shift(this%orbsets(ios), space%dim, std%kpt, kpoints, &
        (std%ispin==SPIN_POLARIZED), vec_pot, vec_pot_var)
    end do

    if (.not. this%basisfromstates) then
      if (this%basis%orthogonalization) then
        call zloewdin_orthogonalize(this%basis, std%kpt, namespace)
      else
        if (debug%info .and. space%is_periodic()) call zloewdin_info(this%basis, std%kpt, namespace)
      end if
    end if

    POP_SUB(lda_u_build_phase_correction)

  end subroutine lda_u_build_phase_correction

  ! ---------------------------------------------------------
  subroutine compute_ACBNO_U_kanamori(this, st, kanamori)
    type(lda_u_t),        intent(in)  :: this
    type(states_elec_t),  intent(in)  :: st
    real(real64),         intent(out) :: kanamori(:,:)

    if (this%nspins == 1) then
      if (states_are_real(st)) then
        call dcompute_ACBNO_U_kanamori_restricted(this, kanamori)
      else
        call zcompute_ACBNO_U_kanamori_restricted(this, kanamori)
      end if
    else
      if (states_are_real(st)) then
        call dcompute_ACBNO_U_kanamori(this, kanamori)
      else
        call zcompute_ACBNO_U_kanamori(this, kanamori)
      end if
    end if


  end subroutine compute_ACBNO_U_kanamori

  ! ---------------------------------------------------------
  subroutine lda_u_freeze_occ(this)
    type(lda_u_t),     intent(inout) :: this

    this%freeze_occ = .true.
  end subroutine lda_u_freeze_occ

  ! ---------------------------------------------------------
  subroutine lda_u_freeze_u(this)
    type(lda_u_t),     intent(inout) :: this

    this%freeze_u = .true.
  end subroutine lda_u_freeze_u

  ! ---------------------------------------------------------
  subroutine lda_u_set_effectiveU(this, Ueff)
    type(lda_u_t),  intent(inout) :: this
    real(real64),   intent(in)    :: Ueff(:) !< (this%norbsets)

    integer :: ios

    PUSH_SUB(lda_u_set_effectiveU)

    do ios = 1,this%norbsets
      this%orbsets(ios)%Ueff = Ueff(ios)
    end do

    POP_SUB(lda_u_set_effectiveU)
  end subroutine lda_u_set_effectiveU

  ! ---------------------------------------------------------
  subroutine lda_u_get_effectiveU(this, Ueff)
    type(lda_u_t),  intent(in)    :: this
    real(real64),   intent(inout) :: Ueff(:) !< (this%norbsets)

    integer :: ios

    PUSH_SUB(lda_u_get_effectiveU)

    do ios = 1,this%norbsets
      Ueff(ios) = this%orbsets(ios)%Ueff
    end do

    POP_SUB(lda_u_get_effectiveU)
  end subroutine lda_u_get_effectiveU

  ! ---------------------------------------------------------
  subroutine lda_u_set_effectiveV(this, Veff)
    type(lda_u_t),  intent(inout) :: this
    real(real64),   intent(in)    :: Veff(:)

    integer :: ios, ncount

    PUSH_SUB(lda_u_set_effectiveV)

    ncount = 0
    do ios = 1, this%norbsets
      this%orbsets(ios)%V_ij(1:this%orbsets(ios)%nneighbors,0) = Veff(ncount+1:ncount+this%orbsets(ios)%nneighbors)
      ncount = ncount + this%orbsets(ios)%nneighbors
    end do

    POP_SUB(lda_u_set_effectiveV)
  end subroutine lda_u_set_effectiveV

  ! ---------------------------------------------------------
  subroutine lda_u_get_effectiveV(this, Veff)
    type(lda_u_t),  intent(in)    :: this
    real(real64),   intent(inout) :: Veff(:)

    integer :: ios, ncount

    PUSH_SUB(lda_u_get_effectiveV)

    ncount = 0
    do ios = 1, this%norbsets
      Veff(ncount+1:ncount+this%orbsets(ios)%nneighbors) = this%orbsets(ios)%V_ij(1:this%orbsets(ios)%nneighbors,0)
      ncount = ncount + this%orbsets(ios)%nneighbors
    end do

    POP_SUB(lda_u_get_effectiveV)
  end subroutine lda_u_get_effectiveV

  ! ---------------------------------------------------------
  subroutine lda_u_write_info(this, iunit, namespace)
    type(lda_u_t),               intent(in) :: this
    integer,           optional, intent(in) :: iunit
    type(namespace_t), optional, intent(in) :: namespace

    PUSH_SUB(lda_u_write_info)

    write(message(1), '(1x)')
    call messages_info(1, iunit=iunit, namespace=namespace)
    if (this%level == DFT_U_EMPIRICAL) then
      write(message(1), '(a)') "Method:"
      write(message(2), '(a)') "  [1] Dudarev et al., Phys. Rev. B 57, 1505 (1998)"
      call messages_info(2, iunit=iunit, namespace=namespace)
    else
      if (.not. this%intersite) then
        write(message(1), '(a)') "Method:"
        write(message(2), '(a)') "  [1] Agapito et al., Phys. Rev. X 5, 011006 (2015)"
      else
        write(message(1), '(a)') "Method:"
        write(message(2), '(a)') "  [1] Tancogne-Dejean, and Rubio, Phys. Rev. B 102, 155117 (2020)"
      end if
      call messages_info(2, iunit=iunit, namespace=namespace)
    end if
    write(message(1), '(a)') "Implementation:"
    write(message(2), '(a)') "  [1] Tancogne-Dejean, Oliveira, and Rubio, Phys. Rev. B 69, 245133 (2017)"
    write(message(3), '(1x)')
    call messages_info(3, iunit=iunit, namespace=namespace)

    POP_SUB(lda_u_write_info)

  end subroutine lda_u_write_info

  ! ---------------------------------------------------------
  subroutine lda_u_loadbasis(this, namespace, space, st, mesh, mc, ierr)
    type(lda_u_t),        intent(inout) :: this
    type(namespace_t),    intent(in)    :: namespace
    class(space_t),       intent(in)    :: space
    type(states_elec_t),  intent(in)    :: st
    class(mesh_t),        intent(in)    :: mesh
    type(multicomm_t),    intent(in)    :: mc
    integer,              intent(out)   :: ierr

    integer :: err, wfns_file, is, ist, idim, ik, ios, iorb
    type(restart_t) :: restart_gs
    character(len=256)   :: lines(3)
    character(len=256), allocatable :: restart_file(:, :)
    logical,            allocatable :: restart_file_present(:, :)
    character(len=12)    :: filename
    character(len=1)     :: char
    character(len=50)    :: str
    type(orbitalset_t), pointer :: os
    integer, allocatable :: count(:)
    real(real64)         :: norm, center(space%dim)
    real(real64), allocatable   :: dpsi(:,:,:)
    complex(real64), allocatable   :: zpsi(:,:,:)


    PUSH_SUB(lda_u_loadbasis)

    ierr = 0

    message(1) = "Debug: Loading DFT+U basis from states."
    call messages_info(1, debug_only=.true.)

    call restart_init(restart_gs, namespace, RESTART_PROJ, RESTART_TYPE_LOAD, mc, err, mesh=mesh)

    ! If any error occured up to this point then it is not worth continuing,
    ! as there something fundamentally wrong with the restart files
    if (err /= 0) then
      call restart_end(restart_gs)
      message(1) = "Error loading DFT+U basis from states, cannot proceed with the calculation"
      call messages_fatal(1)
      POP_SUB(lda_u_loadbasis)
      return
    end if

    ! open files to read
    wfns_file  = restart_open(restart_gs, 'wfns')
    call restart_read(restart_gs, wfns_file, lines, 2, err)
    if (err /= 0) then
      ierr = ierr - 2**5
    else if (states_are_real(st)) then
      read(lines(2), '(a)') str
      if (str(2:8) == 'Complex') then
        message(1) = "Cannot read real states from complex wavefunctions."
        call messages_fatal(1, namespace=namespace)
      else if (str(2:5) /= 'Real') then
        message(1) = "Restart file 'wfns' does not specify real/complex; cannot check compatibility."
        call messages_warning(1, namespace=namespace)
      end if
    end if
    ! complex can be restarted from real, so there is no problem.

    ! If any error occured up to this point then it is not worth continuing,
    ! as there something fundamentally wrong with the restart files
    if (err /= 0) then
      call restart_close(restart_gs, wfns_file)
      call restart_end(restart_gs)
      message(1) = "Error loading DFT+U basis from states, cannot proceed with the calculation"
      call messages_fatal(1)
      POP_SUB(lda_u_loadbasis)
      return
    end if

    SAFE_ALLOCATE(restart_file(1:st%d%dim, 1:st%nst))
    SAFE_ALLOCATE(restart_file_present(1:st%d%dim, 1:st%nst))
    restart_file_present = .false.

    ! Next we read the list of states from the files.
    ! Errors in reading the information of a specific state from the files are ignored
    ! at this point, because later we will skip reading the wavefunction of that state.
    do
      call restart_read(restart_gs, wfns_file, lines, 1, err)
      if (err == 0) then
        read(lines(1), '(a)') char
        if (char == '%') then
          !We reached the end of the file
          exit
        else
          read(lines(1), *) ik, char, ist, char, idim, char, filename
        end if
      end if

      if (any(this%basisstates==ist) .and. ik == 1) then
        restart_file(idim, ist) = trim(filename)
        restart_file_present(idim, ist) = .true.
      end if
    end do
    call restart_close(restart_gs, wfns_file)

    !We loop over the states we need
    SAFE_ALLOCATE(count(1:this%norbsets))
    count = 0
    do is = 1, this%maxnorbs
      ist = this%basisstates(is)
      ios = this%basisstates_os(is)
      count(ios) = count(ios)+1
      do idim = 1, st%d%dim

        if (.not. restart_file_present(idim, ist)) then
          write(message(1), '(a,i3,a)') "Cannot read states ", ist, "from the projection folder"
          call messages_fatal(1, namespace=namespace)
        end if

        if (states_are_real(st)) then
          call restart_read_mesh_function(restart_gs, space, restart_file(idim, ist), mesh, &
            this%orbsets(ios)%dorb(:,idim,count(ios)), err)
        else
          call restart_read_mesh_function(restart_gs, space, restart_file(idim, ist), mesh, &
            this%orbsets(ios)%zorb(:,idim,count(ios)), err)
        end if

      end do
    end do
    SAFE_DEALLOCATE_A(count)
    SAFE_DEALLOCATE_A(restart_file)
    SAFE_DEALLOCATE_A(restart_file_present)
    call restart_end(restart_gs)

    ! Normalize the orbitals. This is important if we use Wannier orbitals instead of KS states
    if(this%basis%normalize) then
      do ios = 1, this%norbsets
        do iorb = 1, this%orbsets(ios)%norbs
          if (states_are_real(st)) then
            norm = dmf_nrm2(mesh, st%d%dim, this%orbsets(ios)%dorb(:,:,iorb))
            call lalg_scal(mesh%np, st%d%dim, M_ONE/norm, this%orbsets(ios)%dorb(:,:,iorb))
          else
            norm = zmf_nrm2(mesh, st%d%dim, this%orbsets(ios)%zorb(:,:,iorb))
            call lalg_scal(mesh%np, st%d%dim, M_ONE/norm, this%orbsets(ios)%zorb(:,:,iorb))
          end if
        end do
      end do
    end if

    ! We rotate the orbitals in the complex plane to have them as close as possible to real functions
    if(states_are_complex(st) .and. st%d%dim == 1) then
      do ios = 1, this%norbsets
        do iorb = 1, this%orbsets(ios)%norbs
          call zmf_fix_phase(mesh, this%orbsets(ios)%zorb(:,1,iorb))
        end do
      end do
    end if

    ! We determine the center of charge by computing <w|r|w>
    ! We could also determine the spread by \Omega = <w|r^2|w> - <w|r|w>^2
    do ios = 1, this%norbsets
      if (states_are_real(st)) then
        call dorbitalset_get_center_of_mass(this%orbsets(ios), space, mesh, this%latt)
      else
        call zorbitalset_get_center_of_mass(this%orbsets(ios), space, mesh, this%latt)
      end if
    end do

    message(1) = "Debug: Converting the Wannier states to submeshes."
    call messages_info(1, debug_only=.true.)

    ! We now transfer the states to a submesh centered on the center of mass of the Wannier orbitals
    this%max_np = 0
    do ios = 1, this%norbsets
      os => this%orbsets(ios)
      center = os%sphere%center
      SAFE_DEALLOCATE_A(os%sphere%center)
      if (states_are_real(st)) then
        SAFE_ALLOCATE(dpsi(1:mesh%np, 1:os%ndim, 1:os%norbs))
        dpsi(1:mesh%np, 1:os%ndim, 1:os%norbs) = os%dorb(1:mesh%np, 1:os%ndim, 1:os%norbs)

        SAFE_DEALLOCATE_A(os%dorb)
        !We initialise the submesh corresponding to the orbital
        call submesh_init(os%sphere, space, mesh, this%latt, center, os%radius)
        SAFE_ALLOCATE(os%dorb(1:os%sphere%np, 1:os%ndim, 1:os%norbs))
        do iorb = 1, os%norbs
          do idim = 1, os%ndim
            call dsubmesh_copy_from_mesh(os%sphere, dpsi(:,idim,iorb), os%dorb(:,idim, iorb))
          end do
        end do
        SAFE_DEALLOCATE_A(dpsi)
      else
        SAFE_ALLOCATE(zpsi(1:mesh%np, 1:os%ndim, 1:os%norbs))
        zpsi(1:mesh%np, 1:os%ndim, 1:os%norbs) = os%zorb(1:mesh%np, 1:os%ndim, 1:os%norbs)
        SAFE_DEALLOCATE_A(os%zorb)
        !We initialise the submesh corresponding to the orbital
        call submesh_init(os%sphere, space, mesh, this%latt, center, os%radius)
        SAFE_ALLOCATE(os%zorb(1:os%sphere%np, 1:os%ndim, 1:os%norbs))
        do iorb = 1, os%norbs
          do idim = 1, os%ndim
            call zsubmesh_copy_from_mesh(os%sphere, zpsi(:,idim,iorb), os%zorb(:,idim, iorb))
          end do
        end do
        SAFE_DEALLOCATE_A(zpsi)

        SAFE_ALLOCATE(os%phase(1:os%sphere%np, st%d%kpt%start:st%d%kpt%end))
        SAFE_ALLOCATE(os%eorb_submesh(1:os%sphere%np, 1:os%ndim, 1:os%norbs, st%d%kpt%start:st%d%kpt%end))
      end if
      os%use_submesh = .true. ! We are now on a submesh
      this%max_np = max(this%max_np, os%sphere%np)
    end do

    this%basis%use_submesh = .true.

    ! If we use GPUs, we need to transfert the orbitals on the device
    if (accel_is_enabled() .and. st%d%dim == 1) then
      do ios = 1, this%norbsets
        os => this%orbsets(ios)

        os%ldorbs = max(accel_padded_size(os%sphere%np), 1)
        os%ldorbs_eorb = max(accel_padded_size(os%sphere%np), 1)
        if (states_are_real(st)) then
          call accel_create_buffer(os%dbuff_orb, ACCEL_MEM_READ_ONLY, TYPE_FLOAT, os%ldorbs*os%norbs)
        else
          call accel_create_buffer(os%zbuff_orb, ACCEL_MEM_READ_ONLY, TYPE_CMPLX, os%ldorbs*os%norbs)
          SAFE_ALLOCATE(os%buff_eorb(st%d%kpt%start:st%d%kpt%end))

          do ik= st%d%kpt%start, st%d%kpt%end
            call accel_create_buffer(os%buff_eorb(ik), ACCEL_MEM_READ_ONLY, TYPE_CMPLX, os%ldorbs_eorb*os%norbs)
          end do
        end if

        call accel_create_buffer(os%sphere%buff_map, ACCEL_MEM_READ_ONLY, TYPE_INTEGER, max(os%sphere%np, 1))
        call accel_write_buffer(os%sphere%buff_map, os%sphere%np, os%sphere%map)

        do iorb = 1, os%norbs
          if(states_are_complex(st)) then
            call accel_write_buffer(os%zbuff_orb, os%sphere%np, os%zorb(:, 1, iorb), &
              offset = (iorb - 1)*os%ldorbs)
          else
            call accel_write_buffer(os%dbuff_orb, os%sphere%np, os%dorb(:, 1, iorb), &
              offset = (iorb - 1)*os%ldorbs)
          end if
        end do
      end do
    end if


    message(1) = "Debug: Loading DFT+U basis from states done."
    call messages_info(1, debug_only=.true.)

    POP_SUB(lda_u_loadbasis)
  end subroutine lda_u_loadbasis

  !> @brief Builds a mapping between the orbital sets based on symmetries
  subroutine build_symmetrization_map(this, ions, gr, st)
    type(lda_u_t),        intent(inout) :: this
    type(ions_t),         intent(in)    :: ions
    type(grid_t),         intent(in)    :: gr
    type(states_elec_t),  intent(in)    :: st

    integer :: nsym, iop, ios, iatom, iatom_sym, ios_sym

    PUSH_SUB(build_symmetrization_map)

    nsym = ions%symm%nops
    SAFE_ALLOCATE(this%inv_map_symm(1:this%norbsets, 1:nsym))
    this%inv_map_symm = -1

    this%nsym = nsym

    do ios = 1, this%norbsets
      iatom = this%orbsets(ios)%iatom
      do iop = 1, nsym
        iatom_sym = ions%inv_map_symm_atoms(iatom, iop)

        do ios_sym = 1, this%norbsets
          if (this%orbsets(ios_sym)%iatom == iatom_sym .and. this%orbsets(ios_sym)%norbs == this%orbsets(ios)%norbs &
            .and. this%orbsets(ios_sym)%nn == this%orbsets(ios)%nn .and. this%orbsets(ios_sym)%ll == this%orbsets(ios)%ll &
            .and. is_close(this%orbsets(ios_sym)%jj, this%orbsets(ios)%jj)) then
            this%inv_map_symm(ios, iop) = ios_sym
            exit
          end if
        end do
        ASSERT(this%inv_map_symm(ios, iop) > 0)
      end do
    end do

    SAFE_ALLOCATE(this%symm_weight(1:this%maxnorbs, 1:this%maxnorbs, 1:this%nsym, 1:this%norbsets))
    this%symm_weight = M_ZERO

    do ios = 1, this%norbsets
      ! s-orbitals
      if (this%orbsets(ios)%norbs == 1) then
        this%symm_weight(1,1, 1:this%nsym, ios) = M_ONE
        cycle
      end if

      ! Not implemented yet
      if (this%orbsets(ios)%ndim > 1) cycle

      call orbitals_get_symm_weight(this%orbsets(ios), ions%space, ions%latt, gr, ions%symm, this%symm_weight(:,:,:,ios))
    end do

    POP_SUB(build_symmetrization_map)
  end subroutine build_symmetrization_map

!>@brief Computes the weight of each rotated orbitals in the basis of the same localized subspace
  subroutine orbitals_get_symm_weight(os, space, latt, gr, symm, weight)
    type(orbitalset_t),   intent(in)    :: os
    type(space_t),        intent(in)    :: space
    type(lattice_vectors_t), intent(in) :: latt
    type(grid_t),         intent(in)    :: gr
    type(symmetries_t),   intent(in)    :: symm
    real(real64),         intent(inout) :: weight(:,:,:)

    integer :: im, imp, iop, mm
    real(real64), allocatable :: orb(:,:), orb_sym(:), ylm(:)
    type(submesh_t) :: sphere
    real(real64) :: rc, norm, origin(space%dim)

    PUSH_SUB(orbitals_get_symm_weight)

    ASSERT(os%ndim == 1)

    SAFE_ALLOCATE(orb_sym(1:gr%np))
    SAFE_ALLOCATE(orb(1:gr%np, 1:os%norbs))

    ASSERT(2*os%ll+1 == os%norbs)

    ! We generate an artificial submesh to compute the symmetries on it
    ! The radius is such that we fit in 20 points
    rc = (50.0_real64 * M_THREE/M_FOUR/M_PI*product(gr%spacing))**M_THIRD
    origin = M_ZERO
    call submesh_init(sphere, space, gr, latt, origin, rc)

    SAFE_ALLOCATE(ylm(1:sphere%np))

    ! We then compute the spherical harmonics in this submesh
    do im = 1, os%norbs
      mm = im-1-os%ll
      call loct_ylm(sphere%np, sphere%rel_x(1,1), sphere%r(1), os%ll, mm, ylm(1))
      orb(:,im) = M_ZERO
      call submesh_add_to_mesh(sphere, ylm, orb(:,im))
      norm = dmf_nrm2(gr, orb(:,im))
      call lalg_scal(gr%np, M_ONE / norm, orb(:,im))
    end do
    SAFE_DEALLOCATE_A(ylm)

    ! Then we put them on the grid, rotate them
    do im = 1, os%norbs
      do iop = 1, symm%nops
        call dgrid_symmetrize_single(gr, iop, orb(:,im), orb_sym)

        do imp = 1, os%norbs
          weight(im, imp, iop) = dmf_dotp(gr, orb(:,imp), orb_sym, reduce=.false.)
        end do
      end do
    end do

    call gr%allreduce(weight)

    SAFE_DEALLOCATE_A(orb)
    SAFE_DEALLOCATE_A(orb_sym)

    POP_SUB(orbitals_get_symm_weight)
  end subroutine orbitals_get_symm_weight

#include "dft_u_noncollinear_inc.F90"

#include "undef.F90"
#include "real.F90"
#include "lda_u_inc.F90"

#include "undef.F90"
#include "complex.F90"
#include "lda_u_inc.F90"
end module lda_u_oct_m
