/*
     This file is part of GNUnet.
     (C) 2001, 2002 Christian Grothoff (and other contributing authors)

     GNUnet is free software; you can redistribute it and/or modify
     it under the terms of the GNU General Public License as published
     by the Free Software Foundation; either version 2, or (at your
     option) any later version.

     GNUnet is distributed in the hope that it will be useful, but
     WITHOUT ANY WARRANTY; without even the implied warranty of
     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
     General Public License for more details.

     You should have received a copy of the GNU General Public License
     along with GNUnet; see the file COPYING.  If not, write to the
     Free Software Foundation, Inc., 59 Temple Place - Suite 330,
     Boston, MA 02111-1307, USA.
*/

/**
 * This file contains the connection table which lists all the
 * current connections of the node with other hosts and buffers
 * outgoing packets to these hosts. The connection table also
 * contains state information such as sessionkeys, credibility and
 * host activity.
 * @author Tzvetan Horozov
 * @author Christian Grothoff
 **/ 

#include "config.h"
#include "statistics.h"
#include "server/connection.h"
#include "knownhosts.h"
#include "server/trustservice.h"

/**
 * The buffer containing all current connections.
 **/
static BufferEntry * CONNECTION_buffer_;

/**
 * Size of the CONNECTION_buffer_
 **/
static unsigned int CONNECTION_MAX_HOSTS_;

/**
 * The DirContents array for scanning the hosts/ directory.
 **/
static volatile int CONNECTION_currentActiveHosts;
static time_t CONNECTION_lastTimeConnectionWarning;

/* ************* internal Methods **************** */

/**
 * Call this method periodically to decrease liveness of hosts.
 * Call once every minute.
 **/
static void cronDecreaseLiveness(void * unused);

/**
 * Append a message to the current buffer.
 **/
static void appendToBuffer(BufferEntry * be,
		    MessagePart * message);


/**
 * Look in the list for known hosts; pick a random host
 * for the hosttable at index index. When called, the 
 * mutex of at the given index must not be hold.
 **/
static void scanForHosts(unsigned int index);

/* ******************** CODE ********************* */


/**
 * Initialize this module.
 **/
void initConnection() {
  unsigned int i;
  BufferEntry * be;
  
  CONNECTION_lastTimeConnectionWarning = 0;
  CONNECTION_currentActiveHosts = 0;
  CONNECTION_MAX_HOSTS_ = getMaxNodes();
   i = 1;
  while (i< CONNECTION_MAX_HOSTS_)
    i*=2;
  CONNECTION_MAX_HOSTS_ = i; /* make sure it's a power of 2 */


  CONNECTION_buffer_ 
    = (BufferEntry*) xmalloc(sizeof(BufferEntry)*CONNECTION_MAX_HOSTS_,
			     "initConnection: connection buffer");

#if PRINT_CONNECTION > 1
  print("CONNECTION: Connection module initialization\n");
#endif
  for (i=0;i<CONNECTION_MAX_HOSTS_;i++) {
    LOOPPRINT("initConnection");
    be = &CONNECTION_buffer_[i];
    create_mutex(&be->sem);
    be->trust = 0;
    be->isAlive = 0;
    be->status = STAT_DOWN;
    be->prio = 0;
    be->buffer.sequence.header.size
      = htons(sizeof(SEQUENCE_Message));
    be->buffer.sequence.header.requestType
      = htons(GNET_PROTO_SEQUENCE);
    be->buffer.sequence.sequenceNumber
      = htonl(1);
    be->len = sizeof(SEQUENCE_Message);
  }
  addCronJob(&cronDecreaseLiveness,
	     5, 60,
	     NULL); /* Liveness in 60s intervals */  
}

/**
 * Print the contents of the connection buffer. May
 * NOT be called from synchronized context!
 **/
void printConnectionBuffer() {
  int i;
  BufferEntry * be;
  HexName hostName;
  SKEYString skey;

  for (i=0;i<CONNECTION_MAX_HOSTS_;i++) {
    LOOPPRINT("printConnectionBuffer");
    be = &CONNECTION_buffer_[i];
    if (be->status != STAT_DOWN) {
      hash2hex(&be->hostId.hashPubKey, &hostName);
      printSessionKey(&be->skey, &skey);
      print("%d: %d - %s has trust %d, liveness %u, buflen %d\n\t and key %s\n",
	    be->status,
	    i, &hostName,
	    be->trust & 0x7FFFFFFF, 
	    be->isAlive, be->len,
	    &skey);
    }
  }
}

