!! Copyright (C) 2008 X. Andrade
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

! --------------------------------------------------------------

!> This routine applies a 'pair-wise' axpy operation to all functions
!! of the batches xx and yy, where the same constant aa is used for all functions.
!
subroutine X(batch_axpy_const)(np, aa, xx, yy)
  integer,           intent(in)    :: np !< number of points
  R_TYPE,            intent(in)    :: aa
  class(batch_t),    intent(in)    :: xx
  class(batch_t),    intent(inout) :: yy !< yy(:,:) = aa*xx(:,:) + yy(:,:)

  integer :: ist
  integer(int64) :: localsize, dim2, dim3
  complex(real64) :: zaa

  PUSH_SUB(X(batch_axpy_const))
  call profiling_in(TOSTRING(X(BATCH_AXPY_CONST)))

  call xx%check_compatibility_with(yy)
#ifdef R_TCOMPLEX
  !if aa is complex, the functions must be complex
  ASSERT(yy%type() == TYPE_CMPLX)
#endif

  select case (xx%status())
  case (BATCH_DEVICE_PACKED)
    if (yy%type() == TYPE_FLOAT) then

      call accel_set_kernel_arg(kernel_daxpy, 0, np)
      call accel_set_kernel_arg(kernel_daxpy, 1, aa)
      call accel_set_kernel_arg(kernel_daxpy, 2, xx%ff_device)
      call accel_set_kernel_arg(kernel_daxpy, 3, int(log2(xx%pack_size(1)), int32))
      call accel_set_kernel_arg(kernel_daxpy, 4, yy%ff_device)
      call accel_set_kernel_arg(kernel_daxpy, 5, int(log2(yy%pack_size(1)), int32))

      localsize = accel_kernel_workgroup_size(kernel_daxpy)/yy%pack_size(1)

      dim3 = np/(accel_max_size_per_dim(2)*localsize) + 1
      dim2 = min(accel_max_size_per_dim(2)*localsize, pad(np, localsize))

      call accel_kernel_run(kernel_daxpy, (/yy%pack_size(1), dim2, dim3/), (/yy%pack_size(1), localsize, 1_int64/))

    else
      zaa = aa
      call accel_set_kernel_arg(kernel_zaxpy, 0, np)
      call accel_set_kernel_arg(kernel_zaxpy, 1, zaa)
      call accel_set_kernel_arg(kernel_zaxpy, 2, xx%ff_device)
      call accel_set_kernel_arg(kernel_zaxpy, 3, int(log2(xx%pack_size(1)), int32))
      call accel_set_kernel_arg(kernel_zaxpy, 4, yy%ff_device)
      call accel_set_kernel_arg(kernel_zaxpy, 5, int(log2(yy%pack_size(1)), int32))

      localsize = accel_kernel_workgroup_size(kernel_zaxpy)/yy%pack_size(1)

      dim3 = np/(accel_max_size_per_dim(2)*localsize) + 1
      dim2 = min(accel_max_size_per_dim(2)*localsize, pad(np, localsize))

      call accel_kernel_run(kernel_zaxpy, (/yy%pack_size(1), dim2, dim3/), (/yy%pack_size(1), localsize, 1_int64/))

    end if

    call accel_finish()

  case (BATCH_PACKED)
    if (yy%type() == TYPE_CMPLX) then
      call lalg_axpy(int(xx%pack_size(1), int32), np, aa, xx%zff_pack, yy%zff_pack)
    else
#ifdef R_TREAL
      call lalg_axpy(int(xx%pack_size(1), int32), np, aa, xx%dff_pack, yy%dff_pack)
#endif
    end if

  case (BATCH_NOT_PACKED)
    do ist = 1, yy%nst_linear
      if (yy%type() == TYPE_CMPLX) then
        call lalg_axpy(np, aa, xx%zff_linear(:, ist), yy%zff_linear(:, ist))
      else
#ifdef R_TREAL
        call lalg_axpy(np, aa, xx%dff_linear(:, ist), yy%dff_linear(:, ist))
#endif
      end if
    end do
  end select

  call profiling_count_operations(xx%nst_linear*np*(R_ADD + R_MUL))

  call profiling_out(TOSTRING(X(BATCH_AXPY_CONST)))
  POP_SUB(X(batch_axpy_const))
end subroutine X(batch_axpy_const)

! --------------------------------------------------------------

!> This routine applies an 'pair-wise' axpy operation to all functions
!! of the batches xx and yy, where the constant aa(ist) is used for the
!! mesh functions in the batch
subroutine X(batch_axpy_vec)(np, aa, xx, yy, a_start, a_full)
  integer,            intent(in)    :: np      !< number of points
  R_TYPE,             intent(in)    :: aa(:)   !< array of multipliers
  class(batch_t),     intent(in)    :: xx
  class(batch_t),     intent(inout) :: yy      !< y(ist,:) = aa(ist) * x(ist,:) + y(ist,:)
  integer,  optional, intent(in)    :: a_start !< first state to operate on (default = 1)
  logical,  optional, intent(in)    :: a_full  !< @brief Is aa of size st:nst?
  !!
  !!                                              By default, aa is expected to be of size st%nst,
  !!                                              i.e., an array of the size of all states.
  !!                                              The correct states will be picked from the indices
  !!                                              stored in each batch. This is used, for example,
  !!                                              for computing residuals given the eigenvalues.
  !!                                              For a_full=.false., aa is expected to be of size
  !!                                              yy%nst_linear, i.e. it has the size of the batch only.

  integer :: ist, ip, effsize, iaa
  R_TYPE, allocatable     :: aa_linear(:)
  integer(int64) :: localsize, dim2, dim3
  integer :: size_factor
  type(accel_mem_t)      :: aa_buffer
#ifdef R_TREAL
  real(real64),  allocatable     :: aa_linear_double(:)
#endif
  type(accel_kernel_t), save :: kernel

  PUSH_SUB(X(batch_axpy_vec))

  ASSERT(not_in_openmp())

  call profiling_in(TOSTRING(X(BATCH_AXPY_VEC)))

  call xx%check_compatibility_with(yy)

  effsize = yy%nst_linear
  if (yy%is_packed()) effsize = int(yy%pack_size(1), int32)
  SAFE_ALLOCATE(aa_linear(1:effsize))

  aa_linear = M_ZERO
  do ist = 1, yy%nst_linear
    iaa = xx%linear_to_ist(ist) - (optional_default(a_start, 1) - 1)
    ! shift index, if necessary
    if (.not. optional_default(a_full, .true.)) iaa = iaa - (xx%linear_to_ist(1) - 1)
    aa_linear(ist) = aa(iaa)
  end do

  select case (xx%status())
  case (BATCH_DEVICE_PACKED)
    call accel_kernel_start_call(kernel, 'axpy.cl', TOSTRING(X(axpy_vec)), flags = '-D' + R_TYPE_CL)

    if (yy%type() == TYPE_CMPLX) then
#ifdef R_TREAL
      size_factor = 2
      SAFE_ALLOCATE(aa_linear_double(1:2*yy%pack_size(1)))
      do ist = 1, int(yy%pack_size(1), int32)
        aa_linear_double(2*ist - 1) = aa_linear(ist)
        aa_linear_double(2*ist) = aa_linear(ist)
      end do
      call accel_create_buffer(aa_buffer, ACCEL_MEM_READ_ONLY, TYPE_FLOAT, 2*yy%pack_size(1))
      call accel_write_buffer(aa_buffer, 2*yy%pack_size(1), aa_linear_double, async=.true.)
