!! Copyright (C) 2024 N. Tancogne-Dejean
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

#include "global.h"

!> This modules takes care of testing optimizers using standard test functions
module minimizer_tests_oct_m
  use debug_oct_m
  use global_oct_m
  use io_oct_m
  use iso_c_binding
  use, intrinsic :: iso_fortran_env
  use lalg_basic_oct_m
  use profiling_oct_m
  use messages_oct_m
  use minimizer_oct_m
  use mpi_oct_m
  use namespace_oct_m
  use unit_oct_m
  use unit_system_oct_m

  implicit none

  private
  public ::                    &
    test_optimizers

contains

  !---------------------------------------------------------------------------
  !>@brief Unit tests for different optimizers
  !!
  !! At the moment
  subroutine test_optimizers(namespace)
    type(namespace_t), intent(in) :: namespace

    real(real64), dimension(2) :: x0, x, mass
    real(real64), parameter :: tol_grad = 1.0e-6_real64
    real(real64) :: dt, energy
    integer :: ierr
    integer, parameter :: max_iter = 1000

    PUSH_SUB(test_optimizers)

    call io_rm("optimization.log", namespace)

    ! Testing the FIRE algorithm with the Rosenbrock function in 2D
    ! Starting point
    x0 = [3.0_real64, 4.0_real64]

    ! Initial parameters for the FIRE algorithm with Velocity Verlet
    x = x0
    ! This timestep is taken to give a reasonable result. Too large vaues lead to unstable results
    dt = 0.002
    mass = M_ONE
    call minimize_fire(2, 2, x, dt, tol_grad, max_iter, rosenbrock_gradient_2d, write_iter_info, energy, ierr, mass, 1)

    write(message(1), "(a,i6,a)") "FIRE algorithm converged in ", -ierr, " iterations"
    write(message(2), "(a,2(f8.4,a))") " The minimum is found to be at (", x(1), ", ", x(2), ")"
    write(message(3), "(a)") " Analytical minimum is (1,1)."
    call messages_info(3)

    POP_SUB(test_optimizers)
  end subroutine test_optimizers

  !---------------------------------------------------------------------------
  !>@brief Gradient of the Rosenbrock function
  !!
  !! The function is given by \f$ f(x, y) = (a - x)^2 + b(y - x^2)^2 \f$.
  !! It has a global minimum at \f$(x,y)=(a,a^2)\f$.
  !!
  !! See https://en.wikipedia.org/wiki/Rosenbrock_function
  subroutine rosenbrock_gradient_2d(n, x, val, getgrad, grad) !< Get the new gradients given a new x
    integer, intent(in)    :: n
    real(real64), intent(in)    :: x(n)
    real(real64), intent(inout) :: val
    integer, intent(in)    :: getgrad
    real(real64), intent(inout) :: grad(n)

    real(real64), parameter :: a = M_ONE
    real(real64), parameter :: b = 100.0_real64

    PUSH_SUB(f_rosenbrok)

    ! Only the 2D version is implemented here
    ASSERT(n==2)

    grad(1) = -M_TWO * (a-x(1)) - M_FOUR * b * (x(2) -x(1)**2) * x(1)
    grad(2) =  M_TWO * b * (x(2) -x(1)**2)

    POP_SUB(f_rosenbrok)
  end subroutine rosenbrock_gradient_2d

  !---------------------------------------------------------------------------
  !>@brief Helper function required by the minimizer
  !!
  !! Outputs the details of the optimization steps into a file called optimization.log
  subroutine write_iter_info(iter, n, val, maxdr, maxgrad, x) !< Output for each iteration step
    integer, intent(in) :: iter
    integer, intent(in) :: n
    real(real64), intent(in) :: val
    real(real64), intent(in) :: maxdr
    real(real64), intent(in) :: maxgrad
    real(real64), intent(in) :: x(n)

    integer :: iunit

    if (mpi_grp_is_root(mpi_world)) then
      iunit = io_open(trim('optimization.log'), global_namespace, action = 'write', position = 'append')
      if (iter == 1) then
        write(iunit, '(a10, 3a20)') '#     iter','x','y', 'max_force'
      end if
      write(iunit, '(i10,3f20.10)') iter, x(1), x(2), maxgrad
      call io_close(iunit)
    end if

  end subroutine write_iter_info

end module minimizer_tests_oct_m

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
