// -*- mode:c++;tab-width:2;indent-tabs-mode:t;show-trailing-whitespace:t;rm-trailing-spaces:t -*-
// vi: set ts=2 noet:

//  CVS information:
//  $Revision: 15327 $
//  $Date: 2007-06-05 07:58:57 -0700 (Tue, 05 Jun 2007) $
//  $Author: sarel $

#include "T32S3.h"

#include "aaproperties_pack.h"
#include "after_opts.h"
#include "count_pair.h"
#include "files_paths.h"
#include "misc.h"
#include "pack_geom_inline.h"
#include "pdbstatistics_pack.h"
#include "read_paths.h"

// ObjexxFCL Headers
#include <ObjexxFCL/ObjexxFCL.hh>
#include <ObjexxFCL/FArray1D.hh>
#include <ObjexxFCL/FArray2D.hh>
#include <ObjexxFCL/FArray3D.hh>
#include <ObjexxFCL/FArray4D.hh>
#include <ObjexxFCL/FArray5D.hh>
#include <ObjexxFCL/formatted.o.hh>

// Utility Headers
#include <utility/io/izstream.hh>
#include <utility/io/ozstream.hh>

// C++ Headers
#include <cstdlib>
#include <iostream>
#include <string>
#include <fstream>
#include <vector>
#include <map>
#include <sstream>


namespace t32s3_ns {

  typedef enum {NX, NH, CO, OC, CAH, CH3, CH2, CFH, CZ, OH, CGTR, CHTR,  NDHS, CH2M, SM, CH3M, CH2K, CH2S, CHPR, CH2C, SH, CGHP, NDHP, CHEP, CR2, NR1, CR3, NR2, NAS, CH2B, CX1, OX1} T32S3AtomType;
  int const NUM_ATOM = 32;
  int const NUM_TYPE = 3*NUM_ATOM*(NUM_ATOM+1)/2;
  float const MAX_DIS  = 6.5; // change this!!!
  float const MAX_DIS2 = MAX_DIS*MAX_DIS;


  /////////////////////////////////////////////////////////////////////////////
  // this function should take the rosetta res type, res variant (which you can
  // probably ignore) and rosetta atom type and returns your appropriate t32s3
  // atom type
  /////////////////////////////////////////////////////////////////////////////
  T32S3AtomType rosetta_type_to_t32s3_type( int const res_type,
																									//int const res_variant,
																									int const atype,
																									int const anum)
  {
    switch (atype) {
    case 1: return CO;
    case 2: return CX1;
    case 5:
      if (res_type==11) //MET
	return CH3M;
      else
	return CH3;
    case 6:
      if (res_type==5) //PHE
	return CFH;
      else if (res_type==20) //TYR
	{
	  if (anum==11) return CZ;
	  else return CFH;
	}
      else if (res_type==19) //TRP
	{
	  if (anum==6 || anum==8) return CGTR;
	  else if (anum==7 || anum==10) return CHTR;
	  else return CFH;
	}
      else if (res_type==15) //ARG CZ
	{
	  return CR3;
	}
      else //HIS
	{
	  if (anum==9) return CHEP;
	  else return CGHP;
	}
    case 7:
      if (res_type==19) //TRP
	return NDHS;
      else //HIS
	return NDHP;
    case 8: return NDHP;
    case 9: return NAS;
    case 10: return NX;
    case 11:
      if (anum == 8) //ARG NE
	return NR1;
      else
	return NR2;
    case 12: return NH;
    case 13: return OH;
    case 14: return OC;
    case 15: return OX1;
    case 16:
      if (res_type==11) //MET
	return SM;
      else //CYS
	return SH;
    case 17: return NH;
    case 18:
      if (res_type==13) //PRO
	return CHPR;
      else
	return CAH;
    case 19: return CO;
    case 20: return OC;

    default:
      if (res_type ==11 && anum==6) //MET CG
	return CH2M;
      else if (res_type ==2) //CYS
	return CH2C;
      else if (res_type ==15 && anum==7) //ARG CD
	return CR2;
      else if (res_type ==9 && anum==8) //LYS CE
	return CH2K;
      else if (res_type ==13 && anum==7) //PRO CD
	return CHPR;
      else if (res_type ==16 || res_type==17) //SER/THR
	return CH2S;
      else if (res_type ==3) //ASP
	return CH2B;
      else if (res_type ==4 && anum==6) //GLU CG
	return CH2B;
      else
	return CH2;
    }
  }

  ////////////////////////////////////////////////////////////////////////////
  // This function takes two of your types + the squared distance
  // and returns the score for that pair
  ////////////////////////////////////////////////////////////////////////////
  float t32s3_score_for_pair(int *type_count, float *pot)
  {
    float to_return = 0;

    for (int i=0; i<NUM_TYPE; i++) {
      to_return+=type_count[i]*pot[i];
    }

    return to_return;
  }