#else
      size_factor = 1
      call accel_create_buffer(aa_buffer, ACCEL_MEM_READ_ONLY, TYPE_CMPLX, yy%pack_size(1))
      call accel_write_buffer(aa_buffer, yy%pack_size(1), aa_linear, async=.true.)
#endif
    else
#ifdef R_TREAL
      size_factor = 1
      call accel_create_buffer(aa_buffer, ACCEL_MEM_READ_ONLY, TYPE_FLOAT, yy%pack_size(1))
      call accel_write_buffer(aa_buffer, yy%pack_size(1), aa_linear, async=.true.)
#else
      !if aa is complex, the functions must be complex
      ASSERT(.false.)
#endif
    end if

    call accel_set_kernel_arg(kernel, 0, np)
    call accel_set_kernel_arg(kernel, 1, aa_buffer)
    call accel_set_kernel_arg(kernel, 2, xx%ff_device)
    call accel_set_kernel_arg(kernel, 3, int(log2(xx%pack_size(1)*size_factor), int32))
    call accel_set_kernel_arg(kernel, 4, yy%ff_device)
    call accel_set_kernel_arg(kernel, 5, int(log2(yy%pack_size(1)*size_factor), int32))

    localsize = accel_kernel_workgroup_size(kernel)/(yy%pack_size(1)*size_factor)

    dim3 = np/(accel_max_size_per_dim(2)*localsize) + 1
    dim2 = min(accel_max_size_per_dim(2)*localsize, pad(np, localsize))

    call accel_kernel_run(kernel, (/yy%pack_size(1)*size_factor, dim2, dim3/), &
      (/yy%pack_size(1)*size_factor, localsize, 1_int64/))

    ! we need to wait here to make sure we deallocate the CPU arrays only
    ! after they have been copied to the GPU
    call accel_finish()

    call accel_release_buffer(aa_buffer)

  case (BATCH_PACKED)
    if (yy%type() == TYPE_CMPLX) then
      !$omp parallel do private(ist)
      do ip = 1, np
        !$omp simd
        do ist = 1, yy%nst_linear
          yy%zff_pack(ist, ip) = aa_linear(ist)*xx%zff_pack(ist, ip) + yy%zff_pack(ist, ip)
        end do
      end do
    else
#ifdef R_TREAL
      !$omp parallel do private(ist)
      do ip = 1, np
        !$omp simd
        do ist = 1, yy%nst_linear
          yy%dff_pack(ist, ip) = aa_linear(ist)*xx%dff_pack(ist, ip) + yy%dff_pack(ist, ip)
        end do
      end do
#else
      !if aa is complex, the functions must be complex
      ASSERT(.false.)
#endif
    end if

  case (BATCH_NOT_PACKED)
    do ist = 1, yy%nst_linear
      if (yy%type() == TYPE_CMPLX) then
        call lalg_axpy(np, aa_linear(ist), xx%zff_linear(:, ist), yy%zff_linear(:, ist))
      else
#ifdef R_TREAL
        call lalg_axpy(np, aa_linear(ist), xx%dff_linear(:, ist), yy%dff_linear(:, ist))
#else
        !if aa is complex, the functions must be complex
        ASSERT(.false.)
#endif
      end if
    end do
  end select

  SAFE_DEALLOCATE_A(aa_linear)
#ifdef R_TREAL
  SAFE_DEALLOCATE_A(aa_linear_double)
#endif

  call profiling_count_operations(xx%nst_linear*np*(R_ADD + R_MUL))

  call profiling_out(TOSTRING(X(BATCH_AXPY_VEC)))
  POP_SUB(X(batch_axpy_vec))
end subroutine X(batch_axpy_vec)


! --------------------------------------------------------------------------
!> This routine performs a set of axpy operations for each function x of a batch (xx),
!! and accumulate the result to y (psi in this case), a single function.
subroutine X(batch_axpy_function)(np, aa, xx, psi, nst)
  integer,           intent(in)    :: np         !< number of points
  class(batch_t),    intent(in)    :: xx         !< input batch
  R_TYPE, contiguous,intent(inout) :: psi(:,:)   !< result: \f$ psi = \sum_{ist=1}^{nst} aa(ist) * xx(ist) \f$
  R_TYPE,            intent(in)    :: aa(:)      !< array of multipliers
  integer, optional, intent(in)    :: nst        !< optional upper bound of sum

  integer :: ist, indb, idim, nst_
  R_TYPE, allocatable :: phi(:,:)

  ! GPU related variables
  type(accel_mem_t) :: aa_buffer
  type(accel_mem_t) :: psi_buffer
  integer :: wgsize, np_padded
  integer :: local_sizes(3)
  integer :: global_sizes(3)

  PUSH_SUB(X(batch_axpy_function))
  call profiling_in(TOSTRING(X(BATCH_AXPY_FUNCTION)))

  ASSERT(xx%dim == ubound(psi,dim=2))

  nst_ = xx%nst
  if (present(nst)) nst_ = nst


  select case (xx%status())
  case (BATCH_NOT_PACKED)
    do ist = 1, nst_
      do idim = 1, xx%dim
        indb = xx%ist_idim_to_linear((/ist, idim/))
        if (abs(aa(ist)) < M_EPSILON) cycle
        call lalg_axpy(np, aa(ist), xx%X(ff_linear)(:, indb), psi(1:np, idim))
      end do
    end do

  case (BATCH_PACKED)

    if (xx%dim == 1) then

      call blas_gemv('T', nst_, np, R_TOTYPE(M_ONE), xx%X(ff_pack)(1,1), &
        ubound(xx%X(ff_pack), dim=1), aa(1), 1, R_TOTYPE(M_ONE), psi(1,1), 1)

    else !Spinor case

      SAFE_ALLOCATE(phi(1:np, 1:xx%dim))

      do ist = 1, nst_
        if (abs(aa(ist)) < M_EPSILON) cycle
        call batch_get_state(xx, ist, np, phi)
        do idim = 1, xx%dim
          call lalg_axpy(np, aa(ist), phi(1:np, idim), psi(1:np, idim))
        end do
      end do

      SAFE_DEALLOCATE_A(phi)

    end if

  case (BATCH_DEVICE_PACKED)

    call accel_create_buffer(aa_buffer, ACCEL_MEM_READ_ONLY, R_TYPE_VAL, nst_)
    call accel_write_buffer(aa_buffer, nst_, aa)

    np_padded = pad_pow2(np)

    call accel_create_buffer(psi_buffer, ACCEL_MEM_READ_WRITE, R_TYPE_VAL, np_padded * xx%dim)
    do idim= 1, xx%dim
      call accel_write_buffer(psi_buffer, np, psi(1:np,idim), offset=(idim-1)*np_padded, async=.true.)
    end do

    call accel_set_kernel_arg(X(kernel_batch_axpy), 0, np)
    call accel_set_kernel_arg(X(kernel_batch_axpy), 1, nst_)
    call accel_set_kernel_arg(X(kernel_batch_axpy), 2, xx%dim)
    call accel_set_kernel_arg(X(kernel_batch_axpy), 3, xx%ff_device)
    call accel_set_kernel_arg(X(kernel_batch_axpy), 4, int(log2(xx%pack_size(1)), int32))
    call accel_set_kernel_arg(X(kernel_batch_axpy), 5, aa_buffer)
    call accel_set_kernel_arg(X(kernel_batch_axpy), 6, psi_buffer)
    call accel_set_kernel_arg(X(kernel_batch_axpy), 7, log2(np_padded))

    wgsize = accel_kernel_workgroup_size(X(kernel_batch_axpy))

    global_sizes = (/ pad(np, wgsize/xx%dim), xx%dim, 1 /)
    local_sizes  = (/ wgsize/xx%dim,          xx%dim, 1 /)

    call accel_kernel_run(X(kernel_batch_axpy), global_sizes, local_sizes)

    do idim = 1, xx%dim
      call accel_read_buffer(psi_buffer, np, psi(1:np,idim), offset=(idim-1)*np_padded)
    end do

    call accel_release_buffer(aa_buffer)
    call accel_release_buffer(psi_buffer)


  end select

  call profiling_out(TOSTRING(X(BATCH_AXPY_FUNCTION)))
  POP_SUB(X(batch_axpy_function))
