//////////////////////////////////////////////////////////////////////////// 
// 
// Copyright (C) DSTC Pty Ltd (ACN 052 372 577) 1993, 1994, 1995.
// Unpublished work.  All Rights Reserved.
// 
// The software contained on this media is the property of the
// DSTC Pty Ltd.  Use of this software is strictly in accordance
// with the license agreement in the accompanying LICENSE.DOC 
// file. If your distribution of this software does not contain 
// a LICENSE.DOC file then you have no rights to use this 
// software in any manner and should contact DSTC at the address 
// below to determine an appropriate licensing arrangement.
// 
//      DSTC Pty Ltd
//      Level 7, GP South
//      University of Queensland
//      St Lucia, 4072
//      Australia
//      Tel: +61 7 3365 4310
//      Fax: +61 7 3365 4311
//      Email: jcsi@dstc.qut.edu.au
// 
// This software is being provided "AS IS" without warranty of
// any kind.  In no event shall DSTC Pty Ltd be liable for
// damage of any kind arising out of or in connection with
// the use or performance of this software.
// 
//////////////////////////////////////////////////////////////////////////// 

package com.dstc.security.ssl;

import java.util.Arrays;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.IOException;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;

import java.security.MessageDigest;
import java.security.cert.X509Certificate;

import javax.crypto.SecretKeyFactory;
import javax.crypto.Cipher;
import javax.crypto.CipherInputStream;
import javax.crypto.CipherOutputStream;
import javax.crypto.spec.SecretKeySpec;
import javax.crypto.spec.IvParameterSpec;

import javax.net.ssl.SSLException;
import javax.net.ssl.SSLProtocolException;

/**
 * A class to represent the current encryption regime.
 *
 * Non-handshaking crypto code is localized here.
 *
 * @author Ming Yung
 */ 
final class SSLContext 
{
  protected static final int CLIENT = 0;
  protected static final int SERVER = 1;
  protected static final int READ = 0;
  protected static final int WRITE = 1;

  private HandShaker handShaker;
  private ProtocolUnitInputStream pis;
  private ProtocolUnitOutputStream pos;

  private Cipher[] cipher = new Cipher[2];
  private MessageDigest[] mac = new MessageDigest[2];
  private byte[][] macSecret = new byte[2][];
  private long[] seqNum = new long[2];

  private int[] blockSize = new int[2];
  private int[] macSize = new int[2]; 
  private int[] keyLength = new int[2];
  private int[] expKeyLength = new int[2];
  private int[] padSize = new int[2];
  private String[] cipherAlgName = new String[2];
  private String[] macAlgName = new String[2];
  private boolean[] isBlockCipher = new boolean[2];

  private CipherSpec pendingCipherSpec;
  private SSLMac sslMac;

  SSLContext(InputStream in, OutputStream out) 
    throws IOException
  {
    this.pos = new ProtocolUnitOutputStream(this, new XDROutputStream(out));
    this.pis = new ProtocolUnitInputStream(this, new XDRInputStream(in));
  }

  protected void setPendingCipherSpec(byte[] type)
  {
    this.pendingCipherSpec = new CipherSpec(type);

    this.sslMac 
      = new SSLMac(this.pendingCipherSpec.macAlgName, 
                   this.pendingCipherSpec.padSize);
  }

