// -*- mode:c++;tab-width:2;indent-tabs-mode:t;show-trailing-whitespace:t;rm-trailing-spaces:t -*-
// vi: set ts=2 noet:
//  CVS information:
//  $Revision: 1.1.2.1 $
//  $Date: 2005/11/07 21:05:35 $
//  $Author: rhiju $


// Rosetta Headers
#include "pose_rna_csa.h"
#include "after_opts.h"
#include "files_paths.h"
#include "force_barcode.h"
#include "pose.h"
#include "pose_io.h"
#include "pose_rms.h"
#include "pose_rna.h"
#include "pose_rna_base_doublet_classes.h"
#include "pose_rna_ns.h"
#include "pose_rna_fragments.h"
#include "pose_rna_jumping.h"
#include "silent_input.h"
#include "trajectory.h" //For graphics.

// ObjexxFCL Headers
#include <ObjexxFCL/ObjexxFCL.hh>
#include <ObjexxFCL/string.functions.hh>

// Numeric Headers
#include <numeric/xyz.functions.hh> // for dihedral()

// Utility Headers
#include <utility/basic_sys_util.hh>
#include <utility/io/ozstream.hh>

// C++ Headers
#include <vector>

// BOINC
#ifdef BOINC
#include "boinc_rosetta_util.h"
#include "counters.h"
#include "trajectory.h"
#endif

using namespace pose_ns;

typedef std::vector < Pose* >  PoseList;
typedef std::map < Pose*, float > PoseFloat;
typedef std::map < Pose*, int > PoseInt;
typedef std::map < Pose*, PoseFloat > PosePoseFloat;

//////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////
// Pretty simple stand-alone protocol for conformational space annealing
// for RNA. No communication between computers, just a single-processor
// run with a bank of poses.
//////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////
void
do_some_fragment_insertions( Pose & pose, int const frag_size = 3 )
{

  // For now keep a constant weight map.
  static const Score_weight_map w = setup_rna_weight_map();

  static float const init_temp( 2.0 );
  Monte_carlo mc( pose, w, init_temp );
  mc.set_autotemp( true, init_temp );

  static const int cycles = intafteroption( "cycles", 100 );

  //Following could be input parameters, and change between rounds.
  static bool const smooth( false );
  static std::string const trial_type = "frag";

  for (int i = 1; i <= cycles; i++ ){
    RNA_move_trial( pose, mc, smooth, trial_type, frag_size );
  }

  pose = mc.low_pose();

}

//////////////////////////////////////////////////////////////////////////
void
setup_rna_csa_initial_bank( PoseList & pose_bank, int const NUM_BANK )
{

  Pose start_pose;

  initialize_query_pose_rna( start_pose );

  read_rna_fragments();

  // OK, will put this in later.
  //  bool const user_input = truefalseoption("s");
  //  Silent_file_data decoys;
  //  if (user_input){
  //    std::string silent_file_name = stringafteroption("s");
  //    decoys.read_file( silent_file_name, true /* fullatom */);
  //  }

  bool native_exists( false );

  Pose * native_pose_p;
  native_pose_p = new Pose;
  if (truefalseoption("n")){
    native_exists = true;
    std::string native_file_name = stringafteroption("n","blah.pdb");

    bool success = pose_from_pdb( *native_pose_p, native_file_name,
				  true, false, true );
    if (!success){
      std::cout << "Had trouble with native pdb " << native_file_name << std::endl;
			utility::exit( EXIT_FAILURE, __FILE__, __LINE__);
    }

    pose_to_native( *native_pose_p );

  }


  barcode_initialize_start( start_pose.total_residue() );
  barcode_initialize_decoy();

  bool const jumping = truefalseoption("pairing_file") || truefalseoption("jumps_from_barcode");

  for ( int n = 1; n <= NUM_BANK; n++ ){

    Pose * pose_p;
    pose_p = new Pose;
    *pose_p = start_pose;
    if (jumping) pose_from_random_base_pairings( *pose_p );

    if (native_exists) pose_p->set_native_pose( *native_pose_p );

    //Heat 'er up.
    int const HEAT_CYCLES = 5 * pose_p->total_residue();
    int const frag_size = 3;

    for (int i = 1; i <= HEAT_CYCLES; i++ ){
      random_fragment_insertions( *pose_p, frag_size, 1 );
    }

    do_some_fragment_insertions( *pose_p );

    pose_bank.push_back( pose_p );

  }


}