end subroutine X(batch_axpy_function)

! --------------------------------------------------------------------------
!> This routine performs a set of axpy operations adding the same function psi to all
!> functions of a batch (yy),
!
subroutine X(batch_ax_function_py)(np, aa, psi, yy)
  integer,            intent(in)    :: np         !< number of points
  R_TYPE,             intent(in)    :: aa(:)      !< array of multipliers
  R_TYPE, contiguous, intent(in)    :: psi(:,:)   !< mesh functions psi
  class(batch_t),     intent(inout) :: yy         !< resulting batch yy(ist) = yy(ist) + aa(ist) * psi

  integer :: ist, indb, idim, ip
  ! GPU related variables
  type(accel_mem_t) :: aa_buffer
  type(accel_mem_t) :: psi_buffer
  integer :: wgsize, np_padded
  integer :: local_sizes(3)
  integer :: global_sizes(3)

  PUSH_SUB(X(batch_ax_function_py))
  call profiling_in(TOSTRING(X(BATCH_AX_FUNCTION_PY)))

  ASSERT(yy%dim == ubound(psi,dim=2))
  ASSERT(ubound(aa, dim=1) == yy%nst_linear)

  select case (yy%status())
  case (BATCH_NOT_PACKED)
    do ist = 1, yy%nst
      do idim = 1, yy%dim
        indb = yy%ist_idim_to_linear((/ist, idim/))
        if (abs(aa(indb)) < M_EPSILON) cycle
        call lalg_axpy(np, aa(indb), psi(1:np, idim), yy%X(ff_linear)(:, indb))
      end do
    end do

  case (BATCH_PACKED)

    !$omp parallel do private(ist, idim, indb)
    do ip = 1, np
      do ist = 1, yy%nst
        do idim = 1, yy%dim
          indb = yy%ist_idim_to_linear((/ist, idim/))
          yy%X(ff_pack)(indb, ip) = yy%X(ff_pack)(indb, ip) + aa(indb) * psi(ip, idim)
        end do
      end do
    end do

  case (BATCH_DEVICE_PACKED)

    call accel_create_buffer(aa_buffer, ACCEL_MEM_READ_ONLY, R_TYPE_VAL, yy%nst_linear)
    call accel_write_buffer(aa_buffer, yy%nst_linear, aa)

    np_padded = pad_pow2(np)

    call accel_create_buffer(psi_buffer, ACCEL_MEM_READ_WRITE, R_TYPE_VAL, np_padded * yy%dim)
    do idim= 1, yy%dim
      call accel_write_buffer(psi_buffer, np, psi(1:np,idim), offset=(idim-1)*np_padded)
    end do

    call accel_set_kernel_arg(X(kernel_ax_function_py), 0, np)
    call accel_set_kernel_arg(X(kernel_ax_function_py), 1, yy%nst)
    call accel_set_kernel_arg(X(kernel_ax_function_py), 2, yy%dim)
    call accel_set_kernel_arg(X(kernel_ax_function_py), 3, yy%ff_device)
    call accel_set_kernel_arg(X(kernel_ax_function_py), 4, int(log2(yy%pack_size(1)), int32))
    call accel_set_kernel_arg(X(kernel_ax_function_py), 5, aa_buffer)
    call accel_set_kernel_arg(X(kernel_ax_function_py), 6, psi_buffer)
    call accel_set_kernel_arg(X(kernel_ax_function_py), 7, log2(np_padded))

    wgsize = accel_kernel_workgroup_size(X(kernel_ax_function_py))

    global_sizes = (/ pad(np, wgsize/yy%dim), yy%dim, 1 /)
    local_sizes  = (/ wgsize/yy%dim,          yy%dim, 1 /)

    call accel_kernel_run(X(kernel_ax_function_py), global_sizes, local_sizes)
    call accel_finish()

    call accel_release_buffer(aa_buffer)
    call accel_release_buffer(psi_buffer)

  end select

  call profiling_out(TOSTRING(X(BATCH_AX_FUNCTION_PY)))
  POP_SUB(X(batch_ax_function_py))
end subroutine X(batch_ax_function_py)

! --------------------------------------------------------------
!> scale all functions in a batch by constant aa
!!
subroutine X(batch_scal_const)(np, aa, xx)
  integer,           intent(in)    :: np  !< number of points
  R_TYPE,            intent(in)    :: aa
  class(batch_t),    intent(inout) :: xx  !< xx(ist) = xx(ist) * aa

  R_TYPE, allocatable :: aavec(:)

  PUSH_SUB(X(batch_scal_const))

  select case (xx%status())
  case (BATCH_PACKED)
    call profiling_in(TOSTRING(X(BATCH_SCAL_CONST)))

    if (xx%type() == TYPE_CMPLX) then
      call lalg_scal(int(xx%pack_size(1), int32), np, aa, xx%zff_pack)
    else
#ifdef R_TREAL
      call lalg_scal(int(xx%pack_size(1), int32), np, aa, xx%dff_pack)
#endif
    end if
    call profiling_count_operations(xx%nst_linear*np*R_MUL)
    call profiling_out(TOSTRING(X(BATCH_SCAL_CONST)))
  case default

    SAFE_ALLOCATE(aavec(1:xx%nst))
    aavec(1:xx%nst) = aa
    call X(batch_scal_vec)(np, aavec, xx, a_full = .false.)

    SAFE_DEALLOCATE_A(aavec)
  end select

  POP_SUB(X(batch_scal_const))
end subroutine X(batch_scal_const)

! --------------------------------------------------------------
!> scale all functions in a batch by state dependent constant
subroutine X(batch_scal_vec)(np, aa, xx, a_start, a_full)
  integer,           intent(in)    :: np        !< number of points
  R_TYPE,            intent(in)    :: aa(:)
  class(batch_t),    intent(inout) :: xx        !< xx(ist) = xx(ist) * aa(ist)
  integer, optional, intent(in)    :: a_start
  logical, optional, intent(in)    :: a_full

  integer :: ist, ip, effsize, iaa
  R_TYPE, allocatable     :: aa_linear(:)
  integer(int64) :: localsize, dim2, dim3
  integer :: size_factor
#ifdef R_TREAL
  real(real64),  allocatable     :: aa_linear_double(:)
