// -*- mode:c++;tab-width:2;indent-tabs-mode:t;show-trailing-whitespace:t;rm-trailing-spaces:t -*-
// vi: set ts=2 noet:
//
// This file is made available under the Rosetta Commons license.
// See http://www.rosettacommons.org/license
// (C) 199x-2007 University of Washington
// (C) 199x-2007 University of California Santa Cruz
// (C) 199x-2007 University of California San Francisco
// (C) 199x-2007 Johns Hopkins University
// (C) 199x-2007 University of North Carolina, Chapel Hill
// (C) 199x-2007 Vanderbilt University

/// @file   design_functions.cc
/// @brief  Miscellaneous design functions.
/// @author Yih-En Andrew Ban (yab@u.washington.edu)

// unit headers
#include <epigraft/design/design_functions.hh>

// package headers
#include <epigraft/AtomPoint.hh>
#include <epigraft/ResidueRange.hh>

// roostock headers
#include <rootstock/Octree.hh>

// Rosetta headers
#include <aaproperties_pack.h>
#include <after_opts.h>
#include <design.h>
#include <etable_manager.h>
#include <fullatom_setup.h>
#include <InteractionGraphBase.h>
#include <jumping_refold.h>
#include <PackerTask.h>
#include <pack_fwd.h>
#include <pack.h>
#include <param.h>
#include <param_aa.h>
#include <pose.h>
#include <RotamerSet.h>
#include <template_pack.h>

// ObjexxFCL headers
#include <ObjexxFCL/ObjexxFCL.hh>
#include <ObjexxFCL/FArray1D.hh>
#include <ObjexxFCL/FArray2D.hh>
#include <ObjexxFCL/FArray3D.hh>

// C++ headers
#include <iostream>
#include <set>
#include <vector>


