// -*- mode:c++;tab-width:2;indent-tabs-mode:t;show-trailing-whitespace:t;rm-trailing-spaces:t -*-
// vi: set ts=2 noet:
//
// (c) Copyright Rosetta Commons Member Institutions.
// (c) This file is part of the Rosetta software suite and is made available under license.
// (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
// (c) For more information, see http://www.rosettacommons.org. Questions about this can be
// (c) addressed to University of Washington UW TechTransfer, email: license@u.washington.edu.

/// @file devel/ProteinInterfaceDesign/util.cc
/// @brief
/// @author Sarel Fleishman (sarelf@u.washington.edu), Jacob Corn (jecorn@u.washington.edu)
// Project Headers
#include <protocols/ProteinInterfaceDesign/util.hh>
#include <core/types.hh>
#include <core/pose/Pose.hh>
#include <core/pose/PDBPoseMap.hh>
#include <core/pose/PDBInfo.hh>
#include <core/conformation/Conformation.hh>
#include <core/scoring/constraints/BackboneStubConstraint.hh>
#include <core/scoring/constraints/CoordinateConstraint.hh>
#include <core/scoring/constraints/AmbiguousConstraint.hh>
#include <core/scoring/ScoreFunction.hh>
#include <core/scoring/ScoreFunctionFactory.hh>
#include <core/scoring/ScoreTypeManager.hh>
#include <core/scoring/ScoreType.hh>
#include <core/scoring/dssp/Dssp.hh>
#include <core/chemical/util.hh>

#include <protocols/moves/Mover.hh>
#include <protocols/hotspot_hashing/HotspotStub.hh>

#include <protocols/moves/MoverStatus.hh>
#include <numeric/random/random.hh>
#include <core/options/option.hh>

#include <protocols/ProteinInterfaceDesign/DockDesign.fwd.hh>
#include <protocols/loops/loops_main.hh> // for remove_cutpoint_variants

// C++ headers
#include <map>
#include <algorithm>
#include <core/chemical/util.hh>

// Utility Headers
#include <utility/string_util.hh>
#include <utility/vector1.hh>
#include <utility/Tag/Tag.hh>

// Unit Headers
#include <protocols/filters/Filter.hh>
#include <protocols/ProteinInterfaceDesign/dock_design_filters.hh>


// option key includes
#include <core/options/keys/out.OptionKeys.gen.hh>
#include <core/options/keys/packing.OptionKeys.gen.hh>

static core::util::Tracer TR( "protocols.ProteinInterfaceDesign.util" );

