/* $Id: packet.c,v 1.26 1999/08/10 13:14:27 levitte Exp $ */

#include "gnu_extras.h"

#include <stdio.h>		/* Because some versions of OpenSSL need it */
#include <string.h>
#include <lib$routines.h>
#include <starlet.h>
#ifdef __GNUC__
#define __DECC
#endif
#include <openssl/rand.h>
#ifdef __GNUC__
#undef __DECC
#endif
#include "fish.h"
#include "ssh.h"
#include "crc32.h"
#include "callback.h"
#include "log.h"
#include "util.h"
#include "compress.h"
#include "fishmsg.h"
#include "buffer.h"

void ssh_crypto_setup(ssh_cipher_type cipher, ssh_state *state, int keylen,
		      Erf erf, void *erfp)
{
    char *info = state->info;
    if (info == 0) info = "Cipher";

    switch(cipher) {
    case SSH_CIPHER_DES:
	on_info(ssh_infof("%s: des", info));

	state->encrypt = ssh_des_encrypt;
	state->decrypt = ssh_des_decrypt;
	state->cryptclean = ssh_des_clean;
	state->encryptstate = ssh_des_encryptstate(state,keylen,erf,erfp);
	state->decryptstate = ssh_des_decryptstate(state,keylen,erf,erfp);
	break;
    
    case SSH_CIPHER_3DES:
	on_info(ssh_infof("%s: 3des", info));

	state->encrypt = ssh_3des_encrypt;
	state->decrypt = ssh_3des_decrypt;
	state->cryptclean = ssh_3des_clean;
	state->encryptstate = ssh_3des_encryptstate(state,keylen,erf,erfp);
	state->decryptstate = ssh_3des_decryptstate(state,keylen,erf,erfp);
	break;
    
    case SSH_CIPHER_IDEA:
	on_info(ssh_infof("%s: idea", info));

	state->encrypt = ssh_idea_encrypt;
	state->decrypt = ssh_idea_decrypt;
	state->cryptclean = ssh_idea_clean;
	state->encryptstate = ssh_idea_encryptstate(state,keylen,erf,erfp);
	state->decryptstate = ssh_idea_decryptstate(state,keylen,erf,erfp);
	break;
    
    case SSH_CIPHER_RC4:
	on_info(ssh_infof("%s: rc4", info));

	state->encrypt = ssh_rc4_encrypt;
	state->decrypt = ssh_rc4_decrypt;
	state->cryptclean = ssh_rc4_clean;
	state->encryptstate = ssh_rc4_encryptstate(state,keylen,erf,erfp);
	state->decryptstate = ssh_rc4_decryptstate(state,keylen,erf,erfp);
	break;
    
    case SSH_CIPHER_BLOWFISH:
	on_info(ssh_infof("%s: blowfish", info));

	state->encrypt = ssh_blowfish_encrypt;
	state->decrypt = ssh_blowfish_decrypt;
	state->cryptclean = ssh_blowfish_clean;
	state->encryptstate = ssh_blowfish_encryptstate(state,keylen,erf,erfp);
	state->decryptstate = ssh_blowfish_decryptstate(state,keylen,erf,erfp);
	break;
    
    case SSH_CIPHER_NONE:
	on_info(ssh_infof("%s: none", info));

	state->encrypt = NULL;
	state->decrypt = NULL;
	state->cryptclean = NULL;
	state->encryptstate = NULL;
	state->decryptstate = NULL;
	break;
    
    default:
	ssh_F_invciph();
	return;
    }
}

