!! Copyright (C) 2002-2006 M. Marques, A. Castro, A. Rubio, G. Bertsch
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

#include "global.h"

!> @brief This module calculates the derivatives (gradients, Laplacians, etc.)
!! of a function.
!!
!! @note the function whose derivative is to be calculated
!! *has* to be defined (1:mesh\%np_part), while the (1:mesh\%np) values of the derivative
!! are calculated.

module derivatives_oct_m
  use accel_oct_m
  use affine_coordinates_oct_m
  use batch_oct_m
  use batch_ops_oct_m
  use boundaries_oct_m
  use debug_oct_m
  use cartesian_oct_m
  use coordinate_system_oct_m
  use curv_briggs_oct_m
  use curv_gygi_oct_m
  use curv_modine_oct_m
  use global_oct_m
  use iso_c_binding
  use, intrinsic :: iso_fortran_env
  use lalg_basic_oct_m
  use lalg_adv_oct_m
  use loct_oct_m
  use math_oct_m
  use mesh_oct_m
  use mesh_function_oct_m
  use messages_oct_m
  use namespace_oct_m
  use nl_operator_oct_m
  use par_vec_oct_m
  use parser_oct_m
  use profiling_oct_m
  use space_oct_m
  use stencil_oct_m
  use stencil_cube_oct_m
  use stencil_star_oct_m
  use stencil_starplus_oct_m
  use stencil_stargeneral_oct_m
  use stencil_variational_oct_m
  use transfer_table_oct_m
  use types_oct_m
  use utils_oct_m
  use varinfo_oct_m

!   debug purposes
!   use io_binary_oct_m
!   use io_function_oct_m
!   use io_oct_m
!   use unit_oct_m
!   use unit_system_oct_m

  implicit none

  private
  public ::                             &
    derivatives_t,                      &
    derivatives_init,                   &
    derivatives_end,                    &
    derivatives_build,                  &
    derivatives_handle_batch_t,         &
    dderivatives_test,                  &
    zderivatives_test,                  &
    dderivatives_batch_start,           &
    zderivatives_batch_start,           &
    dderivatives_batch_finish,          &
    zderivatives_batch_finish,          &
    dderivatives_batch_perform,         &
    zderivatives_batch_perform,         &
    dderivatives_perform,               &
    zderivatives_perform,               &
    dderivatives_lapl,                  &
    zderivatives_lapl,                  &
    derivatives_lapl_diag,              &
    dderivatives_grad,                  &
    zderivatives_grad,                  &
    dderivatives_batch_grad,            &
    zderivatives_batch_grad,            &
    dderivatives_batch_div,             &
    zderivatives_batch_div,             &
    dderivatives_div,                   &
    zderivatives_div,                   &
    dderivatives_curl,                  &
    zderivatives_curl,                  &
    dderivatives_batch_curl,            &
    zderivatives_batch_curl,            &
    dderivatives_batch_curl_from_gradient, &
    zderivatives_batch_curl_from_gradient, &
    dderivatives_partial,               &
    zderivatives_partial,               &
    derivatives_get_lapl,               &
    derivatives_get_inner_boundary_mask, &
    derivatives_lapl_get_max_eigenvalue


  integer, parameter ::     &
    DER_BC_ZERO_F    = 0,   &  !< function is zero at the boundaries
    DER_BC_ZERO_DF   = 1,   &  !< first derivative of the function is zero
    DER_BC_PERIOD    = 2       !< boundary is periodic

  integer, parameter, public ::     &
    DER_STAR         = 1,   &
    DER_VARIATIONAL  = 2,   &
    DER_CUBE         = 3,   &
    DER_STARPLUS     = 4,   &
    DER_STARGENERAL  = 5

  integer, parameter ::     &
    BLOCKING = 1,           &
    NON_BLOCKING = 2

  !> @brief class representing derivatives
  !!
  type derivatives_t
    ! Components are public by default
    type(boundaries_t)    :: boundaries       !< boundary conditions
    type(mesh_t), pointer :: mesh => NULL()   !< pointer to the underlying mesh
    integer               :: dim = 0          !< dimensionality of the space (space%dim)
    integer, private      :: periodic_dim = 0 !< Number of periodic dimensions of the space (space%periodic_dim)
    integer               :: order = 0        !< order of the discretization (value depends on stencil)
    integer               :: stencil_type = 0 !< type of discretization

    real(real64), allocatable    :: masses(:)        !< we can have different weights (masses) per space direction

    !> If the so-called variational discretization is used, this controls a
    !! possible filter on the Laplacian.
    real(real64), private :: lapl_cutoff = M_ZERO

    type(nl_operator_t), allocatable, private :: op(:)  !< op(1:conf%dim) => gradient
    !!                                                     op(conf%dim+1) => Laplacian
    type(nl_operator_t), pointer :: lapl => NULL()      !< these are just shortcuts for op
    type(nl_operator_t), pointer :: grad(:) => NULL()

    integer, allocatable         :: n_ghost(:)   !< ghost points to add in each dimension