//////////////////////////////////////////////////////////////////////////
float
get_average_distance( PoseList & pose_bank, PosePoseFloat & pairwise_distance)
{

  float sum_distance( 0.0 );
  int count( 0 );

  int const num_bank = pose_bank.size();

  for (int m = 0; m < num_bank; m++ ){
    for (int n = m+1; n < num_bank; n++ ){

      sum_distance += pairwise_distance[ pose_bank[m]] [ pose_bank[n] ];
      count++;

    }
  }

  float const average_start_distance = sum_distance/count;

  return average_start_distance;
}


//////////////////////////////////////////////////////////////////////////
float
base_pair_distance( Pose & pose1, Pose & pose2 ){
	float distance = 0.0;

	distance += num_unshared_basepairs( pose1, pose2 );
	distance += num_unshared_basepairs( pose2, pose1 );

	return distance;
}

//////////////////////////////////////////////////////////////////////////
float
get_distance( Pose & pose1, Pose & pose2 )
{
	static bool const use_rmsd_only = truefalseoption("csa_use_rmsd_only");
	static bool const csa_no_rmsd = truefalseoption("csa_no_rmsd");

	float distance ( 0.0 );

	float rmsd = CA_rmsd( pose1, pose2 );
	float bp_distance = base_pair_distance( pose1, pose2 );

	if (use_rmsd_only) {

		distance = rmsd;

	} else if (csa_no_rmsd) {

		distance = bp_distance;

	} else {

		distance = std::min( rmsd, bp_distance );

	}

	return distance;
}

//////////////////////////////////////////////////////////////////////////
void
get_pairwise_distance( PoseFloat & pairwise_distance_to_new_pose,
		       Pose * pose_p,
		       PoseList & pose_bank )
{

  int const num_bank = pose_bank.size();
  for (int n = 0; n < num_bank; n++ ){

    if (pose_bank[n] == pose_p ) continue;

    float const distance = get_distance( *pose_bank[n], *pose_p );

    pairwise_distance_to_new_pose[ pose_bank[n] ] = distance;
  }

}

//////////////////////////////////////////////////////////////////////////
void
fill_pairwise_distance(
    PoseList & pose_bank,
    PosePoseFloat & pairwise_distance
)
{

  int const num_bank = pose_bank.size();

  for (int n = 0; n < num_bank; n++ ){

    get_pairwise_distance( pairwise_distance[ pose_bank[n] ], pose_bank[n], pose_bank );

  }

}


//////////////////////////////////////////////////////////////////////////
float
get_cutoff( float const average_start_distance,
	    int const n,
	    int const NUM_ROUNDS)
{
  float cutoff( 0.0 );

  static float const FINAL_CUTOFF = realafteroption( "final_cutoff", 1.0 );

  if (n > 0.75 * NUM_ROUNDS ){
    cutoff = FINAL_CUTOFF;
  } else {
    cutoff = average_start_distance *
      std::pow( static_cast<double>(FINAL_CUTOFF/average_start_distance),
								static_cast<double>(n)  / ( 0.75 * NUM_ROUNDS ) ) ;
  }

  return cutoff;
}


//////////////////////////////////////////////////////////////////////////
Pose * get_lowest_score_decoy( PoseList & pose_subbank )
{
  int const num_subbank = pose_subbank.size();

  float best_energy = pose_subbank[0]->get_0D_score( SCORE );
  Pose * best_pose_p = pose_subbank[0];

  for (int n = 1; n < num_subbank; n++ ){
    float const energy = pose_subbank[n]->get_0D_score( SCORE );
    if ( energy < best_energy ) {
      best_energy = energy;
      best_pose_p = pose_subbank[n];
    }
  }

  return best_pose_p;

}

