!! Copyright (C) 2008 X. Andrade, 2020 S. Ohlmann
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

#include "global.h"

!> @brief This module implements batches of mesh functions
!!
!! In many situations, we need to perform the same operations over many mesh functions,
!! such as the electronic wave functions. It is therefore advantageous to group those functions into one object.
!! This can ensure that different mesh functions are contiguous in memory.
!!
!! Due to the nature of stencil operations, which constitute a large part of the low level operations on mesh functions,
!! it is often more efficient to perform the same stencil operation over different mesh functions
!! (i.e. using the state index as fast index), than looping first over the mesh index, which would, in general,
!! require a different stencil for each mesh point. This is, in particular, the case for calculations utilizing GPUs.
!!
!! Therefore, we store mesh functions in linear or in so-called packed form. The former refers to the
!! *natural* ordering where the mesh index is the fastest moving, while the latter is transposed.
!! Furthermore the arrays are padded to ensure aligned memory access.
!!
!! The packed form is even more advantageous on the GPU. Therefore, _only_ packed data is stored in the
!! device memory. On the devices, the padding is aligned with the size of a work group and can depend
!! on the actual device.
!
module batch_oct_m
  use accel_oct_m
  use allocate_hardware_aware_oct_m
  use blas_oct_m
  use debug_oct_m
  use global_oct_m
  use hardware_oct_m
  use iso_c_binding
  use math_oct_m
  use messages_oct_m
  use mpi_oct_m
  use profiling_oct_m
  use types_oct_m

  implicit none

  private
  public ::                         &
    batch_t,                        &
    batch_init,                     &
    dbatch_init,                    &
    zbatch_init,                    &
    batch_read_device_to_packed

  !> @brief Class defining batches of mesh functions
  !!
  type batch_t
    private
    integer,                     public :: nst   !< number of functions in the batch
    integer,                     public :: dim   !< Spinor dimension of the state (one, or two for spinors)
    integer                             :: np    !< number of points in each function (this can be np or np_part)
    integer                             :: ndims !< The second dimension of ist_idim_index(:,:). Currently always set to 2.
    integer,        allocatable         :: ist_idim_index(:, :)
    !<                                                  @brief index mapping fom global (ist,idim) to local ist.
    !!
    !!                                                  This maps ist and idim into one linear array.
    !!                                                  This index is constructed in batch_oct_m::batch_build_indices
    integer,        allocatable, public :: ist(:)    !< @brief map from an global to local index
    !!
    !!                                                   The global index does not need to start at 1, while
    !!                                                   the local index is always in the range 1:nst.
    !!
    !!                                                   This index is constructed in batch_oct_m::batch_build_indices

    logical                             :: is_allocated  !< indicate allocation status
    logical                             :: own_memory    !< does the batch own the memory or is it foreign memory?
    !  We also need a linear array with the states in order to calculate derivatives, etc.
    integer,                     public :: nst_linear    !< nst_linear = nst * st%d%dim

    integer                             :: status_of     !< @brief packing status of the batch
    !!
    !!                                                   possible values are:
    !!                                                   BATCH_NOT_PACKED, BATCH_PACKED, BATCH_DEVICE_PACKED
    integer                             :: status_host   !< @brief packing status in CPU memory
    !!
    !!                                                      If Octopus runs on GPU, this indicates the status on the CPU.
    !!                                                      It can only be BATCH_NOT_PACKED and BATCH_PACKED.
    !!                                                      This makes transfers more efficient: usually we allocate a
    !!                                                      batch as packed on the CPU, then call do_pack to copy it to the GPU.
    !!                                                      In this case, it is really a copy.
    !!                                                      If the batch is unpacked on the CPU, we need to transpose in
    !!                                                      addition which makes it much slower.
    type(type_t)                        :: type_of             !< either TYPE_FLOAT or TYPE_CMPLX
    integer                             :: device_buffer_count !< keep track of pack operations performed on the device
    integer                             :: host_buffer_count   !< keep track of pack operations performed on the host
    logical                             :: special_memory      !< are we using hardware-aware memory?
    logical                             :: needs_finish_unpack !< if .true., async unpacking has started and needs be finished


    ! unpacked variables; linear variables are pointers with different shapes
    real(real64), pointer, contiguous,  public :: dff(:, :, :)     !< pointer to real mesh functions: indices are (1:np, 1:dim, 1:nst)
    complex(real64), pointer, contiguous,  public :: zff(:, :, :)     !< pointer to complex mesh functions: indices are (1:np, 1:dim, 1:nst)
    real(real64), pointer, contiguous,  public :: dff_linear(:, :) !< pointer to real mesh functions: indices are (1:np, 1:nst_linear)
    complex(real64), pointer, contiguous,  public :: zff_linear(:, :) !< pointer to complex mesh functions: indices are (1:np, 1:nst_linear)

    ! packed variables; only rank-2 arrays due to padding to powers of 2
    real(real64), pointer, contiguous,  public :: dff_pack(:, :)   !< pointer to real mesh functions: indices are (1:nst_linear, 1:np)
    complex(real64), pointer, contiguous,  public :: zff_pack(:, :)   !< pointer to complex mesh functions: indices are (1:nst_linear, 1:np)

    integer(int64),                 public :: pack_size(1:2)      !< pack_size = [pad_pow2(nst_linear), np]
    !!                                                            (see math_oct_m::pad_pow2)
    integer(int64),                 public :: pack_size_real(1:2) !< pack_size_real = pack_size;
    !!                                                            if batch type is complex, then
    !!                                                            pack_size_real(1) = 2*pack_size(1)

    type(accel_mem_t),           public :: ff_device           !< pointer to device memory

  contains
    procedure :: check_compatibility_with => batch_check_compatibility_with !< @copydoc batch_oct_m::batch_check_compatibility_with
    procedure :: clone_to => batch_clone_to                                 !< @copydoc batch_oct_m::batch_clone_to
    procedure :: clone_to_array => batch_clone_to_array                     !< @copydoc batch_oct_m::batch_clone_to_array
    procedure :: copy_to => batch_copy_to                                   !< @copydoc batch_oct_m::batch_copy_to
    procedure :: copy_data_to => batch_copy_data_to                         !< @copydoc batch_oct_m::batch_copy_data_to
    procedure :: do_pack => batch_do_pack                                   !< @copydoc batch_oct_m::batch_do_pack
    procedure :: do_unpack => batch_do_unpack                               !< @copydoc batch_oct_m::batch_do_unpack
    procedure :: finish_unpack => batch_finish_unpack                       !< @copydoc batch_oct_m::batch_finish_unpack
    procedure :: end => batch_end                                           !< @copydoc batch_oct_m::batch_end
    procedure :: inv_index => batch_inv_index                               !< @copydoc batch_oct_m::batch_inv_index
    procedure :: is_packed => batch_is_packed                               !< @copydoc batch_oct_m::batch_is_packed
    procedure :: ist_idim_to_linear => batch_ist_idim_to_linear             !< @copydoc batch_oct_m::batch_ist_idim_to_linear
    procedure :: linear_to_idim => batch_linear_to_idim                     !< @copydoc batch_oct_m::batch_linear_to_idim
    procedure :: linear_to_ist => batch_linear_to_ist                       !< @copydoc batch_oct_m::batch_linear_to_ist
    procedure :: pack_total_size => batch_pack_total_size                   !< @copydoc batch_oct_m::batch_pack_total_size
    procedure :: remote_access_start => batch_remote_access_start           !< @copydoc batch_oct_m::batch_remote_access_start
    procedure :: remote_access_stop => batch_remote_access_stop             !< @copydoc batch_oct_m::batch_remote_access_stop
    procedure :: status => batch_status                                     !< @copydoc batch_oct_m::batch_status
    procedure :: type => batch_type                                         !< @copydoc batch_oct_m::batch_type
    procedure :: type_as_int => batch_type_as_integer                       !< @copydoc batch_oct_m::batch_type_as_integer
    procedure, private :: dallocate_unpacked_host => dbatch_allocate_unpacked_host
    !< @copydoc batch_oct_m::dbatch_allocate_unpacked_host
    procedure, private :: zallocate_unpacked_host => zbatch_allocate_unpacked_host
    !< @copydoc batch_oct_m::zbatch_allocate_unpacked_host
    procedure, private :: allocate_unpacked_host => batch_allocate_unpacked_host
    !< @copydoc batch_oct_m::batch_allocate_unpacked_host
    procedure, private :: dallocate_packed_host => dbatch_allocate_packed_host
    !< @copydoc batch_oct_m::dbatch_allocate_packed_host
    procedure, private :: zallocate_packed_host => zbatch_allocate_packed_host
    !< @copydoc batch_oct_m::zbatch_allocate_packed_host
    procedure, private :: allocate_packed_host => batch_allocate_packed_host
    !< @copydoc batch_oct_m::batch_allocate_packed_host
    procedure, private :: allocate_packed_device => batch_allocate_packed_device
    !< @copydoc batch_oct_m::batch_allocate_packed_device
    procedure, private :: deallocate_unpacked_host => batch_deallocate_unpacked_host
    !< @copydoc batch_oct_m::batch_deallocate_unpacked_host
    procedure, private :: deallocate_packed_host => batch_deallocate_packed_host
    !< @copydoc batch_oct_m::batch_deallocate_packed_host
    procedure, private :: deallocate_packed_device => batch_deallocate_packed_device
    !< @copydoc batch_oct_m::batch_deallocate_packed_device
  end type batch_t

  !--------------------------------------------------------------
  !> @brief initialize a batch with existing memory
  !! @note the provided arrays are assumed to be in unpacked shape.
  !!
  interface batch_init
    module procedure dbatch_init_with_memory_3
    module procedure zbatch_init_with_memory_3
    module procedure dbatch_init_with_memory_2
    module procedure zbatch_init_with_memory_2
    module procedure dbatch_init_with_memory_1
    module procedure zbatch_init_with_memory_1
  end interface batch_init

  integer, public, parameter :: &
    BATCH_NOT_PACKED     = 0,   & !< functions are stored in CPU memory, unpacked order
    BATCH_PACKED         = 1,   & !< functions are stored in CPU memory, in transposed (packed) order
    BATCH_DEVICE_PACKED  = 2      !< functions are stored in device memory in packed order

  integer, parameter :: CL_PACK_MAX_BUFFER_SIZE = 4 !< this value controls the size (in number of wave-functions)
  !!                                                   of the buffer used to copy states to the opencl device.

contains

  !--------------------------------------------------------------
  !> @brief finalize a batch and release allocated memory, if necessary
  !!
  !! If the batch was initialized with 'external' memory, this routine
  !! ensures that this memory is up-to-date, when the batch is finalized.
  !! This means, that the data is copied from the device (If requested)
  !! and unpacked.
  !
  subroutine batch_end(this, copy)
    class(batch_t),          intent(inout) :: this
    logical,       optional, intent(in)    :: copy !< do we need to copy data from the device? Default = .true.
    !!                                                (from batch_oct_m::batch_do_uppack)

    PUSH_SUB(batch_end)

    if (this%own_memory .and. this%is_packed()) then
      !deallocate directly to avoid unnecessary copies
      if (this%status() == BATCH_DEVICE_PACKED) then
        call this%deallocate_packed_device()
      end if
      if (this%status() == BATCH_PACKED .or. this%status_host == BATCH_PACKED) then
        call this%deallocate_packed_host()
      end if
      this%status_of = BATCH_NOT_PACKED
      this%status_host = BATCH_NOT_PACKED
      this%host_buffer_count = 0
      this%device_buffer_count = 0
    end if
    if (this%status() == BATCH_DEVICE_PACKED) call this%do_unpack(copy, force = .true.)
    if (this%status() == BATCH_PACKED) call this%do_unpack(copy, force = .true.)

    if (this%is_allocated) then
      call this%deallocate_unpacked_host()
    end if

    SAFE_DEALLOCATE_A(this%ist_idim_index)
    SAFE_DEALLOCATE_A(this%ist)

    POP_SUB(batch_end)
  end subroutine batch_end

  !--------------------------------------------------------------
  !> @brief release unpacked host memory
  !!
  !! This routine takes care of special (i.e. pinned memory)
  !
  subroutine batch_deallocate_unpacked_host(this)
    class(batch_t),  intent(inout) :: this

    PUSH_SUB(batch_deallocate_unpacked_host)

    this%is_allocated = .false.

    if (this%special_memory) then
      if (associated(this%dff)) then
        call deallocate_hardware_aware(c_loc(this%dff(1,1,1)), int(this%np, int64)*this%dim*this%nst*8)
      end if
      if (associated(this%zff)) then
        call deallocate_hardware_aware(c_loc(this%zff(1,1,1)), int(this%np, int64)*this%dim*this%nst*16)
      end if
    else
      SAFE_DEALLOCATE_P(this%dff)
      SAFE_DEALLOCATE_P(this%zff)
    end if
    nullify(this%dff)
    nullify(this%dff_linear)
    nullify(this%zff)
    nullify(this%zff_linear)

    POP_SUB(batch_deallocate_unpacked_host)
  end subroutine batch_deallocate_unpacked_host

  !--------------------------------------------------------------
  !> @brief release packed host memory
  !!
  !! This routine takes care of special (i.e. pinned memory)
  !
  subroutine batch_deallocate_packed_host(this)
    class(batch_t),  intent(inout) :: this

    PUSH_SUB(batch_deallocate_packed_host)

    if (this%special_memory) then
      if (associated(this%dff_pack)) then
        call deallocate_hardware_aware(c_loc(this%dff_pack(1,1)), int(this%pack_size(1), int64)*this%pack_size(2)*8)
      end if
      if (associated(this%zff_pack)) then
        call deallocate_hardware_aware(c_loc(this%zff_pack(1,1)), int(this%pack_size(1), int64)*this%pack_size(2)*16)
      end if
    else
      SAFE_DEALLOCATE_P(this%dff_pack)
      SAFE_DEALLOCATE_P(this%zff_pack)
    end if
    nullify(this%dff_pack)
    nullify(this%zff_pack)

    POP_SUB(batch_deallocate_packed_host)
  end subroutine batch_deallocate_packed_host

  !--------------------------------------------------------------
  !> @brief release packed device memory
  !
  subroutine batch_deallocate_packed_device(this)
    class(batch_t),  intent(inout) :: this

    PUSH_SUB(batch_deallocate_packed_device)

    call accel_release_buffer(this%ff_device)

    POP_SUB(batch_deallocate_packed_device)
  end subroutine batch_deallocate_packed_device

  !--------------------------------------------------------------
  !> @brief allocate host (CPU) memory for unpacked data
  !!
  !! This routine is a wrapper to the tyupe specific versions
  !
  subroutine batch_allocate_unpacked_host(this)
    class(batch_t),  intent(inout) :: this

    PUSH_SUB(batch_allocate_unpacked_host)

    if (this%type() == TYPE_FLOAT) then
      call this%dallocate_unpacked_host()
    else if (this%type() == TYPE_CMPLX) then
      call this%zallocate_unpacked_host()
    end if

    POP_SUB(batch_allocate_unpacked_host)
  end subroutine batch_allocate_unpacked_host

  !--------------------------------------------------------------
  !> @brief allocate host (CPU) memory for packed data
  !!
  !! This routine is a wrapper to the tyupe specific versions
  !
  subroutine batch_allocate_packed_host(this)
    class(batch_t),  intent(inout) :: this

    PUSH_SUB(batch_allocate_packed_host)

    if (this%type() == TYPE_FLOAT) then
      call this%dallocate_packed_host()
    else if (this%type() == TYPE_CMPLX) then
      call this%zallocate_packed_host()
    end if

    POP_SUB(batch_allocate_packed_host)
  end subroutine batch_allocate_packed_host

  !--------------------------------------------------------------
  !> @brief allocate device (GPU) memory for packed data
  !!
  !! This routine is a wrapper to the tyupe specific versions
  !
  subroutine batch_allocate_packed_device(this)
    class(batch_t),  intent(inout) :: this

    PUSH_SUB(batch_allocate_packed_device)

    call accel_create_buffer(this%ff_device, ACCEL_MEM_READ_WRITE, this%type(), &
      product(this%pack_size))

    POP_SUB(batch_allocate_packed_device)
  end subroutine batch_allocate_packed_device

  !--------------------------------------------------------------
  !> @brief initialize an empty batch
  !!
  !! This auxilliary function is only called from batch_oct_m::batch_init functions.
  !! It initializes the book-keeping parameters, allocates memory for the indices
  !! and nullifies the pointers to the data arrays.
  !
  subroutine batch_init_empty (this, dim, nst, np)
    type(batch_t), intent(out)   :: this !< the batch to be initialized
    integer,       intent(in)    :: dim  !< The number of spin dimensions
    integer,       intent(in)    :: nst  !< The number of states in the batch
    integer,       intent(in)    :: np   !< The number of points in each mesh function

    PUSH_SUB(batch_init_empty)

    this%is_allocated = .false.
    this%own_memory = .false.
    this%special_memory = .false.
    this%needs_finish_unpack = .false.
    this%nst = nst
    this%dim = dim
    this%type_of = TYPE_NONE

    this%nst_linear = nst*dim

    this%np = np
    this%device_buffer_count = 0
    this%host_buffer_count = 0
    this%status_of = BATCH_NOT_PACKED
    this%status_host = BATCH_NOT_PACKED

    this%ndims = 2
    SAFE_ALLOCATE(this%ist_idim_index(1:this%nst_linear, 1:this%ndims))
    SAFE_ALLOCATE(this%ist(1:this%nst))

    nullify(this%dff, this%zff, this%dff_linear, this%zff_linear)
    nullify(this%dff_pack, this%zff_pack)

    POP_SUB(batch_init_empty)
  end subroutine batch_init_empty

  !--------------------------------------------------------------
  !> @brief clone a batch to a new batch
  !!
  !! This routine clones the metadata of a batch and, if requested
  !! copies the data.
  !
  subroutine batch_clone_to(this, dest, pack, copy_data, new_np)
    class(batch_t),              intent(in)    :: this       !< source batch
    class(batch_t), allocatable, intent(out)   :: dest       !< destination batch
    logical,        optional,    intent(in)    :: pack       !< If .false. the new batch will not be packed.
    !!                                                          Default: batch_is_packed(this)
    logical,        optional,    intent(in)    :: copy_data  !< If .true. the batch data will be copied to the destination batch.
    !!                                                          Default: .false.
    integer,        optional,    intent(in)    :: new_np     !< If present, this replaces this%np in the initialization

    PUSH_SUB(batch_clone_to)

    if (.not. allocated(dest)) then
      SAFE_ALLOCATE_TYPE(batch_t, dest)
    else
      message(1) = "Internal error: destination batch in batch_clone_to has been previously allocated."
      call messages_fatal(1)
    end if

    call this%copy_to(dest, pack, copy_data, new_np)

    POP_SUB(batch_clone_to)
  end subroutine batch_clone_to

  !--------------------------------------------------------------

  subroutine batch_clone_to_array(this, dest, n_batches, pack, copy_data)
    class(batch_t),              intent(in)    :: this
    class(batch_t), allocatable, intent(out)   :: dest(:)
    integer,                     intent(in)    :: n_batches
    logical,        optional,    intent(in)    :: pack       !< If .false. the new batch will not be packed.
    !!                                                          Default: batch_is_packed(this)
    logical,        optional,    intent(in)    :: copy_data  !< If .true. the batch data will be copied to the destination batch.
    !!                                                          Default: .false.

    integer :: ib

    PUSH_SUB(batch_clone_to_array)

    if (.not. allocated(dest)) then
      SAFE_ALLOCATE_TYPE_ARRAY(batch_t, dest, (1:n_batches))
    else
      message(1) = "Internal error: destination batch in batch_clone_to_array has been previously allocated."
      call messages_fatal(1)
    end if

    do ib = 1, n_batches
      call this%copy_to(dest(ib), pack, copy_data)
    end do

    POP_SUB(batch_clone_to_array)
  end subroutine batch_clone_to_array

  !--------------------------------------------------------------
  !> @brief make a copy of a batch
  !!
  !! This routine can perform a deep or a shallow copy of a batch
  !!
  subroutine batch_copy_to(this, dest, pack, copy_data, new_np, special)
    class(batch_t),          intent(in)    :: this       !< The source batch
    class(batch_t),          intent(out)   :: dest       !< The destination batch
    logical,       optional, intent(in)    :: pack       !< If .false. the new batch will not be packed.
    !!                                                      Default: batch_is_packed(this)
    logical,       optional, intent(in)    :: copy_data  !< If .true. the batch data will be copied to the destination batch.
    !!                                                      Default: .false.
    integer,       optional, intent(in)    :: new_np     !< If present, this replaces this%np in the initialization
    logical,       optional, intent(in)    :: special    !< If present, this replace special in the locic below,
    !!                                                      i.e., we try to allocate on the GPU

    logical :: host_packed, special_
    integer :: np_

    PUSH_SUB(batch_copy_to)

    np_ = optional_default(new_np, this%np)

    host_packed = this%host_buffer_count > 0
    ! use special memory here only for batches not on the GPU to avoid allocating
    ! pinned memory for temporary batches because that leads to a severe performance
    ! decrease for GPU runs (up to 20x)
    if (present(special)) then
      special_ = this%special_memory
    else
      special_ = this%special_memory .and. .not. this%device_buffer_count > 0
    end if

    if (this%type() == TYPE_FLOAT) then
      call dbatch_init(dest, this%dim, 1, this%nst, np_, packed=host_packed, special=special_)
    else if (this%type() == TYPE_CMPLX) then
      call zbatch_init(dest, this%dim, 1, this%nst, np_, packed=host_packed, special=special_)
    else
      message(1) = "Internal error: unknown batch type in batch_copy_to."
      call messages_fatal(1)
    end if

    if (this%status() /= dest%status() .and. optional_default(pack, this%is_packed())) call dest%do_pack(copy = .false.)

    dest%ist_idim_index(1:this%nst_linear, 1:this%ndims) = this%ist_idim_index(1:this%nst_linear, 1:this%ndims)
    dest%ist(1:this%nst) = this%ist(1:this%nst)

    if (optional_default(copy_data, .false.)) then
      ASSERT(np_ == this%np)
      call this%copy_data_to(min(this%np, np_), dest)
    end if

    POP_SUB(batch_copy_to)
  end subroutine batch_copy_to

  ! ----------------------------------------------------
  !> @brief return the type of a batch
  !!
  !! This function is THREADSAFE
  !!
  type(type_t) pure function batch_type(this) result(btype)
    class(batch_t),      intent(in)    :: this

    btype = this%type_of

  end function batch_type

  ! ----------------------------------------------------
  !> For debuging purpose only
  integer pure function batch_type_as_integer(this) result(itype)
    class(batch_t),      intent(in)    :: this

    type(type_t) :: btype

    itype = 0
    btype = this%type()
    if (btype == TYPE_FLOAT) itype = 1
    if (btype == TYPE_CMPLX) itype = 2

  end function batch_type_as_integer

  ! ----------------------------------------------------
  !> @brief return the status of a batch
  !!
  !! This function is THREADSAFE
  !!
  integer pure function batch_status(this) result(bstatus)
    class(batch_t),      intent(in)    :: this

    bstatus = this%status_of
  end function batch_status

  ! ----------------------------------------------------

  logical pure function batch_is_packed(this) result(in_buffer)
    class(batch_t),      intent(in)    :: this

    in_buffer = (this%device_buffer_count > 0) .or. (this%host_buffer_count > 0)
  end function batch_is_packed

  ! ----------------------------------------------------

  integer(int64) function batch_pack_total_size(this) result(size)
    class(batch_t),      intent(inout) :: this

    size = this%np
    if (accel_is_enabled()) size = accel_padded_size(size)
    size = size*pad_pow2(this%nst_linear)*types_get_size(this%type())

  end function batch_pack_total_size

  ! ----------------------------------------------------

  !> @brief pack the data in a batch
  !!
  !! If accelerators are enabled, the packed data is moved to the device memory.
  !! If the batch is already packed, a counter is increased to keep track when to unpack.
  !!
  subroutine batch_do_pack(this, copy, async)
    class(batch_t),      intent(inout) :: this  !< The current batch
    logical,   optional, intent(in)    :: copy  !< Do we copy the data to the packed memory? (default .true.)
    logical,   optional, intent(in)    :: async !< We can do an asynchronous operation. (default .false.)
    !!                                             The program flow can continue while data is being transferred to the device.

    logical               :: copy_
    logical               :: async_
    integer               :: source, target

    ! no push_sub, called too frequently

    call profiling_in("BATCH_DO_PACK")

    copy_ = optional_default(copy, .true.)

    async_ = optional_default(async, .false.)

    ! get source and target states for this batch
    source = this%status()
    select case (source)
    case (BATCH_NOT_PACKED, BATCH_PACKED)
      if (accel_is_enabled()) then
        target = BATCH_DEVICE_PACKED
      else
        target = BATCH_PACKED
      end if
    case (BATCH_DEVICE_PACKED)
      target = BATCH_DEVICE_PACKED
    end select

    ! only do something if target is different from source
    if (source /= target) then
      select case (target)
      case (BATCH_DEVICE_PACKED)
        call this%allocate_packed_device()
        this%status_of = BATCH_DEVICE_PACKED ! Note that this%status_host remains untouched.

        if (copy_) then
          select case (source)
          case (BATCH_NOT_PACKED)
            ! copy from unpacked host array to device
            call batch_write_unpacked_to_device(this)
          case (BATCH_PACKED)
            ! copy from packed host array to device
            call batch_write_packed_to_device(this, async_)
          end select
        end if
      case (BATCH_PACKED)
        call this%allocate_packed_host()
        this%status_of = BATCH_PACKED
        this%status_host = BATCH_PACKED

        if (copy_) then
          if (this%type() == TYPE_FLOAT) then
            call dbatch_pack_copy(this)
          else if (this%type() == TYPE_CMPLX) then
            call zbatch_pack_copy(this)
          end if
        end if
        if (this%own_memory) call this%deallocate_unpacked_host()
      end select
    end if

    select case (target)
    case (BATCH_DEVICE_PACKED)
      this%device_buffer_count = this%device_buffer_count + 1
    case (BATCH_PACKED)
      this%host_buffer_count = this%host_buffer_count + 1
    end select

    call profiling_out("BATCH_DO_PACK")
  end subroutine batch_do_pack

  ! ----------------------------------------------------
  !> @brief unpack a batch
  !!
  !! We unpack the batch if the 'packing counter' is one, or the force flag is given.
  !!
  subroutine batch_do_unpack(this, copy, force, async)
    class(batch_t),     intent(inout) :: this
    logical, optional,  intent(in)    :: copy   !< indicate whether to copy the data (default .true.)
    logical, optional,  intent(in)    :: force  !< if force = .true., unpack independently of the counter (default .false.)
    logical, optional,  intent(in)    :: async  !< indicate whether the operation can by asynchronous (default .false.).
    !!                                             In this case the operation has to be completed by calling batch_finish_unpack()

    logical :: copy_, force_, async_
    integer               :: source, target

    PUSH_SUB(batch_do_unpack)

    call profiling_in("BATCH_DO_UNPACK")

    copy_ = optional_default(copy, .true.)

    force_ = optional_default(force, .false.)

    async_ = optional_default(async, .false.)

    ! get source and target states for this batch
    source = this%status()
    select case (source)
    case (BATCH_NOT_PACKED)
      target = source
    case (BATCH_PACKED)
      target = BATCH_NOT_PACKED
    case (BATCH_DEVICE_PACKED)
      target = this%status_host
    end select

    ! only do something if target is different from source
    if (source /= target) then
      select case (source)
      case (BATCH_PACKED)
        if (this%host_buffer_count == 1 .or. force_) then
          if (this%own_memory) call this%allocate_unpacked_host()
          ! unpack from packed_host to unpacked_host
          if (copy_ .or. this%own_memory) then
            if (this%type() == TYPE_FLOAT) then
              call dbatch_unpack_copy(this)
            else if (this%type() == TYPE_CMPLX) then
              call zbatch_unpack_copy(this)
            end if
          end if
          call this%deallocate_packed_host()
          this%status_host = target
          this%status_of = target
          this%host_buffer_count = 1
        end if
        this%host_buffer_count = this%host_buffer_count - 1
      case (BATCH_DEVICE_PACKED)
        if (this%device_buffer_count == 1 .or. force_) then
          if (copy_) then
            select case (target)
              ! unpack from packed_device to unpacked_host
            case (BATCH_NOT_PACKED)
              call batch_read_device_to_unpacked(this)
              ! unpack from packed_device to packed_host
            case (BATCH_PACKED)
              call batch_read_device_to_packed(this, async_)
            end select
          end if
          if (async_) then
            this%needs_finish_unpack = .true.
          else
            call this%deallocate_packed_device()
          end if
          this%status_of = target
          this%device_buffer_count = 1
        end if
        this%device_buffer_count = this%device_buffer_count - 1
      end select
    end if

    call profiling_out("BATCH_DO_UNPACK")

    POP_SUB(batch_do_unpack)
  end subroutine batch_do_unpack

  ! ----------------------------------------------------
  !> @brief finish the unpacking if do_unpack() was called with async=.true.
  subroutine batch_finish_unpack(this)
    class(batch_t),      intent(inout)  :: this

    PUSH_SUB(batch_finish_unpack)
    if (this%needs_finish_unpack) then
      call accel_finish()
      call this%deallocate_packed_device()
      this%needs_finish_unpack = .false.
    end if
    POP_SUB(batch_finish_unpack)
  end subroutine batch_finish_unpack

  ! ----------------------------------------------------

  subroutine batch_write_unpacked_to_device(this)
    class(batch_t),      intent(inout)  :: this

    integer :: ist, ist2
    integer(int64) :: unroll
    type(accel_mem_t) :: tmp
    type(accel_kernel_t), pointer :: kernel

    PUSH_SUB(batch_write_unpacked_to_device)

    call profiling_in("BATCH_WRT_UNPACK_ACCEL")
    if (this%nst_linear == 1) then
      ! we can copy directly
      if (this%type() == TYPE_FLOAT) then
        call accel_write_buffer(this%ff_device, ubound(this%dff_linear, dim=1), this%dff_linear(:, 1))
      else if (this%type() == TYPE_CMPLX) then
        call accel_write_buffer(this%ff_device, ubound(this%zff_linear, dim=1), this%zff_linear(:, 1))
      else
        ASSERT(.false.)
      end if

    else
      ! we copy to a temporary array and then we re-arrange data

      if (this%type() == TYPE_FLOAT) then
        kernel => dpack
      else
        kernel => zpack
      end if

      unroll = min(int(CL_PACK_MAX_BUFFER_SIZE, int64), this%pack_size(1))

      call accel_create_buffer(tmp, ACCEL_MEM_READ_ONLY, this%type(), unroll*this%pack_size(2))

      do ist = 1, this%nst_linear, int(unroll, int32)

        ! copy a number 'unroll' of states to the buffer
        do ist2 = ist, min(ist + int(unroll, int32) - 1, this%nst_linear)

          if (this%type() == TYPE_FLOAT) then
            call accel_write_buffer(tmp, ubound(this%dff_linear, dim=1, kind=int64), this%dff_linear(:, ist2), &
              offset = (ist2 - ist)*this%pack_size(2))
          else
            call accel_write_buffer(tmp, ubound(this%zff_linear, dim=1, kind=int64), this%zff_linear(:, ist2), &
              offset = (ist2 - ist)*this%pack_size(2))
          end if
        end do

        ! now call an opencl kernel to rearrange the data
        call accel_set_kernel_arg(kernel, 0, int(this%pack_size(1), int32))
        call accel_set_kernel_arg(kernel, 1, int(this%pack_size(2), int32))
        call accel_set_kernel_arg(kernel, 2, ist - 1)
        call accel_set_kernel_arg(kernel, 3, tmp)
        call accel_set_kernel_arg(kernel, 4, this%ff_device)

        call profiling_in("CL_PACK")
        call accel_kernel_run(kernel, (/this%pack_size(2), unroll/), (/accel_max_workgroup_size()/unroll, unroll/))

        if (this%type() == TYPE_FLOAT) then
          call profiling_count_transfers(unroll*this%pack_size(2), M_ONE)
        else
          call profiling_count_transfers(unroll*this%pack_size(2), M_ZI)
        end if

        call accel_finish()
        call profiling_out("CL_PACK")

      end do

      call accel_release_buffer(tmp)

    end if

    call profiling_out("BATCH_WRT_UNPACK_ACCEL")
    POP_SUB(batch_write_unpacked_to_device)
  end subroutine batch_write_unpacked_to_device

  ! ------------------------------------------------------------------

  subroutine batch_read_device_to_unpacked(this)
    class(batch_t),      intent(inout) :: this

    integer :: ist, ist2
    integer(int64) :: unroll
    type(accel_mem_t) :: tmp
    type(accel_kernel_t), pointer :: kernel

    PUSH_SUB(batch_read_device_to_unpacked)
    call profiling_in("BATCH_READ_UNPACKED_ACCEL")

    if (this%nst_linear == 1) then
      ! we can copy directly
      if (this%type() == TYPE_FLOAT) then
        call accel_read_buffer(this%ff_device, ubound(this%dff_linear, dim=1), this%dff_linear(:, 1))
      else
        call accel_read_buffer(this%ff_device, ubound(this%zff_linear, dim=1), this%zff_linear(:, 1))
      end if
    else

      unroll = min(int(CL_PACK_MAX_BUFFER_SIZE, int64), this%pack_size(1))

      ! we use a kernel to move to a temporary array and then we read
      call accel_create_buffer(tmp, ACCEL_MEM_WRITE_ONLY, this%type(), unroll*this%pack_size(2))

      if (this%type() == TYPE_FLOAT) then
        kernel => dunpack
      else
        kernel => zunpack
      end if

      do ist = 1, this%nst_linear, int(unroll, int32)
        call accel_set_kernel_arg(kernel, 0, int(this%pack_size(1), int32))
        call accel_set_kernel_arg(kernel, 1, int(this%pack_size(2), int32))
        call accel_set_kernel_arg(kernel, 2, ist - 1)
        call accel_set_kernel_arg(kernel, 3, this%ff_device)
        call accel_set_kernel_arg(kernel, 4, tmp)

        call profiling_in("CL_UNPACK")
        call accel_kernel_run(kernel, (/unroll, this%pack_size(2)/), (/unroll, accel_max_workgroup_size()/unroll/))

        if (this%type() == TYPE_FLOAT) then
          call profiling_count_transfers(unroll*this%pack_size(2), M_ONE)
        else
          call profiling_count_transfers(unroll*this%pack_size(2), M_ZI)
        end if

        call accel_finish()
        call profiling_out("CL_UNPACK")

        ! copy a number 'unroll' of states from the buffer
        do ist2 = ist, min(ist + int(unroll, int32) - 1, this%nst_linear)

          if (this%type() == TYPE_FLOAT) then
            call accel_read_buffer(tmp, ubound(this%dff_linear, dim=1, kind=int64), this%dff_linear(:, ist2), &
              offset = (ist2 - ist)*this%pack_size(2))
          else
            call accel_read_buffer(tmp, ubound(this%zff_linear, dim=1, kind=int64), this%zff_linear(:, ist2), &
              offset = (ist2 - ist)*this%pack_size(2))
          end if
        end do

      end do

      call accel_release_buffer(tmp)
    end if

    call profiling_out("BATCH_READ_UNPACKED_ACCEL")
    POP_SUB(batch_read_device_to_unpacked)
  end subroutine batch_read_device_to_unpacked

  ! ------------------------------------------------------------------
  subroutine batch_write_packed_to_device(this, async)
    class(batch_t),      intent(inout)  :: this
    logical,   optional, intent(in)     :: async


    PUSH_SUB(batch_write_packed_to_device)

    call profiling_in("BATCH_WRITE_PACKED_ACCEL")
    if (this%type() == TYPE_FLOAT) then
      call accel_write_buffer(this%ff_device, product(this%pack_size), this%dff_pack, async=async)
    else
      call accel_write_buffer(this%ff_device, product(this%pack_size), this%zff_pack, async=async)
    end if
    call profiling_out("BATCH_WRITE_PACKED_ACCEL")

    POP_SUB(batch_write_packed_to_device)
  end subroutine batch_write_packed_to_device

  ! ------------------------------------------------------------------
  subroutine batch_read_device_to_packed(this, async)
    class(batch_t),      intent(inout) :: this
    logical,   optional, intent(in)    :: async


    PUSH_SUB(batch_read_device_to_packed)

    call profiling_in("BATCH_READ_PACKED_ACCEL")
    if (this%type() == TYPE_FLOAT) then
      call accel_read_buffer(this%ff_device, product(this%pack_size), this%dff_pack, async=async)
    else
      call accel_read_buffer(this%ff_device, product(this%pack_size), this%zff_pack, async=async)
    end if
    call profiling_out("BATCH_READ_PACKED_ACCEL")

    POP_SUB(batch_read_device_to_packed)
  end subroutine batch_read_device_to_packed

  ! ------------------------------------------------------
  !> @brief inverse index lookup
  !!
  !! This function returns the linear index for \(ist, idim\), where ist ranges from 1 to st%nst.
  !!
  integer function batch_inv_index(this, cind) result(index)
    class(batch_t),     intent(in)    :: this    !< the batch
    integer,            intent(in)    :: cind(:) !< combined index \(ist, idim\)

    do index = 1, this%nst_linear
      if (all(cind(1:this%ndims) == this%ist_idim_index(index, 1:this%ndims))) exit
    end do

    ASSERT(index <= this%nst_linear)

  end function batch_inv_index

  ! ------------------------------------------------------
  !> @brief direct index lookup
  !!
  !! This function returns the linear index for \(ist, idim\), where ist ranges from 1 to this%nst.
  !
  integer pure function batch_ist_idim_to_linear(this, cind) result(index)
    class(batch_t),     intent(in)    :: this    !< the batch
    integer,            intent(in)    :: cind(:) !< combined index \(ist, idim\)

    if (ubound(cind, dim = 1) == 1) then
      index = cind(1)
    else
      index = (cind(1) - 1)*this%dim + cind(2)
    end if

  end function batch_ist_idim_to_linear

  ! ------------------------------------------------------
  !> @brief get state index ist from linear (combined dim and nst) index
  !!
  !! The linear index interleaves the state index with the dimension,
  !! resulting in a one-dimensional ordering of states.
  !
  integer pure function batch_linear_to_ist(this, linear_index) result(ist)
    class(batch_t),     intent(in)    :: this
    integer,            intent(in)    :: linear_index

    ist = this%ist_idim_index(linear_index, 1)

  end function batch_linear_to_ist

  ! ------------------------------------------------------
  !> @brief extract idim from linear index
  !
  integer pure function batch_linear_to_idim(this, linear_index) result(idim)
    class(batch_t),     intent(in)    :: this
    integer,            intent(in)    :: linear_index

    idim = this%ist_idim_index(linear_index, 2)

  end function batch_linear_to_idim

  ! ------------------------------------------------------
  !> @brief start remote access to a batch on another node
  !!
  !! This routine creates a remote access window for a given batch
  !! and returns a handle to that window. A handle of -1 indicates
  !! that no window was created.
  !!
  !! @note this is currently not allowed when using GPUs
  !! @note side effect: the packing of the batch is increased by one
  !
  subroutine batch_remote_access_start(this, mpi_grp, rma_win)
    class(batch_t),  intent(inout) :: this     !< the current batch
    type(mpi_grp_t), intent(in)    :: mpi_grp  !< the MPI group
    type(MPI_Win),   intent(out)   :: rma_win  !< handle of rma window

    PUSH_SUB(batch_remote_access_start)

    if (mpi_grp%size > 1) then

      ASSERT(.not. accel_is_enabled())

      call this%do_pack()

      if (this%type() == TYPE_CMPLX) then
