!! Copyright (C) 2024 Alex Buccheri

!> @brief Expose Fortuno serial and MPI data types and routines through common aliases.
!!
!!      Serial                		    MPI                       Description
!! ------------------------------------------------------------------------------------------------------
!! execute_serial_cmd_app        execute_mpi_cmd_app     Accepts an array of test_item, and runs them.
!! init_serial_cmd_app           init_mpi_cmd_app        Sets up the cmd line app.
!! serial_cmp_app                mpi_cmd_app             Drives the tests through the cmd line app.
!! serial_case_item              mpi_case_item           Returns a test case instace as a generic test item.
!! serial_suite_item             mpi_suite_item          Returns a suite instance wrapped as test_item.
!! serial_check                  mpi_check               Perform a logical check (assertion) on a condition.
!!
#include "global.h"

module fortuno_interface_m

    #ifdef HAVE_MPI
    use fortuno_mpi, only : &
        execute_cmd_app => execute_mpi_cmd_app, &
        test_case => mpi_case_item, &
        suite => mpi_suite_item, &
        check => mpi_check, &
        is_equal, &
        test_item, &
        test_list, &
        global_comm, &
        this_rank
#else
    use fortuno_serial, only : &
        execute_cmd_app => execute_serial_cmd_app, &
        test_case => serial_case_item, &
        suite => serial_suite_item, &
        check => serial_check, &
        is_equal, &
        test_item, &
        test_list

#endif
    use global_oct_m, only: optional_default
    implicit none

    ! Scope is deliberately public

    !> @brief Returns True if two arrays are element-wise equal within a tolerance.
    !!
    !! Based on the numpy function [allclose](https://numpy.org/doc/stable/reference/generated/numpy.allclose.html).
    !! The tolerance values are positive, typically very small numbers. 
    !! The relative difference `(rtol * abs(b))` and the absolute difference `atol`
    !! are added together to compare against the absolute difference between `a` and `b`.  
    interface all_close
        module procedure :: all_close_real64_1d, all_close_real64_2d, &
                            all_close_complex64_1d, all_close_complex64_2d
    end interface all_close

contains

#ifndef HAVE_MPI
    !> Serial overload for returning dummy communicator
    function global_comm() result(comm)
        use mpi_oct_m, only: MPI_Comm, MPI_COMM_UNDEFINED
        type(MPI_Comm) :: comm
        comm = MPI_COMM_UNDEFINED
    end function global_comm
#endif

    logical function all_close_real64_1d(x, y, rtol, atol)
        use, intrinsic :: iso_fortran_env, only: real64

        real(real64), intent(in) :: x(:), y(:)
        real(real64), optional, intent(in) :: rtol
        real(real64), optional, intent(in) :: atol
        real(real64) :: atol_, rtol_
        logical, allocatable :: values_close(:)

        rtol_ =  optional_default(rtol,  1.e-5_real64)
        atol_ =  optional_default(atol,  1.e-8_real64)
        values_close = abs(x - y) <= (atol_ + rtol_ * abs(y))
        all_close_real64_1d = all(values_close)

    end function all_close_real64_1d

    logical function all_close_real64_2d(x, y, rtol, atol)
        use, intrinsic :: iso_fortran_env, only: real64

        real(real64), intent(in) :: x(:, :), y(:, :)
        real(real64), optional, intent(in) :: rtol
        real(real64), optional, intent(in) :: atol
        real(real64) :: atol_, rtol_
        logical, allocatable :: values_close(:, :)

        rtol_ =  optional_default(rtol,  1.e-5_real64)
        atol_ =  optional_default(atol,  1.e-8_real64)
        values_close = abs(x - y) <= (atol_ + rtol_ * abs(y))
        all_close_real64_2d = all(values_close)

    end function all_close_real64_2d

    logical function all_close_complex64_1d(x, y, rtol, atol)
        use, intrinsic :: iso_fortran_env, only: real64

        complex(real64), intent(in) :: x(:), y(:)
        real(real64), optional, intent(in) :: rtol
        real(real64), optional, intent(in) :: atol
        real(real64) :: atol_, rtol_
        logical, allocatable :: values_close(:)

        rtol_ =  optional_default(rtol,  1.e-5_real64)
        atol_ =  optional_default(atol,  1.e-8_real64)
        values_close = abs(x - y) <= (atol_ + rtol_ * abs(y))
        all_close_complex64_1d = all(values_close)

    end function all_close_complex64_1d

    logical function all_close_complex64_2d(x, y, rtol, atol)
        use, intrinsic :: iso_fortran_env, only: real64

        complex(real64), intent(in) :: x(:, :), y(:, :)
        real(real64), optional, intent(in) :: rtol
        real(real64), optional, intent(in) :: atol
        real(real64) :: atol_, rtol_
        logical, allocatable :: values_close(:, :)

        rtol_ =  optional_default(rtol,  1.e-5_real64)
        atol_ =  optional_default(atol,  1.e-8_real64)
        values_close = abs(x - y) <= (atol_ + rtol_ * abs(y))
        all_close_complex64_2d = all(values_close)

    end function all_close_complex64_2d

end module fortuno_interface_m
