#pragma once

#include "map2d.h"
#include "exception.h"

#include "../utils/utils.h"

#include <cmath>
#include <functional>
#include <list>
#include <tuple>
#include <vector>

MAP2D_IMPL_STD_FUNC(float, cos);
MAP2D_IMPL_STD_FUNC(float, sin);
MAP2D_IMPL_STD_FUNC(float, tan);
MAP2D_IMPL_STD_FUNC(float, acos);
MAP2D_IMPL_STD_FUNC(float, asin);
MAP2D_IMPL_STD_FUNC(float, atan);

static map2d<float> atan2(const map2d<float>& y, const map2d<float>& x)
{
	return map2d_in2out<float, float, float>(y, x, [](float a, float b)
		{
			return (float)std::atan2(a, b);
		});
}

MAP2D_IMPL_STD_FUNC(float, cosh);
MAP2D_IMPL_STD_FUNC(float, sinh);
MAP2D_IMPL_STD_FUNC(float, tanh);
MAP2D_IMPL_STD_FUNC(float, acosh);
MAP2D_IMPL_STD_FUNC(float, asinh);
MAP2D_IMPL_STD_FUNC(float, atanh);

MAP2D_IMPL_STD_FUNC(float, log);
MAP2D_IMPL_STD_FUNC(float, log10);
MAP2D_IMPL_STD_FUNC(float, exp2);
MAP2D_IMPL_STD_FUNC(float, expm1);
MAP2D_IMPL_STD_FUNC(float, log1p);
MAP2D_IMPL_STD_FUNC(float, log2);

MAP2D_IMPL_STD_FUNC(float, sqrt);
MAP2D_IMPL_STD_FUNC(float, cbrt);

MAP2D_IMPL_STD_FUNC(float, erf);
MAP2D_IMPL_STD_FUNC(float, erfc);
MAP2D_IMPL_STD_FUNC(float, tgamma);
MAP2D_IMPL_STD_FUNC(float, lgamma);

MAP2D_IMPL_STD_FUNC(float, ceil);
MAP2D_IMPL_STD_FUNC(float, floor);
MAP2D_IMPL_STD_FUNC(float, trunc);
MAP2D_IMPL_STD_FUNC(float, round);

MAP2D_IMPL_STD_FUNC(float, fabs);

static map2d<float> mod(const map2d<float>& obj, const float mod_val)
{
	return map2d_in2out<float, float, float>(obj, mod_val, [](float a, float b)
		{
			return (float)std::fmod(a, b);
		});
}

static map2d<float> pow(const map2d<float>& obj, const float exp_val)
{
	return map2d_in2out<float, float, float>(obj, exp_val, [](float a, float b)
		{
			return (float)std::pow(a, b);
		});
}

static map2d<float> hypot(const map2d<float>& x, const map2d<float>& y)
{
	return map2d_in2out<float, float, float>(y, x, [](float a, float b)
		{
			return (float)std::sqrt(a * a + b * b);
		});
}

static map2d<float> hypotsq(const map2d<float>& x, const map2d<float>& y)
{
	return map2d_in2out<float, float, float>(y, x, [](float a, float b)
		{
			return (float)(a * a + b * b);
		});
}

MAP2D_IMPL_OPERATOR(float, float, +);
MAP2D_IMPL_OPERATOR(float, float, -);
MAP2D_IMPL_OPERATOR(float, float, *);
MAP2D_IMPL_OPERATOR(float, float, /);

MAP2D_IMPL_OPERATOR(float, bool, <);
MAP2D_IMPL_OPERATOR(float, bool, <=);
MAP2D_IMPL_OPERATOR(float, bool, >);
MAP2D_IMPL_OPERATOR(float, bool, >=);
MAP2D_IMPL_OPERATOR(float, bool, ==);
MAP2D_IMPL_OPERATOR(float, bool, !=);

static map2d<float> operator-(const map2d<float>& obj)
{
	return map2d_in2out<float, float>(obj, [](float val)
		{
			return (float)(-val);
		});
}

MAP2D_IMPL_SELF_OPERATOR(float, +=);

const map2d<float>& self_add(map2d<float>& self, const float value, std::shared_ptr<const map2d<bool>> mask = nullptr)
{
	return map2d_selfop<float,float>(self, value, [](float& self, float val)
		{
			self += val;
		}, mask);
}

