前言

由于缺少统一的框架,C++作为后端服务开发语言时,需要与操作系统直接交互,做大量底层通讯相关的工作,即便很常见的应用层协议也需要从头适配,十分繁琐,很大程度上限制了C++在后端服务开发领域的应用。一些常见c++框架和程序库,大多专注于解决某一领域的问题,很难直接拿过来用,需要进行一番魔改。又或者框架量级过重,依赖过多,使用起来十分不便。

Workflow是一个非常轻量化的,设计轻盈优雅的企业级程序引擎,由搜狗公司于2019年开源,此前已经过多年的线上检验,支撑了搜狗几乎所有的后端C++在线服务,如搜索,云输入法,在线广告等,每日处理超百亿请求。它支持Linux,macOS,Windows等操作系统,支持所有CPU平台,包括32或64位x86处理器,大端或小端arm处理器。项目基于c++11标准,除了OpenSSL安全库,无其他依赖。再加上极其灵活的扩展接口,使得workflow逐渐演变成了一个通用的后端服务框架,可满足大多数C++后端开发需求。

Workflow网络模型分析

Workflow框架有许多值得称道的地方,比如一切皆任务的思想,关于协议、算法和任务流的抽象,朴素的异步性实现,优雅的超时机制和错误处理等等。不过作为一个后端框架,最核心的部分还是网络模型的设计,这决定了它的网络服务性能上限。因此本篇案例着眼于workflow网络模型的剖析,洞悉其内部实现原理,学习并掌握网络模块封装技巧。

Workflow-windows分支即是workflow在windows系统下的实现,为了充分利用系统提供的网络性能,使用了IOCP作为IO多路复用基础部件。与Linux系统下的epoll相比,IOCP封装的更彻底,节省了数据从socket内核缓冲区向用户态buffer的拷贝的步骤,甚至将多线程也封装到IOCP内部,避免了用户层面的线程调度,使用起来更加方便。不过,对于写惯了read/write直接操作socket API的c++开发人员来说,IOCP的这些改动,也带来了一些理念上的冲击,需要逐步适应。

IOCP完成端口主要封装在WinPoller类中(源码位置:workflow-windows/src/kernel_win/),提供网络IO的异步读写,强制唤醒,提前终止等操作。除此之外,WinPoller内部还封装了一个定时器,用于处理延时任务,还支持用户事件转发,用来配合外部任务流的运行。Workflow官方提供了详细的使用教程,指导用户使用。根据http_server的使用示例,可以看到服务启动的整个流程,如下图,显示了WinPoller在workflow网络服务框架中所处的位置,以及一些主要的接口方法。(linux系统下网络模块是基于epoll实现的,封装方式与winpoller不同,可以看master分支下的Communicator源码,它创建了多个epoll线程,根据文件描述符将io任务分配给不同的epoll,实现负载均衡)。

f4d62cf8672b4dcfa638d30eb74d7d33.png

上图中,WFServer是一个模板类,根据不同的应用层协议,可特化出不同的服务类型。其父类WFServerBase组合了通信调度器类CommScheduler,通信调度器作为Communicator类的代理,控制通信交互过程。Communicator类拥有一个线程池(Thread Pool)和一个事件轮询器(WinPoller或mpoller),线程池负责从任务队列取走事件任务并完成事件处理(消费者),轮询器负责检查事件,将事件添加到任务队列(生产者),二者构建起生产消费模型。

通过WinPoller的put_io接口,可以往事件轮询器投递(注册)多种不同类型的异步事件,比如读IO事件,写IO事件,建立连接事件,接受连接事件,延时事件(即sleep事件)以及用户自定义事件(即user事件),在线程池的工作回调中执行get_io_result可以拿到异步事件的结果,并完成异步事件的后续操作。

用做服务端时,可投递ACCEPT接受连接事件,做客户端时,可投递CONNECT建立连接事件;连接就绪后,即可继续投递对应socket的异步读写事件。为了适配各种应用协议,在投递事件时,允许携带自定义会话上下文的引用,其中包括会话状态,会话处理回调。当事件就绪后,根据会话信息,执行会话处理回调,实现业务处理。不论是开服务还是连接其他服务,都可以在运行时动态的进行,可以即是服务端也是客户端,框架将在同一个事件轮询器中妥善处理各个连接调用关系,基本上模糊了服务端和客户端的概念,非常适合构建大型分布式软件系统,或者是一些需要访问数据库和消息中间件的后端应用。

Workflow网络模型应用

Workflow框架的应用已经足够简单,官方文档清晰地给出了各种情况下的使用示例,藉由它的一套私有协议扩展规则,很快就能搭建出一个像模像样的网络服务。但这一定程度上损失了灵活性,也因隐藏了内部实现,违背了深入学习和掌握workflow网络模型的初衷。所以,有必要将网络模型从框架中剥离,直接操作事件轮询器,观察模型运行过程(见后面的代码)。

模块去掉了Communicator及以上层级的复杂逻辑,关注网络IO模型的调度实现。仿照Communicator实现了accept, connect, read, sleep等事件的处理逻辑,仅仅用600多行代码就实现了一个性能强悍,自由度极高的高并发TCP Server。

为了测试TCP Server的并发性能,在jemeter中搭建了测试方案,50条测试线程每条线程发送1000次请求,总计50万次请求,可以观察到各项指标数据如下图:

3342215ce92f469491ec1039f3fbb6d4.png
注: 在一台十几年前快报废的办公本上大概压力测了下,吞吐量貌似还不错,轻松过5000。

#include <iostream>
#include <assert.h>
#include <atomic>
#include "WinPoller.h"
#pragma comment(lib, "Ws2_32.lib")
#pragma comment(lib,"Mswsock.lib")

