/* (c) by G. Caronni in 1995                  	19.11.95 */
/* rewritten to conform to spec: Robert Muchsel 1996     */

/* #define DEBUG_CDP */

#ifndef USE_SUNCERT
#error "Sorry, won't work without SUNCERT define"
#endif

#include <stdio.h>
#include <string.h>
#include <assert.h>
#include <unistd.h>
#include <malloc.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include "cert_defs.h"
#include "cert_udp.h"
#include "cert_db.h"
#include "cert_int.h"
#ifdef UBS_CHIPCARD
#include "cert_ubs.h"
#endif
#include "md5.h"
#include "../include/id.h"

#ifdef USE_SUNCERT
#include "suncert/enlink.h"
#ifdef __GNUC__
#ident "$Id: cert_udp.c,v 1.10 1996/06/07 11:32:57 skip Exp $"
#else
static char rcsid[] = "$Id: cert_udp.c,v 1.10 1996/06/07 11:32:57 skip Exp $";
#endif
#endif


int send_socket;
int recv_socket;


int udp_init(void)
{
  struct sockaddr_in addr;

  if ((send_socket = socket(PF_INET, SOCK_DGRAM, 0)) < 0)
    return -1;
  memset(&addr, 0, sizeof(addr));
  addr.sin_family = AF_INET;
  addr.sin_port   = htons(SEND_PORT);
  addr.sin_addr.s_addr = htonl(INADDR_ANY);
  if (bind(send_socket, (struct sockaddr *) &addr, sizeof(addr)) < 0) {
    close(send_socket);
    perror("CERT udp_init bind send_socket");
    return -1;
  }
  if ((recv_socket = socket(PF_INET, SOCK_DGRAM, 0)) < 0) {
    close(send_socket);
    return -1;
  }
  memset(&addr, 0, sizeof(addr));
  addr.sin_family = AF_INET;
  addr.sin_port   = htons(RECV_PORT);
  addr.sin_addr.s_addr = htonl(INADDR_ANY);
  if (bind(recv_socket, (struct sockaddr *) &addr, sizeof(addr)) < 0) {
    perror("CERT udp_init bind recv_socket");
    close(send_socket);
    close(recv_socket);
    return -1;
  }
  return 0;
}

void udp_exit(void)
{
  close(send_socket);
  close(recv_socket);
}

