// -*- mode:c++;tab-width:2;indent-tabs-mode:t;show-trailing-whitespace:t;rm-trailing-spaces:t -*-
// vi: set ts=2 noet:
// :noTabs=false:tabSize=4:indentSize=4:
//
// (c) Copyright Rosetta Commons Member Institutions.
// (c) This file is part of the Rosetta software suite and is made available under license.
// (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
// (c) For more information, see http://www.rosettacommons.org. Questions about this can be
// (c) addressed to University of Washington UW TechTransfer, email: license@u.washington.edu.

/// @file
///
/// @brief
/// @author

#include <core/scoring/sasa.hh>

#include <core/types.hh>

#include <core/chemical/AtomType.hh>

#include <core/conformation/Residue.hh>

#include <core/id/AtomID_Map.hh>
#include <core/id/AtomID_Map.Pose.hh>

#include <core/pose/Pose.hh>

#include <core/io/database/open.hh>
#include <core/util/Tracer.hh>

// ObjexxFCL Headers
#include <ObjexxFCL/ubyte.hh>
//#include <ObjexxFCL/FArray1D.hh>
#include <ObjexxFCL/FArray2D.hh>
//#include <ObjexxFCL/FArray3D.hh>
#include <ObjexxFCL/Fmath.hh>
#include <ObjexxFCL/formatted.io.hh>

// Numeric Headers
#include <numeric/constants.hh>
#include <numeric/trig.functions.hh>

// Utility Headers
#include <utility/basic_sys_util.hh>
#include <utility/io/izstream.hh>

// C++ Headers
#include <cmath>
#include <cstdlib>
#include <fstream>
#include <iostream>

using namespace ObjexxFCL;
using namespace ObjexxFCL::fmt;