long ssh_write_type(ssh_msg_type type, general_buffer *data,
    void *state_, Erf erf, void *erfp)
{
    unsigned char *sendbuf;
    unsigned long crc;
    int padlen;
    long olen;
    long ret;
    ssh_state *state = state_;
    unsigned int len = data ? buf_amount(data) : 0;
    general_buffer_reader br;

    bufread_init_static(&br, data);

    on_state(log_verbose(type, &br, state));

    /* Is this supposed to be packetized? */
    if (type == SSH_INTERNAL) {
	on_packet(dump_buffer(stderr, "Sending string", buf_bytes(data), len));
	return (state->down_write)(buf_bytes(data), len,
				   state->down_state, erf, erfp);
    }
    if (type == SSH_MSG_NONE)
	return 0;

    /* Allocate the buffer that will be sent */
    sendbuf = xmalloc(4+8+1+len+4); /* Let's get the maximum possible */
    if(!sendbuf) {
	ssh_F_memfull();
	return -1;
    }

    /* Type */
    sendbuf[4+8] = type;

    /* Data */
    if (data)
	memmove(sendbuf+4+8+1, buf_bytes(data), len);

    /* Make sure the type byte gets counted */
    len++;

    if (state->compress) {
	/* Compress the type byte and the data */
	general_buffer *buf = compress_buffer(0,sendbuf+4+8,len);

	on_packet(dump_buffer(stderr, "Pre compression", sendbuf+4+8, len));

	memcpy(sendbuf+4+8,buf_bytes(buf), buf_amount(buf));
	len = buf_amount(buf);

	buf_destroy(buf);

	on_packet(dump_buffer(stderr, "Post compression", sendbuf+4+8, len));
    }
	
    /* Find out how much padding we need.  The CRC is counted with. */
    padlen = 8 - ((len+4)%8);

    /* Padding */
    if (state->encrypt) {
	/* random pad */
	RAND_bytes(sendbuf+4+8-padlen,padlen);
    } else {
	/* zero pad */
	memset(sendbuf+4+8-padlen,0,padlen);
    }

    /* CRC */
    crc = do_crc(sendbuf+4+8-padlen, padlen+len);
    sendbuf[4+8+len+3] = crc & 0xff; crc >>= 8;
    sendbuf[4+8+len+2] = crc & 0xff; crc >>= 8;
    sendbuf[4+8+len+1] = crc & 0xff; crc >>= 8;
    sendbuf[4+8+len] = crc & 0xff;

    on_packet(dump_buffer(stderr, "Pre encryption",
			  sendbuf+4+8-padlen, padlen+len+4));

    /* Crypt it */
    if (state->encrypt) {
	(state->encrypt)(sendbuf+4+8-padlen, sendbuf+4+8-padlen, padlen+len+4,
			 state->encryptstate);
    }

    on_packet(dump_buffer(stderr, "Post encryption",
			  sendbuf+4+8-padlen, padlen+len+4));

    /* Insert the length */
    olen = len+4;
    sendbuf[4+8-padlen-1] = olen & 0xff; olen >>= 8;
    sendbuf[4+8-padlen-2] = olen & 0xff; olen >>= 8;
    sendbuf[4+8-padlen-3] = olen & 0xff; olen >>= 8;
    sendbuf[4+8-padlen-4] = olen & 0xff;

    /* Send it away */
    ret = (state->down_write)(sendbuf+4+8-padlen-4, padlen+len+8,
			      state->down_state, erf, erfp);
    if (ret > padlen + 8) ret -= (padlen + 8);

    xfree(sendbuf);

    return ret;
}

