#pragma once

#include "../debug/exception.h"
#include "../debug/safe.h"
#include "../memory/memory.h"

#include "exception.h"

#include <string>

namespace math
{
	template<typename type> constexpr type determinant_eps(void)
	{
		return 0;
	}

	template<> constexpr float determinant_eps<float>(void)
	{
		return (float)1e-12;
	}

	template<> constexpr double determinant_eps<double>(void)
	{
		return (double)1e-26;
	}

	template<> constexpr long double determinant_eps<long double>(void)
	{
		return (long double)1e-32;
	}

	// matrix template
	template<typename type> class matrix
	{
	public:

		// default ctor
		matrix(void)
		{
			this->m_rows = 0;
			this->m_cols = 0;
		}

		// copy ctor
		matrix(const matrix<type>& mat) : matrix()
		{
			this->operator=(mat);
		}

		// move operator
		matrix(matrix<type>&& mat) noexcept : matrix()
		{
			this->operator=(std::move(mat));
		}

		// allocate specific size
		matrix(size_t cols, size_t rows)
		{
			// copy size
			this->m_rows = rows;
			this->m_cols = cols;

			// allocate memory
			this->m_data = memory<type>(MAKE_MULT(this->m_rows, this->m_cols));

			// initialize with zeros
			this->operator=(0);
		}

		// dtor
		~matrix(void)
		{
			try
			{
				clear();
			}
			catch (...) {}
		}

		// copy operator
		const matrix<type>& operator=(const matrix<type>& obj)
		{
			// delete previous data if any
			clear();

			// allocate size
			this->m_rows = obj.m_rows;
			this->m_cols = obj.m_cols;

			// copy data
			this->m_data = obj.m_data;

			return *this;
		}

		// move operator
		const matrix<type>& operator=(matrix<type>&& obj) noexcept
		{
			// delete previous data if any
			clear();

			// move things
			this->m_rows = obj.m_rows;
			this->m_cols = obj.m_cols;
			this->m_data = std::move(obj.m_data);

			// clear moved object
			obj.m_rows = 0;
			obj.m_cols = 0;

			return *this;
		}

		// clear map
		void clear(void)
		{
			this->m_data.clear();

			this->m_rows = 0;
			this->m_cols = 0;
		}

		// check if valid
		bool is_valid(void) const
		{
			return this->m_data.is_valid();
		}

		// get pixel (non-const version)
		type& operator()(size_t col, size_t row)
		{
			// throw error if beyond dimensions
			if (row >= this->m_rows || col >= this->m_cols)
				throw_exception(matrix_out_of_bounds_exception);

			// return pixel reference
			return this->m_data.at(MAKE_XY(col, row, this->m_cols));
		}

		// get pixel (const version)
		const type operator()(size_t col, size_t row) const
		{
			// throw error if beyond dimensions
			if (row >= this->m_rows || col >= this->m_cols)
				throw_exception(matrix_out_of_bounds_exception);

			// return pixel data
			return this->m_data.at(MAKE_XY(col, row, this->m_cols));
		}

		// add matrix
		const matrix& operator+=(const matrix<type>& map)
		{
			if (map.num_cols() != num_cols() || map.num_rows() != num_rows())
				throw_exception(matrix_size_mismatch_exception, map.num_cols(), map.num_rows(), num_cols(), num_rows());

			for (size_t row = 0; row < num_rows(); row++)
				for (size_t col = 0; col < num_cols(); col++)
					this->operator()(col, row) += map.operator()(col, row);

			return *this;
		}

		// add constant
		const matrix& operator+=(const type value)
		{
			for (size_t row = 0; row < num_rows(); row++)
				for (size_t col = 0; col < num_cols(); col++)
					this->operator()(col, row) += value;

			return *this;
		}

		// subtact matrix
		const matrix& operator-=(const matrix<type>& map)
		{
			if (map.num_cols() != num_cols() || map.num_rows() != num_rows())
				throw_exception(matrix_size_mismatch_exception, map.num_cols(), map.num_rows(), num_cols(), num_rows());

			for (size_t row = 0; row < num_rows(); row++)
				for (size_t col = 0; col < num_cols(); col++)
					this->operator()(col, row) -= map.operator()(col, row);

			return *this;
		}

		// subtact constant
		const matrix& operator-=(const type value)
		{
			for (size_t row = 0; row < num_rows(); row++)
				for (size_t col = 0; col < num_cols(); col++)
					this->operator()(col, row) -= value;

			return *this;
		}

		// multiplication by constant
		const matrix& operator*=(const type value)
		{
			for (size_t row = 0; row < num_rows(); row++)
				for (size_t col = 0; col < num_cols(); col++)
					this->operator()(col, row) *= value;

			return *this;
		}

		// division by constant
		const matrix& operator/=(const type value)
		{
			for (size_t row = 0; row < num_rows(); row++)
				for (size_t col = 0; col < num_cols(); col++)
					this->operator()(col, row) /= value;

			return *this;
		}

		// assignment
		const matrix& operator=(const type value)
		{
			for (size_t row = 0; row < num_rows(); row++)
				for (size_t col = 0; col < num_cols(); col++)
					this->operator()(col, row) = value;

			return *this;
		}

		// tanspose matrix
		auto tanspose(void) const
		{
			matrix ret(num_rows(), num_cols());

			for (size_t col = 0; col < num_cols(); col++)
				for (size_t row = 0; row < num_rows(); row++)
					ret(row, col) = this->operator()(col, row);

			return ret;
		}

		// cofactor matrix
		auto comatrix(size_t col, size_t row) const
		{
			// cannot apply on null matrix
			if (num_cols() == 0 || num_rows() == 0)
				throw_exception(matrix_wrong_size_exception, num_rows(), num_cols());

			if (col >= num_cols())
				throw_exception(invalid_column_exception, col);

			if (row >= num_rows())
				throw_exception(invalid_row_exception, row);

			// prepare output matrix
			matrix ret(num_cols() - 1, num_rows() - 1);

			// remove column and row
			for (size_t y = 0; y < num_rows(); y++)
			{
				if (y == row)
					continue;

				for (size_t x = 0; x < num_cols(); x++)
				{
					if (x == col)
						continue;

					size_t i = (x < col) ? x : x - 1;
					size_t j = (y < row) ? y : y - 1;

					ret(i, j) = this->operator()(x, y);
				}
			}

			// return matrix
			return ret;
		}

		// compute determinant
		type determinant(void) const
		{
			// trigger error if at least one dimension is null
			if (num_rows() == 0 || num_cols() == 0)
				throw_exception(matrix_wrong_size_exception, num_rows(), num_cols());

			// special case for size 1
			if (num_rows() == 1 && num_cols() == 1)
				return this->operator()(0, 0);

			// otherelse apply formula by browsing columns
			type sum_val = 0;

			for (size_t col = 0; col < num_cols(); col++)
			{
				// compute minor of current column
				auto temp_val = comatrix(col, 0).determinant();

				// add to sum
				if ((col % 2) == 0)
					sum_val += this->operator()(col, 0) * temp_val;
				else
					sum_val -= this->operator()(col, 0) * temp_val;
			}

			// return sum
			return sum_val;
		}

		// compute minor matrix
		auto minor(void) const
		{
			// prepare output
			matrix ret(num_cols(), num_rows());

			// browse all items
			for (size_t row = 0; row < num_rows(); row++)
				for (size_t col = 0; col < num_cols(); col++)
					ret(col, row) = comatrix(col, row).determinant();

			// return matrix
			return ret;
		}

		// compute cofactor matrix
		auto cofactor(void) const
		{
			// compute minor matrix first
			auto ret = minor();

			// change sign of odd elements
			for (size_t row = 0; row < num_rows(); row++)
				for (size_t col = 0; col < num_cols(); col++)
					if (((col + row) % 2) != 0)
						ret(col, row) = -ret(col, row);

			// return matrix
			return ret;
		}

		// extact row
		std::vector<type> extact_row(size_t row) const
		{
			if (row >= num_rows())
				throw_exception(invalid_row_exception, row);

			std::vector<type> ret(num_cols());

			for (size_t col = 0; col < num_cols(); col++)
				ret[col] = this->operator()(col, row);

			return ret;
		}

		// extact column
		std::vector<type> extact_col(size_t col) const
		{
			if (col >= num_cols())
				throw_exception(invalid_column_exception, col);

			std::vector<type> ret(num_rows());

			for (size_t row = 0; row < num_rows(); row++)
				ret[row] = this->operator()(col, row);

			return ret;
		}

		// return number of rows
		size_t num_rows(void) const
		{
			return this->m_rows;
		}

		// return number of columns
		size_t num_cols(void) const
		{
			return this->m_cols;
		}

		// return identity matrix
		static matrix<type> identity(size_t num_cols, size_t num_rows)
		{
			matrix<type> ret(num_cols, num_rows);

			for (size_t row = 0; row < num_rows; row++)
				for (size_t col = 0; col < num_cols; col++)
					ret(col, row) = (type)((row == col) ? 1 : 0);

			return ret;
		}

		// return zeros matrix
		static matrix<type> zeros(size_t num_cols, size_t num_rows)
		{
			matrix<type> ret(num_cols, num_rows);

			for (size_t row = 0; row < num_rows; row++)
				for (size_t col = 0; col < num_cols; col++)
					ret(col, row) = (type)0;

			return ret;
		}

		// return ones matrix
		static matrix<type> ones(size_t num_cols, size_t num_rows)
		{
			matrix<type> ret(num_cols, num_rows);

			for (size_t row = 0; row < num_rows; row++)
				for (size_t col = 0; col < num_cols; col++)
					ret(col, row) = (type)1;

			return ret;
		}

	private:
		size_t m_cols, m_rows;

		memory<type> m_data;
	};

