/*-
 * Copyright (c) 2004-2005 Robert N. M. Watson
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/ioctl.h>

#include <net/ethernet.h>
#include <net/if.h>
#include <net/if_arp.h>

#include <netinet/in.h>
#include <netinet/in_systm.h>
#include <netinet/ip.h>

#include <arpa/inet.h>

#include <err.h>
#include <errno.h>
#include <fcntl.h>
#include <poll.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#define	MAXPACKET	32000

struct frame_wrapper {
	u_int	fw_len;
};

u_int32_t	netbridge_ip;
int		netbridge_noipbroadcast;
int		netbridge_noipmulticast;

#define	DEV_PATH	"/dev/"
#define	MAX_PATH	1024

static void
usage(void)
{

	fprintf(stderr, "netbridge client tapdev ip port\n");
	fprintf(stderr, "netbridge server tapdev port\n");
	exit(-1);
}

static int
ether_broadcast(u_char *address)
{
	u_char broadcast_address[6] = {0xff, 0xff, 0xff, 0xff, 0xff, 0xff};

	return (bcmp(address, broadcast_address, 6) == 0);
}

static int
ether_multicast(u_char *address)
{

	return (address[0] & 0x01);
}

static int
tap_if_up(const char *ifnet)
{
	struct ifreq ifr;
	int s;

	s = socket(PF_INET, SOCK_RAW, 0);
	if (s < 0) {
		warn("tap_if_up: socket");
		return (-1);
	}

	bzero(&ifr, sizeof(ifr));
	strlcpy(ifr.ifr_name, ifnet, sizeof(ifr.ifr_name));

	if (ioctl(s, SIOCGIFFLAGS, &ifr) < 0) {
		warn("tap_if_up: %s: SIOCGIFFLAGS", ifnet);
		return (-1);
	}

	ifr.ifr_flags |= IFF_UP;

	if (ioctl(s, SIOCSIFFLAGS, &ifr) < 0) {
		warn("tap_if_up: %s: SIOCSIFFLAGS", ifnet);
		close(s);
		return (-1);
	}
	close(s);
	return (0);
}

/*
 * Do we allow this particular ARP packet in?  If it's a request, only if it
 * matches the desired IP.  Otherwise, yes.
 */
static int
arp_match(u_char *packet, u_int len, u_int32_t addr)
{
	struct ether_header *eh;
	struct arphdr *ah;

	if (len < sizeof(*eh) + sizeof(*ah))
		return (0);
	eh = (struct ether_header *)packet;
	ah = (struct arphdr *)(packet + sizeof(*eh));

	if (len < arphdr_len2(6, 4) + sizeof(*eh))
		return (0);
	/*
	 * We speak only ethernet.
	 */
	if (ntohs(ah->ar_hrd) != ARPHRD_ETHER)
		return (0);

	if (ntohs(ah->ar_pro) != ETHERTYPE_IP)
		return (0);

	/*
	 * If it's a request, we only want to accept it if it matches the IP
	 * we are interested in.  Otherwise, we accept it regardless,
	 * figuring it's most likely a unicast response.
	 */
	switch (ntohs(ah->ar_op)) {
	case ARPOP_REQUEST:
		if (*(u_int32_t *)ar_tpa(ah) != addr)
			return (0);
		return (1);

	default:
		return (1);
	}
}

/*
 * The tap input routine is fairly simple, except that we do a bit of
 * filtering to prevent sending packets the other endpoint won't be
 * interested in.
 */