static void sendBufferSynchronized(BufferEntry * be) {
  MUTEX_LOCK(&be->sem);
  sendBuffer(be);
  MUTEX_UNLOCK(&be->sem);
}

/**
 * Call this method periodically to decrease liveness of hosts.
 * Call once every minute.
 **/
static void cronDecreaseLiveness(void * unused) {
  int i;
  BufferEntry * be = NULL;
  unsigned int act;
  int counter;

#if PRINT_CRON
  print("CRON: enter cronDecreaseLiveness\n");
#endif
  counter = 0;
#if PRINT_CONNECTION 
  print("CONNECTION: Connection module cron starts:\n");
  printConnectionBuffer();
#endif
  for (i=0;i<CONNECTION_MAX_HOSTS_;i++) {
    LOOPPRINT("cronDecreaseLiveness");
    be = &CONNECTION_buffer_[i];
    MUTEX_LOCK(&be->sem);
    act = be->isAlive;
    be->isAlive = act >> 1;
    if ((act == 1) && (be->status == STAT_UP))
      whitelistHost(&be->hostId);
    if (be->status == STAT_DOWN) {
      MUTEX_UNLOCK(&be->sem);
      scanForHosts(i);           
    } else {
      if (be->isAlive > 0)
	counter++;
      else
	be->status = STAT_DOWN;
      MUTEX_UNLOCK(&be->sem);
    }
  }
  CONNECTION_currentActiveHosts = counter;

#if PRINT_CRON
  print("CRON: exit cronDecreaseLiveness\n");
#endif
}

/**
 * We received a sign of life from this host. 
 **/
void notifyActive(HostIdentity * hostId,
		  int challenge) {
  BufferEntry * be;

#if PRINT_CONNECTION > 1
  HexName hex;

  hash2hex(&hostId->hashPubKey, &hex);
  print("CONNECTION: Marking host %s active.\n",
	&hex);
#endif
  be = lookForHost(hostId);
  if (be != NULL) {
    switch (be->status) {
    case STAT_DOWN:
      break;
    case STAT_WAITING_FOR_PING:
    case STAT_WAITING_FOR_PONG:
      if (challenge == be->challenge) {
	be->status = STAT_UP;
	CONNECTION_currentActiveHosts++;
	be->isAlive |= 0x80000000;
      }
      break;
    case STAT_UP:
      be->isAlive |= 0x80000000;
      break;
    default:
      print("WARNING: unknown status!\n");
      break;
    }
    MUTEX_UNLOCK(&be->sem);
  }
}

/**
 * Send a message to a couple of people. This method
 * is supposed to be clever about chosing a random set
 * of people where the message fits nicely into the buffers...
 *
 * @param message the message to send
 * @param priority how important is the message? The higher, the more important
 **/
void broadcast(MessagePart * message,
	       unsigned int priority) {
  unsigned int i;
  BufferEntry * be = NULL;

#if PRINT_CONNECTION > 1
  print("CONNECTION: broadcasting\n");
#endif
  for (i=0;i<CONNECTION_MAX_HOSTS_;i++) {
    LOOPPRINT("broadcast");
    /* we need no sync here as we only read,
       and concurrent rw access does not hurt */
    be = &CONNECTION_buffer_[i];
    if (be->status != STAT_UP) 
      continue;    
    /* FIXME: be smarter, do not ALWAYS send! */
    unicast(message,
	    &be->hostId,
	    priority);
  }
}

/**
 * We have gotten a query and decided to forward it
 * (or it may be our own query). This method must
 * decide to whom to send the query. 
 * <p>
 * Attention: if the returnTo-identity is not us, then
 * the receiver must also been given the identity of the
 * host we are forwarding (but not indirecting) for.
 * <br>
 * Choosing to whom to forward uses the following Heuristic:
 * Choose a host using a BIASED random generator.
 * Prefer:
 * - hosts that are alive (be->isAlive is measured mostly
 *   with powers of 2, so addition is too much, we
 *   shift by 16 first...)
 * - hosts that have little credit (don't bother the good
 *   guys if somebody else may do, but this should not
 *   be the dominant effect, so linear though small in 
 *   number
 * - hosts where the key is close to the query (this
 *   should be significant, but not totally dominant,
 *   it must be possible for any host to get the
 *   query anyway. The distance is in the range of
 *   0-65535, so we go linear again
 * - the random number should be in the same range
 *   as the distance; we multiply by an increasing value ('j') to
 *   guarantee termination.
 *
 * @param msg the message to forward.
 **/
