#pragma once

#include "../../debug/exception.h"
#include "../../debug/safe.h"
#include "../../memory/memory.h"

#include "../exception.h"

#include "var.h"

#include <string>

namespace math::symbolic
{
	// matrix template
	template<typename type> class matrix
	{
	public:

		// default ctor
		matrix(void)
		{
			this->m_rows = 0;
			this->m_cols = 0;
		}

		// copy ctor
		matrix(const matrix<type>& mat) : matrix()
		{
			this->operator=(mat);
		}

		// move operator
		matrix(matrix<type>&& mat) noexcept : matrix()
		{
			this->operator=(std::move(mat));
		}

		// allocate specific size
		matrix(size_t cols, size_t rows)
		{
			// copy size
			this->m_rows = rows;
			this->m_cols = cols;

			// allocate memory
			for (size_t i = 0; i < MAKE_MULT(this->m_rows, this->m_cols); i++)
				this->m_data.push_back(var<type>(0));
		}

		// dtor
		~matrix(void)
		{
			try
			{
				clear();
			}
			catch (...) {}
		}

		// copy operator
		const matrix<type>& operator=(const matrix<type>& obj)
		{
			// delete previous data if any
			clear();

			// allocate size
			this->m_rows = obj.m_rows;
			this->m_cols = obj.m_cols;

			// copy data
			this->m_data = obj.m_data;

			return *this;
		}

		// move operator
		const matrix<type>& operator=(matrix<type>&& obj) noexcept
		{
			// delete previous data if any
			clear();

			// move things
			this->m_rows = obj.m_rows;
			this->m_cols = obj.m_cols;
			this->m_data = std::move(obj.m_data);

			// clear moved object
			obj.m_rows = 0;
			obj.m_cols = 0;

			return *this;
		}

		// clear map
		void clear(void)
		{
			this->m_data.clear();

			this->m_rows = 0;
			this->m_cols = 0;
		}

		// check if valid
		bool is_valid(void) const
		{
			return this->m_data.size() != 0;
		}

		// get pixel (non-const version)
		var<type>& operator()(size_t col, size_t row)
		{
			// throw error if beyond dimensions
			if (row >= this->m_rows || col >= this->m_cols)
				throw_exception(matrix_out_of_bounds_exception);

			// return pixel reference
			return this->m_data[MAKE_XY(col, row, this->m_cols)];
		}

		// get pixel (const version)
		const var<type>& operator()(size_t col, size_t row) const
		{
			// throw error if beyond dimensions
			if (row >= this->m_rows || col >= this->m_cols)
				throw_exception(matrix_out_of_bounds_exception);

			// return pixel data
			return this->m_data[MAKE_XY(col, row, this->m_cols)];
		}

		// return number of rows
		size_t num_rows(void) const
		{
			return this->m_rows;
		}

		// return number of columns
		size_t num_cols(void) const
		{
			return this->m_cols;
		}

	private:
		size_t m_cols, m_rows;

		std::vector< var<type> > m_data;
	};

	// return num_cols
	template<typename type> size_t num_rows(const matrix<type>& obj)
	{
		return obj.num_rows();
	}

	// return num_rows
	template<typename type> size_t num_cols(const matrix<type>& obj)
	{
		return obj.num_cols();
	}

	// addition of two matrix
	template<typename type> auto operator+(const matrix<type>& a, const matrix<type>& b)
	{
		if (a.is_valid() && !b.is_valid())
			return a;

		if (!a.is_valid() && b.is_valid())
			return b;

		if (!a.is_valid() && !b.is_valid())
			throw_exception(invalid_matrix_exception);

		if (a.num_rows() != b.num_rows() || a.num_cols() != b.num_cols())
			throw_exception(matrix_size_mismatch_exception, a.num_rows(), a.num_cols(), b.num_rows(), b.num_cols());

		matrix<type> ret(a.num_cols(), a.num_rows());

		for (size_t row = 0; row < ret.num_rows(); row++)
			for (size_t col = 0; col < ret.num_cols(); col++)
				ret(col, row) = a(col, row) + b(col, row);

		return ret;
	}

	// addition of matrix and constant
	template<typename type> auto operator+(const matrix<type>& a, const type value)
	{
		if (!a.is_valid())
			throw_exception(invalid_matrix_exception);

		matrix<type> ret(a.num_cols(), a.num_rows());

		for (size_t row = 0; row < ret.num_rows(); row++)
			for (size_t col = 0; col < ret.num_cols(); col++)
				ret(col, row) = a(col, row) + value;

		return ret;
	}

