//////////////////////////////////////////////////////////////////////////// 
// 
// Copyright (C) DSTC Pty Ltd (ACN 052 372 577) 1993, 1994, 1995.
// Unpublished work.  All Rights Reserved.
// 
// The software contained on this media is the property of the
// DSTC Pty Ltd.  Use of this software is strictly in accordance
// with the license agreement in the accompanying LICENSE.DOC 
// file. If your distribution of this software does not contain 
// a LICENSE.DOC file then you have no rights to use this 
// software in any manner and should contact DSTC at the address 
// below to determine an appropriate licensing arrangement.
// 
//      DSTC Pty Ltd
//      Level 7, GP South
//      University of Queensland
//      St Lucia, 4072
//      Australia
//      Tel: +61 7 3365 4310
//      Fax: +61 7 3365 4311
//      Email: jcsi@dstc.qut.edu.au
// 
// This software is being provided "AS IS" without warranty of
// any kind.  In no event shall DSTC Pty Ltd be liable for
// damage of any kind arising out of or in connection with
// the use or performance of this software.
// 
//////////////////////////////////////////////////////////////////////////// 

package com.dstc.security.ssl;

import java.io.IOException;
import java.io.ByteArrayOutputStream;
import java.util.Iterator;
import java.util.Arrays;
import java.util.Vector;
import java.util.Random;
import java.math.BigInteger;

import java.security.SecureRandom;
import java.security.Signature;
import java.security.KeyPairGenerator;
import java.security.KeyPair;
import java.security.KeyFactory;
import java.security.Key;
import java.security.PublicKey;
import java.security.PrivateKey;
import java.security.MessageDigest;
import java.security.interfaces.RSAPublicKey;
import java.security.interfaces.DSAPublicKey;
import java.security.interfaces.DSAParams;
import java.security.cert.X509Certificate;
import java.security.spec.X509EncodedKeySpec;

import javax.crypto.KeyAgreement;
import javax.crypto.Cipher;
import javax.crypto.CipherInputStream;
import javax.crypto.CipherOutputStream;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.SecretKeySpec;
import javax.crypto.spec.DESKeySpec;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.DHPublicKeySpec;
import javax.crypto.spec.DHParameterSpec;
import javax.crypto.interfaces.DHPublicKey;
import javax.crypto.interfaces.DHPrivateKey;

import javax.net.ssl.SSLException;
import javax.net.ssl.SSLKeyException;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLProtocolException;

import com.dstc.security.x509.X500Name;

/**
 * A class for handling server-side handshaking.
 *
 * @author Ming Yung
 */
final class ServerHandShaker extends HandShaker
{
  private Vector serverCerts;
  private byte expectedHandShake;
  private byte expectedContentType;

  protected ServerHandShaker(SSLSocket sock, SecureRandom rand,
                          PrivateKey privKey, Vector serverCerts, 
                          Vector trustedCerts)
    throws SSLException
  {
    super(sock, rand, trustedCerts);

    if (privKey.getAlgorithm().equals("RSA") ||
        privKey.getAlgorithm().equals("DSA"))
    {
      this.serverSigningKey = privKey;
    }

    if (!privKey.getAlgorithm().equals("DSA"))
      this.serverKeyXKey = privKey;

    this.serverCerts = serverCerts;
    this.currentId = new byte[32];
  }

  protected void setSessionId(byte[] id)
  {
    SSLSession sess = (SSLSession)socket.getSessionCache().getSession(id);
    if (sess == null)
    {
      rand.nextBytes(this.currentId);
      long time = System.currentTimeMillis();
      this.currentId[0] = (byte)((time >> 24) & 0xff);
      this.currentId[1] = (byte)((time >> 16) & 0xff);
      this.currentId[2] = (byte)((time >> 8) & 0xff);
      this.currentId[3] = (byte)(time & 0xff);
    }
    else
    {
      this.currentId = id;
      this.sessionReuse = true;
      this.masterSecret = sess.getMasterSecret();
    }
  }

  protected byte[] getSessionID()
  {
    return this.currentId;
  }

  protected void changeWriteCipher() throws SSLException
  {
     ctx.setKeyData(SSLContext.SERVER, SSLContext.WRITE, this.keyBlock,
                    this.serverHelloRandom, this.clientHelloRandom);
  }

