!! Copyright (C) 2002-2006 M. Marques, A. Castro, A. Rubio, G. Bertsch, M. Verstraete
!! Copyright (C) 2021 S. Ohlmann
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

!----------------------------------------------------------------------
!> @brief Calculate the overlap matrix of two batches
!
subroutine X(mesh_batch_dotp_matrix)(mesh, aa, bb, dot, reduce)
  class(mesh_t),      intent(in)    :: mesh           !< underlying mesh
  class(batch_t),     intent(in)    :: aa             !< batch aa
  class(batch_t),     intent(in)    :: bb             !< batch bb
  R_TYPE, contiguous, intent(inout) :: dot(:, :)      !< result: dot\_ij = < aa(i) | bb(j) >;
  !!                                                     dimension (states\%st\_start: states\%st\_end, states\%st\_start: states\%st\_end)
  logical, optional,  intent(in)    :: reduce         !< optional flag whether to perform reduction; default = .true.

  integer :: ist, jst, idim, sp, block_size, ep, ip, ldaa, ldbb, indb, jndb
  R_TYPE :: ss
  R_TYPE, allocatable :: dd(:, :)
  logical :: use_blas, conj
  type(accel_mem_t) :: dot_buffer
  integer :: wgsize
  integer :: local_sizes(3)
  integer :: global_sizes(3)

  logical :: reduce_

  PUSH_SUB(X(mesh_batch_dotp_matrix))
  call profiling_in(TOSTRING(X(DOTP_BATCH)))

  reduce_ = .true.
  if (present(reduce)) reduce_ = reduce
  conj = .false.

  call aa%check_compatibility_with(bb, only_check_dim = .true.)

  SAFE_ALLOCATE(dd(1:aa%nst, 1:bb%nst))
  ! This has to be set to zero by hand since NaN * 0 = NaN.
  dd(1:aa%nst, 1:bb%nst) = R_TOTYPE(M_ZERO)

  use_blas = .false.

  select case (aa%status())
  case (BATCH_NOT_PACKED)
    use_blas = associated(aa%X(ff)) .and. associated(bb%X(ff)) .and. (.not. mesh%use_curvilinear) .and. (aa%dim == 1)

    if (use_blas) then
      call profiling_in(TOSTRING(X(DOTP_BATCH_GEMM)))

      ldaa = size(aa%X(ff), dim = 1)
      ldbb = size(bb%X(ff), dim = 1)

      call lalg_gemm_cn(aa%nst, aa%dim, bb%nst, bb%dim, mesh%np, R_TOTYPE(mesh%volume_element), &
        aa%X(ff), bb%X(ff), R_TOTYPE(M_ZERO), dd)

    else

      block_size = cpu_hardware%X(block_size)

      do idim = 1, aa%dim
        do sp = 1, mesh%np, block_size
          ep = min(mesh%np, sp + block_size - 1)

          if (mesh%use_curvilinear) then

            do ist = 1, aa%nst
              indb = aa%ist_idim_to_linear((/ist, idim/))
              do jst = 1, bb%nst
                jndb = bb%ist_idim_to_linear((/jst, idim/))

                ss = M_ZERO
                do ip = sp, ep
                  ss = ss + mesh%vol_pp(ip)*R_CONJ(aa%X(ff_linear)(ip, indb))*bb%X(ff_linear)(ip, jndb)
                end do
                dd(ist, jst) = dd(ist, jst) + ss

              end do
            end do

          else

            do ist = 1, aa%nst
              indb = aa%ist_idim_to_linear((/ist, idim/))
              do jst = 1, bb%nst
                jndb = bb%ist_idim_to_linear((/jst, idim/))

                dd(ist, jst) = dd(ist, jst) + mesh%volume_element*&
                  blas_dot(ep - sp + 1, aa%X(ff_linear)(sp, indb), 1, bb%X(ff_linear)(sp, jndb), 1)
              end do
            end do

          end if
        end do
      end do

    end if
  case (BATCH_PACKED)
    ASSERT(.not. mesh%use_curvilinear)
    use_blas = aa%dim == 1

    if (use_blas) then
      conj = .true.
      call profiling_in(TOSTRING(X(DOTP_BATCH_GEMM)))

      ldaa = int(aa%pack_size(1), int32)
      ldbb = int(bb%pack_size(1), int32)
      call blas_gemm(transa = 'n', transb = 'c', m = aa%nst, n = bb%nst, k = mesh%np, &
        alpha = R_TOTYPE(mesh%volume_element), &
        a = aa%X(ff_pack)(1, 1), lda = ldaa, &
        b = bb%X(ff_pack)(1, 1), ldb = ldbb, &
        beta = R_TOTYPE(M_ZERO), c = dd(1, 1), ldc = aa%nst)

    else

      dd = M_ZERO
      !$omp parallel do private(ist, jst) reduction(+:dd)
      do ip = 1, mesh%np
        do jst = 1, bb%nst
          do ist = 1, aa%nst
            dd(ist, jst) = dd(ist, jst) + R_CONJ(aa%X(ff_pack)(2*ist - 1, ip))*bb%X(ff_pack)(2*jst - 1, ip) &
              + R_CONJ(aa%X(ff_pack)(2*ist    , ip))*bb%X(ff_pack)(2*jst    , ip)
          end do
        end do
      end do
      !$omp end parallel do
      dd = mesh%volume_element*dd

    end if

  case (BATCH_DEVICE_PACKED)
    ASSERT(.not. mesh%use_curvilinear)

    if (aa%dim == 1) then

      call accel_create_buffer(dot_buffer, ACCEL_MEM_WRITE_ONLY, R_TYPE_VAL, aa%nst*bb%nst)

      call profiling_in(TOSTRING(X(DOTP_BATCH_CL_GEMM)))

      call X(accel_gemm)(transA = ACCEL_BLAS_N, transB = ACCEL_BLAS_C, &
        M = int(aa%nst, int64), N = int(bb%nst, int64), K = int(mesh%np, int64), alpha = R_TOTYPE(M_ONE), &
        A = aa%ff_device, offA = 0_8, lda = int(aa%pack_size(1), int64), &
        B = bb%ff_device, offB = 0_8, ldb = int(bb%pack_size(1), int64), beta = R_TOTYPE(M_ZERO), &
        C = dot_buffer, offC = 0_8, ldc = int(aa%nst, int64))

      call profiling_count_operations(real(mesh%np, real64) *aa%nst*bb%nst*(R_ADD + R_MUL))

      call accel_finish()
      call profiling_out(TOSTRING(X(DOTP_BATCH_CL_GEMM)))

      call profiling_in(TOSTRING(X(DOTP_BATCH_COPY)))
      call accel_read_buffer(dot_buffer, aa%nst*bb%nst, dd)
      call profiling_count_transfers(aa%nst*bb%nst, dd(1, 1))
      call accel_finish()
      call profiling_out(TOSTRING(X(DOTP_BATCH_COPY)))

      call accel_release_buffer(dot_buffer)

    else

      ASSERT(R_TYPE_VAL == TYPE_CMPLX)

      call accel_create_buffer(dot_buffer, ACCEL_MEM_WRITE_ONLY, R_TYPE_VAL, aa%nst*bb%nst)

      wgsize = accel_kernel_workgroup_size(zkernel_dot_matrix_spinors)

      global_sizes = (/ pad(aa%nst, wgsize/bb%nst),  bb%nst, 1 /)
      local_sizes  = (/ wgsize/bb%nst,               bb%nst, 1 /)

      ASSERT(accel_buffer_is_allocated(aa%ff_device))
      ASSERT(accel_buffer_is_allocated(bb%ff_device))
      ASSERT(accel_buffer_is_allocated(dot_buffer))

      call profiling_in(TOSTRING(X(DOTP_BATCH_CL_KERNEL)))

      call accel_set_kernel_arg(zkernel_dot_matrix_spinors, 0, mesh%np)
      call accel_set_kernel_arg(zkernel_dot_matrix_spinors, 1, aa%nst)
      call accel_set_kernel_arg(zkernel_dot_matrix_spinors, 2, bb%nst)
      call accel_set_kernel_arg(zkernel_dot_matrix_spinors, 3, aa%ff_device)
      call accel_set_kernel_arg(zkernel_dot_matrix_spinors, 4, log2(aa%pack_size(1)))
      call accel_set_kernel_arg(zkernel_dot_matrix_spinors, 5, bb%ff_device)
      call accel_set_kernel_arg(zkernel_dot_matrix_spinors, 6, log2(bb%pack_size(1)))
      call accel_set_kernel_arg(zkernel_dot_matrix_spinors, 7, dot_buffer)
      call accel_set_kernel_arg(zkernel_dot_matrix_spinors, 8, aa%nst)


      call accel_kernel_run(zkernel_dot_matrix_spinors, global_sizes, local_sizes)

      call accel_finish()
      call profiling_count_operations(real(aa%nst*bb%nst*(mesh%np*(R_ADD + R_MUL)) + R_ADD, real64))  ! check !!


      call profiling_out(TOSTRING(X(DOTP_BATCH_CL_KERNEL)))

      call profiling_in(TOSTRING(X(DOTP_BATCH_COPY)))
      call accel_read_buffer(dot_buffer, aa%nst*bb%nst, dd)
      call profiling_count_transfers(aa%nst*bb%nst, dd(1, 1))
      call accel_finish()
      call profiling_out(TOSTRING(X(DOTP_BATCH_COPY)))

      call accel_release_buffer(dot_buffer)

    end if

    do ist = 1, aa%nst
      do jst = 1, bb%nst
        dd(ist, jst) = mesh%volume_element*dd(ist, jst)
      end do
    end do

  case default
    ASSERT(.false.)

  end select

  if (aa%status() /= BATCH_DEVICE_PACKED) then
    if (mesh%use_curvilinear) then
      call profiling_count_operations(real(mesh%np, real64) *aa%nst*bb%nst*aa%dim*(R_ADD + 2*R_MUL))
    else
      call profiling_count_operations(real(mesh%np, real64) *aa%nst*bb%nst*aa%dim*(R_ADD + R_MUL))
    end if
  end if

  if (use_blas) call profiling_out(TOSTRING(X(DOTP_BATCH_GEMM)))

  if (reduce_) then
    call profiling_in(TOSTRING(X(DOTP_BATCH_REDUCE)))
    call mesh%allreduce(dd)
    call profiling_out(TOSTRING(X(DOTP_BATCH_REDUCE)))
  end if

  if (conj) then
    do jst = 1, bb%nst
      do ist = 1, aa%nst
        dot(aa%ist(ist), bb%ist(jst)) = R_CONJ(dd(ist, jst))
      end do
    end do
  else
    do jst = 1, bb%nst
      do ist = 1, aa%nst
        dot(aa%ist(ist), bb%ist(jst)) = dd(ist, jst)
      end do
    end do
  end if

  SAFE_DEALLOCATE_A(dd)

  call profiling_out(TOSTRING(X(DOTP_BATCH_GEMM)))
  POP_SUB(X(mesh_batch_dotp_matrix))
