!! Copyright (C) 2009-2020 X. Andrade, N. Tancogne-Dejean, M. Lueders
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

! ---------------------------------------------------------------------------------------
!> @brief Start application of non-local potentials (stored in the Hamiltonian) to
!! the wave functions
!!
!! This routine is split into the *_start and *_finish parts, in order to allow
!! for other operations during device transfer or communications.
!!
!! TODO: add more details.
!
subroutine X(nonlocal_pseudopotential_start)(this, mesh, std, spiral_bnd, psib, projection, async)
  class(nonlocal_pseudopotential_t), target, intent(in)    :: this
  class(mesh_t),                    intent(in)    :: mesh    !< the mesh
  type(states_elec_dim_t),          intent(in)    :: std     !< dimensions of the states
  logical,                          intent(in)    :: spiral_bnd !< flag for spiral boundary conditions
  type(wfs_elec_t),                 intent(in)    :: psib       !< original wave functions
  type(projection_t),               intent(out)   :: projection !< TODO: add description
  logical, optional,                intent(in)    :: async

  integer :: ist, iproj, imat, nreal, iprojection
  integer :: npoints, nprojs, nst_linear, maxnpoints
  integer, allocatable :: ind(:)
  type(projector_matrix_t), pointer :: pmat
  integer(int64) :: padnprojs, lnprojs, size, thread_block_size, grid_size
  integer :: nphase
  type(accel_kernel_t), save, target :: zker_proj_bra, dker_proj_bra, dker_proj_bra_phase_spiral
  type(accel_kernel_t), save, target :: zker_proj_bra_phase, dker_proj_bra_phase, zker_proj_bra_phase_spiral
  type(accel_kernel_t), pointer :: kernel
  integer, allocatable :: spin_to_phase(:)
  R_TYPE, allocatable :: lpsi(:, :)
#ifdef R_TCOMPLEX
  integer :: iphase, map_ip, ip
  complex(real64), allocatable :: tmp_proj(:, :)
#endif

  integer :: block_size

  if (.not. this%has_non_local_potential) return

  ASSERT(this%apply_projector_matrices)

  call profiling_in(TOSTRING(X(VNLPSI_MAT_BRA)))
  PUSH_SUB(X(nonlocal_pseudopotential_start))

  nst_linear = psib%nst_linear
#ifdef R_TCOMPLEX
  nreal = 2*nst_linear
#else
  nreal = nst_linear
#endif
  nphase = 1

  if (psib%has_phase) then
    ASSERT(allocated(this%projector_phases))
  end if

  if (psib%status() == BATCH_DEVICE_PACKED) then

    call profiling_in(TOSTRING(X(CL_PROJ_BRA)))
    ! only do this if we have some points of projector matrices
    if (this%max_npoints > 0) then

      call accel_create_buffer(projection%buff_projection, ACCEL_MEM_READ_WRITE, R_TYPE_VAL, &
        this%full_projection_size*psib%pack_size(1))

      if (allocated(this%projector_phases)) then
        ASSERT(R_TYPE_VAL == TYPE_CMPLX)

        if (spiral_bnd) then

          nphase = 3

          SAFE_ALLOCATE(spin_to_phase(1:psib%pack_size(1)))
          call accel_create_buffer(projection%buff_spin_to_phase, ACCEL_MEM_READ_ONLY, TYPE_INTEGER, psib%pack_size(1))

          do ist = 1, nst_linear
            if (this%spin(3, psib%linear_to_ist(ist), psib%ik) > 0 .and. psib%linear_to_idim(ist) == 2) then
              spin_to_phase(ist) = 1
            else if (this%spin(3, psib%linear_to_ist(ist), psib%ik) < 0 .and. psib%linear_to_idim(ist) == 1) then
              spin_to_phase(ist) = 2
            else
              spin_to_phase(ist) = 0
            end if
          end do
          ! This might not be necessary:
          do ist = nst_linear+1, int(psib%pack_size(1), int32)
            spin_to_phase(ist) = 0
          end do

          call accel_write_buffer(projection%buff_spin_to_phase, psib%pack_size(1), spin_to_phase)

          if (this%projector_matrices(1)%is_cmplx) then
            call accel_kernel_start_call(zker_proj_bra_phase_spiral, 'projector.cl', 'zprojector_bra_phase_spiral',&
              flags = '-DRTYPE_COMPLEX')
            kernel => zker_proj_bra_phase_spiral
          else
            call accel_kernel_start_call(dker_proj_bra_phase_spiral, 'projector.cl', 'dprojector_bra_phase_spiral',&
              flags = '-DRTYPE_DOUBLE')
            kernel => dker_proj_bra_phase_spiral
          end if
          SAFE_DEALLOCATE_A(spin_to_phase)
        else
          if (this%projector_matrices(1)%is_cmplx) then
            call accel_kernel_start_call(zker_proj_bra_phase, 'projector.cl', 'zprojector_bra_phase',&
              flags = '-DRTYPE_COMPLEX')
            kernel => zker_proj_bra_phase
          else
            call accel_kernel_start_call(dker_proj_bra_phase, 'projector.cl', 'dprojector_bra_phase',&
              flags = '-DRTYPE_DOUBLE')
            kernel => dker_proj_bra_phase
          end if
        end if
        size = psib%pack_size(1)

      else
        if (this%projector_matrices(1)%is_cmplx) then
          call accel_kernel_start_call(zker_proj_bra, 'projector.cl', 'zprojector_bra', flags = '-DRTYPE_COMPLEX')
          kernel => zker_proj_bra
          size = psib%pack_size(1)
        else
          call accel_kernel_start_call(dker_proj_bra, 'projector.cl', 'dprojector_bra', flags = '-DRTYPE_DOUBLE')
          kernel => dker_proj_bra
          size = psib%pack_size_real(1)
        end if
      end if

      ! NOTE: For the projector_bra_* kernels, there is no problem with  self-overlapping spheres.

      call accel_set_kernel_arg(kernel, 0, this%nprojector_matrices)
      call accel_set_kernel_arg(kernel, 1, this%buff_offsets)
      call accel_set_kernel_arg(kernel, 2, this%buff_matrices)
      call accel_set_kernel_arg(kernel, 3, this%buff_maps)
      call accel_set_kernel_arg(kernel, 4, this%buff_scals)
      call accel_set_kernel_arg(kernel, 5, psib%ff_device)
      call accel_set_kernel_arg(kernel, 6, int(log2(size), int32))
      call accel_set_kernel_arg(kernel, 7, projection%buff_projection)
      call accel_set_kernel_arg(kernel, 8, int(log2(size), int32))

      if (allocated(this%projector_phases)) then
        call accel_set_kernel_arg(kernel, 9, this%buff_projector_phases)
        ! Note: we need to use this%nphase, as the kernel might be called with spiral=false, but
        !       the phases been built with spiralBC=true
        call accel_set_kernel_arg(kernel, 10, (psib%ik - std%kpt%start)*this%total_points*this%nphase)
        if (spiral_bnd) then
          call accel_set_kernel_arg(kernel, 11, projection%buff_spin_to_phase)
          call accel_set_kernel_arg(kernel, 12, this%nphase)
        end if
      end if

      ! In case of CUDA we use an optimized kernel, in which the loop over npoints is broken
      ! further into chunks, in order to parallelize over the threads within a warp.
      ! Therefore we need to launch warp_size * size kernels. The size of each block needs to
      ! have multiples of warp_size as x-dimension.
      call accel_get_unfolded_size(size, grid_size, thread_block_size)
      lnprojs = min(accel_kernel_workgroup_size(kernel)/thread_block_size, int(this%max_nprojs, int64))
      padnprojs = pad(this%max_nprojs, lnprojs)

      call accel_kernel_run(kernel, &
        (/grid_size, padnprojs, int(this%nprojector_matrices, int64)/), (/thread_block_size, lnprojs, 1_int64/))

      do imat = 1, this%nprojector_matrices
        pmat => this%projector_matrices(imat)

        npoints = pmat%npoints
        nprojs = pmat%nprojs

        !! update number of operations for nphase !!
        call profiling_count_operations(nreal*nprojs*M_TWO*npoints + nst_linear*nprojs)
      end do

      if(.not. optional_default(async, .false.)) call accel_finish()
    end if

    if (mesh%parallel_in_domains) then
      SAFE_ALLOCATE(projection%X(projection)(1:psib%pack_size(1), 1:this%full_projection_size))
      projection%X(projection) = M_ZERO
      if (this%max_npoints > 0) then
        call accel_read_buffer(projection%buff_projection, &
          this%full_projection_size*psib%pack_size(1), projection%X(projection))
      end if
    end if
    call profiling_out(TOSTRING(X(CL_PROJ_BRA)))

  else

    ! This routine uses blocking to optimize cache usage. One block of
    ! |phi> is loaded in cache L1 and then then we calculate the dot
    ! product of it with the corresponding blocks of |psi_k>, next we
    ! load another block and do the same. This way we only have to load
    ! |psi> from the L2 or memory.
    block_size = cpu_hardware%X(block_size)


    SAFE_ALLOCATE(projection%X(projection)(1:nst_linear, 1:this%full_projection_size))
    projection%X(projection) = M_ZERO

    SAFE_ALLOCATE(ind(1:this%nprojector_matrices))

    iprojection = 0
    maxnpoints = 0
    do imat = 1, this%nprojector_matrices
      pmat => this%projector_matrices(imat)
      npoints = pmat%npoints
      maxnpoints = max(maxnpoints, npoints)
      nprojs = pmat%nprojs
      ind(imat) = iprojection
      iprojection = iprojection + nprojs
      call profiling_count_operations(nprojs*(R_ADD + R_MUL)*npoints + nst_linear*nprojs)
      if (allocated(this%projector_phases)) then
        call profiling_count_operations(R_MUL*npoints*nst_linear)
      end if
    end do

    SAFE_ALLOCATE(lpsi(1:nst_linear, 1:maxnpoints))

    do imat = 1, this%nprojector_matrices
      pmat => this%projector_matrices(imat)
      iprojection = ind(imat)
      npoints = pmat%npoints
      nprojs = pmat%nprojs

      if (npoints == 0) cycle

      if (.not. allocated(this%projector_phases)) then

        call X(batch_copy_with_map_to_array)(npoints, pmat%map, psib, lpsi)

      else
