// -*- mode:c++;tab-width:2;indent-tabs-mode:t;show-trailing-whitespace:t;rm-trailing-spaces:t -*-
// vi: set ts=2 noet:
//  CVS information:
//  $Revision: 1.1.2.1 $
//  $Date: 2005/11/07 21:05:35 $
//  $Author: rhiju $


// Rosetta Headers
#include "pose_rna_featurizer.h"
#include "after_opts.h"
#include "dna_ns.h"
#include "pose.h"
#include "pose_rna.h"
#include "pose_rna_base_doublet_classes.h"
#include "pose_rna_fragments.h"
#include "pose_rna_ns.h"
#include "pose_rna_pdbstats.h"
#include "silent_input.h"

// ObjexxFCL Headers
#include <ObjexxFCL/ObjexxFCL.hh>
#include <ObjexxFCL/FArray1D.hh>
#include <ObjexxFCL/FArray2D.hh>
#include <ObjexxFCL/FArray3D.hh>

// Numeric Headers
#include <numeric/xyz.functions.hh> // for dihedral()

// Utility Headers
#include <utility/basic_sys_util.hh>
#include <utility/io/ozstream.hh>

// C++ Headers
#include <map>

////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////
//
// RNA Featurizer.
//
// Three intended applications:
//
// 1. Visualization of RNA features in ab initio fragment-assembled decoys.
//
// 2. Creation of refined low resolution ab initio potentials based on
//    fitting hi resolution energies with a linear sum of features.
//
// 3. Barcodes to allow control of ab initio runs, including the creation
//    of highly native-like decoys using native features.
//
////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////

//cheating, should make this a class...
namespace rna_features {

  int numdecoys( 0 );
  int numres( 0 );

  std::map < Base_pair, int >  base_pair_count_map;
  std::map < Base_stack, int > base_stack_count_map;

  FArray1D_int is_watson_crick_base_paired( param::MAX_RES() );
  FArray1D_int is_hoogsteen_base_paired( param::MAX_RES() );
  FArray1D_int is_sugar_base_paired( param::MAX_RES() );
  FArray1D_int is_bulged( param::MAX_RES() );

  FArray2D_int stack_count( 4, param::MAX_RES() );

  FArray2D_int alpha_count  ( 3, param::MAX_RES() );
  FArray2D_int beta_count   ( 3, param::MAX_RES() );
  FArray2D_int gamma_count  ( 3, param::MAX_RES() );
  FArray2D_int delta_count  ( 2, param::MAX_RES() ); // 3'-endo/2'-endo
  FArray2D_int epsilon_count( 2, param::MAX_RES() );
  FArray2D_int zeta_count   ( 2, param::MAX_RES() );
  FArray2D_int chi_count( 2, param::MAX_RES() ); // syn/anti

  FArray2D_int pseudo_torsion_count( 4, param::MAX_RES() ); //duarte/pyle pseudo-torsions

}

////////////////////////////////////////////////////////////////////////////
void
initialize_feature_maps( int const total_residue )
{

  using namespace rna_features;

  static bool init = false;

  if (init) return;

  numres = total_residue;

  for ( int i = 1; i <= numres; i++ ) {
    for ( int j = i+1; j <= numres; j++ ) {
      for ( int k = 1; k <= 3; k++ ) {
	for ( int m = 1; m <= 3; m++ ) {
	  for ( int o = 1; o <= 2; o++ ) {

	    Base_pair base_pair;

	    base_pair.res1 = i;
	    base_pair.edge1 = k;

	    base_pair.res2 = j;
	    base_pair.edge2 = m;

	    base_pair.orientation = o;

	    base_pair_count_map[ base_pair ] = 0;
	  }
	}
      }
    }
  }


  for ( int i = 1; i <= numres; i++ ) {
    for ( int j = 1; j <= numres; j++ ) {
      for ( int o = 1; o <= 2; o++ ) {
	for ( int w = 1; w <= 2; w++ ) {

	  Base_stack base_stack;

	  base_stack.res1 = i;
	  base_stack.res2 = j;

	  base_stack.orientation = o;

	  base_stack.which_side = w;

	  base_stack_count_map[ base_stack ] = 0;
	}
      }
    }
  }

  is_watson_crick_base_paired = 0;
  is_hoogsteen_base_paired = 0;
  is_sugar_base_paired = 0;
  is_bulged = 0;

  stack_count = 0;

  alpha_count   = 0;
  beta_count    = 0;
  gamma_count   = 0;
  delta_count   = 0;
  epsilon_count = 0;
  zeta_count    = 0;
  chi_count = 0;

  pseudo_torsion_count = 0;

  init = true;

}