	// addition of constant and matrix
	template<typename type> auto operator+(const type value, const matrix<type>& a)
	{
		if (!a.is_valid())
			throw_exception(invalid_matrix_exception);

		matrix<type> ret(a.num_cols(), a.num_rows());

		for (size_t row = 0; row < ret.num_rows(); row++)
			for (size_t col = 0; col < ret.num_cols(); col++)
				ret(col, row) = value + a(col, row);

		return ret;
	}

	// subtaction of two matrix
	template<typename type> auto operator-(const matrix<type>& a, const matrix<type>& b)
	{
		if (a.is_valid() && !b.is_valid())
			return a;

		if (!a.is_valid() && b.is_valid())
			return -b;

		if (!a.is_valid() && !b.is_valid())
			throw_exception(invalid_matrix_exception);

		if (a.num_rows() != b.num_rows() || a.num_cols() != b.num_cols())
			throw_exception(matrix_size_mismatch_exception, a.num_rows(), a.num_cols(), b.num_rows(), b.num_cols());

		matrix<type> ret(a.num_cols(), a.num_rows());

		for (size_t row = 0; row < ret.num_rows(); row++)
			for (size_t col = 0; col < ret.num_cols(); col++)
				ret(col, row) = a(col, row) - b(col, row);

		return ret;
	}

	// subtaction and matrix and constant
	template<typename type> auto operator-(const matrix<type>& a, const type value)
	{
		if (!a.is_valid())
			throw_exception(invalid_matrix_exception);

		matrix<type> ret(a.num_cols(), a.num_rows());

		for (size_t row = 0; row < ret.num_rows(); row++)
			for (size_t col = 0; col < ret.num_cols(); col++)
				ret(col, row) = a(col, row) - value;

		return ret;
	}

	// subtaction of constant and matrix
	template<typename type> auto operator-(const type value, const matrix<type>& a)
	{
		if (!a.is_valid())
			throw_exception(invalid_matrix_exception);

		matrix<type> ret(a.num_cols(), a.num_rows());

		for (size_t row = 0; row < ret.num_rows(); row++)
			for (size_t col = 0; col < ret.num_cols(); col++)
				ret(col, row) = value - a(col, row);

		return ret;
	}

	// multiplication of two matrix
	template<typename type> auto operator*(const matrix<type>& a, const matrix<type>& b)
	{
		if (a.is_valid() && !b.is_valid())
			return a;

		if (!a.is_valid() && b.is_valid())
			return b;

		if (!a.is_valid() && !b.is_valid())
			throw_exception(invalid_matrix_exception);

		if (a.num_cols() != b.num_rows())
			throw_exception(matrix_size_mismatch_exception, a.num_rows(), a.num_cols(), b.num_rows(), b.num_cols());

		matrix<type> ret(b.num_cols(), a.num_rows());

		for (size_t row = 0; row < ret.num_rows(); row++)
			for (size_t col = 0; col < ret.num_cols(); col++)
			{
				ret(col, row) = a(0, row) * b(col, 0);

				for (size_t k = 1; k < a.num_cols(); k++)
					ret(col, row) = ret(col, row) + a(k, row) * b(col, k);
			}

		return ret;
	}

	// multiplication matrix and constant
	template<typename type>  auto operator*(const matrix<type>& a, const type value)
	{
		if (!a.is_valid())
			throw_exception(invalid_matrix_exception);

		matrix<type> ret(a.num_cols(), a.num_rows());

		for (size_t row = 0; row < ret.num_rows(); row++)
			for (size_t col = 0; col < ret.num_cols(); col++)
				ret(col, row) = a(col, row) * value;

		return ret;
	}

	// multiplication constant and matrix
	template<typename type>  auto operator*(const type value, const matrix<type>& a)
	{
		if (!a.is_valid())
			throw_exception(invalid_matrix_exception);

		matrix<type> ret(a.num_cols(), a.num_rows());

		for (size_t row = 0; row < ret.num_rows(); row++)
			for (size_t col = 0; col < ret.num_cols(); col++)
				ret(col, row) = value * a(col, row);

		return ret;
	}

	// division matrix and constant
	template<typename type> auto operator/(const matrix<type>& a, const type value)
	{
		if (!a.is_valid())
			throw_exception(invalid_matrix_exception);

		matrix<type> ret(a.num_cols(), a.num_rows());

		for (size_t row = 0; row < ret.num_rows(); row++)
			for (size_t col = 0; col < ret.num_cols(); col++)
				ret(col, row) = a(col, row) / value;

		return ret;
	}
};