#ifdef R_TCOMPLEX
        if (.not. spiral_bnd) then
          if (psib%status() == BATCH_PACKED) then
            !$omp parallel do private(ist, map_ip)
            do ip = 1, npoints
              map_ip = pmat%map(ip)
              !$omp simd
              do ist = 1, nst_linear
                lpsi(ist, ip) = psib%zff_pack(ist, map_ip)*this%projector_phases(ip, 1, imat, psib%ik)
              end do
            end do
          else
            !$omp parallel
            do ist = 1, nst_linear
              !$omp do simd
              do ip = 1, npoints
                lpsi(ist, ip) = psib%zff_linear(pmat%map(ip), ist)*this%projector_phases(ip, 1, imat, psib%ik)
              end do
            end do
            !$omp end parallel
          end if
        else
          if (psib%status() == BATCH_PACKED) then
            !$omp parallel do private(ist)
            do ip = 1, npoints
              do ist = 1, nst_linear, 2
                if (this%spin(3,psib%linear_to_ist(ist), psib%ik)>0) then
                  lpsi(ist, ip)   = psib%zff_pack(ist,   pmat%map(ip))*this%projector_phases(ip, 1, imat, psib%ik)
                  lpsi(ist+1, ip) = psib%zff_pack(ist+1, pmat%map(ip))*this%projector_phases(ip, 2, imat, psib%ik)
                else
                  lpsi(ist, ip)   = psib%zff_pack(ist,   pmat%map(ip))*this%projector_phases(ip, 3, imat, psib%ik)
                  lpsi(ist+1, ip) = psib%zff_pack(ist+1, pmat%map(ip))*this%projector_phases(ip, 1, imat, psib%ik)
                end if
              end do
            end do
          else
            do ist = 1, nst_linear
              if (this%spin(3, psib%linear_to_ist(ist), psib%ik) > 0 .and. psib%linear_to_idim(ist) == 2) then
                iphase = 2
              else if (this%spin(3, psib%linear_to_ist(ist), psib%ik) < 0 .and. psib%linear_to_idim(ist) == 1) then
                iphase = 3
              else
                iphase = 1
              end if
              !$omp parallel do
              do ip = 1, npoints
                lpsi(ist, ip) = psib%zff_linear(pmat%map(ip), ist)*this%projector_phases(ip, iphase, imat, psib%ik)
              end do
            end do
          end if
        end if
#else
        ! Phases not allowed for real batches
        ASSERT(.false.)
#endif
      end if

      if (pmat%is_cmplx) then
#ifdef R_TCOMPLEX
        SAFE_ALLOCATE(tmp_proj(1:nprojs, 1:nst_linear))
        call blas_gemm('C', 'T', nprojs, nst_linear, npoints, &
          M_z1, pmat%zprojectors(1, 1), npoints, lpsi(1, 1), nst_linear, M_z0, tmp_proj(1,1), nprojs)
        !$omp parallel do private(iproj, ist)
        do iproj = 1, nprojs
          !$omp simd
          do ist = 1, nst_linear
            projection%X(projection)(ist, iprojection + iproj) = tmp_proj(iproj, ist)
          end do
        end do
        !$omp end parallel do
        SAFE_DEALLOCATE_A(tmp_proj)
        call profiling_count_operations(nst_linear*nprojs*M_TWO*npoints)
#else
        ! Complex projection matrix not allowed for real batches
        ASSERT(.false.)
#endif
      else
        call blas_gemm('N', 'N', nreal, nprojs, npoints, &
          M_ONE, lpsi(1, 1), nreal, pmat%dprojectors(1, 1), npoints, M_ZERO,  projection%X(projection)(1, iprojection + 1), nreal)
        call profiling_count_operations(nreal*nprojs*M_TWO*npoints)
      end if

      !$omp parallel do private(iproj, ist)
      do iproj = 1, nprojs
        !$omp simd
        do ist = 1, nst_linear
          projection%X(projection)(ist, iprojection + iproj) = projection%X(projection)(ist, iprojection + iproj)*pmat%scal(iproj)
        end do
      end do
      !$omp end parallel do

    end do

    SAFE_DEALLOCATE_A(ind)
    SAFE_DEALLOCATE_A(lpsi)
  endif

  POP_SUB(X(nonlocal_pseudopotential_start))
  call profiling_out(TOSTRING(X(VNLPSI_MAT_BRA)))
end subroutine X(nonlocal_pseudopotential_start)

! ---------------------------------------------------------------------------------------
!> @brief finish the application of non-local potentials.
!
subroutine X(nonlocal_pseudopotential_finish)(this, mesh, spiral_bnd, std, projection, vpsib)
  class(nonlocal_pseudopotential_t), target, intent(in)    :: this
  class(mesh_t),                    intent(in)    :: mesh
  type(states_elec_dim_t),          intent(in)    :: std
  logical,                          intent(in)    :: spiral_bnd
  type(projection_t),       target, intent(inout) :: projection
  class(wfs_elec_t),                intent(inout) :: vpsib

  integer :: ist, ip, imat, nreal, iprojection
  integer :: npoints, nprojs, nst_linear, nphase
  R_TYPE, allocatable :: psi(:, :)
  type(projector_matrix_t), pointer :: pmat
#ifdef R_TCOMPLEX
  integer :: iproj, idim, iphase
  complex(real64)  :: phase, phase_pq, phase_mq
  complex(real64), allocatable :: tmp_proj(:, :, :)
#endif

  if (.not. this%has_non_local_potential) return

  ASSERT(this%apply_projector_matrices)

  call profiling_in(TOSTRING(X(VNLPSI_MAT_KET)))
  PUSH_SUB(X(nonlocal_pseudopotential_finish))

  nst_linear = vpsib%nst_linear
#ifdef R_TCOMPLEX
  nreal = 2*nst_linear
#else
  nreal = nst_linear
#endif
  nphase = 1
  if (spiral_bnd) nphase = 3

  if (vpsib%has_phase) then
    ASSERT(allocated(this%projector_phases))
  end if

  ! reduce the projections. Note that projection%X(projection) is only allocated if mesh%parallel_in_domains
  if (mesh%parallel_in_domains) then
    call profiling_in(TOSTRING(X(VNLPSI_MAT_REDUCE)))
    call mesh%allreduce(projection%X(projection))
    call profiling_out(TOSTRING(X(VNLPSI_MAT_REDUCE)))
  end if

  if (vpsib%status() == BATCH_DEVICE_PACKED) then

    if (mesh%parallel_in_domains) then
      ! only do this if we have points of some projector matrices
      if (this%max_npoints > 0) then
        call accel_write_buffer(projection%buff_projection, &
          this%full_projection_size*vpsib%pack_size(1), projection%X(projection), async=.true.)
      end if
    end if

    call finish_accel()
    SAFE_DEALLOCATE_A(projection%X(projection))
    call accel_release_buffer(projection%buff_projection)
    if (spiral_bnd) then
      call accel_release_buffer(projection%buff_spin_to_phase)
    end if

  else

    ASSERT(allocated(projection%X(projection)))

    iprojection = 0
    do imat = 1, this%nprojector_matrices
      pmat => this%projector_matrices(imat)

      npoints = pmat%npoints
      nprojs = pmat%nprojs

      if (allocated(pmat%zmix)) then
#ifdef R_TCOMPLEX
        SAFE_ALLOCATE(tmp_proj(1:nprojs, 1:vpsib%nst, 1:std%dim))

        do ist = 1, vpsib%nst

          tmp_proj(1:nprojs, ist, 1) = matmul(pmat%zmix(1:nprojs, 1:nprojs, 1), &
            projection%X(projection)((ist-1)*std%dim+1, iprojection + 1:iprojection + nprojs)) &
            + matmul(pmat%zmix(1:nprojs, 1:nprojs, 3), &
            projection%X(projection)((ist-1)*std%dim+2, iprojection + 1:iprojection + nprojs))

          tmp_proj(1:nprojs, ist, 2) = matmul(pmat%zmix(1:nprojs, 1:nprojs, 2), &
            projection%X(projection)((ist-1)*std%dim+2, iprojection + 1:iprojection + nprojs)) &
            + matmul(pmat%zmix(1:nprojs, 1:nprojs, 4), &
            projection%X(projection)((ist-1)*std%dim+1, iprojection + 1:iprojection + nprojs))
        end do

        do ist = 1, vpsib%nst
          do idim = 1, std%dim
            do iproj = 1, nprojs
              projection%X(projection)((ist-1)*std%dim+idim, iprojection + iproj) = tmp_proj(iproj, ist, idim)
            end do
          end do
        end do

        SAFE_DEALLOCATE_A(tmp_proj)
#else
        ! Complex projection matrix not allowed for real batches
        ASSERT(.false.)
#endif
      else if (allocated(pmat%dmix)) then
        do ist = 1, nst_linear
          projection%X(projection)(ist, iprojection + 1:iprojection + nprojs) = &
            matmul(pmat%dmix(1:nprojs, 1:nprojs), projection%X(projection)(ist, iprojection + 1:iprojection + nprojs))
        end do
      end if

      if (npoints /=  0) then

        SAFE_ALLOCATE(psi(1:nst_linear, 1:npoints))

        ! Matrix-multiply again.
        ! the line below does: psi = matmul(projection, transpose(pmat%projectors))
        if (.not. pmat%is_cmplx) then
          call blas_gemm('N', 'T', nreal, npoints, nprojs, &
            M_ONE, projection%X(projection)(1, iprojection + 1), nreal, pmat%dprojectors(1, 1), npoints, &
            M_ZERO, psi(1, 1), nreal)
          call profiling_count_operations(nreal*nprojs*M_TWO*npoints)
        else
#ifdef R_TCOMPLEX
          call blas_gemm('N', 'T', nst_linear, npoints, nprojs, &
            M_z1, projection%X(projection)(1, iprojection + 1), nst_linear, pmat%zprojectors(1, 1), npoints, &
            M_z0, psi(1, 1), nst_linear)
          call profiling_count_operations(nst_linear*nprojs*(R_ADD+R_MUL)*npoints)
#else
          ! Complex projection matrix not allowed for real batches
          ASSERT(.false.)
#endif
        end if

        call profiling_in(TOSTRING(X(PROJ_MAT_SCATTER)))

        if (.not. allocated(this%projector_phases)) then
          ! and copy the points from the local buffer to its position
          if (vpsib%status() == BATCH_PACKED) then
            !$omp parallel do private(ip, ist) if (.not. this%projector_self_overlap)
            do ip = 1, npoints
              do ist = 1, nst_linear
                vpsib%X(ff_pack)(ist, pmat%map(ip)) = vpsib%X(ff_pack)(ist, pmat%map(ip)) + psi(ist, ip)
              end do
            end do
            !$omp end parallel do
          else
            do ist = 1, nst_linear
              !$omp parallel do if (.not. this%projector_self_overlap)
              do ip = 1, npoints
                vpsib%X(ff_linear)(pmat%map(ip), ist) = vpsib%X(ff_linear)(pmat%map(ip), ist) + psi(ist, ip)
              end do
              !$omp end parallel do
            end do
          end if
          call profiling_count_operations(nst_linear*npoints*R_ADD)
        else
