serval-dna/msp_proxy.c
2014-05-15 16:38:22 +09:30

618 lines
16 KiB
C

/*
Mesh Streaming Protocol (MSP)
Copyright (C) 2013-2014 Serval Project Inc.
This program is free software; you can redistribute it and/or
modify it under the terms of the GNU General Public License
as published by the Free Software Foundation; either version 2
of the License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
#include "serval.h"
#include "str.h"
#include "strbuf.h"
#include "strbuf_helpers.h"
#include "dataformats.h"
#include "mdp_client.h"
#include "msp_client.h"
#include "socket.h"
struct buffer{
size_t position;
size_t limit;
size_t capacity;
uint8_t bytes[];
};
struct connection{
struct sched_ent alarm_in;
struct sched_ent alarm_out;
MSP_SOCKET sock;
struct buffer *in;
struct buffer *out;
char eof;
int last_state;
};
int saw_error=0;
int once =0;
MSP_SOCKET listener = MSP_SOCKET_NULL;
struct mdp_sockaddr remote_addr;
struct socket_address ip_addr;
static int try_send(struct connection *conn);
static void msp_poll(struct sched_ent *alarm);
static void service_poll(struct sched_ent *alarm);
static void listen_poll(struct sched_ent *alarm);
static void io_poll(struct sched_ent *alarm);
static MSP_HANDLER msp_handler;
struct profile_total mdp_sock_stats={
.name="msp_poll"
};
struct sched_ent mdp_sock={
.poll.revents = 0,
.poll.events = POLLIN,
.poll.fd = -1,
.function = msp_poll,
.stats = &mdp_sock_stats,
};
struct profile_total service_sock_stats={
.name="service_poll"
};
struct sched_ent service_sock={
.poll.revents = 0,
.poll.events = POLLIN,
.poll.fd = -1,
.function = service_poll,
.stats = &service_sock_stats,
};
struct profile_total io_stats={
.name="io_stats"
};
struct profile_total listen_stats={
.name="listen_poll"
};
struct sched_ent listen_alarm={
.poll.revents = 0,
.poll.events = POLLIN,
.poll.fd = -1,
.function = listen_poll,
.stats = &listen_stats,
};
const char *service_name=NULL;
mdp_port_t service_port;
static struct connection *alloc_connection(
MSP_SOCKET sock,
int fd_in,
void (*func_in)(struct sched_ent *alarm),
int fd_out,
void (*func_out)(struct sched_ent *alarm))
{
struct connection *conn = emalloc_zero(sizeof(struct connection));
if (!conn)
return NULL;
conn->sock = sock;
conn->alarm_in.poll.fd = fd_in;
conn->alarm_in.poll.events = POLLIN;
conn->alarm_in.function = func_in;
conn->alarm_in.stats = &io_stats;
conn->alarm_in.context = conn;
conn->alarm_out.poll.fd = fd_out;
conn->alarm_out.poll.events = POLLOUT;
conn->alarm_out.function = func_out;
conn->alarm_out.stats = &io_stats;
conn->alarm_out.context = conn;
watch(&conn->alarm_in);
conn->in = emalloc(1024 + sizeof(struct buffer));
if (!conn->in){
free(conn);
return NULL;
}
conn->out = emalloc(1024 + sizeof(struct buffer));
if (!conn->out){
free(conn->in);
free(conn);
return NULL;
}
conn->in->position = conn->out->position = 0;
conn->in->limit = conn->out->limit = 0;
conn->in->capacity = conn->out->capacity = 1024;
return conn;
}
static void free_connection(struct connection *conn)
{
if (!conn)
return;
if (!msp_socket_is_closed(conn->sock)){
msp_set_handler(conn->sock, msp_handler, NULL);
msp_close(conn->sock);
}
if (is_watching(&conn->alarm_in))
unwatch(&conn->alarm_in);
if (is_watching(&conn->alarm_out))
unwatch(&conn->alarm_out);
if (conn->in)
free(conn->in);
if (conn->out)
free(conn->out);
if (conn->alarm_in.poll.fd!=-1)
close(conn->alarm_in.poll.fd);
if (conn->alarm_out.poll.fd!=-1 && conn->alarm_out.poll.fd != conn->alarm_in.poll.fd)
close(conn->alarm_out.poll.fd);
conn->in=NULL;
conn->out=NULL;
conn->alarm_in.poll.fd=-1;
conn->alarm_out.poll.fd=-1;
free(conn);
if (msp_socket_count()==0 && is_watching(&mdp_sock))
unwatch(&mdp_sock);
}
static void process_msp_asap()
{
mdp_sock.alarm = gettime_ms();
mdp_sock.deadline = mdp_sock.alarm+10;
unschedule(&mdp_sock);
schedule(&mdp_sock);
}
static void remote_shutdown(struct connection *conn)
{
struct mdp_sockaddr remote;
if (shutdown(conn->alarm_out.poll.fd, SHUT_WR))
WARNF_perror("shutdown(%d)", conn->alarm_out.poll.fd);
msp_get_remote(conn->sock, &remote);
INFOF(" - Connection with %s:%d remote shutdown", alloca_tohex_sid_t(remote.sid), remote.port);
}
static void local_shutdown(struct connection *conn)
{
struct mdp_sockaddr remote;
msp_get_remote(conn->sock, &remote);
msp_shutdown(conn->sock);
INFOF(" - Connection with %s:%d local shutdown", alloca_tohex_sid_t(remote.sid), remote.port);
}
static size_t msp_handler(MSP_SOCKET sock, msp_state_t state, const uint8_t *payload, size_t len, void *context)
{
struct connection *conn = context;
if (!conn)
return 0;
if (state & MSP_STATE_ERROR)
saw_error=1;
if (payload && len){
if (conn->out->limit){
// attempt to write immediately
conn->alarm_out.poll.revents=POLLOUT;
conn->alarm_out.function(&conn->alarm_out);
}
if (conn->out->capacity < len + conn->out->limit)
return 0;
bcopy(payload, &conn->out->bytes[conn->out->limit], len);
conn->out->limit+=len;
conn->alarm_out.poll.events|=POLLOUT;
watch(&conn->alarm_out);
// attempt to write immediately
conn->alarm_out.poll.revents=POLLOUT;
conn->alarm_out.function(&conn->alarm_out);
}
if ((state & MSP_STATE_SHUTDOWN_REMOTE) && !(conn->last_state & MSP_STATE_SHUTDOWN_REMOTE) && conn->out->limit==0)
remote_shutdown(conn);
conn->last_state=state;
if (state & MSP_STATE_CLOSED){
struct mdp_sockaddr remote;
msp_get_remote(sock, &remote);
INFOF(" - Connection with %s:%d closed", alloca_tohex_sid_t(remote.sid), remote.port);
conn->sock = MSP_SOCKET_NULL;
if (is_watching(&conn->alarm_in))
unwatch(&conn->alarm_in);
if (!is_watching(&conn->alarm_out))
free_connection(conn);
assert(len == 0);
return 0;
}
if (state&MSP_STATE_DATAOUT)
try_send(conn);
return len;
}
static size_t msp_listener(MSP_SOCKET sock, msp_state_t state, const uint8_t *payload, size_t len, void *UNUSED(context))
{
if (state & MSP_STATE_ERROR){
WHY("Error listening for incoming connections");
}
if (state & MSP_STATE_CLOSED){
if (msp_socket_count()==0){
unschedule(&mdp_sock);
if (is_watching(&mdp_sock))
unwatch(&mdp_sock);
}
return len;
}
struct mdp_sockaddr remote;
msp_get_remote(sock, &remote);
INFOF(" - New connection from %s:%d", alloca_tohex_sid_t(remote.sid), remote.port);
int fd_in = STDIN_FILENO;
int fd_out = STDOUT_FILENO;
if (ip_addr.addrlen){
int fd = esocket(PF_INET, SOCK_STREAM, 0);
if (fd==-1){
msp_close(sock);
return 0;
}
if (socket_connect(fd, &ip_addr)==-1){
msp_close(sock);
close(fd);
return 0;
}
fd_in = fd_out = fd;
}
struct connection *conn = alloc_connection(sock, fd_in, io_poll, fd_out, io_poll);
if (!conn)
return 0;
conn->sock = sock;
msp_set_handler(sock, msp_handler, conn);
if (payload)
return msp_handler(sock, state, payload, len, conn);
if (once){
// stop listening after the first incoming connection
msp_close(listener);
}
assert(len == 0);
return 0;
}
static void msp_poll(struct sched_ent *alarm)
{
if (alarm->poll.revents & POLLIN)
// process incoming data packet
msp_recv(alarm->poll.fd);
// do any timed actions that need to be done, either in response to receiving or due to a timed alarm.
msp_processing(&alarm->alarm);
if (alarm->alarm){
time_ms_t now = gettime_ms();
if (alarm->alarm < now)
alarm->alarm = now;
alarm->deadline = alarm->alarm +10;
unschedule(alarm);
schedule(alarm);
}
}
static void service_poll(struct sched_ent *alarm){
if (alarm->poll.revents & POLLIN){
struct mdp_header header;
uint8_t payload[256];
ssize_t len = mdp_recv(alarm->poll.fd, &header, payload, sizeof payload);
if (len==-1)
return;
if (header.flags & (MDP_FLAG_ERROR|MDP_FLAG_BIND))
return;
if (is_sid_t_broadcast(header.local.sid))
header.local.sid = SID_ANY;
len = snprintf((char*)payload, sizeof payload, "%s.msp.port=%d", service_name, service_port);
mdp_send(alarm->poll.fd, &header, payload, len);
}
}
static int try_send(struct connection *conn)
{
if (!conn->in->limit)
return 0;
if (msp_send(conn->sock, conn->in->bytes, conn->in->limit)==-1)
return 0;
// if this packet was acceptted, clear the read buffer
conn->in->limit = conn->in->position = 0;
// hit end of data?
if (conn->eof){
local_shutdown(conn);
}else{
conn->alarm_in.poll.events|=POLLIN;
watch(&conn->alarm_in);
}
return 1;
}
static void io_poll(struct sched_ent *alarm)
{
struct connection *conn = alarm->context;
if (alarm->poll.revents & POLLIN) {
size_t remaining = conn->in->capacity - conn->in->limit;
if (remaining>0){
ssize_t r = read(alarm->poll.fd,
conn->in->bytes + conn->in->limit,
remaining);
if (r<0){
WARNF_perror("read(%d)", alarm->poll.fd);
alarm->poll.revents = POLLERR;
}
if (r==0){
// EOF
r=-1;
alarm->poll.revents = POLLHUP;
}
if (r>0){
conn->in->limit+=r;
if (try_send(conn))
process_msp_asap();
}
}
// stop reading input when the buffer is full
if (conn->in->limit==conn->in->capacity){
alarm->poll.events &= ~POLLIN;
if (alarm->poll.events)
watch(alarm);
else if (is_watching(alarm))
unwatch(alarm);
}
}
if (alarm->poll.revents & POLLOUT) {
// try to write some data
size_t data = conn->out->limit-conn->out->position;
if (data>0){
ssize_t r = write(alarm->poll.fd,
conn->out->bytes+conn->out->position,
data);
if (r < 0)
WARNF_perror("write(%d)", alarm->poll.fd);
if (r > 0)
conn->out->position+=r;
}
// if the buffer is empty now, reset it and unwatch the handle
if (conn->out->position==conn->out->limit){
conn->out->limit = conn->out->position = 0;
alarm->poll.events &= ~POLLOUT;
if (alarm->poll.events)
watch(alarm);
else if (is_watching(alarm))
unwatch(alarm);
if (conn->last_state & MSP_STATE_SHUTDOWN_REMOTE)
remote_shutdown(conn);
}
if (conn->out->limit < conn->out->capacity){
if (!msp_socket_is_null(conn->sock)){
process_msp_asap();
}else{
free_connection(conn);
}
}
}
if (alarm->poll.revents & POLLHUP) {
// EOF? trigger a graceful shutdown
conn->eof=1;
alarm->poll.events &= ~POLLIN;
if (alarm->poll.events)
watch(alarm);
else if (is_watching(alarm))
unwatch(alarm);
if (!conn->in->limit){
local_shutdown(conn);
process_msp_asap();
}
}
if (alarm->poll.revents & POLLERR) {
free_connection(conn);
process_msp_asap();
}
}
static void listen_poll(struct sched_ent *alarm)
{
if (alarm->poll.revents & POLLIN) {
struct socket_address addr;
addr.addrlen = sizeof addr.store;
int fd = accept(alarm->poll.fd, &addr.addr, &addr.addrlen);
if (fd==-1){
WHYF_perror("accept(%d)", alarm->poll.fd);
return;
}
INFOF("- Incoming TCP connection from %s", alloca_socket_address(&addr));
watch(&mdp_sock);
MSP_SOCKET sock = msp_socket(mdp_sock.poll.fd, 0);
if (msp_socket_is_null(sock))
return;
struct connection *connection = alloc_connection(sock, fd, io_poll, fd, io_poll);
if (!connection){
msp_close(sock);
return;
}
msp_set_handler(sock, msp_handler, connection);
msp_connect(sock, &remote_addr);
process_msp_asap();
if (once){
unwatch(alarm);
close(alarm->poll.fd);
alarm->poll.fd=-1;
}
}
}
int app_msp_connection(const struct cli_parsed *parsed, struct cli_context *UNUSED(context))
{
const char *sidhex, *port_string, *local_port_string;
once = cli_arg(parsed, "--once", NULL, NULL, NULL) == 0;
if ( cli_arg(parsed, "--forward", &local_port_string, cli_uint, NULL) == -1
|| cli_arg(parsed, "--service", &service_name, NULL, NULL) == -1
|| cli_arg(parsed, "sid", &sidhex, str_is_subscriber_id, NULL) == -1
|| cli_arg(parsed, "port", &port_string, cli_uint, NULL) == -1)
return -1;
struct mdp_sockaddr addr;
bzero(&addr, sizeof addr);
service_port = addr.port = atoi(port_string);
saw_error=0;
if (sidhex && *sidhex){
if (str_to_sid_t(&addr.sid, sidhex) == -1)
return WHY("str_to_sid_t() failed");
}
int ret=-1;
MSP_SOCKET sock = MSP_SOCKET_NULL;
if (service_name){
// listen for service discovery messages
service_sock.poll.fd = mdp_socket();
if (service_sock.poll.fd==-1)
goto end;
set_nonblock(service_sock.poll.fd);
watch(&service_sock);
// bind
struct mdp_header header;
bzero(&header, sizeof(header));
header.local.sid = BIND_PRIMARY;
header.local.port = MDP_PORT_SERVICE_DISCOVERY;
header.remote.sid = SID_ANY;
header.remote.port = MDP_LISTEN;
header.ttl = PAYLOAD_TTL_DEFAULT;
header.flags = MDP_FLAG_BIND|MDP_FLAG_REUSE;
if (mdp_send(service_sock.poll.fd, &header, NULL, 0)==-1)
goto end;
}else
service_sock.poll.fd=-1;
mdp_sock.poll.fd = mdp_socket();
if (mdp_sock.poll.fd==-1)
goto end;
set_nonblock(STDIN_FILENO);
set_nonblock(STDOUT_FILENO);
bzero(&ip_addr, sizeof ip_addr);
if (local_port_string){
ip_addr.addrlen = sizeof(ip_addr.inet);
ip_addr.inet.sin_family = AF_INET;
ip_addr.inet.sin_port = htons(atoi(local_port_string));
ip_addr.inet.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
}
if (sidhex && *sidhex){
if (local_port_string){
remote_addr = addr;
listen_alarm.poll.fd = esocket(PF_INET, SOCK_STREAM, 0);
if (listen_alarm.poll.fd==-1)
goto end;
if (socket_bind(listen_alarm.poll.fd, &ip_addr)==-1)
goto end;
if (socket_listen(listen_alarm.poll.fd, 0)==-1)
goto end;
watch(&listen_alarm);
INFOF("- Forwarding from %s to %s:%d", alloca_socket_address(&ip_addr), alloca_tohex_sid_t(addr.sid), addr.port);
}else{
watch(&mdp_sock);
sock = msp_socket(mdp_sock.poll.fd, 0);
once = 1;
struct connection *conn=alloc_connection(sock, STDIN_FILENO, io_poll, STDOUT_FILENO, io_poll);
if (!conn)
goto end;
msp_set_handler(sock, msp_handler, conn);
msp_connect(sock, &addr);
INFOF("- Connecting to %s:%d", alloca_tohex_sid_t(addr.sid), addr.port);
}
}else{
watch(&mdp_sock);
sock = msp_socket(mdp_sock.poll.fd, 0);
msp_set_handler(sock, msp_listener, NULL);
msp_set_local(sock, &addr);
// sock will be closed if listen fails
if (msp_listen(sock)==-1)
goto end;
listener=sock;
if (local_port_string){
INFOF("- Forwarding from port %d to %s", addr.port, alloca_socket_address(&ip_addr));
}else{
once = 1;
INFOF(" - Listening on port %d", addr.port);
}
}
process_msp_asap();
sigIntFlag = 0;
signal(SIGINT, sigIntHandler);
signal(SIGTERM, sigIntHandler);
while(sigIntFlag==0 && fd_poll()){
;
}
ret = saw_error;
signal(SIGINT, SIG_DFL);
sigIntFlag = 0;
end:
listener = MSP_SOCKET_NULL;
if (is_watching(&mdp_sock))
unwatch(&mdp_sock);
if (mdp_sock.poll.fd!=-1){
msp_close_all(mdp_sock.poll.fd);
mdp_close(mdp_sock.poll.fd);
mdp_sock.poll.fd=-1;
}
if (service_sock.poll.fd!=-1){
if (is_watching(&service_sock))
unwatch(&service_sock);
mdp_close(service_sock.poll.fd);
service_sock.poll.fd=-1;
}
if (listen_alarm.poll.fd !=-1 && is_watching(&listen_alarm))
unwatch(&listen_alarm);
if (listen_alarm.poll.fd!=-1)
close(listen_alarm.poll.fd);
unschedule(&mdp_sock);
return ret;
}