// -*- mode:c++;tab-width:2;indent-tabs-mode:t;show-trailing-whitespace:t;rm-trailing-spaces:t -*-
// vi: set ts=2 noet:
//
// This file is part of the Rosetta software suite and is made available under license.
// The Rosetta software is developed by the contributing members of the Rosetta Commons consortium.
// (C) 199x-2009 Rosetta Commons participating institutions and developers.
// For more information, see http://www.rosettacommons.org/.

/// @file   protocols/jd2/MPIFileBufJobDistributor.cc
/// @brief  implementation of MPIFileBufJobDistributor
/// @author Oliver Lange olange@u.washington.edu

// MPI headers
#ifdef USEMPI
#include <mpi.h> //keep this first
#endif

// Unit headers
#include <protocols/jd2/archive/MPIArchiveJobDistributor.hh>
#include <protocols/jd2/archive/ArchiveManager.hh>

// Package headers
#include <protocols/jd2/JobOutputter.hh>
#include <protocols/jd2/Job.hh>

#include <protocols/moves/Mover.hh>

#include <protocols/jd2/MpiFileBuffer.hh>
#include <utility/io/ozstream.hh> //to toggle MPI rerouting

// Utility headers
#include <core/util/Tracer.hh>
#include <core/options/option.hh>
#include <utility/exit.hh>
#include <utility/assert.hh>
#include <core/util/prof.hh>
#include <ObjexxFCL/string.functions.hh>

// Option headers
#include <core/options/keys/out.OptionKeys.gen.hh>
#include <core/options/keys/jd2.OptionKeys.gen.hh>

// C++ headers
#include <string>
#include <ctime>

static core::util::Tracer tr("protocols.jd2.MPIArchiveJobDistributor");

