/*-
 * Copyright (c) 2004-2005 Robert N. M. Watson
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

#include <sys/types.h>
#include <sys/socket.h>

#include <netinet/in.h>

#include <arpa/inet.h>

#include <err.h>
#include <errno.h>
#include <fcntl.h>
#include <netdb.h>
#include <poll.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

/*
 * Darwin doesn't define INFTIM?
 */
#ifndef INFTIM
#define	INFTIM	-1
#endif

/*
 * Given a proxy server ip and port, and destination hostname and port,
 * connect through the web server proxy to the destination.  Bind a TCP port
 * and forward data back and forth between the local TCP port and the
 * forwarded connection.
 */
static void
usage(void)
{

	fprintf(stderr, "wwwproxy proxy_hostname proxy_port target_hostname "
	    "target_port listen_port\n");
	exit(-1);
}

static int
proxy_connect(struct sockaddr_in *sin)
{
	int sock;

	sock = socket(PF_INET, SOCK_STREAM, 0);
	if (sock == -1) {
		perror("socket");
		exit(-1);
	}

	if (connect(sock, (struct sockaddr *)sin, sizeof(*sin)) < 0) {
		perror("connect");
		exit(-1);
	}

	return (sock);
}

static ssize_t
proxy_getline(int sock, char *buffer, int buflen)
{
	ssize_t len;
	int offset;
	char ch;

	offset = 0;
	ch = '\0';
	while (ch != '\n') {
		if (offset > buflen) {
			fprintf(stderr, "proxy_getline: line too long\n");
			exit(-1);
		}
		len = recv(sock, &ch, 1, 0);
		if (len == -1) {
			perror("recv");
			exit(-1);
		}
		if (len == 0) {
			fprintf(stderr, "recv: connection closed\n");
			exit(-1);
		}
		buffer[offset] = ch;
		offset++;
	}
	buffer[offset] = '\0';
	offset++;
	if (offset > buflen) {
		fprintf(stderr, "proxy_getline: line too long\n");
		exit(-1);
	}
	return (offset);
}

static void
proxy_negotiate(int sock, const char *target_hostname,
    const char *target_port)
{
	char buffer[1024];
	char *string;
	ssize_t len;

	asprintf(&string, "CONNECT %s:%s HTTP/1.0\n", target_hostname,
	    target_port);
	if (string == NULL) {
		perror("asprintf");
		exit(-1);
	}

	len = strlen(string);
	if (send(sock, string, len, 0) != len) {
		perror("send");
		exit(-1);
	}
	printf("Sent: %s", string);
	free(string);

	asprintf(&string, "User-Agent: Mozilla/5.0 (compatible; na; na)"
	    " (na)\n");
	if (string == NULL) {
		perror("asprintf");
		exit(-1);
	}

	len = strlen(string);
	if (send(sock, string, len, 0) != len) {
		perror("send");
		exit(-1);
	}
	printf("Sent: %s", string);
	free(string);

	asprintf(&string, "Host: %s\n\n", target_hostname);
	if (string == NULL) {
		perror("asprintf");
		exit(-1);
	}

	len = strlen(string);
	if (send(sock, string, len, 0) != len) {
		perror("send");
		exit(-1);
	}
	printf("Sent: %s", string);
	free(string);

	len = proxy_getline(sock, buffer, 1024);
#ifdef DEBUG
	printf("Received %d\n", len);
	printf("Received: %s\n", buffer);
#endif

#define	HTTP_GOOD_STRING	"HTTP/1.0 200"
	if (strncmp(buffer, HTTP_GOOD_STRING, strlen(HTTP_GOOD_STRING)) !=
	    0) {
		fprintf(stderr, "Proxy error: %s\n", buffer);
		exit(-1);
	}
#undef HTTP_GOOD_STRING

	len = proxy_getline(sock, buffer, 1024);
#ifdef DEBUG
	printf("Received %d\n", len);
	printf("Received: %s\n", buffer);
#endif
}

static void
proxy_char(int socka, int sockb)
{
	ssize_t len;
	char ch;

	len = recv(socka, &ch, 1, 0);
	if (len == -1) {
		perror("recv");
		exit(-1);
	}
	if (len == 0) {
		fprintf(stderr, "recv: connection closed\n");
		exit(-1);
	}
	if (len != 1) {
		fprintf(stderr, "recv: overflow\n");
		exit(-1);
	}
	len = send(sockb, &ch, 1, 0);
	if (len == -1) {
		perror("send");
		exit(-1);
	}
	if (len == 0) {
		fprintf(stderr, "send: connection closed\n");
		exit(-1);
	}
	if (len != 1) {
		fprintf(stderr, "send: overflow\n");
		exit(-1);
	}
}

