//////////////////////////////////////////////////////////////////////////// 
// 
// Copyright (C) DSTC Pty Ltd (ACN 052 372 577) 1993, 1994, 1995.
// Unpublished work.  All Rights Reserved.
// 
// The software contained on this media is the property of the
// DSTC Pty Ltd.  Use of this software is strictly in accordance
// with the license agreement in the accompanying LICENSE.DOC 
// file. If your distribution of this software does not contain 
// a LICENSE.DOC file then you have no rights to use this 
// software in any manner and should contact DSTC at the address 
// below to determine an appropriate licensing arrangement.
// 
//      DSTC Pty Ltd
//      Level 7, GP South
//      University of Queensland
//      St Lucia, 4072
//      Australia
//      Tel: +61 7 3365 4310
//      Fax: +61 7 3365 4311
//      Email: jcsi@dstc.qut.edu.au
// 
// This software is being provided "AS IS" without warranty of
// any kind.  In no event shall DSTC Pty Ltd be liable for
// damage of any kind arising out of or in connection with
// the use or performance of this software.
// 
//////////////////////////////////////////////////////////////////////////// 

package com.dstc.security.ssl;

import java.io.IOException;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.util.Iterator;
import java.util.Arrays;
import java.util.Vector;

import java.security.SecureRandom;
import java.security.Key;
import java.security.KeyFactory;
import java.security.Principal;
import java.security.PublicKey;
import java.security.MessageDigest;
import java.security.NoSuchProviderException;
import java.security.NoSuchAlgorithmException;
import java.security.InvalidKeyException;
import java.security.interfaces.RSAPublicKey;
import java.security.interfaces.DSAPublicKey;
import java.security.interfaces.DSAParams;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.cert.CertificateEncodingException;
import java.security.cert.CertificateException;

import javax.crypto.Cipher;
import javax.crypto.CipherInputStream;
import javax.crypto.CipherOutputStream;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.SecretKeySpec;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.DHParameterSpec;
import javax.crypto.interfaces.DHPublicKey;

import javax.net.ssl.SSLException;
import javax.net.ssl.SSLKeyException;
import javax.net.ssl.HandshakeCompletedListener;
import javax.net.ssl.HandshakeCompletedEvent;
import javax.net.ssl.SSLHandshakeException;

/**
 * An abstract class for collecting handshaking logic  common to both client
 * and server.
 * 
 * Subclassed by a client and a server specific class
 *
 * @author Ming Yung
 */
abstract class HandShaker
{
  static final int CLIENT = 0;
  static final int SERVER = 1;

  protected static final int HANDSHAKE_NOT_YET_BEGUN = 0;
  protected static final int HANDSHAKE_IN_PROGRESS = 1;
  protected static final int HANDSHAKE_COMPLETED = 2;
  protected static final int HANDSHAKE_BAD_COMPLETION = 3;

  protected SSLContext ctx;
  protected SSLSocket socket;

  private Vector listeners = new Vector();

  private MessageDigest md5Hash;
  private MessageDigest shaHash;
  protected int pubKeyBitLength;

  protected SecureRandom rand;
  protected byte[] protocolVersion = V3Constants.VERSION;
  protected byte[] compressionMethods = V3Constants.COMPRESSION_METHODS;
  protected byte[] clientHelloRandom;
  protected byte[] serverHelloRandom;
  protected Key serverSigningKey;
  protected Key serverKeyXKey;
  protected Key clientSigningKey;
  protected Key clientKeyXKey;
  protected byte[] masterSecret;
  protected byte[] keyBlock;
  protected byte[] pendingCipherSuite;
  protected Vector enabledCipherSuites;
  protected boolean clientAuthRequired;

  private CertificateFactory certFact;
  private MessageDigest md5;
  private MessageDigest sha;
  private X509Certificate[] peerCertChain;

  protected int handShakeState;
  protected byte[] offeredId;
  protected byte[] currentId;
  protected boolean sessionReuse;
  protected TrustEngine trustEngine;

  private Object readLock;
  private Object writeLock;
  protected PreEmptor preEmptor;

