// -*- mode:c++;tab-width:2;indent-tabs-mode:t;show-trailing-whitespace:t;rm-trailing-spaces:t -*-
// vi: set ts=2 noet:
//
// (c) Copyright Rosetta Commons Member Institutions.
// (c) This file is part of the Rosetta software suite and is made available under license.
// (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
// (c) For more information, see http://www.rosettacommons.org. Questions about this can be
// (c) addressed to University of Washington UW TechTransfer, email: license@u.washington.edu.

/// @file src/protocols/loophash/MPI_LoopHashRefine.cc 
/// @brief
/// @author Mike Tyka

#define TRDEBUG TR.Debug

// MPI headers
#ifdef USEMPI
#include <mpi.h> //keep this first
#endif

#include <protocols/loophash/MPI_LoopHashRefine.hh>
#include <protocols/loophash/WorkUnit_LoopHash.hh>
#include <protocols/wum/WorkUnit_BatchRelax.hh>
#include <protocols/wum/WorkUnitBase.hh>
#include <protocols/wum/SilentStructStore.hh>
#include <core/io/pdb/pose_io.hh>
#include <core/pose/util.hh>
#include <core/chemical/ChemicalManager.hh>
#include <core/chemical/ResidueTypeSet.hh>
#include <core/chemical/util.hh>
#include <core/io/pose_stream/MetaPoseInputStream.hh>
#include <core/io/pose_stream/util.hh>
#include <core/io/silent/SilentFileData.hh>
#include <core/io/silent/SilentStructFactory.hh>
#include <core/io/silent/SilentStruct.hh>
#include <core/io/silent/ProteinSilentStruct.hh>
#include <core/options/keys/in.OptionKeys.gen.hh>
#include <core/options/keys/out.OptionKeys.gen.hh>
#include <core/options/keys/relax.OptionKeys.gen.hh>
#include <core/options/keys/lh.OptionKeys.gen.hh>
#include <core/options/option.hh>
#include <core/pose/Pose.hh>
#include <core/scoring/ScoreFunctionFactory.hh>
#include <core/scoring/ScoreFunction.hh>
#include <core/util/Tracer.hh>
#include <ObjexxFCL/format.hh>
/// ObjexxFCL headers
#include <ObjexxFCL//string.functions.hh>
#include <ObjexxFCL/format.hh>

#include <numeric/random/random.hh>

#include <unistd.h>

using namespace ObjexxFCL;
using namespace ObjexxFCL::fmt;