void forwardQuery(QUERY_Message * msg) {
  int m;
  int i;
  int j;
#if PRINT_QUERY > 1
  HexName hex;
#endif
  BufferEntry * be = NULL;
#if PRINT_WARNINGS
  time_t now;
#endif

  if (ntohl(msg->ttl) <= 0)
    return;
  m = CONNECTION_currentActiveHosts; /* the obvious max */
  if ( (ntohl(msg->priority) <= 1) &&
       (m > 4) )
    m = 4; /* if priority is REALLY low, send only to 4 hosts */
  else
    if (m > ntohl(msg->priority))
      m = ntohl(msg->priority); /* make sure priority does not
				   drop under 1 if possible */
  if (m == 0) {
    if (CONNECTION_currentActiveHosts == 0) {
      cronDecreaseLiveness(NULL); /* this also tries to connect... */
#if PRINT_WARNINGS
      time(&now);
      if (now - CONNECTION_lastTimeConnectionWarning > 60) {  /* fixme: > 300 */
 	print("WARNING: Not connected to any hosts, can not forward query.\n");
	downloadHostlist(NULL);
	time(&CONNECTION_lastTimeConnectionWarning);
      }
#endif
    }
    return;
  }
  msg->priority = htonl(ntohl(msg->priority)/m);
  
#if PRINT_QUERY > 1
  hash2hex(&msg->query, &hex);
  print("CONNECTION: Forwarding query %s (broadcast) with ttl %d and priority %u.\n",
	&hex, 
	ntohl(msg->ttl),
	ntohl(msg->priority));
#endif
  i = 0;
  j = 1;
  while (m > 0) {
    LOOPPRINT("forwardQuery");
    /* we need no sync here as we only read,
       and concurrent rw access does not hurt */
    be = &CONNECTION_buffer_[i++];
    if (i == CONNECTION_MAX_HOSTS_) {
      i=0;    
      j++;
      if (j==4)
	break;
    }
    if (be->status != STAT_UP)
      continue;
    /**
     * HEURISTIC:
     **/
    if (distanceHashCode160(&msg->query,
			    &be->hostId.hashPubKey)>
	(be->isAlive >> 16) + (rand()&65535)*j - (be->trust & 0x7FFFFFFF))
      continue;
    unicast(&msg->header,
	    &be->hostId,
	    ntohl(msg->priority));
    m--;
  }
}

/**
 * Send a message to a specific host (reply, enqueue)
 *
 * @param message the message to send (unencrypted!), first BLOCK_LENGTH_SIZE bytes give size
 * @param hostId the identity of the receiver
 * @param priority how important is the message?
 **/
void unicast(MessagePart * message,
	     HostIdentity * hostId,
	     unsigned int priority) {
  BufferEntry * be;
#if PRINT_CONNECTION > 1
  HexName hex;

  hash2hex(&hostId->hashPubKey, &hex);
  print("CONNECTION: unicasting to host %s message of type %d\n",
	&hex,
	ntohs(message->requestType));
#endif  
  be = lookForHost(hostId);
  if ( (be == NULL) || 
       (be->status == STAT_DOWN) ) {
    if (be != NULL)
      MUTEX_UNLOCK(&be->sem);
    be = addHost(hostId);  
    if (be->status == STAT_DOWN)
      exchangeKey(be);
  }
  be->prio += priority;
  appendToBuffer(be,
		 message);
  MUTEX_UNLOCK(&be->sem);
}

/**
 * Perform an operation for all connected hosts.
 * The BufferEntry structure is passed to the method.
 * No synchronization or other checks are performed.
 * @param method the method to invoke
 * @param arg the second argument to the method
 **/ 
