// -*- mode:c++;tab-width:2;indent-tabs-mode:t;show-trailing-whitespace:t;rm-trailing-spaces:t -*-
// vi: set ts=2 noet:
// :noTabs=false:tabSize=4:indentSize=4:
//
// (c) Copyright Rosetta Commons Member Institutions.
// (c) This file is part of the Rosetta software suite and is made available under license.
// (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
// (c) For more information, see http://www.rosettacommons.org. Questions about this can be
// (c) addressed to University of Washington UW TechTransfer, email: license@u.washington.edu.

/// @file   protocols/match/output/UpstreamDownstreamCollisionFilter.hh
/// @brief  Implementation for class to filter matches where the upstream residues collide.
/// @author Alex Zanghellini (zanghell@u.washington.edu)
/// @author Andrew Leaver-Fay (aleaverfay@gmail.com), porting to mini

// Unit headers
#include <protocols/match/output/UpstreamDownstreamCollisionFilter.hh>

// Package headers
#include <protocols/match/BumpGrid.hh>
#include <protocols/match/downstream/DownstreamBuilder.hh>
#include <protocols/match/output/UpstreamHitCacher.hh>

// Project headers
#include <core/conformation/Residue.hh>
#include	<core/pose/Pose.hh>
#include <core/scoring/ScoreFunction.hh>
#include <core/scoring/ScoringManager.hh>
#include <core/scoring/etable/EtableEnergy.hh>
#include <core/scoring/methods/EnergyMethodOptions.hh>

// Utility headers
#include <utility/pointer/ReferenceCount.hh>

// C++ headers
#include <iostream>

