// -*- mode:c++;tab-width:2;indent-tabs-mode:t;show-trailing-whitespace:t;rm-trailing-spaces:t -*-
// vi: set ts=2 noet:
//  CVS information:
//  $Revision: 10040 $
//  $Date: 2006-08-29 00:01:52 -0700 (Tue, 29 Aug 2006) $
//  $Author: rhiju $

#include "saxs_model.h"
#include "after_opts.h"
#include "files_paths.h"
#include "input_pdb.h"
#include "misc.h"
#include "native.h"
#include "output_decoy.h"
#include "param.h"

// ObjexxFCL Headers
#include <ObjexxFCL/FArray1D.hh>
#include <ObjexxFCL/FArray2D.hh>
#include <ObjexxFCL/FArray2Da.hh>
#include <ObjexxFCL/FArray3Da.hh>
#include <ObjexxFCL/formatted.o.hh>

// Numeric Headers
#include <numeric/all.fwd.hh>
#include <numeric/xyzVector.hh>
#include <numeric/xyzVector.io.hh>
#include <numeric/xyz.functions.hh>

// Utility Headers
#include <utility/basic_sys_util.hh>
#include <utility/io/izstream.hh>
#include <utility/io/ozstream.hh>

// C++ Headers
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <iostream>
#include <list>
#include <vector>

//////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////
//
// Very basic code to calculate small-angle x-ray scattering (SAXS)
// profiles from centroid decoys, and compare to experimental profiles
// during ab initio folding.
//
//
// This computes the SAXS profile by the Debye formula:
//
//  I( q ) = SUM f_i * f_j sin( q d_ij) / (q d_ij)
//
//  here q is the wavevector
//
//     q = 4 * pi * sin( theta / 2) / lambda
//     [theta is scattering angle, lambda=wavelength]
//
// d_ij is the distance between two atoms i and j, f_i is the number of
// electrons in the atom i.
//
// Actually, I found that just using the N atom and Centroid atoms as
// placeholders for all the atoms in the backbone and sidechain,
// respectively, I get basically the same profile as using all atoms or
// the same profile as Svergun's CRYSOL program, up to
// q = 1.0 (nominally corresponding to a distance scale of 2*pi = 6 Angstroms).
//
// This is currently still not ready for prime time. What is missing
// is a treatment of hydration shells -- modeled in CRYSOL by an extra
// layer of density around the protein, and giving better agreement with
// experimental profiles. We could probably
// put this in with mock atoms to represent water...
//
// The SAXS score is given by a chi-squared formula:
//
//  SCORE = SUM W(q) * [I_exp(q) - I_pred(q)]/ errI(q)^2
//
// Here W(q) = const * q^n (n = 0 by default), and the constant was chosen
// to make a typical non-native decoy have a score of 20 over the native.
//
//
// rhiju october, 2006
//
//////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////

namespace saxs_model{
	int const MAXQBINS = 10000; //can get a lot of bins from a CCD detector!

  int num_qval_bins( 0 );
  FArray1D_float qval( MAXQBINS, 0.0);

  FArray1D_double input_saxs_profile( MAXQBINS, 0.0);
  FArray1D_double input_saxs_profile_error( MAXQBINS, 1.0);
  FArray1D_double input_saxs_profile_norm( MAXQBINS, 0.0);
  FArray1D_double input_saxs_profile_error_norm( MAXQBINS, 1.0);

  FArray1D_double decoy_saxs_profile( MAXQBINS, 0.0);

	float dist_bin_width = 0.2;
	float dist_max = 1000.0;
	int num_dist_bins = static_cast <int> (dist_max / dist_bin_width);

	int atomicnumber_N = 39;
	FArray1D_int const atomicnumber_centroid( 20, atomicnumber_centroid_initializer );
	FArray2D_double saxs_basis;
} //saxs_model


