//////////////////////////////////////////////////////////////////////////// 
// 
// Copyright (C) DSTC Pty Ltd (ACN 052 372 577) 1993, 1994, 1995.
// Unpublished work.  All Rights Reserved.
// 
// The software contained on this media is the property of the
// DSTC Pty Ltd.  Use of this software is strictly in accordance
// with the license agreement in the accompanying LICENSE.DOC 
// file. If your distribution of this software does not contain 
// a LICENSE.DOC file then you have no rights to use this 
// software in any manner and should contact DSTC at the address 
// below to determine an appropriate licensing arrangement.
// 
//      DSTC Pty Ltd
//      Level 7, GP South
//      University of Queensland
//      St Lucia, 4072
//      Australia
//      Tel: +61 7 3365 4310
//      Fax: +61 7 3365 4311
//      Email: jcsi@dstc.qut.edu.au
// 
// This software is being provided "AS IS" without warranty of
// any kind.  In no event shall DSTC Pty Ltd be liable for
// damage of any kind arising out of or in connection with
// the use or performance of this software.
// 
//////////////////////////////////////////////////////////////////////////// 

package com.dstc.security.ssl;

import java.io.IOException;
import java.io.ByteArrayInputStream;
import java.util.Arrays;
import java.util.Vector;
import java.util.Calendar;
import java.math.BigInteger;

import java.security.SecureRandom;
import java.security.Signature;
import java.security.PublicKey;
import java.security.PrivateKey;
import java.security.KeyPairGenerator;
import java.security.KeyPair;
import java.security.Key;
import java.security.KeyFactory;
import java.security.MessageDigest;
import java.security.interfaces.RSAPublicKey;
import java.security.interfaces.DSAPublicKey;
import java.security.interfaces.DSAParams;
import java.security.spec.RSAPublicKeySpec;
import java.security.spec.X509EncodedKeySpec;
import java.security.cert.X509Certificate;

import javax.crypto.Cipher;
import javax.crypto.CipherInputStream;
import javax.crypto.CipherOutputStream;
import javax.crypto.KeyAgreement;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.SecretKeySpec;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.DHPublicKeySpec;
import javax.crypto.spec.DHParameterSpec;
import javax.crypto.interfaces.DHPublicKey;

import javax.net.ssl.SSLException;
import javax.net.ssl.SSLKeyException;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLProtocolException;

import com.dstc.security.x509.X500Name;
import com.dstc.security.asn1.Asn1;

/**
 * A class for handling client-side hand shaking
 *
 * @author Ming Yung
 */
final class ClientHandShaker extends HandShaker
{
  private Vector clientCerts;
  private int expectedContentType;

  protected ClientHandShaker(SSLSocket sock, SecureRandom rand,
                          PrivateKey privKey, Vector clientCerts, 
                          Vector trustedCerts)
    throws SSLException
  {
    super(sock, rand, trustedCerts);

    this.clientCerts = clientCerts;

    if (privKey == null)
      return;

    if (privKey.getAlgorithm().equals("RSA") ||
        privKey.getAlgorithm().equals("DSA"))
    {
      this.clientSigningKey = privKey;
    }
    else
    {
      this.clientKeyXKey = privKey;
    }
  }

  protected void setSessionId(byte[] id)
  {
    this.currentId = id;
  }

  protected byte[] getSessionID()
  {
    SSLSession sess = socket.getSessionCache().getFirstValidSession();

    if (sess == null)
      offeredId = new byte[0];
    else
    {
      offeredId = sess.getId();
      this.masterSecret = sess.getMasterSecret();
    }

    return offeredId;
  }

  protected void changeWriteCipher() throws SSLException
  {
     ctx.setKeyData(SSLContext.CLIENT, SSLContext.WRITE, this.keyBlock,
                    this.serverHelloRandom, this.clientHelloRandom);
  }

  protected void changeReadCipher() throws SSLException
  {
     ctx.setKeyData(SSLContext.SERVER, SSLContext.READ, this.keyBlock,
                    this.serverHelloRandom, this.clientHelloRandom);
  }