#endif
  type(accel_mem_t)      :: aa_buffer
  type(accel_kernel_t), save :: kernel

  PUSH_SUB(X(batch_scal_vec))

  ASSERT(not_in_openmp())

  call profiling_in(TOSTRING(X(BATCH_SCAL_VEC)))

  effsize = xx%nst_linear
  if (xx%is_packed()) effsize = int(xx%pack_size(1), int32)
  SAFE_ALLOCATE(aa_linear(1:effsize))

  aa_linear = M_ZERO
  do ist = 1, xx%nst_linear
    iaa = xx%linear_to_ist(ist) - (optional_default(a_start, 1) - 1)
    if (.not. optional_default(a_full, .true.)) iaa = iaa - (xx%linear_to_ist(1) - 1)
    aa_linear(ist) = aa(iaa)
  end do

  select case (xx%status())
  case (BATCH_DEVICE_PACKED)

    if (xx%type() == TYPE_CMPLX) then
#ifdef R_TREAL
      size_factor = 2
      SAFE_ALLOCATE(aa_linear_double(1:2*xx%pack_size(1)))
      do ist = 1, int(xx%pack_size(1), int32)
        aa_linear_double(2*ist - 1) = aa_linear(ist)
        aa_linear_double(2*ist) = aa_linear(ist)
      end do
      call accel_create_buffer(aa_buffer, ACCEL_MEM_READ_ONLY, TYPE_FLOAT, 2*xx%pack_size(1))
      call accel_write_buffer(aa_buffer, 2*xx%pack_size(1), aa_linear_double, async=.true.)
#else
      size_factor = 1
      call accel_create_buffer(aa_buffer, ACCEL_MEM_READ_ONLY, TYPE_CMPLX, xx%pack_size(1))
      call accel_write_buffer(aa_buffer, xx%pack_size(1), aa_linear, async=.true.)
#endif
    else
#ifdef R_TREAL
      size_factor = 1
      call accel_create_buffer(aa_buffer, ACCEL_MEM_READ_ONLY, TYPE_FLOAT, xx%pack_size(1))
      call accel_write_buffer(aa_buffer, xx%pack_size(1), aa_linear, async=.true.)
#else
      !if aa is complex, the functions must be complex
      ASSERT(.false.)
#endif
    end if

    call accel_kernel_start_call(kernel, 'axpy.cl', TOSTRING(X(scal_vec)), flags = '-D' + R_TYPE_CL)

    call accel_set_kernel_arg(kernel, 0, np)
    call accel_set_kernel_arg(kernel, 1, aa_buffer)
    call accel_set_kernel_arg(kernel, 2, xx%ff_device)
    call accel_set_kernel_arg(kernel, 3, int(log2(xx%pack_size(1)*size_factor), int32))

    localsize = accel_kernel_workgroup_size(kernel)/(xx%pack_size(1)*size_factor)

    dim3 = np/(accel_max_size_per_dim(2)*localsize) + 1
    dim2 = min(accel_max_size_per_dim(2)*localsize, pad(np, localsize))

    call accel_kernel_run(kernel, (/xx%pack_size(1)*size_factor, dim2, dim3/), (/xx%pack_size(1)*size_factor, localsize, 1_int64/))

    ! we need to wait here to make sure we deallocate the CPU arrays only
    ! after they have been copied to the GPU
    call accel_finish()

    call accel_release_buffer(aa_buffer)

  case (BATCH_PACKED)
    if (xx%type() == TYPE_CMPLX) then
      !$omp parallel do private(ist)
      do ip = 1, np
        !$omp simd
        do ist = 1, int(xx%pack_size(1), int32)
          xx%zff_pack(ist, ip) = aa_linear(ist)*xx%zff_pack(ist, ip)
        end do
      end do
    else
#ifdef R_TREAL
      !$omp parallel do private(ist)
      do ip = 1, np
        !$omp simd
        do ist = 1, int(xx%pack_size(1), int32)
          xx%dff_pack(ist, ip) = aa_linear(ist)*xx%dff_pack(ist, ip)
        end do
      end do
#else
      !if aa is complex, the functions must be complex
      ASSERT(.false.)
#endif
    end if

  case (BATCH_NOT_PACKED)
    do ist = 1, xx%nst_linear
      if (xx%type() == TYPE_CMPLX) then
        call lalg_scal(np, aa_linear(ist), xx%zff_linear(:, ist))
      else
#ifdef R_TREAL
        call lalg_scal(np, aa_linear(ist), xx%dff_linear(:, ist))
#else
        !if aa is complex, the functions must be complex
        ASSERT(.false.)
#endif
      end if
    end do
  end select

  SAFE_DEALLOCATE_A(aa_linear)
#ifdef R_TREAL
  SAFE_DEALLOCATE_A(aa_linear_double)
#endif

  call profiling_count_operations(xx%nst_linear*np*R_MUL)

  call profiling_out(TOSTRING(X(BATCH_SCAL_VEC)))
  POP_SUB(X(batch_scal_vec))
end subroutine X(batch_scal_vec)

! --------------------------------------------------------------
!> calculate yy(ist,:) = xx(ist,:) + aa(ist)*yy(ist,:) for a batch
!
subroutine X(batch_xpay_vec)(np, xx, aa, yy, a_start, a_full)
  integer,           intent(in)    :: np        !< number of points
  class(batch_t),    intent(in)    :: xx
  R_TYPE,            intent(in)    :: aa(:)     !< array of constants aa(ist)
  class(batch_t),    intent(inout) :: yy        !< yy(ist,:) = xx(ist,:) + aa(ist)*yy(ist,:)
  integer, optional, intent(in)    :: a_start
  logical, optional, intent(in)    :: a_full

  integer :: ist, ip, effsize, iaa, size_factor
  R_TYPE, allocatable     :: aa_linear(:)
  integer(int64) :: localsize, dim2, dim3
#ifdef R_TREAL
  real(real64),  allocatable     :: aa_linear_double(:)
#endif
  type(accel_mem_t)      :: aa_buffer
  type(accel_kernel_t), save :: kernel

  PUSH_SUB(X(batch_xpay_vec))

  ASSERT(not_in_openmp())

  call profiling_in(TOSTRING(X(BATCH_XPAY)))

  call xx%check_compatibility_with(yy)

  effsize = yy%nst_linear
  if (yy%is_packed()) effsize = int(yy%pack_size(1), int32)
  SAFE_ALLOCATE(aa_linear(1:effsize))

  aa_linear = M_ZERO
  do ist = 1, yy%nst_linear
    iaa = xx%linear_to_ist(ist) - (optional_default(a_start, 1) - 1)
    if (.not. optional_default(a_full, .true.)) iaa = iaa - (xx%linear_to_ist(1) - 1)
    aa_linear(ist) = aa(iaa)
  end do

  select case (xx%status())
  case (BATCH_DEVICE_PACKED)
    if (yy%type() == TYPE_CMPLX) then
#ifdef R_TREAL
      size_factor = 2
      SAFE_ALLOCATE(aa_linear_double(1:2*yy%pack_size(1)))
      do ist = 1, int(yy%pack_size(1), int32)
        aa_linear_double(2*ist - 1) = aa_linear(ist)
        aa_linear_double(2*ist) = aa_linear(ist)
      end do
      call accel_create_buffer(aa_buffer, ACCEL_MEM_READ_ONLY, TYPE_FLOAT, 2*yy%pack_size(1))
      call accel_write_buffer(aa_buffer, 2*yy%pack_size(1), aa_linear_double, async=.true.)
      SAFE_DEALLOCATE_A(aa_linear_double)
