!! Copyright (C) 2005-2006 M. Marques, A. Castro, A. Rubio, G. Bertsch, X. Andrade
!! Copyright (C) 2024 N. Tancogne-Dejean
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

#include "global.h"

!>@brief This modules provides the routines for solving Ax=b using the V-shaped multigrid method
module multigrid_solver_oct_m
  use boundaries_oct_m
  use debug_oct_m
  use derivatives_oct_m
  use global_oct_m
  use, intrinsic :: iso_fortran_env
  use lalg_basic_oct_m
  use mesh_oct_m
  use mesh_function_oct_m
  use messages_oct_m
  use multigrid_oct_m
  use namespace_oct_m
  use nl_operator_oct_m
  use operate_f_oct_m
  use parser_oct_m
  use par_vec_oct_m
  use profiling_oct_m
  use space_oct_m
  use varinfo_oct_m

  implicit none

  integer, parameter ::       &
    GAUSS_SEIDEL        = 1,  &
    WEIGHTED_JACOBI     = 2

  private
  public ::                     &
    multigrid_solver_V_cycle,   &
    multigrid_solver_W_cycle,   &
    multigrid_iterative_solver, &
    multigrid_FMG_solver,       &
    multigrid_solver_init,      &
    mg_solver_t

  type mg_solver_t
    private

    real(real64), public :: threshold
    real(real64) :: relax_factor

    integer, public :: maxcycles = 0
    integer :: presteps
    integer :: poststeps
    integer :: restriction_method
    integer :: relaxation_method
  end type mg_solver_t

  integer, parameter, public :: &
    MG_V_SHAPE = 1,             &
    MG_W_SHAPE = 2,             &
    MG_FMG     = 3

