!! Copyright (C) 2023 N. Tancogne-Dejean
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

#include "global.h"


!> @brief This module provides routines for communicating all batches in a ring-pattern scheme.
!!
!! Communication is done in two steps: the ring_pattern module determines which task talk to each task
!! for each step of the ring-pattern scheme.
!!
!! For a given step of the ring-pattern scheme, this module determines how many steps of send/received are needed
!! given the number of batches owned by the pairs of tasks involved (local, sender, receiver).
!!
!! This module also provides the routines that handle the communication of batches
module states_elec_all_to_all_communications_oct_m
  use batch_oct_m
  use debug_oct_m
  use global_oct_m
  use math_oct_m
  use mesh_oct_m
  use messages_oct_m
  use mpi_oct_m
  use namespace_oct_m
  use profiling_oct_m
  use states_elec_oct_m
  use states_elec_parallel_oct_m
  use wfs_elec_oct_m

  implicit none

  private

  public ::                                            &
    states_elec_all_to_all_communications_t

  type states_elec_all_to_all_communications_t
    private

    integer :: task_from !< Task from which we receive data
    integer :: task_to   !< Task to which we send data

    integer :: nbatch_to_receive !< Number of batches to receive
    integer :: nbatch_to_send    !< Number of batches to send
    integer :: n_comms           !< Number of communications

    integer :: nblock_to_receive !< Number of blocks to receive
    integer :: nblock_to_send    !< Number of blocks to send

    type(MPI_Request), allocatable :: send_req(:) !< Array of MPI request for Isend
  contains
    procedure :: start => states_elec_all_to_all_communications_start
    procedure :: get_nreceiv => states_elec_all_to_all_communications_get_nreceive
    procedure :: get_nsend => states_elec_all_to_all_communications_get_nsend
    procedure :: get_ncom => states_elec_all_to_all_communications_get_ncom
    procedure :: alloc_receive_batch => states_elec_all_to_all_communications_alloc_receive_batch
    procedure :: get_send_indices => states_elec_all_to_all_communications_get_send_indices
    procedure :: get_receive_indices => states_elec_all_to_all_communications_get_receive_indices
    procedure :: dpost_all_mpi_isend => dstates_elec_all_to_all_communications_post_all_mpi_isend
    procedure :: zpost_all_mpi_isend => zstates_elec_all_to_all_communications_post_all_mpi_isend
    procedure :: dmpi_recv_batch => dstates_elec_all_to_all_communications_mpi_recv_batch
    procedure :: zmpi_recv_batch => zstates_elec_all_to_all_communications_mpi_recv_batch
    procedure :: wait_all_isend => states_elec_all_to_all_communications_wait_all_isend
  end type states_elec_all_to_all_communications_t