////////////////////////////////////////////////////////////////////////////
void
print_feature_count( int const i, FArray1D_int & feature_array, std::string const feature_tag,
		     utility::io::ozstream & out)
{
  int const count = feature_array(i);
  if ( count > 0 ) out << feature_tag << " " << i << " " << count << std::endl;
}

////////////////////////////////////////////////////////////////////////////
void
print_feature_count( int const i, FArray2D_int & feature_array, std::string const feature_tag,
		     utility::io::ozstream & out)
{
  int const numtypes = feature_array.size1();
  for (int j = 1; j <= numtypes; j++ ){
    int const count = feature_array(j, i);
    if ( count > 0 ) out << feature_tag << " " << i << " " << j << " " << count << std::endl;
  }
}

////////////////////////////////////////////////////////////////////////////
void
output_feature_frequencies( utility::io::ozstream & out ){

  using namespace rna_features;

  out << "NUM_DECOYS "   << numdecoys << std::endl;
  out << "NUM_RESIDUES " << numres << std::endl;

  for ( int i = 1; i <= numres; i++ ) {
    for ( int j = i+1; j <= numres; j++ ) {
      for ( int k = 1; k <= 3; k++ ) {
	for ( int m = 1; m <= 3; m++ ) {
	  for ( int o = 1; o <= 2; o++ ) {

	    Base_pair base_pair;

	    base_pair.res1 = i;
	    base_pair.edge1 = k;

	    base_pair.res2 = j;
	    base_pair.edge2 = m;

	    base_pair.orientation = o;

	    int const count = base_pair_count_map[ base_pair ];
	    if ( count > 0 ) out << "BASE_PAIR " << base_pair << " " << count << std::endl;

	  }
	}
      }
    }
  }


  for ( int i = 1; i <= numres; i++ ) {
    for ( int j = i+1; j <= numres; j++ ) {
      for ( int o = 1; o <= 2; o++ ) {
	for ( int w = 1; w <= 2; w++ ) {

	  Base_stack base_stack;

	  base_stack.res1 = i;
	    base_stack.res2 = j;

	    base_stack.orientation = o;

	    base_stack.which_side = w;

	    int const count = base_stack_count_map[ base_stack ];
	    if ( count > 0 ) out << "BASE_STACK " << base_stack << " " << count << std::endl;

	}
      }
    }
  }


  for ( int i = 1; i <= numres; i++ ) print_feature_count( i, is_watson_crick_base_paired, "WATSON_CRICK_BP", out );
  for ( int i = 1; i <= numres; i++ ) print_feature_count( i, is_hoogsteen_base_paired, "HOOGSTEEN_BP", out);
  for ( int i = 1; i <= numres; i++ ) print_feature_count( i, is_sugar_base_paired, "SUGAR_BP ", out);
  for ( int i = 1; i <= numres; i++ ) print_feature_count( i, is_bulged, "BULGED", out);
  for ( int i = 1; i <= numres; i++ ) print_feature_count( i, stack_count, "NUM_STACKS", out );

  for ( int i = 1; i <= numres; i++ ) print_feature_count( i, alpha_count, "ALPHA_BIN", out );
  for ( int i = 1; i <= numres; i++ ) print_feature_count( i, beta_count, "BETA_BIN", out );
  for ( int i = 1; i <= numres; i++ ) print_feature_count( i, gamma_count, "GAMMA_BIN", out );
  for ( int i = 1; i <= numres; i++ ) print_feature_count( i, delta_count, "DELTA_BIN", out );
  for ( int i = 1; i <= numres; i++ ) print_feature_count( i, epsilon_count, "EPSILON_BIN", out );
  for ( int i = 1; i <= numres; i++ ) print_feature_count( i, zeta_count, "ZETA_BIN", out );
  for ( int i = 1; i <= numres; i++ ) print_feature_count( i, chi_count, "CHI_BIN", out );
  for ( int i = 1; i <= numres; i++ ) print_feature_count( i, pseudo_torsion_count, "PSEUDO_TORSION_BIN", out );


}

