/****************************************************************************
*																			*
*						 cryptlib MySQL Mapping Routines					*
*						Copyright Peter Gutmann 1997-1999					*
*																			*
****************************************************************************/

/* TODO:

  - This is mostly a direct conversion of the mSQL code to MySQL.  Since I
	don't run MySQL I haven't been able to check the code much.
  - The docs are very vague on the C bindings, places where I wasn't sure
	how things work have '!!!!' inside comments explaining the problem.
*/

#include <stdio.h>
#include <string.h>
#ifdef INC_CHILD
  #include "../crypt.h"
  #include "dbms.h"
#else
  #include "crypt.h"
  #include "misc/dbms.h"
#endif /* INC_CHILD */

/* !!!! dbtest-only !!!! */
#define DEBUG( x )	x
/* !!!! dbtest-only !!!! */

/****************************************************************************
*																			*
*							Unix Database Access Functions					*
*																			*
****************************************************************************/

#ifdef DBX_MYSQL

/* The length of the date/time field when encoded as a text string */

#define DATETIME_SIZE		14
#define TEXT_DATETIME_SIZE	"14"

/* MySQL has a few limits compared to standard SQL.  It doesn't implement
   the usual CREATE INDEX commands (that is, they're present but have no
   effect) so we need to emulate these using ALTER TABLE.  This is somewhat
   inefficient since each ALTER requires copying all information from the
   current table to a temporary one which contains the alteration, deleting
   the old one, and renaming the old one.  However since the CREATE INDEX is
   done when the table is created, there's virtually no overhead as it hasn't
   been populated yet.

   There are two variants we have to handle, the rewrite of CREATE INDEX
   <name>In ON <table> (<column>), which becomes ALTER TABLE <table> ADD
   INDEX <name>In (<column>), and CREATE UNIQUE INDEX ... which is as before
   but with UNIQUE in place of INDEX.  This gets really ugly since we almost
   have to write a miniature SQL parser to obtain the various names and move
   them around for the rewritten expression */

static void convertQuery( char *query, const char *command )
	{
	BOOLEAN uniqueIndex = FALSE;
	char *strPtr;

	strcpy( query, command );
	if( !strncmp( command, "CREATE UNIQUE INDEX", 19 ) )
		uniqueIndex = TRUE;
	if( !strncmp( command, "CREATE INDEX", 12 ) || uniqueIndex )
		{
		char *indexName, *tableName, *columnName;
		int indexNameLen, tableNameLen, columnNameLen, offset;

		/* Determine the positions of the various names.  This assumes that
		   the strings generated by the higher-level code have a fairly
		   fixed format */
		indexName = command + ( ( uniqueIndex ) ? 20 : 13 );
		for( indexNameLen = 0; indexName[ indexNameLen ] != ' '; indexNameLen++ );
		tableName = strstr( indexName, " ON " ) + 4;
		for( tableNameLen = 0; tableName[ tableNameLen ] != ' '; tableNameLen++ );
		columnName = strstr( tableName, " (" ) + 2;
		for( columnNameLen = 0; columnName[ columnNameLen ] != ')'; columnNameLen++ );

		/* Rewrite the CREATE INDEX command as an ALTER TABLE command */
		strcpy( query, "ALTER TABLE " );
		strncpy( query + 12, tableName, tableNameLen );
		strcpy( query + 12 + tableNameLen, ( uniqueIndex ) ? \
				" ADD UNIQUE " : " ADD INDEX " );
		offset = strlen( query );
		strncpy( query + offset, indexName, indexNameLen );
		strcpy( query + offset + indexNameLen, " (" );
		offset += indexNameLen + 2;
		strncpy( query + offset, columnName, columnNameLen );
		strcpy( query + offset + columnNameLen, ")" );
		}

	DEBUG( printf( "XFM: %s\n", query ) );
	}

/* Get information on a MySQL error */

static int getErrorInfo( KEYSET_INFO *keysetInfo, const int defaultStatus )
	{
	const char *mysqlErrorMsg = mysql_error( keysetInfo->keysetDBMS.connection );
	int length = min( strlen( mysqlErrorMsg ), MAX_ERRMSG_SIZE - 1 );

	/* MySQL returns error information as a static string via mysqlErrMsg().
	   Because we can't get a real error code, we have to pick apart the
	   error string to provide more information on certain types of error */
	strncpy( keysetInfo->errorMessage, mysqlErrorMsg, length );
	keysetInfo->errorMessage[ length ] = '\0';
	keysetInfo->errorCode = CRYPT_ERROR;	/* No real error code available */

	/* The only information we can get from mysqlSelectDB() and mysqlQuery()
	   is "OK" or "not OK" (and, in 2.x, the number of items returned for
	   mysqlQuery()), so we have to pick apart the returned error message to
	   find out what went wrong.  This is pretty nasty since it may break if
	   the error messages are ever changed */
	if( ( !strncmp( keysetInfo->errorMessage, "Table", 5 ) && \
		  !strncmp( keysetInfo->errorMessage + length - 6, "exists", 6 ) ) )
		return( CRYPT_DATA_DUPLICATE );

	DEBUG( printf( "Error message:%s\n", keysetInfo->errorMessage ) );
	return( defaultStatus );
	}

/* Open and close a connection to a MySQL server */

static int performUpdate( KEYSET_INFO *keysetInfo, const char *command );

static int openDatabase( KEYSET_INFO *keysetInfo, const char *name,
						 const char *host, const char *user,
						 const char *password )
	{
	char *hostNamePtr = ( char * ) host;
	int status = -1;

	/* Connect to the MySQL server and select the database */
	if( host == NULL )
		hostNamePtr = "localhost";	/* Connect to default host */
	keysetInfo->keysetDBMS.connection = mysql_connect( NULL, hostNamePtr,
													   user, password );
	if( keysetInfo->keysetDBMS.connection == NULL )
		return( CRYPT_DATA_OPEN );
	status = mysql_select_db( keysetInfo->keysetDBMS.connection, name );
	if( status == -1 )
		{
		getErrorInfo( keysetInfo, CRYPT_DATA_OPEN );
		mysql_close( keysetInfo->keysetDBMS.connection );
		keysetInfo->keysetDBMS.connection = NULL;
		return( CRYPT_DATA_OPEN );
		}
	keysetInfo->keysetDBMS.databaseOpen = TRUE;

	/* Set some options to improve performance.  We set the select limit to
	   1 (since we're only every going to retrieve one row), and set MySQL to
	   abort if a select would take a very long time (this shouldn't have any
	   effect on anything created by cryptlib, but it's worth doing anyway
	   for general bulletproofing */
	performQuery( keysetInfo, "SET SQL_SELECT_LIMIT=1" );
	performQuery( keysetInfo, "SET SQL_BIG_SELECTS=1" );

	/* Get the name of the blob data type for this database */
	strcpy( keysetInfo->keysetDBMS.blobName, "BLOB" );

	/* Set source-specific information which we may need later on */
	keysetInfo->keysetDBMS.maxTableNameLen = \
		keysetInfo->keysetDBMS.maxColumnNameLen = 63;
	keysetInfo->keysetDBMS.hasBinaryBlobs = TRUE;
	return( CRYPT_OK );
	}

static void closeDatabase( KEYSET_INFO *keysetInfo )
	{
	mysql_close( keysetInfo->keysetDBMS.connection );
	keysetInfo->keysetDBMS.connection = CRYPT_ERROR;
	keysetInfo->keysetDBMS.databaseOpen = FALSE;
	}

/* Perform a transaction which updates the database without returning any
   data */

static int performUpdate( KEYSET_INFO *keysetInfo, const char *command )
	{
	char query[ MAX_SQL_QUERY_SIZE ];
	int status;

	/* Submit the query to the MySQL server */
	convertQuery( query, command );
	if( keysetInfo->keysetDBMS.isDataUpdate )
		{
		/* !!!! Unsure about how MySQL handles binding for dates !!!! */
		struct tm *timeInfo = gmtime( &keysetInfo->keysetDBMS.date );
		char *datePtr = strchr( query, '?' );
		int length = strlen( query ), ch;

		/* If we can't add the date information, return a data overflow
		   error */
		if( length > MAX_SQL_QUERY_SIZE - DATETIME_SIZE )
			return( CRYPT_OVERFLOW );

		/* Poke the date info into the query string.  This encodes the data
		   in the ISO 8601 format, which allows comparisons like < and >
		   to work properly.  When calculating the size, we use
		   DATETIME_SIZE + 2 to account for the extra ''s needed to demarcate
		   the date string */
		if( datePtr == NULL )
			return( CRYPT_ERROR );	/* Internal error, should never happen */
		memmove( datePtr + DATETIME_SIZE + 1, datePtr,
				 strlen( datePtr ) + 1 );
		ch = datePtr[ DATETIME_SIZE + 2 ];
		sprintf( datePtr, "'%04d%02d%02d%02d%02d%02d'",
				 timeInfo->tm_year + 1900, timeInfo->tm_mon + 1,
				 timeInfo->tm_mday, timeInfo->tm_hour, timeInfo->tm_min,
				 timeInfo->tm_sec );
		datePtr[ DATETIME_SIZE + 2 ] = ch;	/* Restore value zapped by '\0' */
		}

	status = mysql_query( keysetInfo->keysetDBMS.connection, query );
	if( status == -1 )
		return( getErrorInfo( keysetInfo, CRYPT_DATA_WRITE ) );

	return( CRYPT_OK );
	}

/* Perform a transaction which checks for the existence of an object */

static int performCheck( KEYSET_INFO *keysetInfo, const char *command )
	{
	MYSQL_RES *result = NULL;
	char query[ MAX_SQL_QUERY_SIZE ];
	int count;

	/* Submit the query to the MySQL server */
	convertQuery( query, command );
	if( mysql_query( keysetInfo->keysetDBMS.connection, query ) == -1 )
		return( getErrorInfo( keysetInfo, CRYPT_DATA_READ ) );

	/* Store the information returned in a result handle and find out how
	   many rows were returned */
	result = mysql_use_result( keysetInfo->keysetDBMS.connection );
	if( result == NULL )
		count = 0;
	else
		{
		/* !!!! mysql_affected_rows() may be a better way to do this !!!! */
		count = mysql_num_rows( result );
		mysql_free_result( result );
		}

	DEBUG( printf( "performCheck:count = %d\n", count ) );
	return( count );
	}

/* Perform a transaction which returns information */

static int performQuery( KEYSET_INFO *keysetInfo, const char *command,
						 char *data, int *dataLength, const int maxLength )
	{
	MYSQL_RES *result = NULL;
	char query[ MAX_SQL_QUERY_SIZE ];
	int status = CRYPT_OK;

	/* Submit the query to the MySQL server */
	convertQuery( query, command );
	if( mysql_query( keysetInfo->keysetDBMS.connection, query ) == -1 )
		return( getErrorInfo( keysetInfo, CRYPT_DATA_READ ) );

	/* Store the information returned in a result handle and fetch the
	   returned row (this is always just a single value, the key data) */
	result = mysql_use_result( keysetInfo->keysetDBMS.connection );
	if( result == NULL )
		status = CRYPT_DATA_NOTFOUND;
	else
		{
		MYSQL_ROW row;

		row = mysql_fetch_row( result );
		if( row == NULL )
			status = CRYPT_DATA_NOTFOUND;
		else
			{
			*dataLength = strlen( row[ 0 ] );
			if( *dataLength >= maxLength )
				{
				/* Too much data returned */
				*dataLength = 0;
				status = CRYPT_BADDATA;
				}
			else
				strcpy( data, row[ 0 ] );
			}
		mysql_free_result( result );
		}

	DEBUG( printf( "performQuery:dataLength = %d\n", *dataLength ) );
	return( status );
	}

/* Initialise, perform, and wind up a bulk update transaction */

static int performBulkUpdate( KEYSET_INFO *keysetInfo, const char *command )
	{
	int status = CRYPT_OK;

	/* If it's the start or end of a bulk update, lock or unlock the table.
	   For more than about 5 updates in a row, this can be up to 20 times
	   faster than invidual updates on a non-locked table */
	if( keysetInfo->keysetDBMS.bulkUpdateState == BULKUPDATE_START )
		{
		char buffer[ 50 + CRYPT_MAX_TEXTSIZE ];

		sprintf( buffer, "LOCK TABLES %s WRITE", keysetInfo->keysetDBMS.nameTable );
		return( performUpdate( keysetInfo, buffer ) );
		}
	if( keysetInfo->keysetDBMS.bulkUpdateState == BULKUPDATE_FINISH )
		return( performUpdate( keysetInfo, "UNLOCK TABLES" ) );

	/* We're in the middle of a bulk update, perform the update */
	return( performUpdateDatabase( keysetInfo, command ) );
	}

/* Set up the function pointers to the access methods */

int setAccessMethodMySQL( KEYSET_INFO *keysetInfo )
	{
	keysetInfo->keysetDBMS.openDatabase = openDatabase;
	keysetInfo->keysetDBMS.closeDatabase = closeDatabase;
	keysetInfo->keysetDBMS.performUpdate = performUpdate;
	keysetInfo->keysetDBMS.performBulkUpdate = performBulkUpdate;
	keysetInfo->keysetDBMS.performCheck = performCheck;
	keysetInfo->keysetDBMS.performQuery = performQuery;

	return( CRYPT_OK );
	}
#endif /* DBX_MYSQL */