  //Caution: No checks on message ordering at present
  protected void nextMessage(SSLProtocolUnit pu) throws IOException
  {
    if (pu.getContentType() == SSLProtocolUnit.HANDSHAKE)
    {
      byte msgType = ((HandShake)pu).getMessageType();
      switch (msgType)
      {
        case HandShake.HELLO_REQUEST:
          startHandShake();
          break;

        case HandShake.SERVER_HELLO:
          if (this.offeredId.length > 0)
          {
            if (Arrays.equals(this.offeredId, this.currentId))
            {
              sessionReuse = true;
              computeKeyBlock();
            }
          }
          break;

        case HandShake.CERTIFICATE_REQUEST:
          clientAuthRequired = true;
          break;

        case HandShake.SERVER_HELLO_DONE:
          if (clientAuthRequired)
          {
            if (this.clientCerts == null)
            {
              ctx.getOutputStream().writeProtocolUnit(
                new Alert(Alert.FATAL, Alert.NO_CERTIFICATE));
              this.handShakeState = HANDSHAKE_BAD_COMPLETION;

              throw new IOException("No client certificate");
            }
    
            Certificate certificate = new Certificate(this, this.clientCerts);
    
            ctx.getOutputStream().writeProtocolUnit(certificate);
          }

          if (this.serverKeyXKey.getAlgorithm().equals("RSA"))
          {
            ClientKeyExchange clientKeyX = new RSAClientKeyExchange(this);
            ctx.getOutputStream().writeProtocolUnit(clientKeyX);
          }
          else
          {
            ClientKeyExchange clientKeyX = new DHClientKeyExchange(this);
            ctx.getOutputStream().writeProtocolUnit(clientKeyX);
          }

          if (clientAuthRequired && clientSigningKey != null)
          {
            CertificateVerify certificateVerify
              = new CertificateVerify(this);
    
            ctx.getOutputStream().writeProtocolUnit(certificateVerify);
          }

          ctx.getOutputStream().writeProtocolUnit(new ChangeCipherSpec());

          if (Debug.debug >= Debug.DEBUG_MSG)
            Debug.debug("Sent ChangeCipherSpec");
    
          this.computeKeyBlock();
          this.changeWriteCipher();
    
          Finished finished = new Finished(this);
          ctx.getOutputStream().writeProtocolUnit(finished);
    
          break;

        case HandShake.CERTIFICATE:
          break;

        case HandShake.SERVER_KEY_EXCHANGE:
          break;

        case HandShake.FINISHED:
          if (sessionReuse)
          {
            ctx.getOutputStream().writeProtocolUnit(new ChangeCipherSpec());

            this.changeWriteCipher();
        
            ctx.getOutputStream().writeProtocolUnit(new Finished(this));
          }

          this.handShakeState = HANDSHAKE_COMPLETED;
          
          signalHandshakeCompleted();
          break;

        default:
          this.handShakeState = HANDSHAKE_BAD_COMPLETION;
          ctx.getOutputStream().writeProtocolUnit(
            new Alert(Alert.FATAL, Alert.UNEXPECTED_MESSAGE));
          throw new SSLProtocolException("Bad message");
      }
    }
    else if (pu.getContentType() == SSLProtocolUnit.CHANGE_CIPHER_SPEC)
    {
      this.changeReadCipher();
    }
    else if (pu.getContentType() == SSLProtocolUnit.ALERT)
    {
      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      throw new SSLHandshakeException("handshake failure");
    }
  }

  protected void startHandShake() throws IOException
  {
    reset();

    handShakeState = HANDSHAKE_IN_PROGRESS;

    ClientHello clientHello = new ClientHello(this);
    ctx.getOutputStream().writeProtocolUnit(clientHello);
  }

  //////////////////
  // ClientHello
  /////////////////

  protected byte[] getClientHelloRandom()
  {
    this.clientHelloRandom = new byte[32];
    this.rand.nextBytes(this.clientHelloRandom);
    long time = Calendar.getInstance().getTime().getTime();
    this.clientHelloRandom[0] = (byte)((time >> 24) & 0xff);
    this.clientHelloRandom[1] = (byte)((time >> 16) & 0xff);
    this.clientHelloRandom[2] = (byte)((time >> 8) & 0xff);
    this.clientHelloRandom[3] = (byte)(time & 0xff);
    return (byte[])this.clientHelloRandom.clone();
  }

  /////////////////////
  // Certificate
  /////////////////////