contains

  ! ---------------------------------------------------------
  subroutine multigrid_solver_init(this, namespace, space, mesh, thr)
    type(mg_solver_t), intent(out)   :: this
    type(namespace_t), intent(in)    :: namespace
    class(space_t),    intent(in)    :: space
    type(mesh_t),      intent(inout) :: mesh
    real(real64),      intent(in)    :: thr

    PUSH_SUB(multigrid_solver_init)

    this%threshold = thr

    !%Variable MultigridPresmoothingSteps
    !%Type integer
    !%Default 1
    !%Section Hamiltonian::Poisson::Multigrid
    !%Description
    !% Number of Gauss-Seidel smoothing steps before coarse-level
    !% correction in the multigrid solver.
    !%End
    call parse_variable(namespace, 'MultigridPresmoothingSteps', 1, this%presteps)

    !%Variable MultigridPostsmoothingSteps
    !%Type integer
    !%Default 4
    !%Section Hamiltonian::Poisson::Multigrid
    !%Description
    !% Number of Gauss-Seidel smoothing steps after coarse-level
    !% correction in the multigrid solver.
    !%End
    call parse_variable(namespace, 'MultigridPostsmoothingSteps', 4, this%poststeps)

    !%Variable MultigridMaxCycles
    !%Type integer
    !%Default 50
    !%Section Hamiltonian::Poisson::Multigrid
    !%Description
    !% Maximum number of multigrid cycles that are performed if
    !% convergence is not achieved.
    !%End
    call parse_variable(namespace, 'MultigridMaxCycles', 50, this%maxcycles)

    !%Variable MultigridRestrictionMethod
    !%Type integer
    !%Default fullweight
    !%Section Hamiltonian::Poisson::Multigrid
    !%Description
    !% Method used from fine-to-coarse grid transfer.
    !%Option injection 1
    !% Injection
    !%Option fullweight 2
    !% Fullweight restriction
    !%End
    call parse_variable(namespace, 'MultigridRestrictionMethod', 2, this%restriction_method)
    if (.not. varinfo_valid_option('MultigridRestrictionMethod', this%restriction_method)) then
      call messages_input_error(namespace, 'MultigridRestrictionMethod')
    end if
    call messages_print_var_option("MultigridRestrictionMethod", this%restriction_method, namespace=namespace)

    !%Variable MultigridRelaxationMethod
    !%Type integer
    !%Section Hamiltonian::Poisson::Multigrid
    !%Description
    !% Method used to solve the linear system approximately in each grid for the
    !% multigrid procedure that solves a linear equation like the Poisson equation. Default is <tt>gauss_seidel</tt>,
    !% unless curvilinear coordinates are used, in which case the default is <tt>gauss_jacobi</tt>.
    !%Option gauss_seidel 1
    !% Gauss-Seidel.
    !%Option weighted_jacobi 2
    !% Jacobi relaxation with a weight. The weight is determined by by MultigridRelaxationFactor.
    !%End
    if (mesh%use_curvilinear) then
      call parse_variable(namespace, 'MultigridRelaxationMethod', WEIGHTED_JACOBI, this%relaxation_method)
    else
      call parse_variable(namespace, 'MultigridRelaxationMethod', GAUSS_SEIDEL, this%relaxation_method)
    end if

    if (.not. varinfo_valid_option('MultigridRelaxationMethod', this%relaxation_method)) then
      call messages_input_error(namespace, 'MultigridRelaxationMethod')
    end if
    call messages_print_var_option("MultigridRelaxationMethod", this%relaxation_method, namespace=namespace)

    if (this%relaxation_method == WEIGHTED_JACOBI) then
      !%Variable MultigridRelaxationFactor
      !%Type float
      !%Section Hamiltonian::Poisson::Multigrid
      !%Description
      !% Relaxation factor of the relaxation operator used for the
      !% multigrid method. Only used for the <tt>gauss_jacobi</tt> method.
      !% The default is 0.6666 for the <tt>gauss_jacobi</tt> method.
      !%End
      call parse_variable(namespace, 'MultigridRelaxationFactor', 0.6666_real64, this%relax_factor)
    end if

    POP_SUB(multigrid_solver_init)
  end subroutine multigrid_solver_init


  ! ---------------------------------------------------------
  !>@brief Performs one cycle of a V-shaped multigrid solver
  !!
  !! This method is called recursively from the finest grid to the coarsest and then back to the finest grid
  recursive subroutine multigrid_solver_V_cycle(this, der, op, sol, rhs)
    type(mg_solver_t),           intent(in)    :: this
    type(derivatives_t),         intent(in)    :: der
    type(nl_operator_t),        intent(in)    :: op     !< Linear operator
    real(real64), contiguous,   intent(inout) :: sol(:) !< Solution to the problem
    real(real64), contiguous,   intent(in)    :: rhs(:) !< Right-hand side of the linear problem

    real(real64), allocatable :: residue(:), coarse_residue(:), correction(:), coarse_correction(:)

    PUSH_SUB(multigrid_solver_V_cycle)

    SAFE_ALLOCATE(residue(1:der%mesh%np_part))

    if (associated(der%coarser)) then
      ASSERT(associated(op%coarser))

      SAFE_ALLOCATE(correction(1:der%mesh%np_part))
      SAFE_ALLOCATE(coarse_residue(1:der%coarser%mesh%np_part))
      SAFE_ALLOCATE(coarse_correction(1:der%coarser%mesh%np_part))

      ! Pre-Smoothing
      call multigrid_relax(this, der%mesh, der, op, sol, rhs, this%presteps)

      ! Compute the residual error
      call get_residual(op, der, sol, rhs, residue)

      ! Restriction of the residual is the next r.h.s
      message(1) = "Debug: Multigrid restriction"
      call messages_info(1, debug_only=.true.)

      call dmultigrid_fine2coarse(der%to_coarser, der, der%coarser%mesh, residue, coarse_residue, this%restriction_method)

      ! Recursive call for the coarse-grid correction
      coarse_correction = M_ZERO
      call multigrid_solver_V_cycle(this, der%coarser, op%coarser, coarse_correction, coarse_residue)

      !Prolongation
      message(1) = "Debug: Multigrid prolongation"
      call messages_info(1, debug_only=.true.)

      correction = M_ZERO
      call dmultigrid_coarse2fine(der%to_coarser, der%coarser, der%mesh, coarse_correction, correction)

      ! Correction
      call lalg_axpy(der%mesh%np, M_ONE, correction, sol)

      SAFE_DEALLOCATE_A(correction)
      SAFE_DEALLOCATE_A(coarse_residue)
      SAFE_DEALLOCATE_A(coarse_correction)

      ! Post-Smoothing
      call multigrid_relax(this, der%mesh, der, op, sol, rhs, this%poststeps)

    else ! Coarsest grid

      call multigrid_solver_solve_coarsest(this, der, op, sol, rhs, residue)

    end if

    SAFE_DEALLOCATE_A(residue)

    POP_SUB(multigrid_solver_V_cycle)
  end subroutine multigrid_solver_V_cycle

  ! ---------------------------------------------------------
  !>@brief Performs one cycle of a W-shaped multigrid solver
  !!
  !! This method is called recursively from the finest grid to the coarsest and then back to the finest grid
  recursive subroutine multigrid_solver_W_cycle(this, der, op, sol, rhs)
    type(mg_solver_t),           intent(in)    :: this
    type(derivatives_t),         intent(in)    :: der
    type(nl_operator_t),        intent(in)    :: op     !< Linear operator
    real(real64), contiguous,   intent(inout) :: sol(:) !< Solution to the problem
    real(real64), contiguous,   intent(in)    :: rhs(:) !< Right-hand side of the linear problem

    real(real64), allocatable :: residue(:), coarse_residue(:), correction(:), coarse_correction(:)

    PUSH_SUB(multigrid_solver_W_cycle)

    SAFE_ALLOCATE(residue(1:der%mesh%np_part))

    if (associated(der%coarser)) then
      ASSERT(associated(op%coarser))

      SAFE_ALLOCATE(correction(1:der%mesh%np_part))
      SAFE_ALLOCATE(coarse_residue(1:der%coarser%mesh%np_part))
      SAFE_ALLOCATE(coarse_correction(1:der%coarser%mesh%np_part))

      ! Pre-Smoothing
      call multigrid_relax(this, der%mesh, der, op, sol, rhs, this%presteps)

      ! Compute the residual error
      call get_residual(op, der, sol, rhs, residue)

      ! Restriction of the residual is the next r.h.s
      message(1) = "Debug: Multigrid restriction"
      call messages_info(1, debug_only=.true.)

      call dmultigrid_fine2coarse(der%to_coarser, der, der%coarser%mesh, residue, coarse_residue, this%restriction_method)

      ! Recursive call for the coarse-grid correction
      coarse_correction = M_ZERO
      call multigrid_solver_W_cycle(this, der%coarser, op%coarser, coarse_correction, coarse_residue)

      !Prolongation
      message(1) = "Debug: Multigrid prolongation"
      call messages_info(1, debug_only=.true.)

      correction = M_ZERO
      call dmultigrid_coarse2fine(der%to_coarser, der%coarser, der%mesh, coarse_correction, correction)

      ! Correction
      call lalg_axpy(der%mesh%np, M_ONE, correction, sol)

      ! Re-Smoothing
      call multigrid_relax(this, der%mesh, der, op, sol, rhs, this%presteps)

      ! Compute the residual error
      call get_residual(op, der, sol, rhs, residue)

      ! Restriction of the residual is the next r.h.s
      message(1) = "Debug: Multigrid restriction"
      call messages_info(1, debug_only=.true.)

      call dmultigrid_fine2coarse(der%to_coarser, der, der%coarser%mesh, residue, coarse_residue, this%restriction_method)

      ! Recursive call for the coarse-grid correction
      coarse_correction = M_ZERO
      call multigrid_solver_W_cycle(this, der%coarser, op%coarser, coarse_correction, coarse_residue)

      !Prolongation
      message(1) = "Debug: Multigrid prolongation"
      call messages_info(1, debug_only=.true.)

      correction = M_ZERO
      call dmultigrid_coarse2fine(der%to_coarser, der%coarser, der%mesh, coarse_correction, correction)

      ! Correction
      call lalg_axpy(der%mesh%np, M_ONE, correction, sol)

      SAFE_DEALLOCATE_A(correction)
      SAFE_DEALLOCATE_A(coarse_residue)
      SAFE_DEALLOCATE_A(coarse_correction)

      ! Post-Smoothing
      call multigrid_relax(this, der%mesh, der, op, sol, rhs, this%poststeps)

    else ! Coarsest grid

      call multigrid_solver_solve_coarsest(this, der, op, sol, rhs, residue)

    end if

    SAFE_DEALLOCATE_A(residue)

    POP_SUB(multigrid_solver_W_cycle)

  end subroutine multigrid_solver_W_cycle

  ! ---------------------------------------------------------
  !>@brief An iterative multigrid solver.
  !!
  !! It performs multiple V- or W-cycles up to convergence
  subroutine multigrid_iterative_solver(this, namespace, der, op, sol, rhs, multigrid_shape)
    type(mg_solver_t),          intent(in)    :: this
    type(namespace_t),          intent(in)    :: namespace
    type(derivatives_t),        intent(in)    :: der
    type(nl_operator_t),        intent(in)    :: op     !< Linear operator
    real(real64), contiguous,   intent(inout) :: sol(:) !< Solution to the problem
    real(real64), contiguous,   intent(inout) :: rhs(:) !< Right-hand side of the linear problem
    integer,                    intent(in)    :: multigrid_shape !< The shape of each cycle

    integer :: iter
    real(real64) :: resnorm
    real(real64), allocatable :: err(:)

    PUSH_SUB(multigrid_iterative_solver)

    SAFE_ALLOCATE(err(1:der%mesh%np))

    do iter = 1, this%maxcycles

      select case (multigrid_shape)
      case(MG_V_SHAPE)
        call multigrid_solver_V_cycle(this, der, op, sol, rhs)
      case(MG_W_SHAPE)
        call multigrid_solver_W_cycle(this, der, op, sol, rhs)
      case default
        ASSERT(.false.)
      end select
      ! Compute the residual
      call dderivatives_lapl(der, sol, err)
      call lalg_axpy(der%mesh%np, -M_ONE, rhs, err)
      resnorm =  dmf_nrm2(der%mesh, err)

      if (resnorm < this%threshold) exit

      write(message(1), '(a,i5,a,e13.6)') "Multigrid: base level: iter ", iter, " res ", resnorm
      call messages_info(1, namespace=namespace, debug_only=.true.)

    end do

    if (resnorm >= this%threshold) then
      message(1) = 'Multigrid Poisson solver did not converge.'
      write(message(2), '(a,e14.6)') '  Abs. norm of the residue = ', resnorm
      call messages_warning(2, namespace=namespace)
    else
      write(message(1), '(a,i4,a)') "Multigrid Poisson solver converged in ", iter, " iterations."
      write(message(2), '(a,e14.6)') '  Abs. norm of the residue = ', resnorm
      call messages_info(2, namespace=namespace, debug_only=.true.)
    end if

    SAFE_DEALLOCATE_A(err)

    POP_SUB(multigrid_iterative_solver)
  end subroutine multigrid_iterative_solver

  ! ---------------------------------------------------------
  !>@brief Full multigrid (FMG) solver
  !!
  !! There is no starting point needed in this case. The code does N V cycles at each level of the
  !! full multigrid, in order to reach convergence at each level
  recursive subroutine multigrid_FMG_solver(this, namespace, der, op, sol, rhs)
    type(mg_solver_t),          intent(in)    :: this
    type(namespace_t),          intent(in)    :: namespace
    type(derivatives_t),        intent(in)    :: der
    type(nl_operator_t),        intent(in)    :: op     !< Linear operator
    real(real64), contiguous,   intent(inout) :: sol(:) !< Solution to the problem
    real(real64), contiguous,   intent(inout) :: rhs(:) !< Right-hand side of the linear problem

    real(real64), allocatable :: coarse_solution(:), coarse_rhs(:), residue(:)

    PUSH_SUB(multigrid_FMG_solver)

    if (associated(der%coarser)) then
      ASSERT(associated(op%coarser))

      SAFE_ALLOCATE(coarse_rhs(1:der%coarser%mesh%np_part))
      SAFE_ALLOCATE(coarse_solution(1:der%coarser%mesh%np_part))

      ! Restriction of the r.h.s
      message(1) = "Debug: Full Multigrid restriction"
      call messages_info(1, debug_only=.true.)

      call dmultigrid_fine2coarse(der%to_coarser, der, der%coarser%mesh, rhs, coarse_rhs, this%restriction_method)

      ! Recursive call for the coarse-grid correction
      coarse_solution = M_ZERO
      call multigrid_FMG_solver(this, namespace, der%coarser, op%coarser, coarse_solution, coarse_rhs)

      !Prolongation
      message(1) = "Debug: Full Multigrid prolongation"
      call messages_info(1, debug_only=.true.)

      sol = M_ZERO
      call dmultigrid_coarse2fine(der%to_coarser, der%coarser, der%mesh, coarse_solution, sol)

      ! Perform N times a V cycle, up to convergence at each step
      call multigrid_iterative_solver(this, namespace, der, op, sol, rhs, MG_V_SHAPE)

      SAFE_DEALLOCATE_A(coarse_solution)
      SAFE_DEALLOCATE_A(coarse_rhs)

    else ! Coarsest grid - solve the problem

      SAFE_ALLOCATE(residue(1:der%mesh%np))
      call multigrid_solver_solve_coarsest(this, der, op, sol, rhs, residue)
      SAFE_DEALLOCATE_A(residue)
    end if

    POP_SUB(multigrid_FMG_solver)
  end subroutine multigrid_FMG_solver


  ! ---------------------------------------------------------
  !>@brief Computes the solution on the coarsest grid
  subroutine multigrid_solver_solve_coarsest(this, der, op, sol, rhs, residue)
    type(mg_solver_t),          intent(in)    :: this
    type(derivatives_t),        intent(in)    :: der
    type(nl_operator_t),        intent(in)    :: op     !< Linear operator
    real(real64), contiguous,   intent(inout) :: sol(:) !< Solution to the problem
    real(real64), contiguous,   intent(in)    :: rhs(:) !< Right-hand side of the linear problem
    real(real64), contiguous,   intent(inout) :: residue(:) !< A work array for the residue

    integer :: iter
    real(real64)   :: resnorm

    PUSH_SUB(multigrid_solver_solve_coarsest)

    ! Solution of the problem, i.e., multiple call to multigrid_relax up to convergence
    do iter = 1, this%maxcycles

      call multigrid_relax(this, der%mesh, der, op, sol, rhs, 1)

      call get_residual(op, der, sol, rhs, residue)
      resnorm = dmf_nrm2(der%mesh, residue)
      if (resnorm < this%threshold) exit

    end do

    write(message(1), '(a,i4,a)') "Debug: Multigrid coarsest grid solver converged in ", iter, " iterations."
    write(message(2), '(a,es18.6)') " Residue norm is ", resnorm
    call messages_info(2, debug_only=.true.)

    POP_SUB(multigrid_solver_solve_coarsest)
  end subroutine

  ! ---------------------------------------------------------
  !>@brief Computes the residual
  subroutine get_residual(op, der, sol, rhs, residue)
    type(nl_operator_t),       intent(in)    :: op
    type(derivatives_t),       intent(in)    :: der
    real(real64), contiguous,  intent(inout) :: sol(:)
    real(real64), contiguous,  intent(in)    :: rhs(:)
    real(real64), contiguous,  intent(inout) :: residue(:)

    integer :: ip

    ! Compute the residue
    call dderivatives_perform(op, der, sol, residue)
    !$omp parallel do
    do ip = 1, der%mesh%np
      residue(ip) = rhs(ip) - residue(ip)
    end do
  end subroutine get_residual


  ! ---------------------------------------------------------
  !>@brief Given a nonlocal operator op, perform the relaxation operator
  !!
  !! This is needed to solve the linear problem Op sol = rhs steps times
  subroutine multigrid_relax(this, mesh, der, op, sol, rhs, steps)
    type(mg_solver_t),          intent(in)    :: this
    type(mesh_t),               intent(in)    :: mesh
    type(derivatives_t),        intent(in)    :: der
    type(nl_operator_t),        intent(in)    :: op     !< Linear operator
    real(real64), contiguous,   intent(inout) :: sol(:) !< Solution to the problem
    real(real64), contiguous,   intent(in)    :: rhs(:) !< Right-hand side of the linear problem
    integer,                    intent(in)    :: steps  !< Number of steps to be performed

    integer :: istep, index
    integer :: ip, nn, is
    real(real64)   :: point, factor
    real(real64), allocatable :: op_sol(:), diag(:)

    PUSH_SUB(multigrid_relax)
    call profiling_in("MG_GAUSS_SEIDEL")

    select case (this%relaxation_method)

    case (GAUSS_SEIDEL)

      do istep = 1, steps

        call boundaries_set(der%boundaries, der%mesh, sol)

        if (mesh%parallel_in_domains) then
          call dpar_vec_ghost_update(mesh%pv, sol)
        end if

        nn = op%stencil%size

        if (op%const_w) then
          factor = -M_ONE/op%w(op%stencil%center, 1)
          call dgauss_seidel(op%stencil%size, op%w(1, 1), op%nri, &
            op%ri(1, 1), op%rimap_inv(1), op%rimap_inv(2), factor, sol(1), rhs(1))
        else
          !$omp parallel do private(point, index)
          do ip = 1, mesh%np
            point = M_ZERO
            do is = 1, nn
              index = nl_operator_get_index(op, is, ip)
              point = point + op%w(is, ip)*sol(index)
            end do
            sol(ip) = sol(ip) - (point-rhs(ip))/op%w(op%stencil%center, ip)
          end do
        end if

      end do
      call profiling_count_operations(mesh%np*(steps + 1)*(2*nn + 3))

    case (WEIGHTED_JACOBI)

      SAFE_ALLOCATE(op_sol(1:mesh%np))
      SAFE_ALLOCATE(diag(1:mesh%np))

      call dnl_operator_operate_diag(op, diag)
      !$omp parallel do
      do ip = 1, mesh%np
        diag(ip) = this%relax_factor/diag(ip)
      end do

      do istep = 1, steps
        call dderivatives_perform(op, der, sol, op_sol)
        !$omp parallel do
        do ip = 1, mesh%np
          sol(ip) = sol(ip) - diag(ip)*(op_sol(ip) - rhs(ip))
        end do
      end do

      SAFE_DEALLOCATE_A(diag)
      SAFE_DEALLOCATE_A(op_sol)

    end select

    call profiling_out("MG_GAUSS_SEIDEL")
    POP_SUB(multigrid_relax)

  end subroutine multigrid_relax

end module multigrid_solver_oct_m

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