  protected void changeReadCipher() throws SSLException
  {
     ctx.setKeyData(SSLContext.CLIENT, SSLContext.READ, this.keyBlock,
                    this.serverHelloRandom, this.clientHelloRandom);
  }

  protected void reset() throws SSLException
  {
    super.reset();
    
    if (this.serverSigningKey != null &&
        !this.serverSigningKey.getAlgorithm().equals("DSA"))
      this.serverKeyXKey = this.serverSigningKey;

    expectedHandShake = HandShake.CLIENT_HELLO;
    expectedContentType = SSLProtocolUnit.HANDSHAKE;
  }

  protected void setClientAuthRequired(boolean flag)
  {
    clientAuthRequired = flag;
  }

  //Caution: Error detection not quite right yet
  protected void nextMessage(SSLProtocolUnit pu) throws IOException
  {
    if (pu.getContentType() != SSLProtocolUnit.ALERT &&
        pu.getContentType() != expectedContentType)
    {
      this.ctx.getOutputStream().writeProtocolUnit(
        new Alert(Alert.FATAL, Alert.UNEXPECTED_MESSAGE));

      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      throw new SSLHandshakeException("Messages out of order");
    }

    if (pu instanceof V2ClientHello)
    {
      respondToClientHello();
      if (clientAuthRequired)
        expectedHandShake = HandShake.CERTIFICATE;
      else
        expectedHandShake = -1;
    }
    else if (pu.getContentType() == SSLProtocolUnit.HANDSHAKE)
    {
      byte msgType = ((HandShake)pu).getMessageType();

      if (expectedHandShake != -1 && msgType != expectedHandShake)
      {
        this.ctx.getOutputStream().writeProtocolUnit(
          new Alert(Alert.FATAL, Alert.UNEXPECTED_MESSAGE));

        this.handShakeState = HANDSHAKE_BAD_COMPLETION;
        throw new SSLHandshakeException("Messages out of order");
      }

      switch (msgType)
      {
        case HandShake.CLIENT_HELLO:
          respondToClientHello();
          if (clientAuthRequired && !sessionReuse)
            expectedHandShake = HandShake.CERTIFICATE;
          else
            expectedHandShake = -1;
          break;

        case HandShake.CLIENT_KEY_EXCHANGE:
          if (clientAuthRequired && clientSigningKey != null)
            expectedHandShake = HandShake.CERTIFICATE_VERIFY;
          else
          {
            expectedContentType = SSLProtocolUnit.CHANGE_CIPHER_SPEC;
            expectedHandShake = -1;
          }
          break;

        case HandShake.CERTIFICATE:
          expectedHandShake = HandShake.CLIENT_KEY_EXCHANGE;
          break;

        case HandShake.CERTIFICATE_VERIFY:
          expectedHandShake = HandShake.FINISHED;
          expectedContentType = SSLProtocolUnit.CHANGE_CIPHER_SPEC;
          break;

        case HandShake.FINISHED:
          if (!sessionReuse)
          {
            ctx.getOutputStream().writeProtocolUnit(new ChangeCipherSpec());
            this.changeWriteCipher();

            Finished finished = new Finished(this);
            ctx.getOutputStream().writeProtocolUnit(finished);
          }

          this.handShakeState = HANDSHAKE_COMPLETED;
          signalHandshakeCompleted();
          reset();
          break;

        default:
          this.handShakeState = HANDSHAKE_BAD_COMPLETION;
          ctx.getOutputStream().writeProtocolUnit(
            new Alert(Alert.FATAL, Alert.UNEXPECTED_MESSAGE));
          throw new SSLProtocolException("Bad message");
      }
    }
    else if (pu.getContentType() == SSLProtocolUnit.CHANGE_CIPHER_SPEC)
    {
      if (!sessionReuse)
        computeKeyBlock();
      this.changeReadCipher();
      expectedContentType = SSLProtocolUnit.HANDSHAKE;
    }
    else if (pu.getContentType() == SSLProtocolUnit.ALERT)
    {
      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      throw new SSLHandshakeException("hand shake failure");
    }
  }