namespace protocols {
namespace match {
namespace output {

UpstreamDownstreamCollisionFilter::UpstreamDownstreamCollisionFilter(
	UpstreamHitCacherOP coordinate_cacher
) :
	filter_by_lj_( false ),
	wfa_atr_( 0.8 ),
	wfa_rep_( 0.44 ),
	wfa_sol_( 0.6 ),
	lj_cutoff_( 10 ),
	tolerated_overlap_( 0.0 ),
	empty_pose_( new core::pose::Pose ),
	empty_sfxn_( new core::scoring::ScoreFunction ),
	cacher_( coordinate_cacher ),
	etable_energy_( 0 ),
	bump_grid_( new BumpGrid )
{
	std::cout << "Created UpstreamDownstreamCollisionFilter" << std::endl;
}

UpstreamDownstreamCollisionFilter::~UpstreamDownstreamCollisionFilter()
{}

void
UpstreamDownstreamCollisionFilter::set_downstream_pose( core::pose::Pose const & downstream_pose )
{
	downstream_pose_ = new core::pose::Pose( downstream_pose );
	Size count_atoms( 0 );
	per_res_atom_ind_.resize( downstream_pose.total_residue() );
	for ( Size ii = 1; ii <= downstream_pose.total_residue(); ++ii ) {
		Size const ii_natoms = downstream_pose.residue( ii ).natoms();
		per_res_atom_ind_[ ii ].resize( ii_natoms );
		for ( Size jj = 1; jj <= ii_natoms; ++jj ) {
			per_res_atom_ind_[ ii ][ jj ] = ++count_atoms;
		}
	}
	downstream_atoms_.resize( count_atoms );
	coords_.resize( count_atoms );
	count_atoms = 0;
	for ( Size ii = 1; ii <= downstream_pose.total_residue(); ++ii ) {
		Size const ii_natoms = downstream_pose.residue( ii ).natoms();
		for ( Size jj = 1; jj <= ii_natoms; ++jj ) {
			downstream_atoms_[ count_atoms ] = core::id::AtomID( jj, ii );
		}
	}
}

void
UpstreamDownstreamCollisionFilter::set_num_geometric_constraints( Size n_geomcst )
{
	dsbuilders_.resize( n_geomcst );
}

void
UpstreamDownstreamCollisionFilter::set_downstream_builder(
	Size geomcst_id,
	downstream::DownstreamBuilderCOP builder
)
{
	runtime_assert( dsbuilders_.size() >= geomcst_id && geomcst_id > 0 );
	dsbuilders_[ geomcst_id ] = builder;
}



bool
UpstreamDownstreamCollisionFilter::passes_filter(
	match const & m
) const
{
	for ( Size ii = 1; ii <= m.size(); ++ii ) {
		if ( ! dsbuilders_[ ii ] ) continue;
		if ( passes_filter( match_dspos1( m, ii ) ) ) return true;
	}
	return false;
}

bool
UpstreamDownstreamCollisionFilter::passes_filter(
	match_dspos1 const & m
) const
{
	if ( filter_by_lj_ ) {
		return passes_etable_filter( m );
	} else {
		return passes_hardsphere_filter( m );
	}
}

void UpstreamDownstreamCollisionFilter::set_filter_by_lj( bool setting )
{
	filter_by_lj_ = setting;
	if ( filter_by_lj_ ) {
		using namespace core::scoring;
		using namespace core::scoring::etable;
		using namespace core::scoring::methods;
		EnergyMethodOptions eopts;
		etable_energy_ = new EtableEnergy(
			*(ScoringManager::get_instance()->etable( eopts.etable_type() )), eopts );
	}
}

void UpstreamDownstreamCollisionFilter::set_lj_cutoff( Real setting )
{
	lj_cutoff_ = setting;
}

void UpstreamDownstreamCollisionFilter::set_lj_atr_weight( Real setting )
{
	wfa_atr_ = setting;
}

void UpstreamDownstreamCollisionFilter::set_lj_rep_weight( Real setting )
{
	wfa_rep_ = setting;
}

void UpstreamDownstreamCollisionFilter::set_lj_sol_weight( Real setting )
{
	wfa_sol_ = setting;
}

void UpstreamDownstreamCollisionFilter::set_tolerated_overlap( Real setting )
{
	tolerated_overlap_ = setting;
	bump_grid_->set_general_overlap_tolerance( tolerated_overlap_ );
	max_overlap_dis_ = 0;
	for ( Size ii = 1; ii <= n_probe_radii; ++ii ) {
		for ( Size jj = ii; jj <= n_probe_radii; ++jj ) {
			if ( bump_grid_->required_separation_distance(
					ProbeRadius( ii ), ProbeRadius( jj ) ) > max_overlap_dis_ ) {
				max_overlap_dis_ = bump_grid_->required_separation_distance( ProbeRadius( ii ), ProbeRadius( jj ) );
			}
		}
	}

}

bool UpstreamDownstreamCollisionFilter::passes_etable_filter( match_dspos1 const & m ) const
{
	using namespace core;
	using namespace core::conformation;
	using namespace core::pose;

	runtime_assert( dsbuilders_[ m.originating_geom_cst_for_dspos ] );
	dsbuilders_[ m.originating_geom_cst_for_dspos ]->coordinates_from_hit(
		full_hit( m ), downstream_atoms_, coords_ );
	for ( Size ii = 1; ii <= downstream_atoms_.size(); ++ii ) {
		downstream_pose_->set_xyz( downstream_atoms_[ ii ], coords_[ ii ] );
	}

	using namespace core::scoring;
	EnergyMap emap;
	for ( Size ii = 1; ii < m.upstream_hits.size(); ++ii ) {
		if ( ii == m.originating_geom_cst_for_dspos ) continue; // don't collision check since we've presumably done so already
		for ( Size jj = 1; jj <= downstream_pose_->total_residue(); ++jj ) {
			emap[ fa_atr ] = 0; emap[ fa_rep ] = 0; emap[ fa_sol ] = 0;
			etable_energy_->residue_pair_energy(
				*( cacher_->upstream_conformation_for_hit( ii, fake_hit( m.upstream_hits[ ii ] )) ),
				downstream_pose_->residue( jj ),
				*empty_pose_, *empty_sfxn_,
				emap );
			Real energy = wfa_atr_ * emap[ fa_atr ] + wfa_rep_ * emap[ fa_rep ] + wfa_sol_ * emap[ fa_sol ];
			if ( energy > lj_cutoff_ ) return false;
		}
	}
	return true;

}

bool UpstreamDownstreamCollisionFilter::passes_hardsphere_filter( match_dspos1 const & m ) const
{
	runtime_assert( dsbuilders_[ m.originating_geom_cst_for_dspos ] );
	dsbuilders_[ m.originating_geom_cst_for_dspos ]->coordinates_from_hit(
		full_hit( m ), downstream_atoms_, coords_ );

	for ( Size ii = 1; ii < m.upstream_hits.size(); ++ii ) {
		core::conformation::ResidueCOP iires = cacher_->upstream_conformation_for_hit( ii, fake_hit( m.upstream_hits[ ii ] ) );
		Size ii_first_sc = iires->first_sidechain_atom();
		for ( Size jj = 1; jj <= downstream_pose_->total_residue(); ++jj ) {
			core::conformation::Residue const & jjres = downstream_pose_->residue( jj );
			Real intxn_dis = iires->nbr_radius() + jjres.nbr_radius() + max_overlap_dis_;
			if ( iires->xyz( iires->nbr_atom() ).distance_squared(
					coords_[ per_res_atom_ind_[ jj ][ jjres.nbr_atom() ]] ) >
					intxn_dis * intxn_dis ) {
				continue;
			}

			for ( Size kk = ii_first_sc; kk <= iires->nheavyatoms(); ++kk ) {
				ProbeRadius kk_rad = probe_radius_for_atom_type( iires->atom_type_index( jj ) );

				for ( Size ll = 1; ll <= downstream_pose_->residue( jj ).nheavyatoms(); ++ll ) {
					ProbeRadius ll_rad = probe_radius_for_atom_type( jjres.atom_type_index( ll ) );
					Real minsep = bump_grid_->required_separation_distance( kk_rad, ll_rad );
					if ( iires->xyz( kk ).distance_squared( coords_[ per_res_atom_ind_[ jj ][ ll ]]) < minsep * minsep ) {
						return false;
					}
				}
			}
		}
	}
	return true;
}


}
}
}
