// -*- mode:c++;tab-width:2;indent-tabs-mode:t;show-trailing-whitespace:t;rm-trailing-spaces:t -*-
// vi: set ts=2 noet:
//  CVS information:
//  $Revision: 14275 $
//  $Date: 2007-04-16 16:52:48 -0700 (Mon, 16 Apr 2007) $
//  $Author: johnk $

//Utility Headers
#include <utility/basic_sys_util.hh>

//////////////////////////////////////////////


// Rosetta Headers
#include "pose_constraints.h"
#include "aaproperties_pack.h" // nheavyatoms
#include "current_pose.h"
#include "jumping_util.h" // build_ideal_*
#include "kin_util.h" // chainbreak_phi_atom
#include "kin_coords.h"
//#include "misc.h" // pose_flag
#include "pose.h"
#include "param_aa.h"
#include "refold.h" // Charlie's GL stuff
#include "refold_ns.h"
#include "util_basic.h"
#include "util_vector.h"

// ObjexxFCL Headers
#include <ObjexxFCL/FArray1D.hh>
#include <ObjexxFCL/FArray3D.hh>

// C++ Headers
#include <cmath>
#include <cstdlib>
#include <iostream>
#include <map>

///////////////////////////////////////////////////////////////////////////////
bool
pose_constraints_exist(
	pose_ns::Pose const & pose,
	int const pos1,
	int const pos2
)
{
	return pose.constraints_exist() &&
		pose.constraints().residue_pair_constraint_exists( pos1, pos2 );
}


///////////////////////////////////////////////////////////////////////////////
float
pose_get_res_res_cstE(
	pose_ns::Pose const & pose,
	int const res1,
	int const res2,
	int const aa1,
	int const aav1,
	int const aa2,
	int const aav2,
	FArray2Da_float coord1,
	FArray2Da_float coord2,
	bool const bb1,
	bool const sc1,
	bool const bb2,
	bool const sc2
)
{
	if ( !pose.constraints_exist() ) return 0.0;
	return pose.constraints().get_res_res_cstE( res1, res2, aa1, aav1,
					 aa2, aav2, coord1, coord2, bb1, sc1, bb2, sc2 );
}


////////////////////////////////////////////////////////////////////////
// this function is for calculating derivatives of atom pair constraints
// it now also includes coordinate tethers
//
// this gets called from jmp_get_atompairE_deriv()
//

