/*
   Copyright (C) 1996-1997 Robert Muchsel

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software
   Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.

   Please address all correspondence concerning the software to
   muchsel@acm.org.
*/

#include "config.h"
#define __NO_VERSION__
#include <linux/module.h>
#include <netinet/ip.h>
#include <netinet/udp.h>
#include <netinet/ip_icmp.h>
#include <linux/ip.h>
#include <linux/in.h>
#include <linux/skbuff.h>
#include <linux/netdevice.h>
#include <linux/firewall.h>
#include <linux/string.h>
#include <linux/ipsec.h>
#include <net/route.h>
#include <net/icmp.h>
#include "skip_defs.h"
#include "dynamic.h"
#include "memblk.h"
#include "skipcache.h"
#include "ipsp.h"
#include "skip.h"
#include "ipsum.h"
#include "interface.h"

/* #define CHECK_STACK */

static struct sk_buff *skb_expand_copy(struct sk_buff *, int);
static inline void skb2memblk(struct sk_buff *, struct memblk *, struct memseg *);
static inline void interface_ship_out(struct sk_buff *);
static int dev_skip(struct device *, char *);

static int maxheadergrowth = 0;


/* This function copies an skb, growing it to "size" bytes. 
   Mostly derived from the Linux 2.0.25 skb_copy function. */

static struct sk_buff *skb_expand_copy(struct sk_buff *skb, int size)
{
  struct sk_buff *n;
  unsigned long offset;

  /* Allocate the copy buffer */

  IS_SKB(skb);

  n = alloc_skb(size, GFP_ATOMIC);
  if (n == NULL)
    return NULL;

  /* Shift between the two data areas in bytes */
  offset = n->head - skb->head;

  /* Set the data pointer */
  skb_reserve(n, skb->data - skb->head);

  /* Set the tail pointer and length */
  skb_put(n, skb->len);

  /* Copy the bytes */
  memcpy(n->head, skb->head, skb->end - skb->head);
  n->link3   = NULL;
  n->list    = NULL;
  n->sk      = NULL;
  n->when    = skb->when;
  n->dev     = skb->dev;
  n->h.raw   = skb->h.raw + offset;
  n->mac.raw = skb->mac.raw + offset;
  n->ip_hdr  = (struct iphdr *)(((char *)skb->ip_hdr) + offset);
  n->saddr   = skb->saddr;
  n->daddr   = skb->daddr;
  n->raddr   = skb->raddr;
  n->seq     = skb->seq;
  n->end_seq = skb->end_seq;
  n->ack_seq = skb->ack_seq;
  n->acked   = skb->acked;
  memcpy(n->proto_priv, skb->proto_priv, sizeof(skb->proto_priv));
  n->used    = skb->used;
  n->free    = 1;
  n->arp     = skb->arp;
  n->tries   = 0;
  n->lock    = 0;
  n->users   = 0;
  n->pkt_type= skb->pkt_type;
  n->stamp   = skb->stamp;

  IS_SKB(n);
        
  return n;
}


/* This function translates an skb to the memblk structure, which is
   the internal ENskip representation of "memory" (to be compatible
   with BSD and other UNIXes, where IP packets are stored in multiple
   chained mbuf's (or mblk's). Under Linux, we don't need this, the
   drawback is however that prepending a header requires to copy the
   whole data area. Free tailroom seems to be unused... */

static inline void skb2memblk(struct sk_buff *skb, 
                              struct memblk *mb, struct memseg *ms)
{
  /* no "dynamic" part (== M_EXT of mbuf) */
  mb->dynamic    = NULL;
  mb->dynamiclen = 0;

  /* first "segment descriptor" (the one and only we actually use) */
  mb->ms     = ms;
  /* offset == offset of IP header in data area */
  mb->offset = (u_char *) skb->ip_hdr - skb->head;
  /* len == length of IP packet */
  mb->len    = ntohs(skb->ip_hdr->tot_len);