!#if defined(HAVE_MPI)
    integer, private             :: comm_method = 0
!#endif
    type(derivatives_t),    pointer :: finer  => NULL()
    type(derivatives_t),    pointer :: coarser => NULL()
    type(transfer_table_t), pointer :: to_finer => NULL()
    type(transfer_table_t), pointer :: to_coarser => NULL()
  end type derivatives_t

  !> @brief handle to transfer data from the start() to finish() calls.
  type derivatives_handle_batch_t
    private
!#ifdef HAVE_MPI
    type(par_vec_handle_batch_t) :: pv_h           !< handle for ghost updates
!#endif
    type(derivatives_t), pointer :: der            !< pointer to the derivatives
    type(nl_operator_t), pointer :: op             !< pointer to the operation to be performed
    type(batch_t),       pointer :: ff             !< pointer to the initial batch
    type(batch_t),       pointer :: opff           !< pointer to the final batch
    logical                      :: ghost_update   !< flag whether ghost update needs to be performed
    logical                      :: factor_present !< indicate whether a scaling factor is used
    real(real64)                 :: factor         !< optional scaling factor
  end type derivatives_handle_batch_t

  type(accel_kernel_t) :: kernel_uvw_xyz, kernel_dcurl, kernel_zcurl

contains

  ! ---------------------------------------------------------
  subroutine derivatives_init(der, namespace, space, coord_system, order)
    type(derivatives_t), target, intent(inout) :: der
    type(namespace_t),           intent(in)    :: namespace
    class(space_t),              intent(in)    :: space
    class(coordinate_system_t),  intent(in)    :: coord_system
    integer, optional,           intent(in)    :: order

    integer :: idir
    integer :: default_stencil
    character(len=40) :: flags

    PUSH_SUB(derivatives_init)

    ! Non-orthogonal curvilinear coordinates are currently not implemented
    ASSERT(.not. coord_system%local_basis .or. coord_system%orthogonal)

    ! copy this value to my structure
    der%dim = space%dim
    der%periodic_dim = space%periodic_dim

    !%Variable DerivativesStencil
    !%Type integer
    !%Default stencil_star
    !%Section Mesh::Derivatives
    !%Description
    !% Decides what kind of stencil is used, <i>i.e.</i> which points, around
    !% each point in the mesh, are the neighboring points used in the
    !% expression of the differential operator.
    !%
    !% If curvilinear coordinates are to be used, then only the <tt>stencil_starplus</tt>
    !% or the <tt>stencil_cube</tt> may be used. We only recommend the <tt>stencil_starplus</tt>,
    !% since the cube typically needs far too much memory.
    !%Option stencil_star 1
    !% A star around each point (<i>i.e.</i>, only points on the axis).
    !%Option stencil_variational 2
    !% Same as the star, but with coefficients built in a different way.
    !%Option stencil_cube 3
    !% A cube of points around each point.
    !%Option stencil_starplus 4
    !% The star, plus a number of off-axis points.
    !%Option stencil_stargeneral 5
    !% The general star. Default for non-orthogonal grids.
    !%End
    default_stencil = DER_STAR
    if (coord_system%local_basis) default_stencil = DER_STARPLUS
    if (.not. coord_system%orthogonal) default_stencil = DER_STARGENERAL

    call parse_variable(namespace, 'DerivativesStencil', default_stencil, der%stencil_type)

    if (.not. varinfo_valid_option('DerivativesStencil', der%stencil_type)) then
      call messages_input_error(namespace, 'DerivativesStencil')
    end if
    call messages_print_var_option("DerivativesStencil", der%stencil_type, namespace=namespace)

    if (coord_system%local_basis .and.  der%stencil_type < DER_CUBE) call messages_input_error(namespace, 'DerivativesStencil')
    if (der%stencil_type == DER_VARIATIONAL) then
      !%Variable DerivativesLaplacianFilter
      !%Type float
      !%Default 1.0
      !%Section Mesh::Derivatives
      !%Description
      !% Undocumented
      !%End
      call parse_variable(namespace, 'DerivativesLaplacianFilter', M_ONE, der%lapl_cutoff)
    end if

    !%Variable DerivativesOrder
    !%Type integer
    !%Default 4
    !%Section Mesh::Derivatives
    !%Description
    !% This variable gives the discretization order for the approximation of
    !% the differential operators. This means, basically, that
    !% <tt>DerivativesOrder</tt> points are used in each positive/negative
    !% spatial direction, <i>e.g.</i> <tt>DerivativesOrder = 1</tt> would give
    !% the well-known three-point formula in 1D.
    !% The number of points actually used for the Laplacian
    !% depends on the stencil used. Let <math>O</math> = <tt>DerivativesOrder</tt>, and <math>d</math> = <tt>Dimensions</tt>.
    !% <ul>
    !% <li> <tt>stencil_star</tt>: <math>2 O d + 1</math>
    !% <li> <tt>stencil_cube</tt>: <math>(2 O + 1)^d</math>
    !% <li> <tt>stencil_starplus</tt>: <math>2 O d + 1 + n</math> with <i>n</i> being 8
    !% in 2D and 24 in 3D.
    !% </ul>
    !%End
    call parse_variable(namespace, 'DerivativesOrder', 4, der%order)
    ! overwrite order if given as argument
    if (present(order)) then
      der%order = order
    end if