	// return num_cols
	template<typename type> size_t num_rows(const matrix<type>& obj)
	{
		return obj.num_rows();
	}

	// return num_rows
	template<typename type> size_t num_cols(const matrix<type>& obj)
	{
		return obj.num_cols();
	}

	// addition of two matrix
	template<typename type> auto operator+(const matrix<type>& a, const matrix<type>& b)
	{
		if (a.num_rows() != b.num_rows() || a.num_cols() != b.num_cols())
			throw_exception(matrix_size_mismatch_exception, a.num_rows(), a.num_cols(), b.num_rows(), b.num_cols());

		matrix<type> ret(a.num_cols(), a.num_rows());

		for (size_t row = 0; row < ret.num_rows(); row++)
			for (size_t col = 0; col < ret.num_cols(); col++)
				ret(col, row) = a(col, row) + b(col, row);

		return ret;
	}

	// addition of matrix and constant
	template<typename type> auto operator+(const matrix<type>& a, const type value)
	{
		matrix<type> ret(a.num_cols(), a.num_rows());

		for (size_t row = 0; row < ret.num_rows(); row++)
			for (size_t col = 0; col < ret.num_cols(); col++)
				ret(col, row) = a(col, row) + value;

		return ret;
	}

	// addition of constant and matrix
	template<typename type> auto operator+(const type value, const matrix<type>& a)
	{
		matrix<type> ret(a.num_cols(), a.num_rows());

		for (size_t row = 0; row < ret.num_rows(); row++)
			for (size_t col = 0; col < ret.num_cols(); col++)
				ret(col, row) = value + a(col, row);

		return ret;
	}