class ReadContextEx
{
public:
	char* buffer_internal_cached;
	DWORD msgsize;
	WSABUF buffer;
    int read_timeout = -1;
	const DWORD max_msg_size = 64 * 1024;
	ReadContextEx(size_t buffer_size)
	{
		if (buffer_size > max_msg_size)
		{
			buffer_internal_cached = (char*)malloc(max_msg_size);
			msgsize = max_msg_size;
		}
		else
		{
			buffer_internal_cached = (char*)malloc(buffer_size);
			msgsize = buffer_size;
		}
		if (buffer_internal_cached != nullptr)
		{
			buffer.buf = buffer_internal_cached;
			buffer.len = msgsize;
		}
	}
	~ReadContextEx()
	{
		free(buffer_internal_cached);
	}
};

class ConnectContextEx
{
public:
	void* entry;
	struct sockaddr* addr;
	socklen_t addrlen;
	struct sockaddr_in addr_in;
	ConnectContextEx(void* e)
	{
		entry = e;
		addrlen = sizeof(struct sockaddr);
		memset(&addr_in, 0x00, addrlen);
		addr = (struct sockaddr*)&addr_in;
	}
};

static inline int __set_fd_nonblock(SOCKET fd)
{
	unsigned long mode = 1;
	int ret = ioctlsocket(fd, FIONBIO, &mode);

	if (ret == SOCKET_ERROR)
	{
		errno = WSAGetLastError();
		return -1;
	}

	return 0;
}

static int __bind_and_listen(SOCKET listen_sockfd, const struct sockaddr* addr, socklen_t addrlen)
{
	struct sockaddr_storage ss;
	socklen_t len = sizeof(struct sockaddr_storage);

	if (getsockname(listen_sockfd, (struct sockaddr*)&ss, &len) == SOCKET_ERROR)
	{
		if (WSAGetLastError() == WSAEINVAL)
		{
			if (bind(listen_sockfd, addr, addrlen) == SOCKET_ERROR)
				return -1;
		}
	}
	if (listen(listen_sockfd, SOMAXCONN) == SOCKET_ERROR)
		return -1;
	return 0;
}

static int __bind_local(SOCKET sockfd, const struct sockaddr* addr, socklen_t addrlen)
{
	struct sockaddr_storage ss;
	socklen_t len = sizeof(struct sockaddr_storage);

	if (getsockname(sockfd, (struct sockaddr*)&ss, &len) == SOCKET_ERROR)
	{
		if (WSAGetLastError() == WSAEINVAL)
		{
			if (bind(sockfd, addr, addrlen) == SOCKET_ERROR)
				return -1;
		}
	}
	return 0;
}


static int __bind_any(SOCKET sockfd, int sa_family)
{
	struct sockaddr_storage addr;
	socklen_t addrlen;

	memset(&addr, 0, sizeof(struct sockaddr_storage));
	addr.ss_family = sa_family;
	if (sa_family == AF_INET)
	{
		struct sockaddr_in* sin = (struct sockaddr_in*)&addr;
		sin->sin_addr.s_addr = INADDR_ANY;
		sin->sin_port = 0;
		addrlen = sizeof(struct sockaddr_in);
	}
	else if (sa_family == AF_INET6)
	{
		struct sockaddr_in6* sin6 = (struct sockaddr_in6*)&addr;
		sin6->sin6_addr = in6addr_any;
		sin6->sin6_port = 0;
		addrlen = sizeof(struct sockaddr_in6);
	}
	else
		addrlen = sizeof(struct sockaddr_storage);

	if (bind(sockfd, (struct sockaddr*)&addr, addrlen) == SOCKET_ERROR)
		return -1;

	return 0;
}

static int __sync_send(SOCKET sockfd, const void* buf, size_t size)
{
	int ret;
	if (size == 0 || !buf)
		return 0;
	ret = send(sockfd, (const char*)buf, size, 0);
	if (ret == size)
		return size;
	if (ret > 0)
	{
		errno = ENOBUFS;
		ret = -1;
	}
	return ret;
}

int __create_stream_socket(unsigned short address_family, int type, bool is_blocked = false)
{
	SOCKET sock = (int)socket(address_family, type, 0);
	if (sock != INVALID_SOCKET)
	{
		if (!is_blocked && __set_fd_nonblock(sock) < 0)
		{
			closesocket(sock);
			return -1;
		}
	}
	else
		return -1;
	return (int)sock;
}

void handle_accept_result(WinPoller* poller, struct poller_result* res)
{
	AcceptConext* ctx = (AcceptConext*)res->data.context;
	SOCKET listen_fd = (SOCKET)res->data.handle;
	SOCKET sockfd = ctx->accept_sockfd;
	switch (res->state)
	{
	case PR_ST_SUCCESS://todo error???
	case PR_ST_FINISHED:
		if (sockfd != INVALID_SOCKET)
		{
			if (poller->bind((HANDLE)sockfd) >= 0)
			{
				struct poller_data data;
				int timeout;
				auto* new_ctx = new ReadContextEx(1024);

				data.operation = PD_OP_READ;
				data.handle = (HANDLE)sockfd;
				data.context = new_ctx;
				if (poller->put_io(&data, -1) < 0)
				{
					delete new_ctx;
					poller->unbind_socket(sockfd);
				}
				else
				{
					ctx->remote, ctx->remote_len;
					char buf[20] = { 0 };
					inet_ntop(AF_INET, &((sockaddr_in*)ctx->remote)->sin_addr, buf, sizeof(buf));
					printf(" Accept a new connection:  ip=[%s], port=%d.\n", buf, ntohs(((sockaddr_in*)ctx->remote)->sin_port));
				}
			}
			else
			{
				closesocket(sockfd);
				ctx->accept_sockfd = INVALID_SOCKET;
			}
		}
		break;
	case PR_ST_ERROR:
	case PR_ST_STOPPED:
	case PR_ST_TIMEOUT:
	{
		closesocket(sockfd);
		//poller->unbind_socket(listen_fd);// terminate server
		//listen_fd = INVALID_SOCKET;
	}
	break;
	default:
		assert(0);
		break;
	}
	ctx->accept_sockfd = __create_stream_socket(AF_INET,SOCK_STREAM);
	if (listen_fd != INVALID_SOCKET && ctx->accept_sockfd)
	{
		if (poller->put_io(&res->data, -1) >= 0)
			return;//reuse context
		closesocket(ctx->accept_sockfd);
		ctx->accept_sockfd = INVALID_SOCKET;
	}
	if (listen_fd != INVALID_SOCKET)
		poller->unbind_socket(listen_fd);
	delete ctx;
}

