/*
 * Handle setup of session: session key generaation/exchange, cipher selection,
 * cipher initialization.
 *
 * Author:  David Jones
 * Date:    20-MAY-1998
 * Revised: 19-JUN-1998		Isolate crypto portion to another module.
 * Revised: 27-JUN-1998		Include parameters.h.
 * Revised:  1-JUL-1998		Test for correct message scan.
 * Revised: 13-JUL-1998		Plug memory leak in destroy server context.
 */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "sslinc:rand.h"

#include "cport_sshsess.h"
#include "tmemory.h"
#include "tutil.h"

#include "sshcipher.h"

#ifdef DEBUG
static void dump_buffer ( char *heading, unsigned char *buffer, int bufsize )
{
    int i, j;
    printf ( "%s (%d bytes):", heading, bufsize );
    j = 15;
    for ( i = 0; i < bufsize; i++ ) {
	j++;
	if ( j > 15 ) {
	    printf("\n   %04x:", i ); j = 0; }
	printf(" %02x", 255&buffer[i] );
    }
    printf("\n");
}
#else
#define dump_buffer(a,b,c) 1
#endif

/************************************************************************/
int sshsess_rundown ( sshsess_session session )
{
    /*
     * Cleanly close connection.
     */
    sshpad_destroy ( session->pad );
    
    if ( !session->server_ctx ) {
	/*
	 * Free up RSA key structures if allocated.  In server mode,
	 * keys are aliased to those in server_ctx so don't delete.
	 */
	if ( session->server_key ) sshrsa_destroy(session->server_key);
	if ( session->host_key ) sshrsa_destroy(session->host_key);
    }
    tm_free ( session );
    return 1;
}
/************************************************************************/
/* Rundown allocated context.
 */
void sshsess_destroy_server_context ( sshsess_server ctx )
{
    if ( ctx->server_key ) sshrsa_destroy ( ctx->server_key );
    if ( ctx->host_key ) sshrsa_destroy ( ctx->host_key );
    if ( ctx->hkey_exp_value ) free ( ctx->hkey_exp_value );
    if ( ctx->hkey_value ) free ( ctx->hkey_value );
    if ( ctx->skey_exp_value ) free ( ctx->skey_exp_value );
    if ( ctx->skey_value ) free ( ctx->skey_value );
    free ( ctx );
}
/************************************************************************/
/*  Allocate context for processing server side of SSH connection.
 */