static int
netbridge_tap_input(int tap_fd, int sock)
{
	char packet[MAXPACKET + sizeof(struct frame_wrapper)];
	struct frame_wrapper *fw;
	struct ether_header *eh;
	struct ip *ip;
	ssize_t len, recvlen, sendlen;

	len = MAXPACKET + sizeof(struct frame_wrapper);
	recvlen = read(tap_fd, packet + sizeof(struct frame_wrapper), len);
	if (recvlen == -1) {
		perror("netbridge_tap_input.read");
		return (-1);
	}

	/*
	 * Since I'm primarily interested in IP traffic, allow through only
	 * ARP and IPv4.
	 */
	eh = (struct ether_header *)(packet + sizeof(struct frame_wrapper));
	switch (ntohs(eh->ether_type)) {
	case ETHERTYPE_IP:
		if (netbridge_noipbroadcast &&
		    ether_broadcast(eh->ether_dhost))
			return (0);
		if (netbridge_noipmulticast &&
		    ether_multicast(eh->ether_dhost))
			return (0);
		if (recvlen < sizeof(*eh) + sizeof(*ip))
			return (0);
		ip = (struct ip *)(eh + 1);
		if (netbridge_ip) {
			if ((ip->ip_dst.s_addr != netbridge_ip) ||
			    (netbridge_noipmulticast &&
			    IN_CLASSD(ip->ip_dst.s_addr)))
				return (0);
		}
		break;
	case ETHERTYPE_ARP:
		if (netbridge_ip &&
		    !arp_match((u_char *)eh, recvlen, netbridge_ip))
			return (0);
		break;
	default:
		return (0);
	}

	fw = (struct frame_wrapper *)packet;
	fw->fw_len = htonl(recvlen);

	sendlen = send(sock, packet, recvlen + sizeof(struct frame_wrapper),
	    0);
	if (sendlen == -1) {
		perror("netbridge_tap_input.send");
		return (-1);
	}
	if (sendlen != recvlen + sizeof(struct frame_wrapper)) {
		fprintf(stderr,
		    "netbridge_tap_input.send tried %d but send %d\n",
		    recvlen + sizeof(struct frame_wrapper), sendlen);
		return (-1);
	}
	return (0);
}

static int
netbridge_sock_input_real(int tap_fd, int sock)
{
	char packet[MAXPACKET];
	struct frame_wrapper fw;
	ssize_t recvlen, sendlen;

	/*
	 * XXXRW: need to read the frame wrapper first, then read the right
	 * amount of data.
	 */
	recvlen = recv(sock, &fw, sizeof(fw), MSG_WAITALL);
	if (recvlen == -1 && (errno == EINTR || errno == EAGAIN || errno
	    == EWOULDBLOCK))
		return (0);
	if (recvlen == -1) {
		perror("netbridge_socket_input.recv_fw");
		return (-1);
	}
	if (recvlen == 0) {
		fprintf(stderr, "netbridge_socket_input.recv_fw eof\n");
		return (-1);
	}
	if (recvlen != sizeof(fw)) {
		fprintf(stderr,
		    "netbridge_socket_input.recv_fw %d, but not %d\n",
		    recvlen, sizeof(fw));
		return (-1);
	}
	fw.fw_len = ntohl(fw.fw_len);
	if (fw.fw_len > MAXPACKET) {
		fprintf(stderr,
		    "netbridge_socket_input.recv_fw %d is too big\n",
		    fw.fw_len);
		return (-1);
	}

	recvlen = recv(sock, packet, fw.fw_len, MSG_WAITALL);
	if (recvlen == -1) {
		perror("netbridge_socket_input.recv_packet");
		return (-1);
	}
	if (recvlen != fw.fw_len) {
		fprintf(stderr, "netbridge_socket_input.recv_packet didn't "
		    "receive %d bytes\n", fw.fw_len);
		return (-1);
	}

	sendlen = write(tap_fd, packet, fw.fw_len);
	if (sendlen == -1) {
		perror("netbridge_socket_input.write");
		return (-1);
	}
	if (sendlen != fw.fw_len) {
		fprintf(stderr,
		    "netbridge_socket_input.write %d rather than %d bytes\n",
		    sendlen, fw.fw_len);
		return (-1);
	}

	return (0);
}

/*
 * Wrapper around the real input routine to temporarily turn off non-blocking
 * mode.  Really, we should read incrementally into a packet buffer, but I
 * haven't done that yet.
 */