  protected synchronized void setKeyData(int flag, int channel,
                            byte[] keyBlock,
                            byte[] serverRandom, byte[] clientRandom)
    throws SSLException
  {
    try
    {
      if (Debug.debug >= Debug.DEBUG_CRYPTO)
      {
        Debug.debug("keyBlock: ", keyBlock);
      }

      IvParameterSpec ivSpec;
      SecretKeySpec keySpec;

      this.macAlgName[channel] = this.pendingCipherSpec.macAlgName;
      this.cipherAlgName[channel] = this.pendingCipherSpec.cipherAlgName;
      this.padSize[channel] = this.pendingCipherSpec.padSize;
      this.macSize[channel] = this.pendingCipherSpec.macSize;
      this.blockSize[channel] = this.pendingCipherSpec.blockSize;
      this.keyLength[channel] = this.pendingCipherSpec.keyLength;
      this.expKeyLength[channel] = this.pendingCipherSpec.expKeyLength;
      this.isBlockCipher[channel] = this.pendingCipherSpec.isBlockCipher;

      if (isPendingExportable())
      {
        MessageDigest md5 = MessageDigest.getInstance("MD5");
        md5.update(keyBlock, macSize[channel]*2 +
          flag*keyLength[channel], keyLength[channel]);
        if (flag == CLIENT)
          md5.update(clientRandom);
        md5.update(serverRandom);
        if (flag == SERVER)
          md5.update(clientRandom);

        keySpec 
          = new SecretKeySpec(md5.digest(), 0, expKeyLength[channel], 
              cipherAlgName[channel]);

        if (flag == CLIENT)
          md5.update(clientRandom);
        md5.update(serverRandom);
        if (flag == SERVER)
          md5.update(clientRandom);

        ivSpec 
          = new IvParameterSpec(md5.digest(), 0, blockSize[channel]);
      }
      else
      {
        keySpec = new SecretKeySpec(keyBlock, 
                   macSize[channel]*2 + flag*keyLength[channel],
                              keyLength[channel], 
                              cipherAlgName[channel]);
        ivSpec
          = new IvParameterSpec(keyBlock, 2*(macSize[channel] 
              + keyLength[channel]) + flag*blockSize[channel], 
                blockSize[channel]);
      }

      if (Debug.debug >= Debug.DEBUG_CRYPTO)
      {
        Debug.debug("key: ", keySpec.getEncoded());
        Debug.debug("iv: ", ivSpec.getIV());
      }

      String mode = (isBlockCipher[channel] ? "/CBC/NoPadding" : "");
      String cipherAlgMode = cipherAlgName[channel] + mode;

      SecretKeyFactory keyFact 
        = SecretKeyFactory.getInstance(cipherAlgName[channel]);

      this.cipher[channel] = Cipher.getInstance(cipherAlgMode);

      int encryptionMode;
      if (channel == READ)
        encryptionMode = Cipher.DECRYPT_MODE;
      else
        encryptionMode = Cipher.ENCRYPT_MODE;

      this.cipher[channel].init(encryptionMode,
                            keyFact.translateKey(keySpec), ivSpec);

      this.macSecret[channel] = new byte[macSize[channel]];
      System.arraycopy(keyBlock, flag*macSize[channel], 
                       macSecret[channel], 0, macSize[channel]);

      this.mac[channel] 
        = MessageDigest.getInstance(macAlgName[channel]);

      this.seqNum[channel] = 0;

      if (Debug.debug >= Debug.DEBUG_MSG)
        Debug.debug("cipher changed");
    }
    catch (Exception e)
    {
      e.printStackTrace();
      throw new SSLException(e.getMessage());
    }
  }

  private byte[] generateMac(byte contentType, byte[] raw, int channel)
  {
    return sslMac.macMessage(macSecret[channel], 
                             getSequenceNumber(seqNum[channel]++),
                             contentType, raw, 
                             this.handShaker.getProtocolVersion());
  }

  protected byte[] macAndEncrypt(byte contentType, byte[] raw)
    throws IOException
  {
    if (cipher[WRITE] != null)
    {
       byte macBytes[] = generateMac(contentType, raw, WRITE);
       byte[] fragment = encrypt(raw, macBytes);
       return fragment;
    }
    else
      return raw;
  }

  private byte[] encrypt(byte[] raw, byte[] macBytes) throws IOException
  {
    ByteArrayOutputStream bos = new ByteArrayOutputStream();
    CipherOutputStream cos = new CipherOutputStream(bos, cipher[WRITE]);
    cos.write(raw);
    cos.write(macBytes);


    if (isBlockCipher[WRITE])
    {
      int rawAndMacLength = raw.length + macBytes.length;
      int temp
        = (rawAndMacLength + 1) % cipher[WRITE].getBlockSize();
      int padLength = (temp == 0 ? 0 : cipher[WRITE].getBlockSize() - temp);
      byte pad[] = new byte[padLength];

      cos.write(pad);
      cos.write((byte)padLength);

      if (Debug.debug >= Debug.DEBUG_CRYPTO)
      {
        Debug.debug("padLength: " + padLength);
        Debug.debug("rawAndMacLength: " + rawAndMacLength);
      }
    }

    byte[] retval =  bos.toByteArray();

    if (Debug.debug >= Debug.DEBUG_CRYPTO)
    {
      Debug.debug("raw: ", raw);
      Debug.debug("macBytes: ", macBytes);
      Debug.debug("cipherText: ", retval);
    }

    return retval;
  }