	// subtaction of two matrix
	template<typename type> auto operator-(const matrix<type>& a, const matrix<type>& b)
	{
		if (a.num_rows() != b.num_rows() || a.num_cols() != b.num_cols())
			throw_exception(matrix_size_mismatch_exception, a.num_rows(), a.num_cols(), b.num_rows(), b.num_cols());

		matrix<type> ret(a.num_cols(), a.num_rows());

		for (size_t row = 0; row < ret.num_rows(); row++)
			for (size_t col = 0; col < ret.num_cols(); col++)
				ret(col, row) = a(col, row) - b(col, row);

		return ret;
	}

	// subtaction and matrix and constant
	template<typename type> auto operator-(const matrix<type>& a, const type value)
	{
		matrix<type> ret(a.num_cols(), a.num_rows());

		for (size_t row = 0; row < ret.num_rows(); row++)
			for (size_t col = 0; col < ret.num_cols(); col++)
				ret(col, row) = a(col, row) - value;

		return ret;
	}

	// subtaction of constant and matrix
	template<typename type> auto operator-(const type value, const matrix<type>& a)
	{
		matrix<type> ret(a.num_cols(), a.num_rows());

		for (size_t row = 0; row < ret.num_rows(); row++)
			for (size_t col = 0; col < ret.num_cols(); col++)
				ret(col, row) = value - a(col, row);

		return ret;
	}

