// SPDX-FileCopyrightText: Copyright (c) Ken Martin, Will Schroeder, Bill Lorensen
// SPDX-License-Identifier: BSD-3-Clause
#include "vtkCellData.h"
#include "vtkDataSetTriangleFilter.h"
#include "vtkDoubleArray.h"
#include "vtkGhostCellsGenerator.h"
#include "vtkIdTypeArray.h"
#include "vtkImageData.h"
#include "vtkInformation.h"
#include "vtkInformationVector.h"
#include "vtkMPIController.h"
#include "vtkNew.h"
#include "vtkObjectFactory.h"
#include "vtkPointData.h"
#include "vtkRTAnalyticSource.h"
#include "vtkSOADataArrayTemplate.h"
#include "vtkTimerLog.h"
#include "vtkUnsignedCharArray.h"
#include "vtkUnstructuredGrid.h"
#include <sstream>
#include <string>

#include <iostream>

namespace
{
// An RTAnalyticSource that generates GlobalNodeIds
class vtkRTAnalyticSource2 : public vtkRTAnalyticSource
{
public:
  static vtkRTAnalyticSource2* New();
  vtkTypeMacro(vtkRTAnalyticSource2, vtkRTAnalyticSource);

protected:
  vtkRTAnalyticSource2() = default;

  void ExecuteDataWithInformation(vtkDataObject* output, vtkInformation* outInfo) override
  {
    Superclass::ExecuteDataWithInformation(output, outInfo);

    // Split the update extent further based on piece request.
    vtkImageData* data = vtkImageData::GetData(outInfo);
    int* outExt = data->GetExtent();
    int* whlExt = this->GetWholeExtent();

    // find the region to loop over
    int maxX = (outExt[1] - outExt[0]) + 1;
    int maxY = (outExt[3] - outExt[2]) + 1;
    int maxZ = (outExt[5] - outExt[4]) + 1;

    int dX = (whlExt[1] - whlExt[0]) + 1;
    int dY = (whlExt[3] - whlExt[2]) + 1;

    vtkNew<vtkIdTypeArray> ids;
    ids->SetName("GlobalNodeIds");
    ids->SetNumberOfValues(maxX * maxY * maxZ);
    data->GetPointData()->SetGlobalIds(ids);

    vtkIdType cnt = 0;
    for (int idxZ = 0; idxZ < maxZ; idxZ++)
    {
      for (int idxY = 0; idxY < maxY; idxY++)
      {
        for (int idxX = 0; idxX < maxX; idxX++, cnt++)
        {
          ids->SetValue(
            cnt, (idxX + outExt[0]) + (idxY + outExt[2]) * dX + (idxZ + outExt[4]) * (dX * dY));
        }
      }
    }
  }

private:
  vtkRTAnalyticSource2(const vtkRTAnalyticSource2&) = delete;
  void operator=(const vtkRTAnalyticSource2&) = delete;
};

vtkStandardNewMacro(vtkRTAnalyticSource2);

bool CheckFieldData(vtkFieldData* fd)
{
  vtkUnsignedCharArray* fdArray = vtkUnsignedCharArray::SafeDownCast(fd->GetArray("FieldData"));
  if (!fdArray || fdArray->GetValue(0) != 2)
  {
    std::cerr << "Field data array value is not the same as the input" << std::endl;
    return false;
  }

  return true;
}

bool CheckCellDataArray(vtkUnstructuredGrid* unstructuredGrid, vtkDataArray* da)
{
  if (!da)
  {
    std::cerr << "Cell data not found." << std::endl;
    return false;
  }
  if (da->GetNumberOfTuples() <= 0)
  {
    std::cerr << "Cell data has no data." << std::endl;
    return false;
  }

  bool rc = da->GetNumberOfTuples() == unstructuredGrid->GetNumberOfCells();
  return rc;
}

bool CheckCellData(vtkUnstructuredGrid* unstructuredGrid)
{
  vtkCellData* cellData = unstructuredGrid ? unstructuredGrid->GetCellData() : nullptr;

  if (!cellData)
  {
    return false;
  }

  auto vorticity = cellData->GetArray("Vorticity");
  auto pressure = cellData->GetArray("Pressure");

  bool ok = CheckCellDataArray(unstructuredGrid, vorticity) &&
    CheckCellDataArray(unstructuredGrid, pressure);

  return ok;
}

} // anonymous namespace