end subroutine X(mesh_batch_dotp_matrix)

!-----------------------------------------------------------------

!> @brief calculate the overlap matrix of a batch with itself
!
subroutine X(mesh_batch_dotp_self)(mesh, aa, dot, reduce)
  class(mesh_t),      intent(in)    :: mesh      !< underlying mesh
  class(batch_t),     intent(in)    :: aa        !< batch aa
  R_TYPE, contiguous, intent(inout) :: dot(:, :) !< result:  dot\_ij = < aa(i) | aa(j) >;  dimension (1:aa\%nst, 1:aa\%nst)
  logical, optional,  intent(in)    :: reduce    !< optional flag whether to perform reduction; default = .true.

  integer :: ist, jst, idim, sp, block_size, ep, ip, lda, indb, jndb
  R_TYPE :: ss
  logical :: use_blas, reduce_
  R_TYPE, allocatable :: dd(:, :)

  PUSH_SUB(X(mesh_batch_dotp_self))

  ! some limitations of the current implementation
  ASSERT(ubound(dot, dim = 1) >= aa%nst .and. ubound(dot, dim = 2) >= aa%nst)

  if (aa%status() /= BATCH_NOT_PACKED) then
    call X(mesh_batch_dotp_matrix)(mesh, aa, aa, dot, reduce)
    POP_SUB(X(mesh_batch_dotp_self))
    return
  end if

  reduce_ = .true.
  if (present(reduce)) reduce_ = reduce

  use_blas = associated(aa%X(ff)) .and. (.not. mesh%use_curvilinear)

  SAFE_ALLOCATE(dd(1:aa%nst, 1:aa%nst))
  ! This has to be set to zero by hand since NaN * 0 = NaN.
  dd(1:aa%nst, 1:aa%nst) = R_TOTYPE(0.0_real64)

  call profiling_in(TOSTRING(X(BATCH_DOTP_SELF)))

  if (use_blas) then
    call profiling_in(TOSTRING(X(BATCH_HERK)))

    lda = size(aa%X(ff), dim = 1)*aa%dim

    call blas_herk('l', 'c', aa%nst, mesh%np, mesh%vol_pp(1), aa%X(ff)(1, 1, 1), &
      lda, M_ZERO, dd(1, 1), ubound(dd, dim = 1))

    if (aa%dim == 2) then
      call blas_herk('l', 'c', aa%nst, mesh%np, mesh%vol_pp(1), aa%X(ff)(1, 2, 1), &
        lda, M_ONE, dd(1, 1), ubound(dd, dim = 1))
    end if

  else

    block_size = cpu_hardware%X(block_size)

    do idim = 1, aa%dim
      do sp = 1, mesh%np, block_size
        ep = min(mesh%np, sp + block_size - 1)

        if (mesh%use_curvilinear) then

          do ist = 1, aa%nst
            indb = aa%ist_idim_to_linear((/ist, idim/))
            do jst = 1, ist
              jndb = aa%ist_idim_to_linear((/jst, idim/))
              ss = M_ZERO
              do ip = sp, ep
                ss = ss + mesh%vol_pp(ip)*R_CONJ(aa%X(ff_linear)(ip, indb))*aa%X(ff_linear)(ip, jndb)
              end do
              dd(ist, jst) = dd(ist, jst) + ss

            end do
          end do

        else

          do ist = 1, aa%nst
            indb = aa%ist_idim_to_linear((/ist, idim/))
            do jst = 1, ist
              jndb = aa%ist_idim_to_linear((/jst, idim/))
              dd(ist, jst) = dd(ist, jst) + mesh%volume_element*&
                blas_dot(ep - sp + 1, aa%X(ff_linear)(sp, indb), 1, aa%X(ff_linear)(sp, jndb), 1)
            end do
          end do

        end if
      end do
    end do
  end if

  if (mesh%use_curvilinear) then
    call profiling_count_operations(real(mesh%np, real64) *aa%nst**2*aa%dim*(R_ADD + 2*R_MUL))
  else
    call profiling_count_operations(real(mesh%np, real64) *aa%nst**2*aa%dim*(R_ADD + R_MUL))
  end if

  if (use_blas) call profiling_out(TOSTRING(X(BATCH_HERK)))

  if (reduce_) then
    call profiling_in(TOSTRING(X(BATCH_SELF_REDUCE)))
    call mesh%allreduce(dd)
    call profiling_out(TOSTRING(X(BATCH_SELF_REDUCE)))
  end if

  do ist = 1, aa%nst
    do jst = 1, ist
      dot(aa%ist(ist), aa%ist(jst)) = dd(ist, jst)
      dot(aa%ist(jst), aa%ist(ist)) = R_CONJ(dd(ist, jst))
    end do
  end do

  SAFE_DEALLOCATE_A(dd)

  call profiling_out(TOSTRING(X(BATCH_DOTP_SELF)))
  POP_SUB(X(mesh_batch_dotp_self))