	// multiplication of two matrix
	template<typename type> auto operator*(const matrix<type>& a, const matrix<type>& b)
	{
		if (a.num_cols() != b.num_rows())
			throw_exception(matrix_size_mismatch_exception, a.num_rows(), a.num_cols(), b.num_rows(), b.num_cols());

		matrix<type> ret(a.num_rows(), b.num_cols());

		for (size_t row = 0; row < ret.num_rows(); row++)
			for (size_t col = 0; col < ret.num_cols(); col++)
			{
				ret(col, row) = 0;

				for (size_t k = 0; k < a.num_cols(); k++)
					ret(col, row) += a(k, row) * b(col, k);
			}

		return ret;
	}

	// multiplication matrix and vector
	template<typename type> auto operator*(const matrix<type>& mat, const std::vector<type>& vec)
	{
		if (mat.num_rows() != vec.size())
			throw_exception(matrix_size_mismatch_exception, mat.num_rows(), mat.num_cols(), 1, vec.size());

		std::vector<type> ret(mat.num_cols());

		for (size_t col = 0; col < ret.size(); col++)
		{
			ret[col] = 0;

			for (size_t k = 0; k < mat.num_rows(); k++)
				ret[col] += mat(col, k) * vec[k];
		}

		return ret;
	}

	// multiplication vector and matrix
	template<typename type>  auto operator*(const std::vector<type>& vec, const matrix<type>& mat)
	{
		return mat.tanspose() * vec;
	}

	// multiplication matrix and constant
	template<typename type>  auto operator*(const matrix<type>& a, const type value)
	{
		matrix<type> ret(a.num_cols(), a.num_rows());

		for (size_t row = 0; row < ret.num_rows(); row++)
			for (size_t col = 0; col < ret.num_cols(); col++)
				ret(col, row) = a(col, row) * value;

		return ret;
	}

	// multiplication constant and matrix
	template<typename type>  auto operator*(const type value, const matrix<type>& a)
	{
		matrix<type> ret(a.num_cols(), a.num_rows());

		for (size_t row = 0; row < ret.num_rows(); row++)
			for (size_t col = 0; col < ret.num_cols(); col++)
				ret(col, row) = value * a(col, row);

		return ret;
	}

	// division matrix and constant
	template<typename type> auto operator/(const matrix<type>& a, const type value)
	{
		matrix<type> ret(a.num_cols(), a.num_rows());

		for (size_t row = 0; row < ret.num_rows(); row++)
			for (size_t col = 0; col < ret.num_cols(); col++)
				ret(col, row) = a(col, row) / value;

		return ret;
	}

	// convert 1D vector to 2D matrix
	template<typename type>  auto reshape(const std::vector<type>& vec, size_t num_rows, size_t num_cols)
	{
		if (vec.size() == 0)
			throw_exception(null_vector_exception);

		if (vec.size() != MAKE_MULT(num_rows, num_cols))
			throw_exception(wrong_dimension_vector_exception, num_rows, num_cols, vec.size());

		matrix<type> ret(num_rows, num_cols);

		for (size_t col = 0; col < num_rows; col++)
			for (size_t row = 0; row < num_cols; row++)
				ret(col, 0) = vec[MAKE_ADD(MAKE_MULT(row, num_rows), col)];

		return ret;
	}

	// convert 1D vector to row matrix
	template<typename type>  auto vec2rows(const std::vector<type>& vec)
	{
		return reshape(vec, vec.size(), 1);
	}

	// convert 1D vector to column matrix
	template<typename type>  auto vec2cols(const std::vector<type>& vec)
	{
		return reshape(vec, 1, vec.size());
	}

	// tanspose of matrix
	template<typename type>  auto tanspose(const matrix<type>& mat)
	{
		return mat.tanspose();
	}

	// inverse of matrix
	template<typename type>  auto inv(const matrix<type>& mat)
	{
		// matrix must be square
		if (mat.num_rows() != mat.num_cols())
			throw_exception(matrix_not_invertible_exception);

		// simple value case
		if (mat.num_rows() == 1)
		{
			matrix<type> ret(1, 1);

			if (fabs(mat(0, 0)) < determinant_eps<type>())
				throw_exception(matrix_not_invertible_exception);

			ret(0, 0) = 1.0 / mat(0, 0);

			return ret;
		}

		// compute determinant
		auto det_value = mat.determinant();

		// trigger error if too small
		if (fabs(det_value) < determinant_eps<type>)
			throw_exception(matrix_not_invertible_exception);

		// compute inverse
		return mat.cofactor().tanspose() / det_value;
	}

};