#else
      size_factor = 1
      call accel_create_buffer(aa_buffer, ACCEL_MEM_READ_ONLY, TYPE_CMPLX, yy%pack_size(1))
      call accel_write_buffer(aa_buffer, yy%pack_size(1), aa_linear, async=.true.)
#endif
    else
#ifdef R_TREAL
      size_factor = 1
      call accel_create_buffer(aa_buffer, ACCEL_MEM_READ_ONLY, TYPE_FLOAT, yy%pack_size(1))
      call accel_write_buffer(aa_buffer, yy%pack_size(1), aa_linear, async=.true.)
#else
      !if aa is complex, the functions must be complex
      ASSERT(.false.)
#endif
    end if

    call accel_kernel_start_call(kernel, 'axpy.cl', TOSTRING(X(xpay_vec)), flags = '-D' + R_TYPE_CL)

    call accel_set_kernel_arg(kernel, 0, np)
    call accel_set_kernel_arg(kernel, 1, aa_buffer)
    call accel_set_kernel_arg(kernel, 2, xx%ff_device)
    call accel_set_kernel_arg(kernel, 3, int(log2(xx%pack_size(1)*size_factor), int32))
    call accel_set_kernel_arg(kernel, 4, yy%ff_device)
    call accel_set_kernel_arg(kernel, 5, int(log2(yy%pack_size(1)*size_factor), int32))

    localsize = accel_kernel_workgroup_size(kernel)/(yy%pack_size(1)*size_factor)

    dim3 = np/(accel_max_size_per_dim(2)*localsize) + 1
    dim2 = min(accel_max_size_per_dim(2)*localsize, pad(np, localsize))

    call accel_kernel_run(kernel, (/yy%pack_size(1)*size_factor, dim2, dim3/), (/yy%pack_size(1)*size_factor, localsize, 1_int64/))

    call accel_finish()

    call accel_release_buffer(aa_buffer)

  case (BATCH_PACKED)
    if (yy%type() == TYPE_CMPLX) then
      !$omp parallel do private(ip, ist)
      do ip = 1, np
        do ist = 1, yy%nst_linear
          yy%zff_pack(ist, ip) = xx%zff_pack(ist, ip) + aa_linear(ist)*yy%zff_pack(ist, ip)
        end do
      end do
    else
#ifdef R_TREAL
      !$omp parallel do private(ip, ist)
      do ip = 1, np
        do ist = 1, yy%nst_linear
          yy%dff_pack(ist, ip) = xx%dff_pack(ist, ip) + aa_linear(ist)*yy%dff_pack(ist, ip)
        end do
      end do
#else
      !if aa is complex, the functions must be complex
      ASSERT(.false.)
#endif
    end if

  case (BATCH_NOT_PACKED)
    do ist = 1, yy%nst_linear
      if (yy%type() == TYPE_CMPLX) then
        !$omp parallel do
        do ip = 1, np
          yy%zff_linear(ip, ist) = xx%zff_linear(ip, ist) + aa_linear(ist)*yy%zff_linear(ip, ist)
        end do
      else
#ifdef R_TREAL
        !$omp parallel do
        do ip = 1, np
          yy%dff_linear(ip, ist) = xx%dff_linear(ip, ist) + aa_linear(ist)*yy%dff_linear(ip, ist)
        end do
#else
        !if aa is complex, the functions must be complex
        ASSERT(.false.)
#endif
      end if
    end do
  end select

  call profiling_count_operations(xx%nst_linear*np*(R_ADD + R_MUL))

  SAFE_DEALLOCATE_A(aa_linear)

  call profiling_out(TOSTRING(X(BATCH_XPAY)))
  POP_SUB(X(batch_xpay_vec))
end subroutine X(batch_xpay_vec)

! --------------------------------------------------------------
!> calculate yy(ist) = xx(ist) + aa*yy(ist) for a batch
!
subroutine X(batch_xpay_const)(np, xx, aa, yy)
  integer,           intent(in)    :: np  !< number of points
  class(batch_t),    intent(in)    :: xx
  R_TYPE,            intent(in)    :: aa
  class(batch_t),    intent(inout) :: yy  !< yy(ist) = xx(ist) + a*yy(ist)

  integer :: minst, maxst, ii, ist
  R_TYPE, allocatable :: aavec(:)

  minst = HUGE(minst)
  maxst = -HUGE(maxst)

  do ii = 1, xx%nst_linear
    ist = xx%linear_to_ist(ii)
    minst = min(minst, ist)
    maxst = max(maxst, ist)
  end do


  SAFE_ALLOCATE(aavec(minst:maxst))

  aavec = aa

  call X(batch_xpay_vec)(np, xx, aavec, yy, a_start = minst)

  SAFE_DEALLOCATE_A(aavec)

end subroutine X(batch_xpay_const)

! --------------------------------------------------------------
!> Write a single state with np points into a batch at position ist
!!
subroutine X(batch_set_state1)(this, ist, np, psi)
  class(batch_t),     intent(inout) :: this     !< batch to write the state to
  integer,            intent(in)    :: ist      !< position where to write
  integer,            intent(in)    :: np       !< number of points
  R_TYPE, contiguous, intent(in)    :: psi(:)   !< the state to write

  integer :: ip
  type(accel_mem_t) :: tmp
#ifdef R_TREAL
  complex(real64), allocatable :: zpsi(:)
#endif

  PUSH_SUB(X(batch_set_state1))

  ASSERT(not_in_openmp())

  call profiling_in(TOSTRING(X(BATCH_SET_STATE)))

  ASSERT(ist >= 1 .and. ist <= this%nst_linear)

  select case (this%status())
  case (BATCH_NOT_PACKED)
    if (this%type() == TYPE_FLOAT) then
#ifdef R_TCOMPLEX
      ! cannot set a real batch with complex values
      ASSERT(.false.)
#else
      call lalg_copy(np, psi, this%dff_linear(:, ist))
#endif
    else
#ifdef R_TCOMPLEX
      call lalg_copy(np, psi, this%zff_linear(:, ist))
#else
      ! cannot set a complex unpacked batch with real values
      ASSERT(.false.)
#endif
    end if

  case (BATCH_PACKED)
    if (this%type() == TYPE_FLOAT) then
#ifdef R_TCOMPLEX
      ! cannot set a real batch with complex values
      ASSERT(.false.)
#else
      !$omp parallel do
      do ip = 1, np
        this%dff_pack(ist, ip) = psi(ip)
      end do
#endif
    else
      !$omp parallel do
      do ip = 1, np
        this%zff_pack(ist, ip) = psi(ip)
      end do
    end if

  case (BATCH_DEVICE_PACKED)

    call accel_create_buffer(tmp, ACCEL_MEM_READ_ONLY, this%type(), this%pack_size(2))

    if (this%type() == TYPE_FLOAT) then
#ifdef R_TCOMPLEX
      ! cannot set a real batch with complex values
      ASSERT(.false.)
#else
      call accel_write_buffer(tmp, np, psi)
#endif
    else
#ifdef R_TCOMPLEX
      call accel_write_buffer(tmp, np, psi)
