/***************************************************************************
 *                                                                         *
 *                  (begin: Feb 20 2003)                                   *
 *                                                                         *
 *   Parallel IQPNNI - Important Quartet Puzzle with NNI                   *
 *                                                                         *
 *   Copyright (C) 2005 by Le Sy Vinh, Bui Quang Minh, Arndt von Haeseler  *
 *   Copyright (C) 2003-2004 by Le Sy Vinh, Arndt von Haeseler             *
 *   {vinh,minh}@cs.uni-duesseldorf.de                                     *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 *   This program is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   along with this program; if not, write to the                         *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 ***************************************************************************/

#include <math.h>
#include <iostream>

#include "optpairseq.h"
#include "mat.h"
#include "model.h"
#include "brent.h"
#include "constant.h"
#include "rate.h"
#include "ptnls.h"

#ifdef _OPENMP
#include <omp.h>
#endif

OptPairSeq opt_pairseq;

//the constructor function
OptPairSeq::OptPairSeq () {}

//this function is used for constructing this class
void OptPairSeq::doConstructor () {}

//-------------------------------------------------------------
//compute the observed distance between two sequences
double OptPairSeq::cmpObsDis (Seq &seq1, Seq &seq2) {
	int nDif = 0;
	for (int site_ = 0; site_ < seq1.getNSite (); site_ ++)
		if (seq1[site_] != seq2[site_] && seq1[site_] != BS_UNKNOWN && seq2[site_] != BS_UNKNOWN)
			nDif ++;

	return static_cast<double> (nDif) / seq1.getNSite ();
}


//-------------------------------------------------------------
//compute the negative log likelihood when the distance between them is brLen
double OptPairSeq::cmpLogLiSpecificRate (double brLen, Seq* seq1_, Seq* seq2_) {
	int nPtn_ = ptnlist.getNPtn ();

	Vec<double> *ptnRateArr_;
	ptnRateArr_ = myrate.getPtnRate ();


	DVec20 stateFrqArr_;
	mymodel.getStateFrq (stateFrqArr_);

	double logLi_ = 0.0;
	for (int ptnNo_ = 0; ptnNo_ < nPtn_; ptnNo_ ++) {
		int siteNo_ = ptnlist.siteArr_ [ptnNo_];
		int stateNo1_ = (*seq1_)[siteNo_];
		int stateNo2_ = (*seq2_)[siteNo_];

		double ptnRate_ = (*ptnRateArr_)[ptnNo_];
		double prob_ = mymodel.cmpProbChange (stateNo1_, stateNo2_, brLen * ptnRate_);

		double logPtnLi_ = log ( stateFrqArr_[stateNo1_] * prob_);
		logLi_ += logPtnLi_ * ptnlist.weightArr_[ptnNo_];
	}

	return logLi_;
}

//compute the negative log likelihood when the distance between them is brLen
double OptPairSeq::cmpLogLiGammaRate(double brLen, PtnMat& ptnMat_) {
	double logLi_ = 0.0;
	int stateNo1_, stateNo2_;
	int nState_ = mymodel.getNState ();

	double probRate_ = myrate.getProb ();
	DMat20 probChangeCube_[MAX_NUM_RATE];
	int rateNo_;
	int nRate_ = myrate.getNRate ();
	if (myrate.isNsSyHeterogenous()) {
		for (rateNo_ = 0; rateNo_ < nRate_; rateNo_ ++) {
			mymodel.cmpProbChange (brLen, rateNo_, probChangeCube_[rateNo_]);
		}
		
	} else
	for (rateNo_ = 0; rateNo_ < nRate_; rateNo_ ++) {
		double rate_ = myrate.getRate (rateNo_);
		mymodel.cmpProbChange (brLen * rate_, probChangeCube_[rateNo_]);
	}

	DVec20 stateFrqArr_;
	mymodel.getStateFrq (stateFrqArr_);
	//int maxsNo = nState_ * nState_;

	for (stateNo1_ = 0; stateNo1_ < nState_; stateNo1_ ++)
		for (stateNo2_ = 0; stateNo2_ < nState_; stateNo2_ ++) {
	/**/
			double liPtn_ = 0.0;

			//the rate of this pattern follows gamma distribution
			for (rateNo_ = 0; rateNo_ < nRate_; rateNo_ ++) {
				double probChange_ = probChangeCube_[rateNo_][stateNo1_][stateNo2_];
				if (myrate.isNsSyHeterogenous())
					liPtn_ += mymodel.getClassProb(rateNo_) * stateFrqArr_[stateNo1_] * probChange_;
				else
					liPtn_ += probRate_ * stateFrqArr_[stateNo1_] * probChange_;
			}

			logLi_ += log (liPtn_) * ptnMat_[stateNo1_][stateNo2_];
		}

	return logLi_;
}