//////////////////////////////////////////////////////////////////////////
Pose * get_highest_score_decoy( PoseList & pose_subbank )
{
  int const num_subbank = pose_subbank.size();

  float worst_energy = pose_subbank[0]->get_0D_score( SCORE );
  Pose * worst_pose_p = pose_subbank[0];

  for (int n = 1; n < num_subbank; n++ ){
    float const energy = pose_subbank[n]->get_0D_score( SCORE );
    if ( energy > worst_energy ) {
      worst_energy = energy;
      worst_pose_p = pose_subbank[n];
    }
  }

  return worst_pose_p;

}

//////////////////////////////////////////////////////////////////////////
// This currently follows the Lee/Scherag a prescription for
// going through the bank one by one.
//
// A potential gain might come from recombining this pose with
// another one in the bank -- NOT YET CODED!
//

Pose*
choose_next_decoy(
     PoseList & pose_bank,
     PoseInt & pose_used,
     PosePoseFloat & pairwise_distance )
{

  //Find last pose that was used.
  int const num_bank = pose_bank.size();

  int last_used = -1;
  Pose * last_pose_p = pose_bank[ 0 ];

  for (int n = 0; n < num_bank; n++ ){
    if ( pose_used[ pose_bank[n] ] > last_used ){
      last_used = pose_used[ pose_bank[n] ];
      last_pose_p = pose_bank[ n ];
    }
  }

  // Set up pose list from which to choose lowest
  // energy decoy.
  PoseList pose_subbank;

  if (last_used == -1) {//No decoys chosen yet. Look at whole list.

    pose_subbank = pose_bank;

  } else {

		//Following could be done more efficiently, if required.

		//How far was the last pose from the rest of the decoys left in the bank?
		float sum_distance = 0.0;
		int count( 0 );
    for (int n = 0; n < num_bank; n++ ){
      if ( pose_used[ pose_bank[n] ] == -1 ) {// not used yet
				float const distance = pairwise_distance[ last_pose_p] [ pose_bank[n] ];
				sum_distance += distance;
				count++;
			}
		}

		//Set a cutoff equal to half this length scale.
		float const cutoff = 0.5 * (sum_distance/count);

		//Now assemble a list of decoys that are far from the pose.
    for (int n = 0; n < num_bank; n++ ){
      if ( pose_used[ pose_bank[n] ] == -1 ) {// not used yet
				float const distance = pairwise_distance[ last_pose_p] [ pose_bank[n] ];
				if ( distance  >= cutoff ) {
					pose_subbank.push_back( pose_bank[n] );
				}
			}
		}

		assert( pose_subbank.size() > 0 );

  }


  //Find the lowest energy configuration in that set
  Pose * chosen_pose_p = get_lowest_score_decoy( pose_subbank );

  // Mark it as chosen.
  pose_used[ chosen_pose_p ] = last_used + 1;

  Pose * copy_of_chosen_pose_p;
  copy_of_chosen_pose_p = new Pose;
  *copy_of_chosen_pose_p = *chosen_pose_p;

  return copy_of_chosen_pose_p;
}



//////////////////////////////////////////////////////////////////////////
void
remove_pose_from_bank( Pose * old_pose_p,
		       PoseList & pose_bank,
		       PoseInt & pose_used,
		       PosePoseFloat & pairwise_distance )
{
  pose_bank.erase( find( pose_bank.begin(), pose_bank.end(), old_pose_p) );
  pose_used.erase( old_pose_p );
  pairwise_distance.erase( old_pose_p );

  int const num_bank = pose_bank.size();

  for (int n = 0; n < num_bank; n++ ){
    pairwise_distance[ pose_bank[n] ].erase( old_pose_p );
  }

  //Yer gone!
  delete old_pose_p;
}