#ifdef HAVE_MPI
    !%Variable ParallelizationOfDerivatives
    !%Type integer
    !%Default non_blocking
    !%Section Execution::Parallelization
    !%Description
    !% This option selects how the communication of mesh boundaries is performed.
    !%Option blocking 1
    !% Blocking communication.
    !%Option non_blocking 2
    !% Communication is based on non-blocking point-to-point communication.
    !%End

    call parse_variable(namespace, 'ParallelizationOfDerivatives', NON_BLOCKING, der%comm_method)

    if (.not. varinfo_valid_option('ParallelizationOfDerivatives', der%comm_method)) then
      call messages_input_error(namespace, 'ParallelizationOfDerivatives')
    end if

    call messages_obsolete_variable(namespace, 'OverlapDerivatives', 'ParallelizationOfDerivatives')
#endif

    ! if needed, der%masses should be initialized in modelmb_particles_init
    SAFE_ALLOCATE(der%masses(1:space%dim))
    der%masses = M_ONE

    ! construct lapl and grad structures
    SAFE_ALLOCATE(der%op(1:der%dim + 1))
    der%grad => der%op
    der%lapl => der%op(der%dim + 1)

    call derivatives_get_stencil_lapl(der, space, coord_system)
    call derivatives_get_stencil_grad(der)

    ! find out how many ghost points we need in each dimension
    SAFE_ALLOCATE(der%n_ghost(1:der%dim))
    der%n_ghost(:) = 0
    do idir = 1, der%dim
      der%n_ghost(idir) = maxval(abs(der%lapl%stencil%points(idir, :)))
    end do

    nullify(der%coarser)
    nullify(der%finer)
    nullify(der%to_coarser)
    nullify(der%to_finer)

    if (accel_is_enabled()) then
      ! Check if we can build the uvw_to_xyz kernel
      select type (coord_system)
      type is (cartesian_t)
        ! In this case one does not need to call the kernel, so all is fine
      class default
        if (der%dim > 3) then
          message(1) = "Calculation of derivatives on the GPU with dimension > 3 are only implemented for Cartesian coordinates."
          call messages_fatal(1, namespace=namespace)
        else
          write(flags, '(A,I1.1)') ' -DDIMENSION=', der%dim
          call accel_kernel_build(kernel_uvw_xyz, 'uvw_to_xyz.cl', 'uvw_to_xyz', flags)
        end if
      end select
      call accel_kernel_build(kernel_dcurl, 'curl.cl', 'dcurl', flags = '-DRTYPE_DOUBLE')
      call accel_kernel_build(kernel_zcurl, 'curl.cl', 'zcurl', flags = '-DRTYPE_COMPLEX')
    end if

    POP_SUB(derivatives_init)
  end subroutine derivatives_init


  ! ---------------------------------------------------------
  subroutine derivatives_end(der)
    type(derivatives_t), intent(inout) :: der

    integer :: idim

    PUSH_SUB(derivatives_end)

    ASSERT(allocated(der%op))

    do idim = 1, der%dim+1
      call nl_operator_end(der%op(idim))
    end do

    SAFE_DEALLOCATE_A(der%masses)

    SAFE_DEALLOCATE_A(der%n_ghost)

    SAFE_DEALLOCATE_A(der%op)
    nullify(der%lapl, der%grad)

    nullify(der%coarser)
    nullify(der%finer)
    nullify(der%to_coarser)
    nullify(der%to_finer)

    call boundaries_end(der%boundaries)

    POP_SUB(derivatives_end)
  end subroutine derivatives_end


  ! ---------------------------------------------------------
  subroutine derivatives_get_stencil_lapl(der, space, coord_system)
    type(derivatives_t),        intent(inout) :: der
    class(space_t),             intent(in)    :: space
    class(coordinate_system_t), intent(in)    :: coord_system

    PUSH_SUB(derivatives_get_stencil_lapl)

    ASSERT(associated(der%lapl))

    ! initialize nl operator
    call nl_operator_init(der%lapl, "Laplacian")

    ! create stencil
    select case (der%stencil_type)
    case (DER_STAR, DER_VARIATIONAL)
      call stencil_star_get_lapl(der%lapl%stencil, der%dim, der%order)
    case (DER_CUBE)
      call stencil_cube_get_lapl(der%lapl%stencil, der%dim, der%order)
    case (DER_STARPLUS)
      call stencil_starplus_get_lapl(der%lapl%stencil, der%dim, der%order)
    case (DER_STARGENERAL)
      call stencil_stargeneral_get_arms(der%lapl%stencil, space%dim, coord_system)
      call stencil_stargeneral_get_lapl(der%lapl%stencil, der%dim, der%order)
    end select

    POP_SUB(derivatives_get_stencil_lapl)
  end subroutine derivatives_get_stencil_lapl


  ! ---------------------------------------------------------
  !> Returns the diagonal elements of the Laplacian, needed for preconditioning
  subroutine derivatives_lapl_diag(der, lapl)
    type(derivatives_t), intent(in)  :: der
    real(real64),        intent(out) :: lapl(:)  !< lapl(mesh%np)

    PUSH_SUB(derivatives_lapl_diag)

    ASSERT(ubound(lapl, DIM=1) >= der%mesh%np)

    ! the Laplacian is a real operator
    call dnl_operator_operate_diag(der%lapl, lapl)

    POP_SUB(derivatives_lapl_diag)

  end subroutine derivatives_lapl_diag


  ! ---------------------------------------------------------
  subroutine derivatives_get_stencil_grad(der)
    type(derivatives_t), intent(inout) :: der


    integer  :: ii
    character :: dir_label

    PUSH_SUB(derivatives_get_stencil_grad)

    ASSERT(associated(der%grad))

    ! initialize nl operator
    do ii = 1, der%dim
      dir_label = ' '
      if (ii < 5) dir_label = index2axis(ii)

      call nl_operator_init(der%grad(ii), "Gradient "//dir_label)

      ! create stencil
      select case (der%stencil_type)
      case (DER_STAR, DER_VARIATIONAL)
        call stencil_star_get_grad(der%grad(ii)%stencil, der%dim, ii, der%order)
      case (DER_CUBE)
        call stencil_cube_get_grad(der%grad(ii)%stencil, der%dim, der%order)
      case (DER_STARPLUS)
        call stencil_starplus_get_grad(der%grad(ii)%stencil, der%dim, ii, der%order)
      case (DER_STARGENERAL)
        ! use the simple star stencil
        call stencil_star_get_grad(der%grad(ii)%stencil, der%dim, ii, der%order)
      end select
    end do

    POP_SUB(derivatives_get_stencil_grad)

  end subroutine derivatives_get_stencil_grad

  ! ---------------------------------------------------------
  !> @brief build the derivatives object:
  !!
  !! This routine initializes the boundary conditions and
  !! builds the actual nl_operator_oct_m::nl_operator_t objects
  !! for the derivative operators.
  !!
  !! The stencil weights are computed by imposing that the action of the discrete Laplacian
  !! or gradient is exact on a set of polynomials. The set of polynomials needs to be chosen
  !! such that it yields linearly independent vectors when evaluated on the points of the
  !! stencil.
  !!
  !! More specifically, the weights are the solution of the equation
  !! \f[ mat * weights = rhs, \f]
  !! where \f$ mat_{i,j} \f$ is the value of the \f$i\f$-th polynomial on the \f$j\f$-the
  !! point of the stencil and \f$rhs\f$ is the solution of applying the operator to the
  !! polynomial and evaluating it at 0.
  !!
  !! Let \f$u_1,u_2,u_3\f$ be the primitive coordinates and \f$F=BB^T\f$, with
  !! \f$B\f$ the change-of-coordinates matrix (see Natan et al. (2008), PRB 78, 075109,
  !! eqs. 10 and 11). The polynomials are given as \f$u_1^{\alpha_1}u_2^{\alpha_2}u_3^{\alpha_3}\f$,
  !! and the right-hand side for the gradient is
  !! - 1 if \f$\alpha_i = 1, \alpha_j = 0, j\ne i\f$
  !! - 0 otherwise
  !! and for the Laplacian
  !! - \f$2 F_{i,i}\f$ if \f$\alpha_i = 2, \alpha_j = 0, j\ne i\f$
  !! - \f$F_{i,j} + F_{j,i}\f$ if \f$\alpha_i = 1, \alpha_j = 1, \alpha_k = 0, k\ne i, k \ne j\f$
  !! - 0 otherwise
  !! The polynomials need to include the cases for which the right-hand-side is non-zero.
  !!
  !! The method has been tested for the default and cube stencils on all 14 Bravais lattices
  !! in 3D and yields acceptable errors for the derivatives.
  subroutine derivatives_build(der, namespace, space, mesh, qvector, regenerate, verbose)
    type(derivatives_t),    intent(inout) :: der
    type(namespace_t),      intent(in)    :: namespace
    class(space_t),         intent(in)    :: space      !< space (spatial and periodic dimensions)
    class(mesh_t),  target, intent(in)    :: mesh       !< the underlying mesh
    real(real64), optional,        intent(in)    :: qvector(:) !< momentum transfer for spiral BC
    !!                                                     (pass through to boundaries_oct_m::boundaries_init())
    logical, optional,      intent(in)    :: regenerate
    logical, optional,      intent(in)    :: verbose

    integer, allocatable :: polynomials(:,:)
    real(real64),   allocatable :: rhs(:,:)
    integer :: i
    logical :: const_w_
    character(len=32) :: name
    type(nl_operator_t) :: auxop
    integer :: np_zero_bc

    PUSH_SUB(derivatives_build)

    if (.not. optional_default(regenerate, .false.)) then
      call boundaries_init(der%boundaries, namespace, space, mesh, qvector)
    end if

    ASSERT(allocated(der%op))
    ASSERT(der%stencil_type >= DER_STAR .and. der%stencil_type <= DER_STARGENERAL)
    ASSERT(.not.(der%stencil_type == DER_VARIATIONAL .and. mesh%use_curvilinear))

    der%mesh => mesh    ! make a pointer to the underlying mesh

    const_w_ = .true.

    ! need non-constant weights for curvilinear and scattering meshes
    if (mesh%use_curvilinear) const_w_ = .false.

    np_zero_bc = 0

    ! build operators
    do i = 1, der%dim+1
      if (optional_default(regenerate, .false.)) then
        SAFE_DEALLOCATE_A(der%op(i)%w)
      end if
      call nl_operator_build(space, mesh, der%op(i), der%mesh%np, const_w = const_w_, &
        regenerate=regenerate)
      np_zero_bc = max(np_zero_bc, nl_operator_np_zero_bc(der%op(i)))
    end do

    ASSERT(np_zero_bc > mesh%np .and. np_zero_bc <= mesh%np_part)

    select case (der%stencil_type)

    case (DER_STAR, DER_STARPLUS, DER_STARGENERAL) ! Laplacian and gradient have different stencils
      do i = 1, der%dim + 1
        SAFE_ALLOCATE(polynomials(1:der%dim, 1:der%op(i)%stencil%size))
        SAFE_ALLOCATE(rhs(1:der%op(i)%stencil%size, 1))

        if (i <= der%dim) then  ! gradient
          call stencil_stars_pol_grad(der%stencil_type, der%dim, i, der%order, polynomials)
          call get_rhs_grad(der, polynomials, i, rhs(:,1))
          name = index2axis(i) // "-gradient"
        else                      ! Laplacian
          call stencil_stars_pol_lapl(der%stencil_type, der%op(der%dim+1)%stencil, der%dim, der%order, polynomials)
          call get_rhs_lapl(der, polynomials, rhs(:,1))
          name = "Laplacian"
        end if

        call derivatives_make_discretization(namespace, der%dim, der%periodic_dim, der%mesh, der%masses, polynomials, rhs, 1, &
          der%op(i:i), name, verbose=verbose)
        SAFE_DEALLOCATE_A(polynomials)
        SAFE_DEALLOCATE_A(rhs)
      end do

    case (DER_CUBE)
      ! Laplacian and gradient have similar stencils, so use one call to derivatives_make_discretization
      ! to solve the linear equation once for several right-hand sides
      SAFE_ALLOCATE(polynomials(1:der%dim, 1:der%op(1)%stencil%size))
      SAFE_ALLOCATE(rhs(1:der%op(1)%stencil%size, 1:der%dim + 1))
      call stencil_cube_polynomials_lapl(der%dim, der%order, polynomials)

      do i = 1, der%dim
        call get_rhs_grad(der, polynomials, i, rhs(:,i))
      end do
      call get_rhs_lapl(der, polynomials, rhs(:, der%dim+1))

      name = "derivatives"
      call derivatives_make_discretization(namespace, der%dim, der%periodic_dim, der%mesh, der%masses, polynomials, rhs, &
        der%dim+1, der%op(:), name, verbose=verbose)

      SAFE_DEALLOCATE_A(polynomials)
      SAFE_DEALLOCATE_A(rhs)

    case (DER_VARIATIONAL)
      ! we have the explicit coefficients
      call stencil_variational_coeff_lapl(der%dim, der%order, mesh%spacing, der%lapl, alpha = der%lapl_cutoff)
    end select


    ! Here the Laplacian is forced to be self-adjoint, and the gradient to be skew-self-adjoint
    if (mesh%use_curvilinear) then
      ! The nl_operator_copy routine has an assert
      if (accel_is_enabled()) then
        call messages_write('Curvilinear coordinates on GPUs is not implemented')
        call messages_fatal(namespace=namespace)
      end if

      do i = 1, der%dim
        call nl_operator_init(auxop, "auxop")
        call nl_operator_adjoint(der%grad(i), auxop, der%mesh, self_adjoint=.false.)

        call nl_operator_end(der%grad(i))
        call nl_operator_copy(der%grad(i), auxop)
        call nl_operator_end(auxop)
      end do
      call nl_operator_init(auxop, "auxop")
      call nl_operator_adjoint(der%lapl, auxop, der%mesh, self_adjoint=.true.)

      call nl_operator_end(der%lapl)
      call nl_operator_copy(der%lapl, auxop)
      call nl_operator_end(auxop)
    end if

    POP_SUB(derivatives_build)

  contains
    subroutine stencil_stars_pol_grad(stencil_type, dim, direction, order, polynomials)
      integer, intent(in)    :: stencil_type
      integer, intent(in)    :: dim
      integer, intent(in)    :: direction
      integer, intent(in)    :: order
      integer, intent(inout) :: polynomials(:, :)

      select case (der%stencil_type)
      case (DER_STAR, DER_STARGENERAL)
        call stencil_star_polynomials_grad(direction, order, polynomials)
      case (DER_STARPLUS)
        call stencil_starplus_pol_grad(dim, direction, order, polynomials)
      end select
    end subroutine stencil_stars_pol_grad

    subroutine stencil_stars_pol_lapl(stencil_type, stencil, dim, order, polynomials)
      integer,         intent(in)    :: stencil_type
      type(stencil_t), intent(in)    :: stencil
      integer,         intent(in)    :: dim
      integer,         intent(in)    :: order
      integer,         intent(inout) :: polynomials(:, :)

      select case (der%stencil_type)
      case (DER_STAR)
        call stencil_star_polynomials_lapl(dim, order, polynomials)
      case (DER_STARPLUS)
        call stencil_starplus_pol_lapl(dim, order, polynomials)
      case (DER_STARGENERAL)
        call stencil_stargeneral_pol_lapl(stencil, dim, order, polynomials)
      end select
    end subroutine stencil_stars_pol_lapl

  end subroutine derivatives_build

  ! ---------------------------------------------------------
  subroutine get_rhs_lapl(der, polynomials, rhs)
    type(derivatives_t), intent(in)  :: der
    integer,             intent(in)  :: polynomials(:,:)
    real(real64),        intent(out) :: rhs(:)

    integer :: i, j, k
    real(real64) :: F(1:der%dim, 1:der%dim)
    integer :: powers(0:2)

    PUSH_SUB(get_rhs_lapl)

    ! assume orthogonal basis as default (i.e. for curvilinear coordinates)
    F = M_ZERO
    do i = 1, der%dim
      F(i, i) = M_ONE
    end do

    select type (coord => der%mesh%coord_system)
    class is (affine_coordinates_t)
      F(1:der%dim, 1:der%dim) = matmul(coord%basis%change_of_basis_matrix(1:der%dim, 1:der%dim), &
        transpose(coord%basis%change_of_basis_matrix(1:der%dim, 1:der%dim)))
    class is (curv_gygi_t)
    class is (curv_briggs_t)
    class is (curv_modine_t)
    class default
      message(1) = "Weight computation not implemented for the coordinate system chosen."
      call messages_fatal(1)
    end select

    ! find right-hand side for operator
    rhs(:) = M_ZERO
    do j = 1, size(polynomials, dim=2)
      ! count the powers of the polynomials
      powers = 0
      do i = 1, der%dim
        if (polynomials(i, j) <= 2) then
          powers(polynomials(i, j)) = powers(polynomials(i, j)) + 1
        end if
      end do

      ! find all polynomials for which exactly one term is quadratic
      ! for these, the Laplacian on the polynomial is 2*F(i, i)
      if (powers(2) == 1 .and. powers(0) == der%dim - 1) then
        do i = 1, der%dim
          if (polynomials(i, j) == 2) then
            rhs(j) = M_TWO*F(i, i)
          end if
        end do
      end if
      ! find all polynomials for which exactly two terms are linear
      ! for these, the Laplacian on the polynomial is F(i, k) + F(k, i)
      if (powers(1) == 2 .and. powers(0) == der%dim - 2) then
        do i = 1, der%dim
          if (polynomials(i, j) == 1) then
            do k = i+1, der%dim
              if (polynomials(k, j) == 1) then
                rhs(j) = F(i, k) + F(k, i)
              end if
            end do
          end if
        end do
      end if
    end do

    POP_SUB(get_rhs_lapl)
  end subroutine get_rhs_lapl

  ! ---------------------------------------------------------
  subroutine get_rhs_grad(der, polynomials, dir, rhs)
    type(derivatives_t), intent(in)  :: der
    integer,             intent(in)  :: polynomials(:,:)
    integer,             intent(in)  :: dir
    real(real64),        intent(out) :: rhs(:)

    integer :: j, k
    logical :: this_one

    PUSH_SUB(get_rhs_grad)

    ! find right-hand side for operator
    rhs(:) = M_ZERO
    do j = 1, der%grad(dir)%stencil%size
      this_one = .true.
      do k = 1, der%dim
        if (k == dir .and. polynomials(k, j) /= 1) this_one = .false.
        if (k /= dir .and. polynomials(k, j) /= 0) this_one = .false.
      end do
      if (this_one) rhs(j) = M_ONE
    end do

    POP_SUB(get_rhs_grad)
  end subroutine get_rhs_grad


  ! ---------------------------------------------------------
  subroutine derivatives_make_discretization(namespace, dim, periodic_dim, mesh, masses, pol, rhs, nderiv, op, name, &
    verbose)
    type(namespace_t),      intent(in)    :: namespace
    integer,                intent(in)    :: dim
    integer,                intent(in)    :: periodic_dim
    type(mesh_t),           intent(in)    :: mesh
    real(real64),           intent(in)    :: masses(:)
    integer,                intent(in)    :: pol(:,:)
    integer,                intent(in)    :: nderiv
    real(real64), contiguous,      intent(inout) :: rhs(:,:)
    type(nl_operator_t),    intent(inout) :: op(:)
    character(len=32),      intent(in)    :: name
    logical, optional,      intent(in)    :: verbose

    integer :: p, p_max, i, j, k, pow_max
    real(real64)   :: x(dim)
    real(real64), allocatable :: mat(:,:), sol(:,:), powers(:,:)
    logical :: transform_to_cartesian

    PUSH_SUB(derivatives_make_discretization)

    SAFE_ALLOCATE(mat(1:op(1)%stencil%size, 1:op(1)%stencil%size))
    SAFE_ALLOCATE(sol(1:op(1)%stencil%size, 1:nderiv))

    if (optional_default(verbose, .true.)) then
      message(1) = 'Info: Generating weights for finite-difference discretization of ' // trim(name)
      call messages_info(1, namespace=namespace)
    end if

    select type (coord => mesh%coord_system)
    class is (affine_coordinates_t)
      transform_to_cartesian = .false.
    class is (curv_gygi_t)
      transform_to_cartesian = .true.
    class is (curv_briggs_t)
      transform_to_cartesian = .true.
    class is (curv_modine_t)
      transform_to_cartesian = .true.
    class default
      message(1) = "Weight computation not implemented for the coordinate system chosen."
      call messages_fatal(1)
    end select

    ! use to generate power lookup table
    pow_max = maxval(pol)
    SAFE_ALLOCATE(powers(1:dim, 0:pow_max))
    powers(:,:) = M_ZERO
    powers(:,0) = M_ONE

    p_max = op(1)%np
    if (op(1)%const_w) p_max = 1

    do p = 1, p_max
      ! first polynomial is just a constant
      mat(1,:) = M_ONE
      ! i indexes the point in the stencil
      do i = 1, op(1)%stencil%size
        if (mesh%use_curvilinear) then
          x = mesh%x(p + op(1)%ri(i, op(1)%rimap(p)), :) - mesh%x(p, :)
        else
          x = real(op(1)%stencil%points(:, i), real64) *mesh%spacing
          ! transform to Cartesisan coordinates only for curvilinear meshes
          if (transform_to_cartesian) then
            x = mesh%coord_system%to_cartesian(x)
          end if
        end if

        ! NB: these masses are applied on the cartesian directions. Should add a check for non-orthogonal axes
        x = x*sqrt(masses)

        ! calculate powers
        powers(:, 1) = x
        do k = 2, pow_max
          powers(:, k) = x*powers(:, k-1)
        end do

        ! generate the matrix
        ! j indexes the polynomial being used
        do j = 2, op(1)%stencil%size
          mat(j, i) = powers(1, pol(1, j))
          do k = 2, dim
            mat(j, i) = mat(j, i)*powers(k, pol(k, j))
          end do
        end do
      end do ! loop over i = point in stencil

      ! linear problem to solve for derivative weights:
      !   mat * sol = rhs
      call lalg_linsyssolve(op(1)%stencil%size, nderiv, mat, rhs, sol)

      ! for the cube stencil, all derivatives are calculated at once, so assign
      ! the correct solution to each operator
      do i = 1, nderiv
        op(i)%w(:, p) = sol(:, i)
      end do

    end do ! loop over points p

    do i = 1, nderiv
      call nl_operator_output_weights(op(i))
    end do

    ! In case of constant weights, we store the weights of the Laplacian on the GPU, as this
    ! saves many unecessary transfers
    do i = 1, nderiv
      call nl_operator_allocate_gpu_buffers(op(i))
      call nl_operator_update_gpu_buffers(op(i))
    end do


    SAFE_DEALLOCATE_A(mat)
    SAFE_DEALLOCATE_A(sol)
    SAFE_DEALLOCATE_A(powers)

    POP_SUB(derivatives_make_discretization)
  end subroutine derivatives_make_discretization

#ifdef HAVE_MPI
  ! ---------------------------------------------------------
  logical function derivatives_overlap(this) result(overlap)
    type(derivatives_t), intent(in) :: this

    PUSH_SUB(derivatives_overlap)

    overlap = this%comm_method /= BLOCKING

    POP_SUB(derivatives_overlap)
  end function derivatives_overlap
#endif

  ! ---------------------------------------------------------
  subroutine derivatives_get_lapl(this, namespace, op, space, name, order)
    type(derivatives_t),         intent(in)    :: this
    type(namespace_t),           intent(in)    :: namespace
    type(nl_operator_t),         intent(inout) :: op(:)
    class(space_t),              intent(in)    :: space
    character(len=32),           intent(in)    :: name
    integer,                     intent(in)    :: order

    integer, allocatable :: polynomials(:,:)
    real(real64), allocatable :: rhs(:,:)

    PUSH_SUB(derivatives_get_lapl)

    call nl_operator_init(op(1), name)
    if (.not. this%mesh%coord_system%orthogonal) then
      call stencil_stargeneral_get_arms(op(1)%stencil, this%dim, this%mesh%coord_system)
      call stencil_stargeneral_get_lapl(op(1)%stencil, this%dim, order)
    else
      call stencil_star_get_lapl(op(1)%stencil, this%dim, order)
    end if
    call nl_operator_build(space, this%mesh, op(1), this%mesh%np, const_w = .not. this%mesh%use_curvilinear)

    !At the moment this code is almost copy-pasted from derivatives_build.
    SAFE_ALLOCATE(polynomials(1:this%dim, 1:op(1)%stencil%size))
    SAFE_ALLOCATE(rhs(1:op(1)%stencil%size, 1))
    if (.not. this%mesh%coord_system%orthogonal) then
      call stencil_stargeneral_pol_lapl(op(1)%stencil, this%dim, order, polynomials)
    else
      call stencil_star_polynomials_lapl(this%dim, order, polynomials)
    end if
    call get_rhs_lapl(this, polynomials, rhs(:, 1))
    call derivatives_make_discretization(namespace, this%dim, this%periodic_dim, this%mesh, this%masses, &
      polynomials, rhs, 1, op(1:1), name)
    SAFE_DEALLOCATE_A(polynomials)
    SAFE_DEALLOCATE_A(rhs)

    POP_SUB(derivatives_get_lapl)
  end subroutine derivatives_get_lapl

  ! ---------------------------------------------------------
  !> @brief This function tells whether a point in the grid is contained in a layer of the width of
  !! the stencil between the last row of points in the grid.
  !!
  !! E.g. if stencil = 2, then: </br>
  !! 1 1 1 1 1 1 1 </br>
  !! 1 1 1 1 1 1 1 </br>
  !! 1 1 0 0 0 1 1 </br>
  !! 1 1 0 0 0 1 1 </br>
  !! 1 1 0 0 0 1 1 </br>
  !! 1 1 1 1 1 1 1 </br>
  !! 1 1 1 1 1 1 1 </br>
  !! So the innermost points of the grid will not be masked, while the one between the innermost and the boundary will be masked
  function derivatives_get_inner_boundary_mask(this) result(mask)
    type(derivatives_t), intent(in)  :: this

    logical :: mask(1:this%mesh%np) !< mask that tells which points are
    integer :: ip, is, index

    mask = .false.
    ! Loop through all points in the grid
    do ip = 1, this%mesh%np
      ! For each of them, loop through all points in the stencil
      do is = 1, this%lapl%stencil%size
        ! Get the index of the point obtained as: grid_point + displament_due_to_stencil
        index = nl_operator_get_index(this%lapl, is, ip)
        ! Check whether the displaced point if outsude the grid. Is so, it belongs to the mask
        if (index > this%mesh%np + this%mesh%pv%np_ghost) then
          mask(ip) = .true.
          exit
        end if
      end do
    end do

  end function derivatives_get_inner_boundary_mask

  ! ---------------------------------------------------------
  !> Get maximum eigenvalue of discrete Laplacian.
  !! For the star and star_general stencils, use the Fourier transform of the
  !! stencil evaluated at the maximum phase to get an upper bound on the spectrum.
  !! For all other stencils, use the upper bound from the continuum.
  real(real64) function derivatives_lapl_get_max_eigenvalue(this)
    type(derivatives_t), intent(in) :: this

    integer :: i

    PUSH_SUB(derivatives_lapl_get_max_eigenvalue)

    derivatives_lapl_get_max_eigenvalue = M_ZERO
    if ((this%stencil_type == DER_STAR .or. this%stencil_type == DER_STARGENERAL) &
      .and. .not. this%mesh%use_curvilinear) then
      ! use Fourier transform of stencil evaluated at the maximum phase
      do i = 1, this%lapl%stencil%size
        derivatives_lapl_get_max_eigenvalue = derivatives_lapl_get_max_eigenvalue + &
          (-1)**maxval(abs(this%lapl%stencil%points(:, i)))*this%lapl%w(i, 1)
      end do
      derivatives_lapl_get_max_eigenvalue = abs(derivatives_lapl_get_max_eigenvalue)
    else
      ! use upper bound from continuum for other stencils
      do i = 1, this%dim
        derivatives_lapl_get_max_eigenvalue = derivatives_lapl_get_max_eigenvalue + &
          M_PI**2/this%mesh%spacing(i)**2
      end do
    end if

    POP_SUB(derivatives_lapl_get_max_eigenvalue)
  end function derivatives_lapl_get_max_eigenvalue


#include "undef.F90"
#include "real.F90"
#include "derivatives_inc.F90"

#include "undef.F90"
#include "complex.F90"
#include "derivatives_inc.F90"

end module derivatives_oct_m

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