end subroutine X(mesh_batch_dotp_self)

! --------------------------------------------------------------------------
!> @brief calculate the vector of dot-products of mesh functions between two batches
!
subroutine X(mesh_batch_dotp_vector)(mesh, aa, bb, dot, reduce, cproduct)
  class(mesh_t),      intent(in)    :: mesh       !< underlying mesh
  class(batch_t),     intent(in)    :: aa         !< batch aa
  class(batch_t),     intent(in)    :: bb         !< batch bb
  R_TYPE, contiguous, intent(inout) :: dot(:)     !< result: dot\_i = < aa(i) | bb(i) >
  logical, optional,  intent(in)    :: reduce     !< optional flag whether to perform reduction; default = .true.
  logical, optional,  intent(in)    :: cproduct   !< optional flag: complex conj. product; default = .false.

  integer :: ist, indb, idim, ip, wgsize, status
  logical :: cproduct_
  R_TYPE, allocatable :: tmp(:), cltmp(:)
  type(accel_mem_t)  :: dot_buffer
  type(accel_kernel_t), save  :: kernel_batch_dotpv

  PUSH_SUB(X(mesh_batch_dotp_vector))

  ASSERT(not_in_openmp())

  call profiling_in(TOSTRING(X(DOTPV_BATCH)))

  cproduct_ = optional_default(cproduct, .false.)

  call aa%check_compatibility_with(bb)

  status = aa%status()
  ASSERT(bb%status() == status)

  select case (status)
  case (BATCH_NOT_PACKED)
    do ist = 1, aa%nst
      dot(ist) = M_ZERO
      do idim = 1, aa%dim
        indb = aa%ist_idim_to_linear((/ist, idim/))
        dot(ist) = dot(ist) + X(mf_dotp)(mesh, aa%X(ff_linear)(:, indb), bb%X(ff_linear)(:, indb),&
          reduce = .false., dotu = cproduct_)
      end do
    end do

  case (BATCH_PACKED)
    SAFE_ALLOCATE(tmp(1:aa%nst_linear))

    tmp = M_ZERO

    if (mesh%use_curvilinear) then
      if (.not. cproduct_) then
        !$omp parallel do private(ist) reduction(+:tmp)
        do ip = 1, mesh%np
          !$omp simd
          do ist = 1, aa%nst_linear
            tmp(ist) = tmp(ist) + mesh%vol_pp(ip)*R_CONJ(aa%X(ff_pack)(ist, ip))*bb%X(ff_pack)(ist, ip)
          end do
        end do
      else
        !$omp parallel do private(ist) reduction(+:tmp)
        do ip = 1, mesh%np
          !$omp simd
          do ist = 1, aa%nst_linear
            tmp(ist) = tmp(ist) + mesh%vol_pp(ip)*aa%X(ff_pack)(ist, ip)*bb%X(ff_pack)(ist, ip)
          end do
        end do
      end if
    else
      if (.not. cproduct_) then
        !$omp parallel do private(ist) reduction(+:tmp)
        do ip = 1, mesh%np
          !$omp simd
          do ist = 1, aa%nst_linear
            tmp(ist) = tmp(ist) + R_CONJ(aa%X(ff_pack)(ist, ip))*bb%X(ff_pack)(ist, ip)
          end do
        end do
      else
        !$omp parallel do private(ist) reduction(+:tmp)
        do ip = 1, mesh%np
          !$omp simd
          do ist = 1, aa%nst_linear
            tmp(ist) = tmp(ist) + aa%X(ff_pack)(ist, ip)*bb%X(ff_pack)(ist, ip)
          end do
        end do
      end if
    end if

    do ist = 1, aa%nst
      dot(ist) = M_ZERO
      do idim = 1, aa%dim
        indb = aa%ist_idim_to_linear((/ist, idim/))
        dot(ist) = dot(ist) + mesh%volume_element*tmp(indb)
      end do
    end do

    SAFE_DEALLOCATE_A(tmp)

  case (BATCH_DEVICE_PACKED)

    ASSERT(.not. mesh%use_curvilinear)
    if(cproduct_) then

      call accel_create_buffer(dot_buffer, ACCEL_MEM_WRITE_ONLY, R_TYPE_VAL, aa%pack_size(1))

      do ist = 1, aa%nst_linear
        call accel_set_stream(ist)
        call X(accel_dotu)(n = int(mesh%np, int64), &
          x = aa%ff_device, offx = int(ist - 1, int64), incx = int(aa%pack_size(1), int64), &
          y = bb%ff_device, offy = int(ist - 1, int64), incy = int(bb%pack_size(1), int64), &
          res = dot_buffer, offres = int(ist - 1, int64))
      end do
      call accel_synchronize_all_streams()
      call accel_set_stream(1)

      SAFE_ALLOCATE(cltmp(1:aa%pack_size(1)))

      call accel_read_buffer(dot_buffer, aa%pack_size(1), cltmp)
      call accel_release_buffer(dot_buffer)

      do ist = 1, aa%nst
        dot(ist) = M_ZERO
        do idim = 1, aa%dim
          indb = aa%ist_idim_to_linear((/ist, idim/))
          dot(ist) = dot(ist) + mesh%volume_element*cltmp(indb)
        end do
      end do
      SAFE_DEALLOCATE_A(cltmp)

    else
      call accel_kernel_start_call(kernel_batch_dotpv, 'mesh_batch_single.cl', TOSTRING(X(batch_dotpv)))

      call accel_create_buffer(dot_buffer, ACCEL_MEM_READ_WRITE, R_TYPE_VAL, aa%nst)
      call accel_set_buffer_to_zero(dot_buffer, R_TYPE_VAL, aa%nst)

      ASSERT(accel_buffer_is_allocated(aa%ff_device))
      ASSERT(accel_buffer_is_allocated(bb%ff_device))

      call accel_set_kernel_arg(kernel_batch_dotpv, 0, mesh%np)
      call accel_set_kernel_arg(kernel_batch_dotpv, 1, aa%nst_linear)
      call accel_set_kernel_arg(kernel_batch_dotpv, 2, aa%dim)
      call accel_set_kernel_arg(kernel_batch_dotpv, 3, mesh%volume_element)
      call accel_set_kernel_arg(kernel_batch_dotpv, 4, aa%ff_device)
      call accel_set_kernel_arg(kernel_batch_dotpv, 5, log2(int(aa%pack_size(1), int32)))
      call accel_set_kernel_arg(kernel_batch_dotpv, 6, bb%ff_device)
      call accel_set_kernel_arg(kernel_batch_dotpv, 7, log2(int(bb%pack_size(1), int32)))
      call accel_set_kernel_arg(kernel_batch_dotpv, 8, dot_buffer)

      ! Setting the size of the shared region
      wgsize = accel_kernel_workgroup_size(kernel_batch_dotpv)/int(aa%pack_size(1), int32)
      call accel_set_kernel_arg(kernel_batch_dotpv, 9, R_TYPE_VAL, wgsize*int(aa%pack_size(1), int32))

      call accel_kernel_run(kernel_batch_dotpv,             &
        (/pad(mesh%np, wgsize), int(aa%pack_size(1), int32)/), &
        (/wgsize, int(aa%pack_size(1), int32)/))

      ! dot is not guarantied to be contiguous, so we read first in cltmp and later copy it
      SAFE_ALLOCATE(cltmp(1:aa%nst))
      call accel_read_buffer(dot_buffer, aa%nst, cltmp)
      dot(1:aa%nst) = cltmp
      call accel_release_buffer(dot_buffer)

    end if
  end select

  if (optional_default(reduce, .true.)) then
    call profiling_in(TOSTRING(X(DOTPV_BATCH_REDUCE)))
    call mesh%allreduce(dot, dim = aa%nst)
    call profiling_out(TOSTRING(X(DOTPV_BATCH_REDUCE)))
  end if

  call profiling_count_operations(aa%nst_linear*real(mesh%np, real64) *(R_ADD + R_MUL))

  call profiling_out(TOSTRING(X(DOTPV_BATCH)))
  POP_SUB(X(mesh_batch_dotp_vector))