void forAllConnectedHosts(void (*method)(BufferEntry *, void*),
			   void * arg) {
  unsigned int i;
  BufferEntry * be = NULL;
  
#if PRINT_CONNECTION > 1
  print("CONNECTION: broadcasting\n");
#endif
  for (i=0;i<CONNECTION_MAX_HOSTS_;i++) {
    LOOPPRINT("forAllConnectedHosts");
    be = &CONNECTION_buffer_[i];
    method(be, arg);
  }
}

/**
 * Check the sequence number. Updates the sequence
 * number as a side-effect.
 * @returns OK if ok, SYSERR if not.
 **/
int checkSequenceNumber(HostIdentity * hostId,
			unsigned int sequenceNumber) {
  BufferEntry * be;
  int res;

  if (hostId == NULL) {
    return SYSERR;
  }
  be = lookForHost(hostId);
  if (be == NULL) 
    return SYSERR; /* host not found */
  res = OK;
  if (be->lastSequenceNumberReceived >= sequenceNumber) {
    unsigned int rotbit = 1;
    if ( (be->lastSequenceNumberReceived - sequenceNumber <= 32) && 
	 (be->lastSequenceNumberReceived != sequenceNumber) ) {
      rotbit = rotbit << (be->lastSequenceNumberReceived - sequenceNumber - 1);
      if ( (be->lastPacketsBitmap & rotbit) == 0) {
	res = OK;
	be->lastPacketsBitmap |= rotbit;
      } else
	res = SYSERR;
    } else
      res = SYSERR;
#if PRINT_WARNINGS 
    if (res == SYSERR) {
      print("WARNING: Invalid sequence number %u < %u, dropping rest of packet\n",
	    sequenceNumber, 
	    be->lastSequenceNumberReceived);
    }    
#endif
  } else {    
    be->lastPacketsBitmap = 
      be->lastPacketsBitmap << (sequenceNumber - be->lastSequenceNumberReceived);
    be->lastSequenceNumberReceived = sequenceNumber;
  }
  MUTEX_UNLOCK(&be->sem);
  return res;
}
	       
/**
 * Decipher data comming in from a foreign host.
 *
 * @param data the data to decrypt
 * @param size the size of the encrypted data
 * @param hostId the sender host that encrypted the data 
 * @param result where to store the decrypted data, must be at least of size data->len long
 * @returns the size of the decrypted data, -1 on error
 **/
int decryptFromHost(BLOWFISHEncryptedData * data,
		    BLOCK_LENGTH size,
		    HostIdentity * hostId,
		    void * result) {  
  BufferEntry * be;
  int res;
#if PRINT_CONNECTION > 2
  HexName hex;
#endif
#if PRINT_WARNINGS
#if PRINT_CONNECTION <= 2
  HexName hex;
#endif
#endif
  if ( (data == NULL) || 
       (hostId == NULL) ) {
    print("CONNECTION: could not decrypt, message or hostId was NULL\n");
    return -1;
  }

#if PRINT_CONNECTION > 2
  hash2hex(&hostId->hashPubKey, &hex);
  print("CONNECTION: decrypting message from host %s\n",
	&hex);
#endif
  be = lookForHost(hostId);
  if (be == NULL) {
#if PRINT_WARNINGS
#if PRINT_CONNECTION <= 2
    hash2hex(&hostId->hashPubKey, &hex);
#endif
    print("CONNECTION: decrypting message from host %s failed, no sessionkey!\n",
	  &hex);
#endif
    return -1; /* could not decrypt */
  }
  res = decryptBlock(&be->skey, 
		     data,
		     size,
		     result);
  MUTEX_UNLOCK(&be->sem);
  return res;
}


/* ********************* internal methods ************************ */

typedef struct {
  int index;
  int matchCount;
  HostIdentity match;
} IndexMatch;

static void scanHelperCount(HostIdentity * id,
			    IndexMatch * im) {
  if (hostIdentityEquals(&myIdentity, id)) 
    return;
  if (computeIndex(id) != im->index)
    return;
  im->matchCount++;
}

static void scanHelperSelect(HostIdentity * id,
			     IndexMatch * im) {
  if (hostIdentityEquals(&myIdentity, id)) 
    return;
  if (computeIndex(id) != im->index)
    return;
  if (im->matchCount == 0) {
    memcpy(&im->match,
	   id,
	   sizeof(HostIdentity));
  }
  im->matchCount--;
}


