/*
 * Support routines for encoding/decoding SSH messages.
 *
 * Author:  David Jones
 * Date:    16-MAY-1998
 * Revised:  3-JUN-1998
 * Revised:  5-JUN-1998			Another memory allocation bug fix.
 * Revised:  8-AUG-1998			Propagate timeout settings.
 * Revised:  16-SEP-1998		Change format def for CHANNEL_DATA
 * Revised:  17-SEP-1998		Initialize ref-count in create locus.
 * Revised:  27-OCT-1998		Use special format for open_confirm.
 * Revised:  1-NOV-1998			Use special format for PORT_OPEN
 */
#include <stdarg.h>
#include <string.h>
#include <ssdef.h>			/* VMS system service error codes */
#include <stdio.h>

#include "tmemory.h"
#define CPORT_DRIVER_SPECIFIC struct locus
#include "cport_sshpad.h"
#include "cport_sshmsg.h"
/*
 * Define locus structure.
 */
struct local_buffer {
    int size;				/* space allocated to buffer */
    char *data;				/* allocated space */
};
struct trap_vector_entry {
    sshmsg_trap_handler handler;
    void *arg;
};

struct locus {
    struct locus *next;
    sshpad pad;				/* packet layer handle */
    int buffer_alloc;			/* size of buffer table */
    int buffer_used;			/* Next free buffer index */
    int ref_count;			/* number of streams assigned */
    struct local_buffer *buffer;	/* buffer table */
    int trap_vector_size;
    unsigned char trap_table[256];	/* trap_table[i]-1 is vector index */
    struct trap_vector_entry *trap_vector;
    cport_isb out_pad;			/* oubound packet stream */
    cport_isb in_pad;			/* inbound packet stream */
    char errmsg[256];			/* save last I/O error message */
};
/*
 * Default trap vector set to ignore SSH_MSG_IGNORE and SSH_MSG_DEBUG
 */
static int ignore_trap ( void *dummy, sshmsg_local *msg ) { return 0; }
static int debug_trap ( void *dummy, sshmsg_local *msg ) { return 0; }
static struct trap_vector_entry default_trap_vector[2] =
{ {ignore_trap, 0}, {debug_trap, 0} };
/*
 * Define table of format strings for format and scan functions, used to
 * pack and unpack data into SSH protocol messages.  String encodes layout of 
 * message:
 *        i	32-bit integer, stored into next arg in host order.
 *	  m	multi-precision integer, stored into next 2 args.
 *	 nb	'n' bytes (e.g. 8b).
 *	  s	Counted string, stored into next 2 args: count and pointer.
 *        nx	Return current buffer pointer, advance by n if n not zero.
 *        .     Terminates format, REQUIRED.  Remainder of string is
 *		descriptive text.
 */