const map2d<float>& self_add(map2d<float>& self, const map2d<float>& other, std::shared_ptr<const map2d<bool>> mask = nullptr)
{
	return map2d_selfop<float, float>(self, other, [](float& self, float val)
		{
			self += val;
		}, mask);
}

MAP2D_IMPL_SELF_OPERATOR(float, -=);

const map2d<float>& self_subtract(map2d<float>& self, const float value, std::shared_ptr<const map2d<bool>> mask = nullptr)
{
	return map2d_selfop<float, float>(self, value, [](float& self, float val)
		{
			self -= val;
		}, mask);
}

const map2d<float>& self_subtract(map2d<float>& self, const map2d<float>& other, std::shared_ptr<const map2d<bool>> mask = nullptr)
{
	return map2d_selfop<float, float>(self, other, [](float& self, float val)
		{
			self -= val;
		}, mask);
}

MAP2D_IMPL_SELF_OPERATOR(float, *=);

const map2d<float>& self_mult(map2d<float>& self, const float value, std::shared_ptr<const map2d<bool>> mask = nullptr)
{
	return map2d_selfop<float, float>(self, value, [](float& self, float val)
		{
			self *= val;
		}, mask);
}

const map2d<float>& self_mult(map2d<float>& self, const map2d<float>& other, std::shared_ptr<const map2d<bool>> mask = nullptr)
{
	return map2d_selfop<float, float>(self, other, [](float& self, float val)
		{
			self *= val;
		}, mask);
}

MAP2D_IMPL_SELF_OPERATOR(float, /=);

const map2d<float>& self_div(map2d<float>& self, const float value, std::shared_ptr<const map2d<bool>> mask = nullptr)
{
	return map2d_selfop<float, float>(self, value, [](float& self, float val)
		{
			self /= val;
		}, mask);
}

const map2d<float>& self_div(map2d<float>& self, const map2d<float>& other, std::shared_ptr<const map2d<bool>> mask = nullptr)
{
	return map2d_selfop<float, float>(self, other, [](float& self, float val)
		{
			self /= val;
		}, mask);
}

float mean(const map2d<float>& obj, std::shared_ptr<const map2d<bool>> mask = nullptr)
{
	auto kernel_func = [](const std::pair<float, size_t> &curr, float val)
	{
		return std::make_pair<float, size_t>(curr.first + val, curr.second + 1);
	};

	auto redux_func = [](const std::list< std::pair<float, size_t>>& list)
	{
		auto tmp = std::make_pair<float, size_t>(0, 0);

		for (auto& v : list)
		{
			tmp.first += v.first;
			tmp.second += v.second;
		}

		if (tmp.second == 0)
			return (float)0;

		return tmp.first / (float)tmp.second;
	};

	return map2d_redux<float, std::pair<float, size_t>, float>(obj, kernel_func, redux_func, mask);
}

float rms(const map2d<float>& obj, std::shared_ptr<const map2d<bool>> mask = nullptr)
{
	auto kernel_func = [](const std::pair<float, size_t>& curr, float val)
	{
		return std::make_pair<float, size_t>(curr.first + (val * val), curr.second + 1);
	};

	auto redux_func = [](const std::list< std::pair<float, size_t>>& list)
	{
		auto tmp = std::make_pair<float, size_t>(0, 0);

		for (auto& v : list)
		{
			tmp.first += v.first;
			tmp.second += v.second;
		}

		if (tmp.second == 0)
			return (float)0;

		return (float)std::sqrt(tmp.first / (float)tmp.second);
	};

	return map2d_redux<float, std::pair<float, size_t>, float>(obj, kernel_func, redux_func, mask);
}

float stdev(const map2d<float>& obj, std::shared_ptr<const map2d<bool>> mask = nullptr)
{
	auto kernel_func = [](const std::tuple<float, float, size_t>& curr, float val)
	{
		float sum, sumsq;
		size_t num;

		std::tie(sum, sumsq, num) = curr;

		return std::make_tuple<float, float, size_t>(sum + val, sumsq + (val * val), num + 1);
	};

	auto redux_func = [](const std::list< std::tuple<float, float, size_t>>& list)
	{
		auto tmp = std::make_tuple<float, float, size_t>(0, 0, 0);

		for (auto& v : list)
		{
			std::get<0>(tmp) += std::get<0>(v);
			std::get<1>(tmp) += std::get<1>(v);
			std::get<2>(tmp) += std::get<2>(v);
		}

		float sum, sumsq;
		size_t num;

		std::tie(sum, sumsq, num) = tmp;

		if (num == 0)
			return (float)0;

		auto meanval = sum / (float)num;

		return (float)std::sqrt(sumsq / (float)num - meanval * meanval);
	};

	return map2d_redux<float, std::tuple<float, float, size_t>, float>(obj, kernel_func, redux_func, mask);
}