//compute the negative log likelihood when the distance between them is brLen
double OptPairSeq::cmpNegLogLi(double brLen, Seq* seq1, Seq* seq2, PtnMat &ptnMat) {
	double logLi_;
	if (myrate.getType () == SITE_SPECIFIC && myrate.isOptedSpecificRate () == 1)
		logLi_ = cmpLogLiSpecificRate (brLen, seq1, seq2);
	else
		logLi_ = cmpLogLiGammaRate (brLen, ptnMat);

	return -logLi_;
}

//-------------------------------------------------------------
//create the pair pattern matrix of two sequences
void OptPairSeq::createPairPtn (Seq* seq1_, Seq* seq2_, PtnMat &ptnMat_) {
	int stateNo1_, stateNo2_;
	int nState_ = mymodel.getNState ();
	for (stateNo1_ = 0; stateNo1_ < nState_; stateNo1_ ++)
		for (stateNo2_ = 0; stateNo2_ < nState_; stateNo2_ ++)
			ptnMat_[stateNo1_][stateNo2_] = 0;

	for (int siteNo_ = 0; siteNo_ < seq1_->getSize (); siteNo_ ++) {
		stateNo1_ = (*seq1_)[siteNo_];
		stateNo2_ = (*seq2_)[siteNo_];
		if (stateNo1_ < nState_ && stateNo2_ < nState_)
			ptnMat_[stateNo1_][stateNo2_] ++;
	}
}

//-------------------------------------------------------------
//return the genetic distance between two sequences
double OptPairSeq::cmp (Seq &seq1, Seq &seq2, double initGenDis) {
	double genDis_;
	//  seq1.writeOut ();
	//  seq2.writeOut ();
	double jCDis_;
	double observedDis_ = cmpObsDis (seq1, seq2);
	PtnMat ptnMat;
	if (initGenDis < 0.0) {
		//the case, two sequences are too far from each other
		if (1.0 - (4.0 / 3.0) * observedDis_ < ZERO)
			jCDis_ = observedDis_;
		else
			jCDis_ =  - (3.0 / 4.0) * log (1.0 - (4.0 / 3.0) * observedDis_);

		//    genDis_ = jCDis_ * COF_BR;
		genDis_ = jCDis_;
	} else
		genDis_ = initGenDis;

	//if two sequences are identical, of course, the distance between them is 0.0
	if (seq1.getId () == seq2.getId () )
		return 0.0;
	else {
		//seq1_ = &seq1;
		//seq2_ = &seq2;
		createPairPtn (&seq1, &seq2, ptnMat);
	}

	//put the distance into the range of a posible length
	if (genDis_ <= MIN_BR_LEN)
		genDis_ = MIN_BR_LEN + 1.0;
	if (genDis_ >=MAX_BR_LEN)
		genDis_ = MAX_BR_LEN - 1.0;

	double fx, f2x;

	genDis_ = optOneDim (MIN_BR_LEN, genDis_, MAX_BR_LEN,
	                            EPS_OPT_PAIR_SEQ_BR, &fx, &f2x, &seq1, &seq2, ptnMat);


	return genDis_;
}

//--------------------------------------------------------------------
//release the memory of this class
void OptPairSeq::release () {}

//-------------------------------------------------------------
//the destructor function
OptPairSeq::~OptPairSeq () {
	release ();
//	std::cout << "this is destructor of optimal pair sequences class " << endl;
}


/******************************************************************************/
/* minimization of a function by Brents method (Numerical Recipes)            */
/******************************************************************************/


//-------------------------------------------------------------------------------------------------------
#define ITMAX 100
#define CGOLD 0.3819660
#define GOLD 1.618034
#define GLIMIT 100.0
#define TINY 1.0e-20
#define ZEPS 1.0e-10
#define SHFT(a,b,c,d) (a)=(b);(b)=(c);(c)=(d);
#define SIGN(a,b) ((b) >= 0.0 ? fabs(a) : -fabs(a))

