!! Copyright (C) 2002-2006 M. Marques, A. Castro, A. Rubio, G. Bertsch
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.
!!

#include "global.h"

module lalg_basic_oct_m
  use blas_oct_m
  use debug_oct_m
  use global_oct_m
  use, intrinsic :: iso_fortran_env
  use messages_oct_m
  use profiling_oct_m
  use utils_oct_m

  implicit none

  private
  public ::           &
    lalg_swap,        &
    lalg_scal,        &
    lalg_axpy,        &
    lalg_copy,        &
    lalg_nrm2,        &
    lalg_symv,        &
    lalg_gemv,        &
    lalg_gemm,        &
    lalg_gemm_cn,     &
    lalg_gemm_nc,     &
    lalg_gemm_cc,     &
    lalg_trmm,        &
    lalg_symm
  ! ------------------------------------------------------------------
  ! BLAS level I
  ! ------------------------------------------------------------------

  !> swap two vectors
  interface lalg_swap
    module procedure swap_1_2
    module procedure swap_2_2
    module procedure swap_3_2
    module procedure swap_4_2
    module procedure swap_1_4
    module procedure swap_2_4
    module procedure swap_3_4
    module procedure swap_4_4
  end interface lalg_swap

  !> scales a vector by a constant
  interface lalg_scal
    module procedure scal_1_2
    module procedure scal_2_2
    module procedure scal_3_2
    module procedure scal_4_2
    module procedure scal_1_4
    module procedure scal_2_4
    module procedure scal_3_4
    module procedure scal_4_4
    module procedure scal_5_4
    module procedure scal_6_4
  end interface lalg_scal

  !> constant times a vector plus a vector
  interface lalg_axpy
    module procedure axpy_1_2
    module procedure axpy_2_2
    module procedure axpy_3_2
    module procedure axpy_4_2
    module procedure axpy_1_4
    module procedure axpy_2_4
    module procedure axpy_3_4
    module procedure axpy_4_4
    module procedure axpy_5_4
    module procedure axpy_6_4
    module procedure axpy_7_4
  end interface lalg_axpy

  !> Copies a vector x, to a vector y
  interface lalg_copy
    module procedure copy_1_2
    module procedure copy_2_2
    module procedure copy_3_2
    module procedure copy_4_2
    module procedure copy_1_4
    module procedure copy_2_4
    module procedure copy_3_4
    module procedure copy_4_4
  end interface lalg_copy

  !> Returns the euclidean norm of a vector
  interface lalg_nrm2
    module procedure nrm2_2
    module procedure nrm2_4
  end interface lalg_nrm2

  ! ------------------------------------------------------------------
  ! BLAS level II
  ! ------------------------------------------------------------------

  !> Matrix-vector multiplication plus vector.
  interface lalg_symv
    module procedure symv_1_2
    module procedure symv_1_4
    module procedure symv_2_2
    module procedure symv_2_4
  end interface lalg_symv

  interface lalg_gemv
    module procedure gemv_1_2
    module procedure gemv_1_4
    module procedure gemv_2_2
    module procedure gemv_2_4
  end interface lalg_gemv

  ! ------------------------------------------------------------------
  ! BLAS level III
  ! ------------------------------------------------------------------

  !> Matrix-matrix multiplication plus matrix.
  interface lalg_gemm
    module procedure gemm_1_2
    module procedure gemm_1_4
    module procedure gemm_2_2
    module procedure gemm_2_4
    module procedure dgemm_simple
  end interface lalg_gemm

  !> The same as above but with (Hermitian) transpose of A.
  interface lalg_gemm_cn
    module procedure gemm_cn_1_2
    module procedure gemm_cn_1_4
    module procedure gemm_cn_2_2
    module procedure gemm_cn_2_4
  end interface lalg_gemm_cn

  !> The same as lalg_gemm but with (Hermitian) transpose of B.
  interface lalg_gemm_nc
    module procedure gemm_nc_1_2
    module procedure gemm_nc_1_4
    module procedure gemm_nc_2_2
    module procedure gemm_nc_2_4
  end interface lalg_gemm_nc

  !> lalg_gemm with both the (Hermitian) transpose of A and B.
  interface lalg_gemm_cc
    module procedure gemm_cc_1_2
    module procedure gemm_cc_1_4
  end interface lalg_gemm_cc

  !> The following matrix multiplications all expect upper triangular matrices for a.
  !! For real matrices, \f$A = A^T\f$, for complex matrices \f$A = A^H\f$.
  interface lalg_symm
    module procedure symm_1_2
    module procedure symm_1_4
  end interface lalg_symm

  !> Matrix-matrix multiplication.
  interface lalg_trmm
    module procedure trmm_1_2
    module procedure trmm_1_4
  end interface lalg_trmm

contains

!> @brief GEMM with a simplified API for two matrices of consistent shape and type.
!!
!! GEMM performs one of the matrix-matrix operations:
!! ```
!!  C := alpha*op( A )*op( B ) + beta*C,
!! ```
!! where  op( X ) is one of:
!! ```
!! op( X ) = X   or   op( X ) = X**T
!! ```
!! operating on all elements of both matrices. If all defaults are used, the routine performs
!! \f$C = A B\f$. See the official [lapack](https://www.netlib.org/lapack/explore-html/dd/d09/group__gemm)
!! documentation for more details.
  subroutine dgemm_simple(a, b, c, transa, transb, alpha, beta)
    real(real64), contiguous,   intent(in )    :: a(:,:)
    real(real64), contiguous,   intent(in )    :: b(:,:)
    real(real64), contiguous,   intent(inout)  :: c(:,:)          !< If not summing to \p c, it does not need to be
    !                                                             zeroed on input, as \p beta is 0.
    character(len=1), optional, intent(in )    :: transa, transb  !< Transpose \p a and \p b
    real(real64),     optional, intent(in )    :: alpha           !< Scale the product, op( A )*op( B )
    real(real64),     optional, intent(in )    :: beta            !< Scale input value of \p c.
    !                                                             For example, 1.0 allows addition of the MM product to an input \p c.
    !                                                             The default, 0.0, defines \p c as the MM product.

    integer          :: m, k, l, n
    character(len=1) :: ta, tb
    real(real64)     :: p_alpha, p_beta

    PUSH_SUB(dgemm_simple)

    ta = 'N'
    tb = 'N'
    if (present(transa)) ta = transa
    if (present(transb)) tb = transb

    p_alpha = optional_default(alpha, 1.0_real64)
    p_beta = optional_default(beta, 0.0_real64)

    if (ta == 'n' .or. ta == 'N') then
      m = size(a, 1)
      k = size(a, 2)
    else
      m = size(a, 2)
      k = size(a, 1)
    end if
    if (tb == 'n' .or. tb == 'N') then
      l = size(b, 1)
      n = size(b, 2)
    else
      l = size(b, 2)
      n = size(b, 1)
    end if
    ASSERT(size(c, 1) == m)
    ASSERT(size(c, 2) == n)
    ASSERT(k == l)

    call blas_gemm(ta, tb, m, n, k, p_alpha, a(1, 1), lead_dim(a), &
      b(1, 1), lead_dim(b), p_beta, c(1, 1), lead_dim(c))

    POP_SUB(dgemm_simple)

  end subroutine dgemm_simple

#  define N_ARG_TYPES 2
#  include "lalg_basic_blas_inc.F90"
#  undef N_ARG_TYPES

#  define N_ARG_TYPES 4
#  include "lalg_basic_blas_inc.F90"
#  undef N_ARG_TYPES

end module lalg_basic_oct_m


!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