//////////////////////////////////////////////////////////////////////////
void
add_pose_to_bank( Pose * new_pose_p,
		  PoseList & pose_bank,
		  PosePoseFloat & pairwise_distance,
		  PoseFloat & pairwise_distance_to_new_pose)
{
  pose_bank.push_back( new_pose_p );

  int const num_bank = pose_bank.size();

  for (int n = 0; n < num_bank; n++ ){
    float const distance = pairwise_distance_to_new_pose[ pose_bank[n] ];
    pairwise_distance[ new_pose_p   ][ pose_bank[n] ] = distance;
    pairwise_distance[ pose_bank[n] ][ new_pose_p   ] = distance;
  }

}

//////////////////////////////////////////////////////////////////////////
void
switch_out_poses_in_bank(
     Pose * new_pose_p,
     Pose * old_pose_p,
     PoseList & pose_bank,
     PoseInt & pose_used,
     PosePoseFloat & pairwise_distance,
     PoseFloat & pairwise_distance_to_new_pose )
{

  pose_used[ new_pose_p ] = pose_used[ old_pose_p ];
  remove_pose_from_bank( old_pose_p, pose_bank, pose_used, pairwise_distance );
  add_pose_to_bank( new_pose_p, pose_bank, pairwise_distance, pairwise_distance_to_new_pose );

}

//////////////////////////////////////////////////////////////////////////
void
update_bank( pose_ns::Pose* pose_p,
	     PoseList & pose_bank,
	     PoseInt & pose_used,
	     PosePoseFloat & pairwise_distance,
	     float const cutoff_distance )
{

  PoseFloat pairwise_distance_to_new_pose;

  get_pairwise_distance( pairwise_distance_to_new_pose, pose_p, pose_bank );

  //Are we too close to an existing bank configuration?
  int const num_bank = pose_bank.size();
  Pose * closest_pose_p = pose_bank[ 0 ];

  float best_distance = pairwise_distance_to_new_pose[ pose_bank[0] ];
  for (int n = 0; n < num_bank; n++ ){
    float const distance = pairwise_distance_to_new_pose[ pose_bank[n] ];
    if ( distance < best_distance ){
      best_distance  = distance;
      closest_pose_p = pose_bank[n];
    }
  }

  float new_score  = pose_p->get_0D_score( SCORE );
  float bank_score( new_score );

  Pose * old_pose_p = closest_pose_p;

  if ( best_distance  < cutoff_distance ){
    //Consider replacing the closest pose with the newest pose.
    bank_score = closest_pose_p->get_0D_score( SCORE );
		//    std::cout << "Considering replacement of close pose " << new_score << " " << bank_score << "    pairwise distance = " << best_distance << std::endl;
  } else {
    // Consider chucking out the highest enery pose in favor of this new one.
    old_pose_p = get_highest_score_decoy( pose_bank );
    bank_score = old_pose_p->get_0D_score( SCORE );
		//    std::cout << "Considering replacement of worst pose " << new_score << " " << bank_score << std::endl;
  }

  if ( new_score < bank_score ){
    switch_out_poses_in_bank( pose_p, old_pose_p,
			      pose_bank, pose_used, pairwise_distance,
			      pairwise_distance_to_new_pose );
		//    std::cout << "Made the replacement!" << std::endl;
  }


}



//////////////////////////////////////////////////////////////////////////
void
output_all_decoys( PoseList & pose_bank, std::string const tag_prefix, silent_io::Silent_out & out )
{
	static bool const minimize_rna = truefalseoption("minimize_rna");

  int const num_bank = pose_bank.size();
  for (int n = 0; n < num_bank; n++ ){

		Pose & pose = *pose_bank[ n ];

    std::string const tag = tag_prefix + "_" + string_of(n);

		put_the_final_touch_on_rna( pose, out, tag, minimize_rna );

  }

}