  private void respondToClientHello() throws IOException
  {
    this.handShakeState = HANDSHAKE_IN_PROGRESS;
    ServerHello serverHello = new ServerHello(this);
    ctx.getOutputStream().writeProtocolUnit(serverHello);
 
    if (sessionReuse)
    {
      ctx.getOutputStream().writeProtocolUnit(new ChangeCipherSpec());

      computeKeyBlock();
      changeWriteCipher();

      Finished finished = new Finished(this);
      ctx.getOutputStream().writeProtocolUnit(finished);

      expectedContentType = SSLProtocolUnit.CHANGE_CIPHER_SPEC;
      return;
    }

    Certificate certificate = new Certificate(this, this.serverCerts);

    ctx.getOutputStream().writeProtocolUnit(certificate);
 
    if (isDiffieHellmanEphKeyX())
    {
      int keyLength = 1024;
      if (isExportable())
        keyLength = 512;
      ServerKeyExchange serverKeyX = new DHServerKeyExchange(this, keyLength);
      ctx.getOutputStream().writeProtocolUnit(serverKeyX);
    }
    else if (isExportable())
    {
      ServerKeyExchange serverKeyX = new RSAServerKeyExchange(this, 512);
      ctx.getOutputStream().writeProtocolUnit(serverKeyX);
    }

    if (clientAuthRequired)
    {
      if (trustEngine.getTrustedCerts() == null)
      {
        this.handShakeState = HANDSHAKE_BAD_COMPLETION;
        throw new SSLKeyException("No trusted CA certs set");
      }

      Vector cas = new Vector();

      Iterator it = trustEngine.getTrustedCerts().iterator();
      while (it.hasNext())
      {
        cas.addElement(new X500Name(((X509Certificate)
          it.next()).getSubjectDN().getName()));
      }

      CertificateRequest certReq 
        = new CertificateRequest(this, 
           new byte[] {(byte)0x01, (byte)0x02}, cas);
      ctx.getOutputStream().writeProtocolUnit(certReq);
    }

    ServerHelloDone serverHelloDone = new ServerHelloDone(this);
    ctx.getOutputStream().writeProtocolUnit(serverHelloDone);
  }

  protected void startHandShake() throws IOException
  {
    reset();

    handShakeState = HANDSHAKE_IN_PROGRESS;

    HelloRequest helloRequest = new HelloRequest(this);
    ctx.getOutputStream().writeProtocolUnit(helloRequest);
  }

  /**
   * Sets the ClientHelloRandom to the supplied data
   */
  protected void setClientHelloRandom(byte[] data)
  {
    this.clientHelloRandom = data;
  }

  /**
   * Sets the pending cipher suite based on an offered Vector of cipher suites
   */
  protected void setCipherSuite(Vector suites) throws IOException
  {
    if (Debug.debug >= Debug.DEBUG_MSG)
    {
      Debug.debug("\nOffered Cipher Suite");
      for (int i=0; i< suites.size(); i++)
      {
        Debug.debug(CipherSuites.getSuiteName(
          (byte[])suites.elementAt(i)));
      }
      Debug.debug("");
    }

    byte enabledSuites[][] = new byte[enabledCipherSuites.size()][];
    enabledCipherSuites.toArray(enabledSuites);

    String keyAlg 
      = ((X509Certificate)serverCerts.elementAt(0)
          ).getPublicKey().getAlgorithm();
    if (keyAlg.equals("Diffie-Hellman"))
      keyAlg = "DH";

    Vector supportedSuites = new Vector();
    for (int i=0; i<suites.size(); i++)
    {
      byte[] suite = (byte[])suites.elementAt(i);
      for (int j=0; j<enabledSuites.length; j++)
      {
        if (Arrays.equals(suite, enabledSuites[j]))
        {
          supportedSuites.addElement(suite);
          break;
        }
      }
    }

    for (int i=0; i< supportedSuites.size(); i++)
    {
      byte[] suite = (byte[])supportedSuites.elementAt(i);
      String keyXAlg = CipherSuites.getKeyXAlgName(suite);

      //if ((keyXAlg.indexOf("RSA") != -1) && keyAlg.equals("RSA"))
      if ((keyXAlg.indexOf("RSA") != -1))
      {
        setPendingCipherSuite(suite);
        break;
      }
      else if ((keyXAlg.indexOf("DHE_RSA") != -1) && keyAlg.equals("RSA"))
      {
        setPendingCipherSuite(suite);
        break;
      }
      else if ((keyXAlg.indexOf("DHE_DSS") != -1) && keyAlg.equals("DSA"))
      {
        setPendingCipherSuite(suite);
        break;
      }
      else if ((keyXAlg.equals("DH") || keyXAlg.equals("DH_EXP"))
                && keyAlg.equals("DH"))
      {
        setPendingCipherSuite(suite);
        break;
      }
    }

    if (this.pendingCipherSuite == null)
    {
      this.ctx.getOutputStream().writeProtocolUnit(
        new Alert(Alert.FATAL, Alert.HANDSHAKE_FAILURE));

      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      throw new SSLHandshakeException("Cannot support offered cipher suites");
    }
  }