long ssh_write(void *data, long len, void *state_, Erf erf, void *erfp)
{
    unsigned char *_data = data;
    long ret, totret = 0;
    ssh_state *state = state_;

    CALLBACK_PROLOGUE

    if (len < 0) {
	/* Send the closing message */
	general_buffer tmpbuf;

	buf_init_static(&tmpbuf, "\0\0\0\x16User closed connection", 26);
	ssh_write_type(SSH_MSG_DISCONNECT, &tmpbuf, state_, erf, erfp);

	/* Clean up */
	if (state->dispatchpkt) xfree(state->dispatchpkt);
	if (state->term) xfree(state->term);
	if (state->connid) xfree(state->connid);
	if (state->cryptclean) (state->cryptclean)(state);

	(state->down_write)(_data,len,state->down_state,erf,erfp);

	xfree(state);
	totret = -1;
	goto sshwrite_out;
    }

    /* If we're not in interactive mode, dump the data */
    if (state->protophase != SSH_PHASE_INTERACTIVE) {
	totret = 0;
	goto sshwrite_out;
    }

    while (len) {
	long olen, piece;
	general_buffer *newdata;

	/* Write the data in pieces of size at most MAX_SSH_PACKET_LEN-20 */
	if (len > MAX_SSH_PACKET_LEN-20) {
	    piece = MAX_SSH_PACKET_LEN-20;
	} else {
	    piece = len;
	}
	newdata = buf_init(piece+4);
	if (!newdata) {
	    ssh_F_memfull();
	    totret = -1;
	    goto sshwrite_out;
	}
	buf_append_int(newdata, piece);
	buf_append_bytes(newdata, _data, piece);
	olen = piece;

	ret = ssh_write_type(SSH_CMSG_STDIN_DATA, newdata, state, erf, erfp);

	buf_destroy(newdata);

	if (ret < piece+4) {
	    if (ret < 0) totret = ret;
	    else if (ret > 4) totret += ret-4;
	    break;
	}
	totret += ret-4;

	_data += piece;
	len -= piece;
    }

sshwrite_out:
    CALLBACK_EPILOGUE
    return totret;
}

static void ssh_handle_data(unsigned char *pkt, long size, ssh_state *state,
			    Erf erf, void *erfp)
{
    unsigned long crc, crc_recv;
    long padlen = 8 - (size % 8);
    long padsize = size + padlen;
    general_buffer *buf = 0;
    general_buffer_reader br;
    unsigned int type = 0;

    if (size < 5) {
	ssh_F_shortpkt();
	return;
    }

    on_packet(ssh_infof("%d bytes padding, %d size contents, 4 bytes crc",
			padlen, size-4));

    on_packet(dump_buffer(stderr, "Pre decryption", pkt, padsize));

    /* First we'll have to decrypt it */
    if (state->decrypt) {
	(state->decrypt)(pkt, pkt, padsize, state->decryptstate);
    }

    on_packet(dump_buffer(stderr, "Post decryption", pkt, padsize));

    /* Now check the CRC */
    crc = do_crc(pkt, padsize-4);
    crc_recv = pkt[padsize-4]; crc_recv <<= 8;
    crc_recv += pkt[padsize-3]; crc_recv <<= 8;
    crc_recv += pkt[padsize-2]; crc_recv <<= 8;
    crc_recv += pkt[padsize-1];
    if (crc_recv != crc) {
#if 0 /* we might need it again */
	ssh_debugf("size = %d, padsize = %d", size, padsize);
	ssh_debugf("crc = 0x%X, crc_recv = 0x%X", crc, crc_recv);
	ssh_debugf("around crc spot: 0x%X, 0x%X, 0x%X",
		   *(unsigned long *)(&pkt[padsize - 8]),
		   *(unsigned long *)(&pkt[padsize - 4]),
		   *(unsigned long *)(&pkt[padsize]));
#endif
	ssh_F_crc();
	return;
    }

    pkt += padlen;
    size -= 4;

    if (state->compress) {
	on_packet(dump_buffer(stderr, "Pre decompression", pkt, size));

	buf = decompress_buffer(0, pkt, size);
	pkt = buf_chars_noadjust(buf);
	size = buf_amount(buf);

	on_packet(dump_buffer(stderr, "Post decompression", pkt, size));
    } else {
	buf = buf_init(size);
	buf_append_bytes(buf, pkt, size);
    }
    bufread_init_static(&br, buf);
    bufread_bytes(&br, &type, 1);

    /* Handle the packet */
    ssh_handle_packet_type(type, &br, state, erf, erfp);

    if (buf != 0)
	buf_destroy(buf);
}