void handle_connect_result(WinPoller* poller, struct poller_result* res)
{
	ConnectContextEx* ctx = (ConnectContextEx*)res->data.context;
	struct sockaddr_in target_address = *(struct sockaddr_in*)(ctx->addr);
	SOCKET handle = (SOCKET)res->data.handle;
	delete ctx;
	char target_ip_str[30] = {};
	switch (res->state)
	{
	case PR_ST_SUCCESS://todo error???
	case PR_ST_FINISHED:
		if (handle != INVALID_SOCKET)
		{
			inet_ntop(AF_INET, &target_address.sin_addr, target_ip_str, 30);

			// greet message.
			printf("connect to server success[%s].\n", target_ip_str);
			__sync_send(handle, "", 0);

			auto* new_ctx = new ReadContextEx(1024);
			struct poller_data data;
			data.operation = PD_OP_READ;
			data.handle = (HANDLE)handle;
			data.context = new_ctx;
			if (poller->put_io(&data, -1) < 0)
			{
				delete new_ctx;
				poller->unbind_socket(handle);
			}
			else
			{
				return;
			}
		}
		res->error = errno;
		break;

	case PR_ST_ERROR:
	{
		inet_ntop(AF_INET, &target_address.sin_addr, target_ip_str, 30);
		printf("connect to %s failed, error=%d.\n", target_ip_str, res->error);
	}
	break;
	case PR_ST_TIMEOUT:
	{
		poller->unbind_socket(handle);
		printf("connect timeout, error=%d.\n", res->error);
	}
	break;

	case PR_ST_STOPPED:
		poller->unbind_socket(handle);
		break;

	default:
		assert(0);
		break;
	}
}

void handle_read_result(WinPoller* poller, struct poller_result* res)
{
	ReadContextEx* ctx = (ReadContextEx*)res->data.context;
	std::string buf;
	int timeout = ctx->read_timeout;
	switch (res->state)
	{
	case PR_ST_SUCCESS:
	{
		buf = std::string(ctx->buffer.buf, res->iobytes);
		if (buf.size() > 0 && buf[buf.size() - 1] != '\n')
		{
			buf = buf + "\n";
		}
		else if (buf.size() == 0)
			buf = "\n";

		printf(" Recv data from client: dataLen=%d, msg body=%s", res->iobytes, buf.c_str());
		/*
		*     处理数据,并将处理结果发出去
		*     需要在此处可投递异步写事件(一般而言写消息用同步方式更高效,除非一次性要发送大量的数据,否则应该用同步接口)
		*
		*/
		char send_msg[100] = {};
		int sz = sprintf_s(send_msg, 100, "Hello there, already get your msg : %s.\n", buf.c_str());
		__sync_send((SOCKET)res->data.handle, send_msg, sz);

		//继续投递读事件
		res->data.operation = PD_OP_READ;
		if (poller->put_io(&res->data, timeout) >= 0)
		{
			ctx = NULL;//reuse context
		}
		else
		{
			printf("Internal error.");
			poller->unbind_socket((SOCKET)res->data.handle);
		}
	}
	break;
	case PR_ST_FINISHED:
	case PR_ST_TIMEOUT:
	case PR_ST_ERROR:
	{
		printf("client disconnet or dead, win sock err=%d.\n",res->error);
		poller->unbind_socket((SOCKET)res->data.handle);
	}
	break;
	case PR_ST_STOPPED:
	{
		poller->unbind_socket((SOCKET)res->data.handle);
		printf("client has been kicked off.\n");//socket本地主动关闭,最常见的比如超时机制踢掉不活跃的连接
	}
	break;
	default:
		assert(0);
		break;
	}
	delete ctx;
}

void handle_sleep_result(WinPoller* poller, struct poller_result* res)
{
    int io_type=0;
    socklen_t optlen = sizeof(io_type);
    int ret = getsockopt((SOCKET)res->data.handle, SOL_SOCKET, SO_TYPE,(char*)&io_type, &optlen);
    if(ret >= 0 && (io_type == SOCK_DGRAM || io_type == SOCK_STREAM))
    {
        printf("Network Timer Event Triggered.\n");//网络定时事件触发
        poller->unbind_socket((SOCKET)res->data.handle);
    }
}

int lanuch_async_connect(WinPoller* poller, const char* target_ip, unsigned int target_port, int timeout = -1, unsigned int local_port = 0)
{
	SOCKET sockfd = __create_stream_socket(AF_INET,SOCK_STREAM);

	if (sockfd != INVALID_SOCKET)
	{
		if (poller->bind((HANDLE)sockfd) >= 0)
		{
			int bind_local_result = 0;
			if (local_port == 0)
			{
				bind_local_result = __bind_any(sockfd, AF_INET);
			}
			else
			{
				bind_local_result = __bind_any(sockfd, AF_INET);// TODO: real bind to local_port
				target_port = 0;
			}
			if (bind_local_result >= 0)
			{
				poller_data data;
				auto* new_ctx = new ConnectContextEx(nullptr);
				data.operation = PD_OP_CONNECT;
				data.handle = (HANDLE)sockfd;
				data.context = new_ctx;
				new_ctx->addr_in.sin_family = AF_INET;
				inet_pton(AF_INET, target_ip, &new_ctx->addr_in.sin_addr);
				new_ctx->addr_in.sin_port = htons(target_port);
				int err = poller->put_io(&data, timeout);
				if (err >= 0)
					return sockfd;
				else
				{
					printf(" put async connect event failed: error=%d.\n", errno);
				}
				delete new_ctx;
				poller->unbind_socket(sockfd);
			}
			else
			{
				poller->unbind_socket(sockfd);
			}
		}
		closesocket(sockfd);
	}
	return -1;
}

