Rosetta 3.5
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
GeneticAlgorithm.cc
Go to the documentation of this file.
1 // -*- mode:c++;tab-width:2;indent-tabs-mode:t;show-trailing-whitespace:t;rm-trailing-spaces:t -*-
2 // vi: set ts=2 noet:
3 //
4 // (c) Copyright Rosetta Commons Member Institutions.
5 // (c) This file is part of the Rosetta software suite and is made available under license.
6 // (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
7 // (c) For more information, see http://www.rosettacommons.org. Questions about this can be
8 // (c) addressed to University of Washington UW TechTransfer, email: license@u.washington.edu.
9 
10 /// @file GeneticAlgorithm.hh
11 /// @brief template class for genetic algorithm protocols
12 /// @author ashworth, based on template "pseudo"code by Colin Smith
13 
15 
19 
20 #include <utility/pointer/ReferenceCount.hh>
21 
22 #include <core/types.hh>
23 #include <basic/Tracer.hh>
24 
25 #include <utility/file/file_sys_util.hh> // file_exists
26 #include <utility/io/izstream.hh>
27 #include <utility/io/ozstream.hh>
28 #include <utility/pointer/owning_ptr.hh>
29 // AUTO-REMOVED #include <utility/pointer/access_ptr.hh>
30 #include <utility/vector1.hh>
31 
32 #include <numeric/random/random.fwd.hh>
33 
34 #include <boost/unordered_map.hpp>
35 
36 #include <algorithm> // std::copy
37 #include <iostream>
38 #include <set>
39 
40 #include <utility/exit.hh>
41 
42 
43 namespace protocols {
44 namespace genetic_algorithm {
45 
46 static basic::Tracer TR("protocols.genetic_algorithm");
47 
49  utility::pointer::ReferenceCount(),
50  fitness_function_(0),
51  entity_randomizer_(0),
52  entity_template_(0),
53  current_generation_(1),
54  max_generations_(0),
55  max_population_size_(0),
56  number_to_propagate_(1),
57  fraction_by_recombination_(0.5),
58  checkpoint_prefix_(""),
59  checkpoint_write_interval_(0),
60  checkpoint_gzip_(false),
61  checkpoint_rename_(false)
62 {}
63 
65 {
66  // ensure that checkpoint files are not accidentally reused
68 }
69 
72 {
73  EntityOP entity;
74  TraitEntityHashMap::iterator hash_it( entity_cache_.find( traits ) );
75  //Vec1Hash h;
76  if ( hash_it == entity_cache_.end() ) {
77  // make a new entity if it does not exist in the cache
78  entity = new_entity();
79  entity->set_traits( traits );
80  // add the new entity to the cache
81  entity_cache_[entity->traits()] = entity;
82  //TR << "(traits) Adding new entity " << *entity << " " << h( traits ) << std::endl;
83  } else {
84  // otherwise use the cached entity
85  entity = hash_it->second;
86  //TR << "(traits) Using cached entity" << h( traits ) << std::endl;
87  }
88  //TR << "Adding entity " << entity.get() << " to the current generation " << std::endl;
89  generations_[current_generation_].push_back( entity );
90  return entity;
91 }
92 
95 {
96  runtime_assert( entity );
97  TraitEntityHashMap::iterator hash_it( entity_cache_.find( entity->traits() ) );
98  //Vec1Hash h;
99  if ( hash_it == entity_cache_.end() ) {
100  // add the entity to the cache if it does not exist
101  entity_cache_[entity->traits()] = entity;
102  //TR << "(entity) Adding new entity " << *entity << " " << h( entity->traits() ) << std::endl;
103  } else {
104  // otherwise substitute the equivalent cached entity for the one given
105  entity = hash_it->second;
106  //TR << "(entity) Using cached entity" << h( entity->traits() ) << std::endl;
107  }
108  //TR << "Adding entity " << entity.get() << " to the current generation " << std::endl;
109  generations_[current_generation_].push_back( entity );
110  return entity;
111 }
112 
115 {
116  EntityOP entity;
117  TraitEntityHashMap::iterator hash_it( entity_cache_.find( traits ) );
118  if ( hash_it == entity_cache_.end() ) {
119  // make a new entity if it does not exist in the cache
120  entity = new_entity();
121  entity->set_traits( traits );
122  //std::cout << __LINE__ << " Parent entity was not found in entity_cache_" << std::endl;
123  entity_cache_[ traits ] = entity;
124  } else {
125  // otherwise use the cached entity
126  entity = hash_it->second;
127  }
128  parent_entities_.push_back( entity );
129  return entity;
130 }
131 
134 {
135  runtime_assert( entity );
136  TraitEntityHashMap::iterator hash_it( entity_cache_.find( entity->traits() ) );
137  if ( hash_it != entity_cache_.end() ) {
138  // use an equivalent hashed entity if it exists
139  entity = hash_it->second;
140  //std::cout << __LINE__ << " Parent entity was not found in entity_cache_" << std::endl;
141  entity_cache_[ entity->traits() ] = entity;
142  }
143  parent_entities_.push_back( entity );
144  return entity;
145 }
146 
148 
149 
150 void
152 {
153  runtime_assert(generations_[current_generation_].size());
154 
155  parent_entities_.insert(
156  parent_entities_.end(),
159  );
160 }
161 
162 ///@brief add the best entities from the previous generation
163 void
165 {
166  runtime_assert(current_generation_ > 1);
167 
169  std::sort( sorted_entities.begin(), sorted_entities.end(), lt_OP_deref< Entity > );
170 
171  // this will hopefully help keep population diversity higher
172  if (unique && size > 1) {
173  pop_iter new_last(std::unique(sorted_entities.begin(), sorted_entities.end(), eq_OP_deref< Entity >));
174  sorted_entities.erase(new_last, sorted_entities.end());
175  }
176 
177  for (core::Size i = 1; i <= size && i <= sorted_entities.size(); ++i) {
178  add_entity(sorted_entities[i]);
179  }
180 }
181 
184 {
185  core::Real best( 0.0 ); bool first_found = false;
186  for ( Size ii = 1; ii <= generations_[ current_generation_].size(); ++ii ) {
187  if ( generations_[ current_generation_ ][ ii ] ) {
188  if ( ! first_found || best > generations_[ current_generation_ ][ ii ]->fitness() ) {
189  best = generations_[ current_generation_ ][ ii ]->fitness();
190  first_found = true;
191  }
192  }
193  }
194  return best;
195 }
196 
197 
198 void
200 {
201  if ( size == 0 ) size = max_population_size_;
202  while ( generations_[current_generation_].size() < size ) {
203  add_entity( entity_randomizer_->random_entity() );
204  }
205 }
206 
207 ///@brief add entities that are recombinants of fit parents
208 void
210 {
211  runtime_assert(parent_entities_.size());
212 
213  if ( size == 0 ) size = max_population_size_;
214 
215  while ( generations_[current_generation_].size() < size ) {
216  // pick two random parents
219  entity_randomizer_->crossover( *child1, *child2 );
220  add_entity( child1 );
221  // rosetta++ only produced one child upon recombination
222  //if ( generations_[current_generation_].size() < size ) add_entity( child2 );
223  }
224 }
225 
226 ///@brief add entities that are mutants of fit parents
227 void
229 {
230  runtime_assert(parent_entities_.size());
231 
232  if ( size == 0 ) size = max_population_size_;
233 
234  while ( generations_[current_generation_].size() < size ) {
236  entity_randomizer_->mutate( *child );
237  add_entity( child );
238  }
239 }
240 
241 void
243 {
245  core::Size num_uncached(0);
247  it != end; ++it ) {
248  if ( ! (*it)->fitness_valid() ) {
249  //TR << "Evaluating fitness for entity " << (*it).get() << std::endl;
250  //if ( (*it) != entity_cache_[ (*it)->traits() ] ) {
251  // TR << "WEIRD: Evaluating fitness for entity that is not in the cache" << std::endl;
252  //}
253  // entity's fitness not valid
254  fitness_function_->evaluate( **it );
255  // allow intermediate checkpointing for very large populations
256  ++num_uncached;
257  if ( checkpoint_write_interval_ && num_uncached % checkpoint_write_interval_ == 0 ) {
259  }
260  }
261  }
263 }
264 
265 /// @brief progress to the next generation and generate new entities
266 /// @details
267 /// This method performs the following steps:
268 ///
269 /// 1. If parent entities were not already specified, sets the parents to the current generation
270 /// 2. Increments the generation counter
271 /// 3. Copies the best scoring entities from the previous generation to the new one
272 /// 4. Generates new entities by crossover and/or mutation
273 /// 5. Clears the parent entities now that they have been used
274 void
276 {
277  // if parents were not already added, use the current generation as parents
279  // increment the generation counter
281  // copy the best scoring entities to the next generation
283  // add entities via crossover and mutation
284  core::Size pop_size = generations_[current_generation_].size();
285  fill_by_crossover( static_cast<core::Size>( fraction_by_recombination_*(max_population_size_-pop_size)+pop_size ) );
287  // reset parent entities now that they've been used
288  parent_entities_.clear();
289 }
290 
291 bool
293 {
294  for (core::Size i = 1; i <= generations_[current_generation_].size(); ++i) {
295  if (!generations_[current_generation_][i]->fitness_valid()) return false;
296  }
297 
298  return true;
299 }
300 
301 bool
303 {
304  if (current_generation_ < max_generations_) return false;
306  return true;
307 }
308 
311 
312 void
314 {
315  max_generations_ = s;
317 }
318 
319 ///@brief returns variable number of best (const) entities via vector of pointers to them
321 GeneticAlgorithm::best_entities( core::Size num ) // nonconst method to permit sort
322 {
323  std::sort( generations_[current_generation_].begin(), generations_[current_generation_].end(), lt_OP_deref< Entity > );
326  while ( best_entities.size() < num && entity != generations_[current_generation_].end() ) {
327  best_entities.push_back( (*entity)() );
328  ++entity;
329  }
330  return best_entities;
331 }
332 
333 
334 ///@brief pick two random entities from an unordered vector, return the one whose fitness is better
335 Entity const &
337  utility::vector1< EntityCOP > const & pvec
338 ) const
339 {
340  using numeric::random::uniform;
341  Entity const & e1( *pvec[ static_cast<Size>( uniform() * pvec.size() ) + 1 ] );
342  Entity const & e2( *pvec[ static_cast<Size>( uniform() * pvec.size() ) + 1 ] );
343  return ( e1.fitness() < e2.fitness() ? e1 : e2 );
344 }
345 
348 
351 
354 
355 ///@brief true const (read-only) access to entity population
358 {
359  return generations_[gen_num];
360 }
361 
362 void
364  std::ostream & os,
365  core::Size gen_num
366 ) const
367 {
368  std::set<EntityOP> earlier_generations_set;
369  std::set<EntityOP> previous_generation_set;
370  std::set<EntityOP> current_generation_set;
371 
372  // make a set of entities in the previous generation
373  if (gen_num-1 >= 1) {
374  previous_generation_set.insert(generations_[gen_num-1].begin(), generations_[gen_num-1].end());
375  }
376 
377  // make a set of entities in earlier generations but not in the previous one
378  for (core::Size i = 1; i+2 <= gen_num; ++i) {
379  for (pop_const_iter iter(generations_[i].begin()), iter_end(generations_[i].end()); iter != iter_end;
380  ++iter) {
381  if (previous_generation_set.find(*iter) == previous_generation_set.end()) {
382  earlier_generations_set.insert(*iter);
383  }
384  }
385  }
386 
387  core::Size resurrected_entities(0);
388  core::Size heldover_entities(0);
389  core::Size duplicate_new_entities(0);
390  core::Size new_entities(0);
391  EntityOP best_new_entity(NULL);
392 
393  for (pop_const_iter iter(generations_[gen_num].begin()), iter_end(generations_[gen_num].end()); iter != iter_end;
394  ++iter) {
395  if (previous_generation_set.find(*iter) != previous_generation_set.end()) {
396  ++heldover_entities;
397  } else if (earlier_generations_set.find(*iter) != earlier_generations_set.end()) {
398  ++resurrected_entities;
399  } else if (current_generation_set.find(*iter) != current_generation_set.end()) {
400  ++duplicate_new_entities;
401  } else {
402  current_generation_set.insert(*iter);
403  ++new_entities;
404  if (!best_new_entity || ((*iter)->fitness_valid() && (*iter)->fitness() < best_new_entity->fitness())) {
405  best_new_entity = *iter;
406  }
407  }
408  }
409 
410  utility::vector1<EntityOP> sorted_generation(generations_[gen_num]);
411  std::sort( sorted_generation.begin(), sorted_generation.end(), lt_OP_deref< Entity > );
412  core::Real gen_size(static_cast<core::Real>(sorted_generation.size()));
413 
414  os << "Distinct new entities: " << new_entities << std::endl;
415  os << "Duplicate new entities: " << duplicate_new_entities << std::endl;
416  os << "Entities from previous generation: " << heldover_entities << std::endl;
417  os << "Entities resurrected from earlier generations: " << resurrected_entities << std::endl;
418  os << "Fitness Percentiles: 0%=" << sorted_generation.front()->fitness()
419  << " 25%=" << sorted_generation[static_cast<core::Size>(ceil(gen_size*.25))]->fitness()
420  << " 50%=" << sorted_generation[static_cast<core::Size>(ceil(gen_size*.50))]->fitness()
421  << " 75%=" << sorted_generation[static_cast<core::Size>(ceil(gen_size*.75))]->fitness()
422  << " 100%=" << sorted_generation.back()->fitness() << std::endl;
423  os << "Best new entity:" << '\n';
424  os << *best_new_entity << std::endl;
425 }
426 
427 void
428 GeneticAlgorithm::print_population( std::ostream & os ) const
429 {
431  it != end; ++it ) {
432  os << **it << '\n';
433  }
434  os << std::flush;
435 }
436 
437 
438 void
439 GeneticAlgorithm::print_cache( std::ostream & os ) const
440 {
441  for ( TraitEntityHashMap::const_iterator it( entity_cache_.begin() ), end( entity_cache_.end() );
442  it != end; ++it ) {
443  if (it->second->fitness_valid()) {
444  it->second->write_checkpoint(os);
445  os << '\n';
446  }
447  }
448  os << std::flush;
449 }
450 
453 {
454  if ( checkpoint_prefix_ == "" ) return "";
455  std::string filename(checkpoint_prefix_ + ".ga.entities");
456  filename += suffix;
457  if (checkpoint_gzip_) filename += ".gz";
458  return filename;
459 }
460 
461 bool
462 GeneticAlgorithm::read_entities_checkpoint( bool overwrite /* = false */ )
463 {
464  // if cache is not empty, then loading from checkpoint file is assumed not necessary by default
465  if ( !entity_cache_.empty() && !overwrite ) return false;
466  if ( checkpoint_prefix_ == "" ) return false;
468  utility::io::izstream file;
469  file.open( filename.c_str() );
470  if ( !file ) return false;
471  TR(basic::t_info) << "Reading cached entities from file " << filename << '\n';
472 
473  std::string line;
474  core::Size counter(0);
475  EntityOP entity = new_entity();
476  while ( entity->read_checkpoint(file) ) {
477  TR(basic::t_debug) << *entity << '\n';
478  entity_cache_[ entity->traits() ] = entity;
479  ++counter;
480  entity = new_entity();
481  }
482  TR(basic::t_debug) << std::flush;
483  file.close();
484 
485  TR(basic::t_info) << "Read " << counter << " cached fitnesses" << std::endl;
486  return true;
487 }
488 
489 bool
491 {
492  if ( checkpoint_prefix_ == "" ) return false;
494  std::string const filename_tmp(entities_checkpoint_filename(".tmp"));
495  utility::io::ozstream file( filename_tmp.c_str() );
496  if ( !file ) {
497  std::cerr << "trouble opening file " << filename << " for writing" << '\n';
498  return false;
499  }
500 
501  print_cache( file );
502 
503  file.close();
504 
505  // atomically replace the previous checkpoint file
506  if ( std::rename(filename_tmp.c_str(), filename.c_str()) ) {
507  std::cerr << "trouble renaming file " << filename_tmp << " to " << filename << std::endl;
508  return false;
509  }
510 
511  return true;
512 }
513 
516 {
517  if ( checkpoint_prefix_ == "" ) return "";
518  std::string filename(checkpoint_prefix_ + ".ga.generations");
519  filename += suffix;
520  if (checkpoint_gzip_) filename += ".gz";
521  return filename;
522 }
523 
524 
525 /// This seems to duplicate the functionality of the Entity's write_checkpoint function...
526 bool
528 {
529  if ( checkpoint_prefix_ == "" ) return false;
531  std::string const filename_tmp(generations_checkpoint_filename(".tmp"));
532  utility::io::ozstream file( filename_tmp.c_str() );
533  if ( !file ) {
534  std::cerr << "trouble opening file " << filename << " for writing" << std::endl;
535  return false;
536  }
537 
538  for (core::Size i = 1; i <= generations_.size(); ++i) {
539  utility::vector1< EntityOP > const & generation(generations_[i]);
540  if (generation.size()) {
541  file << "generation " << i << '\n';
542  for (core::Size j = 1; j <= generation.size(); ++j) {
543  EntityElements const & traits(generation[j]->traits());
544  for (core::Size k = 1; k <= traits.size(); ++k) {
545  if (k != 1) file << ' ';
546  file << traits[k]->to_string();
547  }
548  file << '\n';
549  }
550  }
551  }
552 
553  file.close();
554 
555  // atomically replace the previous checkpoint file
556  if ( std::rename(filename_tmp.c_str(), filename.c_str()) ) {
557  std::cerr << "trouble renaming file " << filename_tmp << " to " << filename << std::endl;
558  return false;
559  }
560 
561  return true;
562 }
563 
564 bool
566 {
567  if ( checkpoint_prefix_ == "" ) return false;
569  utility::io::izstream file( filename.c_str() );
570  if ( !file ) {
571  std::cerr << "trouble opening file " << filename << " for reading" << std::endl;
572  return false;
573  }
574 
575  core::Size gen_num = 0;
576  std::string line;
577  while (file.getline(line)) {
578  std::istringstream iss(line);
579  std::string word;
580  if (!(iss >> word)) return false;
581  if ( word == "generation" ) {
582  if (!(iss >> gen_num)) return false;
583  runtime_assert(gen_num > 0);
584  runtime_assert(gen_num <= max_generations_);
585  if (generations_.size() < gen_num) generations_.resize(gen_num);
586  current_generation_ = gen_num;
587  } else {
588  // make sure a generation number has been set
589  runtime_assert(gen_num > 0);
590  EntityElements traits;
591  traits.push_back( EntityElementFactory::get_instance()->element_from_string(word));
592  while (iss >> word) {
593  traits.push_back( EntityElementFactory::get_instance()->element_from_string(word));
594  }
595  add_entity(traits);
596  }
597  }
598 
599  file.close();
600 
601  return false;
602 }
603 
604 bool
606 {
607  // entities should be read first
608  if ( !read_entities_checkpoint() ) return false;
609  // only read generations if the entities were read successfully
610  return (read_generations_checkpoint());
611 }
612 
613 void
615 {
616  if ( checkpoint_prefix_ == "" ) return;
617 
618  std::string suffix(".old");
619 
621  std::rename( entities_checkpoint_filename().c_str(), entities_checkpoint_filename(suffix).c_str() );
622  }
623 
625  std::rename( generations_checkpoint_filename().c_str(), generations_checkpoint_filename(suffix).c_str() );
626  }
627 }
628 
631 
632 
635 {
636  if (entity_template_) {
637  return entity_template_->clone();
638  } else {
639  return new Entity;
640  }
641 }
642 
643 
644 } // namespace genetic_algorithm
645 } // namespace protocols
646