//------------------------------------------------------------------------------
// Program main
int TestPUnstructuredGridGhostCellsGenerator(int argc, char* argv[])
{
  int ret = EXIT_SUCCESS;
  // Initialize the MPI controller
  vtkNew<vtkMPIController> controller;
  controller->Initialize(&argc, &argv, 0);
  vtkMultiProcessController::SetGlobalController(controller);
  int myRank = controller->GetLocalProcessId();
  int nbRanks = controller->GetNumberOfProcesses();

  // Create the pipeline to produce the initial grid
  vtkNew<vtkRTAnalyticSource2> wavelet;
  const int gridSize = 50;
  wavelet->SetWholeExtent(0, gridSize, 0, gridSize, 0, gridSize);
  vtkNew<vtkDataSetTriangleFilter> tetrahedralize;
  tetrahedralize->SetInputConnection(wavelet->GetOutputPort());
  tetrahedralize->UpdatePiece(myRank, nbRanks, 0);

  vtkNew<vtkUnstructuredGrid> initialGrid;
  initialGrid->ShallowCopy(tetrahedralize->GetOutput());

  // Add field data
  vtkNew<vtkUnsignedCharArray> fdArray;
  fdArray->SetNumberOfTuples(1);
  fdArray->SetName("FieldData");
  fdArray->SetValue(0, 2);
  vtkNew<vtkFieldData> fd;
  fd->AddArray(fdArray);
  initialGrid->SetFieldData(fd);

  // add cell data
  vtkNew<vtkDoubleArray> dblArray;
  dblArray->SetName("Vorticity");
  dblArray->SetNumberOfTuples(initialGrid->GetNumberOfCells());
  for (vtkIdType i = 0; i < initialGrid->GetNumberOfCells(); ++i)
  {
    dblArray->SetValue(i, static_cast<double>(i));
  }
  initialGrid->GetCellData()->AddArray(dblArray);

  double* p = new double[initialGrid->GetNumberOfCells()];
  for (vtkIdType i = 0; i < initialGrid->GetNumberOfCells(); ++i)
  {
    p[i] = static_cast<double>(i);
    dblArray->SetValue(i, static_cast<double>(i));
  }
  vtkNew<vtkSOADataArrayTemplate<double>> soaArray;
  soaArray->SetName("Pressure");
  soaArray->SetNumberOfComponents(1);
  soaArray->SetArray(0, p, initialGrid->GetNumberOfCells(), /*updateMaxId*/ true, /*save*/ true);
  initialGrid->GetCellData()->AddArray(soaArray);

  if (!CheckCellData(initialGrid))
  {
    std::cerr << "Cell data was not initialized correctly" << std::endl;
    ret = EXIT_FAILURE;
  }

  // Prepare the ghost cells generator
  vtkNew<vtkGhostCellsGenerator> ghostGenerator;
  ghostGenerator->SetInputData(initialGrid);
  ghostGenerator->SetController(controller);

  // Check BuildIfRequired option
  ghostGenerator->BuildIfRequiredOff();
  ghostGenerator->UpdatePiece(myRank, nbRanks, 0);

  if (!vtkUnstructuredGrid::SafeDownCast(ghostGenerator->GetOutputDataObject(0))
         ->GetCellGhostArray())
  {
    std::cerr << "Ghost were not generated but were explicitly requested on process "
              << controller->GetLocalProcessId() << std::endl;
    ret = EXIT_FAILURE;
  }

  ghostGenerator->BuildIfRequiredOn();
  ghostGenerator->UpdatePiece(myRank, nbRanks, 0);

  if (vtkUnstructuredGrid::SafeDownCast(ghostGenerator->GetOutputDataObject(0))
        ->GetCellGhostArray())
  {
    std::cerr << "Ghost were generated but were not requested on process "
              << controller->GetLocalProcessId() << std::endl;
    ret = EXIT_FAILURE;
  }

  // Check that field data is copied
  ghostGenerator->Update();
  if (!CheckFieldData(ghostGenerator->GetOutput()->GetFieldData()))
  {
    std::cerr << "Field data was not copied correctly" << std::endl;
    ret = EXIT_FAILURE;
  }

  auto recievedGrid = vtkUnstructuredGrid::SafeDownCast(ghostGenerator->GetOutputDataObject(0));
  if (!CheckCellData(recievedGrid))
  {
    std::cerr << "Cell data was not copied correctly" << std::endl;
    ret = EXIT_FAILURE;
  }

  // Check if algorithm works with empty input on all nodes except first one
  vtkNew<vtkUnstructuredGrid> emptyGrid;
  ghostGenerator->SetInputData(myRank == 0 ? initialGrid : emptyGrid);
  ghostGenerator->Modified();
  for (int step = 0; step < 2; ++step)
  {
    ghostGenerator->UpdatePiece(myRank, nbRanks, 1);
  }
  ghostGenerator->SetInputData(initialGrid);
  ghostGenerator->Modified();

  // Check ghost cells generated with and without the global point ids
  // for several ghost layer levels
  int maxGhostLevel = 2;
  vtkSmartPointer<vtkUnstructuredGrid> outGrids[2];
  for (int ghostLevel = 1; ghostLevel <= maxGhostLevel; ++ghostLevel)
  {
    for (int step = 0; step < 2; ++step)
    {
      ghostGenerator->Modified();
      vtkNew<vtkTimerLog> timer;
      timer->StartTimer();
      ghostGenerator->UpdatePiece(myRank, nbRanks, ghostLevel);
      timer->StopTimer();

      // Save the grid for further analysis
      outGrids[step] = vtkUnstructuredGrid::SafeDownCast(ghostGenerator->GetOutputDataObject(0));

      if (!CheckFieldData(outGrids[step]->GetFieldData()))
      {
        std::cerr << "Field data was not copied" << std::endl;
        ret = EXIT_FAILURE;
      }
      if (!CheckCellData(outGrids[step]))
      {
        std::cerr << "Cell data was not copied" << std::endl;
        ret = EXIT_FAILURE;
      }

      double elapsed = timer->GetElapsedTime();

      // get some performance statistics
      double minGhostUpdateTime = 0.0;
      double maxGhostUpdateTime = 0.0;
      double avgGhostUpdateTime = 0.0;
      controller->Reduce(&elapsed, &minGhostUpdateTime, 1, vtkCommunicator::MIN_OP, 0);
      controller->Reduce(&elapsed, &maxGhostUpdateTime, 1, vtkCommunicator::MAX_OP, 0);
      controller->Reduce(&elapsed, &avgGhostUpdateTime, 1, vtkCommunicator::SUM_OP, 0);
      avgGhostUpdateTime /= static_cast<double>(nbRanks);
      if (controller->GetLocalProcessId() == 0)
      {
        std::cerr << "-- Ghost Level: " << ghostLevel << " Elapsed Time: min=" << minGhostUpdateTime
                  << ", avg=" << avgGhostUpdateTime << ", max=" << maxGhostUpdateTime << std::endl;
      }
    }

    vtkIdType initialNbOfCells = initialGrid->GetNumberOfCells();

    // quantitative correct values for runs with 4 MPI processes
    // components are for [ghostlevel][procid][bounds]
    vtkIdType correctCellCounts[2] = { 675800 / 4, 728800 / 4 };
    double correctBounds[2][4][6] = {
      {
        { 0.000000, 50.000000, 0.000000, 26.000000, 0.000000, 26.000000 },
        { 0.000000, 50.000000, 24.000000, 50.000000, 0.000000, 26.000000 },
        { 0.000000, 50.000000, 0.000000, 26.000000, 24.000000, 50.000000 },
        { 0.000000, 50.000000, 24.000000, 50.000000, 24.000000, 50.000000 },
      },
      { { 0.000000, 50.000000, 0.000000, 27.000000, 0.000000, 27.000000 },
        { 0.000000, 50.000000, 23.000000, 50.000000, 0.000000, 27.000000 },
        { 0.000000, 50.000000, 0.000000, 27.000000, 23.000000, 50.000000 },
        { 0.000000, 50.000000, 23.000000, 50.000000, 23.000000, 50.000000 } }
    };
    for (int step = 0; step < 2; ++step)
    {
      if (nbRanks == 4)
      {
        if (outGrids[step]->GetNumberOfCells() != correctCellCounts[ghostLevel - 1])
        {
          std::cerr << "Wrong number of cells on process " << myRank << " for " << ghostLevel
                    << " ghost levels!\n";
          ret = EXIT_FAILURE;
        }
        double bounds[6];
        outGrids[step]->GetBounds(bounds);
        for (int i = 0; i < 6; i++)
        {
          if (std::abs(bounds[i] - correctBounds[ghostLevel - 1][myRank][i]) > .001)
          {
            std::cerr << "Wrong bounds for " << ghostLevel << " ghost levels!\n";
            ret = EXIT_FAILURE;
          }
        }
      }

      vtkUnsignedCharArray* ghosts =
        vtkArrayDownCast<vtkUnsignedCharArray>(outGrids[step]->GetCellGhostArray());
      if (initialNbOfCells >= outGrids[step]->GetNumberOfCells())
      {
        std::cerr << "Obtained grids for ghost level " << ghostLevel
                  << " has less or as many cells as the input grid!\n";
        ret = EXIT_FAILURE;
      }
      if (!ghosts)
      {
        std::cerr << "Ghost cells array not found at ghost level " << ghostLevel << ", step "
                  << step << "!\n";
        ret = EXIT_FAILURE;
        continue;
      }

      for (vtkIdType i = 0; i < ghosts->GetNumberOfTuples(); ++i)
      {
        unsigned char val = ghosts->GetValue(i);
        if (i < initialNbOfCells && val != 0)
        {
          std::cerr << "Ghost Level " << ghostLevel << " Cell " << i
                    << " is not supposed to be a ghost cell but it is!\n";
          ret = EXIT_FAILURE;
          break;
        }
        if (i >= initialNbOfCells && val != 1)
        {
          std::cerr << "Ghost Level " << ghostLevel << " Cell " << i
                    << " is supposed to be a ghost cell but it's not!\n";
          ret = EXIT_FAILURE;
          break;
        }
      }
    }
  }

  controller->Finalize();
  return ret;
}