#ifdef HAVE_MPI
        call MPI_Win_create(this%zff_pack(1, 1), int(product(this%pack_size)*types_get_size(this%type()), MPI_ADDRESS_KIND), &
          types_get_size(this%type()), MPI_INFO_NULL, mpi_grp%comm, rma_win, mpi_err)
#endif
      else if (this%type() == TYPE_FLOAT) then
#ifdef HAVE_MPI
        call MPI_Win_create(this%dff_pack(1, 1), int(product(this%pack_size)*types_get_size(this%type()), MPI_ADDRESS_KIND), &
          types_get_size(this%type()), MPI_INFO_NULL, mpi_grp%comm, rma_win, mpi_err)
#endif
      else
        message(1) = "Internal error: unknown batch type in batch_remote_access_start."
        call messages_fatal(1)
      end if

    else
      rma_win = MPI_WIN_NULL
    end if

    POP_SUB(batch_remote_access_start)
  end subroutine batch_remote_access_start

  ! ------------------------------------------------------
  !> @brief stop the remote access to the batch
  !!
  !! If the rma window handle is valid, the window is freed.
  !!
  !! @note side effect the batch pack level is decreased by one.
  !
  subroutine batch_remote_access_stop(this, rma_win)
    class(batch_t),     intent(inout) :: this
    type(MPI_Win),      intent(inout) :: rma_win

    PUSH_SUB(batch_remote_access_stop)

    if (rma_win /= MPI_WIN_NULL) then