end subroutine X(mesh_batch_dotp_vector)

! --------------------------------------------------------------------------
!> @brief calculate the dot products between a batch and a vector of mesh functions
!
subroutine X(mesh_batch_mf_dotp)(mesh, aa, psi, dot, reduce, nst)
  class(mesh_t),     intent(in)    :: mesh      !< underlying mesh
  class(batch_t),    intent(in)    :: aa        !< batch aa
  R_TYPE,            intent(in)    :: psi(:,:)  !< mesh functions; min. dimension (1:np, 1:nst)
  R_TYPE,            intent(inout) :: dot(:)    !< result dot\_i = < aa(i) | psi(i) >
  logical, optional, intent(in)    :: reduce    !< optional flag whether to perform reduction; default = .true.
  integer, optional, intent(in)    :: nst       !< optional number of states; default = aa%nst

  integer :: ist, indb, idim, ip, nst_
  R_TYPE, allocatable :: phi(:, :)

  ! Variables related to the GPU:
  type(accel_mem_t) :: psi_buffer
  type(accel_mem_t) :: dot_buffer
  integer :: wgsize, np_padded
  integer :: local_sizes(3)
  integer :: global_sizes(3)

  PUSH_SUB(X(mesh_batch_mf_dotp))

  ASSERT(not_in_openmp())

  call profiling_in(TOSTRING(X(DOTPV_MF_BATCH)))

  ASSERT(aa%dim == ubound(psi,dim=2))

  nst_ = aa%nst
  if (present(nst)) nst_ = nst

  select case (aa%status())
  case (BATCH_NOT_PACKED)
    do ist = 1, nst_
      dot(ist) = M_ZERO
      do idim = 1, aa%dim
        indb = aa%ist_idim_to_linear((/ist, idim/))
        dot(ist) = dot(ist) + X(mf_dotp)(mesh, aa%X(ff_linear)(:, indb), psi(1:mesh%np,idim),&
          reduce = .false.)
      end do
    end do

  case (BATCH_PACKED)
    SAFE_ALLOCATE(phi(1:mesh%np, aa%dim))

    if (aa%dim == 1) then
      !Here we compute the complex conjuguate of the dot product first and then
      !we take the conjugate at the end

      ! Note: this is to avoid taking the complex conjugate of the whole batch, but rather that of
      ! the single function only.
      ! In the aa%dim>1 case, that is taken care of by the mf_dotp function.

      if (mesh%use_curvilinear) then
        !$omp parallel do
        do ip = 1, mesh%np
          phi(ip, 1) = mesh%vol_pp(ip)*R_CONJ(psi(ip, 1))
        end do
      else
        !$omp parallel do
        do ip = 1, mesh%np
          phi(ip, 1) = R_CONJ(psi(ip, 1))
        end do
      end if

      call blas_gemv('N', nst_, mesh%np, R_TOTYPE(mesh%volume_element), aa%X(ff_pack)(1,1), &
        ubound(aa%X(ff_pack), dim=1), phi(1,1), 1, R_TOTYPE(M_ZERO), dot(1), 1)

      do ist = 1, nst_
        dot(ist) = R_CONJ(dot(ist))
      end do

    else

      ! Note: curvilinear coordinates are handled inside the mf_dotp function!
      if (mesh%use_curvilinear) then
        dot(1:nst_) = M_ZERO
        do ist = 1, nst_
          call batch_get_state(aa, ist, mesh%np, phi)
          dot(ist) = X(mf_dotp)(mesh, aa%dim, phi(1:mesh%np, 1:aa%dim), psi(1:mesh%np, 1:aa%dim),&
            reduce = .false.)
        end do
      else
        dot(1:nst_) = M_ZERO
        !$omp parallel do reduction(+:dot) private(ist, indb)
        do ip = 1, mesh%np
          do ist = 1, nst_
            indb = aa%ist_idim_to_linear((/ist, 1/))
            dot(ist) = dot(ist) + R_CONJ(aa%X(ff_pack)(indb, ip)) * psi(ip, 1) &
              + R_CONJ(aa%X(ff_pack)(indb+1, ip)) * psi(ip, 2)
          end do
        end do
        dot = dot * mesh%volume_element
      end if

    end if

    SAFE_DEALLOCATE_A(phi)

  case (BATCH_DEVICE_PACKED)

    ASSERT(.not. mesh%use_curvilinear)

    np_padded = pad_pow2(mesh%np)

    call accel_create_buffer(dot_buffer, ACCEL_MEM_READ_WRITE, R_TYPE_VAL, aa%nst)
    call accel_create_buffer(psi_buffer, ACCEL_MEM_READ_ONLY, R_TYPE_VAL, np_padded * aa%dim)

    do idim= 1, aa%dim
      call accel_write_buffer(psi_buffer, mesh%np, psi(1:mesh%np,idim), offset=(idim-1)*np_padded)
    end do

    wgsize = accel_kernel_workgroup_size(X(kernel_batch_dotp))

    global_sizes = (/ pad(aa%nst, wgsize),  1, 1 /)
    local_sizes  = (/ wgsize,               1, 1 /)

    ASSERT(accel_buffer_is_allocated(aa%ff_device))
    ASSERT(accel_buffer_is_allocated(psi_buffer))
    ASSERT(accel_buffer_is_allocated(dot_buffer))

    call accel_set_kernel_arg(X(kernel_batch_dotp), 0, mesh%np)
    call accel_set_kernel_arg(X(kernel_batch_dotp), 1, nst_)
    call accel_set_kernel_arg(X(kernel_batch_dotp), 2, aa%dim)
    call accel_set_kernel_arg(X(kernel_batch_dotp), 3, aa%ff_device)
    call accel_set_kernel_arg(X(kernel_batch_dotp), 4, int(log2(aa%pack_size(1)), int32))
    call accel_set_kernel_arg(X(kernel_batch_dotp), 5, psi_buffer)
    call accel_set_kernel_arg(X(kernel_batch_dotp), 6, log2(np_padded))
    call accel_set_kernel_arg(X(kernel_batch_dotp), 7, dot_buffer)

    call accel_kernel_run(X(kernel_batch_dotp), global_sizes, local_sizes)

    call accel_read_buffer(dot_buffer, nst_, dot)

    call accel_release_buffer(psi_buffer)
    call accel_release_buffer(dot_buffer)

    do ist = 1, nst_
      dot(ist) = dot(ist) * mesh%volume_element
    end do

  end select

  if (optional_default(reduce, .true.)) then
    call profiling_in(TOSTRING(X(DOTPV_MF_BATCH_REDUCE)))
    call mesh%allreduce(dot, dim = nst_)
    call profiling_out(TOSTRING(X(DOTPV_MF_BATCH_REDUCE)))
  end if

  call profiling_count_operations(nst_*aa%dim*real(mesh%np, real64) *(R_ADD + R_MUL))

  call profiling_out(TOSTRING(X(DOTPV_MF_BATCH)))
  POP_SUB(X(mesh_batch_mf_dotp))