static int
netbridge_sock_input(int tap_fd, int sock)
{
	int ret;

	if (fcntl(sock, F_SETFL, 0) < 0) {
		perror("clear O_NONBLOCK on sock");
		exit(-1);
	}

	ret = netbridge_sock_input_real(tap_fd, sock);

	if (fcntl(sock, F_SETFL, O_NONBLOCK) < 0) {
		perror("O_NONBLOCK on sock");
		exit(-1);
	}
	return (ret);
}

static void
netbridge_run(int tap_fd, int sock)
{
	struct pollfd pollfd[2];
	int inmask, errmask, ret;

	printf("bridging fd %d and fd %d\n", tap_fd, sock);

	while (1) {
		errmask = POLLERR | POLLHUP | POLLNVAL;
		inmask = POLLIN;

		pollfd[0].fd = tap_fd;
		pollfd[0].events = errmask | inmask;
		pollfd[0].revents = 0;
		pollfd[1].fd = sock;
		pollfd[1].events = errmask | inmask;
		pollfd[1].revents = 0;
		ret = poll(pollfd, 2, INFTIM);
		if (ret < 0) {
			perror("poll");
			continue;
		}

		/*
		 * Need to handle a disconnect?
		 */
		if (pollfd[0].revents & errmask) {
			fprintf(stderr, "exceptional condition on tap: %d\n",
			    pollfd[0].revents);
			continue;
		}
		if (pollfd[1].revents & errmask) {
			fprintf(stderr, "exceptional condition on socket: %d\n",
			    pollfd[1].revents);
			break;
		}

		/*
		 * Data from psuedo-interface?
		 */
		if (pollfd[0].revents & inmask) {
			netbridge_tap_input(tap_fd, sock);
		}

		/*
		 * Data from socket?  Could also return EOF.
		 */
		if (pollfd[1].revents & inmask) {
			if (netbridge_sock_input(tap_fd, sock) == -1)
				break;
		}
	}
}

/*
 * Perform a dubious validation of an interface name to make sure we don't
 * mess up with paths too much.
 */
static int
validate_pathname(const char *string)
{

	if (strchr(string, '/') != NULL) {
		errno = EINVAL;
		return (-1);
	}
	return (0);
}

static void
netbridge_client(int argc, char *argv[])
{
	char pathname[MAX_PATH];
	struct sockaddr_in sin;
	char *dummy;
	int tap_fd;
	long port;
	int sock;

	if (argc != 3)
		usage();

	if (validate_pathname(argv[0]) != 0) {
		perror(argv[0]);
		exit(-1);
	}

	bzero(&sin, sizeof(sin));
	sin.sin_len = sizeof(sin);
	sin.sin_family = AF_INET;
	if (inet_aton(argv[1], &sin.sin_addr) == 0) {
		perror(argv[1]);
		exit(-1);
	}

	port = strtoul(argv[2], &dummy, 10);
	if (port < 1 || port > 65535 || *dummy != '\0')
		usage();
	sin.sin_port = htons(port);

	snprintf(pathname, MAX_PATH, "%s%s", DEV_PATH, argv[0]);
	tap_fd = open(pathname, O_RDWR);
	if (tap_fd == -1) {
		perror(argv[0]);
		exit(-1);
	}

	if (fcntl(tap_fd, F_SETFL, O_NONBLOCK) < 0) {
		perror("O_NONBLOCK on tap_fd");
		exit(-1);
	}

	if (tap_if_up(argv[0]) < 0)
		exit(-1);

	sock = socket(PF_INET, SOCK_STREAM, 0);
	if (sock == -1) {
		perror("socket");
		exit(-1);
	}

	if (connect(sock, (struct sockaddr *) &sin, sizeof(sin)) == -1) {
		perror("connect");
		exit(-1);
	}

	if (fcntl(sock, F_SETFL, O_NONBLOCK) < 0) {
		perror("O_NONBLOCK on sock");
		exit(-1);
	}

	netbridge_run(tap_fd, sock);

	close(sock);
	close(tap_fd);
}

