32 #include <basic/Tracer.hh>
38 #include <utility/exit.hh>
39 #include <utility/string_util.hh>
40 #include <utility/vector1.hh>
41 #include <utility/pointer/owning_ptr.hh>
42 #include <utility/pointer/ReferenceCount.hh>
45 #include <ObjexxFCL/format.hh>
53 #include <basic/options/keys/optE.OptionKeys.gen.hh>
56 #include <basic/options/option.hh>
61 using namespace scoring;
62 using namespace optimization;
65 namespace optimize_weights {
67 using namespace ObjexxFCL::fmt;
71 static basic::Tracer
TR(
"protocols.optimize_weights.OptEMultifunc");
73 using namespace numeric::expression_parser;
80 OptEMultifunc::OptEMultifunc(
88 num_energy_dofs_( num_free_in ),
91 opte_data_( opte_data_in ),
92 fixed_terms_( fixed_terms_in ),
93 score_list_( score_list_in ),
94 fixed_score_list_( fixed_score_list_in ),
95 fix_reference_energies_( false ),
97 component_weights_( component_weights ),
100 distribute_over_mpi_( false )
103 MPI_Comm_rank( MPI_COMM_WORLD, &
mpi_rank_);
105 if ( basic::options::option[ basic::options::OptionKeys::optE::mpi_weight_minimization ] ) {
125 num_energy_dofs_( num_free_in ),
126 num_ref_dofs_( reference_energies_in.
size() ),
127 num_total_dofs_( num_free_in + reference_energies_in.
size() ),
128 opte_data_( opte_data_in ),
129 fixed_terms_( fixed_terms_in ),
130 score_list_( score_list_in ),
131 fixed_score_list_( fixed_score_list_in ),
132 fix_reference_energies_( false ),
133 starting_reference_energies_( reference_energies_in ),
134 component_weights_( component_weights ),
137 distribute_over_mpi_( false )
140 MPI_Comm_rank( MPI_COMM_WORLD, &
mpi_rank_);
142 if ( basic::options::option[ basic::options::OptionKeys::optE::mpi_weight_minimization ] ) {
169 for (
Size ii = 1; ii <= vars.size(); ++ii ) {
170 local_vars[ ii ] = vars[ ii ];
181 Multivec dummy( local_vars.size(), 0.0 );
198 if ( basic::options::option[ basic::options::OptionKeys::optE::limit_bad_scores ].user() )
200 static Size count = 0;
203 utility::exit(__FILE__,__LINE__,
"Counted over 100,000 inf/NaN scores. Admitting defeat now.");
205 std::cerr <<
"vars: " << std::endl;
206 for (
Size ii = 1; ii <= local_vars.size(); ++ii ) {
207 if ( ii != 1 ) std::cerr <<
", ";
208 std::cerr << ii <<
" " << local_vars[ ii ];
210 std::cerr << std::endl;
221 TR <<
"OptEMultifunc " << score <<
"\n";
223 for (
Size ii = 1; ii <= local_vars.size(); ++ii ) {
224 TR <<
" " << local_vars[ ii ];
228 for (
Size ii = 1; ii <= local_vars.size(); ++ii ) {
229 TR <<
" " << dummy[ ii ];
250 for (
Size ii = 1; ii <= vars.size(); ++ii ) {
251 local_vars[ ii ] = vars[ ii ];
260 for (
Size ii(1); ii <= dE_dvars.size(); ++ii ) dE_dvars[ ii ] = 0.0;
269 for(
Size ii = 1 ; ii <= local_dE_dvars.size() ; ++ii ) {
271 if( std::isinf(local_dE_dvars[ ii ]) ) std::cerr <<
"Introduced INF deriv at " << ii <<
" with " <<
OptEPositionDataFactory::optE_type_name( (*itr)->type() ) <<
" " << (*itr)->tag() << std::endl;
272 else if( std::isnan(local_dE_dvars[ ii ]) ) std::cerr <<
"Introduced NAN deriv at " << ii <<
" with " <<
OptEPositionDataFactory::optE_type_name( (*itr)->type() ) <<
" " << (*itr)->tag() << std::endl;
278 for (
Size ii = 1; ii <= dE_dvars.size(); ++ii ) {
279 dE_dvars[ ii ] = local_dE_dvars[ ii ];
288 TR <<
"score: " << F(9,5,score) << std::endl;
290 for (
Size ii = 1 ; ii <= dE_dvars.size() ; ++ii )
TR <<
" " << F(9,3,dE_dvars[ ii ]);
293 for (
Size ii = 1; ii <= vars.size(); ++ii ) {
294 TR <<
" " << F(9,3,vars[ ii ]);
313 for( ScoreTypes::const_iterator itr =
score_list_.begin(),
314 end_itr =
score_list_.end(); itr != end_itr; ++itr ) {
315 dofs[ dof_index++ ] = start_vals[ *itr ];
339 for( ScoreTypes::const_iterator itr =
score_list_.begin(),
341 itr != end_itr ; ++itr ) {
342 return_map[ *itr ] = dofs[ dof_index++ ];
349 itr != end_itr ; ++itr ) {
381 MPI_Bcast( &message, 1, MPI_INT, 0, MPI_COMM_WORLD );
389 MPI_Send( & my_func, 1, MPI_DOUBLE, 0, 1, MPI_COMM_WORLD );
392 dE_dvars.resize( vars.size() );
393 std::fill( dE_dvars.begin(), dE_dvars.end(), 0.0 );
394 dfunc( vars, dE_dvars );
395 double * dE_dvars_raw =
new double[ vars.size() ];
396 for (
Size ii = 1; ii <= vars.size(); ++ii ) { dE_dvars_raw[ ii - 1 ] = dE_dvars[ ii ]; }
397 MPI_Send( dE_dvars_raw, vars.size(), MPI_DOUBLE, 0, 1, MPI_COMM_WORLD );
401 std::cerr <<
"ERROR: Unrecognized message from root node: " << message << std::endl;
416 MPI_Bcast( &message, 1, MPI_INT, 0, MPI_COMM_WORLD );
433 MPI_Bcast( &message, 1, MPI_INT, 0, MPI_COMM_WORLD );
454 MPI_Bcast( &message, 1, MPI_INT, 0, MPI_COMM_WORLD );
470 int vars_size = vars.size();
471 MPI_Bcast( & vars_size, 1, MPI_INT, 0, MPI_COMM_WORLD );
473 double * raw_vars =
new double[ vars_size ];
474 for (
int ii = 1; ii <= vars_size; ++ii ) {
475 raw_vars[ ii - 1 ] = vars[ ii ];
477 MPI_Bcast( raw_vars, vars_size, MPI_DOUBLE, 0, MPI_COMM_WORLD );
492 MPI_Bcast( & vars_size, 1, MPI_INT, 0, MPI_COMM_WORLD );
493 vars.resize( vars_size );
495 double * raw_vars =
new double[ vars_size ];
496 MPI_Bcast( raw_vars, vars_size, MPI_DOUBLE, 0, MPI_COMM_WORLD );
497 for (
int ii = 1; ii <= vars_size; ++ii ) {
498 vars[ ii ] = raw_vars[ ii - 1 ];
516 MPI_Recv( & ii_func, 1, MPI_DOUBLE, ii, 1, MPI_COMM_WORLD, & stat );
536 double * dE_dvars_raw =
new double[ dE_dvars.size() ];
538 MPI_Recv( dE_dvars_raw, dE_dvars.size(), MPI_DOUBLE, ii, 1, MPI_COMM_WORLD, & stat );
539 for (
Size jj = 1; jj <= dE_dvars.size(); ++jj ) {
540 dE_dvars[ jj ] += dE_dvars_raw[ jj - 1 ];
543 delete [] dE_dvars_raw;
555 free_score_list_( free_score_list ),
556 fixed_score_list_( fixed_score_list ),
557 fixed_scores_( fixed_scores ),
560 multifunc_( optEfunc ),
562 n_real_dofs_( free_count )
564 using namespace basic::options;
565 using namespace basic::options::OptionKeys;
566 using namespace core::scoring;
574 if ( ! option[ optE::wrap_dof_optimization ].user() ) {
575 utility_exit_with_message(
"Error in WrapperOptEMultifunc constructor. Cannot create WrapperOptEMultifunc if optE::wrap_dof_optimization is not on the command line");
578 ArithmeticScanner as;
580 for (
Size ii = 1; ii <= free_score_list.size(); ++ii ) {
585 as.add_variable( iiname );
588 for (
Size ii = 1; ii <= fixed_score_list.size(); ++ii ) {
591 as.add_variable( iiname );
599 std::ifstream wrapper_file( option[ optE::wrap_dof_optimization ]()().c_str() );
600 bool finished_new_dof_header(
false );
603 while ( wrapper_file ) {
606 wrapper_file >> dof_dec;
607 if ( dof_dec ==
"" ) {
608 if ( !wrapper_file ) {
611 utility_exit_with_message(
"Expected NEW_DOF or DEPENDENT_DOF from " + option[ optE::wrap_dof_optimization ]()() +
" but got empty string" );
614 wrapper_file >> dof_name;
615 std::cout <<
"READ: " << dof_dec <<
" " << dof_name << std::endl;
616 if ( dof_dec ==
"NEW_DOF" ) {
617 if ( finished_new_dof_header ) {
618 utility_exit_with_message(
"Encountered NEW_DOF declaration after a DEPENDENT_DOF delcaration. All NEW_DOF declarations must be at the top of the file" );
622 as.add_variable( dof_name );
624 }
else if ( dof_dec ==
"DEPENDENT_DOF" ) {
625 if ( ! finished_new_dof_header ) {
626 finished_new_dof_header =
true;
631 utility_exit_with_message(
"Error in WrapperOptEMultifunc::WrapperOptEMultifunc()\nDid not find dof " + dof_name +
" in valid_variable_names_ set; either it is not a free dof or is listed as a dependent dof twice" );
635 wrapper_file >> equals_sign;
636 if ( equals_sign !=
"=" ) {
637 utility_exit_with_message(
"Expected an equals sign after reading 'DEPENDENT_DOF " + dof_name +
"' but instead read " + equals_sign );
641 getline( wrapper_file, expression );
642 std::cout <<
"READ EXPRESSION: " << expression << std::endl;
643 TokenSetOP tokens = as.scan( expression );
644 ArithmeticASTExpression ast_expression;
645 ast_expression.parse( *tokens );
648 numeric::expression_parser::ExpressionCOP derived_dof_expression = expression_creator.create_expression_tree( ast_expression );
653 std::cout <<
"Created expression for " << dof_name <<
" index# " << derived_dof_index << std::endl;
655 utility_exit_with_message(
"Expected either NEW_DOF or DEPENDENT_DOF, but got " + dof_dec +
" from file " + option[ optE::wrap_dof_optimization ]()() );
690 Size count_real_dofs( 1 );
691 for ( std::map< std::string, OptEVariableExpressionOP >::iterator
693 iter != iter_end; ++iter ) {
694 iter->second->set_id( count_real_dofs );
701 for ( std::set< std::string >::const_iterator
704 variter != variter_end; ++variter ) {
705 numeric::expression_parser::ExpressionCOP iiexp_dvar = iiexp->differentiate( *variter );
706 if ( iiexp_dvar == 0 ) {
707 utility_exit_with_message(
"Error constructing parital derivative for '" +
709 "' by variable '" + *variter +
"'. Null pointer returned." );
712 std::cout <<
"Adding dof derivative expression for " << *variter <<
" index#: " << varindex <<
" which appears in the expression for optEdof # " << ii << std::endl;
728 Real score = (*multifunc_)( optEvars );
730 TR <<
"WrapperOptEMultifunc func: " << F(7,2,score) << std::endl;
732 for (
Size ii = 1; ii <= vars.size(); ++ii ) {
733 TR <<
" " << vars[ ii ];
747 Multivec dmultifunc_dvars( optEvars.size() );
748 std::fill( dE_dvars.begin(), dE_dvars.end(), 0.0 );
749 multifunc_->dfunc( optEvars, dmultifunc_dvars );
752 for ( std::list< std::pair< Size, numeric::expression_parser::ExpressionCOP > >::const_iterator
755 iter != iter_end; ++iter ) {
756 dE_dvars[ ii ] += (dmultifunc_dvars[ iter->first ]) * ( (*(iter->second))());
760 if (
TR.visible() ) {
761 TR <<
"WrapperOptEMultifunc dfuncs:";
762 for(
Size ii = 1 ; ii <= dE_dvars.size() ; ++ii )
TR <<
" " << F(7,2,dE_dvars[ ii ]);
778 for ( std::map< std::string, OptEVariableExpressionOP >::const_iterator
780 iter != iter_end; ++iter ) {
781 iter->second->update_value_from_list( vars );
800 ostr <<
"WrapperOptEMultifunc dofs:\n";
801 for ( std::map< std::string, OptEVariableExpressionOP >::const_iterator
803 iter != iter_end; ++iter ) {
804 iter->second->update_value_from_list( vars );
805 ostr << iter->first <<
" : " << (*(iter->second))() <<
"\n";
810 numeric::expression_parser::VariableExpressionOP
827 varexp->set_value(
fixed_scores_[ fixed_variable_scoretype ] );
837 utility_exit_with_message(
"Variable '" + varname +
"' appearing on the right hand side of a DEPENDENT_DOF statement\nwas previously listed as a dependent variable" );
839 std::cerr <<
"Error: variable expression with name '" << varname <<
"' is not a valid variable name." << std::endl;
840 std::cerr <<
"Free variables:" << std::endl;
841 for ( std::set< std::string >::const_iterator
843 iter != iter_end; ++iter ) {
844 std::cerr << *iter << std::endl;
846 std::cerr <<
"Fixed variables:" << std::endl;
847 for ( std::set< std::string >::const_iterator
849 iter != iter_end; ++iter ) {
850 std::cerr << *iter << std::endl;
852 std::cerr <<
"New variables:" << std::endl;
853 for ( std::set< std::string >::const_iterator
855 iter != iter_end; ++iter ) {
856 std::cerr << *iter << std::endl;
860 utility_exit_with_message(
"Could not register variable '" + varname +
"' as it is neither a valid free, fixed nor new DOF" );
891 set_value( value_vector[
id_ ] );
900 : multifunc_( multifunc )
905 numeric::expression_parser::ExpressionCOP
908 return multifunc_->register_variable_expression( node.variable_name() );
911 numeric::expression_parser::ExpressionCOP
913 FunctionTokenCOP
function,
917 utility_exit_with_message(
"WrappedOptEExpressionCreator cannot process function " + function->name() );