#ifdef R_TCOMPLEX
          if (.not. spiral_bnd) then
            ! and copy the points from the local buffer to its position
            if (vpsib%status() == BATCH_PACKED) then
              !$omp parallel do private(ip, ist, phase) if (.not. this%projector_self_overlap)
              do ip = 1, npoints
                phase = conjg(this%projector_phases(ip, 1, imat, vpsib%ik))
                do ist = 1, nst_linear
                  vpsib%zff_pack(ist, pmat%map(ip)) = vpsib%zff_pack(ist, pmat%map(ip)) &
                    + psi(ist, ip)*phase
                end do
              end do
              !$omp end parallel do
            else
              do ist = 1, nst_linear
                !$omp parallel do if (.not. this%projector_self_overlap)
                do ip = 1, npoints
                  vpsib%zff_linear(pmat%map(ip), ist) = vpsib%zff_linear(pmat%map(ip), ist) &
                    + psi(ist, ip)*conjg(this%projector_phases(ip, 1, imat, vpsib%ik))
                end do
                !$omp end parallel do
              end do
            end if
            call profiling_count_operations(nst_linear*npoints*(R_ADD+R_MUL))
          else
            ! and copy the points from the local buffer to its position
            if (vpsib%status() == BATCH_PACKED) then
              !$omp parallel do private(ip, ist, phase, phase_pq, phase_mq) if (.not. this%projector_self_overlap)
              do ip = 1, npoints
                phase = conjg(this%projector_phases(ip, 1, imat, vpsib%ik))
                phase_pq = conjg(this%projector_phases(ip, 2, imat, vpsib%ik))
                phase_mq = conjg(this%projector_phases(ip, 3, imat, vpsib%ik))
                do ist = 1, nst_linear, 2
                  if (this%spin(3, vpsib%linear_to_ist(ist), vpsib%ik) > 0) then
                    vpsib%zff_pack(ist, pmat%map(ip)) = vpsib%zff_pack(ist, pmat%map(ip)) &
                      + psi(ist, ip)*phase
                    vpsib%zff_pack(ist+1, pmat%map(ip)) = vpsib%zff_pack(ist+1, pmat%map(ip)) &
                      + psi(ist+1, ip)*phase_pq
                  else
                    vpsib%zff_pack(ist, pmat%map(ip)) = vpsib%zff_pack(ist, pmat%map(ip)) &
                      + psi(ist, ip)*phase_mq
                    vpsib%zff_pack(ist+1, pmat%map(ip)) = vpsib%zff_pack(ist+1, pmat%map(ip)) &
                      + psi(ist+1, ip)*phase
                  end if
                end do
              end do
              !$omp end parallel do
            else
              do ist = 1, nst_linear
                if (this%spin(3, vpsib%linear_to_ist(ist), vpsib%ik) > 0 .and. vpsib%linear_to_idim(ist) == 2) then
                  iphase = 2
                else if (this%spin(3, vpsib%linear_to_ist(ist), vpsib%ik) < 0 .and. vpsib%linear_to_idim(ist) == 1) then
                  iphase = 3
                else
                  iphase = 1
                end if
                !$omp parallel do if (.not. this%projector_self_overlap)
                do ip = 1, npoints
                  vpsib%zff_linear(pmat%map(ip), ist) = vpsib%zff_linear(pmat%map(ip), ist) &
                    + psi(ist, ip)*conjg(this%projector_phases(ip, iphase, imat, vpsib%ik))
                end do
                !$omp end parallel do
              end do
            end if
            call profiling_count_operations(nst_linear*npoints*(R_ADD+R_MUL))
          end if
#else
          ! Phases not allowed for real batches
          ASSERT(.false.)
#endif
        end if
        call profiling_out(TOSTRING(X(PROJ_MAT_SCATTER)))
      end if

      SAFE_DEALLOCATE_A(psi)

      iprojection = iprojection + nprojs
    end do
  end if
  SAFE_DEALLOCATE_A(projection%X(projection))

  POP_SUB(X(nonlocal_pseudopotential_finish))
  call profiling_out(TOSTRING(X(VNLPSI_MAT_KET)))

contains

  subroutine finish_accel()

    integer :: imat, iregion, nregions_self_overlap, iregion_self_overlap
    integer(int64) :: padnprojs, lnprojs, size, size2, wgsize
    type(accel_kernel_t), save, target :: ker_proj_ket, dker_mix, zker_mix
    type(accel_kernel_t), save, target :: zker_proj_ket_phase, dker_proj_ket_phase
    type(accel_kernel_t), save, target :: dker_proj_ket_phase_spiral, zker_proj_ket_phase_spiral
    type(accel_kernel_t), pointer :: kernel, ker_mix
    type(accel_mem_t), pointer :: buff_proj

    PUSH_SUB(X(nonlocal_pseudopotential_finish).finish_accel)

    ! In this case we run one kernel per projector, since all write to
    ! the wave-function. Otherwise we would need to do atomic
    ! operations.

    ! only do this if we have points of some projector matrices
    if (this%max_npoints > 0) then

      if (this%projector_mix) then
        call profiling_in(TOSTRING(X(CL_PROJ_MIX)))

        SAFE_ALLOCATE(buff_proj)
        if (allocated(this%projector_matrices(1)%zmix)) then
          size = vpsib%pack_size(1)
          size2 = size/2
          call accel_kernel_start_call(zker_mix, 'projector.cl', 'zprojector_mix', flags = '-DRTYPE_COMPLEX')
          call accel_create_buffer(buff_proj, ACCEL_MEM_READ_WRITE, TYPE_CMPLX, this%full_projection_size*size)
          ker_mix => zker_mix
        else
          size = vpsib%pack_size_real(1)
          size2 = size
          call accel_kernel_start_call(dker_mix, 'projector.cl', 'dprojector_mix', flags = '-DRTYPE_DOUBLE')
          call accel_create_buffer(buff_proj, ACCEL_MEM_READ_WRITE, TYPE_FLOAT, this%full_projection_size*size)
          ker_mix => dker_mix
        end if

        call accel_set_kernel_arg(ker_mix, 0, this%nprojector_matrices)
        call accel_set_kernel_arg(ker_mix, 1, this%buff_offsets)
        call accel_set_kernel_arg(ker_mix, 2, this%buff_mix)
        call accel_set_kernel_arg(ker_mix, 3, projection%buff_projection)
        call accel_set_kernel_arg(ker_mix, 4, int(log2(size), int32))
        call accel_set_kernel_arg(ker_mix, 5, buff_proj)

        lnprojs = min(accel_kernel_workgroup_size(ker_mix)/size2, int(this%max_nprojs, int64))
        padnprojs = pad(this%max_nprojs, lnprojs)

        call accel_kernel_run(ker_mix, &
          (/size2, padnprojs, int(this%nprojector_matrices, int64)/), (/size2, lnprojs, 1_int64/))

        call profiling_out(TOSTRING(X(CL_PROJ_MIX)))
      else

        buff_proj => projection%buff_projection

      end if

      call profiling_in(TOSTRING(X(CL_PROJ_KET)))

      if (allocated(this%projector_phases)) then
        ASSERT(R_TYPE_VAL == TYPE_CMPLX)
        if (spiral_bnd) then
          if (this%projector_matrices(1)%is_cmplx) then
            call accel_kernel_start_call(zker_proj_ket_phase_spiral, 'projector.cl', 'zprojector_ket_phase_spiral',&
              flags = '-DRTYPE_COMPLEX')
            kernel => zker_proj_ket_phase_spiral
          else
            call accel_kernel_start_call(dker_proj_ket_phase_spiral, 'projector.cl', 'dprojector_ket_phase_spiral',&
              flags = '-DRTYPE_DOUBLE')
            kernel => dker_proj_ket_phase_spiral
          end if
        else
          if (this%projector_matrices(1)%is_cmplx) then
            call accel_kernel_start_call(zker_proj_ket_phase, 'projector.cl', 'zprojector_ket_phase',&
              flags = '-DRTYPE_COMPLEX')
            kernel => zker_proj_ket_phase
          else
            call accel_kernel_start_call(dker_proj_ket_phase, 'projector.cl', 'dprojector_ket_phase',&
              flags = '-DRTYPE_DOUBLE')
            kernel => dker_proj_ket_phase
          end if
        end if
        size = vpsib%pack_size(1)
      else
        if (this%projector_matrices(1)%is_cmplx) then
          call accel_kernel_start_call(ker_proj_ket, 'projector.cl', 'zprojector_ket', flags = '-DRTYPE_COMPLEX')
          size = vpsib%pack_size(1)
        else
          call accel_kernel_start_call(ker_proj_ket, 'projector.cl', 'dprojector_ket', flags = '-DRTYPE_DOUBLE')
          size = vpsib%pack_size_real(1)
        end if
        kernel => ker_proj_ket
      end if

      do iregion = 1, this%nregions

        do imat = this%regions(iregion), this%regions(iregion+1)-1
          nregions_self_overlap = this%projector_matrices(imat)%nregions

          do iregion_self_overlap = 1, nregions_self_overlap

            call accel_set_kernel_arg(kernel, 0, this%nprojector_matrices)
            call accel_set_kernel_arg(kernel, 1, imat - 1)
            call accel_set_kernel_arg(kernel, 2, this%buff_offsets)
            call accel_set_kernel_arg(kernel, 3, this%buff_matrices)
            call accel_set_kernel_arg(kernel, 4, this%buff_maps)
            call accel_set_kernel_arg(kernel, 5, buff_proj)
            call accel_set_kernel_arg(kernel, 6, int(log2(size), int32))
            call accel_set_kernel_arg(kernel, 7, vpsib%ff_device)
            call accel_set_kernel_arg(kernel, 8, int(log2(size), int32))

            call accel_set_kernel_arg(kernel, 9, this%projector_matrices(imat)%regions(iregion_self_overlap) - 1)
            call accel_set_kernel_arg(kernel,10, this%projector_matrices(imat)%regions(iregion_self_overlap+1) - 1)

            if (allocated(this%projector_phases)) then
              call accel_set_kernel_arg(kernel, 11, this%buff_projector_phases)
              ! Note: we need to use this%nphase, as the kernel might be called with spiral=false, but
              !       the phases been built with spiralBC=true
              call accel_set_kernel_arg(kernel, 12, (vpsib%ik - std%kpt%start)*this%total_points*this%nphase)
              if (spiral_bnd) then
                call accel_set_kernel_arg(kernel, 13, projection%buff_spin_to_phase)
                call accel_set_kernel_arg(kernel, 14, this%nphase)
              end if
            end if

            wgsize = accel_kernel_workgroup_size(kernel)/size

            ! implicit loops over ist, ip
            call accel_kernel_run(kernel, &
              (/size, int(pad(this%max_npoints, wgsize), int64), 1_int64/), &
              (/size, wgsize, 1_int64/))

          end do ! iregion_self_overlap
        end do ! iatom

      end do ! iregion

      do imat = 1, this%nprojector_matrices
        pmat => this%projector_matrices(imat)
        npoints = pmat%npoints
        nprojs = pmat%nprojs
        call profiling_count_operations(nreal*nprojs*M_TWO*npoints)
        call profiling_count_operations(nst_linear*npoints*R_ADD)
      end do

      ! we need to synchronize here to make sure GPU and CPU buffers can be deallocated
      call accel_finish()

      if (this%projector_mix) then
        call accel_release_buffer(buff_proj)
        SAFE_DEALLOCATE_P(buff_proj)
      end if
      call profiling_out(TOSTRING(X(CL_PROJ_KET)))
    end if

    POP_SUB(X(nonlocal_pseudopotential_finish).finish_accel)
  end subroutine finish_accel

end subroutine X(nonlocal_pseudopotential_finish)

