//////////////////////////////////////////////////////////////////////////// 
// 
// Copyright (C) DSTC Pty Ltd (ACN 052 372 577) 1993, 1994, 1995.
// Unpublished work.  All Rights Reserved.
// 
// The software contained on this media is the property of the
// DSTC Pty Ltd.  Use of this software is strictly in accordance
// with the license agreement in the accompanying LICENSE.DOC 
// file. If your distribution of this software does not contain 
// a LICENSE.DOC file then you have no rights to use this 
// software in any manner and should contact DSTC at the address 
// below to determine an appropriate licensing arrangement.
// 
//      DSTC Pty Ltd
//      Level 7, GP South
//      University of Queensland
//      St Lucia, 4072
//      Australia
//      Tel: +61 7 3365 4310
//      Fax: +61 7 3365 4311
//      Email: jcsi@dstc.qut.edu.au
// 
// This software is being provided "AS IS" without warranty of
// any kind.  In no event shall DSTC Pty Ltd be liable for
// damage of any kind arising out of or in connection with
// the use or performance of this software.
// 
//////////////////////////////////////////////////////////////////////////// 

package com.dstc.security.provider;

import java.io.IOException;
import java.io.ByteArrayOutputStream;
import java.math.BigInteger;
import java.security.SecureRandom;
import java.security.AlgorithmParameters;
import java.security.Key;
import java.security.PublicKey;
import java.security.PrivateKey;
import java.security.KeyFactory;
import java.security.InvalidParameterException;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.SignatureException;
import java.security.spec.AlgorithmParameterSpec;
import java.security.spec.X509EncodedKeySpec;
import javax.crypto.Cipher;
import javax.crypto.CipherSpi;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.ShortBufferException;
import javax.crypto.BadPaddingException;
import javax.crypto.IllegalBlockSizeException;
import com.dstc.security.asn1.Sequence;
import com.dstc.security.asn1.OctetString;
import com.dstc.security.asn1.Asn1Exception;
import com.dstc.security.x509.AlgorithmId;
import com.dstc.security.x509.SubjectPublicKeyInfo;

/**
 * Implements RSA Cipher according to PKCS#1
 */
public class RSA extends CipherSpi 
{
  protected Key rsaKey;
  protected byte data[];
  protected SecureRandom random;

  private int state;
  private ByteArrayOutputStream bos;
  private PKCS1Padding pad;
  private BigInteger modulus;
  private static final byte PRIVATE_KEY_ENCRYPT_MODE = (byte)0x01;
  private static final byte PUBLIC_KEY_ENCRYPT_MODE = (byte)0x02;
  private byte pkcs1Mode = PUBLIC_KEY_ENCRYPT_MODE;

  public void engineSetPadding(String padding) throws NoSuchPaddingException 
  {
    if (padding.equals("PKCS1Padding"))
    {
      //do nothing
    }
    else
    {
      throw new NoSuchPaddingException(padding + " Not supported");
    }
  }

  ///"RSA/1/PKCS1Padding" for private key encryption
  public void engineSetMode(String mode) throws NoSuchAlgorithmException 
  {
    if (mode.equals("1"))
    {
      this.pkcs1Mode = PRIVATE_KEY_ENCRYPT_MODE;
    }
    else if (mode.equals("2"))
    {
      this.pkcs1Mode = PUBLIC_KEY_ENCRYPT_MODE;
    }
    else
    {
      throw new NoSuchAlgorithmException(mode + " Not supported");
    }
  }

  protected int engineGetBlockSize() 
  {
    //Not a block cipher
    return 0;
  }

  protected int engineGetOutputSize(int inputLen)
  {
    return (data.length + inputLen);
  }

  protected byte[] engineGetIV() 
  {
    return null;
  }

  protected AlgorithmParameters engineGetParameters()
  {
    return null;
  }