  /**
   * Returns the pending cipher suite
   */
  protected byte[] getPendingCipherSuite()
  {
    if (Debug.debug >= Debug.DEBUG_MSG)
    {
      Debug.debug("\nChosen Cipher Suite");
      Debug.debug(
        CipherSuites.getSuiteName(this.pendingCipherSuite) + "\n");
    }

    return this.pendingCipherSuite;
  }

  ////////////////////////
  // ServerHello
  ////////////////////////

  protected final byte[] getServerHelloRandom()
  {
    byte[] retval = new byte[32];
    this.rand.nextBytes(retval);
    this.serverHelloRandom = (byte[])retval.clone();
    return retval;
  }

  ////////////////////////
  // Certificate
  ////////////////////////

  protected void setPubKeyBitLength(X509Certificate cert)
    throws IOException
  {
    try
    {
      if (cert.getPublicKey().getAlgorithm().equals("RSA"))
      {
        this.pubKeyBitLength =
          ((RSAPublicKey)cert.getPublicKey()).getModulus().bitLength();
      }
      else if (cert.getPublicKey().getAlgorithm().equals("DSA"))
      {
        this.pubKeyBitLength =
          ((DSAParams)((DSAPublicKey)cert.getPublicKey()).getParams()).
            getP().bitLength();
      }
      else
      {
        KeyFactory dhKeyFact = KeyFactory.getInstance("DH");
        X509EncodedKeySpec keySpec
          = new X509EncodedKeySpec(cert.getPublicKey().getEncoded());
        DHPublicKey dhPub = (DHPublicKey)dhKeyFact.generatePublic(keySpec);

        this.pubKeyBitLength = ((DHParameterSpec)dhPub.getParams()).
            getP().bitLength();
      }
    }
    catch (Exception e)
    {
      e.printStackTrace();
      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      ctx.getOutputStream().writeProtocolUnit(
        new Alert(Alert.FATAL, Alert.HANDSHAKE_FAILURE));
      throw new SSLKeyException(e.getMessage());
    }
  }

  /////////////////////////
  // ClientKeyExchange
  ////////////////////////

  protected void setMasterSecret(byte[] exchangeKeys)
    throws IOException
  {
    try
    {
      if (Debug.debug >= Debug.DEBUG_CRYPTO)
      {
        Debug.debug("exchangeKeys: ", exchangeKeys);
      }

      byte[] preMasterSecret = null;

      if (this.serverKeyXKey.getAlgorithm().equals("DH") ||
          this.serverKeyXKey.getAlgorithm().equals("Diffie-Hellman"))
      {
        KeyAgreement keyAg = KeyAgreement.getInstance("DH");
        keyAg.init(this.serverKeyXKey, this.rand);

        KeyFactory keyFact = KeyFactory.getInstance("DH");
        DHPublicKey clientKeyXKey =
        (DHPublicKey)keyFact.generatePublic(new DHPublicKeySpec(
          new BigInteger(1, exchangeKeys), 
          ((DHPrivateKey)this.serverKeyXKey).getParams().getP(),
          ((DHPrivateKey)this.serverKeyXKey).getParams().getG()));

        keyAg.doPhase(clientKeyXKey, true);
        preMasterSecret = keyAg.generateSecret();
      }
      else
      {
        Cipher rsa = Cipher.getInstance("RSA");
        rsa.init(Cipher.DECRYPT_MODE, this.serverKeyXKey);
        preMasterSecret = rsa.doFinal(exchangeKeys);
      }

      if (Debug.debug >= Debug.DEBUG_CRYPTO)
      {
        Debug.debug("isDiffieHellmanKeyX(): " + 
          isDiffieHellmanKeyX());
        Debug.debug("preMasterSecret: ", preMasterSecret);
      }

      computeMasterSecret(preMasterSecret);
      //computeKeyBlock();
    }
    catch (Exception e)
    {
      e.printStackTrace();
      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      ctx.getOutputStream().writeProtocolUnit(
        new Alert(Alert.FATAL, Alert.HANDSHAKE_FAILURE));
      throw new SSLHandshakeException(e.getMessage());
    }
  }