void handle_udp_accept_result(WinPoller* poller, struct poller_result* res)
{
	char client_ip_str[30] = {};
	UdpAcceptCtx* ctx = (UdpAcceptCtx*)res->data.context;
	SOCKET listen_fd = (SOCKET)res->data.handle;
	//SOCKET sock = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, NULL, 0, WSA_FLAG_OVERLAPPED);
	SOCKET sock = __create_stream_socket(AF_INET, SOCK_DGRAM);
	sockaddr_in address;
	address.sin_family = AF_INET;
	address.sin_addr.s_addr = htonl(INADDR_ANY);
	address.sin_port = 0;//htons(UDP_SERVER_PORT);
	//int reuse = 1;
	//setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (const char*)&reuse, sizeof(int));
	int ret = __bind_local(sock, (SOCKADDR*)&address, sizeof(SOCKADDR));
	if (ret >= 0 && sock!=INVALID_SOCKET)
	{
		ret = connect(sock, (struct sockaddr*)&ctx->remoteAddr, ctx->remoteAddrLen);
		if (poller->bind((HANDLE)sock) >= 0)
		{
			poller_data data;
			data.handle = (HANDLE)sock;
			data.operation = PD_OP_READ;
			auto* new_ctx = new ReadContextEx(1024);
            new_ctx->read_timeout=5000;
			data.context = new_ctx;
			if (poller->put_io(&data, new_ctx->read_timeout) < 0)// if client don't reply in 5 seconds, consider it as a dead connection.
			{
				delete new_ctx;
				poller->unbind_socket(sock);
			}
			else
			{
				char greet_msg[100] = {};
				int sz = sprintf_s(greet_msg, 99, "hello there,  get your first msg: %s, length=%d.\n", ctx->wsaBuf.buf,res->iobytes);
				//sendto(sock, greet_msg, sz, 0,(const sockaddr*)&ctx->remoteAddr, ctx->remoteAddrLen);
                __sync_send(sock, greet_msg,sz);
				inet_ntop(AF_INET, &ctx->remoteAddr.sin_addr, client_ip_str, 30);
				int port = ntohs(ctx->remoteAddr.sin_port);
				printf(" udp client reached: %s, %d.\n", client_ip_str, port);
			}
		}
	}
	else 
	{
		if(sock != INVALID_SOCKET)
		    closesocket(sock);
	}
	if (listen_fd != INVALID_SOCKET)
	{
		if (poller->put_io(&res->data, -1) >= 0)
			return;//reuse context
	}
	if (listen_fd != INVALID_SOCKET)
		poller->unbind_socket(listen_fd);
	delete ctx;
    return;
}

int main()
{
	WSADATA wsaData;
	int port = 9218;
	const char* ip_str = "127.0.0.1";
	int ret = WSAStartup(MAKEWORD(2, 2), &wsaData);
	if (ret != 0)
		return ret;
	struct sockaddr_in bind_addr;
	bind_addr.sin_family = AF_INET;
	inet_pton(AF_INET, ip_str, &bind_addr.sin_addr);
	bind_addr.sin_port = htons(port);
	int type = SOCK_DGRAM;
	SOCKET listen_fd = __create_stream_socket(AF_INET,type);
	if (listen_fd < 0)
		return -1;
	WinPoller* poller = new WinPoller(1);
	bool server_succ = false, server_bind_ok = false;;
	if (poller->bind((HANDLE)listen_fd) >= 0)
	{
		server_bind_ok = true;
		if (type == SOCK_STREAM&&__bind_and_listen(listen_fd, (struct sockaddr*)(&bind_addr), sizeof sockaddr) >= 0)
		{
			poller_data data;
			auto* new_ctx = new AcceptConext(nullptr);

			data.operation = PD_OP_ACCEPT;
			data.handle = (HANDLE)listen_fd;
			data.context = new_ctx;
			new_ctx->accept_sockfd = __create_stream_socket(AF_INET,SOCK_STREAM);
			if (new_ctx->accept_sockfd <= 0)
			{
				delete new_ctx;
			}
			else
			{
				if (poller->put_io(&data, -1) < 0)
				{
					closesocket(new_ctx->accept_sockfd);
					delete new_ctx;
				}
				else
				{
					server_succ = true;
				}
			}
		}
		if (type == SOCK_DGRAM && __bind_local(listen_fd, (struct sockaddr*)(&bind_addr), sizeof sockaddr) >= 0)
		{
			poller_data data;
			data.handle= (HANDLE)listen_fd;
			data.operation = PD_OP_ACCEPT + 100;
			UdpAcceptCtx* new_ctx =  new UdpAcceptCtx(1024);
			data.context = new_ctx;
			if (poller->put_io(&data, -1) < 0)
			{
				delete new_ctx;
			}
			else
			{
				server_succ = true;
			}
		}
	}
	if (server_succ)
	{
        if(type == SOCK_STREAM)
		    printf(" TCP server launched success: address = %s, port = %d.\n", ip_str, port);
        if(type == SOCK_DGRAM)
            printf(" UDP server launched success: address = %s, port = %d.\n", ip_str, port);
	}
	else
	{
		if (server_bind_ok)
			poller->unbind_socket(listen_fd);
		else
			closesocket(listen_fd);
		delete poller;
		WSACleanup();
		return -1;
	}
	std::cout << "Hello World!\n";
	poller->start();//start timer, otherwise timeout mechanism will make no sense

	//lanuch_async_connect(poller,"127.0.0.1",8277);// connect to another server if you like.

	poller_result res;
	while (1)
	{
		int ret = poller->get_io_result(&res, -1);
		if (ret < 0)// poller->stop() has been called somewhere, maybe in another thread
		{
			break;
		}
		else if (ret > 0)
		{
			//printf("%lld %d\n", res.data.handle, res.data.operation);
			switch (res.data.operation & 0xFF)
			{
			case PD_OP_READ:
			{
				handle_read_result(poller, &res);
			}
			break;
			case PD_OP_WRITE:
				//handle_write_result(&res);
				break;
			case PD_OP_CONNECT:
			{
				handle_connect_result(poller, &res);
			}
			break;
			case PD_OP_ACCEPT:
			{
				handle_accept_result(poller, &res);
			}
			break;
			case PD_OP_SLEEP:
				handle_sleep_result(poller, &res);
				break;
			case PD_OP_USER:
				//handle_event_result(&res);
				break;
			case 100+ PD_OP_ACCEPT:
			   { 
					handle_udp_accept_result(poller,&res);
			   }
				break;
			default:
				assert(0);
				break;
			}
		}
	}

	WSACleanup();
	poller->unbind_socket(listen_fd);
	delete poller;
	return 0;
}