  protected void setPubKeyBitLength(X509Certificate cert)
    throws IOException
  {
    try
    {
      String keyAlg = cert.getPublicKey().getAlgorithm();

      if (keyAlg.equals("RSA"))
      {
        this.pubKeyBitLength =
          ((RSAPublicKey)cert.getPublicKey()).getModulus().bitLength();
      }
      else if (keyAlg.equals("DSA"))
      {
        this.pubKeyBitLength =
          ((DSAParams)((DSAPublicKey)cert.getPublicKey()).getParams()).
            getP().bitLength();
      }
      else
      {
        KeyFactory dhKeyFact = KeyFactory.getInstance("DH");
        X509EncodedKeySpec keySpec
          = new X509EncodedKeySpec(cert.getPublicKey().getEncoded());
        DHPublicKey dhPub = (DHPublicKey)dhKeyFact.generatePublic(keySpec);

        this.pubKeyBitLength =  ((DHParameterSpec)dhPub.getParams()).
            getP().bitLength();
      }
    }
    catch (Exception e)
    {
      e.printStackTrace();
      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      ctx.getOutputStream().writeProtocolUnit(
        new Alert(Alert.FATAL, Alert.HANDSHAKE_FAILURE));
      throw new SSLKeyException(e.getMessage());
    }
  }

  /////////////////////
  // ClientKeyExchange
  /////////////////////

  protected byte[] getEncryptedPreMasterSecret()
    throws IOException
  {
    byte preMasterSecret[] = new byte[48];
    this.rand.nextBytes(preMasterSecret);
    System.arraycopy(this.protocolVersion, 0, preMasterSecret, 0, 2);

    if (Debug.debug >= Debug.DEBUG_CRYPTO)
    {
      Debug.debug("preMasterSecret: ", preMasterSecret);
    }

    computeMasterSecret(preMasterSecret);

    try
    {
      Cipher rsa = Cipher.getInstance("RSA");
      rsa.init(Cipher.ENCRYPT_MODE, this.serverKeyXKey);
  
      byte[] encryptedPreMasterSecret = rsa.doFinal(preMasterSecret);
      return encryptedPreMasterSecret;
    }
    catch (Exception e)
    {
      e.printStackTrace();
      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      ctx.getOutputStream().writeProtocolUnit(
        new Alert(Alert.FATAL, Alert.HANDSHAKE_FAILURE));
      throw new SSLException(e.getMessage());
    }
  }

  protected byte[] getClientDiffieHellmanPublic()
    throws IOException
  {
    try
    {
      DHPublicKey pubKey = null;

      KeyFactory dhKeyFact = KeyFactory.getInstance("DH");
      X509EncodedKeySpec keySpec 
        = new X509EncodedKeySpec(serverKeyXKey.getEncoded());
      DHPublicKey serverDHPub 
        = (DHPublicKey)dhKeyFact.generatePublic(keySpec);
  
      if (this.clientKeyXKey != null)
      {
        keySpec 
          = new X509EncodedKeySpec(
             ((X509Certificate)clientCerts.elementAt(0)
               ).getPublicKey().getEncoded());
        pubKey = (DHPublicKey)dhKeyFact.generatePublic(keySpec);
      }
      else
      {
        KeyPairGenerator kpg = KeyPairGenerator.getInstance("DH");
        kpg.initialize(serverDHPub.getParams(), this.rand);

        KeyPair kp = kpg.generateKeyPair();
        this.clientKeyXKey = kp.getPrivate();
        pubKey = (DHPublicKey)kp.getPublic();
      }

      KeyAgreement keyAg = KeyAgreement.getInstance("DH");
      keyAg.init(this.clientKeyXKey, this.rand);
      keyAg.doPhase(serverDHPub, true);

      byte[] preMasterSecret = keyAg.generateSecret();

      if (Debug.debug >= Debug.DEBUG_CRYPTO)
      {
        Debug.debug("preMasterSecret: ", preMasterSecret);
      }

      computeMasterSecret(preMasterSecret);
      //computeKeyBlock();

      byte[] exchangeKeys = pubKey.getY().toByteArray();
      
      if (Debug.debug >= Debug.DEBUG_CRYPTO)
      {
        Debug.debug("exchangeKeys: ", exchangeKeys);
      }

      if (exchangeKeys[0] == 0)
      {
        byte retval[] = new byte[exchangeKeys.length-1];
        System.arraycopy(exchangeKeys, 1, retval, 0, retval.length);
        return retval;
      }
      else
        return exchangeKeys;
    }
    catch (Exception e)
    { 
      e.printStackTrace();
      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      ctx.getOutputStream().writeProtocolUnit(
        new Alert(Alert.FATAL, Alert.HANDSHAKE_FAILURE));
      throw new SSLException(e.getMessage());
    }
  }