namespace protocols {
namespace jd2 {
namespace archive {

int const in_master_rank_( 1 ); //keep const for now
int const in_file_buf_rank_( 0 );
int const in_archive_rank_( 2 );
int const in_min_client_rank_( 3 );

using namespace core::options;
using namespace OptionKeys;
using namespace core;
///@details constructor.  Notice it calls the parent class!  It also builds some internal variables for determining
///which processor it is in MPI land.
MPIArchiveJobDistributor::MPIArchiveJobDistributor() :
  MPIFileBufJobDistributor( in_master_rank_, in_file_buf_rank_, in_min_client_rank_, true /*start empty*/ ),
	nr_notify_( 100 ),  //make this a cmdline option ...
	archive_rank_( in_archive_rank_ )
{}

///@brief dummy for master/slave version
void
MPIArchiveJobDistributor::go( protocols::moves::MoverOP mover )
{
	//copied MPIFileJobDistributor, because in this case the archive - process sends stop to FileBuf.
	utility::io::ozstream::enable_MPI_reroute( min_client_rank(), file_buf_rank() );
	if ( rank() == master_rank() ) {
		tr.Warning << "Master JD starts" << std::endl;
    master_go( mover );
	} else if ( rank() == file_buf_rank() ) {
		protocols::jd2::WriteOut_MpiFileBuffer buffer( file_buf_rank() );
		tr.Warning << "FileBuffer starts " << std::endl;
		buffer.run();
	} else if ( rank() == archive_rank() ) {
		tr.Warning << "Archive starts... " << std::endl;
		archive::ArchiveManager archive( archive_rank(), master_rank(), file_buf_rank() );
		archive.go();
		tr.Warning << "send STOP to FileBuffer " << std::endl;
		protocols::jd2::WriteOut_MpiFileBuffer buffer( file_buf_rank() );
		buffer.stop();
	} else {
		slave_go( mover );
  }

	// ideally these would be called in the dtor but the way we have the singleton pattern set up the dtors don't get
	// called
#ifdef USEMPI
 	MPI_Barrier( MPI_COMM_WORLD );
 	MPI_Finalize();
#endif
	if ( rank() == master_rank() ) {
		std::cerr << "MPI FINALIZED closing down... " << std::endl;
		std::cout << "MPI FINALIZED closing down... " << std::endl;
	}
}

bool
MPIArchiveJobDistributor::receive_batch( Size MPI_ONLY( source_rank ) ) {
core::util::prof_show();
#ifdef USEMPI
	MPI_Status status;
	int buf[ 2 ];
	//receive size of string
	MPI_Recv( buf, 2, MPI_INT, source_rank, MPI_JOB_DIST_TAG, MPI_COMM_WORLD, &status );
	Size size( buf[ 0 ]);
	Size id( buf[ 1 ] );
	//receive string
	std::string new_batch;
	char *cbuf = new char[ size+1 ];
	MPI_Recv( cbuf, size, MPI_CHAR, source_rank, MPI_JOB_DIST_TAG, MPI_COMM_WORLD, &status );

	if ( id == 0 ) { //use this as STOP signal!
		tr.Debug << "received STOP signal from Archive " << std::endl;
		return false;
	}

	new_batch.assign( cbuf, size );
	delete[] cbuf;
	tr.Debug << "received new batch " << new_batch << " with id " << id << std::endl;
	add_batch( new_batch, id );
#endif
	return true;
}


void
MPIArchiveJobDistributor::sync_batches( Size MPI_ONLY( slave_rank ) ) {
	PROF_START( util::ARCHIVE_SYNC_BATCHES );
#ifdef USEMPI
	tr.Trace << "Node " << rank() << " sync batches with " << slave_rank << std::endl;
	int buf[ 4 ];
	buf[ 1 ] = ADD_BATCH;
	MPI_Status status;
	//first tell master what our last known batch is
	Size slave_batch_size( nr_batches() );
	Size nr_to_have;
	if ( rank() != master_rank() ) { //SLAVE
		buf[ 0 ] = slave_batch_size;
		MPI_Send( &buf, 1, MPI_INT, master_rank(), MPI_JOB_DIST_TAG, MPI_COMM_WORLD );
	} else {  //MASTER
		MPI_Recv( &buf, 1, MPI_INT, slave_rank, MPI_JOB_DIST_TAG, MPI_COMM_WORLD, &status );
		slave_batch_size = buf[ 0 ];
	}
	tr.Trace << "Node " << rank() << " slave_batch_size " << slave_batch_size << std::endl;
	//now tell slave how many batches will be sent
	nr_to_have = nr_batches();
	if ( rank() != master_rank() ) { //SLAVE
		MPI_Recv( &buf, 1, MPI_INT, master_rank(), MPI_JOB_DIST_TAG, MPI_COMM_WORLD, &status );
		nr_to_have = buf[ 0 ];
	} else {  //MASTER
		buf[ 0 ] = nr_to_have;
		MPI_Send( &buf, 1, MPI_INT, slave_rank, MPI_JOB_DIST_TAG, MPI_COMM_WORLD );
	}
	tr.Trace << "Node " << rank() << " master_batch_size " << nr_to_have << std::endl;

	//now send the individual batches
	for ( Size send_id = slave_batch_size + 1; send_id <= nr_to_have; ++send_id ) {
		if ( rank() != master_rank() ) { //SLAVE
			receive_batch( master_rank() );
			// 			//receive size of string
			// 			MPI_Recv( buf, 1, MPI_INT, master_rank(), MPI_JOB_DIST_TAG, MPI_COMM_WORLD, &status );
			// 			Size size( buf[ 0 ]);
			// 			//receive string
			// 			std::string new_batch;
			// 			char *cbuf = new char[ size+1 ];
			// 			MPI_Status stat;
			// 			MPI_Recv( cbuf, size, MPI_CHAR, master_rank(), MPI_JOB_DIST_TAG, MPI_COMM_WORLD, &status );
			// 			new_batch.assign( cbuf, size );
			// 			delete[] cbuf;
			// 			add_batch( new_batch );
			tr.Trace << "nr_batches() " << nr_batches() << " send_id " << send_id << std::endl;
			runtime_assert( nr_batches() == send_id );
		} else {  //MASTER
			//send size of string
			buf[ 0 ] = batch( send_id ).size();
			buf[ 1 ] = send_id;
			MPI_Send(buf, 2, MPI_INT, slave_rank, MPI_JOB_DIST_TAG, MPI_COMM_WORLD );
			//send string
			MPI_Send(const_cast<char*> ( batch( send_id ).data()), batch( send_id ).size(), MPI_CHAR, slave_rank, MPI_JOB_DIST_TAG, MPI_COMM_WORLD );
		}
	}
	#endif
	PROF_STOP( util::ARCHIVE_SYNC_BATCHES );
}

void
MPIArchiveJobDistributor::master_to_archive( Size MPI_ONLY(tag) ) {
#ifdef USEMPI
	runtime_assert( rank() == master_rank() );
	runtime_assert( rank() != archive_rank() );
	Size const mpi_size( 6 );
	int mpi_buf[ mpi_size ];
	mpi_buf[ 0 ] = tag;
	mpi_buf[ 1 ] = current_batch_id();
	MPI_Send( &mpi_buf, mpi_size, MPI_INT, archive_rank(), MPI_ARCHIVE_TAG, MPI_COMM_WORLD );
#endif
}

void
MPIArchiveJobDistributor::batch_underflow() {
	if ( !( rank() == master_rank() ) ) {
		slave_to_master( BATCH_SYNC );
		sync_batches( rank() );
	} else if ( rank() == master_rank() ) {
		PROF_START( util::MPI_JD2_WAITS_FOR_ARCHIVE );
		tr.Debug << "no more batches... ask ArchiveManager if there is some more to do... wait..." << std::endl;
		master_to_archive( QUEUE_EMPTY );
		tr.Debug << "wait for answer on QUEUE-EMPTY msg..."<< std::endl;
		eat_signal( ADD_BATCH, archive_rank() );
		receive_batch( archive_rank() ); //how about some time-out
		tr.Debug << "...received " << std::endl;
		PROF_STOP( util::MPI_JD2_WAITS_FOR_ARCHIVE );
	}
}


bool
MPIArchiveJobDistributor::process_message( core::Size msg_tag, core::Size slave_rank, core::Size slave_job_id, core::Size slave_batch_id ) {
	runtime_assert( rank() == master_rank() );

	// send out any pending notifications to archive if present -- this is non-blocking
	_notify_archive(); //we should get here often enough... (basically every finished job)

	// now go thru messages
	switch ( msg_tag ) {
	case BATCH_SYNC:
		sync_batches( slave_rank );
		break;
	case ADD_BATCH:
		runtime_assert( slave_rank == archive_rank() );
		receive_batch( archive_rank() );
		break;
// 	case NOTIFICATION_QUERY:
// 		runtime_assert( slave_rank == archive_rank_ );
// 		if ( pending_notifications_.size() )  {
// 			_notify_archive();
// 		}
// 		break;
	default:
		return Parent::process_message( msg_tag, slave_rank, slave_job_id, slave_batch_id );
	}

	return true;
}

void
MPIArchiveJobDistributor::notify_archive( CompletionMessage const& msg ) {
	//TODO: check if there are older messages regarding this batch... if so ... remove
	pending_notifications_.push_back( msg );
}

#ifdef USEMPI
MPI_Request notify_request;
int notify_buf[ 6 ];
bool notify_first( true );
#endif

void MPIArchiveJobDistributor::_notify_archive() {
	PROF_START( util::MPI_NOTIFY_ARCHIVE );
	static core::util::Tracer tr("protocols.jd2.notifications");
	if ( pending_notifications_.size() == 0 ) return;
#ifdef USEMPI
	int flag( 1 );
	if ( !notify_first ) {
		tr.Debug << "test MPI-Send completion of last JOB_COMPLETION ( batch_" << notify_buf[ 1 ] << " ) message...";
		MPI_Status status;
		MPI_Test( &notify_request, &flag, &status ); //has last communication succeeded ? --- buffer is free again.
		int flag2;
		MPI_Test_cancelled( &status, &flag2 );
		tr.Debug << ( flag ? "completed " : "pending " ) << ( !flag2 ? "/ test succeeded " : "/ test cancelled" ) << std::endl;
	}
	if ( flag ) {

		CompletionMessage const& msg(	pending_notifications_.front() );
		tr.Debug << "send out JOB_COMPLETION " << msg.batch_id << std::endl;
		//		int notify_buf[ 6 ];
		notify_buf[ 0 ] = msg.msg_tag;//JOB_COMPLETION;
		notify_buf[ 1 ] = msg.batch_id;
		notify_buf[ 2 ] = msg.final ? 1 : 0;
		notify_buf[ 3 ] = msg.bad;
		notify_buf[ 4 ] = msg.good;
		notify_buf[ 5 ] = msg.njobs;
		MPI_Isend( &notify_buf, 6, MPI_INT,  archive_rank(), MPI_ARCHIVE_TAG, MPI_COMM_WORLD, &notify_request ); //don't block JobDistributor
		pending_notifications_.pop_front();
				notify_first = false;
	}
#endif
	PROF_STOP( util::MPI_NOTIFY_ARCHIVE );
}


void
MPIArchiveJobDistributor::notify_archive( core::Size batch_id ) {
	if ( nr_completed_[ batch_id ] + nr_new_completed_[ batch_id ] + nr_bad_[ batch_id ] == nr_jobs_[ batch_id ] ) {
		nr_completed_[ batch_id ] += nr_new_completed_[ batch_id ];
		nr_new_completed_[ batch_id ] = 0;
		notify_archive(  CompletionMessage( batch_id, true, nr_bad_[ batch_id ], nr_completed_[ batch_id ], nr_jobs_[ batch_id ] ) );
	} else if ( nr_new_completed_[ batch_id ] >= nr_notify_ ) {
		nr_completed_[ batch_id ] += nr_new_completed_[ batch_id ];
		nr_new_completed_[ batch_id ] = 0;
		notify_archive( CompletionMessage( batch_id, false, nr_bad_[ batch_id ], nr_completed_[ batch_id ], nr_jobs_[ batch_id ] ) );
	}
	//are we quickly running out of jobs? -- checking for equality to reduce number of messages -- is this safe? do we ever skip jobs?
	if ( nr_batches() == batch_id && ( (int) current_job_id() == (int) get_jobs().size() - (int) number_of_processors() ) ) {
		tr.Debug << "jobs are low... send QUEUE_EMPTY" << std::endl;
		pending_notifications_.push_front( CompletionMessage( batch_id, QUEUE_EMPTY ) );
		_notify_archive();
	}
}

void MPIArchiveJobDistributor::mark_job_as_completed( core::Size job_id, core::Size batch_id ) {
	Parent::mark_job_as_completed( job_id, batch_id );
	if ( rank() == master_rank() ) {
		runtime_assert( batch_id <= nr_jobs_.size() );
		nr_new_completed_[ batch_id ] += 1;
		notify_archive( batch_id );
	}
}

void MPIArchiveJobDistributor::mark_job_as_bad( core::Size job_id, core::Size batch_id ) {
	Parent::mark_job_as_bad( job_id, batch_id );
	if ( rank() == master_rank() ) {
		runtime_assert( batch_id <= nr_jobs_.size() );
		nr_bad_[ batch_id ] += nstruct_[ batch_id ];
		notify_archive( batch_id );
	}
}

void MPIArchiveJobDistributor::load_new_batch() {
	//	if ( current_batch_id() )	notify_archive( current_batch_id() );
	Parent::load_new_batch();
	if ( rank() == master_rank() ) { //in principle I'd rather do this in add_batch() but we need option of new batch for nstruct...
		while( nr_jobs_.size() < current_batch_id() ) {
			nr_jobs_.push_back( get_jobs().size() );
			nr_new_completed_.push_back( 0 );
			nr_completed_.push_back( 0 );
			nr_bad_.push_back( 0 );
			nstruct_.push_back( option[ out::nstruct ] ); ///Assumming that all JobInputters create nstruct jobs per input_tag...
		}
		runtime_assert( nr_jobs_.size() == current_batch_id() );
	}
}



}//archive
}//jd2
}//protocols