end subroutine X(mesh_batch_mf_dotp)


!--------------------------------------------------------------------------------------
!> @brief calculate the co-densities
!!
!! Result \f$ \rho_i(x) = \psi(x) \phi_i(x)  \f$ where the
!! $\phi_i$ are given by the batch aa.
!
subroutine X(mesh_batch_codensity)(mesh, aa, psi, rho)
  class(mesh_t),     intent(in)    :: mesh     !< The mesh descriptor.
  class(batch_t),    intent(in)    :: aa       !< A batch which contains the mesh functions
  R_TYPE,            intent(in)    :: psi(:,:) !< A mesh function; dimension (1:mesh\%np, 1:aa%dim)
  R_TYPE,            intent(out)   :: rho(:,:) !< An array containing the result of the co-density

  integer :: ii, ip, idim,  block_size, sp, size

  PUSH_SUB(X(mesh_batch_codensity))

  ASSERT(not_in_openmp())

  call profiling_in("CODENSITIES")

  ASSERT((aa%status()) /= BATCH_DEVICE_PACKED)

  block_size = cpu_hardware%X(block_size)

  select case (aa%status())
  case (BATCH_PACKED)
    !$omp parallel do private(size, ii, ip, idim)
    do sp = 1, mesh%np, block_size
      size = min(block_size, mesh%np - sp + 1)
      do  ii = 1, aa%nst
        do ip = sp, sp + size - 1
          rho(ip, ii) = psi(ip, 1) * aa%X(ff_pack)((ii - 1) * aa%dim + 1, ip)
        end do
        do idim = 2, aa%dim
          do ip = sp, sp + size - 1
            rho(ip, ii) = rho(ip, ii) + psi(ip, idim) * aa%X(ff_pack)((ii - 1) * aa%dim + idim, ip)
          end do
        end do
      end do
    end do

  case (BATCH_NOT_PACKED)
    !$omp parallel do private(size, ii, ip, idim)
    do sp = 1, mesh%np, block_size
      size = min(block_size, mesh%np - sp + 1)
      do  ii = 1, aa%nst
        do ip = sp, sp + size - 1
          rho(ip, ii) = psi(ip, 1) * aa%X(ff)(ip, 1, ii)
        end do
        do idim = 2, aa%dim
          do ip = sp, sp + size - 1
            rho(ip, ii) = rho(ip, ii) + psi(ip, idim) * aa%X(ff)(ip, idim, ii)
          end do
        end do
      end do
    end do
  end select

  call profiling_out("CODENSITIES")
  POP_SUB(X(mesh_batch_codensity))
end subroutine X(mesh_batch_codensity)

!--------------------------------------------------------------------------------------

!> This functions exchanges points of a mesh according to a certain
!! map. Two possible maps can be given. Only one map argument must be present.
!
subroutine X(mesh_batch_exchange_points)(mesh, aa, forward_map, backward_map)
  class(mesh_t),         intent(in)    :: mesh            !< The mesh descriptor.
  class(batch_t),        intent(inout) :: aa              !< A batch which contains the mesh functions whose points will be exchanged.
  integer(int64), optional, intent(in)    :: forward_map(:)  !< A map which gives the destination of the value each point.
  logical, optional,     intent(in)    :: backward_map    !< A map which gives the source of the value of each point.
  logical :: packed_on_entry

#ifdef HAVE_MPI
  integer :: ip, npart, ipart, ist, pos, nstl
  integer, allocatable :: send_count(:), recv_count(:), send_disp(:), recv_disp(:), &
    points(:), partno(:)
  integer, allocatable :: send_count_nstl(:), recv_count_nstl(:), send_disp_nstl(:), recv_disp_nstl(:)
  integer(int64), allocatable :: send_indices(:), recv_indices(:)
  R_TYPE, allocatable  :: send_buffer(:, :), recv_buffer(:, :)
#endif

  PUSH_SUB(X(mesh_batch_exchange_points))

  ASSERT(present(backward_map) .neqv. present(forward_map))
  ASSERT(aa%type() == R_TYPE_VAL)
  packed_on_entry = aa%status() == BATCH_DEVICE_PACKED
  if (packed_on_entry) then
    call aa%do_unpack(force=.true.)
  end if

  if (.not. mesh%parallel_in_domains) then
    message(1) = "Not implemented for the serial case. Really, only in parallel."
    call messages_fatal(1)
  else