#define SSHMSG_FORMAT_TABLE_SIZE 40
static char *msg_format[SSHMSG_FORMAT_TABLE_SIZE] = { 
    ".MSG_NONE",			/* 0 - none */
    "s.MSG_DISCONNECT",	       		/* 1 - disconnect */
    "8bimmimmiii.SMSG_PUBLIC_KEY",	/* 2 - public_key */
    "b8bmi.CMSG_SESSION_KEY",		/* 3 - sesion_key */
    "s.CMSG_USER",			/* 4 - user */
    "s.CMSG_AUTH_RHOSTS",		/* 5 - auth rhosts */
    "m.CMSG_AUTH_RSA",			/* 6 - auth rsa */
    "m.SMSG_AUTH_RSA_CHALLENGE",	/* 7 - auth rsa challenge */
    "16b.CMSG_AUTH_RSA_RESPONSE",	/* 8 - auth rsa response */
    "s.CMSG_AUTH_PASSWORD",		/* 9 - password */

    "siiiix.CMSG_REQUEST_PTY",		/* 10 - request pty */
    "iiii.CMSG_WINDOW_SIZE",		/* 11 - window size */
    ".CMSG_EXEC_SHELL",			/* 12 - exec shell */
    "s.CMSG_EXEC_CMD",			/* 13 - exec cmd */
    ".SMSG_SUCCESS",			/* 14 - success */
    ".SMSG_FAILURE",			/* 15 - failure */
    "s.CMSG_STDIN_DATA",		/* 16 - stdin data */
    "s.SMSG_STDOUT_DATA",		/* 17 - stdout data */
    "s.SMSG_STDERR_DATA",		/* 18 - stderr data */
    ".CMSG_EOF",			/* 19 - stdin EOF */

    "i.SMSG_EXITSTATUS",		/* 20 - exit status*/
    "8b.MSG_CHANNEL_OPEN_CONFIRMATION",	/* 21 - channel open confirm (ii)*/
    "i.MSG_CHANNEL_OPEN_FAILURE",	/* 22 - channel open failure */
    "4bs.MSG_CHANNEL_DATA",		/* 23 - channel data */
    "4b.MSG_CHANNEL_CLOSE",		/* 24 - channel close */
    "i.MSG_CHANNEL_CLOSE_CONFIRMATION",	/* 25 - channel close confirm */
    ".CMSG_X11_REQUEST_FORWARDING",	/* 26 - X11 request forwarding */
    "i.SMSG_X11_OPEN",			/* 27 - X11 open */
    "isi.CMSG_PORT_FORWARD_REQUEST",	/* 28 - port forward request */
    "4bs4b.MSG_PORT_OPEN",		/* 29 - port open */

    ".CMSG_AGENT_REQUEST_FORWARDING",	/* 30 - agent request forwarding */
    "i.SMSG_AGENT_OPEN",		/* 31 - agent open */
    "s.MSG_IGNORE",			/* 32 - NOP message */
    ".CMSG_EXIT_CONFIRMATION",		/* 33 - exit confirmation */
    "ss.CMSG_X11_FWD_WITH_AUTH_SPOOFING",/* 34 - x11 forward with auth spoof */
    "simm.CMSG_AUTH_RHOSTS_RSA",	/* 35 - auth rhosts rsa */
    "s.MSG_DEBUG",			/* 36 - debug message */
    "i.CMSG_REQUEST_COMPRESSION",	/* 37 - Set compression alg. */
    "i.CMSG_MAX_PACKET_SIZE",		/* 38 - Limit max packet size sent */
    "x.undefined39"			/* 39 - undefined */
};

/*************************************************************************/
/* Low level routines for creating/destroying message locii.
 */
sshmsg_locus *sshmsg_create_locus ( sshpad pad, int initial_locals )
{
    struct locus *new;
    int i;
    /*
     * Allocate structure and initial buffer table.
     */
    new = (struct locus *) tm_malloc ( sizeof (struct locus) );
    if ( !new ) return (sshmsg_locus) new;
    new->pad = pad;
    new->buffer_alloc = initial_locals;
    new->buffer_used = 0;
    new->ref_count = 0;
    new->buffer = (struct local_buffer *) tm_malloc (
	initial_locals * sizeof(struct local_buffer) );
    if ( !new->buffer ) { tm_free ( new ); return (sshmsg_locus) 0; }
    /*
     * initialize trap vector.
     */
    for ( i = 0; i < sizeof(new->trap_table); i++ ) new->trap_table[i] = 0;
    new->trap_table[SSH_MSG_IGNORE] = 1;	/* vector[0] entry */
    new->trap_table[SSH_MSG_DEBUG] = 2;		/* vector[1] entry */
    new->trap_vector_size = 2;
    new->trap_vector = default_trap_vector;
    /*
     * Initialize buffer table and buffer managment.
     */
    for ( i = 0; i < new->buffer_alloc; i++ ) {
	new->buffer[i].size = 0;
	new->buffer[i].data = (char *) 0;
    }
    /*
     * Initialize i/o context.
     */
    new->out_pad = (cport_isb) 0;
    new->in_pad = (cport_isb) 0;
    new->errmsg[0] = '\0';

    return (sshmsg_locus) new;
}
void sshmsg_destroy_locus ( sshmsg_locus locus )
{
    struct locus *ctx;

    ctx = (struct locus *) locus;
    sshmsg_rundown_locals ( locus );	/* free trap vector */
    
    tm_free ( ctx->buffer );
    tm_free ( ctx );
}
/*
 * Initial sshmsg_local structs supplied by caller with local buffer
 * assignments.
 */