  protected HandShaker(SSLSocket sock, SecureRandom rand, Vector trustedCerts) 
    throws SSLException
  {
    try
    {
      this.socket = sock;
      this.ctx = sock.getContext();
      this.rand = rand;
      this.trustEngine = new TrustEngine(trustedCerts);

      certFact = CertificateFactory.getInstance("X509");
      md5 = MessageDigest.getInstance("MD5");
      sha = MessageDigest.getInstance("SHA");

      reset();
    }
    catch (Exception e)
    {
      e.printStackTrace();
      throw new SSLException(e.getMessage());
    }
  }

  void startPreEmptor()
  {
    this.readLock = this.socket.getReadLock();
    this.writeLock = this.socket.getWriteLock();
    preEmptor = new PreEmptor();
    preEmptor.setDaemon(true);
    preEmptor.start();
  }

  boolean handShakeNotYetBegun()
  {
    return (this.handShakeState == HANDSHAKE_NOT_YET_BEGUN);
  }

  boolean handShakeStarted()
  {
    return (this.handShakeState == HANDSHAKE_IN_PROGRESS);
  }

  boolean handShakeCompleted()
  {
    return ((this.handShakeState == HANDSHAKE_COMPLETED) ||
            (this.handShakeState == HANDSHAKE_BAD_COMPLETION));
  }

  boolean badCompletion()
  {
    return (this.handShakeState == HANDSHAKE_BAD_COMPLETION);
  }

  void signalHandshakeCompleted()
  {
    SSLSession session;

    if (!sessionReuse)
    { 
      session = new SSLSession(socket, this.currentId, this.masterSecret,
                  CipherSuites.getSuiteName(this.pendingCipherSuite),
                  this.peerCertChain); 
      session.addToCache();

      if (Debug.debug >= Debug.DEBUG_MSG)
      {
        System.out.println("\nNew session");
        System.out.println(session.toString());
      }
    }
    else
    {
      session 
        = (SSLSession)socket.getSessionCache().getSession(this.currentId);
      session.updateAccessTime();
      socket.setSession(session);

      if (Debug.debug >= Debug.DEBUG_MSG)
      {
        System.out.println("\nReusing session");
        System.out.println(session.toString());
      }
    }
 
    HandshakeCompletedEvent ev
      = new HandshakeCompletedEvent(this.socket, session);

    for (int i=0; i<listeners.size(); i++)
    {
      ((HandshakeCompletedListener)
          listeners.elementAt(i)).handshakeCompleted(ev);
    }

    synchronized (writeLock)
    {
      writeLock.notifyAll();
    }

    synchronized (readLock)
    {
      readLock.notifyAll();
    }
  }

  void addHandshakeCompletedListener(HandshakeCompletedListener listener)
  {
    listeners.add(listener);
  }

  byte[] getProtocolVersion()
  {
    return this.protocolVersion;
  }

  void setProtocolVersion(byte[] version)
  {
    this.protocolVersion = version;
  }

  byte[] getCompressionMethods()
  {
    return this.compressionMethods;
  }

  void setCompressionMethods(byte[] methods)
  {
    this.compressionMethods = methods;
  }
 
  protected abstract void changeWriteCipher() throws SSLException;

  protected abstract void changeReadCipher() throws SSLException;

  protected abstract void nextMessage(SSLProtocolUnit pu) throws IOException;

  protected abstract void startHandShake() throws IOException;

  protected abstract void setSessionId(byte[] id);

  protected abstract byte[] getSessionID();

  protected void reset() throws SSLException
  {
    try
    {
      this.shaHash = MessageDigest.getInstance("SHA");
      this.md5Hash = MessageDigest.getInstance("MD5");
      this.pendingCipherSuite = null;
    }
    catch (Exception e)
    {
      e.printStackTrace();
      throw new SSLException(e.getMessage());
    }
  }

  /**
   * Returns a Vector of cipher suites to offer
   */
  protected Vector getEnabledCipherSuites()
  {
    return this.enabledCipherSuites;
  }

  /**
   * Sets the enabled cipher suites
   */
  protected void setEnabledCipherSuites(String suites[])
  {
    this.enabledCipherSuites = new Vector();

    for (int i=0; i<suites.length; i++)
    {
      for (int j=0; j< CipherSuites.suiteName.length; j++)
      {
        if (suites[i].equals(CipherSuites.suiteName[j]))
        {
          this.enabledCipherSuites.add(CipherSuites.suiteType[j]);
          break;
        }
      }
    }
  }