#ifdef HAVE_MPI
    npart = mesh%mpi_grp%size
    nstl = aa%nst_linear

    SAFE_ALLOCATE(send_count(1:npart))
    SAFE_ALLOCATE(recv_count(1:npart))
    SAFE_ALLOCATE(send_count_nstl(1:npart))
    SAFE_ALLOCATE(recv_count_nstl(1:npart))
    SAFE_ALLOCATE(send_disp_nstl(1:npart))
    SAFE_ALLOCATE(recv_disp_nstl(1:npart))
    SAFE_ALLOCATE(send_buffer(1:nstl, 1:mesh%np))
    SAFE_ALLOCATE(recv_buffer(1:nstl, 1:mesh%np))
    SAFE_ALLOCATE(send_indices(1:mesh%np))
    SAFE_ALLOCATE(recv_indices(1:mesh%np))

    if (present(forward_map)) then

      SAFE_ALLOCATE(send_disp(1:npart))
      SAFE_ALLOCATE(recv_disp(1:npart))
      SAFE_ALLOCATE(points(1:mesh%np))
      ASSERT(ubound(forward_map, dim = 1) == mesh%np)

      ! get their destination
      SAFE_ALLOCATE(partno(1:mesh%np))
      call partition_get_partition_number(mesh%partition, mesh%np, forward_map, partno)

      ! compute the send counts
      send_count = 0
      do ip = 1, mesh%np
        ipart = partno(ip)
        send_count(ipart) = send_count(ipart) + 1
      end do
      ASSERT(sum(send_count) == mesh%np)

      ! Receiving number of points is the inverse matrix of the sending points
      call mesh%mpi_grp%alltoall(send_count, 1, MPI_INTEGER, &
        recv_count, 1, MPI_INTEGER)
      ASSERT(sum(recv_count) == mesh%np)

      ! compute displacements
      send_disp(1) = 0
      recv_disp(1) = 0
      do ipart = 2, npart
        send_disp(ipart) = send_disp(ipart - 1) + send_count(ipart - 1)
        recv_disp(ipart) = recv_disp(ipart - 1) + recv_count(ipart - 1)
      end do

      ASSERT(send_disp(npart) + send_count(npart) == mesh%np)
      ASSERT(recv_disp(npart) + recv_count(npart) == mesh%np)

      ! Pack for sending
      send_count = 0
      select case (aa%status())
      case (BATCH_NOT_PACKED)
        do ip = 1, mesh%np
          ipart = partno(ip)
          send_count(ipart) = send_count(ipart) + 1
          pos = send_disp(ipart) + send_count(ipart)
          do ist = 1, nstl
            send_buffer(ist, pos) = aa%X(ff_linear)(ip, ist)
          end do
          send_indices(pos) = forward_map(ip)
        end do
      case (BATCH_PACKED)
        do ip = 1, mesh%np
          ipart = partno(ip)
          send_count(ipart) = send_count(ipart) + 1
          pos = send_disp(ipart) + send_count(ipart)
          do ist = 1, nstl
            send_buffer(ist, pos) = aa%X(ff_pack)(ist, ip)
          end do
          send_indices(pos) = forward_map(ip)
        end do
      end select

      SAFE_DEALLOCATE_A(partno)

      send_count_nstl = send_count * nstl
      send_disp_nstl = send_disp * nstl
      recv_count_nstl = recv_count * nstl
      recv_disp_nstl = recv_disp * nstl
      call mesh%mpi_grp%alltoallv(send_buffer, send_count_nstl, send_disp_nstl, R_MPITYPE, &
        recv_buffer, recv_count_nstl, recv_disp_nstl, R_MPITYPE)

      call mesh%mpi_grp%alltoallv(send_indices, send_count, send_disp, MPI_INTEGER8, &
        recv_indices, recv_count, recv_disp, MPI_INTEGER8)

      ! unpack received data
      select case (aa%status())
      case (BATCH_NOT_PACKED)
        !$omp parallel do simd schedule(static) private(ip)
        do pos = 1, mesh%np
          ip = mesh_global2local(mesh, recv_indices(pos))
          ASSERT(ip /= 0)
          do ist = 1, nstl
            aa%X(ff_linear)(ip, ist) = recv_buffer(ist, pos)
          end do
        end do
      case (BATCH_PACKED)
        !$omp parallel do simd schedule(static) private(ip)
        do pos = 1, mesh%np
          ip = mesh_global2local(mesh, recv_indices(pos))
          ASSERT(ip /= 0)
          do ist = 1, nstl
            aa%X(ff_pack)(ist, ip) = recv_buffer(ist, pos)
          end do
        end do
      end select

      SAFE_DEALLOCATE_A(send_disp)
      SAFE_DEALLOCATE_A(recv_disp)
      SAFE_DEALLOCATE_A(send_indices)
      SAFE_DEALLOCATE_A(recv_indices)

    else ! backward map

      recv_count = mesh%pv%recv_count
      ASSERT(sum(recv_count) == mesh%np)

      send_count = mesh%pv%send_count
      ASSERT(sum(send_count) == mesh%np)

      ASSERT(mesh%pv%send_disp(npart) + send_count(npart) == mesh%np)
      ASSERT(mesh%pv%recv_disp(npart) + recv_count(npart) == mesh%np)

      ! Pack for sending
      select case (aa%status())
      case (BATCH_NOT_PACKED)
        !$omp parallel do simd schedule(static)
        do ip = 1, mesh%np
          do ist = 1, nstl
            send_buffer(ist, mesh%pv%sendmap(ip)) = aa%X(ff_linear)(ip, ist)
          end do
        end do
      case (BATCH_PACKED)
        !$omp parallel do simd schedule(static)
        do ip = 1, mesh%np
          do ist = 1, nstl
            send_buffer(ist, mesh%pv%sendmap(ip)) = aa%X(ff_pack)(ist, ip)
          end do
        end do
      end select

      send_count_nstl = send_count * nstl
      send_disp_nstl = mesh%pv%send_disp * nstl
      recv_count_nstl = recv_count * nstl
      recv_disp_nstl = mesh%pv%recv_disp * nstl
      call mesh%mpi_grp%alltoallv(send_buffer, send_count_nstl, send_disp_nstl, R_MPITYPE, &
        recv_buffer, recv_count_nstl, recv_disp_nstl, R_MPITYPE)

      ! Unpack on receiving
      select case (aa%status())
      case (BATCH_NOT_PACKED)
        !$omp parallel
        do ist = 1, nstl
          !$omp do simd schedule(static)
          do ip = 1, mesh%np
            aa%X(ff_linear)(mesh%pv%recvmap(ip), ist) = recv_buffer(ist, ip)
          end do
        end do
        !$omp end parallel
      case (BATCH_PACKED)
        !$omp parallel do simd schedule(static)
        do ip = 1, mesh%np
          do ist = 1, nstl
            aa%X(ff_pack)(ist, mesh%pv%recvmap(ip)) = recv_buffer(ist, ip)
          end do
        end do
      end select

    end if

    SAFE_DEALLOCATE_A(send_count)
    SAFE_DEALLOCATE_A(recv_count)
    SAFE_DEALLOCATE_A(send_buffer)
    SAFE_DEALLOCATE_A(recv_buffer)
    SAFE_DEALLOCATE_A(send_count_nstl)
    SAFE_DEALLOCATE_A(recv_count_nstl)
    SAFE_DEALLOCATE_A(send_disp_nstl)
    SAFE_DEALLOCATE_A(recv_disp_nstl)
#endif
  end if

  if (packed_on_entry) then
    call aa%do_pack()
  end if
  POP_SUB(X(mesh_batch_exchange_points))
end subroutine X(mesh_batch_exchange_points)

! -----------------------------------------------------
!> This function should not be called directly, but through mesh_batch_nrm2.
!
subroutine X(priv_mesh_batch_nrm2)(mesh, aa, nrm2)
  class(mesh_t),           intent(in)    :: mesh
  class(batch_t),          intent(in)    :: aa
  real(real64),            intent(out)   :: nrm2(:)

  integer :: ist, idim, indb, ip, sp, np, num_threads, ithread, istream
  real(real64) :: a0
  real(real64), allocatable :: scal(:,:), ssq(:,:), local_scal(:), local_ssq(:)
  type(accel_mem_t)  :: nrm2_buffer

  PUSH_SUB(X(priv_mesh_batch_nrm2))
  call profiling_in(TOSTRING(X(MESH_BATCH_NRM2)))

  select case (aa%status())
  case (BATCH_NOT_PACKED)
    do ist = 1, aa%nst
      nrm2(ist) = M_ZERO
      do idim = 1, aa%dim
        indb = aa%ist_idim_to_linear((/ist, idim/))
        nrm2(ist) = hypot(nrm2(ist), X(mf_nrm2)(mesh, aa%X(ff_linear)(:, indb), reduce = .false.))
      end do
    end do

  case (BATCH_PACKED)


    num_threads = 1
    !$omp parallel
    !$omp single