int sshmsg_init_locals ( sshmsg_locus locus, int count, ... )
{
    va_list ap;
    int i;
    sshmsg_local *local;
    struct locus *ctx;
    ctx = (struct locus *) locus;

    va_start(ap, count);		/* adjust argument pointer */
    for ( i = 0; i < count; i++ ) {
	/*
	 * cast next argument in arg list to sshmsg_local and initialize.
	 */
	local = va_arg(ap, sshmsg_local *);
	local->type = SSH_MSG_NONE;
	local->bufnum = ctx->buffer_used++;
	if ( local->bufnum >= ctx->buffer_alloc ) {
	    /*
	     * we ran out of buffer descriptors, resize table.
	     */
	    int i;
	    ctx->buffer = tm_realloc ( ctx->buffer, 
		sizeof(struct local_buffer) * ctx->buffer_alloc*2 );
	    for ( i = 0; i < ctx->buffer_alloc; i++ ) {
		ctx->buffer[i+ctx->buffer_alloc].size = 0;
		ctx->buffer[i+ctx->buffer_alloc].data = (char *) 0;
	    }
	    ctx->buffer_alloc *= 2;
	};
	local->length = 0;
	local->data = "";
	local->locus = locus;
    }
    return 1; 
}
/*
 * Free storage allocated to locals assigned to locus, also resets trap table.
 */
int sshmsg_rundown_locals ( sshmsg_locus locus )
{
    struct locus *ctx;
    int i;

    ctx = (struct locus *) locus;
    /*
     * wait for any pending async I/O to complete.
     */
    /*
     * Reset trap vector.
     */
    if ( ctx->trap_vector != default_trap_vector ) {
	if ( ctx->trap_vector ) tm_free ( ctx->trap_vector );
    }
    ctx->trap_vector_size = 2;
    ctx->trap_table[SSH_MSG_IGNORE] = 1;	/* vector[0] entry */
    ctx->trap_table[SSH_MSG_DEBUG] = 2;		/* vector[1] entry */
    ctx->trap_vector = default_trap_vector;
    /*
     * tm_Free memory allocate to message buffers.
     */
    for ( i = 0; i < ctx->buffer_alloc; i++ ) {
	if ( ctx->buffer[i].size > 0 ) tm_free ( ctx->buffer[i].data );
	ctx->buffer[i].size = 0;
    }
    ctx->buffer_used = 0;
    return 1;
}
/*************************************************************************/
/* Following routines handle moving data into and out of local buffer
 * objects.
 * 
 */
int sshmsg_scan_message ( sshmsg_local *message, ... )
{
    va_list ap;
    char c, *fmt, *mp, *arg_b, *arg_s, *arg_i;	/* current msg position */
    char **arg_p;
    char *mp_end;
    int i, count, items;
    int *arg_l;
    struct locus *locus;
    /*
     * Find locus associate with message and lookup format string to use.
     */
    locus = (struct locus *) message->locus;
    if ( message->type > 0 && message->type < SSHMSG_FORMAT_TABLE_SIZE ) {
	fmt = msg_format[message->type];
    } else fmt = "";
    mp = message->data;
    mp_end = &message->data[message->length];	/* 1 past last byte of data */
    /*
     * scan format string.
     */
    va_start(ap, message);		/* adjust argument pointer */
    for ( i = count = items = 0; fmt[i] != '.'; i++ ) {
	c = fmt[i];
	if ( c >= '0' && c <= '9' ) {
	    count = (count*10) + (c-'0');
	    if ( count > 0x40000 ) return -1;		/* invalid count */
	} else if ( c == 'b' ) {
	    if ( count == 0 ) count = 1;
	    arg_b = va_arg(ap, char *);
	    items++;
	    if ( (mp+count) > mp_end ) return -1;	/* invalid count */
	    memmove ( arg_b, mp, count );
	    mp += count;
	    count = 0;
	} else if ( c == 'i' ) {
	    /*
	     * Assume int * and char * are same on argument list.
	     */
	    if ( (mp+4) > mp_end ) return -1;
	    arg_i = va_arg(ap, char *);
	    arg_i[0] = mp[3];
	    arg_i[1] = mp[2];
	    arg_i[2] = mp[1];
	    arg_i[3] = mp[0];
	    mp += 4;
	    items++;
	} else if ( c == 'm' ) {
	    /*
	     * multi-precision integer, word bit count followed by bits.
	     * only return pointer to bit data.
	     */
	    unsigned char *mpu;
	    int size;
	    if ( (mp+2) > mp_end ) return -1;
	    mpu = (unsigned char *) mp;
            mp += 2;
	    size = (mpu[0] << 8) | mpu[1];
	    arg_l = (int *) va_arg(ap, int *);
	    *arg_l = size;
	    arg_p = va_arg(ap, char **);
	    *arg_p = mp;
	    mp += (size+7)/8;		/* convert bits count to bytes */
	    if ( mp > mp_end ) return -1;
	    items+=2;
	} else if ( c == 's' ) {
	    /*
	     * Counted string.
	     */
	    unsigned char *mpu;
	    int size;
	    mpu = (unsigned char *) mp;
	    if ( (mp+4) > mp_end ) return -1;
            mp += 4;
	    size = (*mpu << 24) | (mpu[1]<<16) | (mpu[2]<<8) | mpu[3];
	    arg_l = va_arg(ap, int *);
	    *arg_l = size;
	    arg_p = va_arg(ap, char **);
	    *arg_p = mp;
	    mp += size;
	    if ( mp > mp_end ) return -1;
	    items+=2;
	} else if ( c == 'x' ) {
	    arg_p = va_arg(ap, char **);
	    *arg_p = mp;
	    if ( count > 0 ) { mp += count; count = 0; }
	    items++;
	}
    }
    va_end(ap);
    return items;
}
/*
 * Utility routine to dynamically resize local buffer to include 'ext' more
 * bytes.  Return value is pointer to newly allocated buffer at pos.
 */