! ---------------------------------------------------------------------------------------
!> @brief calculate contribution to forces, from non-local potentials
!!
!! TODO: add more details
!
subroutine X(nonlocal_pseudopotential_force)(this, mesh, st, spiral_bnd, iqn, ndim, psi1b, psi2b, force)
  class(nonlocal_pseudopotential_t), target, intent(in)    :: this
  class(mesh_t),                         intent(in)    :: mesh
  type(states_elec_t),                   intent(in)    :: st
  logical,                               intent(in)    :: spiral_bnd
  integer,                               intent(in)    :: iqn
  integer,                               intent(in)    :: ndim
  type(wfs_elec_t),                      intent(in)    :: psi1b
  type(wfs_elec_t),                      intent(in)    :: psi2b(:)
  real(real64),                          intent(inout) :: force(:, :)

  integer :: ii, ist, ip, iproj, imat, nreal, iprojection, iatom, idir
  integer :: npoints, nprojs, nst_linear
  real(real64), allocatable :: ff(:)
  R_TYPE, allocatable :: psi(:, :, :), projs(:, :, :)
  type(projector_matrix_t), pointer :: pmat
  integer(int64) :: padnprojs, lnprojs, size, grid_size, thread_block_size
  type(accel_mem_t) :: buff_projs
  type(accel_kernel_t), save, target :: dker_proj_bra_force, zker_proj_bra_force
  type(accel_kernel_t), save, target :: zker_proj_bra_force_phase, dker_proj_bra_force_phase
  type(accel_kernel_t), pointer :: kernel
#ifdef R_TCOMPLEX
  integer :: idim
  complex(real64), allocatable :: tmp_proj(:, :, :)
#endif

  if (.not. this%has_non_local_potential) return

  ASSERT(this%apply_projector_matrices)

  call profiling_in(TOSTRING(X(VNLPSI_MAT_ELEM)))
  PUSH_SUB(X(nonlocal_pseudopotential_force))

  ASSERT(psi1b%nst_linear == psi2b(1)%nst_linear)
  ASSERT(psi1b%status() == psi2b(1)%status())

  ASSERT(.not. spiral_bnd)

  nst_linear = psi1b%nst_linear
#ifdef R_TCOMPLEX
  nreal = 2*nst_linear
#else
  nreal = nst_linear
#endif

  if( .not. psi1b%status() == BATCH_DEVICE_PACKED) then
    SAFE_ALLOCATE(projs(0:ndim, 1:nst_linear, 1:this%full_projection_size))
    projs = 0.0_real64

    iprojection = 0
    do imat = 1, this%nprojector_matrices
      pmat => this%projector_matrices(imat)

      npoints = pmat%npoints
      nprojs = pmat%nprojs

      if (npoints /= 0) then

        SAFE_ALLOCATE(psi(0:ndim, 1:nst_linear, 1:npoints))

        call profiling_in(TOSTRING(X(PROJ_MAT_ELEM_GATHER)))

        ! Phases not allowed for real batches
        if (allocated(this%projector_phases)) then
#ifndef R_TCOMPLEX
          message(1) = "Phases not allowed for real batches"
          call messages_fatal(1)
#endif
        end if

        ! Collect all the points we need in a continuous array and apply the phase (if applicable)
        select case (psi1b%status())
        case (BATCH_DEVICE_PACKED)
          call messages_not_implemented("nonlocal_pseudopotential_nlocal_force for BATCH_DEVICE_PACKED")

        case (BATCH_PACKED)
          if (allocated(this%projector_phases)) then ! Complex batches with phase
#ifdef R_TCOMPLEX
            !$omp parallel do private(ist, idir)
            do ip = 1, npoints
              do ist = 1, nst_linear
                psi(0, ist, ip) = psi1b%X(ff_pack)(ist, pmat%map(ip)) * this%projector_phases(ip, 1, imat, psi1b%ik)
                do idir = 1, ndim
                  psi(idir, ist, ip) = psi2b(idir)%X(ff_pack)(ist, pmat%map(ip)) * this%projector_phases(ip, 1, imat, psi1b%ik)
                end do
              end do
            end do
#endif
          else ! Real batches, or complex batches without phases
            !$omp parallel do private(ist, idir)
            do ip = 1, npoints
              do ist = 1, nst_linear
                psi(0, ist, ip) = psi1b%X(ff_pack)(ist, pmat%map(ip))
                do idir = 1, ndim
                  psi(idir, ist, ip) = psi2b(idir)%X(ff_pack)(ist, pmat%map(ip))
                end do
              end do
            end do
          end if

        case (BATCH_NOT_PACKED)
          if (allocated(this%projector_phases)) then ! Complex batches with phase
#ifdef R_TCOMPLEX
            !$omp parallel do private(ist, idir)
            do ip = 1, npoints
              do ist = 1, nst_linear
                psi(0, ist, ip) = psi1b%X(ff_linear)(pmat%map(ip), ist) * this%projector_phases(ip, 1, imat, psi1b%ik)
                do idir = 1, ndim
                  psi(idir, ist, ip) = psi2b(idir)%X(ff_linear)(pmat%map(ip), ist) * this%projector_phases(ip, 1, imat, psi1b%ik)
                end do
              end do
            end do
#endif
          else ! Real batches, or complex batches without phases
            !$omp parallel do private(ist, idir)
            do ip = 1, npoints
              do ist = 1, nst_linear
                psi(0, ist, ip) = psi1b%X(ff_linear)(pmat%map(ip), ist)
                do idir = 1, ndim
                  psi(idir, ist, ip) = psi2b(idir)%X(ff_linear)(pmat%map(ip), ist)
                end do
              end do
            end do
          end if

        case default
          message(1) = "Unknown batch status"
          call messages_fatal(1)
        end select

        call profiling_out(TOSTRING(X(PROJ_MAT_ELEM_GATHER)))

        ! Now matrix-multiply to calculate the projections. We can do all the matrix multiplications at once
        if (.not. pmat%is_cmplx) then
          call blas_gemm('N', 'N', (ndim + 1)*nreal, nprojs, npoints, M_ONE, &
            psi(0, 1, 1), (ndim + 1)*nreal, pmat%dprojectors(1, 1), npoints, &
            M_ZERO, projs(0, 1, iprojection + 1), (ndim + 1)*nreal)

          call profiling_count_operations(nreal*(ndim + 1)*nprojs*M_TWO*npoints)
        else
#ifdef R_TCOMPLEX
          SAFE_ALLOCATE(tmp_proj(1:nprojs, 1:nst_linear*(ndim + 1), 1))
          call blas_gemm('C', 'T', nprojs, (ndim + 1)*nst_linear, npoints, &
            M_z1, pmat%zprojectors(1, 1), npoints, psi(0, 1, 1), (ndim + 1)*nst_linear, &
            M_z0, tmp_proj(1,1,1), nprojs)
          do iproj = 1, nprojs
            do ist = 1, nst_linear
              do idir = 0, ndim
                projs(idir , ist, iprojection + iproj) = tmp_proj(iproj, (ist-1)*(ndim+1)+idir+1, 1)
              end do
            end do
          end do
          SAFE_DEALLOCATE_A(tmp_proj)

          call profiling_count_operations(nst_linear*(ndim + 1)*nprojs*(R_ADD+R_MUL)*npoints)
#endif
        end if

      else

        projs(0:ndim, 1:nst_linear, iprojection + 1:iprojection + nprojs) = 0.0_real64

      end if

      SAFE_DEALLOCATE_A(psi)

      iprojection = iprojection + nprojs

    end do


  else
    ! Mostly Cargo Cult programming, copying accelerated projector application
    ! from nonlocal_pseudopotential_nlocal_start, but modified to deal with
    ! the gradient of psi, psi2b, having ndim directions
    if (allocated(this%projector_phases)) then
      if (this%projector_matrices(1)%is_cmplx) then
        call accel_kernel_start_call(zker_proj_bra_force_phase, 'projector.cl', 'zprojector_bra_force_phase',&
          flags = '-DRTYPE_COMPLEX')
        kernel => zker_proj_bra_force_phase
      else
        call accel_kernel_start_call(dker_proj_bra_force_phase, 'projector.cl', 'dprojector_bra_force_phase',&
          flags = '-DRTYPE_DOUBLE')
        kernel => dker_proj_bra_force_phase
      end if
      ASSERT(R_TYPE_VAL == TYPE_CMPLX)
      size = psi1b%pack_size(1)
    else
      if (this%projector_matrices(1)%is_cmplx) then
        call accel_kernel_start_call(zker_proj_bra_force, 'projector.cl', 'zprojector_bra_force',&
          flags = '-DRTYPE_COMPLEX')
        kernel => zker_proj_bra_force
        size = psi1b%pack_size(1)
      else
        call accel_kernel_start_call(dker_proj_bra_force, 'projector.cl', 'dprojector_bra_force',&
          flags = '-DRTYPE_DOUBLE')
        kernel => dker_proj_bra_force
        size = psi1b%pack_size_real(1)
      end if
    end if

    SAFE_ALLOCATE(projs(0:ndim, 1:psi1b%pack_size(1), 1:this%full_projection_size))
    projs = 0.0_real64

    if (this%max_npoints > 0) then
      call accel_create_buffer(buff_projs, ACCEL_MEM_READ_WRITE, R_TYPE_VAL, &
        (ndim+1)*this%full_projection_size*psi1b%pack_size(1))
      call profiling_in(TOSTRING(X(CL_PROJ_FORCE)))

      do idir = 0, ndim
        call accel_set_kernel_arg(kernel, 0, idir)
        call accel_set_kernel_arg(kernel, 1, ndim+1)
        call accel_set_kernel_arg(kernel, 2, int(size/psi1b%pack_size(1), int32))
        call accel_set_kernel_arg(kernel, 3, this%nprojector_matrices)
        call accel_set_kernel_arg(kernel, 4, this%buff_offsets)
        call accel_set_kernel_arg(kernel, 5, this%buff_matrices)
        call accel_set_kernel_arg(kernel, 6, this%buff_maps)
        if ( idir == 0 ) then
          call accel_set_kernel_arg(kernel, 7, psi1b%ff_device)
        else
          call accel_set_kernel_arg(kernel, 7, psi2b(idir)%ff_device)
        end if
        call accel_set_kernel_arg(kernel, 8, int(log2(size), int32))
        call accel_set_kernel_arg(kernel, 9, buff_projs)
        call accel_set_kernel_arg(kernel,10, int(log2(size), int32))

        if (allocated(this%projector_phases)) then
          call accel_set_kernel_arg(kernel, 11, this%buff_projector_phases)
          ! Note: we need to use this%nphase, as the kernel might be called with spiral=false, but
          !       the phases been built with spiralBC=true
          !> spiral BC ik should be the same for psi2b
          call accel_set_kernel_arg(kernel, 12, (psi1b%ik - st%d%kpt%start)*this%total_points*this%nphase)
        end if
        ! In case of CUDA we use an optimized kernel, in which the loop over npoints is broken
        ! further into chunks, in order to parallelize over the threads within a warp.
        ! Therefore we need to launch warp_size * size kernels. The size of each block needs to
        ! have multiples of warp_size as x-dimension.
        call accel_get_unfolded_size(size, grid_size, thread_block_size)
        lnprojs = min(accel_kernel_workgroup_size(kernel)/thread_block_size, int(this%max_nprojs, int64))
        padnprojs = pad(this%max_nprojs, lnprojs)

        call accel_kernel_run(kernel, &
          (/grid_size, padnprojs, int(this%nprojector_matrices, int64)/), (/thread_block_size, lnprojs, 1_int64/))

        do imat = 1, this%nprojector_matrices
          pmat => this%projector_matrices(imat)

          npoints = pmat%npoints
          nprojs = pmat%nprojs

          !! update number of operations for nphase !!
          call profiling_count_operations(nreal*nprojs*M_TWO*npoints + nst_linear*nprojs)
        end do
      end do
      call accel_read_buffer(buff_projs,(ndim+1)*this%full_projection_size*psi1b%pack_size(1), projs)
      call accel_release_buffer(buff_projs)

      call profiling_out(TOSTRING(X(CL_PROJ_FORCE)))
    end if

  end if

  call profiling_in(TOSTRING(X(VNLPSI_MAT_ELEM_REDUCE)))
  call mesh%allreduce(projs)
  call profiling_out(TOSTRING(X(VNLPSI_MAT_ELEM_REDUCE)))

  iprojection = 0
  do imat = 1, this%nprojector_matrices
    pmat => this%projector_matrices(imat)

    npoints = pmat%npoints
    nprojs = pmat%nprojs

    iatom = this%projector_to_atom(imat)

    if (allocated(pmat%zmix)) then