  /* ptr == pointer to the data area in first segment descriptor */
  ms->ptr = skb->head;
  /* length of data block in first segment descriptor */
  ms->len = mb->len + mb->offset;

  /* first free memseg, ESP uses this internally */
  mb->freems = ms+1;
}


#ifdef CHECK_STACK
/* Check kernel stack (it is only 4K) */

static int check_stack(int i)
{
  char *pg_dir;
  int ok = 0;
  int stack_level = (long)(&pg_dir) - current->kernel_stack_page;

  if (stack_level < 0)
    printk("ENSKIP kstack overflow %d, stack=%d\n", i, stack_level);
  else 
    if (stack_level < 1000)
      printk("ENSKIP kstack low %d, stack=%d\n", i, stack_level);
    else {
#if 0
      printk("ENSKIP kstack ok %d, stack=%d\n", i, stack_level);
#endif
      ok = 1;
    }
  if (*(unsigned long *) current->kernel_stack_page != STACK_MAGIC)
    printk("ENSKIP kstack corruption %d\n", i);
  else
    ok++;

  return (ok == 2);
}
#endif


/* This function intercepts incoming IP packets, both local and non-local.
   We return FW_BLOCK (which causes the kernel to discard the packet)
   if we queue the packet or want to dispose it. We return FW_SKIP if
   the packet was not enskipped or wasn't changed. */