contains

  !------------------------------------------------------------
  !>@brief Given a task to send to, and a task to receive from, initializes a states_elec_all_to_all_communications_t object.
  subroutine states_elec_all_to_all_communications_start(this, st, task_from, task_to)
    class(states_elec_all_to_all_communications_t),  intent(inout) :: this
    type(states_elec_t),   intent(in) :: st
    integer,               intent(in) :: task_from, task_to

    PUSH_SUB(states_elec_all_to_all_communications_start)

    this%task_from = task_from
    this%task_to = task_to

    this%nbatch_to_receive = states_elec_all_to_all_communications_eval_nreceive(st, task_from, this%nblock_to_receive)
    this%nbatch_to_send = states_elec_all_to_all_communications_eval_nsend(st, task_to, this%nblock_to_send)

    !Number of communications
    this%n_comms = max(this%nbatch_to_send, this%nbatch_to_receive)

    POP_SUB(states_elec_all_to_all_communications_start)
  end subroutine states_elec_all_to_all_communications_start

  !------------------------------------------------------------
  !>@brief How many batches we will receive from task_from
  integer function states_elec_all_to_all_communications_eval_nreceive(st, task_from, nblock_to_receive) result(nbatch_to_receive)
    type(states_elec_t),   intent(in)  :: st
    integer,               intent(in)  :: task_from
    integer,               intent(out) :: nblock_to_receive

    integer :: st_start, st_end, kpt_start, kpt_end, ib

    PUSH_SUB(states_elec_all_to_all_communications_eval_nreceive)

    nbatch_to_receive = 0
    nblock_to_receive = 0

    !What we receive for the wfn
    if (task_from > -1) then
      st_start   = st%st_kpt_task(task_from, 1)
      st_end     = st%st_kpt_task(task_from, 2)
      kpt_start  = st%st_kpt_task(task_from, 3)
      kpt_end    = st%st_kpt_task(task_from, 4)

      nblock_to_receive = 0
      do ib = 1, st%group%nblocks
        if (st%group%block_range(ib, 1) >= st_start .and. st%group%block_range(ib, 2) <= st_end) then
          nblock_to_receive = nblock_to_receive + 1
        end if
      end do
      nbatch_to_receive = nblock_to_receive * (kpt_end-kpt_start+1)
    end if

    write(message(1), '(a,i2,a,i2,a,i2)') 'Debug: Task ', st%st_kpt_mpi_grp%rank, ' will receive ', &
      nbatch_to_receive, ' batches from task ', task_from
    call messages_info(1, all_nodes=.true., debug_only=.true.)

    POP_SUB(states_elec_all_to_all_communications_eval_nreceive)
  end function states_elec_all_to_all_communications_eval_nreceive

  !------------------------------------------------------------
  !>@brief How many batches we will send from task_send
  integer function states_elec_all_to_all_communications_eval_nsend(st, task_to, nblock_to_send) result(nbatch_to_send)
    type(states_elec_t),   intent(in)  :: st
    integer,               intent(in)  :: task_to
    integer,               intent(out) :: nblock_to_send

    PUSH_SUB(states_elec_all_to_all_communications_eval_nsend)

    nbatch_to_send = 0
    nblock_to_send = 0

    if (task_to > -1) then
      nblock_to_send = (st%group%block_end-st%group%block_start+1)
      nbatch_to_send = nblock_to_send*(st%d%kpt%end-st%d%kpt%start+1)
    end if

    write(message(1), '(a,i2,a,i2,a,i2)') 'Debug: Task ', st%st_kpt_mpi_grp%rank, ' will send ', nbatch_to_send, &
      ' batches to task ', task_to
    call messages_info(1, all_nodes=.true., debug_only=.true.)

    POP_SUB(states_elec_all_to_all_communications_eval_nsend)
  end function states_elec_all_to_all_communications_eval_nsend

  !------------------------------------------------------------
  !>@brief Returns the number of communications
  integer pure function states_elec_all_to_all_communications_get_ncom(this) result(n_comms)
    class(states_elec_all_to_all_communications_t),  intent(in) :: this

    n_comms = this%n_comms
  end function states_elec_all_to_all_communications_get_ncom

  !------------------------------------------------------------
  !>@brief Returns the number of receiv calls
  integer pure function states_elec_all_to_all_communications_get_nreceive(this) result(nbatch_to_receive)
    class(states_elec_all_to_all_communications_t),  intent(in) :: this

    nbatch_to_receive = this%nbatch_to_receive
  end function states_elec_all_to_all_communications_get_nreceive

  !------------------------------------------------------------
  !>@brief Returns the number send calls
  integer pure function states_elec_all_to_all_communications_get_nsend(this) result(nbatch_to_send)
    class(states_elec_all_to_all_communications_t),  intent(in) :: this

    nbatch_to_send = this%nbatch_to_send
  end function states_elec_all_to_all_communications_get_nsend

  !------------------------------------------------------------
  !>@brief Given the icom step, allocate the receiv buffer (wfs_elec_t)
  subroutine states_elec_all_to_all_communications_alloc_receive_batch(this, st, icom, np, psib)
    class(states_elec_all_to_all_communications_t),  intent(in)  :: this
    type(states_elec_t),                             intent(in)  :: st
    integer,                                         intent(in)  :: icom !< Communication step
    integer,                                         intent(in)  :: np   !< Number of points in the batch (mesh%np or mesh%np_part)
    type(wfs_elec_t),                                intent(out) :: psib !< Allocated batch

    integer :: block_id, ib, ik

    PUSH_SUB(states_elec_all_to_all_communications_alloc_receive_batch)

    ! Given the icom, returns the id of the block of state communicated
    block_id = mod(icom-1, this%nblock_to_receive)+1
    ik = int((icom-block_id)/this%nblock_to_receive) + st%st_kpt_task(this%task_from, 3)
    ib = block_id - 1 + st%group%iblock(st%st_kpt_task(this%task_from, 1))

    write(message(1), '(a,i2,a,i2,a,i2)') 'Debug: Task ', st%st_kpt_mpi_grp%rank, ' allocates memory for block ', &
      ib, ' and k-point ', ik
    call messages_info(1, all_nodes=.true., debug_only=.true.)

    call states_elec_parallel_allocate_batch(st, psib, np, ib, ik, packed=.true.)

    POP_SUB(states_elec_all_to_all_communications_alloc_receive_batch)
  end subroutine states_elec_all_to_all_communications_alloc_receive_batch

  !------------------------------------------------------------
  !>@brief Given the icom step, returns the block and k-point indices to be sent
  subroutine states_elec_all_to_all_communications_get_send_indices(this, st, icom, ib, ik)
    class(states_elec_all_to_all_communications_t),  intent(in)  :: this
    type(states_elec_t),                             intent(in)  :: st
    integer,                                         intent(in)  :: icom !< Communication step
    integer,                                         intent(out) :: ib   !< Block index of the batch to send
    integer,                                         intent(out) :: ik   !< k-point index of batch to send

    PUSH_SUB(states_elec_all_to_all_communications_get_send_indices)

    ! Given the icom, returns the id of the block of state communicated
    ib = mod(icom-1, this%nblock_to_send) + 1
    ik = int((icom-ib)/this%nblock_to_send) + st%d%kpt%start
    ib = ib - 1 + st%group%iblock(st%st_start)

    write(message(1), '(a,i2,a,i2,a,i2)') 'Debug: Task ', st%st_kpt_mpi_grp%rank, ' will send the block ', &
      ib, ' with k-point ', ik
    call messages_info(1, all_nodes=.true., debug_only=.true.)

    POP_SUB(states_elec_all_to_all_communications_get_send_indices)
  end subroutine states_elec_all_to_all_communications_get_send_indices

  !------------------------------------------------------------
  !>@brief Given the icom step, returns the block and k-point indices to be received
  subroutine states_elec_all_to_all_communications_get_receive_indices(this, st, icom, ib, ik)
    class(states_elec_all_to_all_communications_t),  intent(in)  :: this
    type(states_elec_t),                             intent(in)  :: st
    integer,                                         intent(in)  :: icom !< Communication step
    integer,                                         intent(out) :: ib   !< Block index of the batch to receive
    integer,                                         intent(out) :: ik   !< k-point index of batch to receive

    PUSH_SUB(states_elec_all_to_all_communications_get_receive_indices)

    ! Given the icom, returns the id of the block of state communicated
    ib = mod(icom-1, this%nblock_to_receive)+1
    ik = int((icom-ib)/this%nblock_to_receive) +  st%st_kpt_task(this%task_from, 3)
    ib = ib - 1 + st%group%iblock(st%st_kpt_task(this%task_from, 1))

    if (debug%info) then
      write(message(1), '(a,i2,a,i2,a,i2)') 'Task ', st%st_kpt_mpi_grp%rank, ' will receive the block ', &
        ib, ' with k-point ', ik
      call messages_info(1, all_nodes=.true.)
    end if

    POP_SUB(states_elec_all_to_all_communications_get_receive_indices)
  end subroutine states_elec_all_to_all_communications_get_receive_indices

  !------------------------------------------------------------
  !>@brief Do a MPI waitall for the isend requests
  subroutine states_elec_all_to_all_communications_wait_all_isend(this, st)
    class(states_elec_all_to_all_communications_t),  intent(inout) :: this
    type(states_elec_t),                             intent(in)    :: st

    PUSH_SUB(states_elec_all_to_all_communications_wait_all_isend)
    call profiling_in("ALL_TO_ALL_COMM")

    if (allocated(this%send_req)) then

      call st%st_kpt_mpi_grp%wait(this%nbatch_to_send, this%send_req)

      SAFE_DEALLOCATE_A(this%send_req)

    end if

    call profiling_out("ALL_TO_ALL_COMM")
    POP_SUB(states_elec_all_to_all_communications_wait_all_isend)
  end subroutine states_elec_all_to_all_communications_wait_all_isend



#include "undef.F90"
#include "real.F90"
#include "states_elec_all_to_all_communications_inc.F90"

#include "undef.F90"
#include "complex.F90"
#include "states_elec_all_to_all_communications_inc.F90"
#include "undef.F90"

end module states_elec_all_to_all_communications_oct_m

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