/**
 * Look in the list for known hosts; pick a random host
 * for the hosttable at index index. When called, the 
 * mutex of at the given index must not be hold.
 **/
static void scanForHosts(unsigned int index) {
  int i;
  BufferEntry * be;
  IndexMatch indexMatch;
#if PRINT_CONNECTION > 1
  HexName hn;
#endif
  time_t now;

  time(&now);
  indexMatch.index = index;
  indexMatch.matchCount = 0;
  forEachHost((void (*)(HostIdentity*, void *))&scanHelperCount, 
	      now,
	      &indexMatch);
  if (indexMatch.matchCount == 0) 
    return;
  indexMatch.matchCount 
    = randomi(indexMatch.matchCount); /* randomly choose host */
  forEachHost((void (*)(HostIdentity*, void *))&scanHelperSelect,
	      now,
	      &indexMatch);

  assert(!hostIdentityEquals(&myIdentity, &indexMatch.match),
	 "CONNECTION: adding myself to connection table, how can that be?");
  assert(computeIndex(&indexMatch.match) == index,
	 "CONNECTION: index of host to add is wrong!");
#if PRINT_CONNECTION > 1
  hash2hex(&indexMatch.match.hashPubKey,
	   &hn);
  print("CONNECTION: Connecting to %s at slot %d.\n",
	 &hn, index);
#endif
  be = addHost(&indexMatch.match);
  if (be != NULL) {
    if (be->status == STAT_DOWN) {
      blacklistHost(&be->hostId); /* we're trying now, don't try again too soon */
      exchangeKey(be);
    }
    MUTEX_UNLOCK(&be->sem);
  }
}

/**
 * Look for a host in the table. If the entry is there
 * at the time of checking, returns the entry (the lock
 * is aquired).
 * @returns the index of the host in the table
 **/
BufferEntry * lookForHost(HostIdentity * hostId) {
  unsigned int hash;
  BufferEntry * be;

  hash = computeIndex(hostId);
  be = &CONNECTION_buffer_[hash];
  MUTEX_LOCK(&be->sem);
  if (equalsHashCode160(&hostId->hashPubKey,
			&be->hostId.hashPubKey)) {
    return be;
  } else {
    MUTEX_UNLOCK(&be->sem);
    return NULL; /* not found! */
  }
}

/**
 * Force adding of a host to the buffer. If the node is already in the
 * table, the table entry is returned.
 * @return the table entry for the host (no keyexchange performed so far)
 **/
BufferEntry * addHost(HostIdentity * hostId) {
  BufferEntry * be;
  /* flush old entry */

#if PRINT_CONNECTION > 1
  HexName hex;

  hash2hex(&hostId->hashPubKey, 
	   &hex);
  print("CONNECTION: adding host %s to the table.\n",
	&hex);
#endif 

  be = lookForHost(hostId);
  if (be != NULL)
    return be;
  be = &CONNECTION_buffer_[computeIndex(hostId)];
  if (be->status != STAT_DOWN)
    shutdownConnection(be);
  flushHostCredit(be, NULL);
  
  /* overwrite with new entry */
  memcpy(&be->hostId,
	 hostId,
	 sizeof(HostIdentity));	 
  initHostTrust(be);
  return be;
}

/**
 * Compute the hashtable index of a host id.
 **/
unsigned int computeIndex(HostIdentity * hostId) {
  unsigned int res = (((unsigned int)hostId->hashPubKey.a) & 
		      ((unsigned int)(CONNECTION_MAX_HOSTS_ - 1)));
#if EXTRA_CHECKS
  if (res >= CONNECTION_MAX_HOSTS_)
    errexit("FATAL: CONNECTION_MAX_HOSTS_ not power of 2? (%d)\n",
	    CONNECTION_MAX_HOSTS_);
#endif
  return res;
}

/**
 * Append a message to the current buffer. This method
 * assumes that the access to be is already synchronized.
 **/
