#pragma once

#include "base.h"
#include "add.h"

#include <memory>
#include <string>

namespace math::solve
{
	/* mult two solves */
	template<typename type> class mult : public solve::base<type>
	{
	public:
		mult(std::shared_ptr<const math::solve::base<type>> a, std::shared_ptr<const math::solve::base<type>> b) : base<type>({ a, b })
		{
			this->m_a = a;
			this->m_b = b;
		}

		virtual std::string to_string(void) const override
		{
			if (this->m_a == nullptr || this->m_b == nullptr)
				return "0";

			return base<type>::concat({ this->m_b->to_string(), this->m_a->to_string(), "*" });
		}

	private:
		virtual const type compute(void) const
		{
			auto _a = (this->m_a != nullptr) ? this->m_a->eval() : 0;
			auto _b = (this->m_b != nullptr) ? this->m_b->eval() : 0;

			return _a * _b;
		}

		virtual std::shared_ptr<const math::solve::base<type>> create_derivative(const math::solve::base<type>* p) const override
		{
			// (a * b)' = a'*b + b'*a
			auto _da = (this->m_a != nullptr) ? this->m_a->get_derivative(p) : nullptr;
			auto _db = (this->m_b != nullptr) ? this->m_b->get_derivative(p) : nullptr;

			if (_da != nullptr && _db != nullptr)
				return std::make_shared<add<type>>(std::make_shared<mult<type>>(_da, this->m_b), std::make_shared<mult<type>>(this->m_a, _db));

			if (_da != nullptr && _db == nullptr && this->m_b != nullptr)
				return std::make_shared<mult<type>>(_da, this->m_b);

			if (_db != nullptr && _da == nullptr && this->m_a != nullptr)
				return std::make_shared<mult<type>>(this->m_a, _db);

			return nullptr;
		}

		virtual bool has_dependency(const math::solve::base<type>* p) const override
		{
			if (this->m_a != nullptr && this->m_a->has(p))
				return true;

			if (this->m_b != nullptr && this->m_b->has(p))
				return true;

			return false;
		}

	private:
		std::shared_ptr<const math::solve::base<type>> m_a, m_b;
	};

	/* mult one solve and one const */
	template<typename type> class mult_const : public solve::base<type>
	{
	public:
		mult_const(std::shared_ptr<const math::solve::base<type>> a, const type b) : base<type>({ a })
		{
			this->m_a = a;
			this->m_b = b;
		}

		virtual std::string to_string(void) const override
		{
			if (this->m_a == nullptr)
				return "0";

			return base<type>::concat({ type2string(this->m_b), this->m_a->to_string(), "*" });
		}

	private:
		virtual const type compute(void) const
		{
			auto _a = (this->m_a != nullptr) ? this->m_a->eval() : 0;

			return _a * this->m_b;
		}

		virtual std::shared_ptr<const math::solve::base<type>> create_derivative(const math::solve::base<type>* p) const override
		{
			// (a * k)' = k * a'
			auto _da = (this->m_a != nullptr) ? this->m_a->get_derivative(p) : nullptr;

			if (_da != nullptr && this->m_b != 0)
				return std::make_shared<mult_const<type>>(_da, this->m_b);

			return nullptr;
		}

		virtual bool has_dependency(const math::solve::base<type>* p) const override
		{
			if (this->m_a != nullptr && this->m_a->has(p))
				return true;

			return false;
		}

	private:
		std::shared_ptr<const math::solve::base<type>> m_a;
		type m_b;
	};

	/* mult one const and one solve */
	template<typename type> class const_mult : public solve::base<type>
	{
	public:
		const_mult(const type a, std::shared_ptr<const math::solve::base<type>> b) : base<type>({ b })
		{
			this->m_a = a;
			this->m_b = b;
		}

		virtual std::string to_string(void) const override
		{
			if (this->m_a == 0 || this->m_b == nullptr)
				return "0";

			return base<type>::concat({ this->m_b->to_string(), type2string(this->m_a), "*" });
		}

	private:
		virtual const type compute(void) const
		{
			auto _b = (this->m_b != nullptr) ? this->m_b->eval() : 0;

			return this->m_a * _b;
		}

		virtual std::shared_ptr<const math::solve::base<type>> create_derivative(const math::solve::base<type>* p) const override
		{
			// (k * b)' = k * b'
			auto _db = (this->m_b != nullptr) ? this->m_b->get_derivative(p) : nullptr;

			if (_db != nullptr && this->m_a != 0)
				return std::make_shared<const_mult<type>>(this->m_a, _db);

			return nullptr;
		}

		virtual bool has_dependency(const math::solve::base<type>* p) const override
		{
			if (this->m_b != nullptr && this->m_b->has(p))
				return true;

			return false;
		}

	private:
		type m_a;
		std::shared_ptr<const math::solve::base<type>> m_b;
	};

	/* square solves */
	template<typename type> class square : public solve::base<type>
	{
	public:
		square(std::shared_ptr<const math::solve::base<type>> a) : base<type>({ a })
		{
			this->m_a = a;
		}

		virtual std::string to_string(void) const override
		{
			if (this->m_a == nullptr)
				return "0";

			return base<type>::concat({ this->m_a->to_string(), "sq" });
		}

	private:
		virtual const type compute(void) const
		{
			auto _a = (this->m_a != nullptr) ? this->m_a->eval() : 0;

			return _a * _a;
		}

		virtual std::shared_ptr<const math::solve::base<type>> create_derivative(const math::solve::base<type>* p) const override
		{
			// (a * a)' = a'*a + a'*a = 2 * a' * a
			auto _da = (this->m_a != nullptr) ? this->m_a->get_derivative(p) : nullptr;

			if (_da != nullptr && this->m_a != nullptr)
				return std::make_shared<mult_const<type>>(std::make_shared<mult<type>>(_da, this->m_a), (type)2);

			return nullptr;
		}

		virtual bool has_dependency(const math::solve::base<type>* p) const override
		{
			if (this->m_a != nullptr && this->m_a->has(p))
				return true;

			return false;
		}

	private:
		std::shared_ptr<const math::solve::base<type>> m_a;
	};
};