WinPoller可以直接拿来作为作为一个核心部件,实现各种网络服务框架,也可以用作客户端,封装各种协议与其他服务框架交互。

为了支持并发UDP Server,对WinPoller进行了适当扩充,增补了__accept_udp_io方法处理udp连接请求。(上面也给出了UDP Accept异步事件结果处理样例,轻松实现并发UDP Server。)

/*
  Copyright (c) 2019 Sogou, Inc.

  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

  Unless required by applicable law or agreed to in writing, software
  distributed under the License is distributed on an "AS IS" BASIS,
  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  See the License for the specific language governing permissions and
  limitations under the License.

  Authors: Wu Jiaxu (wujiaxu@sogou-inc.com)
*/

#ifndef _WPOLLER_H_
#define _WPOLLER_H_

#include <thread>
#include <mutex>
# include <Ws2tcpip.h>
# include <Ws2def.h>

#define ACCEPT_ADDR_SIZE	(sizeof (struct sockaddr_storage) + 16)

struct poller_data
{
	HANDLE handle;
	void *context;
#define PD_OP_READ			1
#define PD_OP_WRITE			2
#define PD_OP_ACCEPT		3
#define PD_OP_CONNECT		4
#define PD_OP_SLEEP			5
#define PD_OP_USER			16
	uint16_t operation;
};

struct poller_result
{
#define PR_ST_SUCCESS		0
#define PR_ST_FINISHED		1
#define PR_ST_ERROR			2
#define PR_ST_STOPPED		5
#define PR_ST_TIMEOUT		6
	int state;
	int error;
	DWORD iobytes;
	struct poller_data data;
};

class AcceptConext
{
public:
	void *service;
	SOCKET accept_sockfd;

	char *buf;
	struct sockaddr *remote;
	int remote_len;

	AcceptConext(void *sc)
	{
		service = sc;

		buf = new char[ACCEPT_ADDR_SIZE * 2];
	}

	~AcceptConext()
	{
		delete []buf;
	}
};

class UdpAcceptCtx {
public:
	SOCKADDR_IN remoteAddr;
	int remoteAddrLen;
	char* greet_buff;
	size_t buff_size;
	WSABUF wsaBuf;
	UdpAcceptCtx(size_t size)
	{
		if (size > 1024)
			buff_size = 1024;
		else
			buff_size = size;
		greet_buff = new char[buff_size];
		memset(greet_buff,0x00, buff_size);
		remoteAddrLen = sizeof(SOCKADDR_IN);
		memset(&remoteAddr, 0x00, remoteAddrLen);
		wsaBuf.buf = greet_buff;
		wsaBuf.len = buff_size;
	}
	~UdpAcceptCtx()
	{
		delete[] greet_buff;
	}
};

class ConnectContext
{
public:
	void *entry;
	struct sockaddr *addr;
	socklen_t addrlen;

	ConnectContext(void *e, struct sockaddr *a, socklen_t l)
	{
		entry = e;
		addr = a;
		addrlen = l;
	}
};

class ReadContext
{
public:
	void *entry;
	DWORD msgsize;
	WSABUF buffer;

	ReadContext(void *e)
	{
		entry = e;
		msgsize = 0;
	}
};

class WriteContext
{
public:
	char *buf;
	void *entry;
	WSABUF *buffers;
	DWORD count;

	WriteContext(void *e)
	{
		buf = NULL;
		entry = e;
	}

	~WriteContext()
	{
		delete []buf;
	}
};

class WinPoller
{
public:
	WinPoller(size_t poller_threads);
	~WinPoller();

	int start();
	void stop();

	int bind(HANDLE handle);
	void unbind_socket(SOCKET sockfd) const;

	int transfer(const struct poller_data *data, DWORD iobytes);
	int put_io(const struct poller_data *data, int timeout);
	int get_io_result(struct poller_result *res, int timeout);
	int cancel_pending_io(HANDLE handle) const;

	void timer_routine();

private:
	void *timer_queue_;
	std::mutex timer_mutex_;
	std::thread *timer_thread_;
	HANDLE timer_handle_;
	HANDLE iocp_;
	SOCKET lpfn_sockfd_;
	void *lpfn_connectex_;
	//void *lpfn_disconnectex_;
	volatile bool stop_;
};

#endif

/*
  Copyright (c) 2019 Sogou, Inc.

  Licensed under the Apache License, Version 2.0 (the "License");
  you may not use this file except in compliance with the License.
  You may obtain a copy of the License at

      http://www.apache.org/licenses/LICENSE-2.0

  Unless required by applicable law or agreed to in writing, software
  distributed under the License is distributed on an "AS IS" BASIS,
  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  See the License for the specific language governing permissions and
  limitations under the License.

  Authors: Wu Jiaxu (wujiaxu@sogou-inc.com)
*/