#ifdef HAVE_MPI
      call MPI_Win_free(rma_win, mpi_err)
#endif
      call this%do_unpack()
    end if

    POP_SUB(batch_remote_access_stop)
  end subroutine batch_remote_access_stop

  ! --------------------------------------------------------------
  !> @brief copy data to another batch.
  !
  subroutine batch_copy_data_to(this, np, dest, async)
    class(batch_t),    intent(in)    :: this !< source batch
    integer,           intent(in)    :: np   !< number of points to copy for each mesh function
    class(batch_t),    intent(inout) :: dest !< destination batch
    logical, optional, intent(in)    :: async!< asynchronous GPU operations or not

    integer(int64) :: localsize, dim2, dim3
    integer :: ist, ip

    PUSH_SUB(batch_copy_data_to)
    call profiling_in("BATCH_COPY_DATA_TO")

    call this%check_compatibility_with(dest)

    select case (this%status())
    case (BATCH_DEVICE_PACKED)
      call accel_set_kernel_arg(kernel_copy, 0, np)
      call accel_set_kernel_arg(kernel_copy, 1, this%ff_device)
      call accel_set_kernel_arg(kernel_copy, 2, log2(int(this%pack_size_real(1), int32)))
      call accel_set_kernel_arg(kernel_copy, 3, dest%ff_device)
      call accel_set_kernel_arg(kernel_copy, 4, log2(int(dest%pack_size_real(1), int32)))

      localsize = accel_kernel_workgroup_size(kernel_copy)/dest%pack_size_real(1)

      dim3 = np/(accel_max_size_per_dim(2)*localsize) + 1
      dim2 = min(accel_max_size_per_dim(2)*localsize, pad(int(np, int64), localsize))

      call accel_kernel_run(kernel_copy, (/dest%pack_size_real(1), dim2, dim3/), (/dest%pack_size_real(1), localsize, 1_int64/))

      if(.not. optional_default(async, .false.)) call accel_finish()

    case (BATCH_PACKED)
      if (np*this%pack_size(1) > huge(0_int32)) then
        ! BLAS cannot handle 8-byte integers, so we need a special version here
        do ip = 1, np
          if (dest%type() == TYPE_FLOAT) then
            call blas_copy(int(this%pack_size(1), int32), this%dff_pack(1, ip), 1, dest%dff_pack(1, ip), 1)
          else
            call blas_copy(int(this%pack_size(1), int32), this%zff_pack(1, ip), 1, dest%zff_pack(1, ip), 1)
          end if
        end do
      else
        if (dest%type() == TYPE_FLOAT) then
          call blas_copy(int(this%pack_size(1)*np, int32), this%dff_pack(1, 1), 1, dest%dff_pack(1, 1), 1)
        else
          call blas_copy(int(this%pack_size(1)*np, int32), this%zff_pack(1, 1), 1, dest%zff_pack(1, 1), 1)
        end if
      end if

    case (BATCH_NOT_PACKED)
      do ist = 1, dest%nst_linear
        if (dest%type() == TYPE_CMPLX) then
          call blas_copy(np, this%zff_linear(1, ist), 1, dest%zff_linear(1, ist), 1)
        else
          call blas_copy(np, this%dff_linear(1, ist), 1, dest%dff_linear(1, ist), 1)
        end if
      end do

    end select

    call profiling_out("BATCH_COPY_DATA_TO")
    POP_SUB(batch_copy_data_to)
  end subroutine batch_copy_data_to

  ! --------------------------------------------------------------
  !> @brief check whether two batches have compatible dimensions (and type)
  !
  subroutine batch_check_compatibility_with(this, target, only_check_dim)
    class(batch_t),    intent(in) :: this
    class(batch_t),    intent(in) :: target
    logical, optional, intent(in) :: only_check_dim

    PUSH_SUB(batch_check_compatibility_with)

    ASSERT(this%type() == target%type())
    if (.not. optional_default(only_check_dim, .false.)) then
      ASSERT(this%nst_linear == target%nst_linear)
    end if
    ASSERT(this%status() == target%status())
    ASSERT(this%dim == target%dim)

    POP_SUB(batch_check_compatibility_with)

  end subroutine batch_check_compatibility_with

