/*-
 * Copyright (c) 2004 Robert N. M. Watson
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *
 * [id for your version control system, if any]
 */
#include <sys/types.h>
#include <sys/socket.h>

#include <net/ethernet.h>
#include <net/if_arp.h>

#include <netinet/in.h>

#include <arpa/inet.h>

#include <assert.h>
#include <errno.h>
#include <pthread.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>

#include "ether.h"
#include "arp.h"
#include "ip.h"

struct arptabentry {
	struct in_addr	 ate_inaddr;
	u_char		*ate_addr[ETHER_ADDR_LEN];
};

#define	ARPTAB_LEN	128
static struct arptabentry	arptab[ARPTAB_LEN];
static pthread_mutex_t		arp_mutex;

static void	arp_process(u_char *new_ether_addr,
		    struct in_addr *new_ipaddr);
static void	arp_query(struct in_addr *ipaddr);
static void	arp_reply(u_char *target_ether_addr,
		    struct in_addr *target_ipaddr);

void
arp_init(void)
{

	assert(pthread_mutex_init(&arp_mutex, NULL) == 0);

}

void
arp_input(struct ether_header *eh, u_char *payload, u_int payloadlen)
{
	struct arphdr *ah;

	if (payloadlen < arphdr_len2(ETHER_ADDR_LEN, sizeof(struct in_addr)))
		return;
	ah = (struct arphdr *)payload;

	if (ah->ar_hrd != htons(ARPHRD_ETHER))
		return;
	if (ah->ar_pro != htons(ETHERTYPE_IP))
		return;
	if (ah->ar_hln != ETHER_ADDR_LEN)
		return;
	if (ah->ar_pln != sizeof(struct in_addr))
		return;
	switch (ntohs(ah->ar_op)) {
	case ARPOP_REQUEST:
		if (bcmp(ar_tpa(ah), &my_ipaddr, sizeof(my_ipaddr)) != 0)
			return;
		arp_reply(ar_sha(ah), (struct in_addr *)ar_spa(ah));
		break;

	case ARPOP_REPLY:
		if (bcmp(ar_tpa(ah), &my_ipaddr, sizeof(my_ipaddr)) != 0)
			return;
		arp_process(ar_sha(ah), (struct in_addr *)ar_spa(ah));
	default:
		return;
	}
}

static void
arp_process(u_char *new_ether_addr, struct in_addr *new_ipaddr)
{
	int i;

	printf("Told about HW address for %s\n", inet_ntoa(*new_ipaddr));
	assert(pthread_mutex_lock(&arp_mutex) == 0);

	/*
	 * Last-recently found algorithm.  Seems a bit dubious, but easy to
	 * implement.
	 */
	for (i = 0; i < ARPTAB_LEN; i++) {
		if (arptab[0].ate_inaddr.s_addr == 0)
			goto found;
	}
	i = 0;
found:
	arptab[i].ate_inaddr = *new_ipaddr;
	bcopy(new_ether_addr, arptab[i].ate_addr, ETHER_ADDR_LEN);
	assert(pthread_mutex_unlock(&arp_mutex) == 0);
}

int
arp_lookup(struct in_addr *ipaddr, u_char *etheraddr)
{
	int i;

	assert(pthread_mutex_lock(&arp_mutex) == 0);
	for (i = 0; i < ARPTAB_LEN; i++) {
		if (arptab[i].ate_inaddr.s_addr == ipaddr->s_addr) {
			if (etheraddr != NULL)
				bcopy(arptab[i].ate_addr, etheraddr,
				    ETHER_ADDR_LEN);
			assert(pthread_mutex_unlock(&arp_mutex) == 0);
			return (0);
		}
	}
	assert(pthread_mutex_unlock(&arp_mutex) == 0);
	arp_query(ipaddr);
	errno = ENOENT;
	return (-1);
}

static void
arp_query(struct in_addr *ipaddr)
{
	struct ether_header *eh;
	struct arphdr *ah;
	u_int packetlen;
	u_char *packet;

	packetlen = sizeof(*eh) +
	    arphdr_len2(ETHER_ADDR_LEN, sizeof(*ipaddr));
	packet = malloc(packetlen);
	if (packet == NULL)
		return;
	bzero(packet, packetlen);
	eh = (struct ether_header *)packet;
	ah = (struct arphdr *)(packet + sizeof(*eh));

	eh->ether_type = htons(ETHERTYPE_ARP);
	bcopy(my_ether_addr, eh->ether_shost, ETHER_ADDR_LEN);
	bcopy(broadcast_ether_addr, eh->ether_dhost, ETHER_ADDR_LEN);

	ah->ar_hrd = htons(ARPHRD_ETHER);
	ah->ar_pro = htons(ETHERTYPE_IP);
	ah->ar_hln = ETHER_ADDR_LEN;
	ah->ar_pln = sizeof(*ipaddr);
	ah->ar_op = htons(ARPOP_REQUEST);
	bcopy(my_ether_addr, ar_sha(ah), ETHER_ADDR_LEN);
	bcopy(&my_ipaddr, ar_spa(ah), sizeof(my_ipaddr));
	bcopy(ipaddr, ar_tpa(ah), sizeof(*ipaddr));

	ether_output(packet, packetlen);
	free(packet);
}

static void
arp_reply(u_char *target_ether_addr, struct in_addr *target_ipaddr)
{
	struct ether_header *eh;
	struct arphdr *ah;
	u_int packetlen;
	u_char *packet;

	packetlen = sizeof(*eh) + arphdr_len2(ETHER_ADDR_LEN,
	    sizeof(*target_ipaddr));
	packet = malloc(packetlen);
	if (packet == NULL)
		return;
	bzero(packet, packetlen);
	eh = (struct ether_header *)packet;
	ah = (struct arphdr *)(packet + sizeof(*eh));

	eh->ether_type = htons(ETHERTYPE_ARP);
	bcopy(my_ether_addr, eh->ether_shost, ETHER_ADDR_LEN);
	bcopy(target_ether_addr, eh->ether_dhost, ETHER_ADDR_LEN);

	ah->ar_hrd = htons(ARPHRD_ETHER);
	ah->ar_pro = htons(ETHERTYPE_IP);
	ah->ar_hln = ETHER_ADDR_LEN;
	ah->ar_pln = sizeof(struct in_addr);
	ah->ar_op = htons(ARPOP_REPLY);
	bcopy(my_ether_addr, ar_sha(ah), ETHER_ADDR_LEN);
	bcopy(&my_ipaddr, ar_spa(ah), sizeof(my_ipaddr));
	bcopy(target_ether_addr, ar_tha(ah), sizeof(ETHER_ADDR_LEN));
	bcopy(target_ipaddr, ar_tpa(ah), sizeof(*target_ipaddr));

	ether_output(packet, packetlen);
	free(packet);
}

void
arp_ipaddrset(void)
{

	arp_lookup(&my_ipaddr, NULL);
}
