// clang-format off
/* ----------------------------------------------------------------------
   LAMMPS - Large-scale Atomic/Molecular Massively Parallel Simulator
   https://www.lammps.org/, Sandia National Laboratories
   LAMMPS development team: developers@lammps.org

   Copyright (2003) Sandia Corporation.  Under the terms of Contract
   DE-AC04-94AL85000 with Sandia Corporation, the U.S. Government retains
   certain rights in this software.  This software is distributed under
   the GNU General Public License.

   See the README file in the top-level LAMMPS directory.
------------------------------------------------------------------------- */

/* ----------------------------------------------------------------------
   Contributing author: Stan Moore (SNL)
------------------------------------------------------------------------- */

#include "bond_fene_kokkos.h"

#include "atom_kokkos.h"
#include "atom_masks.h"
#include "comm.h"
#include "error.h"
#include "force.h"
#include "kokkos.h"
#include "math_const.h"
#include "memory_kokkos.h"
#include "neighbor_kokkos.h"

#include <cmath>

using namespace LAMMPS_NS;
using MathConst::MY_CUBEROOT2;

/* ---------------------------------------------------------------------- */

template<class DeviceType>
BondFENEKokkos<DeviceType>::BondFENEKokkos(LAMMPS *lmp) : BondFENE(lmp)
{
  kokkosable = 1;

  atomKK = (AtomKokkos *) atom;
  neighborKK = (NeighborKokkos *) neighbor;
  execution_space = ExecutionSpaceFromDevice<DeviceType>::space;
  datamask_read = X_MASK | F_MASK | ENERGY_MASK | VIRIAL_MASK;
  datamask_modify = F_MASK | ENERGY_MASK | VIRIAL_MASK;

  d_flag = typename AT::t_int_scalar("bond:flag");
  h_flag = HAT::t_int_scalar("bond:flag_mirror");
}

/* ---------------------------------------------------------------------- */

template<class DeviceType>
BondFENEKokkos<DeviceType>::~BondFENEKokkos()
{
  if (!copymode) {
    memoryKK->destroy_kokkos(k_eatom,eatom);
    memoryKK->destroy_kokkos(k_vatom,vatom);
  }
}

/* ---------------------------------------------------------------------- */

template<class DeviceType>
void BondFENEKokkos<DeviceType>::compute(int eflag_in, int vflag_in)
{
  eflag = eflag_in;
  vflag = vflag_in;

  ev_init(eflag,vflag,0);

  // reallocate per-atom arrays if necessary

  if (eflag_atom) {
    memoryKK->destroy_kokkos(k_eatom,eatom);
    memoryKK->create_kokkos(k_eatom,eatom,maxeatom,"bond:eatom");
    d_eatom = k_eatom.view<DeviceType>();
  }
  if (vflag_atom) {
    memoryKK->destroy_kokkos(k_vatom,vatom);
    memoryKK->create_kokkos(k_vatom,vatom,maxvatom,"bond:vatom");
    d_vatom = k_vatom.view<DeviceType>();
  }

  k_k.template sync<DeviceType>();
  k_r0.template sync<DeviceType>();
  k_epsilon.template sync<DeviceType>();
  k_sigma.template sync<DeviceType>();

  x = atomKK->k_x.view<DeviceType>();
  f = atomKK->k_f.view<DeviceType>();
  neighborKK->k_bondlist.template sync<DeviceType>();
  bondlist = neighborKK->k_bondlist.view<DeviceType>();
  int nbondlist = neighborKK->nbondlist;
  nlocal = atom->nlocal;
  newton_bond = force->newton_bond;

  Kokkos::deep_copy(d_flag,0);

  copymode = 1;

  // loop over the bond list

  int bond_blocksize = 0;
  if (lmp->kokkos->bond_block_size_set)
    bond_blocksize = lmp->kokkos->bond_block_size;

  EV_FLOAT ev;

  if (evflag) {
    if (newton_bond) {
      Kokkos::parallel_reduce(Kokkos::RangePolicy<DeviceType, TagBondFENECompute<1,1> >(0,nbondlist),*this,ev);
    } else {
      Kokkos::parallel_reduce(Kokkos::RangePolicy<DeviceType, TagBondFENECompute<0,1> >(0,nbondlist),*this,ev);
    }
  } else {
    if (newton_bond) {
      if (bond_blocksize)
        Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType, TagBondFENECompute<1,0> >(0,nbondlist,Kokkos::ChunkSize(bond_blocksize)),*this);
      else
        Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType, TagBondFENECompute<1,0> >(0,nbondlist),*this);
    } else {
      if (bond_blocksize)
        Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType, TagBondFENECompute<0,0> >(0,nbondlist,Kokkos::ChunkSize(bond_blocksize)),*this);
      else
        Kokkos::parallel_for(Kokkos::RangePolicy<DeviceType, TagBondFENECompute<0,0> >(0,nbondlist),*this);
    }
  }

  Kokkos::deep_copy(h_flag,d_flag);

  if (h_flag() == 1)
    error->warning(FLERR,"FENE bond too long");
  else if (h_flag() == 2)
    error->one(FLERR,"Bad FENE bond");

  if (eflag_global) energy += static_cast<double>(ev.evdwl);
  if (vflag_global) {
    virial[0] += static_cast<double>(ev.v[0]);
    virial[1] += static_cast<double>(ev.v[1]);
    virial[2] += static_cast<double>(ev.v[2]);
    virial[3] += static_cast<double>(ev.v[3]);
    virial[4] += static_cast<double>(ev.v[4]);
    virial[5] += static_cast<double>(ev.v[5]);
  }

  if (eflag_atom) {
    k_eatom.template modify<DeviceType>();
    k_eatom.sync_host();
  }

  if (vflag_atom) {
    k_vatom.template modify<DeviceType>();
    k_vatom.sync_host();
  }

  copymode = 0;
}