/* This function is called when data arrives */
void ssh_dispatch(unsigned char *data, long len, void *state_, Erf erf,
    void *erfp)
{
    ssh_state *state = state_;

    CALLBACK_PROLOGUE

    if (len < 0) {
	/* The connection was closed; pass on the info */
	goto ssh_dispatch_out;
    }

    while(len) {
	/* Is this supposed to be depacketized? */
	if (state->protophase <= SSH_PHASE_VERSION_WAIT) {
	    int i;
	    int piece;

	    /* Collect until we see a \n */
	    piece = len;
	    for(i=1;i<piece && data[i-1]!='\n';++i) ;
	    /* data+i-1 now points to \n if there was one, or the end of the
	       received data if not */

	    if (!state->dispatchpkt) {
		state->dispatchsize = piece;
		state->dispatchpkt = xmalloc(piece);
		state->dispatchread = 0;
	    } else if (state->dispatchread + i > state->dispatchsize) {
		state->dispatchsize += piece;
		state->dispatchpkt = xrealloc(state->dispatchpkt,
					      state->dispatchsize);
	    }
	    if (!state->dispatchpkt) {
		ssh_F_memfull();
		goto ssh_dispatch_out;
	    }
	    memmove(state->dispatchpkt+state->dispatchread, data, i);
	    data += i;
	    len -= i;
	    state->dispatchread += i;

	    if (i < piece || data[-1] == '\n') {
		general_buffer buf;
		general_buffer_reader br;

		on_packet(ssh_infof("Packet piece length %d, %d more bytes left",
				     i, len));
		on_packet(dump_buffer(stderr, "Received string",
				      state->dispatchpkt, state->dispatchread));

		buf_init_static(&buf, state->dispatchpkt, state->dispatchread);
		buf_adjust_amount_by(&buf, state->dispatchread);
		bufread_init_static(&br, &buf);

		ssh_handle_packet_type(SSH_INTERNAL, &br, state, erf, erfp);
		xfree(state->dispatchpkt);
		state->dispatchpkt = NULL;
		state->dispatchread = 0;
		state->dispatchsize = 0;
		state->dispatchpadamt = 0;
		state->dispatchsizeread = 0;
	    }
	} else {
	    /* See if we need to complete an existing partial packet */
	    if (state->dispatchsize + state->dispatchpadamt) {
		/* Yes; how much more to read? */
		long left = (state->dispatchsize + state->dispatchpadamt)
		    - state->dispatchread;
		if (left > len) left = len;

		memmove(state->dispatchpkt+state->dispatchread, data, left);
		state->dispatchread += left;
		data += left;
		len -= left;

		if ((state->dispatchsize + state->dispatchpadamt) ==
		    state->dispatchread) {
		    /* We've got the whole thing now */
		    ssh_handle_data(state->dispatchpkt, state->dispatchsize,
				    state, erf, erfp);
		    state->dispatchsize = 0;
		    state->dispatchpadamt = 0;
		    xfree(state->dispatchpkt);
		    state->dispatchpkt = NULL;
		}
	    } else {
		/* We're waiting for (some, possibly all) of the next packet
		   size */
		long left = 4 - state->dispatchsizeread;
		if (left > len) left = len;

		memmove(state->dispatchsizebuf+state->dispatchsizeread, data,
			left);
		state->dispatchsizeread += left;
		data += left;
		len -= left;

		if (state->dispatchsizeread == 4) {
		    /* We have the whole length now */
		    state->dispatchsize = state->dispatchsizebuf[0];
		    state->dispatchsize <<= 8;
		    state->dispatchsize += state->dispatchsizebuf[1];
		    state->dispatchsize <<= 8;
		    state->dispatchsize += state->dispatchsizebuf[2];
		    state->dispatchsize <<= 8;
		    state->dispatchsize += state->dispatchsizebuf[3];
		    state->dispatchsizeread = 0;
		    state->dispatchpadamt = 8 - (state->dispatchsize % 8);
		    state->dispatchpkt = xmalloc(state->dispatchsize +
						 state->dispatchpadamt);
		    state->dispatchread = 0;
		    if (!state->dispatchpkt) {
			ssh_F_memfull();
			goto ssh_dispatch_out;
		    }
		}
	    }
	}
    }

ssh_dispatch_out:
    ;
    CALLBACK_EPILOGUE
}

/* Emacs local variables

Local variables:
eval: (set-c-style "BSD")
end:

*/