  protected byte[] decryptAndVerifyMac(byte contentType, byte[] encrypted)
    throws SSLProtocolException, IOException
  {
    if (cipher[READ] == null)
      return encrypted;
    else
    {
      byte[] contents = decrypt(contentType, encrypted);
      return contents;
    }
  }

  private byte[] decrypt(byte contentType, byte[] cipherText)
    throws IOException
  {
    if (Debug.debug >= Debug.DEBUG_CRYPTO)
      Debug.debug("cipherText.length: " + cipherText.length);

    ByteArrayInputStream bis = new ByteArrayInputStream(cipherText);
    CipherInputStream cis = new CipherInputStream(bis, cipher[READ]);
    byte decrypted[] = new byte[cipherText.length];
    int done = cis.read(decrypted);

    if (Debug.debug >= Debug.DEBUG_CRYPTO)
    {
      Debug.debug("decrypted: ", decrypted);
      Debug.debug("done: " + done);
    }

    byte padLength
      = (isBlockCipher[READ] ? decrypted[decrypted.length-1] : 0);

    if (Debug.debug >= Debug.DEBUG_CRYPTO)
    {
      Debug.debug("padlength: " + padLength);
      Debug.debug("macSize[READ]: " + macSize[READ]);
      Debug.debug("isBlockCipher[READ]: " + isBlockCipher[READ]);
    }

    byte retval[]
      = (isBlockCipher[READ] ?
         new byte[decrypted.length - macSize[READ] - (padLength + 1)] :
         new byte[decrypted.length - macSize[READ]]);
    System.arraycopy(decrypted, 0, retval, 0, retval.length);

    if (Debug.debug >= Debug.DEBUG_CRYPTO)
      Debug.debug("retval.length: " + retval.length);

    byte macReceived[] = new byte[macSize[READ]];
    System.arraycopy(decrypted, retval.length, macReceived,
                     0, macReceived.length);
    if (!Arrays.equals(macReceived, generateMac(contentType, retval, READ)))
    {
      throw new SSLProtocolException("Bad mac received");
    }
    return retval;
  }


  //Returns the sequence number as a unint64.
  private static byte[] getSequenceNumber(long seqNum)
  {
    byte retval[] = new byte[8];

    for (int i=0; i<8; i++)
    {
      retval[i] = (byte)((seqNum >> 8*(7-i)) & 0xff);
    }
    return retval;
  }

  protected String getPendingKeyXAlgName()
  {
    return this.pendingCipherSpec.keyXAlgName;
  }

  protected boolean isPendingExportable()
  {
    return this.pendingCipherSpec.isExportable;
  }

  class CipherSpec
  {
    protected String keyXAlgName;
    protected String cipherAlgName;
    protected String macAlgName;
    protected boolean isBlockCipher;
    protected boolean isExportable;
    protected int macSize;
    protected int blockSize;
    protected int keyLength;
    protected int expKeyLength;
  
    protected int padSize;
    protected String suiteName;
  
    /**
     * Default constructor for a CipherSpec
     */
    CipherSpec(byte[] type)
    {
      int suite = CipherSuites.getIndex(type);
  
      this.keyXAlgName = CipherSuites.keyXAlgName[suite];
      this.suiteName = CipherSuites.suiteName[suite];
      this.isExportable = CipherSuites.isExportable[suite];
      this.cipherAlgName = CipherSuites.cipherAlgName[suite];
      this.isBlockCipher = CipherSuites.isBlockCipher[suite];
      this.keyLength = CipherSuites.keyLength[suite];
      this.expKeyLength = CipherSuites.expKeyLength[suite];
      this.blockSize = CipherSuites.blockSize[suite];
      this.macAlgName = CipherSuites.macAlgName[suite];
      this.macSize = CipherSuites.macSize[suite];
      this.padSize = CipherSuites.padSize[suite];
    }
  }

  HandShaker getHandShaker()
  {
    return this.handShaker;
  }

  protected void setHandShaker(HandShaker handShaker)
  {
    this.handShaker = handShaker;
    this.handShaker.startPreEmptor();
  }

  protected void nextHandShake(SSLProtocolUnit pu) throws IOException
  {
    this.handShaker.nextMessage(pu);
  }

  ProtocolUnitInputStream getInputStream()
  {
    return pis;
  }

  ProtocolUnitOutputStream getOutputStream()
  {
    return pos;
  }

  SSLMac getMac()
  {
    return sslMac;
  }
}