template<class DeviceType>
template<int NEWTON_BOND, int EVFLAG>
// NOLINTNEXTLINE
KOKKOS_INLINE_FUNCTION
void BondFENEKokkos<DeviceType>::operator()(TagBondFENECompute<NEWTON_BOND,EVFLAG>, const int &n, EV_FLOAT& ev) const {

  // The f array is atomic
  Kokkos::View<KK_ACC_FLOAT*[3], typename DAT::t_kkacc_1d_3::array_layout,typename KKDevice<DeviceType>::value,Kokkos::MemoryTraits<Kokkos::Atomic|Kokkos::Unmanaged> > a_f = f;

  const int i1 = bondlist(n,0);
  const int i2 = bondlist(n,1);
  const int type = bondlist(n,2);

  const KK_FLOAT delx = x(i1,0) - x(i2,0);
  const KK_FLOAT dely = x(i1,1) - x(i2,1);
  const KK_FLOAT delz = x(i1,2) - x(i2,2);

  const KK_FLOAT r0 = d_r0[type];
  const KK_FLOAT k = d_k[type];
  const KK_FLOAT sigma = d_sigma[type];
  const KK_FLOAT epsilon = d_epsilon[type];

  // force from log term

  const KK_FLOAT rsq = delx*delx + dely*dely + delz*delz;
  const KK_FLOAT r0sq = r0 * r0;
  KK_FLOAT rlogarg = static_cast<KK_FLOAT>(1.0) - rsq/r0sq;

  // if r -> r0, then rlogarg < 0.0 which is an error
  // issue a warning and reset rlogarg = epsilon
  // if r > 2*r0 something serious is wrong, abort

  if (rlogarg < static_cast<KK_FLOAT>(0.1)) {
    if (rlogarg <= static_cast<KK_FLOAT>(-3.0))
      d_flag() = 2;
    else
      d_flag() = 1;
    rlogarg = static_cast<KK_FLOAT>(0.1);
  }

  KK_FLOAT fbond = -k/rlogarg;

  // force from LJ term

  KK_FLOAT sr6 = 0;
  KK_FLOAT sigma2 = sigma*sigma;
  if (rsq < static_cast<KK_FLOAT>(MY_CUBEROOT2)*sigma2) {
    const KK_FLOAT sr2 = sigma2/rsq;
    sr6 = sr2*sr2*sr2;
    fbond += static_cast<KK_FLOAT>(48.0)*epsilon*sr6*(sr6 - static_cast<KK_FLOAT>(0.5))/rsq;
  }

  // energy

  KK_FLOAT ebond = 0;
  if (eflag) {
    ebond = -static_cast<KK_FLOAT>(0.5) * k*r0sq*log(rlogarg);
    if (rsq < static_cast<KK_FLOAT>(MY_CUBEROOT2)*sigma2)
      ebond += static_cast<KK_FLOAT>(4.0)*epsilon*sr6*(sr6-static_cast<KK_FLOAT>(1.0)) + epsilon;
  }

  // apply force to each of 2 atoms

  if (NEWTON_BOND || i1 < nlocal) {
    a_f(i1,0) += static_cast<KK_ACC_FLOAT>(delx*fbond);
    a_f(i1,1) += static_cast<KK_ACC_FLOAT>(dely*fbond);
    a_f(i1,2) += static_cast<KK_ACC_FLOAT>(delz*fbond);
  }

  if (NEWTON_BOND || i2 < nlocal) {
    a_f(i2,0) -= static_cast<KK_ACC_FLOAT>(delx*fbond);
    a_f(i2,1) -= static_cast<KK_ACC_FLOAT>(dely*fbond);
    a_f(i2,2) -= static_cast<KK_ACC_FLOAT>(delz*fbond);
  }

  if (EVFLAG) ev_tally(ev,i1,i2,ebond,fbond,delx,dely,delz);
}