#include <Winsock2.h>
#include <Ioapiset.h>
#include <Mswsock.h>
#include <Synchapi.h>
#include <stdint.h>
#include <string.h>
#include <atomic>
#include <chrono>
#include <set>
#include "WinPoller.h"

#define GET_CURRENT_MS	std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now().time_since_epoch()).count()

#define IOCP_KEY_HANDLE		1
#define IOCP_KEY_STOP		2

static OVERLAPPED __stop_overlap;

class IOCPData
{
public:
	poller_data data;
	OVERLAPPED overlap;
	int64_t deadline;
	bool cancel_by_timer;
	bool in_rbtree;
	bool queue_out;

	IOCPData(const struct poller_data *d, int t)
	{
		data = *d;
		memset(&overlap, 0, sizeof (OVERLAPPED));
		deadline = t;
		cancel_by_timer = false;
		in_rbtree = false;
		queue_out = false;
		ref = 1;
	}

	void incref()
	{
		ref++;
	}

	void decref()
	{
		if (--ref == 0)
			delete this;
	}

private:
	~IOCPData() { }

	std::atomic<int> ref;
};

static inline bool operator<(const IOCPData& x, const IOCPData& y)
{
	if (x.deadline != y.deadline)
		return x.deadline < y.deadline;

	return (const ULONG_PTR)(&x.overlap) < (const ULONG_PTR)(&y.overlap);
}

class CMP
{
public:
	bool operator() (IOCPData *x, IOCPData *y) const
	{
		return *x < *y;
	}
};

WinPoller::WinPoller(size_t poller_threads)
{
	timer_queue_ = new std::set<IOCPData *, CMP>();
	timer_thread_ = NULL;
	stop_ = false;
	timer_handle_ = CreateWaitableTimer(NULL, FALSE, NULL);
	iocp_ = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, (DWORD)poller_threads);

	GUID GuidConnectEx = WSAID_CONNECTEX;
	//GUID GuidDisconnectEx = WSAID_DISCONNECTEX;
	DWORD dwBytes;

	lpfn_sockfd_ = socket(AF_INET, SOCK_STREAM, 0);
	if (WSAIoctl(lpfn_sockfd_, SIO_GET_EXTENSION_FUNCTION_POINTER,
				&GuidConnectEx, sizeof(GuidConnectEx),
				&lpfn_connectex_, sizeof(lpfn_connectex_),
				&dwBytes, NULL, NULL) == SOCKET_ERROR)
		lpfn_connectex_ = NULL;
/*
	if (WSAIoctl(lpfn_sockfd_, SIO_GET_EXTENSION_FUNCTION_POINTER,
				&GuidDisconnectEx, sizeof(GuidDisconnectEx),
				&lpfn_disconnectex_, sizeof(lpfn_disconnectex_),
				&dwBytes, NULL, NULL) == SOCKET_ERROR)
		lpfn_disconnectex_ = NULL;*/

	if (!timer_handle_ || !iocp_ || !lpfn_connectex_)
		abort();
}

WinPoller::~WinPoller()
{
	closesocket(lpfn_sockfd_);
	CloseHandle(iocp_);
	CloseHandle(timer_handle_);
	delete (std::set<IOCPData *, CMP> *)timer_queue_;
}

int WinPoller::start()
{
	timer_thread_ = new std::thread(&WinPoller::timer_routine, this);
	stop_ = false;
	return 0;
}

void WinPoller::stop()
{
	LARGE_INTEGER due;

	due.QuadPart = -1;
	stop_ = true;
	SetWaitableTimer(timer_handle_, &due, 0, NULL, NULL, FALSE);//通知定时器线程1ns后退出等待

	if (timer_thread_)
	{
		timer_thread_->join();
		delete timer_thread_;
		timer_thread_ = NULL;
	}

	PostQueuedCompletionStatus(iocp_, sizeof (OVERLAPPED),
							   IOCP_KEY_STOP, &__stop_overlap);
}

void WinPoller::timer_routine()
{
	auto *timer_queue = (std::set<IOCPData *, CMP> *)timer_queue_;

	while (!stop_)
	{
		if (WaitForSingleObject(timer_handle_, INFINITE) == WAIT_OBJECT_0)
		{
			std::lock_guard<std::mutex> lock(timer_mutex_);

			if (timer_queue->empty())
				continue;

			int64_t cur_ms = GET_CURRENT_MS;

			while (!timer_queue->empty())
			{
				const auto it = timer_queue->cbegin();
				IOCPData *iocp_data = *it;

				if (cur_ms < iocp_data->deadline)
				{
					LARGE_INTEGER due;

					due.QuadPart = iocp_data->deadline - cur_ms;
					due.QuadPart *= -10000;
					SetWaitableTimer(timer_handle_, &due, 0, NULL, NULL, FALSE);
					break;
				}

				iocp_data->in_rbtree = false;
				iocp_data->cancel_by_timer = true;
				if (iocp_data->data.operation == PD_OP_SLEEP)
					PostQueuedCompletionStatus(iocp_, sizeof IOCPData, IOCP_KEY_HANDLE, &iocp_data->overlap);
				else if (CancelIoEx(iocp_data->data.handle, &iocp_data->overlap) == 0 && GetLastError() == ERROR_NOT_FOUND)
					iocp_data->cancel_by_timer = false;

				timer_queue->erase(it);
				iocp_data->decref();
			}
		}
	}

	std::lock_guard<std::mutex> lock(timer_mutex_);

	while (!timer_queue->empty())
	{
		const auto it = timer_queue->cbegin();
		IOCPData *iocp_data = *it;
		iocp_data->in_rbtree = false;
		if (iocp_data->data.operation == PD_OP_SLEEP)
			PostQueuedCompletionStatus(iocp_, sizeof IOCPData, IOCP_KEY_HANDLE, &iocp_data->overlap);
		else
			CancelIoEx(iocp_data->data.handle, &iocp_data->overlap);
		
		timer_queue->erase(it);
		iocp_data->decref();
	}
}