/////////////////////////////////////////////////////////////////////
void
eval_saxs_model_score( float & saxs_model_score)
{
	using namespace misc;
	using namespace saxs_model;

  saxs_model_score = 0.0;
  if (!get_saxs_model_score_flag()) return;

  calculate_saxs_profile(Eposition, centroid, decoy_saxs_profile);

  // Normalize, based on areas under profiles? There are other ways
  // to do this, based on say, optimal norm factor for computing the least squares sum below.
  float sum_input_saxs_profile = 0.0;
  float sum_decoy_saxs_profile = 0.0;
  for (int i = 1; i<= num_qval_bins; i++){
    sum_input_saxs_profile +=  input_saxs_profile(i);
    sum_decoy_saxs_profile +=  decoy_saxs_profile(i);
  }

	input_saxs_profile_norm = input_saxs_profile * sum_decoy_saxs_profile/ sum_input_saxs_profile;
	input_saxs_profile_error_norm = input_saxs_profile_error * sum_decoy_saxs_profile/ sum_input_saxs_profile;

  // Here it is, the least squares sum.
  double power  = get_saxs_model_power(); //default 0.0
  float  saxs_model_weight  = get_saxs_model_weight(); //default 0.0
  float deviation;
	float qmax = qval(num_qval_bins);
  for (int i = 1; i<= num_qval_bins; i++){
    deviation =  (decoy_saxs_profile(i) - input_saxs_profile_norm(i)) /
			input_saxs_profile_error_norm(i);
    saxs_model_score += saxs_model_weight * deviation * deviation
			* pow( static_cast<double>(qval(i)/qmax) , power);
  }

  return;
}

/////////////////////////////////////////////////////////////////////
void
calculate_saxs_profile(
			     FArray3Da_float decoy_Eposition,
			     FArray2Da_float decoy_centroid,
					 FArray1D_double & saxs_profile
)
{
  using namespace saxs_model;
	using namespace misc;
	using namespace numeric;

	FArray1D_int dist_histogram(num_dist_bins, 0);

	//Use N as a placeholder for backbone... N-N
  for (int i = 1; i <= total_residue; i++){
		xyzVector_float N_i ( &decoy_Eposition(1, 1, i) );
		for (int j = 1; j <= total_residue; j++){
			xyzVector_float N_j ( &decoy_Eposition(1, 1, j) );
			float distance = (N_i - N_j).length();
			int bin = static_cast <int> (distance/ dist_bin_width) + 1;
			if (bin > num_dist_bins) bin = num_dist_bins;
			dist_histogram( bin ) += atomicnumber_N * atomicnumber_N;
		}
	}

	//N-centroid
  for (int i = 1; i <= total_residue; i++){
		xyzVector_float N_i ( &decoy_Eposition(1, 1, i) );
		for (int j = 1; j <= total_residue; j++){
			xyzVector_float centroid_j ( &decoy_centroid(1, j) );
			int res_j = res( j );
			float distance = (N_i - centroid_j).length();
			int bin = static_cast <int> (distance/ dist_bin_width) + 1;
			if (bin > num_dist_bins) bin = num_dist_bins;
			dist_histogram( bin ) += 2 * atomicnumber_N * atomicnumber_centroid(res_j);
		}
	}

	//centroid-centroid
  for (int i = 1; i <= total_residue; i++){
		xyzVector_float centroid_i ( &decoy_centroid(1, i) );
			int res_i = res( i );
		for (int j = 1; j <= total_residue; j++){
			xyzVector_float centroid_j ( &decoy_centroid(1, j) );
			int res_j = res( j );
			float distance = (centroid_i - centroid_j).length();
			int bin = static_cast <int> (distance/ dist_bin_width) + 1;
			if (bin > num_dist_bins) bin = num_dist_bins;
			dist_histogram( bin ) += atomicnumber_centroid(res_i) *
				atomicnumber_centroid(res_j);
		}
	}

	//Calculate saxs profiles!
	initialize_saxs_basis(); //returns immediately if already init.
	saxs_profile.dimension( num_qval_bins );
	saxs_profile = 0.0;
	for (int i = 1; i <= num_dist_bins; i++){
		if( dist_histogram(i) > 0) {
			for (int j = 1; j <= num_qval_bins; j++){
				saxs_profile(j) += saxs_basis(i,j) * dist_histogram(i);
			}
		}
	}
	//Normalize.
	if (saxs_profile(1) > 0) {
			 saxs_profile *= (1.0 / saxs_profile(1));
	}
}


