// -*- mode:c++;tab-width:2;indent-tabs-mode:t;show-trailing-whitespace:t;rm-trailing-spaces:t -*-
// vi: set ts=2 noet:
//
// (c) Copyright Rosetta Commons Member Institutions.
// (c) This file is part of the Rosetta software suite and is made available under license.
// (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
// (c) For more information, see http://www.rosettacommons.org. Questions about this can be
// (c) addressed to University of Washington UW TechTransfer, email: license@u.washington.edu.

/// @file
/// @brief
/// @author Srivatsan Raman
/// @author Frank DiMaio

#include <protocols/jobdist/JobDistributors.hh>
#include <protocols/jobdist/Jobs.hh>
#include <core/types.hh>

#include <core/kinematics/Jump.hh>
#include <core/pose/Pose.hh>
#include <core/pose/util.hh>
#include <core/util/Tracer.hh>
#include <core/io/pdb/pose_io.hh>
#include <core/scoring/ScoreFunction.hh>
#include <core/scoring/ScoreFunctionFactory.hh>
#include <core/chemical/util.hh>
#include <protocols/loops/loops_main.hh>
#include <core/scoring/Energies.hh>

#include <core/kinematics/FoldTree.hh>
#include <protocols/RBSegmentMoves/AutoRBRelaxMover.hh>
#include <protocols/RBSegmentMoves/RBSegmentRelax.hh>
#include <protocols/RBSegmentMoves/RBSegmentMover.hh>
#include <protocols/RBSegmentMoves/RBSegment.hh>
#include <protocols/RBSegmentMoves/util.hh>
#include <protocols/loops/loops_main.hh>
#include <protocols/loops/ccd_closure.hh>
#include <protocols/moves/MonteCarlo.hh>
#include <protocols/moves/MoverContainer.hh>
#include <protocols/moves/RigidBodyMover.hh>
#include <protocols/moves/PackRotamersMover.hh>
#include <protocols/moves/SwitchResidueTypeSetMover.hh>
#include <protocols/electron_density/util.hh>
#include <protocols/relax/FastRelax.hh>

#include <core/scoring/electron_density/util.hh>
#include <core/scoring/constraints/util.hh>

#include <core/pack/task/PackerTask.hh>
#include <core/pack/task/TaskFactory.hh>
#include <core/pack/task/operation/TaskOperations.hh>
#include <core/pack/task/operation/NoRepackDisulfides.hh>
#include <core/pack/task/operation/OperateOnCertainResidues.hh>
#include <core/optimization/AtomTreeMinimizer.hh>
#include <core/optimization/MinimizerOptions.hh>


#include <protocols/viewer/viewers.hh>
#include <protocols/jumping/Dssp.hh>

#include <core/io/silent/SilentStructFactory.hh>


#include <core/chemical/ChemicalManager.fwd.hh>
#include <core/conformation/Residue.hh>
#include <core/io/silent/SilentStruct.hh>
#include <core/options/option.hh>

#include <protocols/RBSegmentMoves/RBSegment.hh>
#include <protocols/moves/SwitchResidueTypeSetMover.hh>
#include <core/fragment/FragSet.hh>
#include <protocols/abinitio/FragmentMover.hh>
#include <numeric/random/random.hh>

// C++ headers
#include <fstream>
#include <iostream>
#include <string>


//options
#include <core/options/keys/edensity.OptionKeys.gen.hh>
#include <core/options/keys/out.OptionKeys.gen.hh>
#include <core/options/keys/RBSegmentRelax.OptionKeys.gen.hh>
#include <core/options/keys/in.OptionKeys.gen.hh>
#include <core/options/keys/loops.OptionKeys.gen.hh>

using core::util::T;
using core::util::Error;
using core::util::Warning;

namespace protocols {
namespace RBSegment {

core::util::Tracer tr("protocols.RBSegment.AutoRBRelaxMover");

////////////////
// stupid mover
class CCDMoveWrapper : public protocols::moves::Mover {
public:
	CCDMoveWrapper( core::kinematics::MoveMapOP movemap, core::Size start, core::Size stop, core::Size cut) :
		movemap_(movemap),
		start_(start),
		stop_(stop),
		cut_(cut) { }

	void apply( Pose & pose ) {
		protocols::loops::ccd_moves(25, pose, *movemap_, start_, stop_, cut_ );
	}