!$  num_threads = omp_get_num_threads()
    !$omp end single
    !$omp end parallel

    SAFE_ALLOCATE(scal(1:aa%nst_linear, 1:num_threads))
    SAFE_ALLOCATE(ssq(1:aa%nst_linear, 1:num_threads))

    scal = M_ZERO
    ssq  = M_ONE

    ! We need to work on local quantities in the loops to avoid false sharing
    ! The values are copies to global scal/ssq arrays only once
    SAFE_ALLOCATE(local_scal(1:aa%nst_linear))
    SAFE_ALLOCATE(local_ssq(1:aa%nst_linear))


    ! divide the range from 1:mesh%np across the OpenMP threads and sum independently
    ! the reduction is done outside the parallel region
    !$omp parallel private(ithread, sp, np, a0, ip, ist, local_ssq, local_scal) firstprivate(num_threads)
    call multicomm_divide_range_omp(mesh%np, sp, np)
    ithread = 1
!$  ithread = omp_get_thread_num() + 1

    local_scal = M_ZERO
    local_ssq  = M_ONE

    ! The algorithm for the squared sum is the same as used, e.g., in openblas.
    ! The idea is that one wants to avoid an overflow caused by squaring a big
    ! number by using separate values for the sum of squares and the scale.
    ! Only at the end, the norm is computed as scal*sqrt(ssq) - in this way
    ! the largest number which is stored in scal is never squared.
    if (.not. mesh%use_curvilinear) then

      do ip = sp, sp + np - 1
        do ist = 1, aa%nst_linear
          ! first add real part
          a0 = abs(R_REAL(aa%X(ff_pack)(ist, ip)))
          ! only add a0 if it is non-zero
          if (a0 > M_TINY) then
            if (local_scal(ist) < a0) then
              local_ssq(ist) = M_ONE + local_ssq(ist)*(local_scal(ist)/a0)**2
              local_scal(ist) = a0
            else
              local_ssq(ist) = local_ssq(ist) + (a0/local_scal(ist))**2
            end if
          end if
#ifdef R_TCOMPLEX
          ! then we add imaginary part
          a0 = abs(R_AIMAG(aa%X(ff_pack)(ist, ip)))
          ! only add a0 if it is non-zero
          if (a0 > M_TINY) then
            if (local_scal(ist) < a0) then
              local_ssq(ist) = M_ONE + local_ssq(ist)*(local_scal(ist)/a0)**2
              local_scal(ist) = a0
            else
              local_ssq(ist) = local_ssq(ist) + (a0/local_scal(ist))**2
            end if
          end if
#endif
        end do
      end do

      ssq(:,ithread) = local_ssq(:)
      scal(:,ithread) = local_scal(:)

    else

      do ip = sp, sp + np - 1
        do ist = 1, aa%nst_linear
          a0 = abs(aa%X(ff_pack)(ist, ip))
          ! only add a0 if it is non-zero
          if (a0 > M_TINY) then
            if (local_scal(ist) < a0) then
              local_ssq(ist) =  mesh%vol_pp(ip) + local_ssq(ist)*(local_scal(ist)/a0)**2
              local_scal(ist) = a0
            else
              local_ssq(ist) = local_ssq(ist) + mesh%vol_pp(ip)*(a0/local_scal(ist))**2
            end if
          end if
#ifdef R_TCOMPLEX
          ! then we add imaginary part
          a0 = abs(R_AIMAG(aa%X(ff_pack)(ist, ip)))
          ! only add a0 if it is non-zero
          if (a0 > M_TINY) then
            if (local_scal(ist) < a0) then
              local_ssq(ist) =  mesh%vol_pp(ip) + local_ssq(ist)*(local_scal(ist)/a0)**2
              local_scal(ist) = a0
            else
              local_ssq(ist) = local_ssq(ist) + mesh%vol_pp(ip)*(a0/local_scal(ist))**2
            end if
          end if
#endif
        end do
      end do

      ssq(:,ithread) = local_ssq(:)
      scal(:,ithread) = local_scal(:)

    end if
    !$omp end parallel

    SAFE_DEALLOCATE_A(local_scal)
    SAFE_DEALLOCATE_A(local_ssq)

    ! now do the reduction: sum the components of the different threads without overflow
    do ithread = 2, num_threads
      do ist = 1, aa%nst_linear
        if (scal(ist, ithread) < M_EPSILON) cycle
        if (scal(ist, 1) < scal(ist, ithread)) then
          ssq(ist, 1) = ssq(ist, 1) * (scal(ist, 1)/scal(ist, ithread))**2 + ssq(ist, ithread)
          scal(ist, 1) = scal(ist, ithread)
        else
          ssq(ist, 1) = ssq(ist, 1) + ssq(ist, ithread) * (scal(ist, ithread)/scal(ist, 1))**2
        end if
      end do
    end do

    ! the result is in scal(ist, 1) and ssq(ist, 1)
    do ist = 1, aa%nst
      nrm2(ist) = M_ZERO
      do idim = 1, aa%dim
        indb = aa%ist_idim_to_linear((/ist, idim/))
        nrm2(ist) = hypot(nrm2(ist), scal(indb, 1)*sqrt(mesh%volume_element*ssq(indb, 1)))
      end do
    end do

    SAFE_DEALLOCATE_A(scal)
    SAFE_DEALLOCATE_A(ssq)

  case (BATCH_DEVICE_PACKED)

    ASSERT(.not. mesh%use_curvilinear)

    SAFE_ALLOCATE(ssq(1:aa%pack_size(1), 1))

    ! we need to make sure everything on the current stream has finished
    ! because aa%ff_device is accessed from different streams here
    call accel_finish()
    ! save the current stream, reset later
    call accel_get_stream(istream)
    call accel_create_buffer(nrm2_buffer, ACCEL_MEM_WRITE_ONLY, TYPE_FLOAT, aa%pack_size(1))

    do ist = 1, aa%nst_linear
#ifndef __HIP_PLATFORM_AMD__
      ! On ADM GPUs, running several nrm2 kernels in parallel in different streams seems
      ! to trigger a race condition, leading to stochastic deviations in the result
      call accel_set_stream(ist)
#endif
      call X(accel_nrm2)(N = int(mesh%np, int64), X = aa%ff_device, offx = int(ist - 1, int64), incx = int(aa%pack_size(1), int64), &
        res = nrm2_buffer, offres = int(ist - 1, int64))
    end do
#ifndef __HIP_PLATFORM_AMD__
    call accel_synchronize_all_streams()
    call accel_set_stream(istream)
#endif

    call accel_read_buffer(nrm2_buffer, aa%pack_size(1), ssq)

    call accel_release_buffer(nrm2_buffer)

    do ist = 1, aa%nst
      nrm2(ist) = M_ZERO
      do idim = 1, aa%dim
        indb = aa%ist_idim_to_linear((/ist, idim/))
        nrm2(ist) = hypot(nrm2(ist), sqrt(mesh%volume_element)*ssq(indb, 1))
      end do
    end do

    SAFE_DEALLOCATE_A(ssq)

  end select

  ! REDUCTION IS REQUIRED, THIS IS DONE BY THE CALLING FUNCTION

  call profiling_count_operations(real(mesh%np, real64) *aa%nst_linear*(R_ADD + R_MUL))

  call profiling_out(TOSTRING(X(MESH_BATCH_NRM2)))
  POP_SUB(X(priv_mesh_batch_nrm2))
