#pragma once

#include "../../utils/utils.h"
#include "../../threads/thread.h"

#include "exception.h"
#include "message.h"
#include "client.h"

#include "../net.h"

#include <memory>
#include <mutex>
#include <list>

#ifdef WIN32

#include <WinSock2.h>
#include <WS2tcpip.h>

#else

#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>

#endif

namespace net::tcp
{
	/* server class */
	template<const unsigned long IDENT> class server : public thread::base
	{
	public:
		using client_base_t = net::tcp::client<IDENT>;

	protected:

		/* default server client implementation */
		class client : public net::tcp::client<IDENT>, public std::enable_shared_from_this<client>
		{
		public:
			client(server& parent, SOCKET socket) : m_server(parent), net::tcp::client<IDENT>(socket) {}

		private:
			virtual void on_disconnect(void) override
			{
				this->m_server.remove(this->shared_from_this());
			}

			virtual void on_error(const std::string& msg) override
			{
				this->m_server.on_error(this->shared_from_this(), msg);
			}

			server& m_server;
		};

	public:
		// ctor
		server(const char* address, int port)
		{
			this->m_socket = INVALID_SOCKET;

			// create socket
			do
			{
				// get address for TCP socket
				struct addrinfo addr, * addresses_info;

				memset(&addr, 0, sizeof(addr));

				addr.ai_family = AF_INET;
				addr.ai_socktype = SOCK_STREAM;
				addr.ai_protocol = IPPROTO_TCP;
				addr.ai_flags = AI_PASSIVE;

				char service[64];

				snprintf(service, sizeof(service), "%d", port);

				if (getaddrinfo(address, service, &addr, &addresses_info) != 0)
					break;

				// loop through address info and try to create a server
				struct addrinfo* curr_address_info = addresses_info;

				while (curr_address_info != nullptr)
				{
					// try to create socket
					this->m_socket = socket(curr_address_info->ai_family, curr_address_info->ai_socktype, curr_address_info->ai_protocol);

					// bind socket if succeeded
					if (this->m_socket != INVALID_SOCKET)
					{
#ifdef WIN32
						const char f = 1;
#else
						const int f = 1;
#endif

						if (setsockopt(this->m_socket, SOL_SOCKET, SO_REUSEADDR, &f, sizeof(int)) < 0)
							_warning("cannot set SO_REUSEADDR");

						// skip if bind succeeded
						if (bind(this->m_socket, curr_address_info->ai_addr, (int)curr_address_info->ai_addrlen) != SOCKET_ERROR)
							break;

						// close socket if failed and try next address
						closesocket(this->m_socket);

						this->m_socket = INVALID_SOCKET;
					}

					// try next address info
					curr_address_info = curr_address_info->ai_next;
				}

				// free address list
				if (addresses_info != nullptr)
					freeaddrinfo(addresses_info);

				addresses_info = nullptr;

				// throw exception if no socket created
				if (this->m_socket == INVALID_SOCKET)
					break;

				// listen to the socket
				if (listen(this->m_socket, SOMAXCONN) == SOCKET_ERROR)
					break;

				// return successfuly
				return;

			} while (false);

			close();

			// could not create socket
			throw_exception(cannot_create_server_exception);
		}

		// dtor
		~server(void)
		{
			close();
		}

		// send to all clients
		void send_all(const msg& obj)
		{
			std::lock_guard<std::mutex> lk(this->m_lock);

			for (auto& p : this->m_clients)
				if (p != nullptr)
					p->send(obj);
		}

		// send to all clients
		void send_all(std::shared_ptr<const msg> p)
		{
			if (p != nullptr)
				send_all(*p);
		}

		// equivalent to run but in synchronous (unthreaded) mode
		bool sync_run(void)
		{
			bool success = false;

			if (_listen() != nullptr)
				success |= true;

			{
				// copy client list
				std::list<std::shared_ptr<client_base_t>> list;
				
				{
					std::lock_guard<std::mutex> lk(this->m_lock);

					list = this->m_clients;
				}

				for (auto p : list)
					if (p != nullptr)
						success |= p->sync_run();
			}

			return success;
		}

		// get client list
		auto get_client_list(void) const
		{
			// copy client list
			std::list<std::shared_ptr<client_base_t>> list;

			{
				std::lock_guard<std::mutex> lk(this->m_lock);

				list = this->m_clients;
			}

			return list;
		}

		// close all connection and server
		void close(void)
		{
			_debug("closing tcp server");

			// stop thread first
			stop();

			// clear all connections
			std::lock_guard<std::mutex> lk(this->m_lock);

			this->m_clients.clear();

			// close socket
			if (this->m_socket != INVALID_SOCKET)
				closesocket(this->m_socket);

			this->m_socket = INVALID_SOCKET;
		}

	private:
		std::shared_ptr<client_base_t> _listen(void)
		{
			// poll socket
			bool conn_reset = false, read_ready = false, write_ready = false;

			{
				std::lock_guard<std::mutex> lk(this->m_lock);

				if (!net::poll(this->m_socket, &conn_reset, &read_ready, &write_ready))
					return nullptr;
			}

			// check for incomming connections
			if (!read_ready && !write_ready)
				return nullptr;

			// accept new connection
			SOCKET client_socket;

			{
				std::lock_guard<std::mutex> lk(this->m_lock);

				client_socket = accept(this->m_socket, NULL, NULL);
			}

			// add to list, safe to send reference because we own copy of shared pointer
			auto p = create(client_socket);

			{
				std::lock_guard<std::mutex> lk(this->m_lock);

				this->m_clients.emplace_back(p);
			}

			this->on_connect(p);

			return p;
		}

		// loop function to accept connections
		virtual void run(void)
		{
			std::shared_ptr<client_base_t> p;

			if ((p = _listen()) == nullptr)
			{
				std::this_thread::sleep_for(std::chrono::milliseconds(1));

				return;
			}

			p->start();
		}

	protected:
		virtual std::shared_ptr<client_base_t> create(SOCKET socket)
		{
			return std::make_shared<client>(*this, socket);
		}

		virtual void on_connect(std::shared_ptr<client_base_t> p)
		{
			_debug("Client %s has connected.", p->get_addr().c_str());
		}

		virtual void on_disconnect(std::shared_ptr<client_base_t> p)
		{
			_debug("Client %s has disconnected.", p->get_addr().c_str());
		}

		virtual void on_error(std::shared_ptr<client_base_t> p, const std::string& msg)
		{
			_error("client %s has triggered error \"%s\" !", p->get_addr().c_str(), msg.c_str());
		}

	private:
		// remove client from list
		virtual void remove(std::shared_ptr<client_base_t> p)
		{
			if (p == nullptr)
				return;

			this->on_disconnect(p);

			{

				std::lock_guard<std::mutex> lk(this->m_lock);

				this->m_clients.remove(p);
			}
		}

	private:
		SOCKET m_socket;

		mutable std::mutex m_lock;
		std::list<std::shared_ptr<client_base_t>> m_clients;
	};
}