namespace epigraft {
namespace design {


/// @brief copy sidechains from full-atom sc-pose to full-atom bb-pose
/// @note  range copied in bb-pose and sc-pose is the same
/// @warning both Poses must be full-atom before calling this!
void
copy_sidechains(
	Pose & bb_pose,
	Pose & sc_pose,
	ResidueRange const & range
)
{
	// make sure both bb and sc Poses are fullatom
	if ( !( bb_pose.fullatom() && sc_pose.fullatom() ) ) {
		std::cerr << "ERROR -- epigraft::design::copy_sidechains() called with a Pose that is not full-atom, refusing to copy!" << std::endl;
		return;
	}

	// cache
	FArray3D_float const & sc_pose_full_coord = sc_pose.full_coord();

	// copy sidechains
	for ( Integer res = range.begin(), last_res = range.end(); res <= last_res; ++res ) {
		bb_pose.copy_sidechain( res, sc_pose.res( res ), sc_pose.res_variant( res ), sc_pose_full_coord( 1, 1, res ) );
	}
}


/// @brief setup design globals and options
void
setup_design_globals(
	Pose const & input
)
{
	// set globals
	::design::design_commands::try_both_his_tautomers = true;
	::design::active_rotamer_options.use_input_sc = true;

	// lookup command line options
	static bool use_large_rotamer_set = truefalseoption( "large_rotamer_set" );
	static bool favor_native_residue = truefalseoption( "favor_native_residue" );
	static bool native_bonus = realafteroption( "favor_native_residue", -1.5 );

	// set from command line options
	if ( use_large_rotamer_set ) {
		select_rotamer_set( "large" );
	}
	if ( favor_native_residue ) {
		set_native_bonus( native_bonus );
	}

	// pose_to_misc call forces redimensioning of global design matrix
	pose_to_misc( input );
}


/// @brief design using design matrix
void
design(
	Pose & input,
	FArray2D< bool > const & design_matrix
)
{
	PackerTask task( input );
	pack::InteractionGraphBase * ig = NULL;
	RotamerSet rs;
	FArray1D_int current_rot_index( input.total_residue() );
	FArray1D_float ligenergy1b_old;

	epigraft::design::design( input, design_matrix, task, ig, rs, current_rot_index, ligenergy1b_old );

	if ( ig != NULL ) {
		delete ig;
	}
}


/// @brief design using design matrix, returning design
///  objects so they can be re-used in subsequent design
///  call
/// @param task  must be initialized with PackerTask( input ) prior
///  to feeding into this routine
void
design(
	Pose & input,
	FArray2D< bool > const & design_matrix,
	PackerTask & task,
	pack::InteractionGraphBase * & ig,
	RotamerSet & rs,
	FArray1D_int & current_rot_index,
	FArray1D_float & ligenergy1b_old
)
{
	::design::use_design_matrix = true;
	::design::design_matrix = false; // reset global design matrix, has size: MAX_AA, MAX_RES

	setup_design_globals( input );

	int const & total_residue = input.total_residue();

	// pack rotamers parameters
	std::string pack_mode( "design" );
	bool make_output_file( false );
	FArray1D_bool allow_repack( param::MAX_RES()(), true );
	bool include_current( true );
	bool include_extra ( false );
	FArray2D_int extra_rot( param::MAX_CHI, param::MAX_RES()() ); // dummy variable in this instance
	FArray2D_float extra_chi( param::MAX_CHI, param::MAX_RES()() ); // dummy variable in this instance

	for ( int res = 1 ; res <= total_residue ; ++res ) {
		for ( int aa = 1; aa <= param::MAX_AA(); ++aa ) {
			::design::design_matrix( aa, res ) = design_matrix( aa, res );
		}
	}

	// create PackerTask
	task.set_task( pack_mode, make_output_file, allow_repack,
	               include_current, include_extra, extra_rot, extra_chi );
	task.setup_residues_to_vary();

	epigraft::design::design( input, task, ig, rs, current_rot_index, ligenergy1b_old, true ); // boolean: setup
}


/// @brief design using Task and given objects
void
design(
	Pose & input,
	PackerTask & task,
	pack::InteractionGraphBase * & ig,
	RotamerSet & rs,
	FArray1D_int & current_rot_index,
	FArray1D_float & ligenergy1b_old,
	bool setup
)
{
	::design::use_design_matrix = false;
	::design::design_matrix = false;

	setup_design_globals( input );

	initialize_fullatom();

	// set the template_pack phi/psi from the input pose
	for ( int i = 1; i <= input.total_residue(); ++i ) {
		template_pack::phi(i) = input.phi(i);
		template_pack::psi(i) = input.psi(i);
	}

	// enable soft rep, if activated through command line
	enable_packing_etables( input.total_residue(), input.res(), input.res_variant(), input.full_coord() );

	if ( setup ) {
		pack_rotamers_setup( task, rs, ig, input, current_rot_index, ligenergy1b_old );
	}

	pack_rotamers_run( task, rs, ig, input, current_rot_index, ligenergy1b_old );

	// disable soft rep
	disable_packing_etables( input.total_residue(), input.res(), input.res_variant(), input.full_coord() );
}


/// @brief mutate range to residue type
/// @brief param[in] keep_gly will mutate all but glycines if true
void
mutate_range(
	Pose & pose,
	ResidueRange const & range,
	Integer const & aa,
	bool const & keep_gly
)
{
	FArray2D< bool > design_matrix( param::MAX_AA(), pose.total_residue(), false );

	// lock chi
	pose.set_allow_chi_move( false );

	// find gly residues
	std::set< Integer > original_gly;
	for ( Integer i = range.begin(), ie = range.end(); i <= ie; ++i ) {
		if ( pose.res( i ) == param_aa::aa_gly && keep_gly ) {
			original_gly.insert( i );
		}
	}

	// REVERSE via gly first: set residue identities
	for ( Integer i = range.begin(), ie = range.end(); i <= ie; ++i ) {
		if ( original_gly.find( i ) == original_gly.end() ) {
			design_matrix( param_aa::aa_gly, i ) = true;
		}
		pose.set_allow_chi_move( i, true );
	}

	// repack to get right coordinates
	epigraft::design::design( pose, design_matrix );

	// CORRECT second: set residue identities
	for ( Integer i = range.begin(), ie = range.end(); i <= ie; ++i ) {
		if ( original_gly.find( i ) == original_gly.end() ) {
			design_matrix( aa, i ) = true;
		}
		pose.set_allow_chi_move( i, true );
	}

	// repack to get right coordinates
	epigraft::design::design( pose, design_matrix );

	// re-lock chi
	pose.set_allow_chi_move( false );
}


/// @brief fill octree with sidechain atoms only
void
fill_octree_sc(
	rootstock::Octree< AtomPoint > & oc,
	Pose const & input
)
{
	fill_octree_sc( oc, input, ResidueRange( 1, input.total_residue() ) );
}


/// @brief fill octree with sidechain atoms only given range
void
fill_octree_sc(
	rootstock::Octree< AtomPoint > & oc,
	Pose const & input,
	ResidueRange const & range
)
{
	// cache
	FArray3D_float const & fc = input.full_coord();

	for ( Integer res = range.begin(), last_res = range.end(); res <= last_res; ++res ) {
		Integer const aa = input.res( res );
		Integer const aav = input.res_variant( res );
		Integer const first_atom = ( aa == param_aa::aa_gly ) ? 2 : 5;
		Integer const last_atom = ( aa == param_aa::aa_gly ) ? 2 : aaproperties_pack::nheavyatoms( aa, aav );

		for ( Integer q_atom = first_atom; q_atom <= last_atom; ++q_atom ) {
			oc.add( AtomPoint( res, 2, fc( 1, q_atom, res ), fc( 2, q_atom, res ), fc( 3, q_atom, res ) ) );
		}
	}
}


/// @brief given a query residue, return the residues that are nearby (in the octree)
std::set< Integer >
near_residues(
	Real const & distance_cutoff,
	rootstock::Octree< AtomPoint > const & oc,
	Pose const & query,
	Integer const & query_residue
)
{
	// cache
	FArray3D_float const & query_full_coord = query.full_coord();

	std::set< Integer > nr; // near residues;

	Integer const q_aa = query.res( query_residue );
	Integer const q_aav = query.res_variant( query_residue );

	for ( Integer q_atom = 1, last_atom = aaproperties_pack::natoms( q_aa, q_aav ); q_atom <= last_atom; ++q_atom ) {

		Integer idx = query_full_coord.index( 1, q_atom, query_residue );

		std::vector< AtomPoint > nn = oc.near_neighbors( distance_cutoff, (Real)query_full_coord[ idx ], (Real)query_full_coord[ idx + 1 ], (Real)query_full_coord[ idx + 2 ] );

		for ( std::vector< AtomPoint >::const_iterator ia = nn.begin(), iae = nn.end(); ia != iae; ++ia ) {
			nr.insert( ia->residue_id() );
		}

	}

	return nr;
}


/// @brief given a query residue's sidechain, return the residues that are nearby (in the octree)
std::set< Integer >
near_residues_to_sc(
	Real const & distance_cutoff,
	rootstock::Octree< AtomPoint > const & oc,
	Pose const & query,
	Integer const & query_residue
)
{
	// cache
	FArray3D_float const & query_full_coord = query.full_coord();

	std::set< Integer > nr; // near residues;

	Integer const q_aa = query.res( query_residue );
	Integer const q_aav = query.res_variant( query_residue );
	Integer const first_atom = ( q_aa == param_aa::aa_gly ) ? 2 : 5;
	Integer const last_atom = ( q_aa == param_aa::aa_gly ) ? 2 : aaproperties_pack::nheavyatoms( q_aa, q_aav );

	for ( Integer q_atom = first_atom; q_atom <= last_atom; ++q_atom ) {

		Integer idx = query_full_coord.index( 1, q_atom, query_residue );

		std::vector< AtomPoint > nn = oc.near_neighbors( distance_cutoff, (Real)query_full_coord[ idx ], (Real)query_full_coord[ idx + 1 ], (Real)query_full_coord[ idx + 2 ] );

		for ( std::vector< AtomPoint >::const_iterator ia = nn.begin(), iae = nn.end(); ia != iae; ++ia ) {
			nr.insert( ia->residue_id() );
		}

	}

	return nr;
}


/// @brief given a query structure, run through its sidechains and grab the residues that are nearby (in the octree)
std::set< Integer >
near_residues_to_sc(
	Real const & distance_cutoff,
	rootstock::Octree< AtomPoint > const & oc,
	Pose const & query
)
{
	// cache
	FArray3D_float const & query_full_coord = query.full_coord();

	std::set< Integer > nr; // near residues

	for ( Integer q_res = 1, last_q_res = query.total_residue(); q_res <= last_q_res; ++q_res ) {
		Integer const q_aa = query.res( q_res );
		Integer const q_aav = query.res_variant( q_res );
		Integer const first_atom = ( q_aa == param_aa::aa_gly ) ? 2 : 5;
		Integer const last_atom = ( q_aa == param_aa::aa_gly ) ? 2 : aaproperties_pack::nheavyatoms( q_aa, q_aav );

		for ( Integer q_atom = first_atom; q_atom <= last_atom; ++q_atom ) {

			Integer idx = query_full_coord.index( 1, q_atom, q_res );

			std::vector< AtomPoint > nn = oc.near_neighbors( distance_cutoff, (Real)query_full_coord[ idx ], (Real)query_full_coord[ idx + 1 ], (Real)query_full_coord[ idx + 2 ] );

			for ( std::vector< AtomPoint >::const_iterator ia = nn.begin(), iae = nn.end(); ia != iae; ++ia ) {
				nr.insert( ia->residue_id() );
			}

		}
	}

	return nr;
}


/// @brief given a query structure, return the residues that are nearby on the query structure
std::set< Integer >
near_residues_to_sc_in_query(
	Real const & distance_cutoff,
	rootstock::Octree< AtomPoint > const & oc,
	Pose const & query
)
{
	// cache
	FArray3D_float const & query_full_coord = query.full_coord();

	std::set< Integer > q_nr; // near residues on query

	for ( Integer q_res = 1, last_q_res = query.total_residue(); q_res <= last_q_res; ++q_res ) {
		Integer const q_aa = query.res( q_res );
		Integer const q_aav = query.res_variant( q_res );
		Integer const first_atom = ( q_aa == param_aa::aa_gly ) ? 2 : 5;
		Integer const last_atom = ( q_aa == param_aa::aa_gly ) ? 2 : aaproperties_pack::nheavyatoms( q_aa, q_aav );

		for ( Integer q_atom = first_atom; q_atom <= last_atom; ++q_atom ) {

			Integer idx = query_full_coord.index( 1, q_atom, q_res );

			std::vector< AtomPoint > nn = oc.near_neighbors( distance_cutoff, (Real)query_full_coord[ idx ], (Real)query_full_coord[ idx + 1 ], (Real)query_full_coord[ idx + 2 ] );

			if ( nn.size() > 0 ) {
				q_nr.insert( q_res );
				break; // no need to check any more atoms
			}

		}
	}

	return q_nr;
}


} // namespace design
} // namespace epigraft