  protected void setPendingCipherSuite(byte[] type)
  {
    this.pendingCipherSuite = type;
    this.ctx.setPendingCipherSpec(type);
  }

  /**
   * Sets the serverHelloRandom to the supplied data
   */
  protected void setServerHelloRandom(byte[] data)
  {
    this.serverHelloRandom = data;
  }

  protected boolean isExportable()
  {
    return this.ctx.isPendingExportable();
  }

  protected boolean isDiffieHellmanKeyX()
  {
    return this.ctx.getPendingKeyXAlgName().startsWith("DH");
  }

  protected boolean isDiffieHellmanEphKeyX()
  {
    return this.ctx.getPendingKeyXAlgName().startsWith("DHE");
  }

  protected abstract void setPubKeyBitLength(X509Certificate cert)
    throws IOException;

  protected int getPubKeyBitLength()
  {
    return this.pubKeyBitLength;
  }

  /**
   * Updates both the MD5 and SHA-1 MessageDigests with the supplied data
   */
  protected void updateHashes(byte[] fragment)
  {
    md5Hash.update(fragment);
    shaHash.update(fragment);
  }

  //////////////////////////
  // ClientKeyExchange
  /////////////////////////

  protected final void computeMasterSecret(byte[] preMasterSecret)
  {
    this.masterSecret 
      = this.ctx.getMac().computeMasterSecret(preMasterSecret, 
                                              this.clientHelloRandom,
                                              this.serverHelloRandom);
  }

  protected final void computeKeyBlock()
  {
    this.keyBlock 
      = this.ctx.getMac().computeKeyBlock(this.masterSecret, 
                                          this.clientHelloRandom,
                                          this.serverHelloRandom);
  }
    
  //////////////////////////
  // ServerKeyExchange
  /////////////////////////

  protected final byte[] toBeSignedSKX(byte[] params, String alg) 
  {
    sha.update(this.clientHelloRandom);
    sha.update(this.serverHelloRandom);
    sha.update(params);
  
    byte[] shaDigest =  sha.digest();

    if (alg.equals("DSA"))
    {
      return shaDigest;
    }

    md5.update(this.clientHelloRandom);
    md5.update(this.serverHelloRandom);
    md5.update(params);
    byte temp[] = md5.digest();

    byte[] retval = new byte[36];
    System.arraycopy(temp, 0, retval, 0, 16);
  
    System.arraycopy(shaDigest, 0, retval, 16, 20);
    return retval;
  }

  ////////////////////////
  // CertificateVerify
  ///////////////////////

  protected final byte[] toBeSignedCV(String alg) throws IOException
  {
    if (alg.equals("RSA"))
    {
      return this.ctx.getMac().hash(shaHash, md5Hash, this.masterSecret,
                                    null, V3Constants.HASH_TYPE_BOTH);
    }
    else 
      return this.ctx.getMac().hash(shaHash, md5Hash, this.masterSecret,
                                  null, V3Constants.HASH_TYPE_SHA);
  }

  ///////////////////////
  // Finished
  //////////////////////

  protected final byte[] generateHashes() throws IOException
  {
    if (this instanceof ClientHandShaker)
      return computeHashes(CLIENT);
    else
      return computeHashes(SERVER);
  }

  protected final void checkHashes(byte[] hashes)
    throws IOException
  {
    byte[] computed;

    if (this instanceof ClientHandShaker)
      computed = computeHashes(SERVER);
    else
      computed = computeHashes(CLIENT);

    if(!Arrays.equals(computed, hashes))
    {
      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      ctx.getOutputStream().writeProtocolUnit(
        new Alert(Alert.FATAL, Alert.HANDSHAKE_FAILURE));
      throw new SSLHandshakeException("Bad handshake hashes");
    }
  }

  protected final byte[] computeHashes(int sender) throws IOException
  {
    return this.ctx.getMac().computeHashes(shaHash, md5Hash, 
        this.masterSecret, sender);
  }

  /////////////////////
  // Certificate
  ////////////////////

