mpioで作るメモリキャッシュサーバー

高速なイベント駆動IOライブラリ mpioがそこそこ動くようになったので、ちょっとサンプルプログラムを作ってみました。

memcachedプロトコルと互換性のあるメモリキャッシュサーバーだったりするとカッコイイと思ったのですが、即席なので全然関係ないメモリキャッシュサーバーになりました><

バイナリベースのプロトコルで、キーと値をセットするsetと、キーに対応する値を取ってくるgetと、キーを削除するdeleteができます。Linux(IO多重化はepoll)とMac OS X(select)で動くことを確認しました。


ソースコードiosample.tar.gz


…ちなみに、このプログラムにはバグがあります。構造体をそのままソケットに送りつけているのですが、構造体のアラインメントを考慮していません。ダメダメです。あとクライアントの実装がヘタレすぎるとか…そのあたりキリが無いですがご勘弁を><


mpioは抽象度が低い順にmp::event、mp::dispatch、mp::io、mp::asioがありますが、mp::ioはまだ中途半端、mp::asioに至っては一切手つかずなので、今回はmp::dispatchを使っています。ファイルディスクリプタにイベントが発生したら関数を呼ぶというイベント駆動型のプログラムが書けます。可変長メモリプール(mp::mempool)も使っています。

  • サーバー:server.cc
#include <mp/dispatch.h>
#include <mp/mempool.h>
#include <kazuhiki/basic.h>
#include <kazuhiki/network.h>
#include <unistd.h>
#include <errno.h>
#include <fcntl.h>
#include <deque>
#include <map>
#include <string>
#include <iostream>
#include "iosample.h"


static const size_t MIN_BUFFER_SIZE = 32 * 1024;
class buffer_t {
public:
	buffer_t(size_t needs);
	~buffer_t();
	char* rbuffer()            { return m_pos; }
	char* wbuffer()            { return m_pos + m_remain; }
	size_t available() const  { return m_avail; }
	size_t remain()    const  { return m_remain; }
	void produced(size_t len) { m_remain += len; }
	void consumed(size_t len) { m_remain -= len;  m_pos += len; }
private:
	size_t m_avail;
	size_t m_remain;
	char* m_buf;
	char* m_pos;
private:
	buffer_t();
	buffer_t(const buffer_t&);
};

class client_t {
public:
	client_t(int sock);
	~client_t();
public:
	typedef std::deque<buffer_t*> queue_t;
	queue_t in_queue;
	queue_t out_queue;
	void interpret_request();
	void operate_request(request_header_t& header, char* arg, size_t arg_len);
	void append_out_queue(size_t len, char* buf);
private:
	int m_sock;
	bool in_queue_has(size_t needs);
	template <bool Reference, bool Consume>
	void use_in_queue(size_t len, char* buf = NULL);
private:
	client_t();
	client_t(const client_t&);
};


int io_accept(int fd, short event, client_t& c);
int io_client(int fd, short event, client_t& c);

mp::mempool<> mempool;

typedef mp::dispatch<client_t> mpdispatch;
mpdispatch mpdp;

struct value_t {
	char* ptr;
	size_t len;
};

typedef std::map<std::string, value_t> cache_t;
cache_t cache;


buffer_t::buffer_t(size_t needs) :
		m_avail(std::max(needs, MIN_BUFFER_SIZE)),
		m_remain(0),
		m_buf((char*)mempool.malloc(m_avail)),
		m_pos(m_buf) {
	if(!m_buf) { throw std::bad_alloc(); }
}
buffer_t::~buffer_t() {
	mempool.free(m_buf);
}


client_t::client_t(int sock) : m_sock(sock) {}
client_t::~client_t()
{
	for(queue_t::iterator it(in_queue.begin()), it_end(in_queue.end());
			it != it_end;
			++it ) {
		delete *it;
	}
	for(queue_t::iterator it(out_queue.begin()), it_end(out_queue.end());
			it != it_end;
			++it ) {
		delete *it;
	}
	close(m_sock);
}

void client_t::interpret_request()
{
	while( in_queue_has(sizeof(request_header_t)) ) {
		request_header_t header;
		use_in_queue<true, false>(sizeof(header), (char*)&header);  // reference

		uint32_t arg_len = ntohl(header.arg_len);
		if( !in_queue_has(sizeof(header) + arg_len) ) { return; }

		char* arg = (char*)mempool.malloc(arg_len);
		if( arg == NULL ) { throw std::bad_alloc(); }

		use_in_queue<false, true>(sizeof(header));  // consume
		use_in_queue<true, true>(arg_len, arg);  // reference and consume
		operate_request(header, arg, arg_len);

		mempool.free(arg);
	}
}