#ifdef R_TCOMPLEX
      SAFE_ALLOCATE(tmp_proj(1:nprojs, 1:psi1b%nst, 1:st%d%dim))

      do idir = 1, ndim
        do ist = 1, psi1b%nst
          tmp_proj(1:nprojs, ist, 1) = matmul(pmat%zmix(1:nprojs, 1:nprojs, 1), &
            projs(idir, (ist-1)*st%d%dim+1, iprojection + 1:iprojection + nprojs)) &
            + matmul(pmat%zmix(1:nprojs, 1:nprojs, 3), &
            projs(idir, (ist-1)*st%d%dim+2, iprojection + 1:iprojection + nprojs))
          tmp_proj(1:nprojs, ist, 2) = matmul(pmat%zmix(1:nprojs, 1:nprojs, 2), &
            projs(idir, (ist-1)*st%d%dim+2, iprojection + 1:iprojection + nprojs)) &
            + matmul(pmat%zmix(1:nprojs, 1:nprojs, 4), &
            projs(idir, (ist-1)*st%d%dim+1, iprojection + 1:iprojection + nprojs))
        end do

        do ist = 1, psi1b%nst
          do idim = 1, st%d%dim
            do iproj = 1, nprojs
              projs(idir, (ist-1)*st%d%dim+idim, iprojection + iproj) = tmp_proj(iproj, ist, idim)
            end do
          end do
        end do
      end do

      SAFE_DEALLOCATE_A(tmp_proj)
#else
      ! Complex projector matrix not allowed for real batches
      ASSERT(.false.)
#endif
    else if (allocated(pmat%dmix)) then

      do idir = 1, ndim
        do ist = 1, nst_linear
          projs(idir, ist, iprojection + 1:iprojection + nprojs) = &
            matmul(pmat%dmix(1:nprojs, 1:nprojs), projs(idir, ist, iprojection + 1:iprojection + nprojs))
        end do
      end do
    end if

    SAFE_ALLOCATE(ff(1:ndim))

    ff(1:ndim) = 0.0_real64

    do ii = 1, psi1b%nst_linear
      ist = psi1b%linear_to_ist(ii)
      if (st%kweights(iqn)*abs(st%occ(ist, iqn)) <= M_EPSILON) cycle
      do iproj = 1, nprojs
        do idir = 1, ndim
          ff(idir) = ff(idir) - M_TWO*st%kweights(iqn)*st%occ(ist, iqn)*pmat%scal(iproj)*mesh%volume_element* &
            R_REAL(R_CONJ(projs(0, ii, iprojection + iproj))*projs(idir, ii, iprojection + iproj))
        end do
      end do
    end do

    force(1:ndim, iatom) = force(1:ndim, iatom) + ff(1:ndim)

    call profiling_count_operations((R_ADD + 2*R_MUL)*nst_linear*ndim*nprojs)

    SAFE_DEALLOCATE_A(ff)

    iprojection = iprojection + nprojs

  end do

  SAFE_DEALLOCATE_A(projs)
  call profiling_out(TOSTRING(X(VNLPSI_MAT_ELEM)))

  POP_SUB(X(nonlocal_pseudopotential_force))
end subroutine X(nonlocal_pseudopotential_force)

! ---------------------------------------------------------------------------------------
!> @brief apply the commutator between the non-local potential and the position to the wave functions.
!!
!! TODO: add more details
!
subroutine X(nonlocal_pseudopotential_position_commutator)(this, mesh, std, spiral_bnd, psib, commpsib, async)
  class(nonlocal_pseudopotential_t), target, intent(in)    :: this
  class(mesh_t),                         intent(in)    :: mesh        !< the mesh
  type(states_elec_dim_t),               intent(in)    :: std         !< the electronic states
  logical,                               intent(in)    :: spiral_bnd  !< flag for spiral boundary conditions
  type(wfs_elec_t),                      intent(in)    :: psib        !< original wave functions psi
  class(wfs_elec_t), target,             intent(inout) :: commpsib(:) !< resulting [v_nl, r]*psi
  logical, optional,                     intent(in)    :: async

  integer :: ist, ip, iproj, imat, nreal, iprojection, idir
  integer :: npoints, nprojs, nst
  integer, allocatable :: ind(:)
  R_TYPE :: aa, bb, cc, dd
  R_TYPE, allocatable :: projections(:, :, :)
  R_TYPE, allocatable :: psi(:, :, :), lpsi(:,:)
  type(projector_matrix_t), pointer :: pmat
  integer(int64) :: wgsize, size
  class(wfs_elec_t), pointer :: commpsib_(:)
#ifdef R_TCOMPLEX
  integer :: idim
  complex(real64) :: phase
  complex(real64), allocatable :: tmp_proj(:, :, :)
#endif

  if (.not. this%has_non_local_potential) return

  ASSERT(this%apply_projector_matrices)

  PUSH_SUB(X(nonlocal_pseudopotential_position_commutator))
  call profiling_in(TOSTRING(X(COMMUTATOR)))

  ASSERT(psib%is_packed())
  ASSERT(.not. spiral_bnd)

  nst = psib%nst_linear
#ifdef R_TCOMPLEX
  nreal = 2*nst
#else
  nreal = nst
#endif
  ! this pointer assigment is needed as a workaround for gcc 12
  ! otherwise it throws errors related to openmp and polymorphic arrays
  commpsib_ => commpsib

  if (psib%status() == BATCH_DEVICE_PACKED) then
    call X(commutator_accel)()
    call profiling_out(TOSTRING(X(COMMUTATOR)))
    POP_SUB(X(nonlocal_pseudopotential_position_commutator))
    return
  end if

  SAFE_ALLOCATE(projections(1:nst, 1:this%full_projection_size, 0:3))
  projections = M_ZERO

  SAFE_ALLOCATE(ind(1:this%nprojector_matrices))

  iprojection = 0
  do imat = 1, this%nprojector_matrices
    pmat => this%projector_matrices(imat)
    npoints = pmat%npoints
    nprojs = pmat%nprojs
    ind(imat) = iprojection
    iprojection = iprojection + nprojs
    !    call profiling_count_operations(nprojs*(R_ADD + R_MUL)*npoints + nst*nprojs)
  end do

  !$omp parallel do private(imat, pmat, iprojection, npoints, nprojs, iproj, ist, aa, bb, cc, dd, ip, lpsi)
  do imat = 1, this%nprojector_matrices
    pmat => this%projector_matrices(imat)
    iprojection = ind(imat)
    npoints = pmat%npoints
    nprojs = pmat%nprojs

    if (npoints == 0) cycle

    SAFE_ALLOCATE(lpsi(1:npoints, 1:nst))
    if (.not. allocated(this%projector_phases)) then
      do ist = 1, nst
        !$omp simd
        do ip = 1, npoints
          lpsi(ip, ist) = psib%X(ff_pack)(ist, pmat%map(ip))
        end do
      end do
    else
#ifdef R_TCOMPLEX
      do ip = 1, npoints
        !$omp simd
        do ist = 1, nst
          lpsi(ip, ist) = psib%zff_pack(ist, pmat%map(ip)) &
            *this%projector_phases(ip, 1, imat, psib%ik)
        end do
      end do
#else
      ! Phases not allowed for real batches
      ASSERT(.false.)
#endif
    end if

    do iproj = 1, nprojs

      if (pmat%is_cmplx) then
#ifdef R_TCOMPLEX
        do ist = 1, nst
          aa = 0.0_real64
          bb = 0.0_real64
          cc = 0.0_real64
          dd = 0.0_real64

          !$omp simd reduction(+:aa, bb, cc, dd)
          do ip = 1, npoints
            aa = aa + R_CONJ(pmat%zprojectors(ip, iproj))*lpsi(ip, ist)
            bb = bb + R_CONJ(pmat%zprojectors(ip, iproj))*pmat%position(1, ip)*lpsi(ip, ist)
            cc = cc + R_CONJ(pmat%zprojectors(ip, iproj))*pmat%position(2, ip)*lpsi(ip, ist)
            dd = dd + R_CONJ(pmat%zprojectors(ip, iproj))*pmat%position(3, ip)*lpsi(ip, ist)
          end do
          projections(ist, iprojection + iproj, 0) = pmat%scal(iproj)*aa
          projections(ist, iprojection + iproj, 1) = pmat%scal(iproj)*bb
          projections(ist, iprojection + iproj, 2) = pmat%scal(iproj)*cc
          projections(ist, iprojection + iproj, 3) = pmat%scal(iproj)*dd
        end do
#else
        ! Complex projection matrix not allowed for real batches
        ASSERT(.false.)
#endif
      else
        do ist = 1, nst
          aa = 0.0_real64
          bb = 0.0_real64
          cc = 0.0_real64
          dd = 0.0_real64
          !$omp simd reduction(+:aa, bb, cc, dd)
          do ip = 1, npoints
            aa = aa + pmat%dprojectors(ip, iproj)*lpsi(ip, ist)
            bb = bb + pmat%dprojectors(ip, iproj)*pmat%position(1, ip)*lpsi(ip, ist)
            cc = cc + pmat%dprojectors(ip, iproj)*pmat%position(2, ip)*lpsi(ip, ist)
            dd = dd + pmat%dprojectors(ip, iproj)*pmat%position(3, ip)*lpsi(ip, ist)
          end do
          projections(ist, iprojection + iproj, 0) = pmat%scal(iproj)*aa
          projections(ist, iprojection + iproj, 1) = pmat%scal(iproj)*bb
          projections(ist, iprojection + iproj, 2) = pmat%scal(iproj)*cc
          projections(ist, iprojection + iproj, 3) = pmat%scal(iproj)*dd
        end do

      end if
    end do

    SAFE_DEALLOCATE_A(lpsi)
  end do
  !$omp end parallel do

  ! reduce the projections
  call profiling_in(TOSTRING(X(COMMUTATOR_REDUCE)))
  call mesh%allreduce(projections)
  call profiling_out(TOSTRING(X(COMMUTATOR_REDUCE)))

  iprojection = 0
  do imat = 1, this%nprojector_matrices
    pmat => this%projector_matrices(imat)

    npoints = pmat%npoints
    nprojs = pmat%nprojs

    if (allocated(pmat%zmix)) then