	virtual std::string get_name() const {
		return ("CCDMoveWrapper");
	}

private:
	core::kinematics::MoveMapOP movemap_;
	core::Size start_, stop_, cut_;
};

typedef utility::pointer::owning_ptr< CCDMoveWrapper >  CCDMoveWrapperOP;



///////////////////
///
AutoRBMover::AutoRBMover() {
	using namespace core::options;

	scorefxn_ = core::scoring::ScoreFunctionFactory::create_score_function(
		option[ OptionKeys::RBSegmentRelax::rb_scorefxn ]() );

	movemap_ = new core::kinematics::MoveMap();

	nouter_cycles_ = option[ OptionKeys::RBSegmentRelax::nrboutercycles ]();
	ninner_cycles_ = option[ OptionKeys::RBSegmentRelax::nrbmoves ]();

	loop_melt_ = 3;

	// load frags
	loops::read_loop_fragments( frag_libs_ );

	allowSeqShiftMoves_ = !(option[ OptionKeys::RBSegmentRelax::skip_seqshift_moves]());

	allowSSFragInserts_ = false;

	// fa stuff
	tf_ = new core::pack::task::TaskFactory();
	tf_->push_back( new core::pack::task::operation::RestrictToRepacking );
	tf_->push_back( new core::pack::task::operation::InitializeFromCommandline );
	tf_->push_back( new core::pack::task::operation::IncludeCurrent );
	tf_->push_back( new core::pack::task::operation::NoRepackDisulfides );

	fa_scorefxn_ = core::scoring::getScoreFunction();
	fa_scorefxn_->set_weight( core::scoring::chainbreak, 10.0/3.0);

	if ( option[ OptionKeys::edensity::mapfile ].user() ) {
		core::scoring::electron_density::add_dens_scores_from_cmdline_to_scorefxn( *scorefxn_ );
		core::scoring::electron_density::add_dens_scores_from_cmdline_to_scorefxn( *fa_scorefxn_ );
	}
}

void
AutoRBMover::apply( core::pose::Pose & pose ) {
	// ensure pose is in centroid mode
	protocols::moves::SwitchResidueTypeSetMover to_centroid("centroid");
	to_centroid.apply( pose );

	// ensure pose is rooted on VRT
	core::pose::addVirtualResAsRoot( pose );

	// Get DSSP parse; use it to:
	//    build fold tree (star topology)
	//    set movemap
	//    set variant types
	setup_topology( pose );

	// load constraints
	core::scoring::constraints::add_constraints_from_cmdline( pose, *scorefxn_ );

	core::Size nres =  pose.total_residue()-1;
	bool loops_closed = false;
	grow_flexible( loop_melt_, nres, 1 );

	core::pose::Pose start_pose = pose;

	while( !loops_closed ) {
		pose = start_pose;

		// setup fragment movers
		utility::vector1< protocols::abinitio::FragmentMoverOP > fragmover;
		for ( utility::vector1< core::fragment::FragSetOP >::const_iterator
					it = frag_libs_.begin(), it_end = frag_libs_.end();
					it != it_end; it++ ) {
			protocols::abinitio::ClassicFragmentMover *cfm = new protocols::abinitio::ClassicFragmentMover( *it, movemap_ );
			cfm->set_check_ss( false );
			cfm->enable_end_bias_check( false );
			fragmover.push_back( cfm );
		}

		// mc object
		core::Real init_temp = 2.0;
		core::Real temperature = init_temp;
		core::Real final_temp = 0.6;
		core::Real const gamma = std::pow(final_temp/init_temp, (1.0/(nouter_cycles_*ninner_cycles_)) );
		moves::MonteCarloOP mc_ = new moves::MonteCarlo( pose, *scorefxn_, init_temp );

		// movement loop
		float final_chain_break_weight = 1.0;
		float delta_weight( final_chain_break_weight/nouter_cycles_ );

		// random mover
		protocols::moves::RandomMover random_move;

		// loop fragment insertion
		for ( std::vector< protocols::abinitio::FragmentMoverOP >::const_iterator
				it = fragmover.begin(),it_end = fragmover.end(); it != it_end; it++ )
			random_move.add_mover(*it, rb_chunks_.size());

		// rigid-body move
		for (int i=1; i<=(int)rb_chunks_.size(); ++i)
			random_move.add_mover(new protocols::moves::RigidBodyPerturbMover( i , 3.0 , 1.0 ));

		//TODO rigid-chunk fragment insertion
		if (allowSSFragInserts_) ;

		// sequence shift
		if (allowSeqShiftMoves_) {
			for (int i=1; i<=(int)rb_chunks_.size(); ++i)
			for (int j=1; j<=(int)rb_chunks_[i].nContinuousSegments(); ++j) {
				protocols::moves::SequenceMoverOP seq_shift_move = new protocols::moves::SequenceMover;
				seq_shift_move->add_mover( new SequenceShiftMover(rb_chunks_[i][j]) );

				// find adjacent loops
				for (core::Size k=1; k<=loops_.size(); ++k) {
					bool adjLoopN = (loops_[k].stop() >= rb_chunks_[i][j].start()-1) && (loops_[k].stop() <= rb_chunks_[i][j].end()+1);
					bool adjLoopC = (loops_[k].start() >= rb_chunks_[i][j].start()-1) && (loops_[k].start() <= rb_chunks_[i][j].end()+1);
					if ( adjLoopN || adjLoopC ) {
						seq_shift_move->add_mover( new CCDMoveWrapper(movemap_, loops_[k].start(), loops_[k].stop(), loops_[k].cut() ) );
					}
				}
				random_move.add_mover(seq_shift_move, 0.5);
			}
		}

		scorefxn_->set_weight( core::scoring::chainbreak, 0.0 );
		for( int n1 = 1; n1 <= (int)nouter_cycles_; ++n1 ) {
			mc_->recover_low( pose );
			scorefxn_->set_weight( core::scoring::chainbreak, n1*delta_weight );

			(*scorefxn_)(pose);
			if ( tr.visible() ) { scorefxn_->show_line( tr.Info , pose ); }
			tr.Info << std::endl;
			mc_->score_function( *scorefxn_ );

			for( int n2 = 1; n2 <= (int)ninner_cycles_; ++n2 ) {
				// cool
				temperature *= gamma;
				mc_->set_temperature( temperature );

				// randomly do something
				if( numeric::random::uniform()*nouter_cycles_ > n1 ) {
					random_move.apply(pose);
				} else {
					protocols::loops::Loops::const_iterator it( loops_.one_random_loop() );
					protocols::loops::ccd_moves(5, pose, *movemap_, it->start(), it->stop(), it->cut() );
				}

				mc_->boltzmann( pose );
			}
		}
		mc_->recover_low( pose );
		mc_->show_counters();

		scorefxn_->set_weight( core::scoring::chainbreak, 1.0 );
		(*scorefxn_)(pose);

		loops_closed = ( pose.energies().total_energies()[ core::scoring::chainbreak ] ) <= loops_.size()*0.5;
		if (!loops_closed) {
			tr << "Loops not closed! ("
			   << pose.energies().total_energies()[ core::scoring::chainbreak ]
			   << " > " << loops_.size()*0.5 << ")" << std::endl;
			grow_flexible( loop_melt_, nres );
		} else {
			tr << "Loops closed! ("
			   << pose.energies().total_energies()[ core::scoring::chainbreak ]
			   << " <= " << loops_.size()*0.5 << ")" << std::endl;
		}
	}

	//////////////
	// fastrelax -- keep foldtree
	relax::FastRelax fast_relax( fa_scorefxn_ );
	fast_relax.apply( pose );
}

////
//// grow loops
//// dont allow jump residues to be flexible
void
AutoRBMover::grow_flexible( core::Size maxlen , core::Size nres , core::Size minlen ) {
	if (maxlen == 0) return;

	tr << "EXTENDING LOOPS:" << std::endl;
	for ( core::Size i=1; i <= loops_.size(); i++ ) {
		core::Size extend_start = (core::Size) numeric::random::random_range(minlen, maxlen-minlen);
		core::Size extend_stop  = (core::Size) numeric::random::random_range(minlen, maxlen-minlen);
		if ( ( extend_start == 0 ) && ( extend_stop == 0 ) ) {
			if ( numeric::random::uniform() > 0.5) extend_start = 1;
			else extend_stop  = 1;
		}

		// dont go past termini
		if (loops_[i].start()  < 1 + extend_start) extend_start = loops_[i].start()-1;
		if (loops_[i].stop() + extend_stop > nres ) extend_stop = nres - loops_[i].stop();

		// make sure we dont go over a jump point
		for (core::Size j=1; j<=jumps_.size(); ++j) {
			if (loops_[i].start()-extend_start <= jumps_[j] && loops_[i].start() > jumps_[j])
				extend_start = loops_[i].start() - jumps_[j] + 1;
			if (loops_[i].stop()+extend_stop >= jumps_[j] && loops_[i].stop() < jumps_[j])
				extend_stop = jumps_[j] - loops_[i].stop() +1;
		}

		loops_[i].set_start( loops_[i].start()-extend_start );
		loops_[i].set_stop( loops_[i].stop()+extend_stop );
	}
	tr << loops_ << std::endl;
}

////
//// set up foldtree, variants, movemap, etc.
void
AutoRBMover::setup_topology( core::pose::Pose & pose ) {
	core::Size nres = pose.total_residue()-1; // terminal VRT

	rigid_segs_.clear();
	rb_chunks_.clear();
	jumps_.clear();
	loops_.clear();

	// dssp parse
	protocols::jumping::Dssp secstruct( pose );
	ObjexxFCL::FArray1D< char > dssp_pose( nres );
	secstruct.dssp_reduced (dssp_pose);
	secstruct.insert_ss_into_pose( pose );

	// find all helices > 5 residues
	//          strands > 3 residues
	utility::vector1< RBSegment > simple_segments;
	bool in_helix=false, in_strand=false;
	int ss_start = -1;
	for (int i=1; i<=(int)nres; ++i) {
		// strand end
		if (dssp_pose(i) != 'E' && in_strand) {
			in_strand = false;
			if (i-ss_start >= 3)
				simple_segments.push_back( RBSegment( ss_start, i-1, 'E' ) );
		}
		// helix end
		if (dssp_pose(i) != 'H' && in_helix) {
			in_helix = false;
			if (i-ss_start >= 5)
				simple_segments.push_back( RBSegment( ss_start, i-1, 'H' ) );
		}
		// strand start
		if (dssp_pose(i) == 'E' && !in_strand)  {
			in_strand = true;
			ss_start = i;
		}
		// helix start
		if (dssp_pose(i) == 'H' && !in_helix)  {
			in_helix = true;
			ss_start = i;
		}
	}

	// put at least 2 "loop residues" inbetween each RB segment.
	// always eat into the helix (even if this leaves a helix < 3 reses)
	for (int i=1; i< (int)simple_segments.size(); ++i) {
		if (simple_segments[i+1][1].start() - simple_segments[i][1].end() <= 2) {
			if (simple_segments[i][1].char_type() == 'H') {
				simple_segments[i][1].set_end( simple_segments[i+1][1].start()-3 );
			} else if (simple_segments[i+1][1].char_type() == 'H') {
				simple_segments[i+1][1].set_start( simple_segments[i][1].end()+3 );
			} else {
				// eat into longer strand (will only need to chomp 1 res)
				if (simple_segments[i+1][1].length() > simple_segments[i][1].length() )
					simple_segments[i+1][1].set_start( simple_segments[i][1].end()+3 );
				else
					simple_segments[i][1].set_end( simple_segments[i+1][1].start()-3 );
			}
		}
	}

	// check for "1-residue" loops at termini; extend RB segments if necessary
	if (simple_segments[1][1].start() == 2) simple_segments[1][1].set_start(1);
	if (simple_segments[simple_segments.size()][1].end() == nres-1) simple_segments[simple_segments.size()][1].set_end(nres);


	// auto-gen loops
	if ( simple_segments.size() > 1 ) {
		std::sort( simple_segments.begin(), simple_segments.end(), RB_lt());
		int start_res=1, end_res=simple_segments[1][1].start()-1;
		int cutpt = (start_res+end_res)/2;
		int nsegs = simple_segments.size();

		if (end_res >= start_res)
			loops_.push_back( protocols::loops::Loop(start_res, end_res, 0, 0.0, false) );
		for (int i=1; i<nsegs; ++i) {
			start_res = simple_segments[i][1].end()+1;
			end_res   = simple_segments[i+1][1].start()-1;
			loops_.push_back( protocols::loops::Loop(start_res, end_res, 0, 0.0, false) );
		}
		start_res = simple_segments[nsegs][1].end()+1;
		end_res   = nres;
		cutpt = (start_res+end_res)/2;
		if (end_res >= start_res)
			loops_.push_back( protocols::loops::Loop(start_res, end_res, 0, 0.0, false) );

		// TODO: split loops on cutpoints from original pose
	}


	// now combine paired strands into a compound segment
	//   look for NH--O distance < 2.6A (?)
	utility::vector1< utility::vector1< int > > compound;
	utility::vector1< core::Size > parent_seg( simple_segments.size(), 0 );
	for (int i=1; i< (int)simple_segments.size(); ++i) {
		if (simple_segments[i][1].char_type() != 'E') continue;

		utility::vector1< int > this_seg( 1, i );
		for (int j=i+1; j<=(int)simple_segments.size(); ++j) {
			if (simple_segments[j][1].char_type() != 'E') continue;

			// foreach res in i,j
			bool found = false;
			for (int ii=(int)simple_segments[i][1].start(); ii<=(int)simple_segments[i][1].end() && !found; ++ii)
			for (int jj=(int)simple_segments[j][1].start(); jj<=(int)simple_segments[j][1].end() && !found; ++jj) {
				core::Real d2=10;

				if (pose.residue(ii).aa() != core::chemical::aa_pro && pose.residue(jj).aa() != core::chemical::aa_pro)
					d2 = std::min(
						(pose.residue(ii).atom("H").xyz() - pose.residue(jj).atom("O").xyz()).length_squared() ,
						(pose.residue(ii).atom("O").xyz() - pose.residue(jj).atom("H").xyz()).length_squared() );
				else if (pose.residue(jj).aa() != core::chemical::aa_pro)
					d2 = (pose.residue(ii).atom("O").xyz() - pose.residue(jj).atom("H").xyz()).length_squared();
				else if (pose.residue(ii).aa() != core::chemical::aa_pro)
					d2 = (pose.residue(jj).atom("O").xyz() - pose.residue(ii).atom("H").xyz()).length_squared();

				if (d2 < 2.6*2.6) {
					this_seg.push_back(j);
					if (parent_seg[i] == 0)
						parent_seg[i] = i;
					if (parent_seg[j] == 0)
						parent_seg[j] = parent_seg[i];
					else {
						// tricky case ... j is already mapped
						// in this case map everything mapped to i to parent_seg[j]
						for (int k=1; k<j; ++k)
							if ((int)parent_seg[k] == i) parent_seg[k] = parent_seg[j];
					}

					tr << "Merging " << j << " (" << simple_segments[j][1].start() << "," << simple_segments[j][1].end() << ") ";
					tr << " to " << i << " (" <<  simple_segments[i][1].start() << "," << simple_segments[i][1].end() << ") " << std::endl;
					found = true;
				}
			}
		}
	}

	// make the compound segments
	for (int i=1; i< (int)simple_segments.size(); ++i) {
		if (((int)parent_seg[i]) != i) continue; // not compound or already added
		utility::vector1< RBSegment > thisLockSeg;
		for (core::Size j=i; j<=simple_segments.size(); ++j) {
			if ( ((int)parent_seg[j]) == i )
				thisLockSeg.push_back( simple_segments[ j ] );
		}
		rigid_segs_.push_back( RBSegment( thisLockSeg ) );
	}
	// add in all other simple segs
	for (int i=1; i<=(int)simple_segments.size(); ++i)
		if (parent_seg[i] == 0)
			rigid_segs_.push_back( simple_segments[i] );

	// sort loops & rbsegs, choose cutpoints
	std::sort( rigid_segs_.begin(), rigid_segs_.end(), RB_lt());
	std::sort( loops_.v_begin(), loops_.v_end(), protocols::loops::Loop_lt());
	loops_.auto_choose_cutpoints( pose );

	// define chunks
	rb_chunks_ = rigid_segs_;
	for (int i=1; i<=(int)rigid_segs_.size(); ++i) {
		for (int j=1; j<=(int)rigid_segs_[i].nContinuousSegments() ; ++j ) {
			core::Size c_low=1, c_high=nres;
			for (int k=1; k<=(int)loops_.size(); ++k) {
				if (loops_[k].cut() < rigid_segs_[i][j].start() && loops_[k].cut() > c_low )
					c_low = loops_[k].cut();
				if (loops_[k].cut()+1 > rigid_segs_[i][j].end() && loops_[k].cut()+1 < c_high )
					c_high = loops_[k].cut()+1;
			}
			rb_chunks_[i][j].set_start(c_low);
			rb_chunks_[i][j].set_end(c_high);
		}
	}

	// call help function to set fold tree
	jumps_ = setup_pose_rbsegs_keep_loops( pose,  rigid_segs_ , loops_,  movemap_ );
}

}
}