//////////////////////////////////////////////////////////////////////////
void
output_all_decoys( PoseList & pose_bank, std::string const tag_prefix, std::string filename )
{

  if (filename.size()< 1)  {
		filename =  files_paths_pdb_out_prefix_nochain()+".out";
	} else {
		//remove previous copy!
		std::string const command =  "rm -f "+filename;
		std::system( command.c_str() );
	}

  silent_io::Silent_out out( filename );
	output_all_decoys( pose_bank, tag_prefix, out );
	out.close();

}

//////////////////////////////////////////////////////////////////////////
void
rna_csa( silent_io::Silent_out & out, std::string const tag_prefix )
{

  PoseList  pose_bank;

  int const NUM_ROUNDS = intafteroption("num_rounds",100);
  int const NUM_BANK   = intafteroption("num_bank",200);

	static int const output_each_csa_round = truefalseoption( "output_each_csa_round" );

  // Setup initial bank
  setup_rna_csa_initial_bank( pose_bank, NUM_BANK );

  // Define initial distance cutoff
  PosePoseFloat pairwise_distance;
  fill_pairwise_distance( pose_bank, pairwise_distance );
  float const average_start_distance  = realafteroption( "start_cutoff", get_average_distance( pose_bank, pairwise_distance ) );

  std::cout << "Average pairwise distance in starting bank: " << average_start_distance << std::endl;

  float average_distance( average_start_distance );

	int frag_size = intafteroption( "frag_size", 3);

  // Main loop
  for (int n = 0; n < NUM_ROUNDS; n++ ){

    float const current_distance_cutoff =
      get_cutoff( average_start_distance, n, NUM_ROUNDS);

		std::cout << " Starting ROUND " << n << " of " << NUM_ROUNDS << ". Distance cutoff: " << current_distance_cutoff << std::endl;
		files_paths::mode_title = "CSA round "+string_of(n);
		clear_trajectory();

    PoseInt pose_used;
    for (int i = 0; i < NUM_BANK; i++ ) {
			pose_used[ pose_bank[ i ] ] = -1;
		}

		update_frag_size( n, NUM_ROUNDS, frag_size );

    // Go through current bank, one by one
    for (int i = 0; i < NUM_BANK; i++ ){

      //Follow prescription of lee and scheraga for order
      // with which to march through decoys.
      Pose* pose_p = choose_next_decoy( pose_bank, pose_used, pairwise_distance );

      // Do some fragment insertions
      do_some_fragment_insertions( *pose_p, frag_size );

      // update bank.
      update_bank( pose_p, pose_bank, pose_used, pairwise_distance, current_distance_cutoff );

    }

    //Done with all bank decoys -- next round!
    // Recalculate average_distance.
    average_distance = get_average_distance( pose_bank, pairwise_distance );

    if (output_each_csa_round) output_all_decoys( pose_bank, tag_prefix, "round"+string_of(n)+".out" );

  }

  // Output all decoys.
  output_all_decoys( pose_bank, tag_prefix, out );

}




//////////////////////////////////////////////////////////////////////////
void
rna_csa_test()
{

	using namespace silent_io;

	Silent_out out( files_paths_pdb_out_prefix_nochain()+".out" );

	int nstruct = intafteroption("nstruct", 1);

	////////////////////////////////
	////////////////////////////////
	// MAIN LOOP Here we go!!!
	////////////////////////////////
	////////////////////////////////
	for (int n = 1; n <= nstruct; n++ ){
		std::string const tag_prefix( "S_"+string_of(n) );
		std::string const tag = tag_prefix+"_0"; //First of the CSA decoys.

		if ( !out.start_decoy(tag) ) continue; // already done or started

		rna_csa( out, tag_prefix );

#ifdef BOINC
		out.append_to_list( tag ); // mark as done
		store_low_info(); // for trajectory plotting.
		clear_trajectory();
		counters::monte_carlo_ints::ntrials = 0;
		int farlx_stage = 0;
		bool ready_for_boinc_end = boinc_checkpoint_in_main_loop(n, n, nstruct, farlx_stage);
		if (ready_for_boinc_end)	return; // Go back to main, which will then go to BOINC_END and shut off BOINC.
#endif

	}

}