#else
      ! this is not ideal, we should do the conversion on the GPU, so
      ! that we copy half of the data there

      SAFE_ALLOCATE(zpsi(1:np))
      !$omp parallel do
      do ip = 1, np
        zpsi(ip) = psi(ip)
      end do

      call accel_write_buffer(tmp, np, zpsi)

      SAFE_DEALLOCATE_A(zpsi)
#endif
    end if

    ! now call an opencl kernel to rearrange the data
    call accel_set_kernel_arg(X(pack), 0, int(this%pack_size(1), int32))
    call accel_set_kernel_arg(X(pack), 1, int(np, int32))
    call accel_set_kernel_arg(X(pack), 2, ist - 1)
    call accel_set_kernel_arg(X(pack), 3, tmp)
    call accel_set_kernel_arg(X(pack), 4, this%ff_device)

    call accel_kernel_run(X(pack), (/int(this%pack_size(2)), 1/), (/accel_max_workgroup_size(), 1/))

    call accel_finish()

    call accel_release_buffer(tmp)

  end select

  call profiling_out(TOSTRING(X(BATCH_SET_STATE)))

  POP_SUB(X(batch_set_state1))
end subroutine X(batch_set_state1)

! --------------------------------------------------------------
!> Write a single state with np points into a batch at position defined by index
!
subroutine X(batch_set_state2)(this, index, np, psi)
  class(batch_t),     intent(inout) :: this      !< batch to write the state into
  integer,            intent(in)    :: index(:)  !< how to access the state.
  !!                                                For further information see batch_ops_oct_m::batch_set_state
  integer,            intent(in)    :: np        !< number of points
  R_TYPE, contiguous, intent(in)    :: psi(:)    !< state to write to the batch

  PUSH_SUB(X(batch_set_state2))

  ASSERT(this%nst_linear > 0)
  call X(batch_set_state1)(this, this%inv_index(index), np, psi)

  POP_SUB(X(batch_set_state2))
end subroutine X(batch_set_state2)

! --------------------------------------------------------------
!> Write a set of state with np points into a batch
!
subroutine X(batch_set_state3)(this, ii, np, psi)
  class(batch_t),     intent(inout) :: this       !< batch to write the states into
  integer,            intent(in)    :: ii         !< potision
  integer,            intent(in)    :: np         !< number of points
  R_TYPE, contiguous, intent(in)    :: psi(:, :)  !< states to write

  integer :: i2

  PUSH_SUB(X(batch_set_state3))

  do i2 = 1, this%dim
    call X(batch_set_state1)(this, (ii - 1)*this%dim + i2, np, psi(:, i2))
  end do

  POP_SUB(X(batch_set_state3))
end subroutine X(batch_set_state3)

! --------------------------------------------------------------
!> Write a get of state with np points from a batch
!!
subroutine X(batch_get_state1)(this, ist, np, psi)
  class(batch_t),     intent(in)    :: this
  integer,            intent(in)    :: ist
  integer,            intent(in)    :: np
  R_TYPE, contiguous, intent(out)   :: psi(:)

  integer :: ip
  type(accel_mem_t) :: tmp
#ifdef R_TCOMPLEX
  real(real64), allocatable :: dpsi(:)
#endif

  PUSH_SUB(X(batch_get_state1))

  ASSERT(not_in_openmp())

  call profiling_in(TOSTRING(X(BATCH_GET_STATE)))

  ASSERT(ubound(psi, dim = 1) >= np)
  ASSERT(ist >= 1 .and. ist <= this%nst_linear)

  select case (this%status())
  case (BATCH_NOT_PACKED)
    if (this%type() == TYPE_FLOAT) then
      !$omp parallel do
      do ip = 1, np
        psi(ip) = this%dff_linear(ip, ist)
      end do
      !$omp end parallel do
    else
#ifdef R_TREAL
      ! cannot get a real value from a complex batch
      ASSERT(.false.)
#else
      !$omp parallel do
      do ip = 1, np
        psi(ip) = this%zff_linear(ip, ist)
      end do
      !$omp end parallel do
#endif
    end if

  case (BATCH_PACKED)
    if (this%type() == TYPE_FLOAT) then
      !$omp parallel do
      do ip = 1, np
        psi(ip) = this%dff_pack(ist, ip)
      end do
      !$omp end parallel do
    else
#ifdef R_TREAL
      ! cannot get a real value from a complex batch
      ASSERT(.false.)
#else
      !$omp parallel do
      do ip = 1, np
        psi(ip) = this%zff_pack(ist, ip)
      end do
      !$omp end parallel do
#endif
    end if

  case (BATCH_DEVICE_PACKED)

    ASSERT(np <= this%pack_size(2))

    call accel_create_buffer(tmp, ACCEL_MEM_WRITE_ONLY, this%type(), this%pack_size(2))

    if (this%type() == TYPE_FLOAT) then
      call accel_set_kernel_arg(dunpack, 0, int(this%pack_size(1), int32))
      call accel_set_kernel_arg(dunpack, 1, np)
      call accel_set_kernel_arg(dunpack, 2, ist - 1)
      call accel_set_kernel_arg(dunpack, 3, this%ff_device)
      call accel_set_kernel_arg(dunpack, 4, tmp)

      call accel_kernel_run(dunpack, (/1, int(this%pack_size(2), int32)/), (/1, accel_max_workgroup_size()/))

      call accel_finish()

#ifdef R_TREAL
      call accel_read_buffer(tmp, np, psi)
#else
      SAFE_ALLOCATE(dpsi(1:np))

      call accel_finish()

      call accel_read_buffer(tmp, np, dpsi)

      ! and convert to complex on the cpu

      !$omp parallel do
      do ip = 1, np
        psi(ip) = dpsi(ip)
      end do

      SAFE_DEALLOCATE_A(dpsi)
#endif
    else
      call accel_set_kernel_arg(zunpack, 0, int(this%pack_size(1), int32))
      call accel_set_kernel_arg(zunpack, 1, np)
      call accel_set_kernel_arg(zunpack, 2, ist - 1)
      call accel_set_kernel_arg(zunpack, 3, this%ff_device)
      call accel_set_kernel_arg(zunpack, 4, tmp)

      call accel_kernel_run(zunpack, (/1, int(this%pack_size(2), int32)/), (/1, accel_max_workgroup_size()/))

      call accel_finish()

#ifdef R_TREAL
      ! cannot get a real value from a complex batch
      ASSERT(.false.)
#else
      call accel_read_buffer(tmp, np, psi)
#endif
    end if

    call accel_release_buffer(tmp)

  end select

  call profiling_out(TOSTRING(X(BATCH_GET_STATE)))

  POP_SUB(X(batch_get_state1))
end subroutine X(batch_get_state1)

! --------------------------------------------------------------

subroutine X(batch_get_state2)(this, index, np, psi)
  class(batch_t),     intent(in)    :: this
  integer,            intent(in)    :: index(:)
  integer,            intent(in)    :: np
  R_TYPE, contiguous, intent(out)   :: psi(:)

  PUSH_SUB(X(batch_get_state2))

  ASSERT(this%nst_linear > 0)
  call X(batch_get_state1)(this, this%inv_index(index), np, psi)

  POP_SUB(X(batch_get_state2))
end subroutine X(batch_get_state2)


! --------------------------------------------------------------