template<class DeviceType>
template<int NEWTON_BOND, int EVFLAG>
// NOLINTNEXTLINE
KOKKOS_INLINE_FUNCTION
void BondFENEKokkos<DeviceType>::operator()(TagBondFENECompute<NEWTON_BOND,EVFLAG>, const int &n) const {
  EV_FLOAT ev;
  this->template operator()<NEWTON_BOND,EVFLAG>(TagBondFENECompute<NEWTON_BOND,EVFLAG>(), n, ev);
}

/* ---------------------------------------------------------------------- */

template<class DeviceType>
void BondFENEKokkos<DeviceType>::allocate()
{
  BondFENE::allocate();

  int n = atom->nbondtypes;
  k_k = DAT::tdual_kkfloat_1d("BondFene::k",n+1);
  k_r0 = DAT::tdual_kkfloat_1d("BondFene::r0",n+1);
  k_epsilon = DAT::tdual_kkfloat_1d("BondFene::epsilon",n+1);
  k_sigma = DAT::tdual_kkfloat_1d("BondFene::sigma",n+1);

  d_k = k_k.template view<DeviceType>();
  d_r0 = k_r0.template view<DeviceType>();
  d_epsilon = k_epsilon.template view<DeviceType>();
  d_sigma = k_sigma.template view<DeviceType>();
}

/* ----------------------------------------------------------------------
   set coeffs for one type
------------------------------------------------------------------------- */

template<class DeviceType>
void BondFENEKokkos<DeviceType>::coeff(int narg, char **arg)
{
  BondFENE::coeff(narg, arg);

  int ilo,ihi;
  utils::bounds(FLERR,arg[0],1,atom->nbondtypes,ilo,ihi,error);

  for (int i = ilo; i <= ihi; i++) {
    k_k.view_host()[i] = static_cast<KK_FLOAT>(k[i]);
    k_r0.view_host()[i] = static_cast<KK_FLOAT>(r0[i]);
    k_epsilon.view_host()[i] = static_cast<KK_FLOAT>(epsilon[i]);
    k_sigma.view_host()[i] = static_cast<KK_FLOAT>(sigma[i]);
  }

  k_k.modify_host();
  k_r0.modify_host();
  k_epsilon.modify_host();
  k_sigma.modify_host();
}


/* ----------------------------------------------------------------------
   proc 0 reads coeffs from restart file, bcasts them
------------------------------------------------------------------------- */

template<class DeviceType>
void BondFENEKokkos<DeviceType>::read_restart(FILE *fp)
{
  BondFENE::read_restart(fp);

  int n = atom->nbondtypes;
  for (int i = 1; i <= n; i++) {
    k_k.view_host()[i] = static_cast<KK_FLOAT>(k[i]);
    k_r0.view_host()[i] = static_cast<KK_FLOAT>(r0[i]);
    k_epsilon.view_host()[i] = static_cast<KK_FLOAT>(epsilon[i]);
    k_sigma.view_host()[i] = static_cast<KK_FLOAT>(sigma[i]);
  }

  k_k.modify_host();
  k_r0.modify_host();
  k_epsilon.modify_host();
  k_sigma.modify_host();
}