////////////////////////////////////////////////////////////////////////////
void
rna_base_doublet_featurize( pose_ns::Pose & pose )
{

  using namespace rna_features;

  int const total_residue = pose.total_residue();

  // For now, just do base pairs.
  Energy_base_pair_list  scored_base_pair_list;
  Energy_base_stack_list scored_base_stack_list;


  //Dummy variables...
  FArray2D_bool scored_base_pair( total_residue, total_residue );
  float rna_bs_score,	rna_bp_w_score,	rna_bp_h_score,	rna_bp_s_score,	rna_axis_score;
  float rna_stagger_score, rna_bulge_score, rna_contact_score,	rna_long_range_contact_score;
  FArray2D_bool edge_is_base_pairing( total_residue, rna_scoring::NUM_EDGES, false );


  eval_rna_score( pose, rna_bs_score, rna_bp_w_score, rna_bp_h_score,
									rna_bp_s_score, rna_axis_score, rna_stagger_score, rna_bulge_score,
									rna_contact_score, rna_long_range_contact_score,
									scored_base_pair,
									edge_is_base_pairing,
									scored_base_pair_list, scored_base_stack_list);

  numdecoys++;

  FArray1D_bool watson_crick_bp( total_residue, false );
  FArray1D_bool hoogsteen_bp( total_residue, false );
  FArray1D_bool sugar_bp( total_residue, false );
  FArray1D_int num_bs( total_residue, 0 );

  std::cout << "Filling base pair info" << std::endl;

  for ( Energy_base_pair_list::const_iterator it = scored_base_pair_list.begin();
	it != scored_base_pair_list.end(); ++it ){

    Base_pair const base_pair = it->second;
    base_pair_count_map[ base_pair ]++;

    if (base_pair.edge1 == rna_scoring::WATSON_CRICK ) watson_crick_bp( base_pair.res1) = true;
    if (base_pair.edge1 == rna_scoring::HOOGSTEEN )    hoogsteen_bp( base_pair.res1) = true;
    if (base_pair.edge1 == rna_scoring::SUGAR )        sugar_bp( base_pair.res1) = true;

    if (base_pair.edge2 == rna_scoring::WATSON_CRICK ) watson_crick_bp( base_pair.res2) = true;
    if (base_pair.edge2 == rna_scoring::HOOGSTEEN )    hoogsteen_bp( base_pair.res2) = true;
    if (base_pair.edge2 == rna_scoring::SUGAR )        sugar_bp( base_pair.res2) = true;

  }

  std::cout << "Filling base stack info" << std::endl;

  for ( Energy_base_stack_list::const_iterator it = scored_base_stack_list.begin();
	it != scored_base_stack_list.end(); ++it ){

    Base_stack const base_stack = it->second;
    base_stack_count_map[ base_stack ]++;

    num_bs( base_stack.res1)++;
    num_bs( base_stack.res2)++;
  }

  std::cout << "Filling feature counts" << std::endl;

  for (int i = 1; i <= total_residue; i++ ){
    if (watson_crick_bp(i))    is_watson_crick_base_paired(i)++;
    if (hoogsteen_bp(i))       is_hoogsteen_base_paired(i)++;
    if (sugar_bp(i))           is_sugar_base_paired(i)++;
    if (!( watson_crick_bp(i) || hoogsteen_bp(i) || sugar_bp(i) || num_bs(i) ))  is_bulged(i)++;

    stack_count( num_bs(i), i )++;
  }

}

////////////////////////////////////////////////////////////////////////////
void
rna_torsion_featurize( pose_ns::Pose & pose )
{
  using namespace rna_features;

  enum{ WHATEVER, ALPHA, BETA, GAMMA, DELTA, EPSILON, ZETA, CHI, NU2, NU3};

  for (int i = 1; i <= numres; i++ ){

    // alpha
    float const alpha = pose.get_torsion_by_number( i, ALPHA );
    if ( (alpha > -120.0 && alpha < 0.0) || !check_RNA_torsion_insertable_strict( pose, i, ALPHA))  {
      alpha_count(1,i)++;
    } else if ( alpha > 0.0 && alpha < 100.0 ){
      alpha_count(2,i)++;
    } else {
      alpha_count(3,i)++;
    }

    // beta
    float const beta = pose.get_torsion_by_number( i, BETA );
    if ( beta > 0.0 && beta < 150.0 )  {
      beta_count(3,i)++;
    } else if ( beta > 150.0 || beta < -150.0 ){
      beta_count(1,i)++;
    } else {
      beta_count(2,i)++;
    }

    // gamma
    float const gamma = pose.get_torsion_by_number( i, GAMMA );
    if ( gamma > 0.0 && gamma < 150.0 )  {
      gamma_count(1,i)++;
    } else if ( gamma > 150.0 || gamma < -150.0 ){
      gamma_count(2,i)++;
    } else {
      gamma_count(3,i)++;
    }

    // delta
    float const DELTA_CUTOFF = 115.0;
    float const delta = pose.get_torsion_by_number( i, DELTA );
    if ( delta <= DELTA_CUTOFF )  {
      delta_count(1,i)++;
    } else {
      delta_count(2,i)++;
    }

    // epsilon
    float const epsilon = pose.get_torsion_by_number( i, EPSILON );
    if ( epsilon <= 0.0 || !check_RNA_torsion_insertable_strict( pose, i, EPSILON))  {
      epsilon_count(1,i)++;
    } else {
      epsilon_count(2,i)++;
    }

    // zeta
    float const zeta = pose.get_torsion_by_number( i, ZETA );
    if ( (zeta > -100.0 && zeta < 0.0)|| !check_RNA_torsion_insertable_strict( pose, i, ZETA) )  {
      zeta_count(1,i)++;
    } else {
      zeta_count(2,i)++;
    }

    // chi
    float chi = pose.get_torsion_by_number( i, CHI );
    //stupid historical things with chi for purines.
    if ( param_aa::aa_name1( pose.res(i) ) == 'a' || param_aa::aa_name1( pose.res(i) ) == 'g' ) {
      if (chi < 0.0) chi += 360.0;
      chi -= 180.0;
    }
    if ( chi < -20.0 )  {
      chi_count(1,i)++;
    } else {
      chi_count(2,i)++;
    }

  }

}

