// -*- mode:c++;tab-width:2;indent-tabs-mode:t;show-trailing-whitespace:t;rm-trailing-spaces:t -*-
// vi: set ts=2 noet:
//
// (c) Copyright Rosetta Commons Member Institutions.
// (c) This file is part of the Rosetta software suite and is made available under license.
// (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
// (c) For more information, see http://www.rosettacommons.org. Questions about this can be
// (c) addressed to University of Washington UW TechTransfer, email: license@u.washington.edu.

/// @file   core/scoring/custom_pair_distance/FullatomCustomPairDistanceEnergy.cc
/// @brief
/// @author David E Kim


// Unit headers
#include <core/scoring/custom_pair_distance/FullatomCustomPairDistanceEnergy.hh>
#include <core/scoring/custom_pair_distance/FullatomCustomPairDistanceEnergyCreator.hh>

#include <core/options/option.hh>
#include <core/options/keys/score.OptionKeys.gen.hh>

// Package headers
#include <core/scoring/ScoreFunction.hh>
#include <core/scoring/Energies.hh>

// Project headers
#include <core/pose/Pose.hh>
#include <core/conformation/Residue.hh>
#include <core/pose/datacache/CacheableDataType.hh>
#include <core/io/database/open.hh>
#include <core/chemical/ChemicalManager.hh>
#include <core/chemical/ResidueTypeSet.hh>
#include <core/util/Tracer.hh>

// Utility headers

// Numeric headers

// option key includes





// C++