#ifdef R_TCOMPLEX
      SAFE_ALLOCATE(tmp_proj(1:nprojs, 1:psib%nst, 1:std%dim))

      do idir = 0, 3
        do ist = 1, psib%nst
          tmp_proj(1:nprojs, ist, 1) = matmul(pmat%zmix(1:nprojs, 1:nprojs, 1), &
            projections((ist-1)*std%dim+1, iprojection + 1:iprojection + nprojs, idir)) &
            + matmul(pmat%zmix(1:nprojs, 1:nprojs, 3), &
            projections((ist-1)*std%dim+2, iprojection + 1:iprojection + nprojs, idir))
          tmp_proj(1:nprojs, ist, 2) = matmul(pmat%zmix(1:nprojs, 1:nprojs, 2), &
            projections((ist-1)*std%dim+2, iprojection + 1:iprojection + nprojs, idir)) &
            + matmul(pmat%zmix(1:nprojs, 1:nprojs, 4), &
            projections((ist-1)*std%dim+1, iprojection + 1:iprojection + nprojs, idir))
        end do

        do ist = 1, psib%nst
          do idim = 1, std%dim
            do iproj = 1, nprojs
              projections((ist-1)*std%dim+idim, iprojection + iproj, idir) = tmp_proj(iproj, ist, idim)
            end do
          end do
        end do
      end do

      SAFE_DEALLOCATE_A(tmp_proj)
#else
      ! Complex projection matrix not allowed for real batches
      ASSERT(.false.)
#endif
    else if (allocated(pmat%dmix)) then
      do idir = 0, 3
        do ist = 1, nst
          projections(ist, iprojection + 1:iprojection + nprojs, idir) = &
            matmul(pmat%dmix(1:nprojs, 1:nprojs), projections(ist, iprojection + 1:iprojection + nprojs, idir))
        end do
      end do
    end if

    if (npoints /=  0) then

      SAFE_ALLOCATE(psi(1:nst, 1:npoints, 0:3))

      ! Matrix-multiply again.
      ! the line below does: psi = matmul(projection, transpose(pmat%projectors))

      if (.not. pmat%is_cmplx) then
        do idir = 0, 3
          call blas_gemm('N', 'T', nreal, npoints, nprojs, &
            M_ONE, projections(1, iprojection + 1, idir), nreal, pmat%dprojectors(1, 1), npoints, &
            M_ZERO, psi(1, 1, idir), nreal)
        end do
        call profiling_count_operations(nreal*nprojs*M_TWO*npoints*4)

      else
#ifdef R_TCOMPLEX
        do idir = 0, 3
          call blas_gemm('N', 'T', nst, npoints, nprojs, &
            M_z1, projections(1, iprojection + 1, idir), nst, pmat%zprojectors(1, 1), npoints, &
            M_z0, psi(1, 1, idir), nst)
        end do
#endif
        call profiling_count_operations(nst*nprojs*(R_ADD+R_MUL)*npoints*4)
      end if

      if (allocated(this%projector_phases)) then
#ifdef R_TCOMPLEX
        !$omp parallel private(ip, ist, phase)
        do idir = 0, 3
          !$omp do
          do ip = 1, npoints
            phase = conjg(this%projector_phases(ip, 1, imat, psib%ik))
            !$omp simd
            do ist = 1, nst
              psi(ist, ip, idir) = phase*psi(ist, ip, idir)
            end do
          end do
          !$omp end do nowait
        end do
        !$omp end parallel
        call profiling_count_operations(nst*npoints*3*R_MUL)
#else
        ! Phases not allowed for real batches
        ASSERT(.false.)
#endif
      end if

      !$omp parallel private(ip, ist)  if (.not. this%projector_self_overlap)
      do idir = 1, 3
        !$omp do
        do ip = 1, npoints
          do ist = 1, nst
            commpsib_(idir)%X(ff_pack)(ist, pmat%map(ip)) = commpsib_(idir)%X(ff_pack)(ist, pmat%map(ip)) &
              - psi(ist, ip, idir) + pmat%position(idir, ip)*psi(ist, ip, 0)
          end do
        end do
        !$omp end do nowait
      end do
      !$omp end parallel

      call profiling_count_operations(nst*npoints*3*(2*R_ADD+R_MUL))
    end if

    SAFE_DEALLOCATE_A(psi)

    iprojection = iprojection + nprojs
  end do

  SAFE_DEALLOCATE_A(ind)

  call profiling_out(TOSTRING(X(COMMUTATOR)))
  POP_SUB(X(nonlocal_pseudopotential_position_commutator))

contains

  subroutine X(commutator_accel)()
    type(accel_kernel_t), target, save :: dker_commutator_bra, dker_commutator_bra_phase, dker_mix
    type(accel_kernel_t), target, save :: dker_commutator_ket, dker_commutator_ket_phase
    type(accel_kernel_t), target, save :: zker_commutator_bra, zker_commutator_bra_phase, zker_mix
    type(accel_kernel_t), target, save :: zker_commutator_ket, zker_commutator_ket_phase
    type(accel_kernel_t), pointer :: kernel, ker_mix
    type(accel_mem_t), target :: buff_proj
    type(accel_mem_t), pointer :: buff_proj_copy
    integer(int64) :: padnprojs, lnprojs, size2
    integer :: iregion, nregions_self_overlap, iregion_self_overlap
    R_TYPE, allocatable :: proj(:)

    if (allocated(this%projector_phases)) then
      ASSERT(R_TYPE_VAL == TYPE_CMPLX)
      if (this%projector_matrices(1)%is_cmplx) then
        call accel_kernel_start_call(zker_commutator_bra_phase, 'projector_commutator.cl', 'zprojector_commutator_bra_phase',&
          flags = '-DRTYPE_COMPLEX')
        kernel => zker_commutator_bra_phase
      else
        call accel_kernel_start_call(dker_commutator_bra_phase, 'projector_commutator.cl', 'dprojector_commutator_bra_phase',&
          flags = '-DRTYPE_DOUBLE')
        kernel => dker_commutator_bra_phase
      end if
      size = psib%pack_size(1)
    else
      if (this%projector_matrices(1)%is_cmplx) then
        call accel_kernel_start_call(zker_commutator_bra, 'projector_commutator.cl', 'zprojector_commutator_bra',&
          flags = '-DRTYPE_COMPLEX')
        size = psib%pack_size(1)
        kernel => zker_commutator_bra
      else
        call accel_kernel_start_call(dker_commutator_bra, 'projector_commutator.cl', 'dprojector_commutator_bra',&
          flags = '-DRTYPE_DOUBLE')
        size = psib%pack_size_real(1)
        kernel => dker_commutator_bra
      end if
    end if

    call accel_create_buffer(buff_proj, ACCEL_MEM_READ_WRITE, R_TYPE_VAL, 4*this%full_projection_size*psib%pack_size(1))

    call accel_set_kernel_arg(kernel,  0, this%nprojector_matrices)
    call accel_set_kernel_arg(kernel,  1, this%buff_offsets)
    call accel_set_kernel_arg(kernel,  2, this%buff_matrices)
    call accel_set_kernel_arg(kernel,  3, this%buff_maps)
    call accel_set_kernel_arg(kernel,  4, this%buff_scals)
    call accel_set_kernel_arg(kernel,  5, this%buff_position)
    call accel_set_kernel_arg(kernel,  6, psib%ff_device)
    call accel_set_kernel_arg(kernel,  7, log2(int(size, int32)))
    call accel_set_kernel_arg(kernel,  8, buff_proj)
    call accel_set_kernel_arg(kernel,  9, log2(int(size, int32)))

    if (allocated(this%projector_phases)) then
      call accel_set_kernel_arg(kernel, 10, this%buff_projector_phases)
      call accel_set_kernel_arg(kernel, 11, (psib%ik - std%kpt%start)*this%total_points)
    end if

    lnprojs = min(accel_kernel_workgroup_size(kernel)/size, int(this%max_nprojs, int64))
    padnprojs = pad(this%max_nprojs, lnprojs)

    call accel_kernel_run(kernel, (/size, padnprojs, int(this%nprojector_matrices, int64)/), (/size, lnprojs, 1_int64/))


    if (mesh%parallel_in_domains) then
      SAFE_ALLOCATE(proj(1:4*this%full_projection_size*psib%pack_size(1)))
      call accel_read_buffer(buff_proj, 4*this%full_projection_size*psib%pack_size(1), proj)
      call mesh%allreduce(proj)
      call accel_write_buffer(buff_proj, 4*this%full_projection_size*psib%pack_size(1), proj)
      SAFE_DEALLOCATE_A(proj)
    end if

    if (this%projector_mix) then

      SAFE_ALLOCATE(buff_proj_copy)

      if (allocated(this%projector_matrices(1)%zmix)) then
        call accel_kernel_start_call(zker_mix, 'projector_commutator.cl', 'zprojector_mix_commutator', flags = '-DRTYPE_COMPLEX')
        ker_mix => zker_mix
        size = psib%pack_size(1)
        size2 = size/2
        call accel_create_buffer(buff_proj_copy, ACCEL_MEM_READ_WRITE, TYPE_CMPLX, &
          4*this%full_projection_size*size)
      else
        call accel_kernel_start_call(dker_mix, 'projector_commutator.cl', 'dprojector_mix_commutator', flags = '-DRTYPE_DOUBLE')
        ker_mix => dker_mix
        size = psib%pack_size_real(1)
        size2 = size
        call accel_create_buffer(buff_proj_copy, ACCEL_MEM_READ_WRITE, TYPE_FLOAT, &
          4*this%full_projection_size*size)
      end if
      call accel_set_kernel_arg(ker_mix, 0, this%nprojector_matrices)
      call accel_set_kernel_arg(ker_mix, 1, this%buff_offsets)
      call accel_set_kernel_arg(ker_mix, 2, this%buff_mix)
      call accel_set_kernel_arg(ker_mix, 3, buff_proj)
      call accel_set_kernel_arg(ker_mix, 4, log2(int(size, int32)))
      call accel_set_kernel_arg(ker_mix, 5, buff_proj_copy)

      lnprojs = min(accel_kernel_workgroup_size(kernel)/size2, int(this%max_nprojs, int64))
      padnprojs = pad(this%max_nprojs, lnprojs)

      call accel_kernel_run(ker_mix, (/size2, padnprojs, int(this%nprojector_matrices, int64)/), (/size2, lnprojs, 1_int64/))


    else

      buff_proj_copy => buff_proj

    end if

    if (allocated(this%projector_phases)) then
      ASSERT(R_TYPE_VAL == TYPE_CMPLX)

      if (this%projector_matrices(1)%is_cmplx) then
        call accel_kernel_start_call(zker_commutator_ket_phase, 'projector_commutator.cl', 'zprojector_commutator_ket_phase',&
          flags = '-DRTYPE_COMPLEX')
        kernel => zker_commutator_ket_phase
      else
        call accel_kernel_start_call(dker_commutator_ket_phase, 'projector_commutator.cl', 'dprojector_commutator_ket_phase',&
          flags = '-DRTYPE_DOUBLE')
        kernel => dker_commutator_ket_phase
      end if
      size = psib%pack_size(1)
    else
      if (this%projector_matrices(1)%is_cmplx) then
        call accel_kernel_start_call(zker_commutator_ket, 'projector_commutator.cl', 'zprojector_commutator_ket',&
          flags = '-DRTYPE_COMPLEX')
        kernel => zker_commutator_ket
        size = psib%pack_size(1)
      else
        call accel_kernel_start_call(dker_commutator_ket, 'projector_commutator.cl', 'dprojector_commutator_ket',&
          flags = '-DRTYPE_DOUBLE')
        kernel => dker_commutator_ket
        size = psib%pack_size_real(1)
      end if
    end if

    do iregion = 1, this%nregions

      do imat = this%regions(iregion), this%regions(iregion+1)-1

        nregions_self_overlap = this%projector_matrices(imat)%nregions

        do iregion_self_overlap = 1, nregions_self_overlap

          call accel_set_kernel_arg(kernel,  0, this%nprojector_matrices)
          call accel_set_kernel_arg(kernel,  1, imat - 1)
          call accel_set_kernel_arg(kernel,  2, this%buff_offsets)
          call accel_set_kernel_arg(kernel,  3, this%buff_matrices)
          call accel_set_kernel_arg(kernel,  4, this%buff_maps)
          call accel_set_kernel_arg(kernel,  5, this%buff_position)
          call accel_set_kernel_arg(kernel,  6, buff_proj_copy)
          call accel_set_kernel_arg(kernel,  7, log2(int(size, int32)))
          call accel_set_kernel_arg(kernel,  8, commpsib(1)%ff_device)
          call accel_set_kernel_arg(kernel,  9, commpsib(2)%ff_device)
          call accel_set_kernel_arg(kernel, 10, commpsib(3)%ff_device)
          call accel_set_kernel_arg(kernel, 11, log2(int(size, int32)))

          call accel_set_kernel_arg(kernel, 12, this%projector_matrices(imat)%regions(iregion_self_overlap) - 1)
          call accel_set_kernel_arg(kernel, 13, this%projector_matrices(imat)%regions(iregion_self_overlap+1) - 1)

          if (allocated(this%projector_phases)) then
            call accel_set_kernel_arg(kernel, 14, this%buff_projector_phases)
            call accel_set_kernel_arg(kernel, 15, (psib%ik - std%kpt%start)*this%total_points)
          end if

          wgsize = accel_kernel_workgroup_size(kernel)/size

          call accel_kernel_run(kernel, &
            (/size, pad(this%max_npoints, wgsize), 1_int64 /), &
            (/size, wgsize, 1_int64/))

        end do

      end do

    end do

    if (this%projector_mix) then
      call accel_release_buffer(buff_proj_copy)
      SAFE_ALLOCATE(buff_proj_copy)
    end if

    if (.not. optional_default(async,.false.)) call accel_finish()

    call accel_release_buffer(buff_proj)

  end subroutine X(commutator_accel)