static dh_data *add_cert(u_char *buf, struct in_addr from)
{
  dh_data *res = NULL;
  int len, nsid, ctype, clen;
  u_char *mkid, *cert, *t, *hex;
  int l, i;
  struct MD5Context ctxt;

  void *cdata;

/* format:
 *   Name Type
 *     CDP_NAME_SKIP = 1
 *   Name Length (8bit)    <-- buf
 *   Name
 *     NSID (8bit), MKID
 *   Cert Type 
 *     DH Public = 4   OR
 *     X.509 = 1
 *   Cert Length (16 bit)
 *   Cert
 */

  len = *buf++;
  nsid = *buf++;
  if (len - 1 != nsid_len[nsid]) {
    printf("CERT add_cert mkid size not correct.\n");
    return NULL;
  }
  mkid = buf;
  buf += len - 1;
  cert = buf;

  ctype = *buf++;
  clen  = *buf++ << 8;
  clen += *buf++;

  if (ctype == CDP_CERT_X509) {
    cdata = enlink_readudp(cert);
    if (!cdata)
      return NULL;
    if (enlink_validate(/* mkid, len, cdata */))
      return NULL;
    res = db_cert_init();
    res->begins  = enlink_getbefore(cdata);
    res->expires = enlink_getafter(cdata);
    if (res->p)
      int_drop(res->p);
    if (res->g)
      int_drop(res->g);
    if (res->ki)
      int_drop(res->ki);
    res->g   = enlink_getg(cdata);
    res->p   = enlink_getp(cdata);
    res->ki  = enlink_geti(cdata);
    res->rns = nsid;
    res->cert_data = cdata;
    res->rmk = malloc(len);
    memcpy(res->rmk, mkid, len);
    db_unique_disk_save(&res, inet_ntoa(from), 0);
    return res;
  }
  else if (ctype == CDP_CERT_DH_PUBLIC && nsid == SKIP_NSID_DH) {
    res = db_cert_init();
    res->rns = nsid;
    t = buf;
    res->begins   = *t++ << 24;
    res->begins  += *t++ << 16;
    res->begins  += *t++ << 8;
    res->begins  += *t++;
#ifdef DEBUG_CDP
    printf("  dh CERT: got begin: %08lx\n", res->begins);
#endif
    res->expires  = *t++ << 24;
    res->expires += *t++ << 16;
    res->expires += *t++ << 8;
    res->expires += *t++;
#ifdef DEBUG_CDP
    printf("  dh CERT: got expire: %08lx\n", res->expires);
#endif
    l  = *t++ << 8;
    l += *t++;
#ifdef DEBUG_CDP
    printf("  dh CERT: got p length: %i\n", l);
#endif
    hex = malloc(2 * l + 1);
    for (i = 0; i < l; i++, t++)
      sprintf(hex + 2 * i, "%02X", *t);
    if (!res->p)
      int_init(&(res->p));
    int_set_str(res->p, hex, 16);
    l  = *t++ << 8;
    l += *t++;
#ifdef DEBUG_CDP
    printf("  dh CERT: got g length: %i\n", l);
#endif
    hex = realloc(hex, 2 * l + 1);
    for (i = 0; i < l; i++, t++)
      sprintf(hex + 2 * i, "%02X", *t);
    if (!res->g)
      int_init(&(res->g));
    int_set_str(res->g, hex, 16);
    l  = *t++ << 8;
    l += *t++;
#ifdef DEBUG_CDP
    printf("  dh CERT: got ki length: %i\n", l);
#endif
    hex = realloc(hex, 2 * l + 1);
    for (i = 0; i < l; i++, t++)
      sprintf(hex + 2 * i, "%02X", *t);
    if (!res->ki)
      int_init(&(res->ki));
    int_set_str(res->ki, hex, 16);
    free(hex);
    if (t - buf != clen) {
#ifdef DEBUG_CDP
      printf("CERT add_cert incoherent length -- expected: %i got: %i\n",
              clen, t - buf);
#endif
      db_cert_drop(res);
      return NULL;
    }
    res->rmk = malloc(nsid_len[nsid]);
    cert_MD5Init(&ctxt);
    cert_MD5Update(&ctxt, buf, t - buf);
    cert_MD5Final(res->rmk, &ctxt);

    if (memcmp(res->rmk, mkid, nsid_len[nsid]) != 0) {
      printf("Unsigned DH CERT: MKID and MD5 hash differ - ignoring CERT.\n");
      db_cert_drop(res);
      return NULL;
    }

    db_unique_disk_save(&res, inet_ntoa(from), 0);
    return res;
  }
  else {
#ifdef DEBUG_CDP
    printf("CERT add_cert unknown CERT-TYPE (%d) or NSID (%d).\n",
           ctype, nsid);
#endif
    /* Huh? What? No comprende! */
    return NULL;
  }
}