int WinPoller::bind(HANDLE handle)
{
	if (CreateIoCompletionPort(handle, iocp_, IOCP_KEY_HANDLE, 0) != NULL)
		return 0;

	errno = GetLastError();
	return -1;
}

void WinPoller::unbind_socket(SOCKET sockfd) const
{
	CancelIoEx((HANDLE)sockfd, NULL);
	shutdown(sockfd, SD_BOTH);
}

int WinPoller::cancel_pending_io(HANDLE handle) const
{
	if (CancelIoEx(handle, NULL) != 0)
		return 0;

	errno = GetLastError();
	return -1;
}

static int __accept_io(IOCPData *iocp_data, int timeout)
{
	AcceptConext *ctx = (AcceptConext *)iocp_data->data.context;
	DWORD dwBytes;
	BOOL ret = AcceptEx((SOCKET)iocp_data->data.handle, ctx->accept_sockfd,
						ctx->buf, 0, ACCEPT_ADDR_SIZE, ACCEPT_ADDR_SIZE,
						&dwBytes, &iocp_data->overlap);
	if (ret == TRUE || WSAGetLastError() == ERROR_IO_PENDING)
	{
		if (ret != TRUE && timeout == 0)
			CancelIoEx(iocp_data->data.handle, &iocp_data->overlap);

		return 0;
	}
	else
		errno = WSAGetLastError();

	return -1;
}

static int __accept_udp_io(IOCPData* iocp_data, int timeout)
{
	UdpAcceptCtx* ctx = (UdpAcceptCtx*)iocp_data->data.context;
	DWORD dwFlag = 0, dwRecv = 0;
	int ret = WSARecvFrom((SOCKET)iocp_data->data.handle, &(ctx->wsaBuf), 1, &dwRecv, &dwFlag, (struct sockaddr*)&(ctx->remoteAddr), &ctx->remoteAddrLen, &iocp_data->overlap, NULL);
	int err = WSAGetLastError();
	if (ret == TRUE || err == ERROR_IO_PENDING)
	{
		if (ret != TRUE && timeout == 0)
			CancelIoEx(iocp_data->data.handle, &iocp_data->overlap);
		return 0;
	}
	else
		errno = WSAGetLastError();

	return -1;
}

static int __connect_io(IOCPData *iocp_data, int timeout, void *lpfn)
{
	ConnectContext *ctx = (ConnectContext *)iocp_data->data.context;
	LPFN_CONNECTEX lpfn_connectex = (LPFN_CONNECTEX)lpfn;
	BOOL ret = lpfn_connectex((SOCKET)iocp_data->data.handle,
							  ctx->addr, ctx->addrlen, NULL, 0, NULL,
							  &iocp_data->overlap);

	if (ret == TRUE || WSAGetLastError() == ERROR_IO_PENDING)
	{
		if (ret != TRUE && timeout == 0)
			CancelIoEx(iocp_data->data.handle, &iocp_data->overlap);

		return 0;
	}

	errno = WSAGetLastError();
	return -1;
}

static int __read_io(IOCPData *iocp_data, int timeout)
{
	ReadContext *ctx = (ReadContext *)iocp_data->data.context;
	DWORD Flags = 0;
	int ret = WSARecv((SOCKET)iocp_data->data.handle, &ctx->buffer, 1, NULL, &Flags, &iocp_data->overlap, NULL);

	if (ret == 0 || WSAGetLastError() == WSA_IO_PENDING)
	{
		if (ret != 0 && timeout == 0)
			CancelIoEx(iocp_data->data.handle, &iocp_data->overlap);

		return 0;
	}

	errno = WSAGetLastError();
	return -1;
}

static int __write_io(IOCPData *iocp_data, int timeout)
{
	WriteContext *ctx = (WriteContext *)iocp_data->data.context;
	int ret = WSASend((SOCKET)iocp_data->data.handle, ctx->buffers, ctx->count, NULL, 0, &iocp_data->overlap, NULL);

	if (ret == 0 || WSAGetLastError() == WSA_IO_PENDING)
	{
		if (ret != 0 && timeout == 0)
			CancelIoEx(iocp_data->data.handle, &iocp_data->overlap);

		return 0;
	}

	errno = WSAGetLastError();
	return -1;
}

static int __sleep_io(IOCPData *iocp_data, int timeout, HANDLE iocp)
{
	if (timeout == 0)
	{
		if (PostQueuedCompletionStatus(iocp, sizeof IOCPData, IOCP_KEY_HANDLE, &iocp_data->overlap) != 0)
			return 0;

		errno = GetLastError();
		return -1;
	}

	return 0;
}

int WinPoller::transfer(const struct poller_data *data, DWORD iobytes)
{
	if (data->operation != PD_OP_USER)
	{
		errno = EINVAL;
		return -1;
	}

	IOCPData *iocp_data = new IOCPData(data, -1);
	if (PostQueuedCompletionStatus(iocp_, iobytes, IOCP_KEY_HANDLE, &iocp_data->overlap) != 0)
		return 0;

	iocp_data->decref();
	errno = GetLastError();
	return -1;
}

