// -*- mode:c++;tab-width:2;indent-tabs-mode:t;show-trailing-whitespace:t;rm-trailing-spaces:t -*-
// vi: set ts=2 noet:
//
// (c) Copyright Rosetta Commons Member Institutions.
// (c) This file is part of the Rosetta software suite and is made available under license.
// (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
// (c) For more information, see http://www.rosettacommons.org. Questions about this can be
// (c) addressed to University of Washington UW TechTransfer, email: license@u.washington.edu.

/// @file ConformerSwitchMover.cc
/// @brief code for the conformer switch mover in ensemble docking
/// @author Sid Chaudhury

#include <protocols/moves/ConformerSwitchMover.hh>
#include <protocols/moves/ConformerSwitchMoverCreator.hh>

// Rosetta Headers
#include <core/conformation/Interface.hh>
#include <core/io/pdb/pose_io.hh>
#include <core/id/types.hh>
#include <core/id/AtomID_Map.Pose.hh>
#include <core/options/option.hh>
#include <core/options/keys/OptionKeys.hh>
#include <core/pose/Pose.hh>
#include <core/scoring/Energies.hh>
#include <core/scoring/rms_util.hh>
#include <core/scoring/ScoreFunction.hh>
#include <core/types.hh>

#include <protocols/moves/ReturnSidechainMover.hh>
#include <protocols/moves/Mover.fwd.hh>
#include <protocols/moves/SwitchResidueTypeSetMover.hh>


// Random number generator
#include <numeric/random/random.hh>
//
#include <string>

#include <core/util/Tracer.hh>

#include <utility/io/izstream.hh>

//Auto Headers
#include <core/chemical/ChemicalManager.fwd.hh>
#include <core/options/keys/docking.OptionKeys.gen.hh>

using core::util::T;
using core::util::Error;
using core::util::Warning;