float minval(const map2d<float>& obj, std::shared_ptr<const map2d<bool>> mask = nullptr)
{
	auto kernel_func = [](const float& curr, float val)
	{
		return std::min(curr, val);
	};

	auto redux_func = [](const std::list<float>& list)
	{
		if (list.size() == 0)
			return (float)0;

		auto tmp = list.front();

		for (auto& v : list)
			tmp = std::min(tmp, v);

		return (float)tmp;
	};

	return map2d_redux<float, float, float>(obj, kernel_func, redux_func, mask);
}

float maxval(const map2d<float>& obj, std::shared_ptr<const map2d<bool>> mask = nullptr)
{
	auto kernel_func = [](const float& curr, float val)
	{
		return std::max(curr, val);
	};

	auto redux_func = [](const std::list<float>& list)
	{
		if (list.size() == 0)
			return (float)0;

		auto tmp = list.front();

		for (auto& v : list)
			tmp = std::max(tmp, v);

		return (float)tmp;
	};

	return map2d_redux<float, float, float>(obj, kernel_func, redux_func, mask);
}

float sum(const map2d<float>& obj, std::shared_ptr<const map2d<bool>> mask = nullptr)
{
	auto kernel_func = [](const float& curr, float val)
	{
		return (float)(curr + val);
	};

	auto redux_func = [](const std::list<float>& list)
	{
		if (list.size() == 0)
			return (float)0;

		float tmp = 0;

		for (auto& v : list)
			tmp = tmp + v;

		return (float)tmp;
	};

	return map2d_redux<float, float, float>(obj, kernel_func, redux_func, mask);
}

float sumsq(const map2d<float>& obj, std::shared_ptr<const map2d<bool>> mask = nullptr)
{
	auto kernel_func = [](const float& curr, float val)
	{
		return (float)(curr + (val * val));
	};

	auto redux_func = [](const std::list<float>& list)
	{
		if (list.size() == 0)
			return (float)0;

		float tmp = 0;

		for (auto& v : list)
			tmp = tmp + v;

		return (float)tmp;
	};

	return map2d_redux<float, float, float>(obj, kernel_func, redux_func, mask);
}

// create a vector by summing columns of the image
static std::vector<float> sum_cols(const map2d<float>& obj)
{
	std::vector<float> vec(obj.width());

	for (size_t x = 0; x < obj.width(); x++)
	{
		vec[x] = obj(x, 0);

		for (size_t y = 1; y < obj.height(); y++)
			vec[x] += obj(x, y);
	}

	return vec;
}

// create a vector by summing rows of the image
static std::vector<float> sum_rows(const map2d<float>& obj)
{
	std::vector<float> vec(obj.height());

	for (size_t y = 0; y < obj.height(); y++)
	{
		vec[y] = obj(0, y);

		for (size_t x = 1; x < obj.width(); x++)
			vec[y] += obj(x, y);
	}

	return vec;
}

// create a vector by getting the maximum value in each column on the image
static std::vector<float> max_cols(const map2d<float>& obj)
{
	std::vector<float> vec(obj.width());

	for (size_t x = 0; x < obj.width(); x++)
	{
		vec[x] = obj(x, 0);

		for (size_t y = 1; y < obj.height(); y++)
			vec[x] = std::max(vec[x], obj(x, y));
	}

	return vec;
}

// create a vector by getting the maximum value in each row on the image
static std::vector<float> max_rows(const map2d<float>& obj)
{
	std::vector<float> vec(obj.height());

	for (size_t y = 0; y < obj.height(); y++)
	{
		vec[y] = obj(0, y);

		for (size_t x = 1; x < obj.width(); x++)
			vec[y] = std::max(vec[y], obj(x, y));
	}

	return vec;
}
