#pragma once

#include "../../utils/utils.h"
#include "../../threads/thread.h"
#include "../../message/handler.h"
#include "../../instancers/local_instancer.h"
#include "../../storage/buffer.h"
#include "../../storage/fs/spc0/object.h"

#include "../net.h"
#include "../packet.h"
#include "../endpoint.h"

#include "exception.h"
#include "message.h"

#include <memory>

#ifdef WIN32

#include <WinSock2.h>
#include <WS2tcpip.h>

#else

#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>

#endif

namespace net::tcp
{
	/* tcp_client class */
	template<const unsigned long IDENT, class storage_object_type=storage::fs::spc0::object> class client : public thread::base, public net::packet, public message::handler<msg>, public instancer::local<msg>, public net::endpoint<msg>
	{
	public:
		client(void) = delete;
		client(client&) = delete;

		// ctor
		client(SOCKET socket) : net::packet(IDENT)
		{
			this->m_socket = socket;
		}

		// ctor
		client(const char* address, int port) : net::packet(IDENT)
		{
			this->m_socket = INVALID_SOCKET;

			do
			{
				// retrieve remote address
				struct addrinfo hints, * addresses_info = nullptr;

				memset(&hints, 0, sizeof(hints));

				hints.ai_family = AF_INET;
				hints.ai_socktype = SOCK_STREAM;
				hints.ai_protocol = IPPROTO_TCP;

				char service[64];

				snprintf(service, sizeof(service), "%d", port);

				if (getaddrinfo(address, service, &hints, &addresses_info) != 0)
					break;

				// browse list
				for (struct addrinfo* p = addresses_info; p != nullptr; p = p->ai_next)
				{
					// create socket
					this->m_socket = socket(p->ai_family, p->ai_socktype, p->ai_protocol);

					if (this->m_socket == INVALID_SOCKET)
						continue;

					// try to connect
					if (connect(this->m_socket, p->ai_addr, (int)p->ai_addrlen) == SOCKET_ERROR)
					{
						closesocket(this->m_socket);

						this->m_socket = INVALID_SOCKET;

						break;
					}

					// we're connected
					break;
				}

				// free address list
				if (addresses_info != nullptr)
					freeaddrinfo(addresses_info);

				addresses_info = nullptr;

				// throw exception if no socket created
				if (this->m_socket == INVALID_SOCKET)
					break;

				// return successfuly
				return;

			} while (false);

			close();

			// could not create socket
			throw_exception(cannot_create_client_exception);
		}

		// dtor
		virtual ~client(void)
		{
			// clean stop
			try
			{
				stop();
			}
			catch (...) {}

			// close socket
			std::lock_guard<std::mutex> lk(this->m_lock);

			if (this->m_socket != INVALID_SOCKET)
				closesocket(this->m_socket);

			this->m_socket = INVALID_SOCKET;
		}

		// force client to close
		void close(void)
		{
			std::lock_guard<std::mutex> lk(this->m_lock);

			if (this->m_socket != INVALID_SOCKET)
				closesocket(this->m_socket);

			this->m_socket = INVALID_SOCKET;
		}

		// return IP address as string
		std::string get_addr(void) const
		{
			if (this->m_address != nullptr)
				return *this->m_address;

			std::lock_guard<std::mutex> lk(this->m_lock);

			// get addr
			struct sockaddr_in saddr_in;
			socklen_t len = sizeof(saddr_in);

			if (getpeername(this->m_socket, (SOCKADDR*)&saddr_in, &len) != 0)
				return std::string("?");

			char tmp[128];

			if (inet_ntop(AF_INET, &saddr_in.sin_addr, tmp, sizeof(tmp)) == NULL)
				return std::string("?");

			this->m_address = std::make_unique<std::string>(tmp);

			return *this->m_address;
		}

		// send data
		virtual bool write(const storage::buffer& data) override
		{
			// check if socket is valid
			{
				std::lock_guard<std::mutex> lk(this->m_lock);

				// skip if no socket
				if (this->m_socket == INVALID_SOCKET)
					return false;
			}

			// pack bytes
			auto buffer = net::packet::pack(data);

			{
				std::lock_guard<std::mutex> lk(this->m_lock);

				// send to socket
				::send(this->m_socket, (const char*)buffer.data(), MAKE_INT(buffer.size()), 0);
			}

			// successful return
			return true;
		}

		// pack to storage buffer
		virtual storage::buffer pack(const msg& msg_obj) override
		{
			storage_object_type storage_object;

			msg_obj.push(storage_object);

			return storage_object.pack();
		}

		// unpack storage buffer
		virtual std::shared_ptr<storage::fs::object_base> unpack(const storage::buffer& buffer) override
		{
			auto p = std::make_shared<storage_object_type>();

			p->unpack(buffer);

			return p;
		}

		// equivalent to run but in synchronous (unthreaded) mode
		bool sync_run(void)
		{
			return !_listen();
		}

	protected:
		virtual void on_disconnect(void)
		{
			_debug("client %s disconnected", get_addr().c_str());
		}

		virtual void on_error(const std::string& err)
		{
			_warning("error in tcp client: %s", err.c_str());
		}

	private:
		// thread stops when socket is invalid
		virtual bool stop_condition(void) const override
		{
			std::lock_guard<std::mutex> lk(this->m_lock);

			return this->m_socket == INVALID_SOCKET;
		}

		// listen
		auto _listen(void)
		{
			// poll socket
			bool conn_reset = false, read_ready = false;

			{
				std::lock_guard<std::mutex> lk(this->m_lock);

				if (!net::poll(this->m_socket, &conn_reset, &read_ready, nullptr))
					return false;
			}

			// connections aborted
			if (conn_reset)
			{
				{
					std::lock_guard<std::mutex> lk(this->m_lock);

					closesocket(this->m_socket);
					this->m_socket = INVALID_SOCKET;
				}

				on_disconnect();

				return true;
			}

			// check for incomming messages
			if (!read_ready)
				return false;

			// get message
			char buffer[4096];
			int ret;

			{
				std::lock_guard<std::mutex> lk(this->m_lock);

				ret = recv(this->m_socket, buffer, sizeof(buffer), 0);
			}

			// other side disconnected
			if (ret == 0)
			{
				{
					std::lock_guard<std::mutex> lk(this->m_lock);

					closesocket(this->m_socket);
					this->m_socket = INVALID_SOCKET;
				}

				on_disconnect();

				return true;
			}

			// valid message
			if (ret > 0)
				on_receive((const unsigned char*)buffer, ret);

			return true;
		}

		// receive loop
		virtual void run(void)
		{
			if(!_listen())
				std::this_thread::sleep_for(std::chrono::milliseconds(1));
		}

		// process incomming packet
		virtual void on_process(const unsigned char* bytes, const size_t size) override
		{
			try
			{
				// unpack bytes
				auto storage_object = unpack(storage::buffer(bytes, size));

				if (storage_object == nullptr)
					return;

				// create instance based on object name
				auto obj = create_instance(storage_object->get_type_name());

				// retrieve object
				obj->pop(*storage_object);

				// handle object
				handle(*obj);
			}
			catch (exception::base& e)
			{
				on_error(e.to_string().c_str());
			}
			catch (std::exception& e)
			{
				on_error(e.what());
			}
			catch (...)
			{
				on_error("unknown error!\r\n");
			}
		}

		// receive bytes from socket
		void on_receive(const unsigned char* bytes, const size_t size)
		{
			// append bytes to current packet and try process
			net::packet::unpack(bytes, size);
		}

		mutable std::mutex m_lock;

		SOCKET m_socket;
		mutable std::unique_ptr<std::string> m_address;
	};
};