static numeric::random::RandomGenerator RG(38627226);
static core::util::Tracer TR("protocols.moves.ConformerSwitchMover");
namespace protocols {
namespace moves {

std::string
ConformerSwitchMoverCreator::keyname() const
{
	return ConformerSwitchMoverCreator::mover_name();
}

protocols::moves::MoverOP
ConformerSwitchMoverCreator::create_mover() const {
	return new ConformerSwitchMover;
}

std::string
ConformerSwitchMoverCreator::mover_name()
{
	return "ConformerSwitchMover";
}

//constructor with arguments
ConformerSwitchMover::ConformerSwitchMover(
	Size start_res,
	Size end_res,
	Size jump_id,
	core::scoring::ScoreFunctionCOP score_fxn,
	std::string ensemble_file_path
) :
	Mover(),
	start_res_(start_res),
	end_res_(end_res),
	jump_id_(jump_id),
	temperature_(0.8),
	fullatom_( false ),
	partition_switch_( true )
{
	// initialize current conf_num to zero
	conf_num_ = 0;

	score_fxn_ = score_fxn;
	load_ensemble( ensemble_file_path );
	ensemble_size_ = ensemble_.size();

	runtime_assert(ensemble_size_ > 0);
	conf_size_ = ensemble_[1].total_residue();
	runtime_assert((end_res_ - start_res_ + 1) == conf_size_);

	TR<<"ensemble summary: start_res_ "<<start_res_<<" end_res_ "<<end_res_<<" conf_size_ "<<conf_size_ <<" ensemble_size_ "<<ensemble_size_<<std::endl;
}

void ConformerSwitchMover::load_ensemble( std::string ensemble_file_path )
{
	utility::vector1< std::string > filenames;
	utility::io::izstream file( ensemble_file_path );
	std::string line;
	while( getline(file, line) ) {
		if ( line.find("pdb") != std::string::npos )
			filenames.push_back( line );
	}

	ensemble_ = core::io::pdb::poses_from_pdbs( filenames);
}

void ConformerSwitchMover::set_temperature( core::Real temp_in )
{
	temperature_ = temp_in;
}

void ConformerSwitchMover::set_fullatom( bool fullatom_in )
{
	fullatom_ = fullatom_in;
}

void ConformerSwitchMover::set_partition_switch( bool partition_switch_in )
{
	partition_switch_ = partition_switch_in;
}

void ConformerSwitchMover::apply( core::pose::Pose & pose )
{
	bool filter_pass( false );
	conf_num_ = 1;

	(*score_fxn_)(pose);
	if (fullatom_){
		filter_pass = true;
	} else {
		filter_pass = docking_lowres_filter( pose );
	}

	if ( partition_switch_ && filter_pass ){		//only calculate partition function if filters are passed
		GenerateProbTable( pose );
		core::Real rand_num( RG.uniform() );
		for (Size i = 1; i <= ensemble_size_; i++){
			if( (rand_num >= prob_table_[i]) ) conf_num_++;
		}
	} else {
		conf_num_ = RG.random_range( 1, ensemble_size_ );
	}

	TR.Debug << "Switching partner with conformer: " << conf_num_ << std::endl;
	SwitchConformer( pose, conf_num_ );
}

void ConformerSwitchMover::recover_conformer_sidechains( core::pose::Pose & pose )
{
	core::pose::Pose recover_pose = ensemble_[conf_num_];
//	ReturnSidechainMoverOP recover_mover = new ReturnSidechainMover( recover_pose, start_res_, end_res_ );
	ReturnSidechainMoverOP recover_mover = new ReturnSidechainMover( recover_pose, start_res_, end_res_ );
	recover_mover->apply( pose );
}

void ConformerSwitchMover::GenerateProbTable( core::pose::Pose & pose )
{

	core::pose::Pose complex_pose = pose;
	utility::vector1< core::Real > e_table;
	core::Real partition_sum(0.0);

	prob_table_.clear();

	for( Size i = 1; i <= ensemble_size_; i++){
		SwitchConformer( complex_pose, i );
		core::Real complex_score = (*score_fxn_)(complex_pose);
		complex_pose = pose;
		e_table.push_back( complex_score );
		}

	core::Real min_energy(0.0);
	for ( Size i = 1; i <= ensemble_size_; i++){
		if (e_table[i] <= min_energy) min_energy = e_table[i];
		}

	for ( Size i = 1; i <= ensemble_size_; i++){
		e_table[i] = std::exp((-1*(e_table[i] - min_energy))/temperature_);
		partition_sum += e_table[i];
		}

	prob_table_.push_back( e_table[1]/partition_sum );
	for (Size i = 2; i <=ensemble_size_; i++){
		prob_table_.push_back( prob_table_[i-1] + (e_table[i]/partition_sum) );
		}

}

void ConformerSwitchMover::SwitchConformer(
	core::pose::Pose & pose,
	core::Size conf_num
	)
{
	core::conformation::Interface interface( jump_id_ );
	core::pose::Pose new_conf = ensemble_[conf_num];

	SwitchResidueTypeSetMover to_centroid( core::chemical::CENTROID );

	if ( !fullatom_ ) to_centroid.apply( new_conf );

	interface.calculate( pose );

	utility::vector1<Size>conf_interface;
	for (Size i = start_res_; i <= end_res_; i++){
		if (interface.is_interface(i)) conf_interface.push_back( i );
	}

	if (conf_interface.size() <= 5){
		conf_interface.clear();
		for (Size i = start_res_; i <= end_res_; i++){
			conf_interface.push_back( i );
		}
	}

	core::id::AtomID_Map< core::id::AtomID > atom_map;
	core::id::initialize( atom_map, new_conf, core::id::BOGUS_ATOM_ID ); // maps every atomid to bogus

	for (Size i = 1; i <= conf_interface.size(); i++){
		 Size new_conf_resnum = conf_interface[i]-start_res_+1;
		 Size pose_resnum = conf_interface[i];
		core::id::AtomID const id1( new_conf.residue(new_conf_resnum).atom_index("CA"), new_conf_resnum );
		core::id::AtomID const id2( pose.residue(pose_resnum).atom_index("CA"), pose_resnum );
		atom_map[ id1 ] = id2;
	}

	core::scoring::superimpose_pose( new_conf, pose, atom_map );
	pose.copy_segment( conf_size_, new_conf, start_res_, 1);
}

bool
ConformerSwitchMover::docking_lowres_filter( core::pose::Pose & pose){

	using namespace core;
	using namespace scoring;
	using namespace options;

	bool passed_filter = true;

	//criterion for failure
	Real interchain_contact_cutoff  = 10.0;
	Real interchain_vdw_cutoff = 1.0;

	if( option[ OptionKeys::docking::dock_lowres_filter ].user() ) {
		utility::vector1< Real > dock_filters = option[ OptionKeys::docking::dock_lowres_filter ]();
		interchain_contact_cutoff = dock_filters[1];
		interchain_vdw_cutoff = dock_filters[2];
	}

	if (pose.energies().total_energies()[ interchain_contact ] >= interchain_contact_cutoff ) passed_filter = false;
	if (pose.energies().total_energies()[ interchain_vdw ] >= interchain_vdw_cutoff ) passed_filter = false;

	return passed_filter;
}

std::string ConformerSwitchMover::get_name() const {
	return ConformerSwitchMoverCreator::mover_name();
}

}  // namespace moves
}  // namespace protocols
