#pragma once

#include "../../symbolic/var.h"
#include "../../symbolic/matrix.h"

#include "exception.h"

namespace math::optics::paraxial
{
	/* ABCD matrix class */
	template<typename type> class abcd : public math::symbolic::matrix<type>
	{
	public:

		// default ctor is identity matrix
		abcd(void) : math::symbolic::matrix<type>(2, 2)
		{
			this->a() = math::symbolic::var<type>((type)1);
			this->b() = math::symbolic::var<type>();
			this->c() = math::symbolic::var<type>();
			this->d() = math::symbolic::var<type>((type)1);
		}

		// var ctor, can be any matrix
		abcd(const math::symbolic::var<type>& _a, const math::symbolic::var<type>& _b, const math::symbolic::var<type>& _c, const math::symbolic::var<type>& _d) : abcd()
		{
			this->a() = _a;
			this->b() = _b;
			this->c() = _c;
			this->d() = _d;
		}

		// assign operator
		const abcd<type>& operator=(const math::symbolic::matrix<type>& other)
		{
			// matrix must be 2x2
			if (other.num_cols() != 2 || other.num_rows() != 2)
				throw_exception(matrix_is_not_abcd_exception);

			// assign
			this->math::symbolic::matrix<type>::operator=(other);

			return *this;
		}

		// get A term
		math::symbolic::var<type>& a(void)
		{
			return this->operator()(0, 0);
		}

		// get A term (const version)
		const math::symbolic::var<type>& a(void) const
		{
			return this->operator()(0, 0);
		}

		// get B term
		math::symbolic::var<type>& b(void)
		{
			return this->operator()(1, 0);
		}

		// get B term (const version)
		const math::symbolic::var<type>& b(void) const
		{
			return this->operator()(1, 0);
		}

		// get C term
		math::symbolic::var<type>& c(void)
		{
			return this->operator()(0, 1);
		}

		// get C term (const version)
		const math::symbolic::var<type>& c(void) const
		{
			return this->operator()(0, 1);
		}

		// get D term
		math::symbolic::var<type>& d(void)
		{
			return this->operator()(1, 1);
		}

		// get D term (const version)
		const math::symbolic::var<type>& d(void) const
		{
			return this->operator()(1, 1);
		}

		// compute efl as -1/C
		math::symbolic::var<type> efl(void) const
		{
			return -(type)1 / c();
		}

		// compute bfl as -A/C
		math::symbolic::var<type> bfl(void) const
		{
			return -a() / c();
		}

		// compute ffl as -D/C
		math::symbolic::var<type> ffl(void) const
		{
			return -d() / c();
		}
	};
};