#include "linearAlgebraLib/src/sparseMatrixCSR.hpp"

namespace linearAlgebraLib {

SparseMatrixCSR::SparseMatrixCSR(unsigned rows, unsigned columns) : _numberOfRows(rows), _numberOfColumns(columns) {
  _rows.resize(_numberOfRows + 1, 0.0);
}

void SparseMatrixCSR::set(unsigned row, unsigned column, double value) {
  assert(row < _numberOfRows && column < _numberOfColumns && "Invalid row or column index.");

  bool rowColumnHasEntry = false;
  unsigned index = 0;
  for (unsigned i = _rows[row]; i < _rows[row + 1]; i++) {
    if (_columns[i] == column) {
      index = i;
      rowColumnHasEntry = true;
    }
  }

  if (value != 0.0) {
    if (rowColumnHasEntry) {
      _values[index] = value;
    } else {
      _values.insert(_values.begin() + _rows[row + 1], value);
      _columns.insert(_columns.begin() + _rows[row + 1], column);

      for (unsigned i = row + 1; i <= _numberOfRows; i++) {
        _rows[i]++;
      }
    }
  }
}

double SparseMatrixCSR::get(unsigned row, unsigned column) const {
  assert(row < _numberOfRows && column < _numberOfColumns && "Invalid row or column index.");

  for (unsigned i = _rows[row]; i < _rows[row + 1]; i++) {
    if (_columns[i] == column) {
      return _values[i];
    }
  }

  return 0.0;
}

unsigned SparseMatrixCSR::getNumberOfRows() const {
  return _numberOfRows;
}

unsigned SparseMatrixCSR::getNumberOfColumns() const {
  return _numberOfColumns;
}

Vector SparseMatrixCSR::operator*(const Vector& rhs) {
  assert(rhs.size() == _numberOfColumns && "Vector size does not match the number of columns in the matrix.");

  Vector result(_numberOfColumns);

  for (unsigned row = 0; row < _numberOfRows; ++row) {
    for (unsigned columnIndex = _rows[row]; columnIndex < _rows[row + 1]; ++columnIndex) {
      auto column = _columns[columnIndex];
      result[row] += _values[columnIndex] * rhs[column];
    }
  }

  return result;
}

SparseMatrixCSR operator*(const double &scaleFactor, const SparseMatrixCSR &matrix) {
  SparseMatrixCSR temp = matrix;
  for (unsigned i = 0; i < matrix._values.size(); ++i) {
    temp._values[i] *= scaleFactor;
  }
  return temp;
}

std::ostream& operator<<(std::ostream &os, const SparseMatrixCSR &rhs) {
  os << "_Values: ";
  for (unsigned i = 0; i < rhs._values.size(); i++) {
    os << rhs._values[i] << " ";
  }
  os << std::endl;

  os << "Columns: ";
  for (unsigned i = 0; i < rhs._values.size(); i++) {
    os << rhs._columns[i] << " ";
  }
  os << std::endl;

  os << "Row Pointers: ";
  for (unsigned i = 0; i <= rhs._numberOfRows; i++) {
    os << rhs._rows[i] << " ";
  }
  os << std::endl;

  std::vector<double> printVector(rhs._numberOfColumns);
  for (unsigned currentRow = 0; currentRow < rhs._numberOfRows; ++currentRow) {
    std::fill(printVector.begin(), printVector.end(), 0.0);
    for (unsigned columnIndex = rhs._rows[currentRow]; columnIndex < rhs._rows[currentRow + 1]; ++columnIndex) {
      auto currentColumn = rhs._columns[columnIndex];
      printVector[currentColumn] = rhs._values[columnIndex];
    }
    for (const auto &value : printVector)
      os << value << " ";
    os << std::endl;
  }
  return os;
}

} // namespace linearAlgebraLib