/////////////////////////////////////////////////////////////////////
void
initialize_saxs_profile() {
	using namespace saxs_model;

	//Default q values: 0.00, 0.01, 0.02, ... 1.0 inverse Angstroms.
	num_qval_bins = 101;
	for (int i = 1; i <= num_qval_bins;  i++){
		qval(i) = (1.0 / static_cast <float> (num_qval_bins-1) ) * (i - 1);
	}

	if (get_saxs_model_flag()) {
		if (truefalseoption("input_saxs_profile")){
			input_saxs_profile_from_file();
		} else {
			input_saxs_model_from_file();
		}
	}

	//redimension, to conserve memory.
// 	qval.dimension( num_qval_bins);
// 	input_saxs_profile.dimension( num_qval_bins);
// 	input_saxs_profile_norm.dimension( num_qval_bins);
// 	input_saxs_profile_error_norm.dimension( num_qval_bins);
// 	decoy_saxs_profile.dimension( num_qval_bins);

	input_saxs_profile_norm = input_saxs_profile;
	input_saxs_profile_error_norm = input_saxs_profile_error;
	std::cout << " Profile defined from q = " << qval(1) << " to " << qval( num_qval_bins) << "." <<std::endl;
}

/////////////////////////////////////////////////////////////////////
void
input_saxs_profile_from_file(){
  using namespace saxs_model;

	if (!get_saxs_model_flag()) return;

	std::string filename;
	stringafteroption("input_saxs_profile","hey",filename);
	utility::io::izstream saxs_profile_stream( filename );

	if ( !saxs_profile_stream ) {
		std::cerr << "Couldn't find file " << filename << " for reading in saxs profile." << std::endl;
		utility::exit( EXIT_FAILURE, __FILE__, __LINE__);
		return;
	}

	int i = 0;
	float profile_error;
	while( saxs_profile_stream ) {
		i++;
		saxs_profile_stream >> qval(i) >> input_saxs_profile(i) >> profile_error >> utility::io::skip;
		if (profile_error)
			input_saxs_profile_error(i) = profile_error;
		else
			input_saxs_profile_error(i) = 0.01 * input_saxs_profile(i);
	}
	saxs_profile_stream.close();

	num_qval_bins = i;

	std::cout << "Saxs profile read in: " << filename << std::endl;
	if (profile_error)
		std::cout << "Saxs profile errors read in: " << filename << std::endl;
	else
		std::cout << "Saxs profile errors not found or zero: errors assumed to be 1% of profile." << std::endl;
}

/////////////////////////////////////////////////////////////////////
void
input_saxs_model_from_file(){
  using namespace saxs_model;
	using namespace misc;

	if (!get_saxs_model_flag()) return;

  // Assume saxs model is in PDB format, with atoms at CA positions.
  bool fail = true;
	std::string filename;
	stringafteroption("saxs_model","hey",filename);
	utility::io::izstream saxs_model_stream( filename );

	if ( !saxs_model_stream ) {
		std::cerr << "Couldn't find file " << filename << " for reading in saxs model." << std::endl;
		utility::exit( EXIT_FAILURE, __FILE__, __LINE__);
		return;
	}

  input_pdb( saxs_model_stream, false /* seq_defined */, false /* fullatom */, fail );
	saxs_model_stream.close();

	std::cout << "Saxs model read in: " << filename << std::endl;
	calculate_saxs_profile( Eposition, centroid, input_saxs_profile);
	input_saxs_profile_error = input_saxs_profile * 0.01;
	input_saxs_profile_norm = input_saxs_profile;

	std::cout << "Errors assumed to be 1% of profile." << std::endl;
}