  ///////////////////////
  // ServerKeyExchange
  ///////////////////////

  protected RSAPublicKey generateTempRSAPubKey()
    throws IOException
  {
    try
    {
      KeyPairGenerator kpg = KeyPairGenerator.getInstance("RSA");
      kpg.initialize(512, this.rand);

      KeyPair kp = kpg.generateKeyPair();
      this.serverKeyXKey = kp.getPrivate();
      return (RSAPublicKey)kp.getPublic();
    }
    catch (Exception e)
    {
      e.printStackTrace();
      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      ctx.getOutputStream().writeProtocolUnit(
        new Alert(Alert.FATAL, Alert.HANDSHAKE_FAILURE));
      throw new SSLException(e.getMessage());
    }
  }

  protected DHPublicKey generateTempDHPubKey(int keyLength)
    throws IOException
  {
    try
    {
      KeyPairGenerator kpg = KeyPairGenerator.getInstance("DH");
      kpg.initialize(keyLength, this.rand);

      KeyPair kp = kpg.generateKeyPair();
      this.serverKeyXKey = kp.getPrivate();
      this.pubKeyBitLength =
        ((DHPublicKey)kp.getPublic()).getY().bitLength();
      return (DHPublicKey)kp.getPublic();
    }
    catch (Exception e)
    {
      e.printStackTrace();
      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      ctx.getOutputStream().writeProtocolUnit(
        new Alert(Alert.FATAL, Alert.HANDSHAKE_FAILURE));
      throw new SSLException(e.getMessage());
    }
  }

  protected final byte[] generateServerSignature(byte[] params) 
    throws IOException
  {
    try
    {
      String alg = this.serverSigningKey.getAlgorithm();
      Signature sig = Signature.getInstance("Raw" + alg);
      sig.initSign((PrivateKey)this.serverSigningKey, this.rand);
      sig.update(this.toBeSignedSKX(params, alg));
      return sig.sign();
    }
    catch (Exception e)
    {
      e.printStackTrace();
      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      ctx.getOutputStream().writeProtocolUnit(
        new Alert(Alert.FATAL, Alert.HANDSHAKE_FAILURE));
      throw new SSLException(e.getMessage());
    }
  }

  /////////////////////////
  // CertificateVerify
  ////////////////////////

  protected final void verifyClientSignature(byte[] signature)
    throws IOException
  {
    try
    {
      String alg = this.clientSigningKey.getAlgorithm();
      Signature sig = Signature.getInstance("Raw" + alg);
      sig.initVerify((PublicKey)this.clientSigningKey);
      sig.update(this.toBeSignedCV(alg));

      if (!sig.verify(signature))
      {
        this.handShakeState = HANDSHAKE_BAD_COMPLETION;
        ctx.getOutputStream().writeProtocolUnit(new Alert(Alert.FATAL, 
          Alert.HANDSHAKE_FAILURE));
        throw new SSLHandshakeException("Bad signature");
      }
    }
    catch (Exception e)
    {
      e.printStackTrace();
      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      ctx.getOutputStream().writeProtocolUnit(
        new Alert(Alert.FATAL, Alert.HANDSHAKE_FAILURE));
      throw new SSLException(e.getMessage());
    }
  }

  ///////////////////////
  // CertificateRequest
  //////////////////////

  protected final byte[] processAcceptableCAs(Vector cAs)
    throws IOException
  {
    try
    {
      ByteArrayOutputStream bos = new ByteArrayOutputStream();
      XDROutputStream xos = new XDROutputStream(bos);

      for (int i=0; i<cAs.size(); i++)
      {
        X500Name dn = (X500Name)cAs.elementAt(i);

        if (Debug.debug >= Debug.DEBUG_MSG)
          dn.info();
        xos.writeVector(2, dn.encode());
      }
      xos.flush();

      return bos.toByteArray();
    }
    catch (Exception e)
    {
      e.printStackTrace();
      this.handShakeState = HANDSHAKE_BAD_COMPLETION;
      ctx.getOutputStream().writeProtocolUnit(
        new Alert(Alert.FATAL, Alert.HANDSHAKE_FAILURE));
      throw new SSLException(e.getMessage());
    }
  }
}
