#include "gtest/gtest.h"

#include "linearAlgebraLib/linearAlgebraLib.hpp"

// solve the heat equation implicitly of the form dT/dt = gamma * d^2 T/ dx^2 over a domain L using the conjugate
// gradient methdod
// initial condition: 0 everywhere
// boundary condition: T(0) = 0, T(L) = 1
TEST(SystemTest, HeatEquation1DImplicit) {
  // input variables
  const double gamma = 1.0;
  const unsigned numberOfCells = 100;
  const double domainLength = 1.0;
  const double boundaryValueLeft = 0.0;
  const double boundaryValueRight = 1.0;
  const double dx = domainLength / (numberOfCells);

  // vectors and matrices
  linearAlgebraLib::Vector coordinateX(numberOfCells);
  linearAlgebraLib::Vector temperature(numberOfCells);
  linearAlgebraLib::Vector boundaryConditions(numberOfCells);

  linearAlgebraLib::SparseMatrixCSR coefficientMatrix(numberOfCells, numberOfCells);

  // initialise arrays and set-up 1D mesh
  for (unsigned i = 0; i < numberOfCells; ++i) {
    coordinateX[i] = i * dx + dx / 2.0;
    temperature[i] = 0.0;
    boundaryConditions[i] = 0.0;
  }

  // calculate individual matrix coefficients
  const double aE = gamma / dx;
  const double aW = gamma / dx;
  const double aP = -1.0 * (aE + aW);

  // set individual matrix coefficients
  for (unsigned i = 1; i < numberOfCells - 1; ++i) {
    coefficientMatrix.set(i, i, aP);
    coefficientMatrix.set(i, i + 1, aE);
    coefficientMatrix.set(i, i - 1, aW);
  }

  coefficientMatrix.set(0, 0, -1.0 * (aE + 2.0 * aW));
  coefficientMatrix.set(0, 1, aE);
  coefficientMatrix.set(numberOfCells - 1, numberOfCells - 2, aW);
  coefficientMatrix.set(numberOfCells - 1, numberOfCells - 1, -1.0 * (2.0 * aE + aW));

  // set boundary conditions
  boundaryConditions[0] = -2.0 * aW * boundaryValueLeft;
  boundaryConditions[numberOfCells - 1] = -2.0 * aE * boundaryValueRight;

  // solve the linear system using the conjugate gradient method
  linearAlgebraLib::ConjugateGradient CGSolver(numberOfCells);
  CGSolver.setCoefficientMatrix(coefficientMatrix);
  CGSolver.setRightHandSide(boundaryConditions);
  temperature = CGSolver.solve(100, 1e-10);

  // the obtain temperature profile is a linear one of the form T(x) = x. Thus, we can compare it directly against
  // the coordinate vector (which in this case acts as an analytic solution)
  linearAlgebraLib::Vector difference(numberOfCells);
  for (unsigned i = 0; i < numberOfCells; ++i) {
    difference[i] += std::fabs(temperature[i] - coordinateX[i]);
  }

  // ensure that temperature has converged to at least single precision
  ASSERT_NEAR(difference.getL2Norm(), 0.0, 1e-8);
}