!! Copyright (C) 2015 X. Andrade
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

!> @brief gather distributed states into a local array
!!
subroutine X(states_elec_parallel_gather_3)(st, dims, psi)
  type(states_elec_t), intent(in)  :: st            !< for information on parallel distribution and the mpi group
  integer,           intent(in)    :: dims(2)       !< first and second dimensions of the send/receive buffers
  R_TYPE,            intent(inout) :: psi(:, :, :)  !< wave functions to be gathered;
  !!                                                   dimensions (1:st\%nst, 1:dims(1), 1:dims(2))

  integer :: maxst, ist, i1, i2, irank, ist_local
  R_TYPE, allocatable :: sendpsi(:, :, :), recvpsi(:, :, :)

  !no PUSH_SUB, called too often

  call profiling_in(TOSTRING(X(STATES_GATHER)))

  if (st%parallel_in_states) then

    maxst = maxval(st%dist%num(0:st%mpi_grp%size - 1))

    SAFE_ALLOCATE(sendpsi(1:dims(1), 1:dims(2), 1:maxst))
    SAFE_ALLOCATE(recvpsi(1:dims(1), 1:dims(2), 1:maxst*st%mpi_grp%size))

    ! We have to use a temporary array to make the data contiguous

    do ist = 1, st%lnst
      do i1 = 1, dims(1)
        do i2 = 1, dims(2)
          sendpsi(i1, i2, ist) = psi(st%st_start + ist - 1, i1, i2)
        end do
      end do
    end do
    sendpsi(1:dims(1), 1:dims(2), st%lnst+1:maxst) = M_ZERO

    call st%mpi_grp%allgather(sendpsi(1, 1, 1), product(dims(1:2))*maxst, R_MPITYPE, &
      recvpsi(1, 1, 1), product(dims(1:2))*maxst, R_MPITYPE)

    ! now get the correct states from the data of each rank
    ist = 0
    do irank = 0, st%mpi_grp%size - 1
      do ist_local = 1, st%dist%num(irank)
        ist = ist + 1
        do i1 = 1, dims(1)
          do i2 = 1, dims(2)
            psi(ist, i1, i2) = recvpsi(i1, i2, irank*maxst + ist_local)
          end do
        end do
      end do
    end do

    SAFE_DEALLOCATE_A(sendpsi)
    SAFE_DEALLOCATE_A(recvpsi)

  end if

  call profiling_out(TOSTRING(X(STATES_GATHER)))
end subroutine X(states_elec_parallel_gather_3)

!---------------------------------------------------
!> @brief gather a one-dimensional array, distributed over states
!!
subroutine X(states_elec_parallel_gather_1)(st, aa)
  type(states_elec_t), intent(in)    :: st     !< for information on parallel distribution and the mpi group
  R_TYPE, contiguous,  intent(inout) :: aa(:)  !< array to be gathered, dimensions (1: st\%nst)

  !no PUSH_SUB, called too often

  R_TYPE, allocatable :: sendaa(:)
  integer, allocatable :: displs(:)

  call profiling_in(TOSTRING(X(STATES_GATHER)))

  if (st%parallel_in_states) then

    SAFE_ALLOCATE(sendaa(st%st_start:st%st_end))
    SAFE_ALLOCATE(displs(0:st%mpi_grp%size - 1))

    sendaa(st%st_start:st%st_end) = aa(st%st_start:st%st_end)
    displs(0:st%mpi_grp%size - 1) = st%dist%range(1, 0:st%mpi_grp%size - 1) - 1


    call st%mpi_grp%allgatherv(sendaa(st%st_start:), st%lnst, R_MPITYPE, &
      aa, st%dist%num, displs, R_MPITYPE)

    SAFE_DEALLOCATE_A(sendaa)
    SAFE_DEALLOCATE_A(displs)

  end if

  call profiling_out(TOSTRING(X(STATES_GATHER)))
end subroutine X(states_elec_parallel_gather_1)

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