static int encode_cert(u_char *buf, dh_data *find, int nsid)
{
  int l;
  u_char *raw;
  u_char *base = buf;

  if (!find->cert_data) {
    /* according to the draft, the layout of unsigned DH public 
       values is (~~~ denotes variable length):
         u32 not valid before
         u32 not valid after
         u16 prime len
         ~~~ prime p
         u16 generator len
         ~~~ generator g
         u16 public value len
         ~~~ public value ki
     */
    *buf = CDP_CERT_DH_PUBLIC;
    buf += 3; /* skip length field */
#ifdef DEBUG_CDP
    printf("     UDH: begins: %08lx\n", find->begins);
#endif
    *buf++ = find->begins >> 24;
    *buf++ = (find->begins >> 16) & 0xff;
    *buf++ = (find->begins >> 8) & 0xff;
    *buf++ = find->begins & 0xff;
#ifdef DEBUG_CDP
    printf("     UDH: expires: %08lx\n", find->expires);
#endif
    *buf++ = find->expires >> 24;
    *buf++ = (find->expires >> 16) & 0xff;
    *buf++ = (find->expires >> 8) & 0xff;
    *buf++ = find->expires & 0xff;

    l = int_extract_raw(&raw, find->p, 0);
    assert(l < 65536);
    *buf++ = (l >> 8) & 0xff;
    *buf++ = l & 0xff;
#ifdef DEBUG_CDP
    printf("     UDH: p length: %i\n", l);
#endif
    memcpy(buf, raw, l);
    buf += l;
    free(raw);

    l = int_extract_raw(&raw, find->g, 0);
    assert(l < 65536);
    *buf++ = (l >> 8) & 0xff;
    *buf++ = l & 0xff;
#ifdef DEBUG_CDP
    printf("     UDH: g length: %i\n", l);
#endif
    memcpy(buf, raw, l);
    buf += l;
    free(raw);

    l = int_extract_raw(&raw, find->ki, 0);
    assert(l < 65536);
    *buf++ = (l >> 8) & 0xff;
    *buf++ = l & 0xff;
#ifdef DEBUG_CDP
    printf("     UDH: ki length: %i\n", l);
#endif
    memcpy(buf, raw, l);
    buf += l;
    free(raw);

    base[1] = (buf-base-3) >> 8 & 0xff;
    base[2] = (buf-base-3) & 0xff;
  }
  else
    buf += enlink_saveudp(buf, find->cert_data);

  return buf-base;
}