static char *grow_buffer ( struct local_buffer *lbuf, int pos, int ext )
{
    int new_size;
    char *newbuf;
    new_size = (pos+ext+64)&0xfffc0;		/* ensure 1 more byte */
    if ( lbuf->size <= 0 ) newbuf = tm_malloc ( new_size );
    else if ( new_size > lbuf->size ) 
	newbuf = tm_realloc ( lbuf->data, new_size );

    if ( !newbuf ) return newbuf;
    lbuf->size = new_size;
    lbuf->data = newbuf;
    return &newbuf[pos];
}

int sshmsg_format_message ( sshmsg_local *message, int type, ... )
{
    va_list ap;
    char c, *fmt, *mp, *arg_b, *arg_s, *arg_i;	/* current msg position */
    char **arg_p;
    int i, count, items;
    int *arg_l, length, value;
    struct locus *locus;
    struct local_buffer *lbuf;
    /*
     * Find locus associate with message and lookup format string to use.
     */
    locus = (struct locus *) message->locus;
    if ( type > 0 && type < SSHMSG_FORMAT_TABLE_SIZE ) {
	fmt = msg_format[type];
    } else fmt = "";
    message->type = type;
    mp = message->data;
    lbuf = &locus->buffer[message->bufnum];	/* buffer table entry */
    /*
     * scan format string.
     */
    va_start(ap, type);			/* initialize argument pointer */
    length = 0;
    mp = lbuf->data;
    for ( i = count = items = 0; fmt[i] != '.'; i++ ) {
	c = fmt[i];
	if ( c >= '0' && c <= '9' ) {
	    count = (count*10) + (c-'0');
	    if ( count > 0x40000 ) {
		message->length = length;
		message->data = lbuf->data;
		return -1;		/* invalid count */
	    }
	} else if ( c == 'b' ) {
	    if ( count == 0 ) count = 1;
	    arg_b = va_arg(ap, char *);
	    items++;
	    if ( length+count > lbuf->size ) 
		mp = grow_buffer ( lbuf, length, count );
	    memmove ( mp, arg_b, count );
	    mp += count;
	    length += count;
	    count = 0;
	} else if ( c == 'i' ) {
	    /*
	     * Assume int * and char * are same on argument list.
	     */
	    unsigned char *mpu;
	    if ( length+4 > lbuf->size ) mp = grow_buffer ( lbuf, length, 4 );
	    value = va_arg(ap, int);
	    mpu = (unsigned char *) mp;
	    mpu[0] = (value>>24)&255;
	    mpu[1] = (value>>16)&255;
	    mpu[2] = (value>>8)&255;
	    mpu[3] = value&255;
	    mp += 4;
	    length += 4;
	    items++;
	} else if ( c == 'm' ) {
	    /*
	     * multi-precision integer, word bit count followed by bits.
	     */
	    unsigned char *mpu;
	    int size;
	    size = va_arg(ap, int);
	    if ( length+((size+7)/8)+2 > lbuf->size ) 
		mp = grow_buffer ( lbuf, length, ((size+7)/8)+2 );
	    mpu = (unsigned char *) mp;
	    mpu[0] = (size >> 8)&255;
	    mpu[1] = size&255;
            mp += 2;
	    arg_b = va_arg(ap, char *);
	    size = (size+7)/8;
	    memmove ( mp, arg_b, size );
	    mp += size;
	    length += (2 + size);
	    items+=2;
	} else if ( c == 's' ) {
	    /*
	     * Counted string.
	     */
	    unsigned char *mpu;
	    int size;
	    size = va_arg(ap, int);
	    arg_b = va_arg(ap, char *);
	    if ( length+size+4 > lbuf->size ) 
		mp = grow_buffer ( lbuf, length, size+4 );

	    mpu = (unsigned char *) mp;
            mp += 4;
	    mpu[0] = (size>>24)&255;
	    mpu[1] = (size>>16)&255;
	    mpu[2] = (size>>8)&255;
	    mpu[3] = size&255;
	    memmove ( mp, arg_b, size );
	    mp += size;
	    length += (4+size);
	    items+=2;
	} else if ( c == 'x' ) {
	    /*
	     * Raw data, 2 arguments.
	     */
	    unsigned char *mpu;
	    int size;
	    size = va_arg(ap, int);
	    arg_b = va_arg(ap, char *);
	    if ( length+size > lbuf->size ) 
		mp = grow_buffer ( lbuf, length, size );

	    memmove ( mp, arg_b, size );
	    mp += size;
	    length += (size);
	    items+=2;
	}
    }
    /*
     * save final length.
     */
    message->length = length;
    message->data = lbuf->data;
    return items;
}

