!! Copyright (C) 2002-2006 M. Marques, A. Castro, A. Rubio, G. Bertsch
!! Copyright (C) 2020 M. Oliveira
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

#include "global.h"

!> \ingroup Fortran_Module
!! \brief top level module for all calculation modes
module run_oct_m
  use accel_oct_m
  use casida_oct_m
  use em_resp_oct_m
  use external_potential_oct_m
  use external_waves_oct_m
  use fft_oct_m
  use geom_opt_oct_m
  use global_oct_m
  use ground_state_oct_m
  use minimizer_factory_oct_m
  use interactions_factory_oct_m
  use interaction_partner_oct_m
  use invert_ks_oct_m
  use lasers_oct_m
  use messages_oct_m
  use mpi_debug_oct_m
  use mpi_oct_m
  use multicomm_oct_m
  use multisystem_basic_oct_m
  use multisystem_debug_oct_m
  use multisystem_run_oct_m
  use namespace_oct_m
  use opt_control_oct_m
  use parser_oct_m
  use phonons_fd_oct_m
  use phonons_lr_oct_m
  use poisson_oct_m
  use propagator_factory_oct_m
  use propagator_oct_m
  use kdotp_oct_m
  use profiling_oct_m
  use pulpo_oct_m
  use restart_oct_m
  use static_pol_oct_m
  use system_factory_oct_m
  use system_oct_m
  use td_oct_m
  use test_oct_m
  use time_dependent_oct_m
  use unit_system_oct_m
  use unocc_oct_m
  use varinfo_oct_m
  use vdw_oct_m

  implicit none

  private
  public ::                      &
    run

  integer, parameter :: LR = 1, FD = 2