/* Brents method in one dimension */
double OptPairSeq::opt_brent (double ax, double bx, double cx, double tol,
                   double *foptx, double *f2optx, double fax, double fbx, double fcx, Seq* seq1, Seq* seq2, PtnMat &ptnMat) {
	int iter;
	double a,b,d=0,etemp,fu,fv,fw,fx,p,q,r,tol1,tol2,u,v,w,x,xm;
	double xw,wv,vx;
	double e=0.0;

	a=(ax < cx ? ax : cx);
	b=(ax > cx ? ax : cx);
	x=bx;
	fx=fbx;
	if (fax < fcx) {
		w=ax;
		fw=fax;
		v=cx;
		fv=fcx;
	} else {
		w=cx;
		fw=fcx;
		v=ax;
		fv=fax;
	}

	for (iter=1;iter<=ITMAX;iter++) {
		xm=0.5*(a+b);
		tol2=2.0*(tol1=tol*fabs(x)+ZEPS);
		if (fabs(x-xm) <= (tol2-0.5*(b-a))) {
			*foptx = fx;
			xw = x-w;
			wv = w-v;
			vx = v-x;
			*f2optx = 2.0*(fv*xw + fx*wv + fw*vx)/
			          (v*v*xw + x*x*wv + w*w*vx);
			return x;
		}

		if (fabs(e) > tol1) {
			r=(x-w)*(fx-fv);
			q=(x-v)*(fx-fw);
			p=(x-v)*q-(x-w)*r;
			q=2.0*(q-r);
			if (q > 0.0)
				p = -p;
			q=fabs(q);
			etemp=e;
			e=d;
			if (fabs(p) >= fabs(0.5*q*etemp) || p <= q*(a-x) || p >= q*(b-x))
				d=CGOLD*(e=(x >= xm ? a-x : b-x));
			else {
				d=p/q;
				u=x+d;
				if (u-a < tol2 || b-u < tol2)
					d=SIGN(tol1,xm-x);
			}
		} else {
			d=CGOLD*(e=(x >= xm ? a-x : b-x));
		}

		u=(fabs(d) >= tol1 ? x+d : x+SIGN(tol1,d));
		fu=cmpNegLogLi(u, seq1, seq2, ptnMat);
		if (fu <= fx) {
			if (u >= x)
				a=x;
			else
				b=x;

			SHFT(v,w,x,u)
			SHFT(fv,fw,fx,fu)
		} else {
			if (u < x)
				a=u;
			else
				b=u;
			if (fu <= fw || w == x) {
				v=w;
				w=u;
				fv=fw;
				fw=fu;
			} else
				if (fu <= fv || v == x || v == w) {
					v=u;
					fv=fu;
				}
		}
	}

	*foptx = fx;
	xw = x-w;
	wv = w-v;
	vx = v-x;
	*f2optx = 2.0*(fv*xw + fx*wv + fw*vx)/(v*v*xw + x*x*wv + w*w*vx);

	return x;
}

#undef ITMAX
#undef CGOLD
#undef ZEPS
#undef SHFT
#undef SIGN
#undef GOLD
#undef GLIMIT
#undef TINY


//-------------------------------------------------------------------------------------------------------
/* one-dimensional minimization - as input a lower and an upper limit and a trial
value for the minimum is needed: xmin < xguess < xmax
the function and a fractional tolerance has to be specified
onedimenmin returns the optimal x value and the value of the function
and its second derivative at this point
*/
double OptPairSeq::optOneDim(double xmin, double xguess, double xmax, 
                        double tol, double *fx, double *f2x, Seq* seq1, Seq* seq2, PtnMat &ptnMat) {
	double eps, optx, ax, bx, cx, fa, fb, fc;

	eps = tol;

	ax = xguess - eps;
	if (ax < xmin)
		ax = xmin;
	bx = xguess;
	cx = xguess + eps;
	if (cx > xmax)
		cx = xmax;
	/* check if this works */
	fa = cmpNegLogLi(ax, seq1, seq2, ptnMat);
	fb = cmpNegLogLi(bx, seq1, seq2, ptnMat);
	fc = cmpNegLogLi(cx, seq1, seq2, ptnMat);
	if (fa >= fb && fc >= fb)
		return bx;

	eps = xguess * 0.5;

	if (fa <= fb) {
		ax = xguess - eps;
		if (ax < xmin)
			ax = xmin;
		fa = cmpNegLogLi (ax, seq1, seq2, ptnMat);
	}

	if (fc <= fb) {
		cx = xguess + eps;
		if (cx > xmax)
			cx = xmax;
		fc = cmpNegLogLi (cx, seq1, seq2, ptnMat);
	}


	/* if it works use these borders else be conservative */
	if ((fa < fb) || (fc < fb)) {
		if (ax != xmin) {
			ax = xmin;
			fa = cmpNegLogLi(ax, seq1, seq2, ptnMat);
		}

		if (ax == xmin) {
			double fa1 = cmpNegLogLi(ax + 0.5 * tol, seq1, seq2, ptnMat);
			if ( fa1 >  fa) {
				double fa2 = cmpNegLogLi(ax + tol, seq1, seq2, ptnMat);
				if (fa2 > fa1)
					return xmin;
			}
		}

		if (cx != xmax)
			fc = cmpNegLogLi(xmax, seq1, seq2, ptnMat);
		optx = opt_brent(xmin, xguess, xmax, tol, fx, f2x, fa, fb, fc, seq1, seq2, ptnMat);
	} else {
		optx = opt_brent(ax, bx, cx, tol, fx, f2x, fa, fb, fc, seq1, seq2, ptnMat);
	}

	return optx; /* return optimal x */
}