namespace core {
namespace scoring {


// this lookup table is used in sasa computation (also in void.cc)
short const bit_count[] = { // lookup table for number of 1 bits in a ubyte
	0,1,1,2,1,2,2,3, 1,2,2,3,2,3,3,4,   1,2,2,3,2,3,3,4, 2,3,3,4,3,4,4,5,  // 0x 1x
	1,2,2,3,2,3,3,4, 2,3,3,4,3,4,4,5,   2,3,3,4,3,4,4,5, 3,4,4,5,4,5,5,6,  // 2x 3x
	1,2,2,3,2,3,3,4, 2,3,3,4,3,4,4,5,   2,3,3,4,3,4,4,5, 3,4,4,5,4,5,5,6,  // 4x 5x
	2,3,3,4,3,4,4,5, 3,4,4,5,4,5,5,6,   3,4,4,5,4,5,5,6, 4,5,5,6,5,6,6,7,  // 6x 7x
	1,2,2,3,2,3,3,4, 2,3,3,4,3,4,4,5,   2,3,3,4,3,4,4,5, 3,4,4,5,4,5,5,6,  // 8x 9x
	2,3,3,4,3,4,4,5, 3,4,4,5,4,5,5,6,   3,4,4,5,4,5,5,6, 4,5,5,6,5,6,6,7,  // Ax Bx
	2,3,3,4,3,4,4,5, 3,4,4,5,4,5,5,6,   3,4,4,5,4,5,5,6, 4,5,5,6,5,6,6,7,  // Cx Dx
	3,4,4,5,4,5,5,6, 4,5,5,6,5,6,6,7,   4,5,5,6,5,6,6,7, 5,6,6,7,6,7,7,8,  // Ex Fx
};

int const nbytes = { 21 };
int const nphi = { 64 };
int const ntheta = { 64 };
int const nolp = { 100 };
int const nori = { 162 };
int const maskbits = { 162 };

FArray2D_int angles( nphi, ntheta );
FArray2D_ubyte masks( nbytes, nolp*nori );


////////////////////////////////////////////////////////////////////////////////
/// @begin input_sasa_dats
///
/// @brief
///cj    Reads in SASA-angles.dat  SASA-masks.dat
///
/// @detailed
///
/// @global_read
///
/// @global_write
///
/// @remarks
///
/// @references
///
/// @authors
///
/// @last_modified
/////////////////////////////////////////////////////////////////////////////////
void
input_sasa_dats()
{

	FArray1D_short tmp( nbytes );
	static bool init = { false };

	if ( init ) return;
	init = true;

//cj    inputting the masks they are 21 ubytes long, 162x100 (see header)
//cj    expects file to be complete
	utility::io::izstream masks_stream( io::database::full_name("SASA-masks.dat" ) );

	for ( int i = 1; i <= nolp*nori; ++i ) {
		for ( int j = 1; j <= nbytes; ++j ) {
			masks_stream >> tmp(j);
		} masks_stream >> skip;

		for ( int j = 1; j <= nbytes; ++j ) {
			masks(j,i) = static_cast< ubyte>(tmp(j));
		}
	}
	masks_stream.close();

//cj    inputting the angle lookup for the mask, need to add a 1 to each number
	utility::io::izstream angles_stream( io::database::full_name( "SASA-angles.dat" ) );

//cj    2D array is aphi by theta
	for ( int i = 1; i <= nphi; ++i ) {
		for ( int j = 1; j <= ntheta; ++j ) {
			angles_stream >> angles(i,j);
		} angles_stream >> skip;
//cj       for ( j = 1; j <= ntheta; ++j ) {
//cj          ++angles(i,j);
//cj       }
	}
	angles_stream.close();
}

////////////////////////////////////////////////////////////////////////////////
/// @begin get_overlap
///
/// @brief
///
/// @detailed
///cj    getting overlap from a to b (or i to j, see below)
///cj    this returns the degree of overlap between two atoms
///cj    adapted from erics code in area.c GetD2
///cj    returns value from 1 to 100
///cj    This calculation is based on the law of cosines,
///cj    see LeGrand and Merz, Journal of Computational
///cj    Chemistry 14(3):349-52 (1993).
///cj    Note that equation (4) is wrong, the denominator
///cj    should be 2r r   instead of 2r r
///cj                i iq              i q
///
/// @param  a - [in/out]? -
/// @param  ra - [in/out]? -
/// @param  b - [in/out]? -
/// @param  rb - [in/out]? -
/// @param  dist - [in/out]? -
/// @param  olp - [in/out]? -
///
/// @global_read
///
/// @global_write
///
/// @remarks
///
/// @references
///
/// @authors
///
/// @last_modified
/////////////////////////////////////////////////////////////////////////////////
void
get_overlap(
						//Vector const & a,
	//FArray1DB_float const & a,
	Real const ra,
	//Vector const & b,
	Real const rb,
	Real const dist,
	int & olp
)
{
	Real epsilon,costh;

//cj    min distance cutoff
	epsilon = 0.01;

	if ( dist < epsilon ) {
//cj    atoms too close, causes round off error
//cj    use this cutoff
		if ( ra < rb ) {
			olp = 100;
		} else {
			olp = 1;
		}
	} else if ( rb+dist <= ra ) {
//cj    If atom a completely engulfs atom b, consider a to have
//cj    no overlap due to atom b.
		olp = 1;
	} else if ( rb+dist <= ra ) {
//cj    If atom a is completely engulfed by atom b, then turn it
//cj    completely off (i.e. d2 = 99).
		olp = 100;
	} else {
//cj    Otherwise, compute the amount of overlap using the law of cosines.
//cj    "costh" is the angle of the cone of intersection that atom b
//cj    imposes on atom a.  "ra" is the radius of atom a, and "rb" is
//cj    the radius of atom b.  "sqrtd" is the actual distance between
//cj    the a and b centers, while "dist" is the square of this distance.
		costh = (ra*ra+dist*dist-rb*rb)/(2*ra*dist);
		olp = static_cast< int >((1.0f-costh)*50)+1;
		if ( olp > 100 ) {
			olp = 100;
		} else if ( olp < 0 ) {
//cj       We already hopefully accounted for this possibility by requiring that
//cj       dist < epsilon, but in case not we don't want a potential bug to go
//cj       unnoticed.
			std::cout << "problem in calculating overlap between:" << std::endl;
// 			std::cout << "a  " <<
// 			 F( 7, 3, a(1) ) << ' ' << F( 7, 3, a(2) ) << ' ' << F( 7, 3, a(3) ) <<
// 			 std::endl;
// 			std::cout << "b  " <<
// 			 F( 7, 3, b(1) ) << ' ' << F( 7, 3, b(2) ) << ' ' << F( 7, 3, b(3) ) <<
// 			 std::endl;
			std::cout << "ra=" << SS( ra ) << std::endl;
			std::cout << "rb=" << SS( rb ) << std::endl;
			std::cout << "dist=" << SS( dist ) << std::endl;
			std::cout << "costh=" << SS( costh ) << std::endl;
			std::cout << "Teminiating calculation" << std::endl;
			utility::exit( EXIT_FAILURE, __FILE__, __LINE__);
		}
	}
}

////////////////////////////////////////////////////////////////////////////////
/// @begin get_orientation
///
/// @brief
///cj    gets the orientation of a to b (i to j, see below)
///cj    does this by calculating two angles, aphi and theta
///
/// @detailed
///
/// @param  a - [in/out]? -
/// @param  b - [in/out]? -
/// @param  aphi - [in/out]? -
/// @param  theta - [in/out]? -
///
/// @global_read
///
/// @global_write
///
/// @remarks
///
/// @references
///
/// @authors
///
/// @last_modified
/////////////////////////////////////////////////////////////////////////////////
void
get_orientation(
	Vector const & a,
	Vector const & b,
	//FArray1DB_float const & a,
	//FArray1DB_float const & b,
	int & aphi,
	int & theta,
	Real dist
)
{
	using namespace numeric::constants::d;
	using numeric::sin_cos_range;

	// pb -- static can cause problems in multi-threading
	//apl allocate this once only
	//static FArray1D_float diff( 3 );
	Vector diff( ( a - b ) / dist );

//cj    figure the difference between a and b
	//apl - You've already computed the distance! reuse it.
	//diff(1) = (a(1)-b(1) ) / dist;
	//diff(2) = (a(2)-b(2) ) / dist;
	//diff(3) = (a(3)-b(3) ) / dist;

//cj    now figure out polar values of aphi and theta
//cj    Normalize the difference
//cj    first get the length of the vector
	//vector_normalize(diff);

//cj    figuring aphi

	Real p = std::acos( sin_cos_range( diff(3) ) );

	p *= nphi / pi_2;
	aphi = static_cast< int >( p );
	++aphi; // for fortran goes from 1 to n
	if ( aphi > nphi ) aphi = 1;

//cj    figuring theta
	Real t = std::atan2(diff(2),diff(1));
	t *= ntheta / pi_2;
	theta = static_cast< int >( t );
	++theta; // for fortran goes from 1 to n
	if ( theta < 1 ) {
		theta += ntheta;
	} else if ( theta > ntheta ) {
		theta = 1;
	}

}

////////////////////////////////////////////////////////////////////////////////
/// @begin get_2way_orientation
///
/// @brief
///cj    gets the orientation of a to b (i to j, see below)
///cj    does this by calculating two angles, aphi and theta
///
/// @detailed
///
/////////////////////////////////////////////////////////////////////////////////
// void
// get_2way_orientation(
// 	FArray1DB_Real const & a,
// 	FArray1DB_Real const & b,
// 	int & aphi_a2b,
// 	int & theta_a2b,
// 	int & aphi_b2a,
// 	int & theta_b2a,
// 	Real dist
// )
// {
// 	using namespace fullatom_sasa;
// 	using namespace numeric::constants::f;
// 	using numeric::sin_cos_range;

// 	//apl allocate this once only
// 	static FArray1D_Real diff( 3 );

// 	//cj    figure the difference between a and b
// 	//apl - You've already computed the distance! reuse it.
// 	diff(1) = (a(1)-b(1) ) / dist;
// 	diff(2) = (a(2)-b(2) ) / dist;
// 	diff(3) = (a(3)-b(3) ) / dist;

// 	//diff(4) = (b(1)-a(1) ) / dist;
// 	//diff(5) = (b(2)-a(2) ) / dist;
// 	//diff(6) = (b(3)-a(3) ) / dist;

// 	//cj    now figure out polar values of aphi and theta
// 	//cj    Normalize the difference
// 	//cj    first get the length of the vector
// 	//vector_normalize(diff);

// 	//figuring aphi_a2b
// 	Real p_a2b = std::acos( sin_cos_range( diff(3) ) );
// 	p_a2b *= nphi / pi_2;
// 	Real p_b2a = nphi / 2 - p_a2b;

// 	aphi_a2b = static_cast< int >( p_a2b );
// 	aphi_b2a = static_cast< int >( p_b2a );

// 	++aphi_a2b; // for fortran goes from 1 to n
// 	++aphi_b2a; // for fortran goes from 1 to n

// 	if ( aphi_a2b > nphi ) aphi_a2b = 1;
// 	if ( aphi_b2a > nphi ) aphi_b2a = 1;

// 	//figuring theta_a2b
// 	Real t_a2b = std::atan2(diff(2),diff(1));
// 	t_a2b *= ntheta / pi_2;

// 	theta_a2b = static_cast< int >( t_a2b );
// 	++theta_a2b; // for fortran goes from 1 to n
// 	if ( theta_a2b < 1 ) {
// 		theta_a2b += ntheta;
// 	} else if ( theta_a2b > ntheta ) {
// 		theta_a2b = 1;
// 	}

// 	//figuring theta_a2b
// 	Real t_b2a = ntheta / 2.0f + t_a2b;
// 	if (t_b2a > ntheta / 2.0f ) t_b2a -= ntheta;


// 	theta_b2a = static_cast< int >( t_b2a );
// 	++theta_b2a; // for fortran goes from 1 to n
// 	if ( theta_b2a < 1 ) {
// 		theta_b2a += ntheta;
// 	} else if ( theta_b2a > ntheta ) {
// 		theta_b2a = 1;
// 	}
// }



////////////////////////////////////////////////////////////////////////////////
Real // returns total sasa
calc_total_sasa(
	pose::Pose const & pose,
	Real const probe_radius
)
{
	id::AtomID_Map< Real > atom_sasa;
	utility::vector1< Real > rsd_sasa;

	return calc_per_atom_sasa( pose, atom_sasa, rsd_sasa, probe_radius );
}

////////////////////////////////////////////////////////////////////////////////
Real // returns total sasa
calc_per_atom_sasa(
	pose::Pose const & pose,
	id::AtomID_Map< Real > & atom_sasa,
	utility::vector1< Real > & rsd_sasa,
	Real const probe_radius,
	bool const use_big_polar_H // = false
)
{
	id::AtomID_Map< bool > atom_subset;
	atom_subset.clear();
	id::initialize( atom_subset, pose, true ); // jk use all atoms if atom_subset is not specified
	return calc_per_atom_sasa( pose, atom_sasa, rsd_sasa, probe_radius, use_big_polar_H, atom_subset );
}


////////////////////////////////////////////////////////////////////////////////
Real // returns total sasa
calc_per_atom_sasa(
	pose::Pose const & pose,
	id::AtomID_Map< Real > & atom_sasa,
	utility::vector1< Real > & rsd_sasa,
	Real const probe_radius,
	bool const use_big_polar_H,
	id::AtomID_Map< bool > & atom_subset
)
{
	using conformation::Residue;
	using conformation::Atom;
	using id::AtomID;

	Real const big_polar_H_radius( 1.08 ); // increase radius of polar hydrogens, eg when doing unsatisfied donorH check

	Size const nres( pose.total_residue() );
	if ( nres < 1 ) return 0.0; // nothing to do

	// setup the radii array, indexed by the atom type int
	// atom index for looking up an extra data type stored in the AtomTypes
	utility::vector1< Real > radii;
	{
		chemical::AtomTypeSet const & atom_set( pose.residue(1).atom_type_set() );
		Size const SASA_RADIUS_INDEX( atom_set.extra_parameter_index( "SASA_RADIUS" ) );
		radii.resize( atom_set.n_atomtypes() );
		for ( Size i=1; i<= radii.size(); ++i ) {
			chemical::AtomType const & at( atom_set[i] );
			radii[i] = atom_set[i].extra_parameter( SASA_RADIUS_INDEX );
			if ( use_big_polar_H && at.is_polar_hydrogen() && big_polar_H_radius > radii[i] ) {
				util::T("core.scoring.sasa") << "Using " << big_polar_H_radius << " instead of " << radii[i] <<
					" for atomtype " << at.name() << " in sasa calculation!\n";
				radii[i] = big_polar_H_radius;
			}
		}
	}

	typedef utility::vector1< ubyte > ubytes;
	id::AtomID_Map< ubytes > atm_masks;
	{ // initialize atm_masks
		ubytes zero_mask( nbytes, ubyte( 0 ) );
		id::initialize( atm_masks, pose, zero_mask );
	} // scope

// local:
	//FArray3D_ubyte atm_masks( nbytes, MAX_ATOM()(), MAX_RES()() );

	// read sasa datafiles
	input_sasa_dats(); //returns if already done

	// for use in skipping residue pairs if their nbr atoms are too far apart:
	Real cutoff_distance(0.0);
	{ // scope
		Real max_radius( 0.0 );
		for ( Size i=1; i<= nres; ++i ) {
			Residue const & rsd( pose.residue( i ) );
			for ( Size j=1; j<= rsd.natoms(); ++j ) {
				max_radius = std::max( max_radius, radii[ rsd.atom(j).type() ] );
				assert( std::abs( atm_masks[ AtomID(j,i) ][1] + atm_masks[ AtomID(j,i) ][nbytes] ) < 1e-3 );
			}
		}
		cutoff_distance = 2 * ( max_radius + probe_radius );
	}


//cj----now do calculations: get the atm_masks by looping over all_atoms x all_atoms
	for ( Size ir = 1; ir <= nres; ++ir ) {
		Residue const & irsd( pose.residue( ir ) );

		for ( Size jr = ir; jr <= nres; ++jr ) {
			Residue const & jrsd( pose.residue( jr ) );
			// use distance rather than distance_squared since the nrbr_radii might be negative
			if ( distance( irsd.atom( irsd.nbr_atom() ).xyz(), jrsd.atom( jrsd.nbr_atom() ).xyz() ) >
					 irsd.nbr_radius() + jrsd.nbr_radius() + cutoff_distance ) continue;

			for ( Size ia = 1; ia <= irsd.natoms(); ++ia ) {
				AtomID const iid( ia, ir );
				if ( ! atom_subset[ iid ] ) continue; // jk skip this atom if not part of the subset
				Atom const & iatom( irsd.atom( ia ) );
				Vector const & ic( iatom.xyz() );
				Real const irad = radii[ iatom.type() ] + probe_radius;
				//chemical::AtomType const & it( irsd.atom_type( ia ) );
				//Real const irad = it.extra_parameter( SASA_RADIUS_INDEX ) + probe_radius;

				for ( Size ja = 1; ja <= jrsd.natoms(); ++ja ) {
					AtomID const jid( ja, jr );
					if ( ! atom_subset[ jid ] ) continue; // jk skip this atom if not part of the subset
					Atom const & jatom( jrsd.atom( ja ) );
					Vector const & jc( jrsd.atom( ja ).xyz() );
					Real const jrad = radii[ jatom.type() ] + probe_radius;
					//chemical::AtomType const & jt( jrsd.atom_type( ja ) );
					//Real const jrad = jt.extra_parameter( SASA_RADIUS_INDEX ) + probe_radius;

					Real const dist( distance( ic, jc ) ); // could be faster w/o sqrt, using Jeff Gray's rsq_min stuff

					if ( dist <= irad + jrad ) {
						if ( dist <= 0.0 ) continue;

						// account for j overlapping i:
						// jk Note: compute the water SASA, but DON'T allow the water
						// jk to contribute to the burial of non-water atoms
						int olp, aphi, theta, point, masknum;

						if ( !jrsd.atom_type(ja).is_h2o() ) {
							get_overlap( irad, jrad, dist, olp );
							get_orientation( ic, jc, aphi, theta, dist );
							point = angles( aphi, theta );
							masknum = point * 100 + olp;
							ubytes & imasks( atm_masks[ AtomID( ia, ir ) ] );
							for ( int bb = 1, m = masks.index(bb,masknum); bb <= nbytes; ++bb, ++m ) {
								imasks[ bb ] = bit_or( imasks[ bb ], masks[ m ] );
								//atm_masks(bb,ia,ir) = bit_or(atm_masks(bb,ia,ir),masks(bb,masknum));
							}
						}

						// account for i overlapping j:
						// jk Note: compute the water SASA, but DON'T allow the water
						// jk to contribute to the burial of non-water atoms
						if ( !irsd.atom_type(ia).is_h2o() ) {
							get_overlap( jrad, irad, dist, olp );
							get_orientation( jc, ic, aphi, theta, dist );
							point = angles( aphi, theta );
							masknum = point * 100 + olp;
							ubytes & jmasks( atm_masks[ AtomID( ja, jr ) ] );
							for ( int bb = 1, m = masks.index(bb,masknum); bb <= nbytes; ++bb, ++m ) {
								jmasks[ bb ] = bit_or( jmasks[ bb ], masks[ m ] );
								// atm_masks(bb,ja,jr) = bit_or(atm_masks(bb,ja,jr),masks(bb,masknum));
							}
						}
					}         // dist <= irad + jrad
				}          // ja
			}           // jr
		}            // ia
	}             // ir

//-----calculate the residue and atom sasa
	atom_sasa.clear();
	id::initialize( atom_sasa, pose, (Real) -1.0 ); // jk initialize to -1 for "not computed"

	rsd_sasa.clear();
	rsd_sasa.resize( nres, 0.0 );

	Real total_sasa( 0.0 );

	Real const four_pi = 4.0f * Real( numeric::constants::d::pi );
	for ( Size ir=1; ir<= nres; ++ir ) {
		Residue const & rsd( pose.residue(ir) );
		rsd_sasa[ ir ] = 0.0;
		for ( Size ia = 1; ia<= rsd.natoms(); ++ia ) {

			AtomID const id( ia, ir );
			if ( ! atom_subset[ id ] ) continue; // jk skip this atom if not part of the subset

			chemical::AtomType const & it( rsd.atom_type(ia) );

			Real const irad = radii[ rsd.atom(ia).type() ] + probe_radius;

			//cj       to get SASA:
			//cj       - count the number of 1's
			//cj       - figure fraction that they are
			//cj       - multiply by 4*pi*r_sqared

			int ctr = 0;
			ubytes const & imasks( atm_masks[ id ] );
			for ( int bb = 1; bb <= nbytes; ++bb ) {
				ctr += bit_count[ imasks[bb] ]; // atm_masks(bb,ia,ir)
			}

			Real const fraction = static_cast< Real >( ctr ) / maskbits;
			Real const total_sa = four_pi * ( irad * irad );
			Real const expose = ( 1.0f - fraction ) * total_sa;

			atom_sasa[ id ] = expose;
			// jk Water SASA doesn't count toward the residue's SASA
			if ( !it.is_h2o() ) {
				rsd_sasa[ ir ] += expose;
				total_sasa += expose;
			}
		}                  // ia
	}                     // ir
	return total_sasa;
}

} // ns scoring
} // ns core