void client_t::operate_request(request_header_t& header, char* arg, size_t arg_len)
{
	switch(header.op) {
	case REQ_GET:
	{
		response_header_t<op_get_response> res;
		res.op = REQ_GET;
		std::string key(arg, arg_len);
		cache_t::iterator found( cache.find(key) );
		if( found != cache.end() ) {
			value_t& value( found->second );
			res.arg_len = htonl(sizeof(op_get_response) + value.len);
			res.arg.exists = 1;
			append_out_queue(sizeof(res), (char*)&res);
			append_out_queue(value.len, value.ptr);
		} else {
			res.arg_len = htonl(sizeof(res));
			res.arg.exists = 0;
			append_out_queue(sizeof(res), (char*)&res);
		}
		break; }
	case REQ_SET:
	{
		op_set_request* req = (op_set_request*)arg;
		arg += sizeof(op_set_request);
		size_t key_len = ntohl(req->key_len);
		std::string key(arg, key_len);
		arg += key_len;
		size_t value_len = arg_len - sizeof(op_set_request) - key_len;
		value_t value;
		value.ptr = (char*)mempool.malloc(value_len);
		response_header_t<op_set_response> res;
		res.op = REQ_SET;
		res.arg_len = sizeof(op_set_response);
		if( value.ptr != NULL ) {
			memcpy(value.ptr, arg, value_len);
			value.len = value_len;
			cache_t::iterator found( cache.find(key) );
			if( found != cache.end() ) {
				mempool.free(found->second.ptr);
				cache.erase(found);
			}
			res.arg.success = cache.insert(std::make_pair(key, value)).second;
		} else {
			res.arg.success = 0;
		}
		append_out_queue(sizeof(res), (char*)&res);
		break; }
	case REQ_DELETE:
	{
		response_header_t<op_delete_response> res;
		std::string key(arg, arg_len);
		cache_t::iterator found( cache.find(key) );
		res.op = REQ_DELETE;
		res.arg_len = htonl(sizeof(op_delete_response));
		if( found != cache.end() ) {
			mempool.free(found->second.ptr);
			cache.erase(found);
			res.arg.success = 1;
		} else {
			res.arg.success = 0;
		}
		append_out_queue(sizeof(res), (char*)&res);
		break; }
	}
}

bool client_t::in_queue_has(size_t needs)
{
	size_t len = 0;
	for(queue_t::iterator it(in_queue.begin()), it_end(in_queue.end());
			it != it_end;
			++it ) {
		len += (*it)->remain();
		if( len >= needs ) { return true; }
	}
	return false;
}

template <bool Reference, bool Consume>
void client_t::use_in_queue(size_t len, char* buf)
{
	while(1) {
		buffer_t* buffer = in_queue.front();
		if( buffer->remain() > len ) {
			if( Reference ) { memcpy(buf, buffer->rbuffer(), len); }
			if( Consume   ) { buffer->consumed(len); }
			return;
		} else if( buffer->remain() == len ) {
			if( Reference ) { memcpy(buf, buffer->rbuffer(), len); }
			if( Consume   ) { delete in_queue.front(); in_queue.pop_front(); }
			return;
		} else {
			if( Reference ) { memcpy(buf, buffer->rbuffer(), buffer->remain()); }
			buf += buffer->remain();
			len -= buffer->remain();
			if( Consume   ) { delete in_queue.front(); in_queue.pop_front(); }
		}
	}
}

void client_t::append_out_queue(size_t len, char* buf)
{
	if( !out_queue.empty() &&
			out_queue.front()->available() >= len ) {
		buffer_t* buffer = out_queue.front();
		memcpy(buffer->wbuffer(), buf, len);
		buffer->produced(len);
	} else {
		buffer_t* buffer = new buffer_t(len);
		memcpy(buffer->wbuffer(), buf, len);
		buffer->produced(len);
		out_queue.push_back(buffer);
	}
}