void udp_process_requestor(int sock)
{
  dh_data *find = NULL;
  u_char in[20000], out[20000];
  struct sockaddr_in addr;
  int len, addrlen = sizeof(addr);
  int nr;
  cdp_header *req_hdr = (cdp_header *) in;
  cdp_header *ans_hdr = (cdp_header *) out;
  cdp_record *req_rec = (cdp_record *) &(in[12]);
  cdp_record *ans_rec = (cdp_record *) &(out[12]);
  cdp_cert   *req_cert, *ans_cert;

  len = recvfrom(sock, in, sizeof(in), 0, (struct sockaddr *) &addr, &addrlen);

  if (len < 4)
    return;			/* Whoa! */

#ifdef DEBUG_CDP
  printf("CDP query from %s VERSION=%d ACTION=%s RECORDS=%d\n",
	 inet_ntoa(addr.sin_addr), req_hdr->version, 
         req_hdr->action == CDP_REQUEST ? "REQUEST" : "unknown", 
         req_hdr->num_recs);
#endif

  /* We can handle Version 1 only */
  if (req_hdr->version != CDP_VERSION)
    return;

  /* Only requests on this port */
  if (req_hdr->action != CDP_REQUEST)
    return;

  nr = req_hdr->num_recs;

  len -= sizeof(*req_hdr);
  if (len < 2)
    return;

  /* build reply header */
  ans_hdr->version  = CDP_VERSION;
  ans_hdr->action   = CDP_RESPONSE;
  ans_hdr->num_recs = 0;
  ans_hdr->reserved = 0;
  ans_hdr->req_cookie  = req_hdr->req_cookie;
  ans_hdr->resp_cookie = 0;
  
  /* loop through the records and add reply records on the fly */
  while (nr--) {
    req_cert = (cdp_cert *) &(req_rec->name[req_rec->name_len]);

#ifdef DEBUG_CDP
    printf("    record: ACTION=%s NAME TYPE=%d NAME LEN=%d CERT TYPE=%d CERT LEN=%i\n",
	   req_rec->action == CDP_REC_GET ? "GET" : "PUT/other", 
           req_rec->name_type, req_rec->name_len, req_cert->cert_type, 
           (req_cert->cert_len[0] << 8) + req_cert->cert_len[1]);
#endif

    if (req_rec->name_type == CDP_NAME_SKIP) {
      /* We only support SKIP name types and ignore all others */
 
      /* We support GET and PUT commands and ignore all others */
      if (req_rec->action == CDP_REC_GET) {
#ifdef DEBUG_CDP
        printf("            GET: %s\n", 
               make_name(req_rec->name[0], &(req_rec->name[1])));
#endif
	memcpy(ans_rec, req_rec, (u_char *) req_cert - (u_char *) req_rec); 
        ans_cert = (cdp_cert *) &(ans_rec->name[ans_rec->name_len]);

        /* for SKIP, the name contains first NSID and then MKID */
        find = db_public_find_load(NULL, req_rec->name[0], &(req_rec->name[1]));

        if (find == (dh_data *) (-1))
          find = db_public_find_load(find, req_rec->name[0], &(req_rec->name[1]));

        if ((find == NULL) ||
            (encode_cert((u_char *) &(ans_cert->cert_type), find,
                         req_rec->name[0]) == 0)) {
#ifdef DEBUG_CDP
          printf("    answer: request failed.\n");
#endif
          /* The request failed. Return error record. */
          ans_rec->action = CDP_REC_GET_FAIL;
          ans_cert->cert_type   =
	  ans_cert->cert_len[0] =
          ans_cert->cert_len[1] = 0;
          ans_rec = (cdp_record *) ans_cert->cert;
          ans_hdr->num_recs++;
        }
        else {
          int ans_cert_len = (ans_cert->cert_len[0] << 8)
                             + ans_cert->cert_len[1];
#ifdef DEBUG_CDP
          printf("    answer: cert type=%i, cert length=%i\n", 
                 ans_cert->cert_type, ans_cert_len);
#endif
          ans_rec->action = CDP_REC_GET_OK;
          ans_rec = (cdp_record *) &(ans_cert->cert[ans_cert_len]);
          ans_hdr->num_recs++;
        }
      }
      else if (req_rec->action == CDP_REC_PUT) {
        /* XXX currently unsupported XXX */

        /* The request failed. Return error record. */
	memcpy(ans_rec, req_rec, (u_char *) req_cert - (u_char *) req_rec); 
        ans_rec->action = CDP_REC_PUT_FAIL;
        ans_cert = (cdp_cert *) &(ans_rec->name[ans_rec->name_len]);
        ans_cert->cert_type   = 
	ans_cert->cert_len[0] =
	ans_cert->cert_len[1] = 0;
        ans_rec = (cdp_record *) ans_cert->cert;
        ans_hdr->num_recs++;
      }
    }

    req_rec = (cdp_record *) &(req_cert->cert[(req_cert->cert_len[0] << 8) 
                                              + req_cert->cert_len[1]]);
  } /* while nr-- */

#ifdef DEBUG_CDP
  printf("CDP sendto(%s) STATUS=%d RECORDS=%d length=%i\n", 
         inet_ntoa(addr.sin_addr), ans_hdr->action, ans_hdr->num_recs,
         (u_char *) ans_rec - out); 
#endif
  sendto(sock, out, (u_char *) ans_rec - out, 0, 
         (struct sockaddr *) &addr, sizeof(addr));
}