void fill_atom_cst_F1_F2(
	pose_ns::Pose const & pose,
	float const kin_3D_cst_weight,
	float const atompair_cst_weight,
	float const coord_cst_weight,
	float const chainbreak_cst_weight,
	int const seqpos,
	int const atomno,
	numeric::xyzVector_float & F1,
	numeric::xyzVector_float & F2
)
{
	using namespace cst_set_ns;


	if ( !pose.constraints_exist() ) return;

	Cst_set const & pose_constraints( pose.constraints() );
	FArray1D_int const & jmp_domain_map ( pose.get_domain_map() );//

	kin::Atom_id atom( atomno, seqpos );
	kin::Coords_FArray_const coords( pose.full_coord() );

	assert( atomno <= param::MAX_ATOM()() );
	assert( seqpos <= pose.total_residue() );

	if ( kin_3D_cst_weight ) {
		pose_constraints.atom_angle_deriv( atom, coords, kin_3D_cst_weight,
																					 F1, F2 );

		pose_constraints.atom_torsion_deriv( atom, coords, kin_3D_cst_weight,
																						 F1, F2 );
	}

	if ( atompair_cst_weight != 0.0 ) {
		pose_constraints.atompair_cst_atom_deriv( atom, coords, jmp_domain_map,
																							atompair_cst_weight, F1, F2 );
	}

	if ( coord_cst_weight != 0 ) {
		pose_constraints.coord_cst_atom_deriv( atom, coords, coord_cst_weight,
																					 F1, F2 );
	}

	/////////////////////////////////////////////////////////////////////////////
	// chainbreak derivative -- could move this to a separate function...
	if ( chainbreak_cst_weight != 0 ) {
		static FArray2D_float overlap_xyz(3,6);
		static numeric::xyzVector_float xyz;

		if ( pose_constraints.is_chainbreak( seqpos ) ) {
			int const cutpoint( seqpos );
			// dE/dr divided by r:
			float const factor( 2.0 * chainbreak_cst_weight *
													pose_constraints.chainbreak_weight( cutpoint ) );
			if ( atomno == 1 || atomno == 2 || atomno == 3 ) {
				// N,CA,C
				coords.get_xyz( atom, xyz );
				build_overlap_coords( cutpoint, pose.psi(cutpoint), 180.0,
					pose.phi( cutpoint+1 ), pose.full_coord(), overlap_xyz );
				numeric::xyzVector_float
					pseudo_xyz ( &(overlap_xyz( 1, atomno ) ) ),
					f1         ( cross( xyz, pseudo_xyz ) ),
					f2         ( xyz - pseudo_xyz );
				F1 += factor * f1;
				F2 += factor * f2;

			} else if ( atomno == 4 ) {
				// O -- when O is placed the overlap pseudo atoms are placed also
				// so sum the contributions from the pseudo atoms
				build_overlap_coords( cutpoint, pose.psi(cutpoint), 180.0,
					pose.phi( cutpoint+1 ), pose.full_coord(), overlap_xyz );
				for ( int i=1; i<= 3; ++i ) {
					numeric::xyzVector_float
						xyz        ( &(pose.full_coord()(1,i,cutpoint+1))),
						pseudo_xyz ( &(overlap_xyz(1,3+i))),
						f1         ( cross( pseudo_xyz, xyz ) ),
						f2         ( pseudo_xyz - xyz );
					F1 += factor * f1;
					F2 += factor * f2;
				}
			}
		}
		if ( pose_constraints.is_chainbreak( seqpos-1 ) ) {
			int const cutpoint( seqpos-1 );
			// dE/dr divided by r:
			float const factor( 2.0 * chainbreak_cst_weight *
													pose_constraints.chainbreak_weight( cutpoint ) );
			int const phi_atomno( chainbreak_phi_atomno( pose.res( cutpoint+1 ),
															pose.res_variant( cutpoint+1 ) ) );
			if ( atomno == 1 || atomno == 2 || atomno == 3 ) {
				// N,CA,C
				coords.get_xyz( atom, xyz );
				build_overlap_coords( cutpoint, pose.psi(cutpoint), 180.0,
					pose.phi( cutpoint+1 ), pose.full_coord(), overlap_xyz );

				numeric::xyzVector_float
					pseudo_xyz ( &(overlap_xyz( 1, 3+atomno ) ) ),
					f1         ( cross( xyz, pseudo_xyz ) ),
					f2         ( xyz - pseudo_xyz );
				F1 += factor * f1;
				F2 += factor * f2;

			} else if ( atomno == phi_atomno ) {
				// HN or pro-CD, when this is placed the overlap pseudo atoms are
				// placed also so sum the contributions from the pseudo atoms
				build_overlap_coords( cutpoint, pose.psi(cutpoint), 180.0,
					pose.phi( cutpoint+1 ), pose.full_coord(), overlap_xyz );

				for ( int i=1; i<= 3; ++i ) {
					numeric::xyzVector_float
						xyz        ( &(pose.full_coord()(1,i,cutpoint))),
						pseudo_xyz ( &(overlap_xyz(1,i))),
						f1         ( cross( pseudo_xyz, xyz ) ),
						f2         ( pseudo_xyz - xyz );
					F1 += factor * f1;
					F2 += factor * f2;
				}
			}
		}
	}
}


///////////////////////////////////////////////////////////////////////////////
void
build_overlap_coords(
	int const cutpoint,
	float const psi, // cutpoint
	float const omega, // cutpoint
	float const phi, // cutpoint + 1
	FArray3DB_float const & coords,
	FArray2D_float & overlap_xyz,
	bool const Epos_index, // = false
	int const dir // = 0
)
{
	int const N_index(1), CA_index(2), C_index( Epos_index ? 4 : 3 );
	int const n2c(1), c2n(-1);
	//float const omega(180.0);

	//FArray3D_float const & coords( pose.full_coord() );
	FArray1D_float c_xyz(3), n_xyz(3), ca_xyz(3);

	build_ideal_C_coords( phi, coords(1,N_index,cutpoint+1),
		coords(1,CA_index,cutpoint+1), coords(1,C_index,cutpoint+1), c_xyz );

	build_ideal_N_CA_coords( psi, omega,
		coords(1,N_index,cutpoint), coords(1,CA_index,cutpoint),
		coords(1,C_index,cutpoint), n_xyz, ca_xyz );

	FArray2D_float Mgl( 4, 4 );
	if ( !dir || dir == c2n ) {
		// build overlap_xyz 1-3
		// by mapping the existing coords for n,ca,c of cutpoint
		// by a transformation that maps c,n_xyz,ca_xyz to c_xyz,n,ca
		get_GL_matrix( ca_xyz,
									 n_xyz,
									 coords(1,  C_index, cutpoint),
									 coords(1, CA_index, cutpoint+1),
									 coords(1,  N_index, cutpoint+1),
									 c_xyz,
									 Mgl);
		GL_rot( Mgl, coords(1,  N_index, cutpoint ), overlap_xyz(1,1) );
		GL_rot( Mgl, coords(1, CA_index, cutpoint ), overlap_xyz(1,2) );
		GL_rot( Mgl, coords(1,  C_index, cutpoint ), overlap_xyz(1,3) );
	}

	if ( !dir || dir == n2c ) {
		// build overlap_xyz 4-6 -- could just invert the transform??
		get_GL_matrix( coords(1, CA_index, cutpoint+1),
									 coords(1,  N_index, cutpoint+1),
									 c_xyz,
									 ca_xyz,
									 n_xyz,
									 coords(1,  C_index, cutpoint),
									 Mgl);
		GL_rot( Mgl, coords(1,  N_index, cutpoint+1 ), overlap_xyz(1,4) );
		GL_rot( Mgl, coords(1, CA_index, cutpoint+1 ), overlap_xyz(1,5) );
		GL_rot( Mgl, coords(1,  C_index, cutpoint+1 ), overlap_xyz(1,6) );
	}

	return;

	if ( !dir ) { // debugging
		// confirm that overlap_xyz(1,3) == c_xyz
		// confirm that overlap_xyz(1,4) == n_xyz
		// confirm that vector from 1,4 to 1,5 == n_xyz->ca_xyz
		numeric::xyzVector_float
			n ( &( n_xyz(1))),
			ca( &(ca_xyz(1))),
			c ( &( c_xyz(1))),
			o3( &(overlap_xyz(1,3))),
			o4( &(overlap_xyz(1,4))),
			o5( &(overlap_xyz(1,5)));
		float const
			dev1( distance( c, o3 ) ),
			dev2( distance( n, o4 ) ),
			dev3( distance( ( ca - n ).normalized(), ( o5 - o4 ).normalized() ) );
		std::cout << "debug dev123: " << dev1 << ' ' << dev2 << ' ' << dev3 <<
			std::endl;
		assert( dev1<1e-3 && dev2<1e-3 && dev3<1e-3 );
	}
}