namespace protocols {
namespace ProteinInterfaceDesign {

using namespace core::scoring;
using namespace protocols::moves;
using namespace core;
using namespace std;
using utility::vector1;

core::scoring::constraints::ConstraintCOPs
get_bbcsts( core::pose::Pose const & pose ) {
	using namespace core::scoring;
	using namespace core::scoring::constraints;

	ScoreFunctionOP scorefxn( new ScoreFunction );
	scorefxn->set_weight( backbone_stub_constraint, 1.0 );
	core::pose::Pose  nonconst_pose = pose;
	// pre-score to get the active cst's
	(*scorefxn)(nonconst_pose);
	/// Now handled automatically.  scorefxn->accumulate_residue_total_energies( nonconst_pose );

	// sort through pose's constraint set and pull out only active bbcst's
	utility::vector1< ConstraintCOP > original_csts = nonconst_pose.constraint_set()->get_all_constraints();
	ConstraintCOPs new_csts;
	for( utility::vector1< ConstraintCOP >::const_iterator it = original_csts.begin(), end = original_csts.end(); it != end; ++it ) {
		ConstraintCOP cst( *it );
		if( cst->type() == "AmbiguousConstraint" ) {
			AmbiguousConstraintCOP ambiguous_cst = AmbiguousConstraintCOP( dynamic_cast<AmbiguousConstraint const *>( cst() ) ); //downcast to derived ambiguous constraint
			if( ambiguous_cst) { // safety check for downcasting
				if( ambiguous_cst->active_constraint()->type() == "BackboneStub" ) {
					new_csts.push_back( ambiguous_cst->active_constraint() );
				}
			}
		}
	}
	return new_csts;
}

/// @detailed a utility function to evaluate backbone_stub_constraints for each residue in a chain and return a vector with the top n_return residue numbers by cst score
/// note that this function is NOT guaranteed to return n_return residues! It will return the best n<=n_return
utility::vector1< core::Size >
best_bbcst_residues( core::pose::Pose const & pose, core::Size const chain, core::Size const n_return )
{
	using namespace core::scoring;
	using namespace core::scoring::constraints;
	core::pose::Pose  nonconst_pose = pose;
	utility::vector1< std::pair<core::Real, core::Size> > all_residues; // score, seqpos
	utility::vector1< core::Size > best_residues;

	// scorefxn containing only the constraint energy
	ScoreFunctionOP scorefxn( new ScoreFunction );
	scorefxn->set_weight( backbone_stub_constraint, 1.0 );

	// score to make sure that the cst energies are current
	(*scorefxn)(nonconst_pose);
	/// Now handled automatically.  scorefxn->accumulate_residue_total_energies( nonconst_pose );

	// get pairs of residue number and weighted cst energy
	for( core::Size i=nonconst_pose.conformation().chain_begin( chain ); i<=nonconst_pose.conformation().chain_end( chain ); ++i ){
		core::Real const score( nonconst_pose.energies().residue_total_energies( i )[ backbone_stub_constraint ] );
		core::Real const weight( (*scorefxn)[ backbone_stub_constraint ] );
		core::Real const curr_energy( weight * score );
		all_residues.push_back(std::make_pair( curr_energy, i ));
	}
	sort( all_residues.begin(), all_residues.end() );
	for( core::Size i=1; i<=n_return; ++i ) {
		// only use it if the cst actually evaluates
		if( all_residues[i].first < 0 ) best_residues.push_back( all_residues[i].second );
	}
	assert( best_residues.size() <= n_return );
	return best_residues;
}

void
find_lowest_constraint_energy_residue( core::pose::Pose const & pose, core::Size const chain, core::Size & resi, core::Real & lowest_energy )
{
	using namespace core::scoring;
	core::scoring::ScoreFunctionCOP scorefxn( ScoreFunctionFactory::create_score_function( STANDARD_WTS, SCORE12_PATCH ) );

	resi = 0;
	lowest_energy = 100000.0;
	for( core::Size i=pose.conformation().chain_begin( chain ); i<=pose.conformation().chain_end( chain ); ++i ){
		using namespace core::scoring;
		EnergyPerResidueFilter const eprf( i, scorefxn, backbone_stub_constraint, 10000.0/*dummy threshold*/ );
		core::Real const curr_energy( eprf.compute( pose ) );
		if( curr_energy<=lowest_energy ){
			lowest_energy = curr_energy;
			resi = i;
		}
	}
}

/// @details a utility function for removing ALL coordinate constraints from a pose.
/// returns the constraints that were removed
core::scoring::constraints::ConstraintCOPs
remove_coordinate_constraints_from_pose( core::pose::Pose & pose )
{
	using namespace core::scoring::constraints;

  ConstraintCOPs original_csts = pose.constraint_set()->get_all_constraints() ;
  ConstraintCOPs crd_csts;
  for( ConstraintCOPs::const_iterator it = original_csts.begin(), end = original_csts.end(); it != end; ++it ) {
    ConstraintCOP cst( *it );
    if( cst->type() == "CoordinateConstraint" ) {
      ConstraintCOP crd_cst = dynamic_cast< CoordinateConstraint const *>( cst() ); //downcast to derived ambiguous constraint
      if( crd_cst) { // safety check for downcasting
         crd_csts.push_back( cst ); // add the entire ambiguous cst, since it contained a bbcst
      }
    }
  }
  pose.remove_constraints( crd_csts ); // remove all the ambigcsts that contain a bbcst
	return( crd_csts );
}


/// @details utility function for stub_based_atom_tree. tries to find an optimal cutpoint in a pose given two different boundaries. First looks for a 3-res loop stretch on the downstream partner and returns the middle residue. Then, does the same for the upstream chain. Then, becomes desperate and tries to find any loop residue on downstream chain, and then on upstream chain. Finally, if no success, returns 0 which means that no break was found
core::Size
best_cutpoint( core::pose::Pose & pose, core::Size const prev_u, core::Size const prev_d, core::Size const u, core::Size const d )
{
	// if the pose is all loops (mini default), then run dssp
	// this logic may cause a problem for miniproteins that are indeed all loop. But for now it's certainly OK
	for( core::Size i=1; i<=pose.total_residue(); ++i ) {
		char const ss = pose.secstruct( i );
		if( ss == 'H' || ss == 'S' ) break;
		if( i == pose.total_residue() ) {
			core::scoring::dssp::Dssp dssp( pose );
			dssp.insert_ss_into_pose( pose );
		}
	}

	for( core::Size res = prev_d; res <= d-1; ++res ){
		if( pose.secstruct( res ) == 'L' ){
			if( pose.secstruct( res + 1 ) == 'L' && pose.secstruct( res + 2 ) == 'L' ) return res;
		}
	}
	for( core::Size res = prev_u; res <= u-1; ++res ){
		if( pose.secstruct( res ) == 'L' ){
			if( pose.secstruct( res + 1 ) == 'L' && pose.secstruct( res + 2 ) == 'L' ) return res;
		}
	}
	for( core::Size res = prev_d; res <= d; ++res ){
		if( pose.secstruct( res ) == 'L' ) return res;
	}
	for( core::Size res = prev_u; res <= u-1; ++res ){
		if( pose.secstruct( res ) == 'L' ) return res;
	}
	return 0; // sign of trouble
}

// a convenience function to test whether the user has specified pdb numbering rather than rosetta numbering.
core::Size
get_resnum( utility::Tag::TagPtr const tag_ptr, core::pose::Pose const & pose, std::string const & prefix/*=""*/ ) {
	core::Size resnum( 0 );
	bool const pdb_num_used( tag_ptr->hasOption( prefix + "pdb_num" ) );
	if( pose.pdb_info().get() == NULL ){//no pdbinfo for this pose (e.g., silent file), resort to using just the residue number
		if( pdb_num_used ){
			TR<<"Bad tag: "<< *tag_ptr<<std::endl;
			utility_exit_with_message( "pdb_num used but no pdb_info found. Use res_num instead" );
			return( 0 );
		}
	}
	else{
		core::pose::PDBPoseMap const pose_map( pose.pdb_info()->pdb2pose() );
		if( pdb_num_used ) {
			std::string pdbnum( tag_ptr->getOption<std::string>( prefix + "pdb_num" ) );
			char const chain( pdbnum[ pdbnum.length() - 1 ] );
			std::stringstream ss( pdbnum.substr( 0, pdbnum.length() - 1 ) );
			core::Size number;
			ss >> number;
			resnum = pose_map.find( chain, number );
		}
	}
	if( !pdb_num_used )
		resnum = tag_ptr->getOption<core::Size>( prefix + "res_num" );

	runtime_assert( resnum );
	return( resnum );
}

/// @brief Extracts a residue number from a string.
/// @detail Recognizes two forms of numbering:
///   - Rosetta residue numbers (numbered sequentially from 1 to the last residue
///     in the pose). These have the form [0-9]+
///   - PDB numbers. These have the form [0-9]+[A-Z], where the trailing letter
///     is the chain ID.
/// @return the rosetta residue number for the string, or 0 upon an error
core::Size parse_resnum(std::string const& resnum, core::pose::Pose const& pose) {

	string::const_iterator input_end = resnum.end();
	//Set number to the sequence of digits at the start of input [0-9]*
	string::const_iterator number_start = resnum.begin();
	string::const_iterator number_end = resnum.begin();
	while( number_end != input_end && *number_end >= '0' && *number_end <= '9' ) {
		++number_end;
	}
	//Set chain to the following characters
	string::const_iterator chain_start = number_end;
	string::const_iterator chain_end = number_end;
	while(  chain_end != input_end
		&& ('A' <= *chain_end && *chain_end <= 'Z' ||
			'a' <= *chain_end && *chain_end <= 'z' ||
			'_' == *chain_end ) )
	{
		++chain_end;
	}

	string number(number_start,number_end);
	string chain(chain_start,chain_end);

	//Require that the whole string match, and that the chain be a single char
	if( chain_end != input_end || chain.size() > 1 || number.size() < 1) {
		TR.Error << "Could not parse '" << resnum << "' into a residue number." << std::endl;
		return Size(0);
	}

	Size n;
	std::istringstream ss( number );
	ss >> n;
	if( chain.size() == 1 ) { // PDB Number
		TR.Trace << "Interpretting " << n << chain << " as a pdb number." << std::endl;
		pose::PDBInfoCOP info = pose.pdb_info();
		runtime_assert(info);
		return info->pdb2pose( chain[0], n );
	}
	else { // Rosetta Number
		TR.Trace << "Interpreting " << n << " as a Rosetta residue number." << std::endl;
		return n;
	}
}


/// @brief Extracts a list of residue numbers from a tag.
/// @details The tag should contain a comma-separated list of numbers, in either
///   pdb or rosetta format (@see parse_resnum for details)
vector1<Size> get_resnum_list(utility::Tag::TagPtr const tag_ptr, string const& tag, pose::Pose const& pose)
{
	vector1<Size> resnums;
	if( ! tag_ptr->hasOption( tag ) ) {
		TR.Error << "[Error] No '" << tag << "' tag was found." << std::endl;
		return resnums;
	}

	vector<string> residues = utility::string_split( tag_ptr->getOption<string>( tag ), ',' );

	for(vector<string>::const_iterator res = residues.begin(), end_res = residues.end();
			res != end_res; ++res)
	{
		if( *res == "" ) continue; //Ignore multiple spaces
		Size num = parse_resnum(*res,pose);
		if( num == 0 ) { //invalid number
			TR.Error << "Could not convert '" << *res << "' into a residue number." << std::endl;
		}
		else {
			resnums.push_back(num);
		}
	}

	return resnums;
}

/// @details find the nearest residue on the target chain to res
core::Size
find_nearest_residue( core::pose::Pose const & pose, core::Size const target_chain, core::Size const res, std::string const atom/*=="CA"*/ )
{
  core::Size nearest_resi( 0 );
	core::Real nearest_dist( 100000.0 );
	for( core::Size resi( pose.conformation().chain_begin( target_chain ) ); resi<=pose.conformation().chain_end( target_chain ); ++resi ){
		core::Real const distance( pose.residue(resi).xyz( pose.residue(resi).nbr_atom() ).distance( pose.residue( res ).xyz( atom ) ) );
		if( distance<=nearest_dist ){
			nearest_resi = resi;
			nearest_dist = distance;
		}
	}
	runtime_assert( nearest_resi );
	return( nearest_resi );
}

} //ProteinInterfaceDesign
} //protocols
