!! Copyright (C) 2002-2024 M. Marques, A. Castro, A. Rubio, G. Bertsch, A. Buccheri
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

#include "global.h"

!> @brief This module handles the calculation mode.
!!
!! This module uses a module-scope global object to allow calculation modes
!! to set the available parallelization strategies and whether the layout
!! must be compatible with ScaLAPACK, and to allow this information to be
!! accessed elsewhere. It does not, and should not, contain the definitions
!! of the calculation modes themselves, to avoid writing code explicitly
!! dependent on the calculation mode elsewhere.
module calc_mode_par_oct_m
  use debug_oct_m
  use global_oct_m

  implicit none

  private
  public ::          &
    calc_mode_par_t, &
    calc_mode_par

  !> Possible parallelization strategies
  integer, public, parameter ::    &
    P_STRATEGY_SERIAL  = 0,          & !< single domain, all states, k-points on a single processor
    P_STRATEGY_DOMAINS = 1,          & !< parallelization in domains
    P_STRATEGY_STATES  = 2,          & !< parallelization in states
    P_STRATEGY_KPOINTS = 3,          & !< parallelization in k-points
    P_STRATEGY_OTHER   = 4,          & !< something else like e-h pairs
    P_STRATEGY_MAX     = 4

  integer, parameter :: default_parallelization_mask =  ior(ibset(P_STRATEGY_SERIAL, P_STRATEGY_DOMAINS - 1), &
    ibset(P_STRATEGY_SERIAL, P_STRATEGY_KPOINTS - 1))

  !> @brief Octopus Parallelization modes, stored concurrently in a bit representation.
  !!
  !! ibset sets a specific bit in an integer, at a position given by P_STRATEGY_X - 1, to 1.
  !! For example:
  !!  ibset(P_STRATEGY_SERIAL, P_STRATEGY_DOMAINS - 1) = ibset(0, 0) sets the 0th bit of 0 to 1, giving 0001
  !!  ibset(P_STRATEGY_SERIAL, P_STRATEGY_KPOINTS - 1) = ibset(0, 2) sets the 2nd bit of 0 to 1, giving 0100
  !!
  !! ior performs a bitwise inclusive OR operation. For example:
  !!  ior(0001, 0100) = 0101
  !!
  !! For a 32-bit integer, up to 32 strategies can be represented.
  !!
  !! One can query if a given strategy is enabled with btest:
  !! ```fortran
  !!  if (btest(par_mask, P_STRATEGY_X - 1)) ...
  !! ```
  !! or compare two parallel strategy masks like:
  !! ```fortran
  !!  par_strategy == bitand(mc%par_strategy, parallel_mask)
  type calc_mode_par_t
    private
    integer :: par_mask = default_parallelization_mask        !< Parallelization mask
    integer :: def_par_mask = default_parallelization_mask    !< Default Parallelization mask
    logical :: scalapack_compat_ = .false.  !< Is the Parallelization strategy compatible with scalapack
  contains
    procedure :: set_parallelization => calc_mode_par_set_parallelization
    procedure :: unset_parallelization => calc_mode_par_unset_parallelization
    procedure :: set_scalapack_compat => calc_mode_par_set_scalapack_compat
    procedure :: scalapack_compat => calc_mode_par_scalapack_compat
    procedure :: parallel_mask => calc_mode_par_parallel_mask
    procedure :: default_parallel_mask => calc_mode_par_default_parallel_mask
  end type calc_mode_par_t

  !> Singleton instance of parallel calculation mode
  type(calc_mode_par_t) :: calc_mode_par

contains

  !> @brief Add a parallelization strategy to the list of possible ones.
  subroutine calc_mode_par_set_parallelization(this, par, default)
    class(calc_mode_par_t), intent(inout) :: this
    integer, intent(in) :: par
    logical, intent(in) :: default  !< Add Parallelization strategy to defaults

    PUSH_SUB(calc_mode_par_set_parallelization)

    this%par_mask = ibset(this%par_mask, par - 1)
    if (default) this%def_par_mask = ibset(this%def_par_mask, par - 1)

    POP_SUB(calc_mode_par_set_parallelization)

  end subroutine calc_mode_par_set_parallelization

  !> @brief Remove a parallelization strategy from the list of possible ones.
  !! It will also be removed from the default.
  subroutine calc_mode_par_unset_parallelization(this, par)
    class(calc_mode_par_t), intent(inout) :: this
    integer, intent(in) :: par

    PUSH_SUB(calc_mode_par_unset_parallelization)

    this%par_mask = ibclr(this%par_mask, par - 1)
    this%def_par_mask = ibclr(this%def_par_mask, par - 1)

    POP_SUB(calc_mode_par_unset_parallelization)

  end subroutine calc_mode_par_unset_parallelization

  !> @brief Set that the current run mode requires division of states
  !! and domains to be compatible with scalapack.
  subroutine calc_mode_par_set_scalapack_compat(this)
    class(calc_mode_par_t), intent(inout) :: this

    PUSH_SUB(calc_mode_par_set_scalapack_compat)

    this%scalapack_compat_ = .true.

    POP_SUB(calc_mode_par_set_scalapack_compat)

  end subroutine calc_mode_par_set_scalapack_compat

  !> @brief Get whether the current run mode requires division of states
  !! and domains to be compatible with scalapack.
  logical pure function calc_mode_par_scalapack_compat(this) result(compat)
    class(calc_mode_par_t), intent(in) :: this

    compat = this%scalapack_compat_

  end function calc_mode_par_scalapack_compat

  !> @brief Get parallel mask
  integer pure function calc_mode_par_parallel_mask(this) result(par_mask)
    class(calc_mode_par_t), intent(in) :: this

    par_mask = this%par_mask

  end function calc_mode_par_parallel_mask

  !> @brief Get the default parallel mask used for a calculation.
  !!
  !! Note, this might be different from the modes available.
  integer pure function calc_mode_par_default_parallel_mask(this) result(par_mask)
    class(calc_mode_par_t), intent(in) :: this

    par_mask = this%def_par_mask

  end function calc_mode_par_default_parallel_mask

end module calc_mode_par_oct_m

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