static void
proxy(int sock0, int sock1)
{
	int want_write_sock0, want_write_sock1;
	int errmask, inmask, outmask, ret;
	struct pollfd pollfd[2];

	errmask = POLLERR | POLLHUP | POLLNVAL;
	inmask = POLLIN;
	outmask = POLLOUT;

	want_write_sock0 = 0;
	want_write_sock1 = 0;
	while (1) {
		pollfd[0].fd = sock0;
		if (want_write_sock0)
			pollfd[0].events = errmask | inmask | outmask;
		else
			pollfd[0].events = errmask | inmask;
		pollfd[0].revents = 0;

		pollfd[1].fd = sock1;
		if (want_write_sock1)
			pollfd[1].events = errmask | inmask | outmask;
		else
			pollfd[1].events = errmask | inmask;
		pollfd[1].revents = 0;

#ifdef DEBUG
		printf("poll: requesting fd[0] 0x%x, fd[1] 0x%x\n",
		    pollfd[0].events, pollfd[1].events);
#endif

		ret = poll(pollfd, 2, INFTIM);
		if (ret < 0 && errno != EINTR) {
			perror("poll");
			continue;
		}
#ifdef DEBUG
		printf("poll returned fd[0] = 0x%x/0x%x, fd[1] = 0x%x/0x%x\n",
		    pollfd[0].events, pollfd[0].revents,
		    pollfd[1].events, pollfd[1].revents);
#endif

		if (pollfd[0].revents & errmask ||
		    pollfd[1].revents & errmask) {
			fprintf(stderr, "Connection error\n");
			exit(-1);
		}

		if (pollfd[0].revents & inmask) {
			want_write_sock1 = 1;
			if (pollfd[1].revents & outmask) {
#ifdef DEBUG
				printf("char from 0 to 1\n");
#endif
				proxy_char(sock0, sock1);
#ifdef DEBUG
			} else
				printf("sock0 readable but 1 not writable\n");
#else
			}
#endif
		} else
			want_write_sock1 = 0;

		if (pollfd[1].revents & inmask) {
			want_write_sock0 = 1;
			if (pollfd[0].revents & outmask) {
#ifdef DEBUG
				printf("char from 1 to 0\n");
#endif
				proxy_char(sock1, sock0);
#ifdef DEBUG
			} else
				printf("sock1 readable but 0 not writable\n");
#else
			}
#endif
		} else
			want_write_sock0 = 0;
	}
}

int
main(int argc, char *argv[])
{
	const char *target_hostname, *target_port;
	int proxy_sock, listen_sock, accept_sock;
	struct sockaddr_in proxy_sin, listen_sin;
	struct addrinfo *ai, hint;
	char *dummy;
	int error;
	long port;
	pid_t pid;

	if (argc != 6)
		usage();

	bzero(&proxy_sin, sizeof(proxy_sin));
	proxy_sin.sin_len = sizeof(proxy_sin);
	proxy_sin.sin_family = AF_INET;


	bzero(&hint, sizeof(hint));
	hint.ai_flags = AI_CANONNAME;
	hint.ai_family = AF_INET;
	hint.ai_socktype = SOCK_STREAM;
	error = getaddrinfo(argv[1], NULL, &hint, &ai);
	if (error)
		errx(-1, "%s: %s", argv[1], gai_strerror(error));

	proxy_sin.sin_addr = ((struct sockaddr_in *)ai->ai_addr)->sin_addr;
	proxy_sin.sin_len = ((struct sockaddr_in *)ai->ai_addr)->sin_len;

	freeaddrinfo(ai);

	port = strtoul(argv[2], &dummy, 10);
	if (port < 1 || port > 65535 || *dummy != '\0')
		usage();
	proxy_sin.sin_port = htons(port);

	target_hostname = argv[3];

	target_port = argv[4];

	printf("Proxy server %s:%ld\n", inet_ntoa(proxy_sin.sin_addr), port);
	printf("Target server %s:%s\n", target_hostname, target_port);
	printf("\n");

	if (signal(SIGCHLD, SIG_IGN) < 0) {
		perror("signal(SIGCHLD, SIG_IGN)");
		exit(-1);
	}

	bzero(&listen_sin, sizeof(listen_sin));
	listen_sin.sin_len = sizeof(listen_sin);
	listen_sin.sin_family = AF_INET;
	listen_sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
	port = strtoul(argv[5], &dummy, 10);
	if (port < 1 || port > 65535 || *dummy != '\0')
		usage();
	listen_sin.sin_port = htons(port);

	listen_sock = socket(PF_INET, SOCK_STREAM, 0);
	if (listen_sock == -1) {
		perror("socket");
		exit(-1);
	}

	if (bind(listen_sock, (struct sockaddr *) &listen_sin,
	    sizeof(listen_sin)) < 0) {
		perror("bind");
		exit(-1);
	}

	if (listen(listen_sock, -1) == -1) {
		perror("listen");
		exit(-1);
	}

	printf("Waiting for local connection\n");
	while (1) {
		accept_sock = accept(listen_sock, NULL, NULL);
		if (accept_sock == -1) {
			perror("accept");
			continue;
		}
		printf("Received local connection\n");

		pid = fork();
		if (pid < 0) {
			perror("fork");
			close(accept_sock);
			continue;
		}
		if (pid > 0) {
			/* Parent daemon. */
			close(accept_sock);
			continue;
		}
		/* Child proxy. */
		proxy_sock = proxy_connect(&proxy_sin);

		proxy_negotiate(proxy_sock, target_hostname, target_port);

		if (fcntl(proxy_sock, F_SETFL, O_NONBLOCK) < 0) {
			perror("O_NONBLOCK");
			exit(-1);
		}
		if (fcntl(accept_sock, F_SETFL, O_NONBLOCK) < 0) {
			perror("O_NONBLOCK");
			exit(-1);
		}
		proxy(proxy_sock, accept_sock);
		exit(0);
	}

	exit(0);
}