contains

  ! ---------------------------------------------------------
  !> query input file for the response mode.
  integer function get_resp_method(namespace)
    type(namespace_t),    intent(in)    :: namespace

    PUSH_SUB(get_resp_method)

    !%Variable ResponseMethod
    !%Type integer
    !%Default sternheimer
    !%Section Linear Response
    !%Description
    !% Some response properties can be calculated either via
    !% Sternheimer linear response or by using finite
    !% differences. You can use this variable to select how you want
    !% them to be calculated, it applies to <tt>em_resp</tt> and <tt>vib_modes</tt>
    !% calculation modes. By default, the Sternheimer linear-response
    !% technique is used.
    !%Option sternheimer 1
    !% The linear response is obtained by solving a self-consistent
    !% Sternheimer equation for the variation of the orbitals. This
    !% is the recommended method.
    !%Option finite_differences 2
    !% Properties are calculated as a finite-differences derivative of
    !% the energy obtained by several ground-state calculations. This
    !% method, slow and limited only to static response, is kept
    !% mainly because it is simple and useful for testing purposes.
    !%End

    call parse_variable(namespace, 'ResponseMethod', LR, get_resp_method)

    if (.not. varinfo_valid_option('ResponseMethod', get_resp_method)) then
      call messages_input_error(namespace, 'ResponseMethod')
    end if

    POP_SUB(get_resp_method)
  end function get_resp_method

  ! ---------------------------------------------------------
  !> main routine to run all calculations:
  !! This routine parses the input file, sets up the systems and interactions, and
  !! calls the corresponding routines for the requested run mode.
  subroutine run(namespace, calc_mode_id)
    type(namespace_t), intent(in) :: namespace
    integer,           intent(in) :: calc_mode_id

    type(partner_list_t)     :: partners !< structured (hierarchical) list of partners
    class(system_t), pointer :: systems  !< top level system (either electrons or multisystem_basic container)
    type(system_factory_t) :: system_factory
    type(interactions_factory_t) :: interactions_factory
    logical :: from_scratch
    integer :: iunit_out
    type(partner_iterator_t) :: iter
    type(system_list_t) :: flat_list
    class(interaction_partner_t), pointer :: partner
    real(real64) :: largest_dt, largest_allowed_time

    PUSH_SUB(run)

    call messages_print_with_emphasis(msg="Calculation Mode", namespace=namespace)
    call messages_print_var_option("CalculationMode", calc_mode_id, namespace=namespace)
    call messages_print_with_emphasis(namespace=namespace)

    call calc_mode_init()

    if (calc_mode_id == OPTION__CALCULATIONMODE__RECIPE) then
      call pulpo_print()
      POP_SUB(run)
      return
    end if

    call restart_module_init(namespace)

    call unit_system_init(namespace)

    call accel_init(mpi_world, namespace)

    ! initialize FFTs
    call fft_all_init(namespace)

    if (calc_mode_id == OPTION__CALCULATIONMODE__TEST) then
      call test_run(namespace)
      call fft_all_end()
      call mpi_debug_statistics()
      POP_SUB(run)
      return
    end if

    ! Create systems
    if (parse_is_defined(namespace, "Systems")) then
      ! We are running in multi-system mode
      systems => system_factory%create(namespace, SYSTEM_MULTISYSTEM)
    else
      ! Fall back to old behaviour
      systems => electrons_t(namespace, generate_epot = calc_mode_id /= OPTION__CALCULATIONMODE__DUMMY)
    end if

    ! initialize everything that needs parallelization
    call systems%init_parallelization(mpi_world)

    ! Create list of partners
    select type (systems)
    class is (multisystem_basic_t)
      ! Systems are also partners
      partners = systems%list
      ! Add external potentials to partners list
      call load_external_potentials(partners, namespace)

      call load_external_waves(partners, namespace)

      ! Add lasers to the partner list
      call load_lasers(partners, namespace)

    type is (electrons_t)
      call partners%add(systems)
    end select

    ! Initialize algorithms (currently only propagators are supported)
    select type (systems)
    class is (multisystem_basic_t)
      select case (calc_mode_id)
      case (OPTION__CALCULATIONMODE__GS)
        call systems%new_algorithm(minimizer_factory_t(systems%namespace))
      case (OPTION__CALCULATIONMODE__TD)
        call systems%new_algorithm(propagator_factory_t(systems%namespace))
      end select
    end select


    ! Check whether the final propagation time would lead to an infinite loop, and stop the code if necessary.
    ! See issue: https://gitlab.com/octopus-code/octopus/-/issues/1226
    !
    ! The deadlock currently appears when systems have different timesteps, and one system could do a
    !
    !                 2   4   6   8  10  propagation time=11
    ! system1 (dt=2): *   *   *   *   * |
    ! system2 (dt=4):     *       *     |
    !
    ! As system2 is not allowed to do to timestep 12, it is stuck at 8. System1 will try to get to 10, but is stuck at a barrier, as
    ! system2 is behind, and keeps waiting for system2.
    !
    ! We can only allow propagations up to the largest time, the fastest system (largest dt)
    ! can reach within the given propagation time. If one of the slower systems could perform
    ! an extra step within the propagation time, we need to stop the calculation, before the
    ! propagation, to prevent wasting CPU time in the deadlock.
    !
    ! TODO: Fix the framework, so that this infinite loop can be avoided in the first place.

    if(calc_mode_id == OPTION__CALCULATIONMODE__TD) then

      select type (sys => systems)
      type is (multisystem_basic_t)

        largest_dt = M_ZERO
        largest_allowed_time = M_ZERO

        call sys%get_flat_list(flat_list)

        call iter%start(flat_list)
        do while (iter%has_next())
          select type (subsystem => iter%get_next())
          class is (system_t)
            select type (subalgorithm => subsystem%algo)
            class is (propagator_t)
              largest_dt = max(largest_dt, subalgorithm%dt)
            end select
          end select
        end do

        select type(prop => systems%algo)
        class is (propagator_t)
          largest_allowed_time = floor(prop%final_time / largest_dt)* largest_dt
        class default
          ASSERT(.false.)
        end select

        call iter%start(flat_list)
        do while (iter%has_next())
          select type(subsystem => iter%get_next())
          class is (system_t)
            select type (subalgorithm => subsystem%algo)
            class is (propagator_t)
              if( floor(subalgorithm%final_time/subalgorithm%dt) * subalgorithm%dt > largest_allowed_time ) then
                write(message(1), *) "Incommensurate propagation time: The calculation would run into a deadlock, as a system with"
                write(message(2), *) "a smaller timestep would attempt a timestep beyond the last step" &
                  // " of a system with a larger time step."
                write(message(3), *) "Please, adjust the TDPropagationTime and/or the TDTimeStep variables of the systems."
                write(message(4), *) ""
                write(message(5), *) "With the current timesteps, the TDPropagationTime should be ", largest_allowed_time
                call messages_fatal(5, namespace=namespace)
              end if
            end select
          end select
        end do

      end select

    end if

    ![create_interactions] !doxygen marker. Dont delete
    ! Create and initialize interactions
    !
    ! This function is called recursively for all subsystems of systems.
    ! If systems is a multisystem_basic_t container, the partners list contains all subsystems.
    call systems%create_interactions(interactions_factory, partners)
    ![create_interactions]

    select type (systems)
    class is (multisystem_basic_t)
      ! Write the interaction graph as a DOT graph for debug
      if ((debug%interaction_graph .or. debug%interaction_graph_full) .and. mpi_grp_is_root(mpi_world)) then
        iunit_out = io_open('debug/interaction_graph.dot', systems%namespace, action='write')
        write(iunit_out, '(a)') 'digraph {'
        call systems%write_interaction_graph(iunit_out, debug%interaction_graph_full)
        write(iunit_out, '(a)') '}'
        call io_close(iunit_out)
      end if
    end select

    if (.not. systems%process_is_slave()) then
      call messages_write('Info: Octopus initialization completed.', new_line = .true.)
      call messages_write('Info: Starting calculation mode.')
      call messages_info(namespace=namespace)

      !%Variable FromScratch
      !%Type logical
      !%Default false
      !%Section Execution
      !%Description
      !% When this variable is set to true, <tt>Octopus</tt> will perform a
      !% calculation from the beginning, without looking for restart
      !% information.
      !% NOTE: If available, mesh partitioning information will be used for
      !% initializing the calculation regardless of the set value for this variable.
      !%End
      call parse_variable(namespace, 'FromScratch', .false., from_scratch)

      call profiling_in("CALC_MODE")

      select case (calc_mode_id)
      case (OPTION__CALCULATIONMODE__GS)
        select type (systems)
        class is (multisystem_basic_t)
          call multisystem_run(systems, from_scratch)
        type is (electrons_t)
          call ground_state_run(systems, from_scratch)
        end select
      case (OPTION__CALCULATIONMODE__UNOCC)
        call unocc_run(systems, from_scratch)
      case (OPTION__CALCULATIONMODE__TD)
        select type (systems)
        class is (multisystem_basic_t)
          call multisystem_run(systems, from_scratch)
        type is (electrons_t)
          call time_dependent_run(systems, from_scratch)
        end select
      case (OPTION__CALCULATIONMODE__GO)
        call geom_opt_run(systems, from_scratch)
      case (OPTION__CALCULATIONMODE__OPT_CONTROL)
        call opt_control_run(systems)
      case (OPTION__CALCULATIONMODE__EM_RESP)
        select case (get_resp_method(namespace))
        case (FD)
          call static_pol_run(systems, from_scratch)
        case (LR)
          call em_resp_run(systems, from_scratch)
        end select
      case (OPTION__CALCULATIONMODE__CASIDA)
        call casida_run(systems, from_scratch)
      case (OPTION__CALCULATIONMODE__VDW)
        call vdW_run(systems, from_scratch)
      case (OPTION__CALCULATIONMODE__VIB_MODES)
        select case (get_resp_method(namespace))
        case (FD)
          call phonons_run(systems)
        case (LR)
          call phonons_lr_run(systems, from_scratch)
        end select
      case (OPTION__CALCULATIONMODE__ONE_SHOT)
        message(1) = "CalculationMode = one_shot is obsolete. Please use gs with MaximumIter = 0."
        call messages_fatal(1, namespace=namespace)
      case (OPTION__CALCULATIONMODE__KDOTP)
        call kdotp_lr_run(systems, from_scratch)
      case (OPTION__CALCULATIONMODE__DUMMY)
      case (OPTION__CALCULATIONMODE__INVERT_KS)
        call invert_ks_run(systems)
      case (OPTION__CALCULATIONMODE__RECIPE)
        ASSERT(.false.) !this is handled before, if we get here, it is an error
      end select

      call profiling_out("CALC_MODE")
    end if

    select type (systems)
    class is (multisystem_basic_t)
      !Deallocate the external potentials
      call iter%start(partners)
      do while (iter%has_next())
        select type(ptr => iter%get_next())
        class is(external_potential_t)
          partner => ptr
          SAFE_DEALLOCATE_P(partner)
        class is(external_waves_t)
          partner => ptr
          SAFE_DEALLOCATE_P(partner)
        class is(lasers_t)
          partner => ptr
          SAFE_DEALLOCATE_P(partner)
        end select
      end do
    end select

    ! Finalize systems
    SAFE_DEALLOCATE_P(systems)

    call fft_all_end()

    call accel_end(global_namespace)

    call mpi_debug_statistics()

    POP_SUB(run)

  contains

    subroutine calc_mode_init()

      PUSH_SUB(calc_mode_init)

      select case (calc_mode_id)
      case (OPTION__CALCULATIONMODE__GS, OPTION__CALCULATIONMODE__GO, OPTION__CALCULATIONMODE__UNOCC)
        call ground_state_run_init()
      case (OPTION__CALCULATIONMODE__TD)
        call td_run_init()
      case (OPTION__CALCULATIONMODE__CASIDA)
        call casida_run_init()
      end select

      POP_SUB(calc_mode_init)
    end subroutine calc_mode_init

  end subroutine run

end module run_oct_m

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