int input_packet(struct firewall_ops *this, int pf, struct device *dev,
		 void *phdr, void *arg, struct sk_buff **pskb)
{
  struct iphdr *ipp = (*pskb)->ip_hdr;
  int tot_len = ntohs(ipp->tot_len);
  struct sk_buff *newm, *qskb;
  int result;
  int retval = FW_BLOCK;
  /* May be static only if we ensure locks: */
  static struct memblk oldmb, newmb;
  static struct memseg oldms[STATIC_MEMSEGS], newms[2];
  unsigned long flags;

  /* ignore loopback device */
  if (!dev_skip(dev, "input"))
    return FW_SKIP;

  /* leave short packets to the firewall */
  if (tot_len < ipp->ihl * 4)
    return FW_SKIP;

  /* hack packet to decrease maximum MTU */
  if ((ipp->protocol == IPPROTO_ICMP) && 
      (tot_len >= ipp->ihl * 4 + sizeof(struct icmphdr))) {
    struct icmphdr *icmp = (struct icmphdr *) (((u_char *) ipp) + ipp->ihl * 4);

    if (icmp->type == ICMP_DEST_UNREACH && 
        icmp->code == ICMP_FRAG_NEEDED) {
      __u32 checksum = icmp->checksum;
      __u16 mtu = ntohs(icmp->un.echo.sequence);
      __u16 diff = maxheadergrowth;

      /* we calculate the checksum in network byte order */
      mtu -= diff;
      icmp->un.echo.sequence = htons(mtu);

      checksum += htons(diff);
      checksum += checksum >> 16;
      icmp->checksum = checksum;

      ip_send_check(ipp);
    }

  }

  if ((*pskb)->proto_priv[15] & RCV_SEC)
    return FW_SKIP;

  /* never change transparent UDP ports needed for discovery */
  if ((ipp->protocol == IPPROTO_UDP) && 
      (tot_len >= ipp->ihl * 4 + sizeof(struct udphdr))) {
    struct udphdr *udp = (struct udphdr *) (((u_char *) ipp) + ipp->ihl * 4);

    if ((udp->source == htons(SKIP_UDP_RECV_PORT)) || 
        (udp->source == htons(SKIP_UDP_SEND_PORT)) ||
        (udp->dest   == htons(SKIP_UDP_RECV_PORT)) || 
        (udp->dest   == htons(SKIP_UDP_SEND_PORT)))
      return FW_SKIP;
  }

#ifdef CHECK_STACK
  if (!check_stack(1))
    return FW_BLOCK;
#endif

  /* Unfortunately, we need this to fix the problem of the limited kernel 
     stack */

  save_flags(flags);
  cli();	
	
  /* Clone the input skb. This operation locks the data of the input skb,
     so we can queue it. Be careful not to free the skb while it is queued! */
  qskb = newm = skb_clone(*pskb, GFP_ATOMIC);

  MEMZERO(&newmb, sizeof(newmb));
  newmb.ms = newms;
  MEMZERO(&oldmb, sizeof(oldmb));
  oldmb.ms = oldms;
  skb2memblk(newm, &oldmb, oldms);


  result = skip_process(SKIP_INPUT, NULL, NULL, (void *) &newm, &oldmb, &newmb);


  if (result == SKIP_PROCESSED) {
    /* nothing happened */

    kfree_skb(newm, FREE_WRITE);
    if (newm != qskb)
      kfree_skb(qskb, FREE_WRITE);

    newm->proto_priv[15] = RCV_SEC;
    retval = FW_SKIP;
  }
  else if (result > SKIP_PROCESSED) {
    /* detunneled/decrypted/authenticated */

    /* fix up the skb pointers */
    newm->data      = BLKSTART(&newmb);
    newm->ip_hdr    = (struct iphdr *) newm->data;
    newm->len       = ntohs(newm->ip_hdr->tot_len);
    newm->tail      = newm->data + newm->len;
    newm->protocol  = htons(ETH_P_IP);
    newm->ip_summed = 0;
    newm->h.iph     = newm->ip_hdr;

    /* and mark the skb as "authenticated"/"decrypted" */
    newm->proto_priv[15] = ((result & SKIP_P_AUTH)    ? RCV_AUTH   : 0)
                         | ((result & SKIP_P_DECRYPT) ? RCV_CRYPT  : 0)
                         | ((result & SKIP_P_TUNNEL)  ? RCV_TUNNEL : 0)
                         | RCV_SEC;

    /* and feed the packet back into the input queue (must not switch
       skbs here, because we need the defragmentor) */
    if (newm != qskb)
      kfree_skb(qskb, FREE_WRITE);

    netif_rx(newm); /* frees skb for us */
  } 
  else if (result == SKIP_QUEUED) {
    /* the skb was queued */

    if (newm != qskb)
      kfree_skb(newm, FREE_WRITE);
  } else {
    /* bad packet, policy violation, unsupported protocol, etc. */

    kfree_skb(newm, FREE_WRITE);
    if (newm != qskb)
      kfree_skb(qskb, FREE_WRITE);
  }

  restore_flags(flags);
 
  return retval;
}


/* This is the output interceptor. Basically this function works like
   the input interceptor, however the feedback function is a kludge
   (we use ip_forward in lack of something better). There's also a new 
   return code, FW_QUEUE, which works like FW_BLOCK but does not cause a 
   "permission denied" message on local packets (kernel function 
   ip_build_xmit). 
   Note: The dummies are all 0, arg is for debugging only. */