   ////////////////////////////////////////////////////////////////////////////
  // This function takes two atom types and returns the index in the
  // one-dimensional vector representing the triangular matrix
  ////////////////////////////////////////////////////////////////////////////
  int compute_index(int type1, int type2) {

    int index;

    if (type1 <= type2)
      index = type2*(type2+1)/2 + type1;
    else
      index = type1*(type1+1)/2 + type2;
    return index;

  }

    ////////////////////////////////////////////////////////////////////////////
  // This function increments the counts of the corresponding energy type by 1
  ////////////////////////////////////////////////////////////////////////////
  void increment_count(int index, float dist2, int *type_count) {

    int step = NUM_ATOM*(NUM_ATOM+1)/2;

    if (dist2 >=2*2 && dist2 <3.5*3.5)
      type_count[index]++;
    else if (dist2 >= 3.5*3.5 && dist2 <5*5)
      type_count[step+index]++;
    else if (dist2 >=5*5 && dist2 <=6.5*6.5)
      type_count[2*step+index]++;
  }


  void read_potential(char *file_name, float *pot) {
    using namespace std;

		bool failed;
    utility::io::izstream & input( try_to_open_data_file(string(file_name), failed ) );
    if ( failed ) {
      cerr << "T32S3.cc: cannot open file " << file_name << '\n';
      exit(1);
    }

    for (int i=0; i<NUM_TYPE; i++) {
      input >> pot[i];
    }

    input.close();
  }

  ///////////////////////////////////////////////////////////////////////////
  // this function computes the total score for the structure currently
  // loaded into rosetta's global arrays... you *shouldn't* have to
  // modify this one too much...
  //////////////////////////////////////////////////////////////////////////
  float compute_t32s3_score_from_misc() {

    using namespace aaproperties_pack;
    using namespace misc;
    using namespace std;
    using namespace files_paths;   // for start_file
    using namespace param;
    using namespace param_aa;
    int type_count[NUM_TYPE];
    float pot[NUM_TYPE];

    cout << "T32S3.cc: compute_t32s3_score_from_misc() called" << endl;

    int aa1,aav1,aa2,aav2;
    float dis2 = -1324.0f;
    float score = 0.0f;

    for (int i=0; i<NUM_TYPE; i++)
      type_count[i]=0;

    for ( int res1 = 1; res1 < total_residue; ++res1 ) {
      aa1 = res(res1);
      aav1 = res_variant(res1);

      for ( int atom1 = 1, atom1e = natoms(aa1,aav1); atom1 <= atom1e; ++atom1 ) {
	int const atype1 = fullatom_type(atom1,aa1,aav1);
	if (atype1>20) continue;

	float const X1 = misc::full_coord(1,atom1,res1);
	float const Y1 = misc::full_coord(2,atom1,res1);
	float const Z1 = misc::full_coord(3,atom1,res1);

	for ( int res2 = res1+1; res2 <= total_residue; ++res2 ) {
	  aa2 = res(res2);
	  aav2 = res_variant(res2);

	  for ( int atom2 = 1, atom2e = natoms( aa2, aav2 ); atom2 <= atom2e; ++atom2 ) {

	    int const atype2 = fullatom_type(atom2,aa2,aav2);
	    if (atype2>20) continue;
	    //ignore backbone interactions between neighboring residues
	    if (res2==res1+1 && atom1<=4 && atom2<=4) continue;

	    // quick bounds for speed
	    float const X2 = misc::full_coord(1,atom2,res2);
	    if( abs( X2 - X1 ) > MAX_DIS )
	      continue;
	    float const Y2 = misc::full_coord(2,atom2,res2);
	    if( abs( Y2 - Y1 ) > MAX_DIS )
	      continue;
	    float const Z2 = misc::full_coord(3,atom2,res2);
	    if( abs( Z2 - Z1 ) > MAX_DIS )
	      continue;

	    // get the distance
	    dis2 = (X1-X2)*(X1-X2) + (Y1-Y2)*(Y1-Y2) + (Z1-Z2)*(Z1-Z2);
	    if (dis2 > MAX_DIS2) continue;

	    // get the types
	    T32S3AtomType const jiantype1 = rosetta_type_to_t32s3_type(aa1,atype1,atom1);
	    T32S3AtomType const jiantype2 = rosetta_type_to_t32s3_type(aa2,atype2,atom2);
	    int index=compute_index(jiantype1, jiantype2);
	    increment_count(index, dis2, type_count);

	  } // atom2
	} //res2
      }
    }

    read_potential("T32S3.pot", pot);
    // compute the T32S3 energy based on the potential and the counts of the energy types
    score = t32s3_score_for_pair(type_count, pot);

    // return the final score
    return score;

  }