char *sshmsg_last_error_text(sshmsg_locus locus)
{
    return ((struct locus *) locus)->errmsg;
}
/***************************************************************************/
/* Declare functions used in the start_io function table.  The 'buffer' for
 * reads and writes is a sshmsg_local pointer.
 */
int finish_write ( cport_isb isb )
{
    struct locus *locus;

    locus = isb->drv;
    memcpy ( isb->iosb, locus->out_pad->iosb, 8 );
    return 2;
}
int start_write ( cport_isb isb, int func, void *data, int length )
{
    sshmsg_local *msg;
    struct locus *locus;
    int status;

    msg = (sshmsg_local *) data;
    locus = isb->drv;
    /*
     * Call out_pad stream with special function to set the type.
     */
    status = cport_start_io ( 
	locus->out_pad, CPORT_SSHPAD_SEND_TYPE, &msg->type, 0 );
    if ( (status&1) == 0 ) return status;
    /*
     * Initiate the write, pick up after completion.
     */
    if ( isb->timer || locus->out_pad->timer ) 
	cport_copy_timeout ( isb, locus->out_pad );
    isb->completion_callback = finish_write;
    isb->default_iosb[0] = isb->default_iosb[1] = 0;
    status = cport_start_io ( locus->out_pad, CPORT_WRITE,
	msg->data, msg->length );
    return status;
}

int finish_read ( cport_isb isb )
{
    struct locus *locus;
    sshmsg_local *msg;
    struct local_buffer *buf;
    unsigned short int *iosb;

    locus = isb->drv;
    msg = (sshmsg_local *) isb->buffer;
    iosb = (unsigned short int *) locus->in_pad->iosb;

    if ( (iosb[0]&1) == 1 ) {
	/*
	 * Packet received, fill in message buffer with info in data.
	 */
	msg->type = iosb[3];
	msg->length = iosb[1] | (iosb[2]<<16);

	iosb = (unsigned short int *) isb->iosb;
	iosb[0] = 1;
	iosb[1] = msg->length;
	return 2;
    } else if ( iosb[0] == SS$_DATAOVERUN ) {
	/*
	 * Buffer was too small, allocate larger one and continue.
	 */
	int status, new_size;
	struct local_buffer *buf;
	buf = &locus->buffer[msg->bufnum];

	new_size = iosb[1] | (iosb[2]<<16);

	if ( new_size > buf->size ) {
	    grow_buffer ( buf, 0, new_size );
	    msg->data = buf->data;
	}

	status = cport_start_io ( locus->in_pad, CPORT_SSHPAD_RETRY_READ,
		buf->data, buf->size );
	if ( (status&1) == 0 ) {
	    iosb = (unsigned short int *) isb->iosb;
	    iosb[0] = status;
	    return 2;		/* abort */
	}
	return 0;
    } else {
	/*
	 * I/O error, propagate up.
	 */
	memcpy ( isb->iosb, iosb, 8 );
    }
    return 2;
}