namespace protocols {
namespace loophash {

using namespace protocols::wum;

static core::util::Tracer TR("MPI.LHR");

static numeric::random::RandomGenerator RG(9788321);  // <- Magic number, do not change it (and dont try and use it anywhere else)



MPI_LoopHashRefine::MPI_LoopHashRefine( char machine_letter ):
	MPI_WorkUnitManager( machine_letter ),
	max_lib_size_(50),
	save_state_interval_(300),
	last_save_state_(0),
	totaltime_loophash_(0),
	n_loophash_(0),
	totaltime_batchrelax_(0),
	n_batchrelax_(0),
	total_structures_(0),
	total_structures_relax_(0),
	total_metropolis_(1),
	total_metropolis_accepts_(0),
	ident_string_("ident")
{
	set_defaults();  // constructors must must must call this!
}

void
MPI_LoopHashRefine::set_defaults(){
	using namespace core::options;
	using namespace core::options::OptionKeys;
	
	max_lib_size_   = option[ OptionKeys::lh::max_lib_size ](); 
	save_state_interval_        = option[ OptionKeys::lh::mpi_save_state_interval ]();
	mpi_feedback_               = option[ OptionKeys::lh::mpi_feedback ]();
	mpi_metropolis_temp_        = option[ OptionKeys::lh::mpi_metropolis_temp ]();
	rms_limit_                  = option[ OptionKeys::lh::rms_limit ]();
	mpi_resume_ = "";
	if( option[ OptionKeys::lh::mpi_resume ].user() ){
		mpi_resume_ = option[ OptionKeys::lh::mpi_resume ]();
	}
	jobname_ = option[ OptionKeys::lh::jobname ]();

//	if ( option[ in::file::native ].user() ) {
//		native_pose_ = new core::pose::Pose();
//		core::io::pdb::pose_from_pdb( *native_pose_, option[ in::file::native ]() );
//		core::pose::set_ss_from_phipsi( *native_pose_ );
//		core::chemical::switch_to_residue_type_set( *native_pose_, core::chemical::CENTROID);
//	}

	// Make ident string:
	// Make a medley of a string and the PID 
	ident_string_ =  option[ lh::jobname ]; 
	TR << "IDENT: " << ident_string_ << std::endl;

	// make sure the state saves are randomly staggered - otherwise all the masters dump several hundred megs at once!
	last_save_state_ = time(NULL)  + core::Size( RG.uniform()  * (core::Real) save_state_interval_) ;
	TR << "Interlace dumps: " << last_save_state_ << "  " << time(NULL)  << "  " << last_save_state_ - time(NULL) << "  " << save_state_interval_ << "  " << std::endl;
}



void
MPI_LoopHashRefine::load_structures_from_cmdline_into_library( core::Size structure_read_offset )
{
	start_timer( TIMING_IO_READ  );
	TR << "Reading in structures..." << std::endl;

	core::chemical::ResidueTypeSetCAP rsd_set;
	rsd_set = core::chemical::ChemicalManager::get_instance()->residue_type_set( "fa_standard" );

	core::io::pose_stream::MetaPoseInputStream input = core::io::pose_stream::streams_from_cmd_line();

	core::Size count = 0;
	SilentStructStore temp_lib;

	while( input.has_another_pose() ) {
		core::pose::Pose pose;
		input.fill_pose( pose, *rsd_set );
		//core::chemical::switch_to_residue_type_set( pose, core::chemical::CENTROID);
		core::pose::set_ss_from_phipsi( pose );
		core::io::silent::ProteinSilentStruct pss;
		pss.fill_struct( pose );
		pss.add_energy( "lhcount", 0 );
		pss.add_energy( "ltime", time(NULL) );
		pss.add_energy( "master", mpi_rank() );
		temp_lib.add( pss );
		count ++;
	}

	if ( temp_lib.size() == 0 ){
		TR.Error << "Error reading starting structures: 0 valid structures read. Check your input files and parameters" << std::endl;
		utility_exit_with_message( "Error reading starting structures: 0 valid structures read. Check your input files and parameters" );
	}

	TR << "Loaded " << temp_lib.size() << " starting structures" << std::endl;
	
	core::Size position = structure_read_offset % temp_lib.size();
	count=1;
	while( library_central_.size() < max_lib_size_ ){
		core::io::silent::SilentStructOP ss;
		ss = temp_lib.get_struct( position )->clone();
		ss->add_energy( "ssid", count );
		library_central_.add( ss );
		position++;
		if(position>=temp_lib.size()) position = 0;
		count++;	
	}

	TR << "Added " << library_central_.size() << " structures to library " << std::endl;

	runtime_assert(  library_central_.size() == max_lib_size_ );
	start_timer( TIMING_CPU );
}


void 
MPI_LoopHashRefine::save_state(std::string prefix ){
	start_timer( TIMING_IO_WRITE );
	long starttime = time(NULL);
	write_queues_to_file( prefix + "." + string_of(mpi_rank()) );
	library_central_.serialize_to_file( prefix + "." + string_of(mpi_rank())+ ".lib.library_central" );
	long endtime = time(NULL);
	TR << "Saved state: " << endtime - starttime << "s " << inbound().size() << " + " << outbound().size() << " + " <<  library_central_.size() << " + " << ( inbound().size() + outbound().size() + library_central_.size() ) << std::endl;
	start_timer( TIMING_CPU );
}


void 
MPI_LoopHashRefine::save_state_auto(){
	if( (core::Size)(last_save_state_ + save_state_interval_ ) < (core::Size)time(NULL) ){ 
		TR << "Saving state.. " << std::endl;
		last_save_state_ = time(NULL);
		save_state( ident_string_ );	
	}
}


void 
MPI_LoopHashRefine::load_state(std::string prefix ){
	start_timer( TIMING_IO_READ );
	inbound().clear();
	outbound().clear();
	read_queues_from_file( prefix + "." + string_of(mpi_rank()) );
	library_central_.clear();
	library_central_.read_from_file( prefix +  "." + string_of(mpi_rank()) + ".lib.library_central" );
	start_timer( TIMING_CPU );
}






void
MPI_LoopHashRefine::print_stats(){
	static int lasttime = 0;
	if( (time(NULL) - lasttime) < 300 ) return;
	lasttime = time(NULL);

	TR << "STATL: " 
		 << wall_time() << "s  "
     << total_structures_ << "  " 
     << total_structures_relax_ 
     << " Acc: " << F(5,3, core::Real(total_metropolis_accepts_)/ core::Real(total_metropolis_) ) 
		 << " CPU: " << int((totaltime_batchrelax_+totaltime_loophash_)/3600) << " hrs  " 
     << " r/l: " << F(5,2, float(totaltime_batchrelax_)/(float(totaltime_loophash_)+0.1)) 
	   << " LHav: " << int(totaltime_loophash_/(n_loophash_+1)) << "s "
	   << " BRav: " << int(totaltime_batchrelax_/(n_batchrelax_+1)) << "s "
     << " MEM: " << int(library_central_.mem_footprint()/1024) << " kB "  << std::endl;
}






bool 
MPI_LoopHashRefine::add_structure_to_library( core::io::silent::ProteinSilentStruct &pss ){
	// reset the lhcount to 0
	pss.add_energy( "lhcount", 0 );
	pss.add_energy( "ltime", time(NULL) );

	bool result;
	if( mpi_feedback_ == "no" )             result = false; 
	if( mpi_feedback_ == "add_n_limit" )    result = add_structure_to_library_direct( pss );
	if( mpi_feedback_ == "add_n_replace" )  result = add_structure_to_library_add_n_replace( pss );
	if( mpi_feedback_ == "single_replace" ) result = add_structure_to_library_single_replace( pss );
	
	return result;
}


bool
MPI_LoopHashRefine::add_structure_to_library_direct( core::io::silent::ProteinSilentStruct &pss )
{
	library_central_.add( pss );
	return true;	
}


bool 
MPI_LoopHashRefine::add_structure_to_library_add_n_replace( core::io::silent::ProteinSilentStruct &pss )
{
	core::Real new_struct_score = pss.get_energy("score");
	pss.add_energy( "lhcount", 0 );
	start_timer( TIMING_CPU );
	TRDEBUG << "Checking for zero size .. " << std::endl;
	if( library_central_.size() == 0 ){
	
		add_structure_to_library_direct( pss );
		return true;
	}

	library_central_.sort_by();
	// if energy is wrse then worst member - ignore structure
	if ( new_struct_score > library_central_.store().back()->get_energy("score") ){
					TR << "Ignoring struc: " << new_struct_score << " > " << format_silent_struct(library_central_.store().back()) << std::endl;
					return false;
	}

	// now find the closest RMS
	core::Real closest_rms = 1000000;
	SilentStructStore::iterator closest_struct;

	// upcast to protein silent_structs:
	for( SilentStructStore::iterator jt =  library_central_.begin();
									jt != library_central_.end(); jt ++ )
	{
					// upcast to protein silent_structs:
					core::io::silent::ProteinSilentStruct *jt_pss = dynamic_cast < core::io::silent::ProteinSilentStruct * > ( &(*(*jt)) );
					if ( jt_pss == NULL ) utility_exit_with_message( "FATAL ERROR:  This code only runs with Protein SilentStructus " );
					core::Real the_rms = pss.CA_rmsd( *jt_pss );
					TRDEBUG << "The rms: " << the_rms << std::endl;
					if ( the_rms < closest_rms ){
									TR << "Found better: " << the_rms;
									closest_rms = the_rms;
									closest_struct = jt;
					}
	}

	core::Real rms_time = start_timer( TIMING_CPU );
	TR << "RMS_Time: " << rms_time << std::endl; 

	if ( closest_rms < rms_limit_ ){
			// replace if lower in energy
			core::Real energy_old = (*closest_struct)->get_energy("score");
			if( new_struct_score < energy_old ){
				pss.add_energy( "lhcount", 0 );
				*(*closest_struct) = pss;
				return true;
			}
	}else{
			// add
			pss.add_energy( "lhcount", 0 );
			library_central_.add( pss );
			library_central_.sort_by();
			return true;
	}
	return false;	
}



bool
MPI_LoopHashRefine::add_structure_to_library_single_replace( core::io::silent::ProteinSilentStruct &pss )
{
	core::Real ssid = pss.get_energy("ssid");

	// now find the library structure with the same ssid
	find_SilentStructOPs predic("ssid", ssid);
	SilentStructStore::iterator ssid_match = std::find_if( library_central_.begin(), library_central_.end(), predic );

	if( ssid_match == library_central_.end() ){
		TR << "FATAL ERROR: Cannot find ssid: " + ObjexxFCL::string_of( ssid ) << std::endl;
		utility_exit_with_message( "FATAL ERROR: Cannot find ssid: " + ObjexxFCL::string_of( ssid ) );
		return false;
	}

	core::io::silent::SilentStructOP ssop = new core::io::silent::ProteinSilentStruct( pss );

	bool replace_it = false;
	// to get library replacement you must either: be of a more advanced round or have lower energy
	if( pss.get_energy("round") > (*ssid_match)->get_energy("round") ) replace_it = true;
	else {
		if( pss.get_energy("score") < (*ssid_match)->get_energy("score") ) replace_it = true;
	}
	
	if( replace_it ){
		core::Real new_energy = ssop->get_energy("score");
		core::Real old_energy = (*ssid_match)->get_energy("score");

		bool metropolis_replace = false;
		
		core::Real energy_diff_T = 0;
		if( mpi_metropolis_temp_ > 0.0 ) energy_diff_T = old_energy - new_energy;
	
		if( ( energy_diff_T >= 0.0 ) ) metropolis_replace = true; // energy of new is simply lower
		else if ( energy_diff_T > (-10.0) ){
			core::Real random_float = RG.uniform();
			if ( random_float < exp( energy_diff_T ) )  metropolis_replace = true;
		} 
		total_metropolis_++;
		if ( metropolis_replace ) { 
			total_metropolis_accepts_++;
			TR << "ReplacingACC: " << format_silent_struct( *ssid_match) << " with " << format_silent_struct( ssop ) << "  " << energy_diff_T << std::endl;
			*ssid_match = ssop;
			return true;	
		}else{
			TR << "ReplacingREJ: " << format_silent_struct( *ssid_match) << " with " << format_silent_struct( ssop ) << "  " << energy_diff_T << std::endl;
		}
	}

	return false;
}







void
MPI_LoopHashRefine::print_library(){
	TR << "-----<ssid>  <score> <censc>  <rms>  <round>  <time>  <master>  <lhcount> ----" << std::endl;
	for( SilentStructStore::const_iterator it = library_central_.begin(); it != library_central_.end(); ++ it ){
		TR << "LIB: " << format_silent_struct( *it ) << std::endl;
	}
	TR << "-< Total size: " << library_central_.store().size()  << " >---------------------------------" << std::endl;
}


bool
MPI_LoopHashRefine::add_structures_to_library( SilentStructStore &new_structs ){
	bool result = false;

	for( SilentStructStore::const_iterator it = new_structs.begin();
		 it != new_structs.end(); ++it )
	{
		runtime_assert( *it );
		core::io::silent::ProteinSilentStruct *pss;
		TR << "Add structure... " << format_silent_struct( *it ) << std::endl; 
		core::io::silent::SilentStruct *ss = &(*(*it));
		pss = dynamic_cast< core::io::silent::ProteinSilentStruct* >( ss ); 
		runtime_assert( pss );
		bool local_result = add_structure_to_library( *pss );
		result = result || local_result;
	}

	limit_library();
	print_library();

	return result;
}


void 
MPI_LoopHashRefine::limit_library(){
	// now shave off the worst structres
	library_central_.sort_by();
	while( library_central_.size() > max_lib_size_) {
		library_central_.store().pop_back();
	}
}

void
MPI_LoopHashRefine::dump_structures( const SilentStructStore &new_structs, bool score_only ) const {
	start_timer( TIMING_IO_WRITE  );
	core::io::silent::SilentFileData sfd;
	std::string filename = jobname_ + "." + string_of( mpi_rank() ) + ".out";
	for( SilentStructStore::const_iterator it = new_structs.begin();
		 it != new_structs.end(); ++it )
	{
		(*it)->print_score_header( TR );
		sfd.write_silent_struct( *(*it), filename, score_only );
	}
	core::Real write_time = start_timer( TIMING_CPU );
	TR << "Write time: " << write_time << std::endl;
}

void
MPI_LoopHashRefine::send_random_library_struct( core::Size dest_rank, core::Size newssid ) const {
	if( library_central_.size() == 0 ){
	 	std::cout << "ERROR: Havce no structure to send" << std::endl;
		return;
	}
	
	// now fabricate a return WU
	WorkUnit_SilentStructStoreOP resultpack = new WorkUnit_SilentStructStore( );
	resultpack->set_wu_type( "resultpack" );
	core::io::silent::SilentStructOP new_struct = library_central_.get_struct_random()->clone();
	new_struct->add_energy("ssid", newssid);   // overwrite the ssid 
	resultpack->decoys().add( new_struct ); 
	send_MPI_workunit( resultpack, dest_rank );
}



std::string format_silent_struct( const core::io::silent::SilentStructOP &ss ){
	std::stringstream sstream;
		sstream << "["<<I(4,   ss->get_energy("ssid") )
		   << " |"    << F(8,1, ss->get_energy("score")) 
		   << " |"    << F(8,1, ss->get_energy("censcore")) 
		   << " |"    << F(5,1, ss->get_energy("rms") )
		   << " |"    << I(3,  ss->get_energy("round") )
		   << " |"    << I(5, time(NULL) - ss->get_energy("ltime") )
		   << " |"    << I(3,  ss->get_energy("master") )
		   << " |"    << I(1,  ss->get_energy("lhcount")) << "]"; 
	return sstream.str();
}




} // namespace loophash
} // namespace protocols