/* ----------------------------------------------------------------------
   tally energy and virial into global and per-atom accumulators
------------------------------------------------------------------------- */

template<class DeviceType>
//template<int NEWTON_BOND>
// NOLINTNEXTLINE
KOKKOS_INLINE_FUNCTION
void BondFENEKokkos<DeviceType>::ev_tally(EV_FLOAT &ev, const int &i, const int &j,
      const KK_FLOAT &ebond, const KK_FLOAT &fbond, const KK_FLOAT &delx,
                const KK_FLOAT &dely, const KK_FLOAT &delz) const
{
  // The eatom and vatom arrays are atomic
  Kokkos::View<KK_ACC_FLOAT*, typename DAT::t_kkacc_1d::array_layout,typename KKDevice<DeviceType>::value,Kokkos::MemoryTraits<Kokkos::Atomic|Kokkos::Unmanaged> > v_eatom = d_eatom;
  Kokkos::View<KK_ACC_FLOAT*[6], typename DAT::t_kkacc_1d_6::array_layout,typename KKDevice<DeviceType>::value,Kokkos::MemoryTraits<Kokkos::Atomic|Kokkos::Unmanaged> > v_vatom = d_vatom;

  if (eflag_either) {
    if (eflag_global) {
      if (newton_bond) ev.evdwl += static_cast<KK_ACC_FLOAT>(ebond);
      else {
        KK_ACC_FLOAT ebondhalf = static_cast<KK_ACC_FLOAT>(static_cast<KK_FLOAT>(0.5)*ebond);
        if (i < nlocal) ev.evdwl += ebondhalf;
        if (j < nlocal) ev.evdwl += ebondhalf;
      }
    }
    if (eflag_atom) {
      KK_ACC_FLOAT ebondhalf = static_cast<KK_ACC_FLOAT>(static_cast<KK_FLOAT>(0.5)*ebond);
      if (newton_bond || i < nlocal) v_eatom[i] += ebondhalf;
      if (newton_bond || j < nlocal) v_eatom[j] += ebondhalf;
    }
  }

  if (vflag_either) {
    const KK_ACC_FLOAT v_acc[6] =
      { static_cast<KK_ACC_FLOAT>(delx*delx*fbond),
        static_cast<KK_ACC_FLOAT>(dely*dely*fbond),
        static_cast<KK_ACC_FLOAT>(delz*delz*fbond),
        static_cast<KK_ACC_FLOAT>(delx*dely*fbond),
        static_cast<KK_ACC_FLOAT>(delx*delz*fbond),
        static_cast<KK_ACC_FLOAT>(dely*delz*fbond) };

    if (vflag_global) {
      if (newton_bond) {
        for (int n = 0; n < 6; n++)
          ev.v[n] += v_acc[n];
      } else {
        if (i < nlocal) {
          for (int n = 0; n < 6; n++)
            ev.v[n] += static_cast<KK_ACC_FLOAT>(0.5)*v_acc[n];
        }
        if (j < nlocal) {
          for (int n = 0; n < 6; n++)
            ev.v[n] += static_cast<KK_ACC_FLOAT>(0.5)*v_acc[n];
        }
      }
    }

    if (vflag_atom) {
      if (newton_bond || i < nlocal) {
        for (int n = 0; n < 6; n++)
          v_vatom(i,n) += static_cast<KK_ACC_FLOAT>(0.5)*v_acc[n];
      }
      if (newton_bond || j < nlocal) {
        for (int n = 0; n < 6; n++)
          v_vatom(j,n) += static_cast<KK_ACC_FLOAT>(0.5)*v_acc[n];
      }
    }
  }
}

/* ---------------------------------------------------------------------- */

namespace LAMMPS_NS {
template class BondFENEKokkos<LMPDeviceType>;
#ifdef LMP_KOKKOS_GPU
template class BondFENEKokkos<LMPHostType>;
#endif
}