int output_packet(struct firewall_ops *dummy1, int pf, struct device *dev,
		  void *dummy3, void *arg, struct sk_buff **pskb)
{
  struct iphdr *ipp = (*pskb)->ip_hdr;
  int tot_len = ntohs(ipp->tot_len);
  struct sk_buff *newm, *qskb;
  int result;
  int retval = FW_BLOCK;
  /* May be static only if we ensure locks */
  static struct memblk oldmb, newmb;
  static struct memseg oldms[STATIC_MEMSEGS], newms[2];
  unsigned long flags;

  /* ignore loopback */
  if (!dev_skip(dev, "output"))
    return FW_SKIP;

  /* Recursion happens if the datagram has been fragmented:
     ip_queue_xmit -> skip -> ip_fragment -> ip_queue_xmit -> SKIP */
  
  if ((*pskb)->proto_priv[15] & SND_SEC)
    return FW_ACCEPT;

  /* this happens, too */
  if (ipp->protocol == IPPROTO_SKIP)
    return FW_ACCEPT;

  /* never change trasparent UDP packets */ 
  if ((ipp->protocol == IPPROTO_UDP) &&
      (tot_len >= ipp->ihl * 4 + sizeof(struct udphdr))) {
    struct udphdr *udp = (struct udphdr *) (((u_char *) ipp) + ipp->ihl * 4);

    if ((udp->source == htons(SKIP_UDP_RECV_PORT)) || 
        (udp->source == htons(SKIP_UDP_SEND_PORT)) ||
        (udp->dest   == htons(SKIP_UDP_RECV_PORT)) || 
        (udp->dest   == htons(SKIP_UDP_SEND_PORT)))
      return FW_SKIP;
  }

  if ((*pskb)->sk != NULL) {
    if ((*pskb)->sk->authentication == IPSEC_LEVEL_NONE &&
        (*pskb)->sk->encryption     == IPSEC_LEVEL_NONE)
      return FW_SKIP;
  }

#ifdef CHECK_STACK
  if (!check_stack(2))
    return FW_BLOCK;
#endif

  /* see above */
  save_flags(flags);
  cli();

  /* Clone the input skb; locks its data. Never free it while it is queued! */
  qskb = newm = skb_clone(*pskb, GFP_ATOMIC);

  MEMZERO(&newmb, sizeof(newmb));
  newmb.ms = newms;
  MEMZERO(&oldmb, sizeof(oldmb));
  oldmb.ms = oldms;
  skb2memblk(newm, &oldmb, oldms);

  
  result = skip_process(SKIP_OUTPUT, NULL, NULL, (void *) &newm, &oldmb, &newmb);


  if (result == SKIP_PROCESSED) {
    /* nothing happened - if not enskipped, we just say "OK" */ 

    kfree_skb(newm, FREE_WRITE);
    if (newm != qskb)
      kfree_skb(qskb, FREE_WRITE);

    /* check user level policy */
    if (!((*pskb)->sk &&
         ((*pskb)->sk->authentication >= IPSEC_LEVEL_USE ||
          (*pskb)->sk->encryption     >= IPSEC_LEVEL_USE))) {

      /* Set marker in skb + accept packet */
      (*pskb)->proto_priv[15] |= SND_SEC;
      retval = FW_SKIP;
    }
  } 
  else if (result > SKIP_PROCESSED) {
    /* tunneled/encrypted/authenticated */

    /* check user level policy */
    if ((*pskb)->sk &&
        ((!(result & SKIP_P_AUTH) && 
          (*pskb)->sk->authentication >= IPSEC_LEVEL_USE) ||
         (!(result & SKIP_P_ENCRYPT) &&
          (*pskb)->sk->encryption >= IPSEC_LEVEL_USE))) {

      kfree_skb(newm, FREE_WRITE);
      if (newm != qskb)
        kfree_skb(qskb, FREE_WRITE);
    }
    else {
      /* fix up skb */
      newm->data      = BLKSTART(&newmb);
      newm->ip_hdr    = (struct iphdr *) newm->data;
      newm->len       = ntohs(newm->ip_hdr->tot_len);
      newm->tail      = newm->data + newm->len;
      newm->protocol  = htons(ETH_P_IP);
      newm->ip_summed = 0;
      newm->h.iph     = newm->ip_hdr;

      newm->dev	    = dev;

      if (newm != qskb)
        kfree_skb(qskb, FREE_WRITE);

      /* Set marker in skb */
      newm->proto_priv[15] |= SND_SEC;

      interface_ship_out(newm);

      retval = FW_QUEUE;
    }
  }
  else if (result == SKIP_QUEUED) {
    /* queued, will be fed back to us */

    if (newm != qskb)
      kfree_skb(newm, FREE_WRITE);
    
    qskb->dev = dev;

    retval = FW_QUEUE;
  } 
  else {
    /* bad packet/policy/etc. */

    kfree_skb(newm, FREE_WRITE);
    if (newm != qskb)
      kfree_skb(qskb, FREE_WRITE);
  }

  restore_flags(flags);

  return retval;
}


