!! Copyright (C) 2024. A Buccheri.
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
!! 02110-1301, USA.

#include "global.h"

module kmeans_clustering_oct_m
  use, intrinsic :: iso_fortran_env
  use debug_oct_m
  use electron_space_oct_m
  use global_oct_m
  use mesh_oct_m
  use messages_oct_m
  use mpi_oct_m
  use mpi_distribute_oct_m
  use mpi_lib_oct_m
  use namespace_oct_m
  use profiling_oct_m
  use space_oct_m
  use sort_oct_m
  use quickrnd_oct_m
  implicit none
  private

  public :: assign_points_to_centroids_finite_bc, &
    update_centroids, &
    weighted_kmeans, &
    sample_initial_centroids

contains

  ! TODO(Alex) Issue #1004. Implement `assign_points_to_centroids` for periodic boundary conditions

  !> @brief Assign each grid point to the closest centroid.
  !! A centroid and its set of nearest grid points defines a cluster.
  !!
  !! This can be mathematically expressed as:
  !!\f[
  !!  C_\mu=\left\{Z\left(\mathbf{r}_i\right) \mid \operatorname{dist}\left(Z\left(\mathbf{r}_i\right), Z\left(\mathbf{r}_\mu\right)\right)
  !!      \leq \operatorname{dist}\left(Z\left(\mathbf{r}_i\right), Z\left(\mathbf{r}_m\right)\right) \text { for all } m \neq \mu\right\}
  !!\f]
  !! where \f$ Z(\mathbf{r}_i) \f$ are data points and \f$ Z(\mathbf{r}_\mu) \f$ is the centroid. The distance metric, \f$\operatorname{dist}\f$,
  !! is defined as the dot product in this implementation.
  !!
  !! See eqs 10-13 in [Complex-valued K-means clustering of interpolative separable density fitting algorithm for large-scale hybrid functional
  !! enabled ab initio  molecular dynamics simulations within plane waves](https://doi.org/10.48550/arXiv.2208.07731)
  !!
  !! @note
  !! For any grid point that is equidistant from two or more centroids, it is assigned to the the lowest-indexed centroid.
  !! This is because the algorithm updates the assignment when the current minimum is less than the stored minimum.
  !! If this equality was less than or equal, the point would be assigned to the highest-indexed centroid.
  !! Note, some algorithms will assign equidistant points at random from the set of centroids at minimum distance.
  !! One might consider implementing this behaviour.
  !!
  !! @note
  !! Because we store the map from grid points to centroid indices, and not the opposite map, it is possible to silently miss assigning
  !! a centroid. This will typically happen if two or more centroids are duplicates (or indistinguishable). As described above, the first
  !! instance of the centroid, with the lowest index, will get assigned to the grid point and any duplicates will get passed over.
  subroutine assign_points_to_centroids_finite_bc(mesh, centroids, ip_to_ic)
    class(mesh_t), intent(in)  :: mesh                 !< Real-space grid (np, space%dim)
    real(real64),  intent(in)  :: centroids(:, :)      !< Centroid positions (space%dim, Ncentroids)
    integer,       intent(out) :: ip_to_ic(:)          !< Index array that maps grid indices to centroid indices

    integer                   :: ip, ic, icen
    integer                   :: n_centroids
    real(real64)              :: min_dist, dist
    real(real64), allocatable :: point(:)

    ! Some small finite tolerance is required to distinguish degenerate points, else
    ! `if (dist < min_dist)` can vary on different hardware due to numerical noise.
    ! One could equally choose tol to be some percentage of the grid spacing.
    real(real64), parameter :: tol = 1.0e-13_real64

    PUSH_SUB(assign_points_to_centroids_finite_bc)

    ! Grid to centroid index map should have size of the grid
    ASSERT(size(ip_to_ic) == mesh%np)

    n_centroids = size(centroids, 2)
    SAFE_ALLOCATE(point(1:size(centroids, 1)))

    !$omp parallel do default(shared) private(point, icen, min_dist, dist)
    do ip = 1, mesh%np
      ip_to_ic(ip) = 0
      ! Compute which centroid, grid point `ip` is closest to
      point = mesh%x(ip, :)
      icen = 1
      min_dist = sum((centroids(:, 1) - point(:))**2)
      do ic = 2, n_centroids
        dist = sum((centroids(:, ic) - point(:))**2)
        if (dist < min_dist - tol) then
          min_dist = dist
          icen = ic
        endif
      enddo
      ip_to_ic(ip) = icen
    enddo
    !$omp end parallel do

    SAFE_DEALLOCATE_A(point)

    POP_SUB(assign_points_to_centroids_finite_bc)

  end subroutine assign_points_to_centroids_finite_bc


  !> @brief Compute a new set of centroids.
  !!
  !! A centroid is defined as:
  !! \f[
  !!    \mathbf{r}_\mu = \frac{\sum_{\mathbf{r}_j \in C_\mu} \mathbf{r}_j w(\mathbf{r}_j)}{\sum_{\mathbf{r}_j \in C_\mu} w(\mathbf{r}_j)}
  !! \f]
  !! where \f$\mathbf{r}_j\f$ and \f$w(\mathbf{r}_j\f$ are grid points and weights restricted to the cluster \f$ C_\mu \f$, respectively.
  !!
  !! If using domain decomposition, the routine initially computes centroids using only the local set of grid points,
  !! then reduces this sum over all domains.
  subroutine update_centroids(mesh, weight, ip_to_ic, centroids)
    class(mesh_t), intent(in)    :: mesh                     !< Real-space grid instance
    real(real64),  intent(in)    :: weight(:)                !< Weights (mesh%np)
    integer,       intent(in)    :: ip_to_ic(:)              !< Index array that maps grid indices to centroid indices
    real(real64), contiguous,  intent(inout) :: centroids(:, :)          !< In: centroid positions (space%dim, Ncentroids)
    !                                                          Out: Updated centroid positions

    integer                   :: n_centroids, ic, ip
    real(real64)              :: one_over_denom
    real(real64), allocatable :: denominator(:)

    PUSH_SUB(update_centroids)

    ! The indexing of weight and grid must be consistent => must belong to the same spatial distribution
    ! This does not explicit assert this, but if the grid sizes differ, this is a clear indicator of a problem
    ASSERT(mesh%np == size(weight))

    n_centroids = size(centroids, 2)
    SAFE_ALLOCATE(denominator(1:n_centroids))

    do ic = 1, n_centroids
      centroids(:, ic) = 0._real64
      denominator(ic) = 0._real64
    enddo

    !$omp parallel do private(ic) reduction(+ : centroids, denominator)
    do ip = 1, mesh%np
      ic = ip_to_ic(ip)
      ! Initially accumulate the numerator in `centroids`
      centroids(:, ic) = centroids(:, ic) + (mesh%x(ip, :) * weight(ip))
      denominator(ic) = denominator(ic) + weight(ip)
    enddo
    !$omp end parallel do

    ! Gather contributions to numerator and denominator of all centroids, from all domains of the mesh/grid
    call mesh%allreduce(centroids)
    call mesh%allreduce(denominator)

    ! If division by zero occurs here it implies that the sum(weight) = 0 for all grid points
    ! in cluster ic. This can occur if the initial centroid is poorly chosen at a point with no
    !! associated weight (such as the vacuum of a crystal cell)
    !$omp parallel do private(one_over_denom) reduction(* : centroids)
    do ic = 1, n_centroids
      one_over_denom = 1._real64 / denominator(ic)
      centroids(:, ic) = centroids(:, ic) * one_over_denom
    enddo
    !$omp end parallel do

    SAFE_DEALLOCATE_A(denominator)
    POP_SUB(update_centroids)

  end subroutine update_centroids


  !> @brief Compute the difference in two grids as \f$abs(\mathbf{g}_1 - \mathbf{g}_2)\f$.
  !!
  !! If a component of the difference vector for grid point ip is greater than
  !! the tolerance, the element points_differ(ip) is set to true.
  subroutine compute_grid_difference(points, updated_points, tol, points_differ)
    real(real64), intent(in)    :: points(:, :)         !< Real-space grid (n_dims, N)
    real(real64), intent(in)    :: updated_points(:, :) !< Real-space grid (n_dims, N)
    real(real64), intent(in)    :: tol                  !< Tolerance
    logical,      intent(out)   :: points_differ(:)     !< |a_i - b_i|

    integer                     :: ip, n_dim
    real(real64), allocatable   :: diff(:)

    PUSH_SUB(compute_grid_difference)

    n_dim = size(points, 1)
    allocate(diff(n_dim))

    !$omp parallel do default(shared) private(diff)
    do ip = 1, size(points, 2)
      diff(:) = abs(updated_points(:, ip) - points(:, ip))
      points_differ(ip) = any(diff > tol)
    enddo
    !$omp end parallel do

    if(debug%info) then
      call report_differences_in_grids(points, updated_points, tol, points_differ)
    endif

    POP_SUB(compute_grid_difference)

  end subroutine compute_grid_difference


  !> @brief Report differences returned from `compute_grid_difference`.
  subroutine report_differences_in_grids(points, updated_points, tol, points_differ)
    real(real64), intent(in)  :: points(:, :)      !< Real-space grids (n_dims, N)
    real(real64), intent(in)  :: updated_points(:, :)
    real(real64), intent(in)  :: tol
    logical,      intent(in)  :: points_differ(:)   !< If any element of point_i > tol

    integer,      allocatable :: indices(:)
    integer                   :: i, j, n_unconverged, ndim
    character(len=50)         :: f_string
    real(real64), allocatable :: diff(:)

    PUSH_SUB(report_differences_in_grids)

    indices = pack([(i, i=1,size(points_differ))], points_differ)
    n_unconverged = size(indices)
    ndim = size(points, 1)
    allocate(diff(ndim))

    write(f_string, '(A, I1, A, I1, A, I1, A)') '(', &
    & ndim, '(F16.10, X), ', &
    & ndim, '(F16.10, X), ', &
    & ndim, '(F16.10, X), F16.10)'

    write(message(1), '(a)') "# Current Point  ,  Prior Point  ,  |ri - r_{i-1}|  ,  tol"
    call messages_info(1)
    do j = 1, n_unconverged
      i = indices(j)
      diff(:) = abs(updated_points(:, i) - points(:, i))
      write(message(1), f_string) updated_points(:, i), points(:, i), diff, tol
      call messages_info(1)
    enddo
    write(message(1), *) "Summary:", n_unconverged, "of out", size(points, 2), "are not converged"
    call messages_info(1)

    POP_SUB(report_differences_in_grids)

  end subroutine report_differences_in_grids


  !> @brief Weighted K-means clustering.
  !!
  !! The K-means algorithm divides a set of \f$N_r\f$ samples (in this case, grid points) into \f$N_\mu\f$ disjoin clusters \f$C\f$.
  !! The mean of each cluster defines the cluster centroid. Note that centroids are not, in general, points from the discrete `grid`
  !! - they can take any continuous values that span the grid limits.
  !!
  !! ## Theory
  !! Given a grid, and some initial guess at centroids, the K-means algorithm aims to choose centroids that minimise the inertia,
  !! or within-cluster sum-of-squares criterion:
  !! \f[
  !!    argmin \sum^{N_\mu}_{\mu=1} \sum_{\mathbf{r}_\mathbf{k} \in C_\mu} || Z(\mathbf{r}_k) - Z(\mathbf{r}_\mu)||^2
  !! \f]
  !! This implementation is based on Algorithm 1. given in [Complex-valued K-means clustering of interpolative separable density fitting algorithm
  !! for large-scale hybrid functional enabled ab initio  molecular dynamics simulations within plane waves](https://doi.org/10.48550/arXiv.2208.07731),
  !! however it is equivalent to implementations found in packages such as [scikit-learn](https://scikit-learn.org/stable/modules/clustering.html#k-means).
  !!
  !! ## Algorithm Description
  !! The K-means algorithm consists of looping between two steps:
  !!  1. The first step assigns each sample to its nearest centroid. See `assign_points_to_centroids_finite_bc`
  !!  2. The second step creates new centroids by taking the mean value of all of the samples assigned to each previous centroid. See `update_centroids`
  !! The difference between the old and the new centroids are computed, and the algorithm repeats these last two steps until this value is less than a threshold.
  !!
  !! ## MPI Implementation
  !! * Routine is MPI-safe for domain decomposition.
  !! * It expects a local mesh grid and returns centroids that are defined on all ranks.
  subroutine weighted_kmeans(space, mesh, weight, centroids, n_iter, centroid_tol, discretize, inertia)
    class(space_t),  intent(in)    :: space                  !< Spatial dimensions and periodic dimensions
    class(mesh_t),   intent(in)    :: mesh                   !< Real-space grid instance
    real(real64),    intent(in)    :: weight(:)              !< Weights (n_points)
    real(real64), contiguous,   intent(inout) :: centroids(:, :)        !< In: Initial centroids (n_dim, n_centroid)
    !                                                           Out: Final centroids
    integer,      optional, intent(in ) :: n_iter             !< Optional max number of iterations
    real(real64), optional, intent(in ) :: centroid_tol       !< Optional convergence criterion
    logical,      optional, intent(in ) :: discretize         !< Optional Discretize centroid values to grid points
    real(real64), optional, intent(out) :: inertia            !< Optional metric quantifying the quality of the centroids

    logical                   :: discretize_centroids
    integer                   :: n_iterations, n_centroid, i
    real(real64)              :: tol
    integer,      allocatable :: ip_to_ic(:)
    real(real64), allocatable :: prior_centroids(:, :)
    logical,      allocatable :: points_differ(:)

    PUSH_SUB(weighted_kmeans)

    n_iterations = optional_default(n_iter, 200)
    tol = optional_default(centroid_tol, 1.e-4_real64)
    discretize_centroids = optional_default(discretize, .true.)

    ! Should use a positive number of iterations
    ASSERT(n_iterations >= 1)
    ! Number of weights inconsistent with number of grid points
    ASSERT(size(weight) == mesh%np)
    ! Spatial dimensions of centroids array is inconsistent
    ASSERT(size(centroids, 1) == space%dim)
    ! Assignment of points to centroids only implemented for finite BCs
    ASSERT(.not. space%is_periodic())

    ! Work arrays
    n_centroid = size(centroids, 2)
    SAFE_ALLOCATE_SOURCE(prior_centroids(space%dim, size(centroids, 2)), centroids)
    SAFE_ALLOCATE(ip_to_ic(1:mesh%np))
    SAFE_ALLOCATE(points_differ(1:n_centroid))

    write(message(1), '(a)') 'Debug: Performing weighted Kmeans clustering '
    call messages_info(1, debug_only=.true.)

    do i = 1, n_iterations
      write(message(1), '(a, I3)') 'Debug: Iteration ', i
      call messages_info(1, debug_only=.true.)
      ! TODO(Alex) Issue #1004. Implement `assign_points_to_centroids` for periodic boundary conditions
      call assign_points_to_centroids_finite_bc(mesh, centroids, ip_to_ic)

      call update_centroids(mesh, weight, ip_to_ic, centroids)
      call compute_grid_difference(prior_centroids, centroids, tol, points_differ)

      if (any(points_differ)) then
        prior_centroids = centroids
      else
        write(message(1), '(a)') 'Debug: All centroid points converged'
        call messages_info(1, debug_only=.true.)
        ! Break loop
        exit
      endif

    enddo

    if (discretize_centroids) then
      call mesh_discretize_values_to_mesh(mesh, centroids)
    endif

    if (present(inertia)) then
      call compute_centroid_inertia(mesh, centroids, weight, ip_to_ic, inertia)
    endif

    SAFE_DEALLOCATE_A(prior_centroids)
    SAFE_DEALLOCATE_A(ip_to_ic)
    SAFE_DEALLOCATE_A(points_differ)

    POP_SUB(weighted_kmeans)

  end subroutine weighted_kmeans


  !> @brief Sample initial centroids from the full mesh
  !!
  !! Points are chosen at random, with no replacement (no point can be chosen twice).
  !!
  subroutine sample_initial_centroids(mesh, centroids, seed_value)
    class(mesh_t),             intent(in )           :: mesh             !< mesh
    real(real64), contiguous,  intent(out)           :: centroids(:, :)  !< Initial centroids
    integer(int64),            intent(inout), optional :: seed_value     !< Initial seed value for PRNG
    ! This will get mutated by the Fisher Yates shuffle

    integer(int32)              :: n_centroids           !< Number of total centroids
    integer(int64), allocatable :: centroid_idx(:)       !< Centroid indices
    integer(int64)              :: ipg
    integer(int32)              :: ic

    PUSH_SUB(sample_initial_centroids)

    n_centroids = size(centroids, 2)
    SAFE_ALLOCATE(centroid_idx(1:n_centroids))

    !  Choose n_centroids indices from [1, np_global]
    call fisher_yates_shuffle(n_centroids, mesh%np_global, seed_value, centroid_idx)

    ! Convert ip_global to (x,y,z)
    do ic = 1, n_centroids
      ipg = centroid_idx(ic)
      centroids(:, ic) = mesh_x_global(mesh, ipg)
    enddo

    SAFE_DEALLOCATE_A(centroid_idx)

    POP_SUB(sample_initial_centroids)

  end subroutine sample_initial_centroids


  !> @brief Compute the inertia of all centroids.
  !!
  !! Inertia is defined as the sum of squared distances of grid points to their
  !! closest cluster center, weighted by the grid weights:
  !!
  !! \f[ I = \sum_{\text{ip}}^{N_p}
  !!            w_{\text{ip}} |\mathbf{C}_{\text{ip} \in C} - \mathbf{r}_{\text{ip}}|^2
  !! \f]
  !!
  subroutine compute_centroid_inertia(mesh, centroids, weight, ip_to_ic, inertia)
    class(mesh_t),   intent(in)    :: mesh                   !< Real-space grid instance
    real(real64),    intent(in)    :: centroids(:, :)        !< Centroids (n_dim, n_centroid)
    real(real64),    intent(in)    :: weight(:)              !< Weights (n_points)
    integer,         intent(in)    :: ip_to_ic(:)            !< Map grid point to associated centroid
    real(real64),    intent(out)   :: inertia                !< Inertia
    integer :: ip, ic

    inertia = 0.0_real64

    !$omp parallel do private(ip) reduction(+ : inertia)
    do ip = 1, mesh%np
      ic = ip_to_ic(ip)
      inertia = inertia + weight(ip) * sum((centroids(:, ic) - mesh%x(ip, :))**2)
    enddo
    !$omp end parallel do

    call mesh%allreduce(inertia)

  end subroutine compute_centroid_inertia

end module kmeans_clustering_oct_m

!! Local Variables:
!! mode: f90
!! coding: utf-8
!! End:
