#include "model.h"

m_type	Model(int length, double npseudo, long *counts, a_type A)
/* create and return a null model containing only pseudo counts */
{
	m_type	M;
	int j,r,b;
	
	NEW(M,1,model_type);
	M->A = A;
	M->log_alpha = log((double)nAlpha(A));
	M->length = length;
	M->npseudo = MAX(double,npseudo,0.01);
	M->totsites = 0;
	NEW(M->freq,nAlpha(A)+2,double);
	NEW(M->counts,nAlpha(A)+2,long);
	for(M->tot_cnts=0.0, r=1; r<= nAlpha(A); r++){
		M->counts[r]= counts[r];
		M->tot_cnts += counts[r];
	}
	NEW(M->N0,nAlpha(M->A)+2,double);
	NEW(M->temp,nAlpha(M->A)+2,double);
	NEWP(M->site_freq, length+1, long);
	NEWP(M->likelihood, length+1, double);
	for(r=1; r<= nAlpha(A); r++) {
		M->freq[r] = (double) counts[r]/(double) M->tot_cnts;
		M->N0[r]= M->npseudo * M->freq[r];
	}
	for( j = 0; j <= length; j++){	/* 0 = nonsitefreq */
	   NEW(M->site_freq[j],nAlpha(M->A) +2, long);
	   NEW(M->likelihood[j],nAlpha(M->A) +2, double);
	}
	M->update = TRUE;
	return M;
}

void	InitModel(m_type M)
/* initialize model M to the null model */
{
	int	j,b;

	M->totsites = 0;
	for( j = 0; j <= M->length; j++){	/* 0 = nonsitefreq */
	   for(b=0; b <= nAlpha(M->A); b++){ M->site_freq[j][b] = 0; }
	}
	M->update = TRUE;
}

m_type NilModel(m_type M)
/* destroy model M */
{
   int j;
   if(M!=NULL){ 
	free(M->freq); free(M->counts); free(M->N0); free(M->temp);
	for( j = 0; j <= M->length; j++){
	    free(M->site_freq[j]); free(M->likelihood[j]);
	}
	free(M->site_freq); free(M->likelihood); free(M);
   }
   return (m_type) NULL;
}

double	RelProbModel(register char *seq, register m_type M)
/*  calculate the relative probability that seq is in the model */
{
        register int j;
        register double p=1.0;

	if(M->update) update_model_freqN(M);
        for(seq--,j=M->length; j > 0; j--){
                p *= M->likelihood[j][(seq[j])];
        }
        return p;
}

void    RmModel(char *seq, int site, m_type M)
/** Remove the segment starting at site in sequence seq from the model **/
{
        int j;
        for(j=1; j<=M->length; j++,site++){
                M->site_freq[j][seq[site]] -= 1;
        }
	M->totsites--;
	M->update = TRUE;
}

void	Add2Model(char *seq, int site, m_type M)
/** Add the segment starting at site in sequence seq to the model **/
{
	int j;
	for(j=1; j<=M->length; j++,site++){
	    	M->site_freq[j][seq[site]] += 1;
	}
	M->totsites++;
	M->update = TRUE;
}

void	update_model_freqN(register m_type M)
/* recalculate the normalized frequencies to avoid overflow */
{
	register int	j,b;
	register double factor,f;
	register a_type	A=M->A;

	/***** NORMALIZE NONSITE FREQUENCIES *****/
	for(b=1; b<=nAlpha(A); b++)M->site_freq[0][b]=M->counts[b];
	for(j=1; j<=M->length; j++){ /* DETERMINE # COUNTS IN NONSITES */
		for(b=0; b <= nAlpha(A); b++){
			M->site_freq[0][b] -= M->site_freq[j][b];
		}
	}
	for(factor=0.0,b=1; b <= nAlpha(A); b++) 
		factor += (double) M->site_freq[0][b];
	factor += M->npseudo;
	for(b=0; b <= nAlpha(A); b++){
	   if(M->N0[b] > 0.0){
        	M->likelihood[0][b] = (M->site_freq[0][b]+M->N0[b])/factor;
	   }
	}
	factor = (double) (M->totsites+M->npseudo);
	/***** NORMALIZE SITE FREQUENCIES *****/
        for(j=1; j<=M->length; j++){
	   M->likelihood[j][0] = 1.0;
           for(b=1; b <= nAlpha(A); b++){
                M->likelihood[j][b] = (M->site_freq[j][b]+M->N0[b])/(factor);
                M->likelihood[j][b] /= M->likelihood[0][b];
           }
        }
	M->update = FALSE;
}

double	RatioModel(boolean left, long *site_freq, int d, m_type M)
/* Return the log likelihood ratio of new to old site_freq */
{
	double	ratio;
	int	b;
	long	*msite_freq;

	if(site_freq == NULL || d < 1 || 2*d > M->length) return 0.0;
	if(left) msite_freq = M->site_freq[d];
	else msite_freq = M->site_freq[(M->length - d +1)];
	for(ratio=0.0,b=1;b<= nAlpha(M->A); b++) {
           if(M->N0[b] != 0.0){
                ratio += lgamma((site_freq[b]+M->N0[b]))
                        - lgamma((msite_freq[b]+M->N0[b]));
           }
	}
	return exp(ratio);
}

void	ShiftModel(long *site_freq, boolean left, m_type M)
/* shift model to left (if left == TRUE) or to right (if left == FALSE) 
   adding site_freq to the new end */
{
	int	i,b;

	for(b=1;b<= nAlpha(M->A); b++) site_freq[b] +=  M->N0[b];
	if(left){   /* free 1; 1..w-1 = 2..w; w = new */
		free(M->site_freq[1]);
		for(i= 1; i<M->length; i++) {
			M->site_freq[i] = M->site_freq[i+1];
		}
		M->site_freq[M->length] = site_freq;
	} else {     /* free w; w..2 = w-1..1; 1 = new */
		free(M->site_freq[M->length]);
		for(i= M->length; i>1; i--)
			M->site_freq[i] = M->site_freq[i-1];
		M->site_freq[1] = site_freq;
	}
	M->update = TRUE;
}

double	LogLikeModel(m_type  M)
/* log likelihood for model */
{
	double	term1,term1b;
	double	n,r,d,df,total;
	int	j,b,end;
	long	*counts = M->counts;

	for(term1=0.0, j=1; j<= M->length; j++){
		for(d=0.0,b=1; b<= nAlpha(M->A); b++){
			d += (double) (M->site_freq[j][b]);
		} 
		d += M->npseudo;
		for(b=1; b<= nAlpha(M->A); b++){
		    if((n=(double)M->site_freq[j][b]) > 0.0){
			term1 += n * log((n+M->N0[b])/d);
		    }
		}
	}
	for(total=0.0, b=1; b<= nAlpha(M->A); b++){
		for(d=0.0,j=1; j<= M->length; j++){
			d += (double) M->site_freq[j][b]; 
		}
		M->temp[b] = (double) counts[b] - d;
		total += M->temp[b];
	}
	total += M->npseudo;
	for(term1b=0.0, b=1; b<= nAlpha(M->A); b++){
	    if(M->temp[b] > 0.0){
		n = M->temp[b];
		r = (n+M->N0[b])/total;
		term1b += n * log(r);
	    }
	}
	return (1.4427*(term1 + term1b));
}

void	PutModel(FILE *fptr, m_type M)
/* Report the current frequency model. */
{
	int	j,b,r;
	double	total;

	fprintf(fptr,"POS  ");
	for(total=0.0, b = 1; b <= nAlpha(M->A); b++){
	    total += (double) M->site_freq[1][b];
	    fprintf(fptr,"%3c", AlphaChar(b, M->A));
	}
	fprintf(fptr,"\n");
	for(j=1; j<= M->length; j++){
	    fprintf(fptr,"%4d ",j);
	    for(b = 1; b <= nAlpha(M->A); b++){
	    	r = (int)((100.0*(M->site_freq[j][b]+M->N0[b])/total)+0.49);
	    	fprintf(fptr,"%3d", r);
	    }
	    fprintf(fptr,"\n");
	}
	fprintf(fptr,"non-\nsite ");	
	for(b = 1; b <= nAlpha(M->A); b++){
		r = (int) ((100.0 * (M->N0[b]/M->npseudo))+0.49);
	    	fprintf(fptr,"%3d", r);
	}
}