/* The forward function is used to get the true device MTU only.
   This is a hack... Real packets are forwarded via a pair of
   input_packet/output_packet. */

int forward_packet(struct firewall_ops *dummy1, int dummy2, struct device *dev,
		  void *dummy3, void *arg, struct sk_buff **dummy4)
{
  /* hack to return true device MTU; arg is pointer to MTU */

  if (dev_skip(dev, "forward") && arg)
    *((unsigned short *) arg) += maxheadergrowth;

  return FW_ACCEPT;
}


/* and the null op */

int packet_nop(struct firewall_ops *this, int pf, struct device *dev,
	       void *phdr, void *arg, struct sk_buff **pskb)
{
  return FW_SKIP;
}

/* and one for keeping away standard firewalls -- on forward, a standard
   forward firewall call would be made. This call would also be made
   with local packets AND it would get the encrypted packet. */

int packet_accept(struct firewall_ops *this, int pf, struct device *dev,
	       void *phdr, void *arg, struct sk_buff **pskb)
{
  if ((*pskb)->proto_priv[15] & SND_SEC)
    return FW_ACCEPT;
  else
    return FW_SKIP;
}


/* Put skb back to kernel output processing. Caution! Causes
   the output interceptor to be called recursively. Since it
   is done via the timer, there should be no way to crash the
   stack!? */

static inline void interface_ship_out(struct sk_buff *skb)
{
  IS_SKB(skb);

  if (ip_forward(skb, skb->dev, IPFWD_NOTTLDEC, skb->h.iph->daddr))
    kfree_skb(skb, FREE_WRITE);
}


/* feed back queued packet */
int interface_feed_out(struct sk_buff *skb)
{
/*
  unsigned long flags;
*/

  IS_SKB(skb);


  /* fix the packet for ip_forward (because ip_build_xmit might not have) */
  skb->h.iph = skb->ip_hdr;

  /* check it */
  if (output_packet(NULL, PF_IPSEC, skb->dev, NULL, NULL, &skb) < FW_ACCEPT)
    kfree_skb(skb, FREE_WRITE);
  else {
    /* ...and ship it */
/*
    save_flags(flags);
    cli();
*/

    interface_ship_out(skb);

/*
    restore_flags(flags);
*/
  }

  return 0;
}


/* feed back queued packet */
int interface_feed_in(struct sk_buff *skb)
{
  skb->protocol  = htons(ETH_P_IP);
  skb->ip_summed = 0;

  netif_rx(skb); /* frees skb for us */

  return 0;
}


/* Allocate "size" bytes skb; we leave the original "headroom" and
   "tailroom" needed for tcpdump and MAC headers. We also copy the
   original packet which is slow but the only way to make this *really*
   stable */

int interface_getbuf(struct memblk *m, void **pskb, int size)
{
  struct sk_buff *n;
  struct sk_buff *skb = (struct sk_buff *) *pskb;

  n = skb_expand_copy(skb, skb_headroom(skb) + size + skb_tailroom(skb));
  if (n == NULL)
    return -ENOMEM;

  *pskb = n;

  skb2memblk(n, m, m->ms);
  m->len     = size;
  m->ms->len = m->len + m->offset;

  return 0;
}

struct devlist {
  char *name;
  struct devlist *next;
};
static struct devlist *skip_devs = NULL;

/* Check interface */
static int dev_skip(struct device *dev, char *s)
{
  struct devlist *tmp = skip_devs;

  /* XXX */
  if (!tmp) {
    printk("Warning: dev_skip %s: dev==NULL\n", s);
    return 0;
  }
  /* XXX */

  while (tmp && strcmp(tmp->name, dev->name))
    tmp = tmp->next;

  return (tmp != NULL);
}