subroutine X(batch_get_state3)(this, ii, np, psi)
  class(batch_t),     intent(in)    :: this
  integer,            intent(in)    :: ii
  integer,            intent(in)    :: np
  R_TYPE, contiguous, intent(out)   :: psi(:, :)

  integer :: i2

  PUSH_SUB(X(batch_get_state3))

  do i2 = 1, this%dim
    call X(batch_get_state1)(this, (ii - 1)*this%dim + i2, np, psi(:, i2))
  end do

  POP_SUB(X(batch_get_state3))
end subroutine X(batch_get_state3)

! --------------------------------------------------------------
!> @brief copy a set of points into a mesh function
!!
subroutine X(batch_get_points)(this, sp, ep, psi)
  class(batch_t),     intent(in)    :: this           !< the batch to get points from
  integer,            intent(in)    :: sp             !< starting point
  integer,            intent(in)    :: ep             !< end point
  R_TYPE, contiguous, intent(inout) :: psi(:, :, sp:) !< mesh function into which to write the points;
  !!                                                 dimensions (1:nst, 1:dim, sp:ep)

  integer :: idim, ist, ii, ip, shift

  PUSH_SUB(X(batch_get_points))

  ASSERT(not_in_openmp())

  call profiling_in(TOSTRING(X(GET_POINTS)))

#ifdef R_TREAL
  ! cannot get a real value from a complex batch
  ASSERT(this%type() /= TYPE_CMPLX)
#endif

  select case (this%status())
  case (BATCH_NOT_PACKED)

    if (this%type() == TYPE_FLOAT) then
      do ii = 1, this%nst_linear
        ist = this%linear_to_ist(ii)
        idim = this%linear_to_idim(ii)
        psi(ist, idim, sp:ep) = this%dff_linear(sp:ep, ii)
      end do
    else
#ifdef R_TREAL
      ! cannot get a real value from a complex batch
      ASSERT(.false.)
#else
      do ii = 1, this%nst_linear
        ist = this%linear_to_ist(ii)
        idim = this%linear_to_idim(ii)
        psi(ist, idim, sp:ep) = this%zff_linear(sp:ep, ii)
      end do
#endif
    end if

  case (BATCH_PACKED)

    shift = this%ist(1) - 1

    if (this%type() == TYPE_FLOAT) then

      !$omp parallel do private(ip, ist, idim, ii)
      do ip = sp, ep
        do idim = 1, this%dim
          !$omp simd
          do ist = 1, this%nst
            ii = (ist -1) * this%dim + idim
            psi(ist+shift, idim, ip) = this%dff_pack(ii, ip)
          end do
        end do
      end do

    else
#ifdef R_TREAL
      ! cannot get a real value from a complex batch
      ASSERT(.false.)
#else
      !$omp parallel do private(ip, ist, idim, ii)
      do ip = sp, ep
        do idim = 1, this%dim
          !$omp simd
          do ist = 1, this%nst
            ii = (ist -1) * this%dim + idim
            psi(ist+shift, idim, ip) = this%zff_pack(ii, ip)
          end do
        end do
      end do
      !$omp end parallel do
#endif
    end if

  case (BATCH_DEVICE_PACKED)
    call messages_not_implemented('batch_get_points for CL packed batches')
  end select

  call profiling_count_transfers((ep-sp+1)*this%nst_linear, psi(1,1,sp))

  call profiling_out(TOSTRING(X(GET_POINTS)))

  POP_SUB(X(batch_get_points))
end subroutine X(batch_get_points)

! --------------------------------------------------------------
!> @brief copy a set of points into a mesh function
!!
subroutine X(batch_set_points)(this, sp, ep, psi)
  class(batch_t),     intent(inout) :: this            !< the batch to write points into
  integer,            intent(in)    :: sp              !< starting point
  integer,            intent(in)    :: ep              !< end point
  R_TYPE, contiguous, intent(in)    :: psi(:, :, sp:)  !< mesh function into which to write the points;
  !!                                                      dimensions (1:nst, 1:dim, sp:ep)

  integer :: idim, ist, ii, ip, shift

  PUSH_SUB(X(batch_set_points))

  ASSERT(not_in_openmp())

  call profiling_in(TOSTRING(X(SET_POINTS)))

  select case (this%status())
  case (BATCH_NOT_PACKED)

    if (this%type() == TYPE_FLOAT) then
#ifdef R_TCOMPLEX
      ! cannot set a real batch with complex values
      ASSERT(.false.)
#else
      do ii = 1, this%nst_linear
        ist = this%linear_to_ist(ii)
        idim = this%linear_to_idim(ii)
        this%dff_linear(sp:ep, ii) = psi(ist, idim, sp:ep)
      end do
#endif
    else

      do ii = 1, this%nst_linear
        ist = this%linear_to_ist(ii)
        idim = this%linear_to_idim(ii)
        this%zff_linear(sp:ep, ii) = psi(ist, idim, sp:ep)
      end do

    end if

  case (BATCH_PACKED)

    shift = this%ist(1) - 1

    if (this%type() == TYPE_FLOAT) then
#ifdef R_TCOMPLEX
      ! cannot set a real batch with complex values
      ASSERT(.false.)
#else
      !$omp parallel do private(ip, ist, idim, ii)
      do ip = sp, ep
        do idim = 1, this%dim
          !$omp simd
          do ist = 1, this%nst
            ii = (ist -1) * this%dim + idim
            this%dff_pack(ii, ip) = psi(ist+shift, idim, ip)
          end do
        end do
      end do
#endif
    else

      !$omp parallel do private(ip, ist, idim, ii)
      do ip = sp, ep
        do idim = 1, this%dim
          !$omp simd
          do ist = 1, this%nst
            ii = (ist -1) * this%dim + idim
            this%zff_pack(ii, ip) = psi(ist+shift, idim, ip)
          end do
        end do
      end do
      !$omp end parallel do
    end if

  case (BATCH_DEVICE_PACKED)
    call messages_not_implemented('batch_set_points for CL packed batches')
  end select

  call profiling_out(TOSTRING(X(SET_POINTS)))

  POP_SUB(X(batch_set_points))
end subroutine X(batch_set_points)

! --------------------------------------------------------------
!> @brief multiply all functions in a batch pointwise by a given mesh function ff
!
subroutine X(batch_mul)(np, ff,  xx, yy)
  integer,           intent(in)    :: np     !< number of points
  R_TYPE,            intent(in)    :: ff(:)  !< mesh function
  class(batch_t),    intent(in)    :: xx     !< input batch
  class(batch_t),    intent(inout) :: yy     !< output batch

  integer :: ist, ip
  R_TYPE :: mul
#if defined(R_TREAL)
  integer(int64) :: localsize, dim2, dim3
  type(accel_mem_t) :: ff_buffer
#endif

  PUSH_SUB(X(batch_mul))

  ASSERT(not_in_openmp())

  call profiling_in(TOSTRING(X(BATCH_MUL)))

  call xx%check_compatibility_with(yy)
#ifdef R_TCOMPLEX
  !if aa is complex, the functions must be complex
  ASSERT(yy%type() == TYPE_CMPLX)
#endif

  select case (yy%status())
  case (BATCH_DEVICE_PACKED)