///////////////////////////////////////////////////////////////////////////////


void
fill_cst_set(
	cst_set_ns::Cst_set & cst_set,
	pose_ns::Pose const & pose,
	bool const & just_use_bb_heavy_atoms
)
{
	using namespace cst_set_ns;

	int const cutoff_atm( 2 ); // c-alpha
	float const dist2_threshold( 144.0 );


	bool const fullatom( pose.fullatom() );
	if ( !fullatom && !just_use_bb_heavy_atoms ) {
		std::cout << "just for fullatom!!" << std::endl;
		utility::exit( EXIT_FAILURE, __FILE__, __LINE__);
	}

	int const nres( pose.total_residue() );
	FArray1D_int const & res( pose.res() );
	FArray1D_int const & res_variant( pose.res_variant() );
	FArray3D_float const & full_coord( pose.full_coord() );

	int count(0);
	int natm1;
	int natm2;
	for ( int i=1; i<= nres; ++i ) {
		int const aa1( res(i) );
		int const aav1( res_variant(i) );
		if ( just_use_bb_heavy_atoms ) {
			natm1 = 4;
		} else {
			natm1 = aaproperties_pack::nheavyatoms(aa1,aav1);
		}
		for ( int j=i+1; j<= nres; ++j ) {
			int const aa2( res(j) );
			int const aav2( res_variant(j) );
			if ( just_use_bb_heavy_atoms ) {
				natm2 = 4;
			} else {
				natm2 = aaproperties_pack::nheavyatoms(aa2,aav2);
			}
			if ( vec_dist2( full_coord( 1, cutoff_atm, i ),
											full_coord( 1, cutoff_atm, j ) ) < dist2_threshold ) {
				for ( int ii=1; ii<= natm1; ++ii ) {
					kin::Atom_id atom1(ii,i);
					for ( int jj=1; jj<= natm2; ++jj ) {
						kin::Atom_id atom2(jj,j);
						float const r ( std::sqrt( vec_dist2( full_coord( 1, ii, i ),
																									full_coord( 1, jj, j ) ) ) );
						Cst cst(r);
						cst_set.add_atompair_constraint( atom1, atom2, cst );
						++count;
					}
				}
			}
		}
	}
	std::cout << "total atompair constraints: " << count << std::endl;

	// now do the torsions
	for ( int i=1; i<= nres; ++i ) {
		int const aa(res(i));
		int const aav(res_variant(i));
		int const nchi( fullatom ? aaproperties_pack::nchi(aa,aav) : 0 );
		for ( int tor=1; tor<= param_torsion::total_bb_torsion + nchi; ++tor ) {
			cst_set.add_rosetta_torsion_constraint( i, tor,
				pose.get_torsion_by_number( i, tor ) );
		}
	}

	// now do the coordinate tether
	if ( fullatom ) {
		for ( int i=1; i<= nres; ++i ) {
			int const aa(res(i));
			int const aav(res_variant(i));
			int const natoms( aaproperties_pack::natoms(aa,aav) );
			for ( int j=1; j<= natoms; ++j ) {
				kin::Atom_id  atom(j,i);
				cst_set.add_coordinate_constraint( atom,
					numeric::xyzVector_float( &full_coord(1,j,i) ) );
			}
		}
	}
}