  ///////////////////////////////////////////////////////////////////////////////
  // for now, this function will get called after each decoy is scored/generated
  // so you can test everything
  ///////////////////////////////////////////////////////////////////////////////
  void test_t32s3_score() {
    using namespace std;

    cout << endl;
    cout << endl;
    cout << "Jian's Score:" << compute_t32s3_score_from_misc() << endl;
    cout << endl;
    cout << endl;

  }


} // end of namespace t32s3_ns



//T32S3 atom types
// 0	LYS	NZ
// 1	ALA	N
// 1	ARG	N
// 1	ASN	N
// 1	ASP	N
// 1	CYS	N
// 1	GLN	N
// 1	GLU	N
// 1	GLY	N
// 1	HIS	N
// 1	ILE	N
// 1	LEU	N
// 1	LYS	N
// 1	MET	N
// 1	PHE	N
// 1	PRO	N
// 1	SER	N
// 1	THR	N
// 1	TRP	N
// 1	TYR	N
// 1	VAL	N
// 2	ALA	C
// 2	ARG	C
// 2	ASN	C
// 2	ASP	C
// 2	CYS	C
// 2	GLN	C
// 2	GLU	C
// 2	GLY	C
// 2	HIS	C
// 2	ILE	C
// 2	LEU	C
// 2	LYS	C
// 2	MET	C
// 2	PHE	C
// 2	PRO	C
// 2	SER	C
// 2	THR	C
// 2	TRP	C
// 2	TYR	C
// 2	VAL	C
// 2	ASN	CG
// 2	GLN	CD
// 3	ALA	O
// 3	ARG	O
// 3	ASN	O
// 3	ASP	O
// 3	CYS	O
// 3	GLN	O
// 3	GLU	O
// 3	GLY	O
// 3	HIS	O
// 3	ILE	O
// 3	LEU	O
// 3	LYS	O
// 3	MET	O
// 3	PHE	O
// 3	PRO	O
// 3	SER	O
// 3	THR	O
// 3	TRP	O
// 3	TYR	O
// 3	VAL	O
// 3	ASN	OD1
// 3	GLN	OE1
// 4	ALA	CA
// 4	ARG	CA
// 4	ASN	CA
// 4	ASP	CA
// 4	CYS	CA
// 4	GLN	CA
// 4	GLU	CA
// 4	GLY	CA
// 4	HIS	CA
// 4	ILE	CA
// 4	LEU	CA
// 4	LYS	CA
// 4	MET	CA
// 4	PHE	CA
// 4	SER	CA
// 4	THR	CA
// 4	TRP	CA
// 4	TYR	CA
// 4	VAL	CA
// 5	ALA	CB
// 5	ILE	CG2
// 5	ILE	CD1
// 5	LEU	CD1
// 5	LEU	CD2
// 5	THR	CG2
// 5	VAL	CG1
// 5	VAL	CG2
// 6	ARG	CB
// 6	ARG	CG
// 6	ASN	CB
// 6	GLN	CB
// 6	GLN	CG
// 6	GLU	CB
// 6	HIS	CB
// 6	ILE	CB
// 6	ILE	CG1
// 6	LEU	CB
// 6	LEU	CG
// 6	LYS	CB
// 6	LYS	CG
// 6	LYS	CD
// 6	MET	CB
// 6	PHE	CB
// 6	PRO	CB
// 6	PRO	CG
// 6	TRP	CB
// 6	TYR	CB
// 6	VAL	CB
// 7	PHE	CG
// 7	PHE	CD1
// 7	PHE	CD2
// 7	PHE	CE1
// 7	PHE	CE2
// 7	PHE	CZ
// 7	TRP	CE3
// 7	TRP	CZ2
// 7	TRP	CZ3
// 7	TRP	CH2
// 7	TYR	CG
// 7	TYR	CD1
// 7	TYR	CD2
// 7	TYR	CE1
// 7	TYR	CE2
// 8	TYR	CZ
// 9	SER	OG
// 9	THR	OG1
// 9	TYR	OH
// 10	TRP	CG
// 10	TRP	CD2
// 11	TRP	CD1
// 11	TRP	CE2
// 12	TRP	NE1
// 13	MET	CG
// 14	MET	SD
// 15	MET	CE
// 16	LYS	CE
// 17	SER	CB
// 17	THR	CB
// 18	PRO	CA
// 18	PRO	CD
// 19	CYS	CB
// 20	CYS	SG
// 21	HIS	CG
// 21	HIS	CD2
// 22	HIS	ND1
// 22	HIS	NE2
// 23	HIS	CE1
// 24	ARG	CD
// 25	ARG	NE
// 26	ARG	CZ
// 27	ARG	NH1
// 27	ARG	NH2
// 28	ASN	ND2
// 28	GLN	NE2
// 29	ASP	CB
// 29	GLU	CG
// 30	ASP	CG
// 30	GLU	CD
// 31	ASP	OD1
// 31	ASP	OD2
// 31	GLU	OE1
// 31	GLU	OE2