void udp_process_responder(int sock)
{
  extern void do_discovered(u_char, u_char *, dh_data *, u_long);
  dh_data *response = NULL;
  u_char in[20000];
  struct sockaddr_in addr;
  int len, addrlen = sizeof(addr);
  int nr;
  cdp_header *hdr = (cdp_header *) in;
  cdp_record *rec = (cdp_record *) &(in[12]);
  cdp_cert   *cert;

  len = recvfrom(sock, in, sizeof(in), 0, (struct sockaddr *) &addr, &addrlen);

  if (len < 4)
    return;			/* Whoa! */

#ifdef DEBUG_CDP
  printf("CDP reply from %s VERSION=%d STATUS=%d RECORDS=%d\n",
	 inet_ntoa(addr.sin_addr), hdr->version, hdr->action, hdr->num_recs);
#endif

  /* We can handle Version 1 only */
  if (hdr->version != CDP_VERSION)
    return;

  /* Only replies on this port */
  if (hdr->action != CDP_RESPONSE)
    return;

  nr = hdr->num_recs;

  len -= sizeof(*hdr);
  if (len < 2)
    return;

  /* loop through the records */
  while (nr--) {
    cert = (cdp_cert *) &(rec->name[rec->name_len]);

#ifdef DEBUG_CDP
    printf("    record: STATUS=%s NAME TYPE=%d NAME LEN=%d CERT TYPE=%d CERT LEN=%i\n",
	   rec->action == CDP_REC_GET_OK ? "GET_OK" : "GET_FAIL", 
           rec->name_type, rec->name_len, cert->cert_type, 
           (cert->cert_len[0] << 8) + cert->cert_len[1]);
#endif

    if (rec->name_type == CDP_NAME_SKIP) {
      /* We only support SKIP name types and ignore all others */
 
      /* We support GET SUCCEEDED and ignore all others */
      if (rec->action == CDP_REC_GET_OK) {
#ifdef DEBUG_CDP
        printf("            OK: %s\n", 
               make_name(rec->name[0], &(rec->name[1])));
#endif
        /* if positive response, find/create remote certificate and dh_data,
         * if the incoming data corresponds to an already existing cert,
         * it may be that it just contains new certificate data, which should
         * be appended. if different public value certificate with same name 
         * already exists, store the new one nevertheless, should later check 
         * validity times, and refuse the new one if overlap 
         * (or remove old one)
         * pass it with nsid, mkid, ip to do_discovered
         * if failure, just pass remote=NULL to do_discovered
         */
         response = add_cert(&(rec->name_len), addr.sin_addr);

         if (response)
           do_discovered(response->rns, response->rmk, response,
		         addr.sin_addr.s_addr);
      }
      else if (rec->action == CDP_REC_GET_FAIL) {
        /* XXX currently ignored XXX */

#ifdef DEBUG_CDP
        printf("            FAIL: %s\n", make_name(rec->name[0], &(rec->name[1])));
#endif
      }
    }

    rec = (cdp_record *) &(cert->cert[(cert->cert_len[0] << 8) 
                                      + cert->cert_len[1]]);
  } /* while nr-- */

}


/* ask about certificates. includes own public value if available */
void udp_request(u_long addr, u_char ons, u_char *omk, u_char rns, u_char *rmk)
{
  u_char buf[20000];
  dh_data *find;
  struct sockaddr_in iaddr;
  cdp_header *hdr = (cdp_header *) buf;
  cdp_record *rec = (cdp_record *) &(buf[12]);
  cdp_cert *cert;
  u_char *end;

  find = db_public_find_load(NULL, ons, omk);
  if (find == (dh_data *) (-1))
    find = db_public_find_load(find, ons, omk);

  /* Fill in the fields */
  hdr->version     = CDP_VERSION;
  hdr->action      = CDP_REQUEST;
  hdr->num_recs    = 1;
  hdr->reserved    = 0;
  hdr->req_cookie  = 0;
  hdr->resp_cookie = 0;

  rec->action      = CDP_REC_GET;
  rec->name_type   = CDP_NAME_SKIP;
  rec->name_len    = nsid_len[rns] + 1;
  rec->name[0]     = rns;
  memcpy(&(rec->name[1]), rmk, nsid_len[rns]);
#ifdef DEBUG_CDP
  printf("CDP_REQUEST GET: %s\n", make_name(rec->name[0], &(rec->name[1])));
#endif

  cert = (cdp_cert *) &(rec->name[rec->name_len]);
  if (find && find->cert_data) {
    /* include our own certificate */
    end = enlink_saveudp((u_char *) cert, find->cert_data) + (u_char *) cert;
  }
  else {
    cert->cert_type    =
    cert->cert_len[0]  =
    cert->cert_len[1]  = 0;
   
    end = (u_char *) cert->cert;
  }

  memset(&iaddr, 0, sizeof(iaddr));
  iaddr.sin_family = AF_INET;
  iaddr.sin_port   = htons(RECV_PORT);
  iaddr.sin_addr.s_addr = addr;	/* already network format */
#ifdef DEBUG_CDP
  printf("CDP sendto(%s) CDP_REQUEST\n", inet_ntoa(iaddr.sin_addr));
#endif
  sendto(send_socket, buf, end - buf, 0, (struct sockaddr *) &iaddr, sizeof(iaddr));
}