/* Add interface to list */
static void dev_addlist(struct device *dev)
{
  struct devlist *tmp;

  if (skip_devs == NULL) {
    skip_devs = kmalloc(sizeof(*skip_devs), GFP_ATOMIC);
    skip_devs->name = kmalloc(strlen(dev->name) + 1, GFP_ATOMIC);
    strcpy(skip_devs->name, dev->name);
    skip_devs->next = NULL;
  }
  else {
    tmp = skip_devs;
    while (tmp->next != NULL)
      tmp = tmp->next;
    tmp->next = kmalloc(sizeof(*tmp), GFP_ATOMIC);
    tmp = tmp->next;
    tmp->name = kmalloc(strlen(dev->name) + 1, GFP_ATOMIC);
    strcpy(tmp->name, dev->name);
    tmp->next = NULL;
  }
}

/* Remove interface from list */
static void dev_rmlist(struct device *dev)
{
  struct devlist *tmp, *tmp2;

  if (skip_devs == NULL)
    return;

  if (strcmp(skip_devs->name, dev->name) == 0) {
    tmp = skip_devs->next;
    kfree(skip_devs->name);
    kfree(skip_devs);
    skip_devs = tmp; 
  }
  else {
    tmp = skip_devs;
    while (tmp->next && strcmp(tmp->next->name, dev->name))
      tmp = tmp->next;

    if (tmp->next == NULL)
      return;

    tmp2 = tmp->next->next;
    kfree(tmp->next->name);
    kfree(tmp->next);
    tmp->next = tmp2;
  }
}

/* Attach/detach to/from interfaces */
int interface_attach(void *dummy, u_char *ipaddr)
{
  struct device *dev;
  int result = -1;

  for (dev = dev_base; dev != NULL; dev = dev->next) { 
    if ((dev->family == AF_INET) && (dev->pa_addr == *((__u32 *) ipaddr))
      && !dev_skip(dev, "attach")) {
      if (dev->mtu < 68 + maxheadergrowth)
        printk("enskip: %s: interface mtu of %d is too small\n",
                dev->name, dev->mtu);
      else {
        dev->mtu -= maxheadergrowth;
	dev_addlist(dev);
        result = 0;
      }
    }
  }

  return result;
}

int interface_detach(void *dummy, u_char *ipaddr)
{
  struct device *dev;
  int result = -1;

  for (dev = dev_base; dev != NULL; dev = dev->next) { 
    if ((dev->family == AF_INET) && (dev->pa_addr == *((__u32 *) ipaddr)) 
       && dev_skip(dev, "detach")) {
      dev_rmlist(dev);
      dev->mtu += maxheadergrowth;
      result = 0;
    }
  }

  return result;
}


/* Decrease the MTUs of all SKIP interfaces */

int interface_init(void)
{
  struct device *dev;

  maxheadergrowth = ipsp_maxheadergrowth();

  for (dev = dev_base; dev != NULL; dev = dev->next) { 
    if (dev->family == AF_INET) {
      if (dev->mtu < 68 + maxheadergrowth)
        printk("enskip: %s: interface mtu of %d is too small\n",
                dev->name, dev->mtu);
      else {
        if ((dev->flags & IFF_LOOPBACK) == 0) {
          dev->mtu -= maxheadergrowth;
          dev_addlist(dev);
        }
      }
    }
  }

  return 0;
}


/* Restore the MTUs */

int interface_exit(void)
{
  struct device *dev;

  for (dev = dev_base; dev != NULL; dev = dev->next) { 
    if ((dev->family == AF_INET) && dev_skip(dev, "exit")) {
      dev_rmlist(dev);
      dev->mtu += maxheadergrowth;
    }
  }

  return 0;
}
