#pragma once

#include "../../../debug/evemon.h"

#include "../../symbolic/var.h"

namespace math::optics::paraxial
{
	/* 3rd order aberration wavefront class */
	template<typename type> class wavefront
	{
	public:

		// ctor from vars
		wavefront(const math::symbolic::var<type>& _spherical,
			const math::symbolic::var<type>& _coma,
			const math::symbolic::var<type>& _astig,
			const math::symbolic::var<type>& _fieldcurv,
			const math::symbolic::var<type>& _dist)
		{
			this->m_data[0] = _spherical;
			this->m_data[1] = _coma;
			this->m_data[2] = _astig;
			this->m_data[3] = _fieldcurv;
			this->m_data[4] = _dist;
		}

		// get spherical aberration
		const math::symbolic::var<type>& spherical(void) const
		{
			return this->m_data[0];
		}

		// get coma
		const math::symbolic::var<type>& coma(void) const
		{
			return this->m_data[1];
		}

		// get astigmatism
		const math::symbolic::var<type>& astig(void) const
		{
			return this->m_data[2];
		}

		// get petzval
		const math::symbolic::var<type>& fieldcurv(void) const
		{
			return this->m_data[3];
		}

		// get distortion
		const math::symbolic::var<type>& dist(void) const
		{
			return this->m_data[4];
		}

		// get rms for given y field pos (0<=y<=1)
		math::symbolic::var<type> rms(const type y) const
		{
			type Q[5][5];

			if (y < 0 || y > 1)
				_warning("computing rms for boggous y");

			// initialize to zero
			for (size_t i = 0; i < 5; i++)
				for (size_t j = 0; j < 5; j++)
					Q[i][j] = 0;

			// short-hand
			const type y2 = y * y;
			const type y4 = y2 * y2;
			const type y6 = y4 * y2;

			// triangular superior
			Q[0][0] = (type)1 / (type)5;
			Q[0][2] = y2 / (type)8;
			Q[0][3] = y2 / (type)4;
			Q[1][1] = y2 / (type)8;
			Q[1][4] = y4 / (type)6;
			Q[2][2] = y4 / (type)8;
			Q[2][3] = y4 / (type)6;
			Q[3][3] = y4 / (type)3;
			Q[4][4] = y6 / (type)4;

			// copy to triangular inferior
			for (size_t i = 0; i < 5; i++)
				for (size_t j = i + 1; j < 5; j++)
					Q[j][i] = Q[i][j];

			// compute rms
			auto sumsq = math::symbolic::var<double>();

			for (size_t i = 0; i < 5; i++)
				for (size_t j = 0; j < 5; j++)
					sumsq = sumsq + this->m_data[i] * this->m_data[j] * Q[i][j];

			return math::symbolic::sqrt(sumsq);
		}

	private:
		math::symbolic::var<type> m_data[5];
	};
};