static void appendToBuffer(BufferEntry * be,
			   MessagePart * message) {
  BLOCK_LENGTH exSize;
  BLOCK_LENGTH newSize;
  unsigned char * buffer;

#if PRINT_CONNECTION > 2
  HexName hex;

  hash2hex(&be->hostId.hashPubKey, &hex);
  print("CONNECTION: adding message of size %d to buffer of host %s.\n",
	ntohs(message->size),
	&hex);
#endif 
  buffer = (char*) &be->buffer;
  if (message == NULL) {
#if PRINT_WARNINGS
    print("appendToBuffer got garbage. Ignored.\n");
#endif
    return;
  }
  if (ntohs(message->requestType) < GNET_PROTO_MAX)
    GNUNET_STATISTICS.udp_out_counts[ntohs(message->requestType)]++;

  newSize = ntohs(message->size);
  exSize = be->len;
  if (exSize + newSize > BUFSIZE) {
    if (newSize > BUFSIZE) {
#if PRINT_WARNINGS      
      print("appendToBuffer fatal error: exSize: %d newsize: %d BUFSIZE %d\n",
	    exSize, newSize, BUFSIZE);
#endif
      return;
    }
    sendBuffer(be); /* clean up*/
    exSize = sizeof(SEQUENCE_Message);
  }
  memcpy(&buffer[exSize], 	 
	 message, newSize); 
  be->len = exSize + newSize;
  /* flush the buffer after an average of TTL_DECREMENT seconds,
     but not earlier than with 1s delay (to allow accumulation of
     messages) */
  addCronJob((void (*)(void*))sendBufferSynchronized,
	     1+TTL_DECREMENT/2+randomi(TTL_DECREMENT/2),
	     0,
	     be);
#if PRINT_CONNECTION > 3
  print("CONNECTION: done appending message.\n");
#endif 
}

/**
 * Send a buffer; assumes that access to be is already 
 * synchronized.
 **/
void sendBuffer(BufferEntry * be){
  int crc;
  BLOWFISHEncryptedDataMax content;
  MessagePart * part;
  int i;

#if PRINT_CONNECTION > 2
  HexName hex;
  
  hash2hex(&be->hostId.hashPubKey, &hex);
  print("CONNECTION: sending buffer of host %s.\n",
	&hex);
#endif 

#if DUMP_PACKETS_SEND
#if 0 == PRINT_CONNECTION
  HexName hex;
  
  hash2hex(&be->hostId.hashPubKey, &hex);
#endif
#endif
  if (be == NULL) {
#if PRINT_WARNINGS      
    print("sendBuffer called with invalid arguments (NULL).\n");
#endif
    return;
  }  
  if (be->len == sizeof(SEQUENCE_Message))       
    return; /* nothing to send */  
  if (SYSERR == outgoingCheck(be->prio)) {
    be->len = sizeof(SEQUENCE_Message); /* too busy, truncate and drop! */
    return;
  }
  /* here, we could try to add HELOs */
  if (be->len  + sizeof(HELO_Message) <= BUFSIZE) {
    /* not implemented */
  }
  /* finally padd with noise */
  if (be->len + sizeof(MessagePart) <= BUFSIZE) {
    /* append noise */
    /* we may also check here at some point if 
       a HELO may be applicable! */
    part = (MessagePart *) &((char*)&be->buffer)[be->len];
    part->size = htons(BUFSIZE - be->len);
    part->requestType = htons(GNET_PROTO_NOISE);
    GNUNET_STATISTICS.udp_out_counts[GNET_PROTO_NOISE]++;
    for (i=be->len+sizeof(MessagePart);i<BUFSIZE;i++)
      be->buffer.variable[i] = (char) rand();
    be->len = BUFSIZE;
  }

  crc = crc32N(&be->buffer, 
	       be->len);
#if PRINT_CONNECTION > 2
  printf("Sending buffer of length %d\n",
	 be->len);
  printConnectionBuffer();
#endif
  if (be->len ==
      encryptBlock(&be->buffer,
		   be->len,
		   &be->skey,
		   (BLOWFISHEncryptedData*)&content)) {
    sendToHost(NULL,
	       &be->hostId,
	       be->len,
	       &content,
	       ENCRYPTED_FLAG,
	       crc);  
  }
  be->len = sizeof(SEQUENCE_Message); /* ok, mark buffer as send */
  GNUNET_STATISTICS.udp_out_counts[GNET_PROTO_SEQUENCE]++;
   be->buffer.sequence.sequenceNumber 
    = htonl(ntohl(be->buffer.sequence.sequenceNumber)+1);
}

/* end of connection.c */