  //////////////////////
  // ServerKeyExchange
  /////////////////////

  protected void setServerEncryptionKey(byte[] modulus, byte[] exponent)
    throws IOException
  {
    try
    {
      KeyFactory keyFact = KeyFactory.getInstance("RSA");
 
      this.serverKeyXKey =
        (RSAPublicKey)keyFact.generatePublic(new RSAPublicKeySpec(
          new BigInteger(1, modulus),
          new BigInteger(1, exponent)));
    }
    catch (Exception e)
    { 
      e.printStackTrace();
      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      ctx.getOutputStream().writeProtocolUnit(
        new Alert(Alert.FATAL, Alert.HANDSHAKE_FAILURE));
      throw new SSLException(e.getMessage());
    }
  }

  protected void setServerEncryptionKey(byte[] p, byte[] g, byte[] y)
    throws IOException
  {
    try
    {
      KeyFactory keyFact = KeyFactory.getInstance("DH");
 
      this.serverKeyXKey =
        (DHPublicKey)keyFact.generatePublic(new DHPublicKeySpec(
          new BigInteger(1, y), new BigInteger(1, p),
          new BigInteger(1, g)));
    }
    catch (Exception e)
    { 
      e.printStackTrace();
      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      ctx.getOutputStream().writeProtocolUnit(
        new Alert(Alert.FATAL, Alert.HANDSHAKE_FAILURE));
      throw new SSLException(e.getMessage());
    }
  }

  protected final void verifyServerSignature(byte[] params, byte[] signature)
    throws IOException
  {
    try
    {
      String alg = this.serverSigningKey.getAlgorithm();
      Signature sig = Signature.getInstance("Raw" + alg);
      sig.initVerify((PublicKey)this.serverSigningKey);
      sig.update(this.toBeSignedSKX(params, alg));
  
      if (!sig.verify(signature))
      {
        this.handShakeState = HANDSHAKE_BAD_COMPLETION;
        ctx.getOutputStream().writeProtocolUnit(new Alert(Alert.FATAL, 
          Alert.HANDSHAKE_FAILURE));
        throw new SSLHandshakeException("Bad signature");
      }
    }
    catch (Exception e)
    {
      e.printStackTrace();
      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      ctx.getOutputStream().writeProtocolUnit(
        new Alert(Alert.FATAL, Alert.HANDSHAKE_FAILURE));
      throw new SSLHandshakeException(e.getMessage());
    }
  }

  ///////////////////////
  // CertificateVerify
  ///////////////////////

  protected final byte[] generateClientSignature()
    throws IOException
  {
    try
    {
      String alg = this.clientSigningKey.getAlgorithm();
      Signature sig = Signature.getInstance("Raw" + alg);
      sig.initSign((PrivateKey)this.clientSigningKey, this.rand);
      sig.update(this.toBeSignedCV(alg));

      return sig.sign();
    }
    catch (Exception e)
    {
      e.printStackTrace();
      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      ctx.getOutputStream().writeProtocolUnit(
        new Alert(Alert.FATAL, Alert.HANDSHAKE_FAILURE));
      throw new SSLException(e.getMessage());
    }
  }

  ////////////////////////
  // CertificateRequest
  ///////////////////////

  protected final void processAcceptableCAs(byte[] cas)
    throws IOException
  {
    try
    {
      XDRInputStream sis 
        = new XDRInputStream(new ByteArrayInputStream(cas));

      while(sis.available() != 0)
      {
        byte[] data = sis.readVector(2);

        X500Name dn = new X500Name(data);
       
        if (Debug.debug >= Debug.DEBUG_MSG)
          dn.info();
      }
    }
    catch (Exception e)
    {
      e.printStackTrace();
      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      ctx.getOutputStream().writeProtocolUnit(
        new Alert(Alert.FATAL, Alert.HANDSHAKE_FAILURE));
      throw new SSLHandshakeException(e.getMessage());
    }
  }
}