  protected void engineInit(int opmode, Key key, SecureRandom random) 
    throws InvalidKeyException 
  {
    this.state = opmode;
    this.data = new byte[0];
    this.random = random;
    this.bos = new ByteArrayOutputStream();
    this.pad = new PKCS1Padding(this.pkcs1Mode);

    if (!key.getAlgorithm().equals("RSA"))
    {
      throw new InvalidKeyException("Not an RSA Key");
    }

    if (key instanceof PublicKey)
    {
      if (opmode == Cipher.DECRYPT_MODE && 
          pkcs1Mode == PUBLIC_KEY_ENCRYPT_MODE)
        throw new InvalidKeyException("Public Key decrypt not supported");

      try
      {  
        KeyFactory keyFact = KeyFactory.getInstance("RSA", "DSTC");
        X509EncodedKeySpec keySpec = new X509EncodedKeySpec(key.getEncoded());
        this.rsaKey = keyFact.generatePublic(keySpec);
        this.modulus = ((RSAPublicKey)this.rsaKey).getModulus();
      }
      catch (Exception e)
      {
        throw new InvalidKeyException("Bad Key encoding");
      }
    }
    else if (key instanceof PrivateKey)
    {
      if (opmode == Cipher.DECRYPT_MODE &&
          pkcs1Mode == PRIVATE_KEY_ENCRYPT_MODE)
        throw new InvalidKeyException("Private Key decrypt not supported");

      this.rsaKey = (RSAPrivateCrtKey)key;
      this.modulus = ((RSAPrivateCrtKey)this.rsaKey).getModulus();
    }
  }

  protected void engineInit(int opmode, Key key, AlgorithmParameterSpec params,
                            SecureRandom random) 
    throws InvalidKeyException, InvalidAlgorithmParameterException 
  {
    engineInit(opmode, key, random);
  }

  protected int engineDoFinal(byte input[], int inputOffset, int inputLen,
                              byte output[], int outputOffset) 
    throws ShortBufferException, IllegalBlockSizeException, BadPaddingException 
  {
    engineUpdate(input, inputOffset, inputLen);
    
    if (output.length - outputOffset < engineGetOutputSize(inputLen))
      throw new ShortBufferException("Output Buffer too short");

    byte out[] = new byte[0];

    try
    {
      this.bos.flush();
      this.data = this.bos.toByteArray();
    }
    catch (IOException e)
    {
      //shouldn't happen
    }

    int k = (modulus.bitLength() + 1)/8;

    if (state == Cipher.ENCRYPT_MODE)
    {
      if (pkcs1Mode == PUBLIC_KEY_ENCRYPT_MODE)
        out = publicKeyOp(pad.doPadding(k, this.data));
      else
        out = privateKeyOp(pad.doPadding(k, this.data));
    }
    else if (state == Cipher.DECRYPT_MODE)
    {
      if (pkcs1Mode == PUBLIC_KEY_ENCRYPT_MODE)
        out = pad.doUnPadding(k, privateKeyOp(this.data));
      else
        out = pad.doUnPadding(k, publicKeyOp(this.data));
    }

    System.arraycopy(out, 0, output, outputOffset, out.length);
    return out.length;
  }

  protected byte[] engineDoFinal(byte input[], int inputOffset, int inputLen) 
    throws IllegalBlockSizeException, BadPaddingException 
  {
    byte retval[] = null;

    engineUpdate(input, inputOffset, inputLen);

    try
    {
      this.bos.flush();
      this.data = this.bos.toByteArray();
    }
    catch (IOException e)
    {
      //shouldn't happen
    }

    int k = (modulus.bitLength() + 1)/8;

    if (state == Cipher.ENCRYPT_MODE)
    {
      if (pkcs1Mode == PUBLIC_KEY_ENCRYPT_MODE)
        retval = publicKeyOp(pad.doPadding(k, this.data));
      else
        retval = privateKeyOp(pad.doPadding(k, this.data));
    }
    else if (state == Cipher.DECRYPT_MODE)
    {
      if (pkcs1Mode == PUBLIC_KEY_ENCRYPT_MODE)
        retval = pad.doUnPadding(k, privateKeyOp(this.data));
      else
        retval = pad.doUnPadding(k, publicKeyOp(this.data));
    }

    return retval;
  }