int io_accept(int fd, short event, client_t& c)
{
	int sock = accept(fd, NULL, NULL);
	if( sock < 0 ) {
		if( errno == EAGAIN || errno == EINTR ) { return 0; }
		else { return -1; }
	}

	if( fcntl(sock, F_SETFL, O_NONBLOCK) < 0 ) {
		close(sock);
		perror("set socket nonblock");
		return 0;
	}

	if( mpdp.add(sock, mp::EV_READ, io_client, sock) < 0 ) {
		perror("mp.add sock");
		return 0;
	}

	return 0;
}

int io_client(int fd, short event, client_t& c)
{
	ssize_t len;
	if( event & mp::EV_READ ) {
		if( c.in_queue.empty() ) {
			c.in_queue.push_back(new buffer_t(0));
		}
		buffer_t* buffer = c.in_queue.back();
		len = read(fd, buffer->wbuffer(), buffer->available());
		if( len <= 0 ) {
			if( errno == EAGAIN || errno == EINTR ) { return 0; }
			else { return mpdp.remove(fd, mp::EV_READ); }
		}
		buffer->produced(len);
		if( buffer->available() == 0 ) {
			c.in_queue.push_back(new buffer_t(0));
		}
		bool write_waiting = !c.out_queue.empty();
		c.interpret_request();
		if( !write_waiting && !c.out_queue.empty() ) {
			mpdp.modify(fd, mp::EV_READ, mp::EV_READ|mp::EV_WRITE);
		}
	}
	if( event & mp::EV_WRITE ) {
		buffer_t* buffer = c.out_queue.back();
		len = write(fd, buffer->rbuffer(), buffer->remain());
		if( len <= 0 ) {
			if( errno == EAGAIN || errno == EINTR ) { return 0; }
			else { return mpdp.remove(fd, mp::EV_READ|mp::EV_WRITE); }
		}
		buffer->consumed(len);
		if( buffer->remain() == 0 ) {
			c.out_queue.pop_front(); delete buffer;
		}
		if( c.out_queue.empty() ) {
			mpdp.modify(fd, mp::EV_READ|mp::EV_WRITE, mp::EV_READ);
		}
	}
	return 0;
}


void usage(void)
{
	std::cout
		<< "Usage: cb <listen address>[:<port number>]"
		<< std::endl;
	exit(1);
}

int main(int argc, char* argv[])
{
	struct sockaddr_storage opt_addr;
	bool opt_verbose;
	try {
		using namespace Kazuhiki;
		--argc; ++argv;  // skip argv[0]
		Parser opt;
		opt.on("-v", "--verbose", Accept::Boolean(opt_verbose));
		opt.break_parse(argc, argv);   // parse!
		if( argc < 1 ) { throw ArgumentError(""); }
		Convert::FlexibleListtenAddress(argv[0], opt_addr, DEFAULT_PORT);

	} catch (Kazuhiki::ArgumentError& e) {
		std::cout << e.what() << std::endl;
		usage();
		return 1;
	}

	int lsock = socket(opt_addr.ss_family, SOCK_STREAM, 0);
	if( lsock < 0 ) { pexit("lsock"); }
	int on = 1;
	if( setsockopt(lsock, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) < 0 ) {
		pexit("setsockopt"); }
	socklen_t addr_len = (opt_addr.ss_family == AF_INET) ?
		sizeof(struct sockaddr_in) :
		sizeof(struct sockaddr_in6);
	if( bind(lsock, (struct sockaddr*)&opt_addr, addr_len) < 0 ) {
		pexit("bind"); }
	if( listen(lsock, 5) < 0 ) { pexit("listen"); }

	if( fcntl(lsock, F_SETFL, O_NONBLOCK) < 0 ) { pexit("set socket nonblock"); }

	if( mpdp.add(lsock, mp::EV_READ, io_accept, lsock) < 0 ) { pexit("mp.add lsock"); }

	return mpdp.run();
}


重要なのは最後のmain()関数と、io_accept()関数、io_client()関数です。その他は出来心です。

main()関数を見てみると、コマンドライン引数の解析に怪しげなライブラリが使われていますが、そこはスルーして、最後のmpdp.add(...)がポイントです。ここで「ソケット(lsock変数)が読み込み可能になったら、io_accept関数を呼んでくれ」と登録しています。