end subroutine X(priv_mesh_batch_nrm2)

! ---------------------------------------------------------
!> @brief Orthonormalizes states of phib to the orbitals of nst batches of psi.
!!
!! It also permits doing only the orthogonalization (no normalization).
!
!TODO: add more documantation
!
subroutine X(mesh_batch_orthogonalization)(mesh, nst, psib, phib,  &
  normalize, overlap, norm, gs_scheme, full_batch)
  class(mesh_t),     intent(in)    :: mesh      !< underlying mesh
  integer,           intent(in)    :: nst       !< number of states
  class(batch_p_t),  intent(in)    :: psib(:)   !< psi(nst) (array of nst pointers to batches)
  class(batch_t),    intent(inout) :: phib      !< phi
  logical, optional, intent(in)    :: normalize !< optional flag whether to normalize the result; default = .false.
  R_TYPE,  optional, intent(out)   :: overlap(:,:) !< optional result: (nst, phib%nst)
  R_TYPE,  optional, intent(out)   :: norm(:)      !< optional result: array of norms
  integer, optional, intent(in)    :: gs_scheme    !< optional: Gram-Schmidt scheme to use
  logical, optional, intent(in)    :: full_batch   !< optional: orthogonalize full batch

  logical :: normalize_, full_batch_
  integer :: ist, is
  R_TYPE, allocatable   :: nrm2(:)
  R_TYPE, allocatable  :: ss(:,:), ss_full(:,:)
  logical :: drcgs
  integer :: nsteps
  R_TYPE :: tmp

  call profiling_in(TOSTRING(X(BATCH_GRAM_SCHMIDT)))
  PUSH_SUB(X(mesh_batch_orthogonalization))

  full_batch_ = optional_default(full_batch, .false.)

  SAFE_ALLOCATE(ss(1:phib%nst, 1:nst))
  ss = R_TOTYPE(M_ZERO)

  do ist = 1, nst
    call phib%check_compatibility_with(psib(ist)%p)
  end do

  drcgs = .false.
  nsteps = 1
  if (present(gs_scheme)) then
    if (gs_scheme == OPTION__ARNOLDIORTHOGONALIZATION__DRCGS) then
      drcgs = .true.
      nsteps = 2
      SAFE_ALLOCATE(ss_full(1:phib%nst, 1:nst))
      ss_full = R_TOTYPE(M_ZERO)
    end if
  end if

  do is = 1, nsteps
    if (nst >= 1 .and. drcgs) then
      call X(mesh_batch_dotp_vector)(mesh, psib(nst)%p, phib, ss(1:phib%nst,1))
      if (full_batch_) then
        tmp = sum(ss(1:phib%nst,1))
        ss(1:phib%nst,1) = tmp
      end if
      call batch_axpy(mesh%np, -ss(1:phib%nst,1), psib(nst)%p, phib, a_full = .false.)
      if (present(overlap)) ss_full(1:phib%nst, nst) = ss_full(1:phib%nst, nst) + ss(1:phib%nst, 1)
    end if
    ss = R_TOTYPE(M_ZERO)

    !TODO: We should reuse phib here for improved performances
    do ist = 1, nst
      call X(mesh_batch_dotp_vector)(mesh, psib(ist)%p, phib, ss(1:phib%nst,ist), reduce = .false.)
      if (full_batch_) then
        tmp = sum(ss(1:phib%nst,ist))
        ss(1:phib%nst,ist) = tmp
      end if
    end do

    call profiling_in(TOSTRING(X(BATCH_GS_REDUCE)))
    call mesh%allreduce(ss, dim = (/phib%nst, nst/))
    call profiling_out(TOSTRING(X(BATCH_GS_REDUCE)))

    !TODO: We should have a routine batch_gemv for improved performances
    do ist = 1, nst
      call batch_axpy(mesh%np, -ss(1:phib%nst,ist), psib(ist)%p, phib, a_full = .false.)
    end do

    !We accumulate the overlap
    if (drcgs .and. present(overlap)) then
      do ist = 1, nst
        ss_full(1:phib%nst, ist) = ss_full(1:phib%nst, ist) + ss(1:phib%nst, ist)
      end do
    end if
  end do

  !We have a transpose here because this helps for the Lanczos implementation
  !which is the only routine using this one at the moment
  !Indeed, Lanczos acts on phib%nst arrays of dimension nst, whereas the code would return
  !an array of dim (phib%nst, nst)
  !For an orthogalization, it is more natural to have for each state the overlap with the others
  !which is what the code outputs now.
  if (present(overlap)) then
    if (drcgs) then
      overlap(1:nst, 1:phib%nst) = transpose(ss_full(1:phib%nst, 1:nst))
    else
      overlap(1:nst, 1:phib%nst) = transpose(ss(1:phib%nst, 1:nst))
    end if
  end if

  normalize_ = optional_default(normalize, .false.)
  if (present(norm) .or. normalize_) then
    SAFE_ALLOCATE(nrm2(1:phib%nst))
    !Here we do not call mesh_batch_nrm2 which is too slow
    call X(mesh_batch_dotp_vector)(mesh, phib, phib, nrm2)
    if (full_batch_) then
      tmp = sum(nrm2)
      nrm2(:) = tmp
    end if
    if (present(norm)) then
      norm(1:phib%nst) = sqrt(real(nrm2(1:phib%nst), real64))
    end if
    if (normalize_) then
      call batch_scal(mesh%np, M_ONE/sqrt(real(nrm2, real64) ), phib, a_full =.false.)
    end if
    SAFE_DEALLOCATE_A(nrm2)
  end if

  SAFE_DEALLOCATE_A(ss)
  SAFE_DEALLOCATE_A(ss_full)

  POP_SUB(X(mesh_batch_orthogonalization))
  call profiling_out(TOSTRING(X(BATCH_GRAM_SCHMIDT)))
end subroutine X(mesh_batch_orthogonalization)

! ---------------------------------------------------------
!> @brief Normalize a batch
!
subroutine X(mesh_batch_normalize)(mesh, psib, norm)
  class(mesh_t),     intent(in)    :: mesh    !< underlying mesh
  class(batch_t),    intent(inout) :: psib    !< batch to normalize
  real(real64), optional,   intent(out)   :: norm(:) !< optional result: array of norms

  real(real64), allocatable :: nrm2(:)

  PUSH_SUB(X(mesh_batch_normalize))

  SAFE_ALLOCATE(nrm2(1:psib%nst))

  call mesh_batch_nrm2(mesh, psib, nrm2)
  if (present(norm)) then
    norm(1:psib%nst) = nrm2(1:psib%nst)
  end if
  call batch_scal(mesh%np, R_TOTYPE(M_ONE/nrm2(1:psib%nst)), psib, a_full=.false.)

  SAFE_DEALLOCATE_A(nrm2)

  POP_SUB(X(mesh_batch_normalize))
end subroutine X(mesh_batch_normalize)

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