  protected byte[] processSenderCerts(Vector certs)
    throws IOException
  {
    if (certs == null)
      throw new SSLException("No available certificate");

    try
    {
      ByteArrayOutputStream bos = new ByteArrayOutputStream();
      XDROutputStream xos = new XDROutputStream(bos);

      X509Certificate cert = (X509Certificate)certs.elementAt(0);

      setPubKeyBitLength(cert);

      xos.writeVector(3, cert.getEncoded());

      for (int i=1; i<certs.size(); i++)
      {
        cert = (X509Certificate)certs.elementAt(i);
        xos.writeVector(3, cert.getEncoded());
      }
      xos.flush();

      return bos.toByteArray();
    }
    catch (CertificateEncodingException e)
    {
      e.printStackTrace();
      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      ctx.getOutputStream().writeProtocolUnit(
        new Alert(Alert.FATAL, Alert.HANDSHAKE_FAILURE));
      throw new SSLException("Bad certificate encoding");
    }
  }

  protected void processReceivedCerts(byte[] certData)
    throws IOException
  {
    try
    {
      XDRInputStream sis
        = new XDRInputStream(new ByteArrayInputStream(certData));

      Vector certs = new Vector();

      while(sis.available() != 0)
      {
        byte[] data = sis.readVector(3);
        ByteArrayInputStream bais = new ByteArrayInputStream(data);

        certs.addElement((X509Certificate)certFact.generateCertificate(bais));
      }

      peerCertChain = new X509Certificate[certs.size()];
      certs.toArray(peerCertChain);
      trustEngine.verifyCertChain(peerCertChain);

      X509Certificate serverCert = peerCertChain[0];
      PublicKey pub = serverCert.getPublicKey();

      if (this instanceof ClientHandShaker)
      {
        if (pub.getAlgorithm().equals("RSA") || 
            pub.getAlgorithm().equals("DSA"))
        {
          serverSigningKey = pub;
        }
        serverKeyXKey = pub;
      }
      else
      {
        if (pub.getAlgorithm().equals("DH") || 
            pub.getAlgorithm().equals("Diffie-Hellman"))
        {
          clientKeyXKey = pub;
        }
        else
        {
          clientSigningKey = pub;
        }
      }

      if (Debug.debug >= Debug.DEBUG_MSG)
      {
        Debug.debug("\nReceived certs");
        for (int i=0; i<peerCertChain.length; i++)
        {
          Debug.debug(peerCertChain[i].toString());
        }
        Debug.debug("");
      }
    }
    catch (CertificateException e)
    {
      e.printStackTrace();
      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      ctx.getOutputStream().writeProtocolUnit(
        new Alert(Alert.FATAL, Alert.BAD_CERTIFICATE));
      throw new SSLHandshakeException(e.getMessage());
    }
  }

  protected synchronized void cleanup() throws IOException
  {
    if (Debug.debug >= Debug.DEBUG_MSG)
      Debug.debug("start cleanup");

    ctx.getOutputStream().writeProtocolUnit(
      new Alert(Alert.WARNING, Alert.CLOSE_NOTIFY));

    if (Debug.debug >= Debug.DEBUG_MSG)
      Debug.debug("finish cleanup");
  }

  private final class PreEmptor extends Thread
  {
    public void run()
    {
      while (true)
      {
        synchronized (readLock)
        {
          try
          {
            while (handShakeState == HANDSHAKE_COMPLETED)
            {
              readLock.notifyAll();
              readLock.wait();
            }
    
            do
            {
              SSLProtocolUnit pu = ctx.getInputStream().readProtocolUnit();
              
              if (pu == null)
              {
                handShakeState = HANDSHAKE_BAD_COMPLETION;
                readLock.notifyAll();
                return;
              }
              else if (pu instanceof Alert)
              {
//System.out.println("********************** Alert !!!!!!!!!!!!!!!");
              }  
              else if (pu instanceof ApplicationData)
              {
//System.out.println("********************** AppData !!!!!!!!!!!!!!!");
              }  
              else
                ctx.nextHandShake(pu);
            }
            while (!ctx.getInputStream().bufferEmpty());
          }
          catch (InterruptedException e)
          {
            //wake up
          }
          catch (IOException e)
          {
            e.printStackTrace();
            handShakeState = HANDSHAKE_BAD_COMPLETION;
            try
            {
              ctx.getOutputStream().close();
            }
            catch (IOException ie)
            {
              ie.printStackTrace();
            }
            readLock.notifyAll();
            break;
          }
        }
      }
    }
  } 
}
