#pragma once

#include "base.h"

#include <memory>
#include <string>

namespace math::solve
{
	/* negate solve */
	template<typename type> class negate : public solve::base<type>
	{
	public:
		negate(std::shared_ptr<const math::solve::base<type>> a) : base<type>({ a })
		{
			this->m_a = a;
		}

		virtual std::string to_string(void) const override
		{
			if (this->m_a == nullptr)
				return "0";

			return base<type>::concat({ this->m_a->to_string(), "neg" });
		}

	private:
		virtual const type compute(void) const
		{
			auto _a = (this->m_a != nullptr) ? this->m_a->eval() : 0;

			return -_a;
		}

		virtual std::shared_ptr<const math::solve::base<type>> create_derivative(const math::solve::base<type>* p) const override
		{
			// (-a)' = -a'
			auto _da = (this->m_a != nullptr) ? this->m_a->get_derivative(p) : nullptr;

			if (_da != nullptr)
				return std::make_shared<negate<type>>(_da);

			return nullptr;
		}

		virtual bool has_dependency(const math::solve::base<type>* p) const override
		{
			if (this->m_a != nullptr && this->m_a->has(p))
				return true;

			return false;
		}

	private:
		std::shared_ptr<const math::solve::base<type>> m_a;
	};

	/* subtract two solves */
	template<typename type> class subtract : public solve::base<type>
	{
	public:
		subtract(std::shared_ptr<const math::solve::base<type>> a, std::shared_ptr<const math::solve::base<type>> b) : base<type>({ a, b })
		{
			this->m_a = a;
			this->m_b = b;
		}

		virtual std::string to_string(void) const override
		{
			if (this->m_a == nullptr && this->m_b == nullptr)
				return "0";

			if (this->m_a == nullptr)
				return base<type>::concat({ this->m_b->to_string(), "neg" });

			if (this->m_b == nullptr)
				return this->m_a->to_string();

			return base<type>::concat({ this->m_b->to_string(), this->m_a->to_string(), "-" });
		}

	private:
		virtual const type compute(void) const
		{
			auto _a = (this->m_a != nullptr) ? this->m_a->eval() : 0;
			auto _b = (this->m_b != nullptr) ? this->m_b->eval() : 0;

			return _a - _b;
		}

		virtual std::shared_ptr<const math::solve::base<type>> create_derivative(const math::solve::base<type>* p) const override
		{
			// (a - b)' = a' - b'
			auto _da = (this->m_a != nullptr) ? this->m_a->get_derivative(p) : nullptr;
			auto _db = (this->m_b != nullptr) ? this->m_b->get_derivative(p) : nullptr;

			if (_da != nullptr && _db != nullptr)
				return std::make_shared<subtract<type>>(_da, _db);

			if (_da != nullptr && _db == nullptr)
				return _da;

			if (_db != nullptr && _da == nullptr)
				return std::make_shared<negate<type>>(_db);

			return nullptr;
		}

		virtual bool has_dependency(const math::solve::base<type>* p) const override
		{
			if (this->m_a != nullptr && this->m_a->has(p))
				return true;

			if (this->m_b != nullptr && this->m_b->has(p))
				return true;

			return false;
		}

	private:
		std::shared_ptr<const math::solve::base<type>> m_a, m_b;
	};

	/* subtract one solve and one const */
	template<typename type> class subtract_const : public solve::base<type>
	{
	public:
		subtract_const(std::shared_ptr<const math::solve::base<type>> a, const type b) : base<type>({ a })
		{
			this->m_a = a;
			this->m_b = b;
		}

		virtual std::string to_string(void) const override
		{
			if (this->m_a == nullptr)
				return type2string(this->m_b);

			return base<type>::concat({ type2string(this->m_b), this->m_a->to_string(), "-" });
		}

	private:
		virtual const type compute(void) const
		{
			auto _a = (this->m_a != nullptr) ? this->m_a->eval() : 0;

			return _a - this->m_b;
		}

		virtual std::shared_ptr<const math::solve::base<type>> create_derivative(const math::solve::base<type>* p) const override
		{
			// (a - k)' = a'
			return (this->m_a != nullptr) ? this->m_a->get_derivative(p) : nullptr;
		}

		virtual bool has_dependency(const math::solve::base<type>* p) const override
		{
			if (this->m_a != nullptr && this->m_a->has(p))
				return true;

			return false;
		}

	private:
		std::shared_ptr<const math::solve::base<type>> m_a;
		type m_b;
	};

	/* subtract one const and one solve */
	template<typename type> class const_subtract : public solve::base<type>
	{
	public:
		const_subtract(const type a, std::shared_ptr<const math::solve::base<type>> b) : base<type>({ b })
		{
			this->m_a = a;
			this->m_b = b;
		}

		virtual std::string to_string(void) const override
		{
			if (this->m_b == nullptr)
				return base<type>::concat({ type2string(this->m_a), "neg" });

			return base<type>::concat({ this->m_b->to_string(), type2string(this->m_a), "-" });
		}

	private:
		virtual const type compute(void) const
		{
			auto _b = (this->m_b != nullptr) ? this->m_b->eval() : 0;

			return this->m_a - _b;
		}

		virtual std::shared_ptr<const math::solve::base<type>> create_derivative(const math::solve::base<type>* p) const override
		{
			// (k - b)' = -b'
			return (this->m_b != nullptr) ? std::make_shared<negate>(this->m_b->get_derivative(p)) : nullptr;
		}

		virtual bool has_dependency(const math::solve::base<type>* p) const override
		{
			if (this->m_b != nullptr && this->m_b->has(p))
				return true;

			return false;
		}

	private:
		type m_a;
		std::shared_ptr<const math::solve::base<type>> m_b;
	};
};