Rosetta 3.5
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
trie_vs_trie.hh
Go to the documentation of this file.
1 // -*- mode:c++;tab-width:2;indent-tabs-mode:t;show-trailing-whitespace:t;rm-trailing-spaces:t -*-
2 // vi: set ts=2 noet:
3 //
4 // (c) Copyright Rosetta Commons Member Institutions.
5 // (c) This file is part of the Rosetta software suite and is made available under license.
6 // (c) The Rosetta software is developed by the contributing members of the Rosetta Commons.
7 // (c) For more information, see http://www.rosettacommons.org. Questions about this can be
8 // (c) addressed to University of Washington UW TechTransfer, email: license@u.washington.edu.
9 
10 /// @file core/scoring/trie/trie_vs_trie.hh
11 /// @brief
12 /// @author Andrew Leaver-Fay (aleaverfay@gmail.com)
13 
14 #ifndef INCLUDED_core_scoring_trie_trie_vs_trie_hh
15 #define INCLUDED_core_scoring_trie_trie_vs_trie_hh
16 
17 // Unit Headers
19 
20 // Package Headers
21 // AUTO-REMOVED #include <core/scoring/trie/RotamerTrie.hh>
22 
23 // Project Headers
24 #include <core/types.hh>
25 
26 
27 // STL Headers
28 // AUTO-REMOVED #include <iostream>
29 
30 // ObjexxFCL Headers
31 // AUTO-REMOVED #include <ObjexxFCL/FArray1D.hh>
32 // AUTO-REMOVED #include <ObjexxFCL/FArray1A.hh>
33 #include <ObjexxFCL/FArray2D.hh>
34 #include <ObjexxFCL/FArray2A.hh>
35 
36 // Utility Headers
37 // AUTO-REMOVED #include <utility/vector1.hh>
38 
39 #include <utility/vector1_bool.hh>
40 
41 
42 namespace core {
43 namespace scoring {
44 namespace trie {
45 
46 //typedef core::PackerEnergy core::PackerEnergy;
47 
48 /// @brief trie vs trie algorithm, templated on the Atom type that each TrieAtom is templated on,
49 /// along with the Count Pair data (two tries must share the same Atom type to be used for the
50 /// trie-vs-trie algorithm, but may contain different peices of count-pair data), on the kind of
51 /// count pair function used, and finally, on the score function itself.
52 ///
53 template < class AT, class CPDAT1, class CPDAT2, class CPFXN, class SFXN >
54 void
56  RotamerTrie< AT, CPDAT1 > const & trie1,
57  RotamerTrie< AT, CPDAT2 > const & trie2,
58  CPFXN & count_pair,
59  SFXN & score_function,
60  ObjexxFCL::FArray2D< core::PackerEnergy > & pair_energy_table,
61  ObjexxFCL::FArray2D< core::PackerEnergy > & temp_table
62 )
63 {
64  /*
65  std::cout << "I made it" << std::endl;
66  score_function.print();
67  trie1.print();
68  trie2.print();
69  CPFXN::print();
70  */
71 
72  using namespace ObjexxFCL;
73 
74  ObjexxFCL::FArray2A< core::PackerEnergy > rot_rot_table(
75  temp_table,
76  trie2.num_unique_rotamers(),
77  trie1.num_unique_rotamers() );
78 
79  DistanceSquared const hydrogen_interaction_cutoff = score_function.hydrogen_interaction_cutoff2();
80 
81  typename utility::vector1< TrieNode < AT, CPDAT1 > > const & trie1_atoms = trie1.atoms();
82  Size const trie1_natoms = trie1.atoms().size();
83  //Size const trie1_num_unique_rotamers = trie1.num_unique_rotamers();
84 
85  typename utility::vector1 < TrieNode < AT, CPDAT2 > > const & trie2_atoms = trie2.atoms();
86  Size const trie2_natoms = trie2.atoms().size();
87 
88  Size const trie2_num_heavyatoms = trie2.num_heavy_atoms();
89  Size const trie2_num_unique_rotamers = trie2.num_unique_rotamers();
90 
91  Size const trie1_max_branch_depth = trie1.max_branch_depth();
92  Size const trie1_max_heavyatom_depth = trie1.max_heavyatom_depth();
93 
94  FArray2D< core::PackerEnergy > at_v_rot_stack( trie2_num_unique_rotamers, trie1_max_branch_depth, 0.0);
95 
96  FArray2D_int parent_heavy_wi_hydrogen_cutoff(trie2_num_heavyatoms, trie1_max_heavyatom_depth, true);
97  FArray2D_int r_heavy_skip_s_subtree(trie2_num_heavyatoms, trie1_max_heavyatom_depth, false);
98 
99  utility::vector1< core::PackerEnergy > energy_stack( std::max( trie1.max_atom_depth() + 1, trie2.max_atom_depth() + 1 ));
100  energy_stack[1] = 0;
101  //energy_stack[0] = 0;
102 
103  utility::vector1< Size > r_heavy_depth_stack( trie1.max_branch_depth() + 1 );
104  utility::vector1< Size > s_heavy_depth_stack( trie2.max_branch_depth() + 1 );
105 
106  //utility::vector1< bool > parent_heavy_wi_hcut_stack( trie2.max_heavyatom_depth() );
107  int * parent_heavy_wi_hcut_stack = new int[ trie2.max_heavyatom_depth() + 2 ]; // 17% of time spent looking up utility::vector1<bool>
108 
109  utility::vector1< Size > s_sibling_stack( trie2.max_atom_depth() + 1 );
110  s_sibling_stack[1] = trie2_natoms + 1; // out-of-range sibling
111  for ( Size ii = 2; ii <= trie2.max_atom_depth(); ++ii ) {
112  s_sibling_stack[ii] = 0;
113  }
114 
115  Size r_curr_stack_top = 2; //immediately decremented to 1
116  Size s_curr_stack_top;
117  Size r_rotamers_seen = 0;
118  Size s_rotamers_seen;
119  Size s_heavyatoms_seen;
120 
121  r_heavy_depth_stack[1] = 0;
122 
123  for ( Size ii = 1; ii <= trie1_natoms; ++ii )
124  {
125  //std::cout << "tvt with ii = " << ii << std::endl;
126  /*typename*/ TrieNode< AT, CPDAT1 > const & r = trie1_atoms[ ii ];
127  energy_stack[1] = 0;
128 
129  s_rotamers_seen = 0;
130  s_heavyatoms_seen = 0;
131 
132  if ( r.first_atom_in_branch() ) --r_curr_stack_top;
133 
134  if (r.has_sibling() ) { //push - copy stack downwards
135 
136  ++r_curr_stack_top;
137  //std::cout << "r.has_sibling(): new stack top: " << r_curr_stack_top << std::endl;
138 
139  // when I previously dimensioned by trie1_num_unique_rotamers (a bug) I didn't get an assertion failure here!
140  // wtf?!
141  //FArray1A< core::PackerEnergy > at_rot_array_d_proxy(at_v_rot_stack(1, r_curr_stack_top), trie2_num_unique_rotamers);
142  //FArray1A< core::PackerEnergy > at_rot_array_dminus1_proxy(at_v_rot_stack(1, r_curr_stack_top-1), trie2_num_unique_rotamers);
143  //at_rot_array_d_proxy = at_rot_array_dminus1_proxy;
144  for ( Size jj = 1, litop = at_v_rot_stack.index( 1, r_curr_stack_top ),
145  liprev = at_v_rot_stack.index( 1, r_curr_stack_top - 1 );
146  jj <= trie2_num_unique_rotamers; ++jj, ++litop, ++liprev ) {
147  at_v_rot_stack[ litop ] = at_v_rot_stack[ liprev ];
148  }
149 
150  r_heavy_depth_stack[ r_curr_stack_top ] = r_heavy_depth_stack[ r_curr_stack_top-1 ];
151  //std::cout << "r_heavy_depth_stack[" << r_curr_stack_top << "] = " << r_heavy_depth_stack[ r_curr_stack_top ] << std::endl;
152  }
153 
154  //++r_tree_depth_stack[ r_curr_stack_top ];
155 
156  if (! r.is_hydrogen() ) ++r_heavy_depth_stack[ r_curr_stack_top ];
157 
158 
159  //FArray1A< core::PackerEnergy > at_rot_array_proxy(at_v_rot_stack(1, r_curr_stack_top), trie2_num_unique_rotamers);
160 
161  //FArray1A_int parent_wi_h_dist( parent_heavy_wi_hydrogen_cutoff( 1,
162  // r_heavy_depth_stack[ r_curr_stack_top ] ), trie2_num_heavyatoms );
163 
164  //FArray1A_int r_heavy_skip_s_subtree_proxy
165  // ( r_heavy_skip_s_subtree(1, r_heavy_depth_stack[ r_curr_stack_top ] ), trie2_num_heavyatoms);
166 
167  if (r.is_hydrogen()) {
168  s_curr_stack_top = 2;
169  s_heavy_depth_stack[1] = 0;
170  //s_tree_depth_stack[1] = 0;
171 
172  for ( Size jj = 1; jj <= trie2_natoms; ++jj ) {
173  //std::cerr << " ii: " << ii << " is H, jj : " << jj << " " << s_curr_stack_top << " ";
174  //std::cerr << "s_sibling_stack [";
175  //for ( Size kk = 1; kk <= s_curr_stack_top; ++kk )
176  // std::cerr << s_sibling_stack[kk] << " ";
177  //std::cerr << " " << std::endl;
178 
179 
180  TrieNode< AT, CPDAT2 > const & s = trie2_atoms[ jj ];
181 
182  if ( s.first_atom_in_branch() )
183  --s_curr_stack_top;
184 
185  if (s.has_sibling() ) { //push - copy stack downwards
186  ++s_curr_stack_top;
187  energy_stack[s_curr_stack_top] = energy_stack[s_curr_stack_top - 1];
188  s_heavy_depth_stack[ s_curr_stack_top] = s_heavy_depth_stack[ s_curr_stack_top - 1];
189  //s_tree_depth_stack[ s_curr_stack_top]
190  // = s_tree_depth_stack[ s_curr_stack_top - 1];
191  s_sibling_stack[s_curr_stack_top] = s.sibling();
192  }
193  //++s_tree_depth_stack[ s_curr_stack_top];
194 
195  if (!s.is_hydrogen()) {
196  ++s_heavyatoms_seen;
197  ++s_heavy_depth_stack[s_curr_stack_top];
198  parent_heavy_wi_hcut_stack[s_heavy_depth_stack[s_curr_stack_top]]
199  = parent_heavy_wi_hydrogen_cutoff(s_heavyatoms_seen, r_heavy_depth_stack[ r_curr_stack_top ] );
200 
201  if ( r_heavy_skip_s_subtree(s_heavyatoms_seen, r_heavy_depth_stack[ r_curr_stack_top ]) ) {
202  if (energy_stack[s_curr_stack_top] != 0.0f ) {
203  for ( Size kk = s_rotamers_seen + 1;
204  kk <= s_rotamers_seen + s.num_rotamers_in_subtree();
205  ++kk ) {
206  at_v_rot_stack(kk, r_curr_stack_top) += energy_stack[s_curr_stack_top];
207  }
208  }
209  s_rotamers_seen += s.num_rotamers_in_subtree();
210  jj = s_sibling_stack[s_curr_stack_top] - 1;
211  continue;
212  }
213  }
214 
215  Real weight(1.0); core::PackerEnergy e(0.0); Size path_dist(0);
216  if ( parent_heavy_wi_hcut_stack[s_heavy_depth_stack[s_curr_stack_top] ] &&
217  count_pair( r.cp_data(), s.cp_data(), weight, path_dist) ) {
218  if ( s.is_hydrogen() ) {
219  e = score_function.hydrogenatom_hydrogenatom_energy(r.atom(), s.atom(), path_dist );
220  energy_stack[ s_curr_stack_top ] += weight * e;
221  //std::cout << "h/h atom pair energy: " << ii << " & " << jj << " = " << weight * e << "( unweighted: " << e << ") estack: " << energy_stack[ s_curr_stack_top ] << std::endl;
222 
223  } else {
224  e = score_function.hydrogenatom_heavyatom_energy( r.atom(), s.atom(), path_dist );
225  energy_stack[ s_curr_stack_top ] += weight * e;
226  //std::cout << "h/hv atom pair energy: " << ii << " & " << jj << " = " << weight * e << "( unweighted: " << e << ") estack: " << energy_stack[ s_curr_stack_top ] << std::endl;
227  }
228  }
229 
230  if (s.is_rotamer_terminal() ) {
231  ++s_rotamers_seen;
232  at_v_rot_stack(s_rotamers_seen, r_curr_stack_top) += energy_stack[ s_curr_stack_top ];
233  }
234  }
235  } else { // r is a heavy atom
236  s_curr_stack_top = 2;
237  s_heavy_depth_stack[1] = 0;
238  //s_tree_depth_stack[1] = 0;
239 
240  for ( Size jj = 1,
241  liskip = r_heavy_skip_s_subtree.index( 1, r_heavy_depth_stack[ r_curr_stack_top ] );
242  jj <= trie2_num_heavyatoms; ++jj, ++liskip ) {
243  r_heavy_skip_s_subtree[ liskip ] = false;
244  }
245 
246  for ( Size jj = 1; jj <= trie2_natoms; ++jj ) {
247  //std::cerr << "! ii: " << ii << " is H, jj : " << jj << " " << s_curr_stack_top <<
248  // " " << s_rotamers_seen << " " << s_heavyatoms_seen << std::endl;
249  //trie_node & s = rt_trie[jj];
250  TrieNode< AT, CPDAT2 > const & s = trie2_atoms[ jj ];
251 
252  //std::cerr << "s_sibling_stack [";
253  //for ( Size kk = 1; kk <= s_curr_stack_top; ++kk )
254  // std::cerr << s_sibling_stack[kk] << " ";
255  //std::cerr << " " << std::endl;
256 
257  if ( s.first_atom_in_branch() ) --s_curr_stack_top;
258 
259  if (s.has_sibling() ) { //push - copy stack downwards
260  ++s_curr_stack_top;
261  energy_stack[s_curr_stack_top] = energy_stack[s_curr_stack_top - 1];
262  s_heavy_depth_stack[ s_curr_stack_top] = s_heavy_depth_stack[ s_curr_stack_top - 1];
263  //s_tree_depth_stack[ s_curr_stack_top]
264  // = s_tree_depth_stack[ s_curr_stack_top - 1];
265  s_sibling_stack[s_curr_stack_top] = s.sibling();
266  }
267  //++s_tree_depth_stack[ s_curr_stack_top];
268 
269  if (s.is_hydrogen()) {
270  Real weight(1.0); Size path_dist(0);
271  if (parent_heavy_wi_hcut_stack[s_heavy_depth_stack[s_curr_stack_top]] &&
272  count_pair( r.cp_data(), s.cp_data(), weight, path_dist )) {
273  core::PackerEnergy e = score_function.heavyatom_hydrogenatom_energy( r.atom(), s.atom(), path_dist );
274  energy_stack[ s_curr_stack_top ] += weight * e;
275  //std::cout << "hv/h atom pair energy: " << ii << " & " << jj << " = " << weight * e << "( unweighted: " << e << ") estack: " << energy_stack[ s_curr_stack_top ] << std::endl;
276  /*,trie1.res_id(), trie2.res_id(),
277  tri1.num_neighbors(), trie2.num_neighbors(),
278  Whbond);*/
279  }
280 
281  } else { // s is a heavy atom
282  ++s_heavyatoms_seen;
283  ++s_heavy_depth_stack[s_curr_stack_top];
284 
285  DistanceSquared d2(0.0);
286  core::PackerEnergy e = 0;
287  Real weight = 1;
288  Size path_dist(0);
289 
290  if ( count_pair( r.cp_data(), s.cp_data(), weight, path_dist) ) {
291  e = score_function.heavyatom_heavyatom_energy(r.atom(), s.atom(), d2, path_dist);
292  //std::cout << "hv/hv atom pair energy: " << ii << " & " << jj << " = " << weight * e << "( unweighted: " << e << ") estack: " << energy_stack[ s_curr_stack_top ] << std::endl;
293  energy_stack[s_curr_stack_top] += weight * e;
294  } else {
295  /// compute d2
296  d2 = r.atom().xyz().distance_squared( s.atom().xyz() );
297  }
298 
299  parent_heavy_wi_hcut_stack[s_heavy_depth_stack
300  [s_curr_stack_top]] =
301  parent_heavy_wi_hydrogen_cutoff(s_heavyatoms_seen, r_heavy_depth_stack[ r_curr_stack_top ]) =
303 
305  r_heavy_skip_s_subtree(s_heavyatoms_seen, r_heavy_depth_stack[ r_curr_stack_top ]) = true;
306  if (energy_stack[s_curr_stack_top] != 0. ) {
307  for ( Size kk = s_rotamers_seen + 1,
308  li_avrstack = at_v_rot_stack.index( kk, r_curr_stack_top );
309  kk <= s_rotamers_seen + s.num_rotamers_in_subtree();
310  ++kk, ++li_avrstack ) {
311  at_v_rot_stack[ li_avrstack ] += energy_stack[s_curr_stack_top];
312  }
313  }
314  //std::cout << "heavy atom / subtree prune: " << ii << " " << jj << " : jumping to sibling " << s_sibling_stack[s_curr_stack_top] << std::endl;
315  jj = s_sibling_stack[s_curr_stack_top] - 1;
316  s_rotamers_seen += s.num_rotamers_in_subtree();
317  continue;
318  }
319  }
320  if (s.is_rotamer_terminal() ) {
321  ++s_rotamers_seen;
322  //std::cout << "s.is_rotamer_terminal " << s_rotamers_seen << " with energy: " << energy_stack[ s_curr_stack_top ] << std::endl;
323  at_v_rot_stack(s_rotamers_seen, r_curr_stack_top) += energy_stack[s_curr_stack_top];
324  }
325  }
326  }
327  if (r.is_rotamer_terminal()) {
328  ++r_rotamers_seen;
329  //std::cout << "terminal #" << r_rotamers_seen << ": writing at_rot_array_proxy: ";
330  //for ( Size jj = 1; jj <= trie2_num_unique_rotamers; ++jj ) {
331  // std::cout << " " << at_rot_array_proxy( jj );
332  //}
333  //std::cout << std::endl;
334  //FArray1A< core::PackerEnergy > rot_rot_table_row( rot_rot_table(1, r_rotamers_seen), trie2_num_unique_rotamers);
335  //rot_rot_table_row.dimension(trie2_num_unique_rotamers);
336  //rot_rot_table_row = at_rot_array_proxy;
337  for ( Size jj = 1; jj <= trie2_num_unique_rotamers; ++jj ) {
338  rot_rot_table( jj, r_rotamers_seen ) = at_v_rot_stack( jj, r_curr_stack_top );
339  }
340 
341 
342  //std::cout << "rot rot table, column " << r_rotamers_seen << std::endl;
343  //for ( Size jj = 1; jj <= trie2_num_unique_rotamers; ++jj ) {
344  // std::cout << " " << rot_rot_table( jj, r_rotamers_seen );
345  //}
346  //std::cout << std::endl;
347 
348  }
349 
350 
351  }
355  pair_energy_table,
356  rot_rot_table );
357  //std::cout << "complete trie-vs-trie" << std::endl;
358 
359  delete [] parent_heavy_wi_hcut_stack;
360 }
361 
362 
363 } // namespace trie
364 } // namespace scoring
365 } // namespace core
366 
367 
368 #endif