/////////////////////////////////////////////////////////////////////
void
output_saxs_profile( std::string fullname) {
	using namespace saxs_model;
	using namespace misc;

	if (!get_output_saxs_profile_flag()) return;

	calculate_saxs_profile( Eposition, centroid, decoy_saxs_profile );

	fullname += ".saxs";
	utility::io::ozstream pdb_out_stream( fullname );

	if ( !pdb_out_stream ) {
		std::cout << "Open failed for file: " << pdb_out_stream.filename() << std::endl;
		utility::exit( EXIT_FAILURE, __FILE__, __LINE__);
	}
	std::cout << "writing saxs profile: " << pdb_out_stream.filename() << std::endl;

	for (int i = 1; i<=num_qval_bins; i++){
		pdb_out_stream << F(7, 5, qval(i)) << " " <<
			F(7, 5, decoy_saxs_profile(i))  << std::endl;
		//		<<  " " << F(7, 5, input_saxs_profile_norm(i)) << std::endl;
	}
}


/////////////////////////////////////////////////////////////////////
bool
get_saxs_model_score_flag() {
  static bool init = {false};
  static bool saxs_model_score_flag;

  if (!init) {
    saxs_model_score_flag = truefalseoption("saxs_model_score");
    init = true;
  }

  return saxs_model_score_flag;
}
/////////////////////////////////////////////////////////////////////
bool
get_saxs_model_flag() {
  static bool init = {false};
  static bool saxs_model_flag;

  if (!init) {
    saxs_model_flag = truefalseoption("saxs_model");
    init = true;
  }

  return saxs_model_flag;
}
/////////////////////////////////////////////////////////////////////
bool
get_output_saxs_profile_flag() {
  static bool init = {false};
  static bool output_saxs_profile_flag;

  if (!init) {
    output_saxs_profile_flag = truefalseoption("output_saxs_profile");
    init = true;
  }

  return output_saxs_profile_flag;
}
/////////////////////////////////////////////////////////////////////
double get_saxs_model_power(){
  static bool init = {false};
  static double saxs_model_power;
  if (!init) {
    Drealafteroption("saxs_model_power", 0.0, saxs_model_power );
    init = true;
  }
  return saxs_model_power;
}

/////////////////////////////////////////////////////////////////////
float get_saxs_model_weight(){
  static bool init = {false};
  static float saxs_model_power;
  if (!init) {
    realafteroption("saxs_model_weight", 0.0005, saxs_model_power );
    init = true;
  }
  return saxs_model_power;
}

/////////////////////////////////////////////////////////////////////
void initialize_saxs_basis()
{
	using namespace saxs_model;
	static bool init = {false};

	saxs_basis.dimension( num_dist_bins, num_qval_bins);

	if (init) return;

	for (int i = 1; i <= num_dist_bins; i++){
		for (int j = 1; j <= num_qval_bins; j++){
			float distance = (i - 0.5) * dist_bin_width;
			//padding to prevent divide by zero
			double qd = qval( j ) * distance +  1e-8;
			saxs_basis(i,j) = std::sin( qd )/ qd;
		}
	}
	init = true;
}

/////////////////////////////////////////////////////////////////////
void
atomicnumber_centroid_initializer( FArray1D_int & atomic_number_centroid ){
	int i = 0;
	atomic_number_centroid( ++i ) =  1;
	atomic_number_centroid( ++i ) = 17;
	atomic_number_centroid( ++i ) = 23;
	atomic_number_centroid( ++i ) = 31;
	atomic_number_centroid( ++i ) = 49;
	atomic_number_centroid( ++i ) = -8;
	atomic_number_centroid( ++i ) = 33;
	atomic_number_centroid( ++i ) = 25;
	atomic_number_centroid( ++i ) = 33;
	atomic_number_centroid( ++i ) = 25;
	atomic_number_centroid( ++i ) = 33;
	atomic_number_centroid( ++i ) = 23;
	atomic_number_centroid( ++i ) = 15;
	atomic_number_centroid( ++i ) = 31;
	atomic_number_centroid( ++i ) = 47;
	atomic_number_centroid( ++i ) =  9;
	atomic_number_centroid( ++i ) = 17;
	atomic_number_centroid( ++i ) = 17;
	atomic_number_centroid( ++i ) = 61;
	atomic_number_centroid( ++i ) = 49;
}