static void
netbridge_server(int argc, char *argv[])
{
	struct sockaddr_in sin, other_sin;
	int listen_sock, accept_sock;
	char pathname[MAX_PATH];
	socklen_t addrlen;
	char *dummy;
	int tap_fd;
	long port;

	if (argc != 2)
		usage();

	if (validate_pathname(argv[0]) != 0)
		errx(-1, "validate_pathname: %s: %s", argv[0],
		    strerror(errno));

	bzero(&sin, sizeof(sin));
	sin.sin_len = sizeof(sin);
	sin.sin_family = AF_INET;
	sin.sin_addr.s_addr = htonl(INADDR_ANY);

	port = strtoul(argv[1], &dummy, 10);
	if (port < 1 || port > 65535 || *dummy != '\0')
		usage();
	sin.sin_port = htons(port);

	snprintf(pathname, MAX_PATH, "%s%s", DEV_PATH, argv[0]);
	tap_fd = open(pathname, O_RDWR);
	if (tap_fd == -1)
		errx(-1, "%s: open: %s:", argv[0], strerror(errno));

	if (fcntl(tap_fd, F_SETFL, O_NONBLOCK) < 0)
		errx(-1, "%s: fcntl: F_SETFL: %s:", argv[0], strerror(errno));

	if (tap_if_up(argv[0]))
		exit(-1);

	listen_sock = socket(PF_INET, SOCK_STREAM, 0);
	if (listen_sock == -1)
		errx(-1, "socket(PF_INET, SOCK_STREAM): %s", strerror(errno));

	if (bind(listen_sock, (struct sockaddr *)&sin, sizeof(sin)) == -1)
		errx(-1, "socket(%ld): %s:", port, strerror(errno));

	if (listen(listen_sock, -1) == -1)
		errx(-1, "listen: %s", strerror(errno));

	while (1) {
		bzero(&other_sin, sizeof(other_sin));
		addrlen = sizeof(other_sin);

		accept_sock = accept(listen_sock, (struct sockaddr *)
		    &other_sin, &addrlen);
		if (accept_sock == -1) {
			perror("accept");
			continue;
		}
		printf("connection opened from %s:%d\n",
		    inet_ntoa(other_sin.sin_addr), ntohs(other_sin.sin_port));

		if (fcntl(accept_sock, F_SETFL, O_NONBLOCK) < 0) {
			perror("O_NONBLOCK on sock");
			exit(-1);
		}

		netbridge_run(tap_fd, accept_sock);

		close(accept_sock);
	}

	close(tap_fd);
}

/*
 * netbridge relies on an optional environmental variable, NETBRIDGE_IP, to
 * tell it if it's only supposed to bridge traffic over a socket for a
 * particular remote IP.  This helps cut down on undesired ARP chatter.
 *
 * netbridge also relies on an optional environmental variable,
 * NETBRIDGE_NOIPBROADCAST, which tells it not to forward IP-layer broadcast
 * packets.
 */
int
main(int argc, char *argv[])
{
	struct in_addr in;

	if (argc < 2)
		usage();

	if (getenv("NETBRIDGE_NOIPBROADCAST") != NULL) {
		printf("Stripping IP broadcast traffic on bridge\n");
		netbridge_noipbroadcast = 1;
	}
	if (getenv("NETBRIDGE_NOIPMULTICAST") != NULL) {
		printf("Stripping IP multicast traffic on bridge\n");
		netbridge_noipmulticast = 1;
	}
	if (getenv("NETBRIDGE_IP") != NULL) {
		if (inet_aton(getenv("NETBRIDGE_IP"), &in)) {
			printf("Only forwarding IP traffic for %s\n",
			    inet_ntoa(in));
			netbridge_ip = in.s_addr;
		} else
			printf("WARNING: Unable to parse IP address\n");
	}

	if (strcmp(argv[1], "client") == 0)
		netbridge_client(argc - 2, argv + 2);
	else if (strcmp(argv[1], "server") == 0)
		netbridge_server(argc - 2, argv + 2);
	else
		usage();

	exit(0);
}