#if defined(R_TREAL)

    ! We reuse here the routine to apply the local potential
    call batch_set_zero(yy)

    call accel_create_buffer(ff_buffer, ACCEL_MEM_READ_ONLY, R_TYPE_VAL, np)
    call accel_write_buffer(ff_buffer, np, ff)

    call accel_set_kernel_arg(kernel_vpsi, 0, 0)
    call accel_set_kernel_arg(kernel_vpsi, 1, np)
    call accel_set_kernel_arg(kernel_vpsi, 2, ff_buffer)
    call accel_set_kernel_arg(kernel_vpsi, 3, xx%ff_device)
    call accel_set_kernel_arg(kernel_vpsi, 4, int(log2(xx%pack_size_real(1)), int32))
    call accel_set_kernel_arg(kernel_vpsi, 5, yy%ff_device)
    call accel_set_kernel_arg(kernel_vpsi, 6, int(log2(yy%pack_size_real(1)), int32))

    localsize = accel_kernel_workgroup_size(kernel_vpsi)/xx%pack_size_real(1)

    dim3 = np/(accel_max_size_per_dim(2)*localsize) + 1
    dim2 = min(accel_max_size_per_dim(2)*localsize, pad(np, localsize))

    call accel_kernel_run(kernel_vpsi, (/xx%pack_size_real(1), dim2, dim3/), &
      (/xx%pack_size_real(1), localsize, 1_int64/))

    call accel_release_buffer(ff_buffer)
#else
    call messages_not_implemented("Accel batch_mul for the complex case")
#endif

  case (BATCH_PACKED)
    if (yy%type() == TYPE_CMPLX) then
      !$omp parallel do private(ip, ist, mul)
      do ip = 1, np
        mul = ff(ip)
        do ist = 1, yy%nst_linear
          yy%zff_pack(ist, ip) = mul*xx%zff_pack(ist, ip)
        end do
      end do
      !$omp end parallel do
    else
#ifdef R_TREAL
      !$omp parallel do private(ip, ist, mul)
      do ip = 1, np
        mul = ff(ip)
        do ist = 1, yy%nst_linear
          yy%dff_pack(ist, ip) = mul*xx%dff_pack(ist, ip)
        end do
      end do
      !$omp end parallel do
#endif
    end if

  case (BATCH_NOT_PACKED)
    if (yy%type() == TYPE_CMPLX) then
      do ist = 1, yy%nst_linear
        !$omp parallel do
        do ip = 1, np
          yy%zff_linear(ip, ist) = ff(ip)*xx%zff_linear(ip, ist)
        end do
        !$omp end parallel do
      end do
    else
#ifdef R_TREAL
      do ist = 1, yy%nst_linear
        !$omp parallel do
        do ip = 1, np
          yy%dff_linear(ip, ist) = ff(ip)*xx%dff_linear(ip, ist)
        end do
        !$omp end parallel do
      end do

#endif
    end if
  end select

  call profiling_out(TOSTRING(X(BATCH_MUL)))
  POP_SUB(X(batch_mul))

end subroutine X(batch_mul)

! --------------------------------------------------------------
subroutine X(batch_add_with_map)(np, map, xx, yy, zz)
  integer,           intent(in)    :: np
  integer,           intent(in)    :: map(:)
  class(batch_t),    intent(in)    :: xx
  class(batch_t),    intent(in)    :: yy
  class(batch_t),    intent(inout) :: zz

  integer :: ii, ip_in, ip

  PUSH_SUB(X(batch_add_with_map))

  ASSERT(not_in_openmp())

  call profiling_in(TOSTRING(X(BATCH_COPY_WITH_MAP)))

  call xx%check_compatibility_with(yy)
  call xx%check_compatibility_with(zz)

  select case (xx%status())
  case (BATCH_NOT_PACKED)
    do ii = 1, xx%nst_linear
      do ip_in = 1, np
        ip = map(ip_in)
        zz%X(ff_linear)(ip, ii) = xx%X(ff_linear)(ip, ii) + yy%X(ff_linear)(ip, ii)
      end do
    end do
  case (BATCH_PACKED)
    do ip_in = 1, np
      ip = map(ip_in)
      do ii = 1, xx%nst_linear
        zz%X(ff_pack)(ii, ip) = xx%X(ff_pack)(ii, ip) + yy%X(ff_pack)(ii, ip)
      end do
    end do
  case (BATCH_DEVICE_PACKED)
    call messages_not_implemented("batch_add_with_map CL")
  end select

  call profiling_out(TOSTRING(X(BATCH_COPY_WITH_MAP)))
  POP_SUB(X(batch_add_with_map))
end subroutine X(batch_add_with_map)

! --------------------------------------------------------------
subroutine X(batch_copy_with_map)(np, map, xx, yy)
  integer,           intent(in)    :: np
  integer,           intent(in)    :: map(:)
  class(batch_t),    intent(in)    :: xx
  class(batch_t),    intent(inout) :: yy

  integer :: ii, ip_in, ip

  PUSH_SUB_WITH_PROFILE(X(batch_copy_with_map))

  ASSERT(not_in_openmp())

  call xx%check_compatibility_with(yy)

  select case (xx%status())
  case (BATCH_NOT_PACKED)
    do ii = 1, xx%nst_linear
      !$omp parallel do private(ip, ip_in)
      do ip_in = 1, np
        ip = map(ip_in)
        yy%X(ff_linear)(ip, ii) = xx%X(ff_linear)(ip, ii)
      end do
    end do
  case (BATCH_PACKED)
    !$omp parallel do private(ip, ip_in, ii)
    do ip_in = 1, np
      ip = map(ip_in)
      do ii = 1, xx%nst_linear
        yy%X(ff_pack)(ii, ip) = xx%X(ff_pack)(ii, ip)
      end do
    end do
  case (BATCH_DEVICE_PACKED)
    call messages_not_implemented("batch_copy_with_map CL")
  end select

  POP_SUB_WITH_PROFILE(X(batch_copy_with_map))
end subroutine X(batch_copy_with_map)

! ---------------------------------------------------------
!>@brief Transfer a batch from the mesh to an array on the submesh (defined by a map)
subroutine X(batch_copy_with_map_to_array)(np, map, xx, array)
  integer,          intent(in)    :: np
  integer,          intent(in)    :: map(:)
  class(batch_t),   intent(in)    :: xx
  R_TYPE,           intent(inout) :: array(:,:) !< (psib%nst_linear, submesh%np)

  integer :: ip, ist, ip_map

  PUSH_SUB_WITH_PROFILE(X(batch_copy_with_map_to_array))

  ASSERT(xx%status()/= BATCH_DEVICE_PACKED)
  ASSERT(not_in_openmp())

#ifdef R_TREAL
  ASSERT(xx%type() == TYPE_FLOAT)
#else
  ASSERT(xx%type() == TYPE_CMPLX)
#endif

  select case (xx%status())
  case (BATCH_NOT_PACKED)
    !$omp parallel
    do ist = 1, xx%nst_linear
      !$omp do simd
      do ip = 1, np
        array(ist, ip) = xx%X(ff_linear)(map(ip), ist)
      end do
    end do
    !$omp end parallel
  case (BATCH_PACKED)
    !$omp parallel do private(ist, ip_map)
    do ip = 1, np
      ip_map = map(ip)
      !$omp simd
      do ist = 1, xx%nst_linear
        array(ist, ip) = xx%X(ff_pack)(ist, ip_map)
      end do
    end do
    !$omp end parallel do
  end select

  POP_SUB_WITH_PROFILE(X(batch_copy_with_map_to_array))
end subroutine X(batch_copy_with_map_to_array)

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