sshsess_server sshsess_server_context ( 
	char *skey_file, 		/* PEM file with server key */
	char *hkey_file,		/* PEM file with  host key */
	int ciphers, int auth_types, 	/* Ciphers to advertise */
	char errmsg[256] ) 		/* diagnostice if NULL returned */
{
    struct sshsess_server_st template, *new;
    int status, delta_bits;
    /*
     * build template structure.  First load keys.
     */
    template.server_key = sshrsa_read_keyfile ( skey_file, 
	&template.skey_bits, errmsg );
    if ( !template.server_key ) return (sshsess_server) 0;
    template.host_key = sshrsa_read_keyfile ( hkey_file,
	&template.hkey_bits, errmsg );
    if ( !template.host_key ) {
	sshrsa_destroy ( template.server_key );
	return (sshsess_server) 0;
    }
    /*
     * Convert keys to multi-byte integer format for future reference
     * in SSH message construction.
     */
    template.hkey_exp_value = template.hkey_value = (char *) 0;
    template.skey_exp_value = template.skey_value = (char *) 0;
    status = sshrsa_extract_public_key ( template.server_key, 
	&template.skey_exp_bits, &template.skey_exp_value,
	&template.skey_bits, &template.skey_value );

    if ( status&1 ) status = sshrsa_extract_public_key ( template.host_key, 
	&template.hkey_exp_bits, &template.hkey_exp_value,
	&template.hkey_bits, &template.hkey_value );
    if ( (status&1) == 0 ) {
	strcpy ( errmsg, "Error converting key format" );
	return (sshsess_server) 0;
    }

    delta_bits = template.hkey_bits - template.skey_bits;
    if ( (delta_bits < 256) && (delta_bits > -256) ) {
	strcpy ( errmsg, 
	    "Server key and host key do not differ by at least 256 bits" );
	return (sshsess_server) 0;
    }

    template.protocols = 0;
    template.ciphers = (ciphers & KNOWN_CIPHERS);
    template.authentications = auth_types;
    /*
     * Initialize MD5 context and preload with digest of common part
     * of session_ids (don't do final yet).
     */
    MD5_Init ( &template.pre_digest );
    MD5_Update ( &template.pre_digest, (unsigned char *) template.hkey_value,
		(unsigned long) (template.hkey_bits+7)/8 );
    MD5_Update ( &template.pre_digest, (unsigned char *) template.skey_value,
		(unsigned long) (template.skey_bits+7)/8 );
    /*
     * Initialize per-thread heaps for use by session creators.
     */
    tm_initialize();
    /*
     * Allocate new block and initialize with template.  Note that this
     * structure is shared, do not use tm_malloc().
     */
    new = (struct sshsess_server_st *) malloc 
		(sizeof(struct sshsess_server_st));
    if ( !new ) {
	strcpy (errmsg, "Error allocating memory for session_server context");
    }
    else *new = template;

    return (sshsess_server) new;
}
sshsess_server sshsess_server_copy ( 
	char *skey_file, 		/* PEM file with new server key */
	sshsess_server original,
	char errmsg[256] ) 		/* diagnostice if NULL returned */
{
    struct sshsess_server_st template, *new;
    int status;
    /*
     * build template structure.  First load keys  Note that host_key is
     * shared by the copy and the original, take take when destroying.
     */
    errmsg[0] = '\0';
    template.server_key = sshrsa_read_keyfile ( skey_file, 
	&template.skey_bits, errmsg );
    if ( !template.server_key ){
	tu_strcpy ( errmsg, "error reading skey file" );
	 return (sshsess_server) 0;
    }
    template.host_key = original->host_key;
    if ( !template.host_key ) {
	tu_strcpy ( errmsg, "Orginal host key was null" );
	sshrsa_destroy ( template.server_key );
	return (sshsess_server) 0;
    }
    /*
     * Convert keys to multi-byte integer format for future reference
     * in SSH message construction.
     */
    template.hkey_exp_value = template.hkey_value = (char *) 0;
    template.skey_exp_value = template.skey_value = (char *) 0;
    status = sshrsa_extract_public_key ( template.server_key, 
	&template.skey_exp_bits, &template.skey_exp_value,
	&template.skey_bits, &template.skey_value );

    if ( status&1 ) status = sshrsa_extract_public_key ( template.host_key, 
	&template.hkey_exp_bits, &template.hkey_exp_value,
	&template.hkey_bits, &template.hkey_value );
    if ( (status&1) == 0 ) {
	tu_strcpy ( errmsg, "Error converting key format" );
	return (sshsess_server) 0;
    }
    template.protocols = 0;
    template.ciphers = original->ciphers;
    template.authentications = original->authentications;
    /*
     * Initialize MD5 context and preload with digest of common part
     * of session_ids (don't do final yet).
     */
    MD5_Init ( &template.pre_digest );
    MD5_Update ( &template.pre_digest, (unsigned char *) template.hkey_value,
		(unsigned long) (template.hkey_bits+7)/8 );
    MD5_Update ( &template.pre_digest, (unsigned char *) template.skey_value,
		(unsigned long) (template.skey_bits+7)/8 );
    /*
     * Allocate new block and initialize with template.  Note that this
     * structure is shared, do not use tm_malloc().
     */
    new = (struct sshsess_server_st *) malloc 
		(sizeof(struct sshsess_server_st));
    if ( !new ) {
	tu_strcpy (errmsg, "Error allocating memory for context");
    }
    else *new = template;

    return (sshsess_server) new;
}

/*****************************************************************************/
/* Deallocate session block, copy error message to errmsg and return null.
 */
static sshsess_session abort_new_session ( 
	struct sshsess_session_st *session, const char *diagnostic,
	char *errmsg )
{
    if ( !session->server_ctx ) {
	/*
	 * Free up RSA key structures if allocated.  In server mode,
	 * keys are aliased to thos in server_ctx so don't delete.
	 */
	if ( session->server_key ) sshrsa_destroy(session->server_key);
	if ( session->host_key ) sshrsa_destroy(session->host_key);
    }
    sshmsg_destroy_locus ( session->locus );
    tm_free ( session );
    /*
     * append diagnostic
     */
    if ( diagnostic == errmsg ) {	/* do nothing */
    } else if ( errmsg[0] && (strlen(errmsg) + strlen(diagnostic) < 252) ) {
	strcat ( errmsg, " : " );
	strcat ( errmsg, diagnostic );
    } else {
        strcpy ( errmsg, diagnostic );
    }
    return (sshsess_session) 0;
}
	