登録するとき、引数の最後にlsockを渡していますが、これはclient_tクラスのコンストラクタに渡されます。mpdp.add()のmpdpはmp::dispatchクラスのインスタンス(上の方にグローバル変数で宣言してあります)で、ファイルディスクリプタを1つ登録するごとに、client_tクラスのインスタンスが1つ確保されるようになっています。そのときのコンストラクタに渡される引数をmp.add()で指定できるというわけです。


ここで、client_tクラスはデフォルトコンストラクタもコピーコンストラクタもprivateで封印されている点は要注目です。client_tクラスの寿命はぴったりmpdpにaddされてからremoveされるまでで、コピーは一度も行われません。このため、client_tクラスのデストラクタでクライアントの接続が切れたときの後始末(ファイルディスクリプタをcloseしたり、mallocしたバッファをfreeするなど)ができます。(このあたりの実装の詳細はソースコードの中のmp/sparse_array.hを参照)


lsockが読み込み可能可能になると、io_accept()関数が呼ばれます。ここでクライアントをaccept()して、mpdp.add(sock, mp::EV_READ, io_client, sock)でmp::dispatchに登録しています。これでクライアントからのソケットが読み込み可能になると、io_client()関数が呼ばれます。

io_client関数では、ソケットが読み込み可能になっていたら、要求を読み込んで、キューに追加します。そして実際に要求キューを処理し、応答キューにデータを追加します。応答キューにデータが溜まったら、mpdp.modify()を呼び出して、書き込み可能になった場合でもio_client関数が呼ばれるように変更しています。
書き込み可能になったら、応答キューからデータを取り出して、クライアントに応答を返します。これで一連の処理が終了します。めでたしめでたし。


mpdpやmempoolがグローバル変数だったりと、やや難ありなコードですが、これをクラスにまとめることもできます。そうするとio_accept()関数やio_client()関数はメンバ関数になりますが、mpdp.add()は関数オブジェクトも登録できるので、mp::bind(&FooClass::io_accept, this, _1, _2)でメンバ関数をbindしてやればOKです。
C++0x tr1のfunctionalを先取りしていますが、gccにはもう入っているので使ってしまいます。boostでも代用できます)


別のサンプルでは、CodeRepos:/lang/c/parttyはもう少しマシなコードになっています。こちらはmultiplexer.ccでmp::event、host.ccでmp::dispatch、server.ccでmp::ioを使っています(mp::ioはまだ試行錯誤中ですが)。
(ちなみに、怪しげなコマンドラインパーサ「Kazuhiki」もCodeRepos:/lang/c/parttyに入っています)



↓クライアントはオマケです。へたれすぎます><

  • クライアント:client.cc
#include <mp/dispatch.h>
#include <mp/mempool.h>
#include <kazuhiki/basic.h>
#include <kazuhiki/network.h>
#include <unistd.h>
#include <errno.h>
#include <fcntl.h>
#include <deque>
#include <map>
#include <string>
#include <iostream>
#include "iosample.h"

mp::mempool<> mempool;

bool send_set(int sock, const char* key, size_t key_len,
		const char* value, size_t value_len)
{
	request_header_t req;
	op_set_request req_set;
	req.op = REQ_SET;
	req.arg_len = htonl(sizeof(req_set) + key_len + value_len);
	req_set.key_len = htonl(key_len);
	write(sock, &req, sizeof(req));
	write(sock, &req_set, sizeof(req_set));
	write(sock, key, key_len);
	write(sock, value, value_len);

	response_header_t<op_set_response> res;
	read(sock, &res, sizeof(res));
	return res.arg.success == 1;
}

bool send_get(int sock, const char* key, size_t key_len, char** value, size_t* value_len)
{
	request_header_t req;
	req.op = REQ_GET;
	req.arg_len = htonl(key_len);
	write(sock, &req, sizeof(req));
	write(sock, key, key_len);

	response_header_t<op_get_response> res;
	read(sock, &res, sizeof(res));
	if( res.arg.exists != 1 ) {
		*value = NULL;
		*value_len = 0;
		return false;
	} else {
		*value_len = ntohl(res.arg_len) - sizeof(op_get_response);
		*value = (char*)mempool.malloc(*value_len);
		read(sock, *value, *value_len);
		return true;
	}
}

bool send_delete(int sock, const char* key, size_t key_len)
{
	request_header_t req;
	req.op = REQ_DELETE;
	req.arg_len = htonl(key_len);
	write(sock, &req, sizeof(req));
	write(sock, key, key_len);

	response_header_t<op_delete_response> res;
	read(sock, &res, sizeof(res));
	return res.arg.success == 1;
}