  protected int engineUpdate(byte input[], int inputOffset, int inputLen,
                             byte output[], int outputOffset) 
    throws ShortBufferException 
  {
    this.bos.write(input, inputOffset, inputLen);
    return 0;
  }

  protected byte[] engineUpdate(byte input[], int inputOffset, int inputLen)
  {
    this.bos.write(input, inputOffset, inputLen);
    return null;
  }

  private byte[] publicKeyOp(byte input[]) 
  {
    BigInteger modulus = ((RSAPublicKey)rsaKey).getModulus();
    BigInteger exponent = ((RSAPublicKey)rsaKey).getPublicExponent();

    BigInteger x = new BigInteger(1, input);
    BigInteger y = x.modPow(exponent, modulus);
    
    return removeLeadingZero(y.toByteArray());
  }

  private byte[] privateKeyOp(byte input[]) 
  {
    BigInteger modulus, exponent, p, q, primeExponentP, primeExponentQ, 
               crtCoefficient;

    modulus = ((RSAPrivateCrtKey)rsaKey).getModulus();
    exponent = ((RSAPrivateCrtKey)rsaKey).getPrivateExponent();

    p = ((RSAPrivateCrtKey)rsaKey).getPrimeP();
    q = ((RSAPrivateCrtKey)rsaKey).getPrimeQ();
    primeExponentP = ((RSAPrivateCrtKey)rsaKey).getPrimeExponentP();
    primeExponentQ = ((RSAPrivateCrtKey)rsaKey).getPrimeExponentQ();
    crtCoefficient = ((RSAPrivateCrtKey)rsaKey).getCrtCoefficient();

    BigInteger y = new BigInteger(1, input);

    BigInteger j1 = y.modPow(primeExponentP, p);
    BigInteger j2 = y.modPow(primeExponentQ, q);
    BigInteger h = j1.subtract(j2).multiply(crtCoefficient).mod(p);
    BigInteger x = h.multiply(q).add(j2);

    return removeLeadingZero(x.toByteArray());
  }

  protected class PKCS1Padding
  {
    private byte blockType;

    protected PKCS1Padding(byte BT)
    {
      this.blockType = BT;
    }

    protected byte[] doPadding(int k, byte[] data)
    {
      // PKCS#1 specifies: EB = 00 || BT || PS || 00 || D
      byte retval[] = new byte[k];
  
      // First byte 00
      retval[0] = (byte)0x00;
  
      // BT
      retval[1] = this.blockType;
  
      // PS
      if (this.blockType == 0x01 || this.blockType == 0x02)
      {
        for (int i=0; i < k - 3 - data.length; i++)
        {
          retval[2+i] = (byte)0xff;
        }
      }
      else
      {
        byte[] randBytes = new byte[k - 3 - data.length];
        random.nextBytes(randBytes);
        for (int i=0; i<randBytes.length; i++)
        {
          randBytes[i] |= 0x01;
        }
        System.arraycopy(randBytes, 0, retval, 2, randBytes.length);
      }

      // 00 || D
      System.arraycopy (data, 0, retval, k - data.length, data.length);
  
      // EB
      return retval;
    }

    protected byte[] doUnPadding(int k, byte block[])
      throws BadPaddingException
    {
      if ((block[0] != this.blockType))
      {
        throw new BadPaddingException("Bad block type");
      }
  
      int index = 2 - k + block.length;
  
      // Compute beginning of data
      while (true) 
      {
        if (block[index] == (byte)0x00) break;
        index++;
      }
  
      byte retval[] = new byte[block.length - index - 1];
      System.arraycopy(block, index + 1, retval, 0, retval.length);
      return retval;
    }
  }

  private byte[] removeLeadingZero(byte[] in)
  {
    if (in[0] == 0)
    {
      byte retval[] = new byte[in.length-1];
      System.arraycopy(in, 1, retval, 0, retval.length);
      return retval;
    }
    else
      return in;
  }
}
