// -*- mode:c++;tab-width:2;indent-tabs-mode:t;show-trailing-whitespace:t;rm-trailing-spaces:t -*-
// vi: set ts=2 noet:
//
// (c) Copyright Rosetta Commons Member Institutions.
// (c) This file is part of the Rosetta software suite and is made available under license.
// (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
// (c) For more information, see http://www.rosettacommons.org. Questions about this can be
// (c) addressed to University of Washington UW TechTransfer, email: license@u.washington.edu.

/// @file numeric/kdtree/util.cc
/// @brief
/// @author James Thompson

#include <numeric/types.hh>

#include <numeric/kdtree/constants.hh>
#include <numeric/kdtree/util.hh>
#include <numeric/kdtree/KDNode.hh>
#include <numeric/kdtree/KDTree.hh>
#include <numeric/kdtree/KDPoint.hh>
#include <numeric/kdtree/HyperRectangle.hh>
#include <numeric/kdtree/HyperRectangle.fwd.hh>

#include <utility/vector1.hh>
#include <utility/pointer/ReferenceCount.hh>
#include <utility/pointer/ReferenceCount.fwd.hh>

#include <cmath>
#include <string>
#include <iostream>
#include <algorithm>

namespace numeric {
namespace kdtree {

void
nearest_neighbor(
	KDTree & tree,
	utility::vector1< numeric::Real > const & pt,
	// returns:
	KDNodeOP & nearest,
	numeric::Real & dist_sq
) {
	HyperRectangle bounds = (*tree.bounds());
	//numeric::Real max_dist_sq( sq_vec_distance(
	//	bounds->lower(),
	//	bounds->upper()
	//) );
	numeric::Real max_dist_sq( REALLY_BIG_DISTANCE );
	KDNodeOP root( tree.root() );

	nearest_neighbor(
		root, pt, bounds, max_dist_sq, nearest, dist_sq
	);
}

KDPointList
nearest_neighbors(
	KDTree & tree,
	utility::vector1< numeric::Real > const & pt,
	Size const wanted
) {
	HyperRectangle bounds = (*tree.bounds());
	//numeric::Real max_dist_sq( sq_vec_distance(
	//	bounds->lower(),
	//	bounds->upper()
	//) );
	numeric::Real max_dist_sq( REALLY_BIG_DISTANCE );
	KDNodeOP root( tree.root() );

	KDPointList nearest( wanted );

	nearest_neighbors(
		root, pt, bounds, max_dist_sq, nearest
	);

	return nearest;
}

HyperRectangleOP get_percentile_bounds(
	utility::vector1< utility::vector1< numeric::Real > > & points
) {
	using std::min;
	using std::max;
	using numeric::Real;
	using utility::vector1;
	typedef vector1< vector1< Real > >::iterator row_iter;
	typedef vector1< Real >::iterator pt_iter;

	// define lower/upper values
	vector1< Real >
		lower( points.front().begin(), points.front().end() ),
		upper( points.front().begin(), points.front().end() );

	for ( row_iter it = points.begin(), end = points.end();
				it != end; ++it
	) {
		for ( pt_iter p_it = it->begin(), p_end = it->end(),
					l_it = lower.begin(), l_end = lower.end(),
					u_it = upper.begin(), u_end = upper.end();
					p_it != p_end && l_it != l_end && u_it != u_end;
					++p_it, ++l_it, ++u_it
		) {
			*l_it = std::min( *l_it, *p_it );
			*u_it = std::max( *u_it, *p_it );
		}
	} // rows

	return new HyperRectangle(
		upper, lower
	);
}

void transform_percentile_single_pt(
	utility::vector1< numeric::Real > & point,
	HyperRectangleOP bounds
) {
	using numeric::Real;
	using utility::vector1;
	vector1< Real >
		lower( bounds->lower() ),
		upper( bounds->upper() );
	typedef vector1< Real >::iterator pt_iter;
	for ( pt_iter p_it = point.begin(), p_end = point.end(),
				l_it = lower.begin(), l_end = lower.end(),
				u_it = upper.begin(), u_end = upper.end();
				p_it != p_end && l_it != l_end && u_it != u_end;
				++p_it, ++l_it, ++u_it
	) {
		*p_it = ( *p_it - *l_it ) / ( *u_it - *l_it );
	}
}

void transform_percentile(
	utility::vector1< utility::vector1< numeric::Real > > & points,
	HyperRectangleOP bounds
) {
	using std::min;
	using std::max;
	using numeric::Real;
	using utility::vector1;
	typedef vector1< vector1< Real > >::iterator row_iter;
	typedef vector1< Real >::iterator pt_iter;

	// define lower/upper values
	vector1< Real >
		lower( bounds->lower() ),
		upper( bounds->upper() );

	// transform values
	for ( row_iter it = points.begin(), end = points.end();
				it != end; ++it
	) {
		for ( pt_iter p_it = it->begin(), p_end = it->end(),
					l_it = lower.begin(), l_end = lower.end(),
					u_it = upper.begin(), u_end = upper.end();
					p_it != p_end && l_it != l_end && u_it != u_end;
					++p_it, ++l_it, ++u_it
		) {
			*p_it = ( *p_it - *l_it ) / ( *u_it - *l_it );
		}
	} // rows
} // transform_percentile


void transform_percentile(
	utility::vector1< utility::vector1< numeric::Real > > & points
) {
	HyperRectangleOP bounds = get_percentile_bounds( points );
	transform_percentile( points, bounds );
}

utility::vector1< KDPointOP > make_points(
	utility::vector1< utility::vector1< numeric::Real > > const & points
) {

	using numeric::Real;
	using utility::vector1;
	typedef vector1< vector1< Real > >::const_iterator p_iter;

	vector1< KDPointOP > new_data;
	for ( p_iter p_it = points.begin(), p_end = points.end();
				p_it != p_end; ++p_it
	) {
		KDPointOP pt( new KDPoint( *p_it ) );
		new_data.push_back( pt );
	}

	return new_data;
} // make_points

utility::vector1< KDPointOP > make_points(
	utility::vector1< utility::vector1< numeric::Real > > const & points,
	utility::vector1< utility::pointer::ReferenceCountOP > const & data
) {

	assert( points.size() == data.size() );

	using numeric::Real;
	using utility::vector1;
	using utility::pointer::ReferenceCountOP;
	typedef vector1< vector1< Real >  >::const_iterator p_iter;
	typedef vector1< ReferenceCountOP >::const_iterator d_iter;

	vector1< KDPointOP > new_data;
	d_iter d_it = data.begin(), d_end = data.end();
	for ( p_iter p_it = points.begin(), p_end = points.end();
				p_it != p_end && d_it != d_end;
				++p_it, ++d_it
	) {
		KDPointOP pt( new KDPoint( *p_it, *d_it ) );
		new_data.push_back( pt );
	}

	return new_data;
} // make_points

KDNodeOP construct_kd_tree(
	utility::vector1< KDPointOP > & points,

	numeric::Size depth,
	KDTree & tree
) {
	using numeric::Real;
	using numeric::Size;
	using utility::vector1;
	if ( points.size() == 0 ) {
		return NULL;
	}

	numeric::Size width = points.front()->size();
	numeric::Size axis  = ( depth % width ) + 1;

	// sort points by this axis, which is the index into the points for
	// comparison. This is a big place where we can get a speedup in tree
	// construction. Other smart things to do in the future are:
	// - a linear-time median finding function (smart and hard to implement).
	// - finding a median based on a subset of points and grabbing a pivot
	// point near it (stupid but easy to implement).
	std::sort( points.begin(), points.end(), CompareKDPoints(axis) );
	// location is the median of points along the split axis
	numeric::Size median_idx = static_cast< numeric::Size > ( points.size() / 2 ) + 1;
	if ( median_idx == 0 ) median_idx = 1;
	KDNodeOP current( new KDNode(tree) );
	current->point( points[median_idx] );
	current->split_axis( axis );

	// if we have enough points, split the points into two halves:
	// - the set of points less than this location (left_child)
	// - the set of points greater than this location (right_child)
	// Since the points are already sorted by axis.
	if ( points.size() > 1 ) {
		//vector1< vector1< Real > > left_points(
		//	points.begin(), points.begin() + median_idx - 1
		//);
		vector1< KDPointOP > left_points(
			points.begin(), points.begin() + median_idx - 1
		);
		current->left_child( construct_kd_tree( left_points, depth + 1, tree ) );
		current->left_child()->parent( current );
	}
	if ( points.size() > 2 ) {
		vector1< KDPointOP > right_points(
			points.begin() + median_idx, points.end()
		);
		current->right_child( construct_kd_tree( right_points, depth + 1, tree ) );
		current->right_child()->parent( current );
	}

	return current;
} // construct_kd_tree

void
nearest_neighbors(
	KDNodeOP & current,
	utility::vector1< numeric::Real > const & pt,
	HyperRectangle & bounds,
	numeric::Real max_dist_sq,

	// returns:
	KDPointList & neighbors
) {
	using numeric::Size;
	using numeric::Real;
	using utility::vector1;

	if ( !current ) {
		return;
	}
	Size split_axis  = current->split_axis();
	Real split_value = current->location()[split_axis];

	KDNodeOP nearer, further;
	// mid is a hyper-plane through this location and perpendicular
	// to the split axis
	utility::vector1< numeric::Real >
		mid  ( current->location() ),
		lower( bounds.lower()      ),
		upper( bounds.upper()      );
	HyperRectangle left_hr ( lower, mid );
	HyperRectangle right_hr( upper, mid );

	HyperRectangle nearer_hr, further_hr;
	if ( pt[split_axis] <= split_value ) {
		nearer     = current->left_child();
		further    = current->right_child();
		nearer_hr  = left_hr;
		further_hr = right_hr;
	} else {
		nearer     = current->right_child();
		further    = current->left_child();
		nearer_hr  = left_hr;
		further_hr = right_hr;
	}

	nearest_neighbors(
		nearer, pt, nearer_hr, max_dist_sq, neighbors
	);
	max_dist_sq = std::min( neighbors.worst_distance(), max_dist_sq );

	// we need to seach this point and right-hand child (furthest
	// hyper-rectangle) if there's a part of further_hr within
	// sqrt( max_dist_sq ) of pt.
	if ( hr_intersects_hs( further_hr, pt, std::sqrt( max_dist_sq ) ) ) {
		current->distance(
			sq_vec_distance( current->location(), pt )
		);
		if ( current->distance() < neighbors.worst_distance() ) {
			neighbors.insert( current->point() );
			max_dist_sq = neighbors.worst_distance();
		}
		// 10.2
		KDPointList temp_neighbors( neighbors.max_values() );
		nearest_neighbors(
			further, pt, further_hr, max_dist_sq, temp_neighbors
		);

		neighbors.merge( temp_neighbors );
	} // hr_intersects_hs
} // nearest_neighbor

void
nearest_neighbor(
	KDNodeOP & current,
	utility::vector1< numeric::Real > const & pt,
	HyperRectangle & bounds,
	numeric::Real max_dist_sq,

	// returns:
	KDNodeOP & nearest,
	numeric::Real & dist_sq
) {
	using numeric::Size;
	using numeric::Real;
	using utility::vector1;

	if ( !current ) {
		dist_sq = REALLY_BIG_DISTANCE; // hopefully big enough for most spaces
		return;
	}
	Size split_axis  = current->split_axis();
	Real split_value = current->location()[split_axis];

	KDNodeOP nearer, further;
	// mid is a hyper-plane through this location and perpendicular
	// to the split axis
	vector1< Real >
		mid  ( current->location() ),
		lower( bounds.lower()      ),
		upper( bounds.upper()      );
	HyperRectangle left_hr ( lower, mid );
	HyperRectangle right_hr( upper, mid );

	HyperRectangle nearer_hr, further_hr;
	if ( pt[split_axis] <= split_value ) {
		nearer     = current->left_child();
		further    = current->right_child();
		nearer_hr  = left_hr;
		further_hr = right_hr;
	} else {
		nearer     = current->right_child();
		further    = current->left_child();
		nearer_hr  = left_hr;
		further_hr = right_hr;
	}

	nearest_neighbor( nearer, pt, nearer_hr, max_dist_sq, nearest, dist_sq );
	max_dist_sq = std::min( dist_sq, max_dist_sq );

	// we need to seach the right-hand child (furthest hyper-rectangle)
	// if there's a part of further_hr within sqrt( max_dist_sq ) of pt.
	if ( hr_intersects_hs( further_hr, pt, std::sqrt( max_dist_sq ) ) ) {
		numeric::Real this_dist_sq = sq_vec_distance( current->location(), pt );
		if ( this_dist_sq < dist_sq ) {
			nearest = current;
			dist_sq = this_dist_sq;
			max_dist_sq = dist_sq;
		}
		// 10.2
		KDNodeOP temp_nearest;
		numeric::Real temp_dist_sq( REALLY_BIG_DISTANCE );
		nearest_neighbor(
			further, pt, further_hr, max_dist_sq, temp_nearest, temp_dist_sq
		);
		if ( temp_dist_sq < dist_sq ) {
			nearest = temp_nearest;
			dist_sq = temp_dist_sq;
		}
	} // hr_intersects_hs
} // nearest_neighbor

numeric::Real sq_vec_distance(
	utility::vector1< numeric::Real > const & vec1,
	utility::vector1< numeric::Real > const & vec2
) {
	assert( vec1.size() == vec2.size() );
	using numeric::Real;
	using utility::vector1;

	numeric::Real dist( 0.0 );
	for ( vector1< Real >::const_iterator
			it1 = vec1.begin(), it2 = vec2.begin(),
			end1 = vec1.end(), end2 = vec2.end();
			it1 != end1 && it2 != end2; ++it1, ++it2
	) {
		dist += ( *it1 - *it2 ) * ( *it1 - *it2 );
	}

	return dist;
}

numeric::Real vec_distance(
	utility::vector1< numeric::Real > const & vec1,
	utility::vector1< numeric::Real > const & vec2
) {
	return std::sqrt( sq_vec_distance( vec1, vec2 ) );
}

void print_points(
	std::ostream & out,
	utility::vector1< utility::vector1< numeric::Real > > points
) {
	using numeric::Real;
	using utility::vector1;
	for ( vector1< vector1< Real > >::const_iterator pt = points.begin(),
			end = points.end(); pt != end; ++pt ) {
		for ( vector1< Real >::const_iterator val = pt->begin(),
				val_end = pt->end(); val != val_end; ++val
		) {
			out << ' ' << *val;
		}
		out << std::endl;
	} // for points
}


// returns true if the given hyper-rectangle intersects with the given
// hypersphere.
bool hr_intersects_hs(
	HyperRectangle hr,
	utility::vector1< numeric::Real > const & pt,
	numeric::Real r
) {
	using numeric::Size;
	using numeric::Real;
	using utility::vector1;

	vector1< Real > qt( pt ),
		upper( hr.upper() ),
		lower( hr.lower() );
	// gross iteration!
	for ( vector1< Real >::const_iterator
				pt_it = pt.begin(), pt_end = pt.end(),
				lower_it = lower.begin(), lower_end = lower.end(),
				upper_it = upper.begin(), upper_end = upper.end();
			pt_it != pt_end && lower_it != lower_end && upper_it != upper_end;
			++pt_it, ++lower_it, ++upper_it
	) {
		if ( *lower_it <= *pt_it ) {
			qt.push_back( *lower_it );
		} else if ( *upper_it >= *pt_it ) {
			qt.push_back( *upper_it );
		}
	}

	// compute distance between p and q
	numeric::Real dist_sq( 0.0 );
	for ( vector1< Real >::const_iterator
				pt_it = pt.begin(), pt_end = pt.end(),
				qt_it = qt.begin(), qt_end = qt.end();
			pt_it != pt_end && qt_it != qt_end;
			++pt_it, ++qt_it
	) {
		dist_sq += ( *pt_it - *qt_it ) * ( *pt_it - *qt_it );
	}

	return ( dist_sq <= (r * r) );
} // hr_intersects_hs

} // kdtree
} // numeric