float
get_eta( pose_ns::Pose & pose, int const i ){

  using namespace kin;
  using namespace rna_variables;
  using namespace numeric;

  Atom_id a1( c4star, i - 1 );
  Atom_id a2(      p, i     );
  Atom_id a3( c4star, i     );
  Atom_id a4(      p, i + 1 );

  Coords_FArray_const coords( pose.full_coord());

  float eta = dihedral( coords.get_xyz( a1 ),coords.get_xyz( a2 ),
			coords.get_xyz( a3 ),coords.get_xyz( a4 ) );

  if (eta < -0.0) eta += 360.0;
  return eta;
}

////////////////////////////////////////////////////////////////////////////
float
get_theta( pose_ns::Pose & pose, int const i ){

  using namespace kin;
  using namespace rna_variables;
  using namespace numeric;

  Atom_id a1(      p, i     );
  Atom_id a2( c4star, i     );
  Atom_id a3(      p, i + 1 );
  Atom_id a4( c4star, i + 1 );

  Coords_FArray_const coords( pose.full_coord());

  float theta = dihedral( coords.get_xyz( a1 ),coords.get_xyz( a2 ),
		   coords.get_xyz( a3 ),coords.get_xyz( a4 ) );

  if (theta < 0.0) theta += 360.0;
  return theta;

}

////////////////////////////////////////////////////////////////////////////
void
rna_pseudo_torsion_featurize( pose_ns::Pose & pose )
{
  using namespace rna_features;

  float eta, theta;
  bool eta_Aform, theta_Aform;

  for (int i = 1; i <= numres; i++ ){

    eta = 180.0;
    theta = 200.0;

    if (i > 1 && i < numres &&
	!pose.is_cutpoint(i-1) && !pose.is_cutpoint(i) &&
	!check_chainbreak(i-1,pose) && !check_chainbreak(i,pose))
      eta = get_eta( pose, i );

    if (i < numres && !pose.is_cutpoint(i) && !check_chainbreak(i,pose))
      theta = get_theta( pose, i );

    std::cout << "ETA: " << eta << "  THETA "  << theta << std::endl;

    if ( eta > 150.0 && eta < 190.0) {
      eta_Aform = true;
    } else {
      eta_Aform = false;
    }

    if ( theta > 190.0 && theta < 260.0) {
      theta_Aform = true;
    } else {
      theta_Aform = false;
    }

    if ( eta_Aform &&  theta_Aform) pseudo_torsion_count(1,i)++;
    if (!eta_Aform &&  theta_Aform) pseudo_torsion_count(2,i)++;
    if ( eta_Aform && !theta_Aform) pseudo_torsion_count(3,i)++;
    if (!eta_Aform && !theta_Aform) pseudo_torsion_count(4,i)++;
  }

}


////////////////////////////////////////////////////////////////////////////
void
rna_featurize( pose_ns::Pose & pose )
{

  //returns if already initialized.
  initialize_feature_maps( pose.total_residue() );

  std::cout << "  Base doublets..." << std::endl;
  rna_base_doublet_featurize( pose );
  std::cout << "  Torsions..." << std::endl;
  rna_torsion_featurize( pose );
  std::cout << "  Pseudotorsions..." << std::endl;
  rna_pseudo_torsion_featurize( pose );

}

////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////
void
rna_featurizer_test()
{

  using namespace pose_ns;
  using namespace silent_io;

  Silent_file_data * decoys_p;
  std::vector< std::string > files;
  bool silent_input;

  read_decoys_silent_or_not( decoys_p, files, silent_input );

  for ( std::vector< std::string >::const_iterator file=files.begin();
	file != files.end(); ++file ) {

    // read pose
    Pose pose;

    std::string const start_file( *file);

    std::cout << "Featurizing: " << start_file << std::endl;
    fill_decoy_silent_or_not( pose, decoys_p, start_file, silent_input );

    rna_featurize( pose );

  }

  utility::io::ozstream out ( stringafteroption("o", "features.txt") );
  output_feature_frequencies( out );
  out.close();

}