/******************************************************************************/
/* Initialize new session - exchange session key and negotiate and initialize
 * cipher.  All communication over the returned session context will be
 * encrypted.
 */
sshsess_session sshsess_new_session ( 
	sshsess_server ctx, 		/* Server context or NULL if client */
	cport_port port,		/* I/O completion queue */
	sshpad pad, 			/* TCP connection */
	char *cipher_prefs,		/* allowed ciphers ordered by pref. */
	char *auth_prefs,		/* allowed auths ordered by pref */
	char errmsg[256] ) 		/* diagnostic for failure */
{
    struct sshsess_session_st *new;
    struct sshsess_server_st *server, client_copy;
    int status, enc_size, length, i, xcount, size;
    int flags, bytes, *locusdmp;
    MD5_CTX id_digest;
    char enc_buf[4096];
    sshmsg_local msg1, msg2;		/* dynamic message buffers */
    /*
     * Allocate block and locus, zero local execution state.
     */
    errmsg[0] = '\0';
    size = sizeof(struct sshsess_session_st) + sshciph_state_size();
    new = (struct sshsess_session_st *) tm_malloc ( size );
    if ( !new ) {
	strcpy ( errmsg, "Error allocating memory for session context" );
	return (sshsess_session) 0;
    }
    new->server_ctx = (struct sshsess_server_st *) ctx;
    new->pad = pad;
    new->cport = port;
    new->server_key = new->host_key = (sshrsa) 0;
    new->locus = sshmsg_create_locus ( pad, 4 );
    new->out_msg = cport_assign_stream ( new->cport,
	&cportsshmsg_driver, new->locus, 0 );
    new->in_msg = cport_assign_stream ( new->cport,
	&cportsshmsg_driver, new->locus, 1 );
    sshmsg_init_locals ( new->locus, 2, &msg1, &msg2 );
    new->exec_state = 0;
    new->exec_env = (void *) 0;
    /*
     * Initialize server context, either copy from ctx supplied by caller
     * (server mode) or read public_key message and build temporary
     * structure (client mode).
     */
    if ( ctx ) {
	server = ctx;
	new->server_key = server->server_key;
	new->host_key = server->host_key;
	/*
	 * Format message.
	 */
	RAND_bytes ( (unsigned char *) new->cookie, sizeof(new->cookie) );
	status = sshmsg_format_message ( &msg1,	SSH_SMSG_PUBLIC_KEY, 
		new->cookie,
		server->skey_bits, 
		server->skey_exp_bits, server->skey_exp_value,
		server->skey_bits, server->skey_value,
		server->hkey_bits,
		server->hkey_exp_bits, server->hkey_exp_value,
		server->hkey_bits, server->hkey_value,
		server->protocols,
		server->ciphers,
		server->authentications );
	if ( status != 14 ) { printf ("Wrong result for format: %d\n", status); }
	dump_buffer ( "formatted message", (unsigned char *)msg1.data, msg1.length );
	/*
	 * Send message to client and confirm we sent it.
	 */
	status = cport_do_io ( new->out_msg, CPORT_WRITE, &msg1, 1, &xcount );
	if ( (status*1) == 0 ) {
	    return abort_new_session ( new,
	        "Error writing public key message to client", 
		sshmsg_last_error_text ( new->locus ) );
	}
	/*
	 * Make copy of saved pre-digested partial MD5 of session id.
	 */
	id_digest = server->pre_digest;	
    } else {
	/*
	 * Set server pointer to temporary shadow copy.
	 */
	server = &client_copy;
	new->server_key = server->server_key = sshrsa_create();
	if (new->server_key) new->host_key = server->host_key = sshrsa_create();
	if ( !new->host_key || !new->server_key ) {
	    return abort_new_session ( new, "Error initialize RSA keys",
		errmsg );
	}
	/*
         * Read public key from server and use to initialize server struct.
	 */
	status = cport_do_io ( new->in_msg, CPORT_READ, &msg2, 1, &xcount );
	if ( (status&1) == 0 ) return abort_new_session ( new,
		"Error reading server key", 
		sshmsg_last_error_text ( new->locus ) );
        else if ( msg2.type == SSH_SMSG_PUBLIC_KEY ) {
		/*
		 * read completed and public key received. extract data
		 * into local variables.  Remember that returned char
		 * pointers into message buffer.
		 */
		int i, skey_bits, skey_exp_bits, skey_mod, skey_mod_p;
		i = sshmsg_scan_message ( &msg2,
			&new->cookie,
			&server->skey_bits, 
			&server->skey_exp_bits, &server->skey_exp_value,
			&server->skey_bits, &server->skey_value,
			&server->hkey_bits,
			&server->hkey_exp_bits, &server->hkey_exp_value,
			&server->hkey_bits, &server->hkey_value,
			&server->protocols,
			&server->ciphers,
			&server->authentications );
		if ( i != 14 ) return abort_new_session ( new,
			"Error scanning PUBLIC_KEY message", errmsg );
		/*
		 * Convert keys to format used by SSLeay.
		 */
		new->server_key = sshrsa_create();
		new->host_key = sshrsa_create();
		status = sshrsa_set_public_key ( server->skey_exp_bits, 
			server->skey_exp_value, server->skey_bits, 
			server->skey_value, new->server_key );
		status = sshrsa_set_public_key ( server->hkey_exp_bits, 
			server->hkey_exp_value, server->hkey_bits,
			server->hkey_value, new->host_key );
		/*
		 * Compute first part of digest for session id.
		 */
		MD5_Init ( &id_digest );
		MD5_Update ( &id_digest, 
			(unsigned char *) server->hkey_value,
			(unsigned long) (server->hkey_bits+7)/8 );
		MD5_Update ( &id_digest, 
			(unsigned char *) server->skey_value,
			(unsigned long) (server->skey_bits+7)/8 );
	} else {
		strcpy ( errmsg, "Unexpected response from ssh PAD layer" );
		status = 0;
	}
    }
    /*
     * Compute session ID, appending cookie.
     */
    MD5_Update ( &id_digest, (unsigned char *) new->cookie,
	sizeof(new->cookie) );
    MD5_Final ( new->session_id, &id_digest );
    /*
     * Get cipher type and and session key.
     */
   if ( ctx ) {
	/*
	 * Read response from client and decrypt session key.
	 */
	status = cport_do_io ( new->in_msg, CPORT_READ, &msg1, 1, &xcount );
	if ( status&1 ) {
	    if ( msg1.type == SSH_CMSG_SESSION_KEY ) {
		char cookie[sizeof(new->cookie)], *session_key;
		int session_key_bits;
		/*
		 * Dis-assemble fields in packet.
		 */
		new->cipher = 0;
		if ( 5 != sshmsg_scan_message  ( &msg1,
			&new->cipher, cookie, &session_key_bits, 
			&session_key, &new->protocol_flags ) ) return abort_new_session (
		    new, "Error scanning SESSION_KEY message", errmsg );
		/*
		 * Verify the cookie matches.
		 */
		for ( i = 0; i < sizeof(new->cookie); i++ ) {
		    if ( new->cookie[i] != cookie[i] ) {
			return abort_new_session ( new,
				"Wrong cookie in client response", errmsg );
		    }
		}
	        /*
		 * Verify cipher is supported/allowed.
		 */
		if ( new->cipher > 31 ) return abort_new_session ( new,
			"Cipher number out of range", errmsg );
		if ( 0 == (ctx->ciphers&(1<<new->cipher)) )
			return abort_new_session ( new, "Disallowed cipher",
			errmsg );
		/*
		 * Decrypt the session key, using key with larger modulus first.
		 */
		if ( server->skey_bits > server->hkey_bits ) {
		    sshrsa_decrypt_number ( 
			(session_key_bits+7)/8, session_key, new->server_key, 
			enc_buf, (server->hkey_bits+7)/8, &length);
		    sshrsa_decrypt_number ( length, enc_buf, new->host_key, 
			(char *) new->session_key, sizeof(new->session_key), 
			&length );
		} else {
		    sshrsa_decrypt_number ( 
			(session_key_bits+7)/8, session_key, new->host_key, 
			enc_buf, (server->skey_bits)/8, &length);
		    sshrsa_decrypt_number ( length, enc_buf, new->server_key, 
			(char *) new->session_key, sizeof(new->session_key), 
			&length );
		}
	    } else {
		strcpy ( errmsg, "Unexpected response from ssh PAD layer" );
		status = 0;
	    }
	} else return abort_new_session ( new, "Error reading session key",
		sshmsg_last_error_text ( new->locus ) );

   } else {
	int cipher;
	char iobuffer[4096];
	/*
	 * Choose cipher from preferred cipher list and what server can handle.
	 * Lower in list means most preferred.
	 */
	new->cipher = -1;
	for ( i = 0; cipher_prefs[i]; i++ ) {
	    /*
	     * See if cipher is in list supported by server.
	     */
	    if ( cipher_prefs[i] < 0 || cipher_prefs[i] > 31 ) {
		return abort_new_session ( new, "Cipher pref out of range",
			errmsg );
	    }
	    if ( ((1<<cipher_prefs[i]) & KNOWN_CIPHERS) == 0) continue;
	    if ( (1<<cipher_prefs[i]) & server->ciphers ) {
		/*
		 * Found compatible cipher.
		 */
		new->cipher = cipher_prefs[i];
		break;
	    }
	}
	if ( new->cipher == -1 ) {
	    /*
	     * No match found, last chance for match is bit 0.
	     */
	    if ( KNOWN_CIPHERS & server->ciphers & 1 ) {
		new->cipher = 0;		/* cipher 'none' */
	    } else return abort_new_session ( new,
		"No compatible cipher's found with server", errmsg );
	}
	/*
	 * build random session key and encrypt.
	 */
        RAND_bytes ( new->session_key, sizeof(new->session_key) );
	for ( i=0; i < sizeof(new->session_id); i++ )
		new->session_key[i] ^= new->session_id[i];

	if ( server->skey_bits > server->hkey_bits ) {
	    sshrsa_encrypt_number ( 256, (char *) new->session_key, 
		new->host_key, iobuffer, sizeof(iobuffer), &length );
	    sshrsa_encrypt_number ( length*8, iobuffer, new->server_key, 
		enc_buf, sizeof(enc_buf), &length );
	} else {
	    sshrsa_encrypt_number ( 256, (char *) new->session_key, 
		new->server_key, iobuffer, sizeof(iobuffer), &length );
	    sshrsa_encrypt_number ( length*8, iobuffer, new->host_key, 
		enc_buf, sizeof(enc_buf), &length );
	}
	/*
	 * Construct SESSION_KEY message and reply to server.
	 */
	sshmsg_format_message ( &msg2, SSH_CMSG_SESSION_KEY, &new->cipher, 
		new->cookie, length*8, enc_buf, server->protocols );

	status = cport_do_io ( new->out_msg, CPORT_WRITE, &msg2, 1, &xcount );
	if ( (status&1) == 0 ) {
	    return abort_new_session ( new, 
		"Error sending session key to server", 
		sshmsg_last_error_text ( new->locus ) );
	}
    }
    /*
     * Do final xor to restore original session key.
     */
    for ( i=0; i < sizeof(new->session_id); i++ ) {
	char t;
	new->session_key[i] ^= new->session_id[i];
    }
    dump_buffer ( "session id", (unsigned char *) new->session_id, 16 );
    dump_buffer ( "Final session key", (unsigned char *) new->session_key, 32 );
    /*
     * set PAD to use selected cipher and initialize cipher context.
     * (use memset to set initialization vectors to zero).
     */
    if ( new->cipher != SSH_CIPHER_NONE ) {
        if ( 0 == sshciph_setup ( new, errmsg ) ) {
	    return abort_new_session ( new, "cipher init err", errmsg );
        } else {
	    sshpad_set_encryption ( new->pad, 
	        new->inbound,  new->in_state, new->outbound, new->out_state );
	}
    }
    /*
     * If in server mode, acknolege receipt of session key.
     */
    if ( new->server_ctx ) {
	sshmsg_format_message ( &msg1, SSH_SMSG_SUCCESS );
	status = cport_do_io ( new->out_msg, CPORT_WRITE, &msg1, 1, &xcount );
	if ( (status&1) == 0 ) return abort_new_session ( new,
		"Error acknowleging session key", 
		sshmsg_last_error_text(new->locus) );
    } else {
	status = cport_do_io ( new->in_msg, CPORT_READ, &msg2, 1, &xcount );
	if ( (status&1) == 0 ) return abort_new_session ( new,
		"Error reading ack of session key", 
		sshmsg_last_error_text(new->locus) );
	if ( msg2.type != SSH_SMSG_SUCCESS ) {
	    return abort_new_session ( new,
		"Server did not acknowlege session key", errmsg );
	}
    }
    sshmsg_rundown_locals ( new->locus );
   return (sshsess_session) new;
}
/*
	encrypted data format for key exhange (sent by client)
	cookie[8] = random data (from server_key message)
	session_id[0..15] = MD5(host_key//server_key//cookie)
	session_key[0..31] = random data.

	PKCS1(data) = 0//2//rand-fill//0//data

        transfer = RSA(server_key,RSA(PKCS1(host_key,session_key^session_id)))
 */