int WinPoller::put_io(const struct poller_data *data, int timeout)
{
	auto *timer_queue = (std::set<IOCPData *, CMP> *)timer_queue_;
	IOCPData *iocp_data = new IOCPData(data, timeout);
	bool succ;

	iocp_data->incref();//for timeout
	switch (data->operation & 0xFF)
	{
	case PD_OP_READ:
		succ = (__read_io(iocp_data, timeout) >= 0);

		break;
	case PD_OP_WRITE:
		succ = (__write_io(iocp_data, timeout) >= 0);

		break;
	case PD_OP_ACCEPT:
		succ = (__accept_io(iocp_data, timeout) >= 0);

		break;
	case PD_OP_ACCEPT+100:
		succ = (__accept_udp_io(iocp_data, timeout) >= 0);
		break;
	case PD_OP_CONNECT:
		succ = (__connect_io(iocp_data, timeout, lpfn_connectex_) >= 0);

		break;
	case PD_OP_SLEEP:
		succ = (__sleep_io(iocp_data, timeout, iocp_) >= 0);

		break;
	default:
		succ = false;
		errno = EINVAL;
		break;
	}

	if (timeout <= 0)
		iocp_data->decref();

	if (!succ)
	{
		iocp_data->decref();
		return -1;
	}

	if (timeout > 0)
	{
		iocp_data->deadline += GET_CURRENT_MS;
		timer_mutex_.lock();
		if (!iocp_data->queue_out)
		{
			iocp_data->in_rbtree = true;
			timer_queue->insert(iocp_data);
			if (*timer_queue->cbegin() == iocp_data)
			{
				LARGE_INTEGER due;

				due.QuadPart = timeout;
				due.QuadPart *= -10000;
				SetWaitableTimer(timer_handle_, &due, 0, NULL, NULL, FALSE);
			}
		}

		timer_mutex_.unlock();
	}

	return 0;
}

static void __accept_on_success(struct poller_result *res)
{
	SOCKET listen_sockfd = (SOCKET)res->data.handle;
	AcceptConext *ctx = (AcceptConext *)res->data.context;
	struct sockaddr *local;
	struct sockaddr *remote;
	int local_len = sizeof (struct sockaddr);
	int remote_len = sizeof (struct sockaddr);
	int seconds;
	int seconds_len = sizeof (int);

	if (getsockopt(ctx->accept_sockfd, SOL_SOCKET, SO_CONNECT_TIME, (char *)&seconds, &seconds_len) == 0)
	{
		if (setsockopt(ctx->accept_sockfd, SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, (char*)&listen_sockfd, sizeof (listen_sockfd)) == 0)
		{
			GetAcceptExSockaddrs(ctx->buf, 0, ACCEPT_ADDR_SIZE, ACCEPT_ADDR_SIZE, &local, &local_len, &remote, &remote_len);
			ctx->remote = remote;
			ctx->remote_len = remote_len;
			return;
		}
	}

	res->state = PR_ST_ERROR;
	res->error = WSAGetLastError();
}

static void __connect_on_success(struct poller_result *res)
{
	SOCKET sockfd = (SOCKET)res->data.handle;
	ConnectContext *ctx = (ConnectContext *)res->data.context;
	int seconds;
	int seconds_len = sizeof (int);

	if (getsockopt(sockfd, SOL_SOCKET, SO_CONNECT_TIME, (char *)&seconds, &seconds_len) == 0)
	{
		//if (seconds == 0xFFFFFFFF) error?
		if (setsockopt(sockfd, SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, NULL, 0) == 0)
			return;
	}

	res->state = PR_ST_ERROR;
	res->error = WSAGetLastError();
}

int WinPoller::get_io_result(struct poller_result *res, int timeout)
{
	DWORD bytes_transferred;
	ULONG_PTR completion_key;
	OVERLAPPED* pOverlapped;
	DWORD dwMilliseconds;

	if (stop_)
		dwMilliseconds = 100;
	else if (timeout >= 0)
		dwMilliseconds = timeout;
	else
		dwMilliseconds = INFINITE;

	if (GetQueuedCompletionStatus(iocp_, &bytes_transferred, &completion_key,
								  &pOverlapped, dwMilliseconds) == FALSE)
	{
		res->state = PR_ST_ERROR;
		res->error = GetLastError();
		if (pOverlapped == NULL && res->error == ERROR_ABANDONED_WAIT_0)
			return -1;// IOCP closed

		if (res->error == ERROR_OPERATION_ABORTED)
			res->state = PR_ST_STOPPED;
	}
	else if (pOverlapped == NULL)
	{
		// An unrecoverable error occurred in the completion port.
		// Wait for the next notification
		res->state = PR_ST_ERROR;
		res->error = ENOENT;
	}
	else if (bytes_transferred == 0)
	{
		res->state = PR_ST_FINISHED;
		res->error = ECONNRESET;
	}
	else
	{
		res->state = PR_ST_SUCCESS;
		res->error = 0;
	}

	if (!pOverlapped)
		return 0;

	res->iobytes = bytes_transferred;
	if (completion_key == IOCP_KEY_STOP)
	{
		PostQueuedCompletionStatus(iocp_, sizeof (OVERLAPPED),
								   IOCP_KEY_STOP, &__stop_overlap);

		//return 0;
		return -1;// Thread over
	}

	IOCPData *iocp_data = CONTAINING_RECORD(pOverlapped, IOCPData, overlap);

	if (iocp_data->deadline > 0)// timeout > 0
	{
		timer_mutex_.lock();
		iocp_data->queue_out = true;
		if (iocp_data->in_rbtree)
		{
			iocp_data->in_rbtree = false;
			((std::set<IOCPData *, CMP> *)timer_queue_)->erase(iocp_data);
			iocp_data->decref();
		}

		timer_mutex_.unlock();

		if (res->state == PR_ST_STOPPED)
		{
			std::lock_guard<std::mutex> lock(timer_mutex_);

			if (iocp_data->cancel_by_timer)
			{
				res->state = PR_ST_TIMEOUT;
				res->error = ETIMEDOUT;
			}
		}
	}
	else if (iocp_data->deadline == 0 && res->state == PR_ST_STOPPED)// timeout == 0
	{
		res->state = PR_ST_TIMEOUT;
		res->error = ETIMEDOUT;
	}

	res->data = iocp_data->data;
	if (res->state == PR_ST_SUCCESS || res->state == PR_ST_FINISHED)
	{
		switch (res->data.operation)
		{
		case PD_OP_ACCEPT:
			__accept_on_success(res);

			break;
		case PD_OP_CONNECT:
			__connect_on_success(res);

			break;
		}
	}

	iocp_data->decref();

	return 1;
}

Logo

瓜分20万奖金 获得内推名额 丰厚实物奖励 易参与易上手

更多推荐