!--------------------------------------------------------------
  !> @brief build the index ist(:) and ist_idim_index(:,:) and set pack_size
  !!
  subroutine batch_build_indices(this, st_start, st_end)
    class(batch_t), intent(inout) :: this
    integer,        intent(in)    :: st_start
    integer,        intent(in)    :: st_end

    integer :: idim, ii, ist

    PUSH_SUB(batch_build_indices)

    do ist = st_start, st_end
      ! now we also populate the linear array
      do idim = 1, this%dim
        ii = this%dim*(ist - st_start) + idim
        this%ist_idim_index(ii, 1) = ist
        this%ist_idim_index(ii, 2) = idim
      end do
      this%ist(ist - st_start + 1) = ist
    end do

    ! compute packed sizes
    this%pack_size(1) = pad_pow2(this%nst_linear)
    this%pack_size(2) = this%np
    if (accel_is_enabled()) this%pack_size(2) = accel_padded_size(this%pack_size(2))

    this%pack_size_real = this%pack_size
    if (type_is_complex(this%type())) this%pack_size_real(1) = 2*this%pack_size_real(1)

    POP_SUB(batch_build_indices)
  end subroutine batch_build_indices


#include "real.F90"
#include "batch_inc.F90"
#include "undef.F90"

#include "complex.F90"
#include "batch_inc.F90"
#include "undef.F90"

end module batch_oct_m

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
