#pragma once

#include "base.h"
#include "mult.h"
#include "div.h"

#include <memory>
#include <string>
#include <cmath>

namespace math::solve
{
	/* sqrt solve */
	template<typename type> class sqrt : public solve::base<type>
	{
	public:
		sqrt(std::shared_ptr<const math::solve::base<type>> a) : base<type>({ a })
		{
			this->m_a = a;
		}

		virtual std::string to_string(void) const override
		{
			if (this->m_a == nullptr)
				return "0";

			return base<type>::concat({ this->m_a->to_string(), "sqrt" });
		}

	private:
		virtual const type compute(void) const
		{
			auto _a = (this->m_a != nullptr) ? this->m_a->eval() : 0;

			return std::sqrt(_a);
		}

		virtual std::shared_ptr<const math::solve::base<type>> create_derivative(const math::solve::base<type>* p) const override
		{
			// (sqrt(a))' = a' / (2 * sqrt(a))
			auto _da = (this->m_a != nullptr) ? this->m_a->get_derivative(p) : nullptr;

			if (_da != nullptr && this->m_a != nullptr)
				return std::make_shared<div<type>>(_da, std::make_shared<const_mult<type>>(2, this->shared_from_this()));

			return nullptr;
		}

		virtual bool has_dependency(const math::solve::base<type>* p) const override
		{
			if (this->m_a != nullptr && this->m_a->has(p))
				return true;

			return false;
		}

	private:
		std::shared_ptr<const math::solve::base<type>> m_a;
	};
};