int start_read ( cport_isb isb, int func, void *data, int length )
{
    struct locus *locus;
    sshmsg_local *msg;
    struct local_buffer *buf;
    int status;

    locus = isb->drv;
    msg = (sshmsg_local *) data;
    /*
     * locate buffer being used by msg.
     */
    buf = &locus->buffer[msg->bufnum];
    if ( buf->size < 64 ) {
	grow_buffer ( buf, 0, 64 );
	msg->data = buf->data;
    }
    /*
     * Read into buffer from pad layer.
     */
    isb->completion_callback = finish_read;
    isb->default_iosb[0] = isb->default_iosb[1] = 0;
    if ( isb->timer || locus->in_pad->timer ) 
	cport_copy_timeout ( isb, locus->in_pad );

    status = cport_start_io ( 
	locus->in_pad, CPORT_READ, buf->data, buf->size );
    return status;
}
/****************************************************************************/
/* Declare the functions used in in the stream handler table.
 */
static cport_start_function ftable[2] = {
    start_write,
    start_read
};

static int new_stream ( cport_isb isb, void *context, int dir,
	char errmsg[256] )
{
    struct locus *ctx;

    isb->drv = (struct locus *) context;
    ctx = isb->drv;

    if ( ctx->ref_count == 0 ) {
	/*
	 * Assign streams to lower level driver.
	 */
	struct locus *locus;
	ctx->out_pad = cport_assign_stream ( isb->port, 
		&cportsshpad_driver, ctx->pad, 0);
	if ( ctx->out_pad ) ctx->in_pad = cport_assign_stream ( isb->port, 
		&cportsshpad_driver, ctx->pad, 1 );
	if ( !ctx->out_pad || !ctx->in_pad ) {
	    return cport_last_assign_status ( isb->port, errmsg );
	}
    }

    if ( dir == 0 ) ctx->out_pad->wrapper = isb;
    if ( dir == 1 ) ctx->in_pad->wrapper = isb;
    ctx->ref_count++;
    isb->channel = 0;
    return 1;
}

static int destroy_stream ( cport_isb isb )
{
    int status;
    isb->drv->ref_count--;
    if ( isb->drv->in_pad ) if ( isb->drv->in_pad->wrapper == isb ) {
	isb->drv->in_pad->wrapper = (cport_isb) 0;
	cport_deassign ( isb->drv->in_pad );
    }
    if ( isb->drv->out_pad ) if ( isb->drv->out_pad->wrapper == isb ) {
	isb->drv->out_pad->wrapper = (cport_isb) 0;
	cport_deassign ( isb->drv->out_pad );
    }
    if ( isb->drv->ref_count <= 0 ) {
	tm_free ( isb->drv );
    }
    else status = 1;
    isb->channel = 0;
    return status;
}

static int cancel_io ( cport_isb isb )
{
    struct locus *ctx;
    int status;
    /*
     * Cancel the TCP stuff.
     */
    ctx = isb->drv;
    status = 1;
    if ( isb->drv->in_pad ) if ( isb->drv->in_pad->wrapper == isb ) {
	isb->drv->in_pad->wrapper = (cport_isb) 0;
	status = cport_cancel ( isb->drv->in_pad );
    }
    if ( isb->drv->out_pad ) if ( isb->drv->out_pad->wrapper == isb ) {
	isb->drv->out_pad->wrapper = (cport_isb) 0;
	status = cport_cancel ( isb->drv->out_pad );
    }
    return status;
}
/*
 * The cportucx_driver table will be specified in cport_assign_stream
 * calls as the handler argument.
 */
cport_stream_handler cportsshmsg_driver = {
	1,					/* mask 3 => 4 functions */
	ftable,
	new_stream,
	destroy_stream,
	cancel_io
};