void test_set(int sock, const char* key, const char* value)
{
	std::cout << "setting  key " << key << " = " << value << ": " << std::flush;
	bool res = send_set(sock, key, strlen(key), value, strlen(value));
	std::cout << (res ? "success" : "failed") << std::endl;
}

void test_get(int sock, const char* key)
{
	std::cout << "getting  key " << key << " = ";
	char* value;
	size_t value_len;
	bool res = send_get(sock, key, strlen(key), &value, &value_len);
	if(res) {
		std::cout.write(value, value_len);
		std::cout << std::endl;
		mempool.free(value);
	} else {
		std::cout << "<NOT FOUND>" << std::endl;
	}
}

void test_delete(int sock, const char* key)
{
	std::cout << "deleting key " << key << ": ";
	bool res = send_delete(sock, key, strlen(key));
	std::cout << (res ? "success" : "failed") << std::endl;
}


void usage(void)
{
	std::cout
		<< "Usage: cb <listen address>[:<port number>]"
		<< std::endl;
	exit(1);
}

int main(int argc, char* argv[])
{
	struct sockaddr_storage opt_addr;
	bool opt_verbose;
	try {
		using namespace Kazuhiki;
		--argc; ++argv;  // skip argv[0]
		Parser opt;
		opt.on("-v", "--verbose", Accept::Boolean(opt_verbose));
		opt.break_parse(argc, argv);   // parse!
		if( argc < 1 ) { throw ArgumentError(""); }
		Convert::FlexibleActiveHost(argv[0], opt_addr, DEFAULT_PORT);

	} catch (Kazuhiki::ArgumentError& e) {
		std::cout << e.what() << std::endl;
		usage();
		return 1;
	}

	int sock = socket(opt_addr.ss_family, SOCK_STREAM, 0);
	if( sock < 0 ) { pexit("sock"); }
	socklen_t addr_len = (opt_addr.ss_family == AF_INET) ?
		sizeof(struct sockaddr_in) :
		sizeof(struct sockaddr_in6);
	if( connect(sock, (struct sockaddr*)&opt_addr, addr_len) < 0 ) {
		pexit("connect");
	}


	test_set(sock, "test", "TEST_VALUE");
	test_get(sock, "test");

	test_set(sock, "test", "overwrite");
	test_get(sock, "test");

	test_get(sock, "no_value");

	test_delete(sock, "test");
	test_get(sock, "test");
}
  1. "test" = "TEST_VALUE"をセット
  2. "test"をゲットしてみる
  3. "test" = "overwrite"をセット
  4. "test"をゲットしてみる
  5. "no_value"をゲットしてみる
  6. "test"を削除してみる
  7. "test"をゲットしてみる

という要求を行っています。実際に実行すると、↓こうなります。

$ ./server 2000
$ ./client localhost:2000  # 別の端末で
setting  key test = TEST_VALUE: success
getting  key test = TEST_VALUE
setting  key test = overwrite: success
getting  key test = overwrite
getting  key no_value = <NOT FOUND>
deleting key test: success
getting  key test = <NOT FOUND>

確かにset/get/deleteができているようです。





  • ヘッダ:iosample.h

ふつうのヘッダです。response_header_tがテンプレートになっているあたりが変態仕様ですが、受信したバッファをresponse_header_t<応答の種類>型にキャストして手抜きしてやろうという魂胆です。(しかしそれはアラインメントが…)

#ifndef IOSAMPLE_H__
#define IOSAMPLE_H__

#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>

static const uint8_t REQ_GET = 0;
static const uint8_t REQ_SET = 1;
static const uint8_t REQ_DELETE = 2;


struct request_header_t {
	uint8_t op;
	uint32_t arg_len;
	// char arg[arg_len];
};


template <typename T>
struct response_header_t {
	uint8_t op;
	uint32_t arg_len;
	T arg;
};

struct op_get_response {
	uint8_t exists;
};

struct op_set_request {
	uint32_t key_len;
};
struct op_set_response {
	uint8_t success;
};

struct op_delete_response {
	uint8_t success;
};


inline void pexit(const char* msg) {
	perror(msg);
	exit(1);
}

unsigned short DEFAULT_PORT = 2000;

#endif /* iosample.h */