#include "linearAlgebraLib/src/conjugateGradient.hpp"

namespace linearAlgebraLib {

ConjugateGradient::ConjugateGradient(unsigned numberOfCells) : LinearAlgebraSolverBase(numberOfCells) { }

Vector ConjugateGradient::solve(unsigned maxIterations, double convergenceThreshold) {
  Vector r0(_rightHandSide.size());
  Vector r1(_rightHandSide.size());
  Vector p0(_rightHandSide.size());
  Vector p1(_rightHandSide.size());
  Vector x(_rightHandSide.size());

  auto &A = _coefficientMatrix;
  auto &b = _rightHandSide;
  
  double alpha = 0.0;
  double beta = 0.0;

  unsigned iteration = 0;

  p1 = b - A * x;
  r1 = b - A * x;

  do {
    r0 = r1;
    p0 = p1;

    alpha = (r0.transpose() * r0) / (p0.transpose() * (A * p0));
    x = x + alpha * p0;
    r1 = r0 - alpha * A * p0;
    beta = (r1.transpose() * r1) / (r0.transpose() * r0);
    p1 = r1 + beta * p0;

    ++iteration;
  } while (iteration < maxIterations && r1.getL2Norm() > convergenceThreshold);

  return x;
}

} // namespace linearAlgebraLib