end subroutine X(nonlocal_pseudopotential_position_commutator)

! ---------------------------------------------------------------------------------------
!> @brief Accumulates to commpsib the result of x V_{nl} | psib >
!
subroutine X(nonlocal_pseudopotential_r_vnlocal)(this, mesh, std, spiral_bnd, psib, commpsib)
  class(nonlocal_pseudopotential_t), target, intent(in)   :: this
  class(mesh_t),                         intent(in)    :: mesh
  type(states_elec_dim_t),               intent(in)    :: std
  logical,                               intent(in)    :: spiral_bnd
  type(wfs_elec_t),                      intent(in)    :: psib
  class(wfs_elec_t), target,             intent(inout) :: commpsib(1:3)

  integer :: ist, ip, iproj, imat, nreal, iprojection, idir
  integer :: npoints, nprojs, nst
  integer, allocatable :: ind(:)
  R_TYPE :: aa
  R_TYPE, allocatable :: projections(:, :)
  R_TYPE, allocatable :: psi(:, :), lpsi(:,:)
  type(projector_matrix_t), pointer :: pmat
  integer(int64) :: wgsize, size
  class(wfs_elec_t), pointer :: commpsib_(:)
#ifdef R_TCOMPLEX
  integer :: idim
  complex(real64) :: phase
  complex(real64), allocatable :: tmp_proj(:, :, :)
#endif

  if (.not. this%has_non_local_potential) return

  ASSERT(this%apply_projector_matrices)

  PUSH_SUB(X(nonlocal_pseudopotential_r_vnlocal))
  call profiling_in(TOSTRING(X(R_VNL)))

  ASSERT(psib%is_packed())
  ASSERT(.not. spiral_bnd)

  nst = psib%nst_linear
#ifdef R_TCOMPLEX
  nreal = 2*nst
#else
  nreal = nst
#endif
  ! this pointer assigment is needed as a workaround for gcc 12
  ! otherwise it throws errors related to openmp and polymorphic arrays
  commpsib_ => commpsib

  if (psib%status() == BATCH_DEVICE_PACKED) then
    call X(commutator_accel)()
    call profiling_out(TOSTRING(X(R_VNL)))
    POP_SUB(X(nonlocal_pseudopotential_r_vnlocal))
    return
  end if

  SAFE_ALLOCATE(projections(1:nst, 1:this%full_projection_size))
  projections = M_ZERO

  SAFE_ALLOCATE(ind(1:this%nprojector_matrices))

  iprojection = 0
  do imat = 1, this%nprojector_matrices
    pmat => this%projector_matrices(imat)
    npoints = pmat%npoints
    nprojs = pmat%nprojs
    ind(imat) = iprojection
    iprojection = iprojection + nprojs
  end do

  !$omp parallel do private(imat, pmat, iprojection, npoints, nprojs, iproj, ist, aa, ip, lpsi)
  do imat = 1, this%nprojector_matrices
    pmat => this%projector_matrices(imat)
    iprojection = ind(imat)
    npoints = pmat%npoints
    nprojs = pmat%nprojs

    if (npoints == 0) cycle

    SAFE_ALLOCATE(lpsi(1:npoints, 1:nst))
    if (.not. allocated(this%projector_phases)) then
      do ist = 1, nst
        !$omp simd
        do ip = 1, npoints
          lpsi(ip, ist) = psib%X(ff_pack)(ist, pmat%map(ip))
        end do
      end do
    else
#ifdef R_TCOMPLEX
      do ip = 1, npoints
        !$omp simd
        do ist = 1, nst
          lpsi(ip, ist) = psib%zff_pack(ist, pmat%map(ip)) &
            *this%projector_phases(ip, 1, imat, psib%ik)
        end do
      end do
#else
      ! Phases not allowed for real batches
      ASSERT(.false.)
#endif
    end if

    do iproj = 1, nprojs

      if (pmat%is_cmplx) then
#ifdef R_TCOMPLEX
        do ist = 1, nst
          aa = 0.0_real64

          !$omp simd reduction(+:aa)
          do ip = 1, npoints
            aa = aa + R_CONJ(pmat%zprojectors(ip, iproj))*lpsi(ip, ist)
          end do
          projections(ist, iprojection + iproj) = pmat%scal(iproj)*aa
        end do
#else
        ! Complex projection matrix not allowed for real batches
        ASSERT(.false.)
#endif
      else
        do ist = 1, nst
          aa = 0.0_real64
          !$omp simd reduction(+:aa)
          do ip = 1, npoints
            aa = aa + pmat%dprojectors(ip, iproj)*lpsi(ip, ist)
          end do
          projections(ist, iprojection + iproj) = pmat%scal(iproj)*aa
        end do

      end if
    end do

    SAFE_DEALLOCATE_A(lpsi)
  end do
  !$omp end parallel do

  ! reduce the projections
  call profiling_in(TOSTRING(X(R_VNL_REDUCE)))
  call mesh%allreduce(projections)
  call profiling_out(TOSTRING(X(R_VNL_REDUCE)))

  iprojection = 0
  do imat = 1, this%nprojector_matrices
    pmat => this%projector_matrices(imat)

    npoints = pmat%npoints
    nprojs = pmat%nprojs

    if (allocated(pmat%zmix)) then
#ifdef R_TCOMPLEX
      SAFE_ALLOCATE(tmp_proj(1:nprojs, 1:psib%nst, 1:std%dim))

      do ist = 1, psib%nst
        tmp_proj(1:nprojs, ist, 1) = matmul(pmat%zmix(1:nprojs, 1:nprojs, 1), &
          projections((ist-1)*std%dim+1, iprojection + 1:iprojection + nprojs)) &
          + matmul(pmat%zmix(1:nprojs, 1:nprojs, 3), &
          projections((ist-1)*std%dim+2, iprojection + 1:iprojection + nprojs))
        tmp_proj(1:nprojs, ist, 2) = matmul(pmat%zmix(1:nprojs, 1:nprojs, 2), &
          projections((ist-1)*std%dim+2, iprojection + 1:iprojection + nprojs)) &
          + matmul(pmat%zmix(1:nprojs, 1:nprojs, 4), &
          projections((ist-1)*std%dim+1, iprojection + 1:iprojection + nprojs))
      end do

      do ist = 1, psib%nst
        do idim = 1, std%dim
          do iproj = 1, nprojs
            projections((ist-1)*std%dim+idim, iprojection + iproj) = tmp_proj(iproj, ist, idim)
          end do
        end do
      end do

      SAFE_DEALLOCATE_A(tmp_proj)
#else
      ! Complex projection matrix not allowed for real batches
      ASSERT(.false.)
#endif
    else if (allocated(pmat%dmix)) then
      do ist = 1, nst
        projections(ist, iprojection + 1:iprojection + nprojs) = &
          matmul(pmat%dmix(1:nprojs, 1:nprojs), projections(ist, iprojection + 1:iprojection + nprojs))
      end do
    end if

    if (npoints /=  0) then

      SAFE_ALLOCATE(psi(1:nst, 1:npoints))

      ! Matrix-multiply again.
      ! the line below does: psi = matmul(projection, transpose(pmat%projectors))

      if (.not. pmat%is_cmplx) then
        call blas_gemm('N', 'T', nreal, npoints, nprojs, &
          M_ONE, projections(1, iprojection + 1), nreal, pmat%dprojectors(1, 1), npoints, &
          M_ZERO, psi(1, 1), nreal)
        call profiling_count_operations(nreal*nprojs*M_TWO*npoints*4)

      else
#ifdef R_TCOMPLEX
        call blas_gemm('N', 'T', nst, npoints, nprojs, &
          M_z1, projections(1, iprojection + 1), nst, pmat%zprojectors(1, 1), npoints, &
          M_z0, psi(1, 1), nst)