///////////////////////////////////////////////////////////////////////////////
float
calc_pose_constraint_scores(
	pose_ns::Pose & pose, // non-const for setting scores
	float const phipsi_cst_weight,
	float const omega_cst_weight,
	float const chi_cst_weight,
	float const kin_1D_cst_weight,
	float const kin_3D_cst_weight,
	float const atompair_cst_weight,
	float const coord_cst_weight,
	float const chainbreak_cst_weight
)
{
	using namespace pose_ns;

	if ( !pose.constraints_exist() ) return 0.0;

	cst_set_ns::Cst_set const & cst_set( pose.constraints() );

	float score_val( 0.0 );

	// degree-of-freedom constraints
	if ( ( std::abs(phipsi_cst_weight) > 0.0001 ) || ( std::abs(omega_cst_weight) > 0.0001 ) ||
			 ( std::abs(chi_cst_weight) > 0.0001 ) || ( std::abs(kin_1D_cst_weight) > 0.0001 ) ) {
		float phipsi_cst_score, omega_cst_score, chi_cst_score,
			kin_1D_cst_score;
		cst_set.torsion1D_score( pose, phipsi_cst_score, omega_cst_score,
														 chi_cst_score, kin_1D_cst_score );
		pose.set_0D_score( PHIPSI_CST, phipsi_cst_score );
		pose.set_0D_score(  OMEGA_CST,  omega_cst_score );
		pose.set_0D_score(    CHI_CST,    chi_cst_score );
		pose.set_0D_score( KIN_1D_CST, kin_1D_cst_score );
		score_val += phipsi_cst_weight * phipsi_cst_score;
		score_val +=  omega_cst_weight *  omega_cst_score;
		score_val +=    chi_cst_weight *    chi_cst_score;
		score_val += kin_1D_cst_weight * kin_1D_cst_score;
	}


	// atom-atom distance constraints
	if ( std::abs(atompair_cst_weight) > 0.0001 ) {
		kin::Coords_FArray_const coords( pose.full_coord() );
		float const this_score
			( cst_set.atompair_cst_score( coords ) );
		pose.set_0D_score( ATOMPAIR_CST, this_score );
		score_val += atompair_cst_weight * this_score;
	}


	// coordinate constraints
	if ( std::abs(coord_cst_weight) > 0.0001 ) {
		kin::Coords_FArray_const coords( pose.full_coord() );
		float const this_score
			( cst_set.coord_cst_score( coords ) );
		pose.set_0D_score( COORD_CST, this_score );
		score_val += coord_cst_weight * this_score;
	}


	// 3d torsion and angle constraints on arbitrary atoms
	if ( std::abs(kin_3D_cst_weight) > 0.0001 ) {
		kin::Coords_FArray_const coords( pose.full_coord() );
		float const angle_score
			( cst_set.atom_angle_score( coords ) );
		float const torsion_score
			( cst_set.atom_torsion_score( coords ) );
		pose.set_0D_score( KIN_3D_CST, angle_score + torsion_score );
		score_val += kin_3D_cst_weight * ( angle_score + torsion_score );
	}

	// chainbreaks -- could be separate routine
	if ( std::abs(chainbreak_cst_weight) > 0.0001 ) {
		int const nres( pose.total_residue() );
		FArray3D_float const & fcoord( pose.full_coord() );
		FArray2D_float overlap_xyz(3,6);
		float score(0.0);
		for ( int i=1; i<= nres-1; ++i ) {
			if ( cst_set.is_chainbreak( i ) ) {
				build_overlap_coords( i, pose.psi(i), 180.0, pose.phi(i+1),
															pose.full_coord(), overlap_xyz );
				float dev(0.0);
				for ( int k=1; k<= 3; ++k ) {
					dev +=
						square( overlap_xyz(k,1) - fcoord(k,1,i  ) ) +
						square( overlap_xyz(k,2) - fcoord(k,2,i  ) ) +
						square( overlap_xyz(k,3) - fcoord(k,3,i  ) ) +
						square( overlap_xyz(k,4) - fcoord(k,1,i+1) ) +
						square( overlap_xyz(k,5) - fcoord(k,2,i+1) ) +
						square( overlap_xyz(k,6) - fcoord(k,3,i+1) );
				}
				score += dev * cst_set.chainbreak_weight( i );
//				std::cout << "chainbreak: " << i << ' ' << dev << ' ' <<
//					pose.psi(i) << ' ' << pose.phi(i+1) << std::endl;
			}
		}
		pose.set_0D_score( CHAINBREAK_CST, score );
		score_val += chainbreak_cst_weight * score;
	}

	return score_val;
}