namespace core {
namespace scoring {
namespace custom_pair_distance {


/// @details This must return a fresh instance of the FullatomCustomPairDistanceEnergy class,
/// never an instance already in use
methods::EnergyMethodOP
FullatomCustomPairDistanceEnergyCreator::create_energy_method(
	methods::EnergyMethodOptions const &
) const {
	return new FullatomCustomPairDistanceEnergy;
}

ScoreTypes
FullatomCustomPairDistanceEnergyCreator::score_types_for_method() const {
	ScoreTypes sts;
	sts.push_back( fa_cust_pair_dist );
	return sts;
}


static util::Tracer tr("core.scoring.custom_pair_distance.FullatomCustomPairDistanceEnergy");

// constructor
FullatomCustomPairDistanceEnergy::FullatomCustomPairDistanceEnergy() :
	parent( new FullatomCustomPairDistanceEnergyCreator )
{
  set_pair_and_func_map();
}

FullatomCustomPairDistanceEnergy::FullatomCustomPairDistanceEnergy( FullatomCustomPairDistanceEnergy const & src ):
	parent( src )
{}

/// clone
methods::EnergyMethodOP
FullatomCustomPairDistanceEnergy::clone() const
{
	return new FullatomCustomPairDistanceEnergy( *this );
}


///
void
FullatomCustomPairDistanceEnergy::residue_pair_energy(
	conformation::Residue const & rsd1,
	conformation::Residue const & rsd2,
	pose::Pose const & /*pose*/,
	ScoreFunction const &,
	TwoBodyEnergyMap & emap
) const
{
	ResTypePair respair;
	respair[1] = rsd1.type();
	respair[2] = rsd2.type();
	PairFuncMap::const_iterator respairiter = pair_and_func_map_.find( respair );
	if ( respairiter == pair_and_func_map_.end() ) return;
	Energy score = 0.0;
	for (	std::list<atoms_and_func_struct>::const_iterator atom_func_iter = (*respairiter).second.begin(),
				iter_end = (*respairiter).second.end();
				atom_func_iter != iter_end; ++atom_func_iter) {
		Vector const& atom_a_xyz( rsd1.xyz((*atom_func_iter).resA_atom_index_));
		Vector const& atom_b_xyz( rsd2.xyz((*atom_func_iter).resB_atom_index_));
		score += (*atom_func_iter).func_->func(atom_a_xyz.distance_squared(atom_b_xyz));

		//if ((*atom_func_iter).func_->func(atom_a_xyz.distance_squared(atom_b_xyz)) > 0)
		//tr.Debug << rsd1.name() << " " << rsd1.seqpos() << " " << rsd1.atom_name((*atom_func_iter).resA_atom_index_) << " " <<
		//	rsd2.name() << " " << " " << rsd2.seqpos() << " " << rsd2.atom_name((*atom_func_iter).resB_atom_index_) <<
		//	" score: " << (*atom_func_iter).func_->func(atom_a_xyz.distance_squared(atom_b_xyz)) << " dist_sq: " << atom_a_xyz.distance_squared(atom_b_xyz) << std::endl;
	}
	emap[ fa_cust_pair_dist ] += score;
}

/// @brief FullatomCustomPairDistanceEnergy distance cutoff
Distance
FullatomCustomPairDistanceEnergy::atomic_interaction_cutoff() const
{
	return max_dis_;
}

void
FullatomCustomPairDistanceEnergy::setup_for_minimizing(
	pose::Pose & pose,
	ScoreFunction const & /*sfxn*/,
	optimization::MinimizerMap const & /*min_map*/
) const
{
	using namespace util::datacache;

  if( !pose.data().has( pose::datacache::CacheableDataType::CUSTOM_PAIR_DIST_SCORE_INFO ) ) {
    pose.data().set( pose::datacache::CacheableDataType::CUSTOM_PAIR_DIST_SCORE_INFO, new CacheableAtomPairFuncMap );
  }
  CacheableDataOP dat( pose.data().get_ptr( pose::datacache::CacheableDataType::CUSTOM_PAIR_DIST_SCORE_INFO ) );
  CacheableAtomPairFuncMap *cachemap = (CacheableAtomPairFuncMap*)dat();
  ResAtomIndexFuncMap & resatomfuncmap(cachemap->map());

	// Identify which residue and atom pairs to evaluate in the pose by looping
	// through the pair_and_func_map keys (residue pairs to evaluate).
	// Save the data in the pose data cache for evaluating derivative.
	for (PairFuncMap::const_iterator respairiter = pair_and_func_map_.begin(),
	     respairiter_end = pair_and_func_map_.end();
			 respairiter != respairiter_end; ++respairiter) {
		ResTypePair respair = (*respairiter).first.data();
		for (Size i=1; i<=pose.total_residue(); ++i) {
			if (chemical::ResidueTypeCAP( & (pose.residue_type(i))) == respair[1]) {
				for (Size j=1; j<=pose.total_residue(); ++j) {
						if (chemical::ResidueTypeCAP( & (pose.residue_type(j))) == respair[2]) {
							for ( std::list<atoms_and_func_struct>::const_iterator iter_a = (*respairiter).second.begin(),
										iter_end_a = (*respairiter).second.end(); iter_a != iter_end_a; ++iter_a) {
								if (i == j && (*iter_a).resA_atom_index_ == (*iter_a).resB_atom_index_) continue; // skip same atom
								// add pair
								resatom_and_func_struct resatom_b;
								ResAtomIndex rai_a;
								rai_a[1] = i;
								rai_a[2] = (*iter_a).resA_atom_index_;
								resatom_b.res_index_ = j;
								resatom_b.atom_index_ = (*iter_a).resB_atom_index_;
								resatom_b.func_ = (*iter_a).func_;
								resatomfuncmap[rai_a].push_back( resatom_b );
						}
					}
				}
			}
		}
	}
}


void
FullatomCustomPairDistanceEnergy::eval_atom_derivative(
	id::AtomID const & id,
	pose::Pose const & pose,
	kinematics::DomainMap const &, //domain_map,
	ScoreFunction const &, // sfxn,
	EnergyMap const & emap,
	Vector & F1,
	Vector & F2
) const
{
	using namespace util::datacache;
	ResAtomIndex resatom;
	resatom[1] = id.rsd();
	resatom[2] = id.atomno();
	CacheableDataCOP dat( pose.data().get_const_ptr( pose::datacache::CacheableDataType::CUSTOM_PAIR_DIST_SCORE_INFO ) );
	CacheableAtomPairFuncMap const *cachemap = (CacheableAtomPairFuncMap const *)dat();
  ResAtomIndexFuncMap const & resatomfuncmap(cachemap->map());
	ResAtomIndexFuncMap::const_iterator resatomiter = resatomfuncmap.find( resatom );
  if ( resatomiter == resatomfuncmap.end() ) return;

	for ( std::list<resatom_and_func_struct>::const_iterator iter_a = (*resatomiter).second.begin(),
				iter_end_a = (*resatomiter).second.end(); iter_a != iter_end_a; ++iter_a) {
		// determine distance squared
		Vector const& atom_a_xyz( pose.residue(id.rsd()).atom(id.atomno()).xyz() );
		Vector const& atom_b_xyz( pose.residue((*iter_a).res_index_).atom((*iter_a).atom_index_).xyz() );
		Real dist_sq = atom_a_xyz.distance_squared(atom_b_xyz);
		if (dist_sq < (*iter_a).func_->min_dis() || dist_sq > (*iter_a).func_->max_dis()) continue;
		Vector const f1( atom_a_xyz.cross( atom_b_xyz ));
		Vector const f2( atom_a_xyz - atom_b_xyz );
		Real const dist( f2.length() );
		assert(dist != 0);
		Real deriv = (*iter_a).func_->dfunc(dist_sq);
		F1 += ( deriv / dist ) * emap[ fa_cust_pair_dist ] * f1;
		F2 += ( deriv / dist ) * emap[ fa_cust_pair_dist ] * f2;

	}
}

void
FullatomCustomPairDistanceEnergy::set_pair_and_func_map()
{
	using namespace options;
	using namespace options::OptionKeys;
	using namespace chemical;

	std::string pairfuncfile = (option[score::fa_custom_pair_distance_file].user()) ?
		option[score::fa_custom_pair_distance_file]() :
		io::database::full_name( "scoring/score_functions/custom_pair_distance/fa_custom_pair_distance" );

	tr.Debug << "Reading fa_custom_pair_distance_file: " << pairfuncfile << std::endl;

	utility::io::izstream in( pairfuncfile );
	if ( !in.good() )
		utility_exit_with_message( "Unable to open fa_custom_pair_distance file: " + pairfuncfile );

	std::string line, residue_type_set, score_function_name;
	utility::vector1< std::string > resA, resB, atomA, atomB;

	while ( getline( in, line ) ) {
		if ( line.size() < 1 || line[0] == '#' ) continue;
		std::istringstream l( line );
		std::string tag;
		l >> tag;
		if (tag == "RESIDUE_TYPE_SET") {
			resA.clear(); resB.clear(); atomA.clear(); atomB.clear();
			l >> residue_type_set;
			tr.Debug << "Residue type set: " <<  residue_type_set << std::endl;
			continue;
		}
		if (tag == "PAIR") {
			std::string buf;
			utility::vector1< std::string > pair_cols;
			while( l >> buf ) {
				pair_cols.push_back( buf );
			}
			if (pair_cols.size() == 4) {
				resA.push_back( pair_cols[1] );
				atomA.push_back( pair_cols[2] );
				resB.push_back( pair_cols[3] );
				atomB.push_back( pair_cols[4] );
				tr.Debug << "Pair: " << pair_cols[1] << " " <<  pair_cols[2] << " " << pair_cols[3] << " " << pair_cols[4] << std::endl;
				continue;
			}
		}
		if (tag == "SCORE_FUNCION") {
			l >> score_function_name;

			atoms_and_func_struct pair_func;
			pair_func.func_ = new DistanceFunc( score_function_name );
			if (pair_func.func_->max_dis() > max_dis_)
				max_dis_ = pair_func.func_->max_dis();
			tr.Debug << "SCORE_FUNCION: " << score_function_name << " (min: " <<
				pair_func.func_->min_dis() << " max: " << pair_func.func_->max_dis() << ")" << std::endl;

			ResidueTypeSetCAP restype_set =
				ChemicalManager::get_instance()->residue_type_set( residue_type_set );

			// get all possible residue types for each residue pair
			for (Size i = 1; i <= resA.size(); ++i) {
				ResidueTypeCAPs const & possible_res_types_a = restype_set->name3_map( resA[i] );
				ResidueTypeCAPs const & possible_res_types_b = restype_set->name3_map( resB[i] );
				for ( Size j = 1; j <= possible_res_types_a.size(); ++j ) {
					ResidueTypeCAP const & rsd_type_a = possible_res_types_a[ j ];
					Size atom_index_a = rsd_type_a->atom_index( atomA[i] );
					for ( Size k = 1; k <= possible_res_types_b.size(); ++k ) {
						ResidueTypeCAP const & rsd_type_b = possible_res_types_b[ k ];
						Size atom_index_b = rsd_type_b->atom_index( atomB[i] );
						ResTypePair respair;

						respair[1] = rsd_type_a;
						respair[2] = rsd_type_b;
						pair_func.resA_atom_index_ =  atom_index_a;
						pair_func.resB_atom_index_ =  atom_index_b;
						pair_and_func_map_[ respair ].push_back( pair_func );

						// save the mirrored pair if the types are not equal so if AB or
						// BA gets passed to residue_pair_energy they'll be evaluated.
						// I tried to prevent having to do this by sorting the pair but
						// couldn't get it to work
						if (rsd_type_a != rsd_type_b) {
							respair[1] = rsd_type_b;
							respair[2] = rsd_type_a;
							pair_func.resA_atom_index_ =  atom_index_b;
							pair_func.resB_atom_index_ =  atom_index_a;
							pair_and_func_map_[ respair ].push_back( pair_func );
						}
					}
				}
			}
		}
	}
}


DistanceFunc::DistanceFunc( std::string const name ) {
	utility::io::izstream scores_stream;
	io::database::open( scores_stream, "scoring/score_functions/custom_pair_distance/" + name);
	scores_hist_ = new numeric::interpolation::Histogram<Real,Real>( scores_stream() );
	scores_stream.close();
}

DistanceFunc::~DistanceFunc() {}
Real DistanceFunc::func( Real const dist_sq ) const {
		Real e(0.0);
		if ( dist_sq < scores_hist_->minimum() ||
			   dist_sq > scores_hist_->maximum() ) return e;
		scores_hist_->interpolate(dist_sq,e);
		return e;
}

Real DistanceFunc::dfunc( Real const dist_sq ) const {
	Real df(0.0), e(0.0);
  if ( dist_sq < scores_hist_->minimum() ||
		dist_sq > scores_hist_->maximum() ) return df;
	scores_hist_->interpolate(dist_sq,e,df);
	return df;
}

Real DistanceFunc::max_dis() const {
	return scores_hist_->maximum();
}

Real DistanceFunc::min_dis() const {
	return scores_hist_->minimum();
}

}
}
}