#endif
        call profiling_count_operations(nst*nprojs*(R_ADD+R_MUL)*npoints*4)
      end if

      if (allocated(this%projector_phases)) then
#ifdef R_TCOMPLEX
        !$omp parallel do private(ip, ist, phase)
        do ip = 1, npoints
          phase = conjg(this%projector_phases(ip, 1, imat, psib%ik))
          !$omp simd
          do ist = 1, nst
            psi(ist, ip) = phase*psi(ist, ip)
          end do
        end do
        !$omp end parallel do
        call profiling_count_operations(nst*npoints*3*R_MUL)
#else
        ! Phases not allowed for real batches
        ASSERT(.false.)
#endif
      end if

      !$omp parallel private(ip, ist)  if (.not. this%projector_self_overlap)
      do idir = 1, 3
        !$omp do
        do ip = 1, npoints
          do ist = 1, nst
            commpsib_(idir)%X(ff_pack)(ist, pmat%map(ip)) = &
              commpsib_(idir)%X(ff_pack)(ist, pmat%map(ip)) + pmat%position(idir, ip)*psi(ist, ip)
          end do
        end do
        !$omp end do nowait
      end do
      !$omp end parallel

      call profiling_count_operations(nst*npoints*3*(2*R_ADD+R_MUL))
    end if

    SAFE_DEALLOCATE_A(psi)

    iprojection = iprojection + nprojs
  end do

  SAFE_DEALLOCATE_A(ind)

  call profiling_out(TOSTRING(X(R_VNL)))
  POP_SUB(X(nonlocal_pseudopotential_r_vnlocal))

contains

  subroutine X(commutator_accel)()
    type(accel_kernel_t), target, save :: dker_commutator_bra, dker_commutator_bra_phase, dker_mix
    type(accel_kernel_t), target, save :: dker_commutator_ket, dker_commutator_ket_phase
    type(accel_kernel_t), target, save :: zker_commutator_bra, zker_commutator_bra_phase, zker_mix
    type(accel_kernel_t), target, save :: zker_commutator_ket, zker_commutator_ket_phase
    type(accel_kernel_t), pointer :: kernel, ker_mix
    type(accel_mem_t), target :: buff_proj
    type(accel_mem_t), pointer :: buff_proj_copy
    integer(int64) :: padnprojs, lnprojs, size2
    integer :: iregion, nregions_self_overlap, iregion_self_overlap
    R_TYPE, allocatable :: proj(:)


    if (allocated(this%projector_phases)) then
      ASSERT(R_TYPE_VAL == TYPE_CMPLX)
      if (this%projector_matrices(1)%is_cmplx) then
        call accel_kernel_start_call(zker_commutator_bra_phase, 'projector.cl', 'zprojector_r_vnl_bra_phase',&
          flags = '-DRTYPE_COMPLEX')
        kernel => zker_commutator_bra_phase
      else
        call accel_kernel_start_call(dker_commutator_bra_phase, 'projector.cl', 'dprojector_r_vnl_bra_phase',&
          flags = '-DRTYPE_DOUBLE')
        kernel => dker_commutator_bra_phase
      end if
      size = psib%pack_size(1)
    else
      if (this%projector_matrices(1)%is_cmplx) then
        call accel_kernel_start_call(zker_commutator_bra, 'projector.cl', 'zprojector_r_vnl_bra',&
          flags = '-DRTYPE_COMPLEX')
        size = psib%pack_size(1)
        kernel => zker_commutator_bra
      else
        call accel_kernel_start_call(dker_commutator_bra, 'projector.cl', 'dprojector_r_`vnl_bra',&
          flags = '-DRTYPE_DOUBLE')
        size = psib%pack_size_real(1)
        kernel => dker_commutator_bra
      end if
    end if

    call accel_create_buffer(buff_proj, ACCEL_MEM_READ_WRITE, R_TYPE_VAL, this%full_projection_size*psib%pack_size(1))

    call accel_set_kernel_arg(kernel,  0, this%nprojector_matrices)
    call accel_set_kernel_arg(kernel,  1, this%buff_offsets)
    call accel_set_kernel_arg(kernel,  2, this%buff_matrices)
    call accel_set_kernel_arg(kernel,  3, this%buff_maps)
    call accel_set_kernel_arg(kernel,  4, this%buff_scals)
    call accel_set_kernel_arg(kernel,  5, psib%ff_device)
    call accel_set_kernel_arg(kernel,  6, log2(int(size, int32)))
    call accel_set_kernel_arg(kernel,  7, buff_proj)
    call accel_set_kernel_arg(kernel,  8, log2(int(size, int32)))

    if (allocated(this%projector_phases)) then
      call accel_set_kernel_arg(kernel, 9, this%buff_projector_phases)
      call accel_set_kernel_arg(kernel, 10, (psib%ik - std%kpt%start)*this%total_points)
    end if

    lnprojs = min(accel_kernel_workgroup_size(kernel)/size, int(this%max_nprojs, int64))
    padnprojs = pad(this%max_nprojs, lnprojs)

    call accel_kernel_run(kernel, (/size, padnprojs, int(this%nprojector_matrices, int64)/), (/size, lnprojs, 1_int64/))

    call accel_finish()

    if (mesh%parallel_in_domains) then
      SAFE_ALLOCATE(proj(1:this%full_projection_size*psib%pack_size(1)))
      call accel_read_buffer(buff_proj, this%full_projection_size*psib%pack_size(1), proj)
      call mesh%allreduce(proj)
      call accel_write_buffer(buff_proj, this%full_projection_size*psib%pack_size(1), proj)
      SAFE_DEALLOCATE_A(proj)
    end if

    if (this%projector_mix) then

      SAFE_ALLOCATE(buff_proj_copy)

      if (allocated(this%projector_matrices(1)%zmix)) then
        call accel_kernel_start_call(zker_mix, 'projector.cl', 'zprojector_mix', flags = '-DRTYPE_COMPLEX')
        ker_mix => zker_mix
        size = psib%pack_size(1)
        ! In the case of spinors, we need to distribute states in pairs, else access random memory
        size2 = size/2
        call accel_create_buffer(buff_proj_copy, ACCEL_MEM_READ_WRITE, TYPE_CMPLX, &
          this%full_projection_size*size)
      else
        call accel_kernel_start_call(dker_mix, 'projector.cl', 'dprojector_mix', flags = '-DRTYPE_DOUBLE')
        ker_mix => dker_mix
        size = psib%pack_size_real(1)
        size2 = size
        call accel_create_buffer(buff_proj_copy, ACCEL_MEM_READ_WRITE, TYPE_FLOAT, &
          this%full_projection_size*size)
      end if
      call accel_set_kernel_arg(ker_mix, 0, this%nprojector_matrices)
      call accel_set_kernel_arg(ker_mix, 1, this%buff_offsets)
      call accel_set_kernel_arg(ker_mix, 2, this%buff_mix)
      call accel_set_kernel_arg(ker_mix, 3, buff_proj)
      call accel_set_kernel_arg(ker_mix, 4, log2(int(size, int32)))
      call accel_set_kernel_arg(ker_mix, 5, buff_proj_copy)

      lnprojs = min(accel_kernel_workgroup_size(kernel)/size2, int(this%max_nprojs, int64))
      padnprojs = pad(this%max_nprojs, lnprojs)

      call accel_kernel_run(ker_mix, (/size2, padnprojs, int(this%nprojector_matrices, int64)/), (/size2, lnprojs, 1_int64/))

      call accel_finish()

    else

      buff_proj_copy => buff_proj

    end if

    if (allocated(this%projector_phases)) then
      ASSERT(R_TYPE_VAL == TYPE_CMPLX)

      if (this%projector_matrices(1)%is_cmplx) then
        call accel_kernel_start_call(zker_commutator_ket_phase, 'projector.cl', 'zprojector_r_vnl_ket_phase',&
          flags = '-DRTYPE_COMPLEX')
        kernel => zker_commutator_ket_phase
      else
        call accel_kernel_start_call(dker_commutator_ket_phase, 'projector.cl', 'dprojector_r_vnl_ket_phase',&
          flags = '-DRTYPE_DOUBLE')
        kernel => dker_commutator_ket_phase
      end if
      size = psib%pack_size(1)
    else
      if (this%projector_matrices(1)%is_cmplx) then
        call accel_kernel_start_call(zker_commutator_ket, 'projector.cl', 'zprojector_r_vnl_ket',&
          flags = '-DRTYPE_COMPLEX')
        kernel => zker_commutator_ket
        size = psib%pack_size(1)
      else
        call accel_kernel_start_call(dker_commutator_ket, 'projector.cl', 'dprojector_r_vnl_ket',&
          flags = '-DRTYPE_DOUBLE')
        kernel => dker_commutator_ket
        size = psib%pack_size_real(1)
      end if
    end if

    do iregion = 1, this%nregions

      do imat = this%regions(iregion), this%regions(iregion+1)-1

        nregions_self_overlap = this%projector_matrices(imat)%nregions

        do iregion_self_overlap = 1, nregions_self_overlap

          call accel_set_kernel_arg(kernel,  0, this%nprojector_matrices)
          call accel_set_kernel_arg(kernel,  1, imat - 1)
          call accel_set_kernel_arg(kernel,  2, this%buff_offsets)
          call accel_set_kernel_arg(kernel,  3, this%buff_matrices)
          call accel_set_kernel_arg(kernel,  4, this%buff_maps)
          call accel_set_kernel_arg(kernel,  5, this%buff_position)
          call accel_set_kernel_arg(kernel,  6, buff_proj_copy)
          call accel_set_kernel_arg(kernel,  7, log2(int(size, int32)))
          call accel_set_kernel_arg(kernel,  8, commpsib(1)%ff_device)
          call accel_set_kernel_arg(kernel,  9, commpsib(2)%ff_device)
          call accel_set_kernel_arg(kernel, 10, commpsib(3)%ff_device)
          call accel_set_kernel_arg(kernel, 11, log2(int(size, int32)))

          call accel_set_kernel_arg(kernel, 12, this%projector_matrices(imat)%regions(iregion_self_overlap) - 1)
          call accel_set_kernel_arg(kernel, 13, this%projector_matrices(imat)%regions(iregion_self_overlap+1) - 1)

          if (allocated(this%projector_phases)) then
            call accel_set_kernel_arg(kernel, 14, this%buff_projector_phases)
            call accel_set_kernel_arg(kernel, 15, (psib%ik - std%kpt%start)*this%total_points)
          end if

          wgsize = accel_kernel_workgroup_size(kernel)/size

          call accel_kernel_run(kernel, &
            (/size, pad(this%max_npoints, wgsize), 1_int64 /), &
            (/size, wgsize, 1_int64/))

          call accel_finish()

        end do

      end do

    end do

    if (this%projector_mix) then
      call accel_release_buffer(buff_proj_copy)
      SAFE_ALLOCATE(buff_proj_copy)
    end if

    call accel_release_buffer(buff_proj)

  end subroutine X(commutator_accel)

end subroutine X(nonlocal_pseudopotential_r_vnlocal)


!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
