17 #include <boost/unordered_map.hpp>
24 #include <basic/Tracer.hh>
29 #include <utility/vector0.hh>
32 #include <utility/vector1.hh>
36 static basic::Tracer
tr(
"protocols.sparta");
44 #define MAX(x,y) ((x)>(y)?(x):(y))
45 #define MIN(x,y) ((x)<(y)?(x):(y))
49 N1_NODE_I = 96; N1_NODE_H = 20; N1_NODE_O = 3;
50 N2_NODE_I = 9; N2_NODE_H = 6; N2_NODE_O = 3;
58 ANN::ANN(
const string& dPATH,
const string& dNAME_PREFIX)
63 N1_NODE_I = 96; N1_NODE_H = 20; N1_NODE_O = 3;
64 N2_NODE_I = 9; N2_NODE_H = 6; N2_NODE_O = 3;
68 DB_NAME_PREFIX = dNAME_PREFIX;
74 ANN::ANN(
int N1_nodeI,
int N1_nodeH,
int N1_nodeO,
const string& dPATH,
const string& dNAME_PREFIX)
84 N2_NODE_I = 9; N2_NODE_H = 6; N2_NODE_O = 3;
87 DB_NAME_PREFIX = dNAME_PREFIX;
93 ANN::ANN(
int N1_nodeI,
int N1_nodeH,
int N1_nodeO,
int N2_nodeI,
int N2_nodeH,
int N2_nodeO,
const string& dPATH,
const string& dNAME_PREFIX)
100 N1_NODE_O = N1_nodeO;
103 N2_NODE_I = N2_nodeI; N2_NODE_H = N2_nodeH; N2_NODE_O = N2_nodeO;
106 DB_NAME_PREFIX = dNAME_PREFIX;
112 void ANN::init(
int N1_nodeI,
int N1_nodeH,
int N1_nodeO,
int N2_nodeI,
int N2_nodeH,
int N2_nodeO,
const string& dPATH,
const string& dNAME_PREFIX)
117 N1_NODE_I = N1_nodeI;
118 N1_NODE_H = N1_nodeH;
119 N1_NODE_O = N1_nodeO;
122 N2_NODE_I = N2_nodeI; N2_NODE_H = N2_nodeH; N2_NODE_O = N2_nodeO;
125 DB_NAME_PREFIX = dNAME_PREFIX;
141 wName = DB_PATH+slash_char+DB_NAME_PREFIX+
".level1.WI.tab";
142 loadWeightBias3(wName, WI_1, BI_1, WI_2, BI_2, WI_3, BI_3, N1_NODE_I, N1_NODE_I, N1_NODE_I);
144 wName = DB_PATH+slash_char+DB_NAME_PREFIX+
".level1.WL1.tab";
145 loadWeightBias3(wName, WL1_1, BL1_1, WL1_2, BL1_2, WL1_3, BL1_3, N1_NODE_H, N1_NODE_I, N1_NODE_H);
147 wName = DB_PATH+slash_char+DB_NAME_PREFIX+
".level1.WL2.tab";
148 loadWeightBias3(wName, WL2_1, BL2_1, WL2_2, BL2_2, WL2_3, BL2_3, N1_NODE_O, N1_NODE_H, N1_NODE_O);
168 int N_W_row,
int N_W_col,
int )
176 int row = N_W_row, col = N_W_col;
177 for ( it = W_Tab.
Entries.begin(); it != W_Tab.
Entries.end(); it++ )
179 int check = index/row;
180 for(
int i = 0; i < col; i++ )
183 float w = atof((it->second[str]).c_str());
184 if( check == 0) W1[ index ].push_back( w );
185 else if( check == 1) W2[ index-row ].push_back( w );
186 else if( check == 2) W3[ index-row*2 ].push_back( w );
187 else tr.Error <<
"Wrong size for matrix " << fName <<
" ... \n";
190 if( check == 0) B1.push_back( atof((it->second[
"b"]).c_str()) );
191 else if( check == 1) B2.push_back( atof((it->second[
"b"]).c_str()) );
192 else if( check == 2) B3.push_back( atof((it->second[
"b"]).c_str()) );
207 boost::unordered_map<int, utility::vector0<float> >::iterator itV;
208 for ( itV = ANN_IN_MTX_LEVEL1.begin(); itV != ANN_IN_MTX_LEVEL1.end(); itV++ )
214 applyANNTransformation(itV->second, WI_1, BI_1, IL1, 1);
215 applyANNTransformation(itV->second, WI_2, BI_2, IL2, 1);
216 applyANNTransformation(itV->second, WI_3, BI_3, IL3, 1);
220 applyANNTransformation(IL1, WL1_1, BL1_1, HL1, 1);
221 applyANNTransformation(IL2, WL1_2, BL1_2, HL2, 1);
222 applyANNTransformation(IL3, WL1_3, BL1_3, HL3, 1);
226 applyANNTransformation(HL1, WL2_1, BL2_1, OL1, 0);
227 applyANNTransformation(HL2, WL2_2, BL2_2, OL2, 0);
228 applyANNTransformation(HL3, WL2_3, BL2_3, OL3, 0);
231 applyVecAverage(OL1,OL2,OL3,OUT1);
233 ANN_OUT_MTX_LEVEL1[itV->first] = OUT1;
244 boost::unordered_map<int, utility::vector0<float> >::iterator itV;
245 for ( itV = ANN_IN_MTX_LEVEL2.begin(); itV != ANN_IN_MTX_LEVEL2.end(); itV++ )
251 applyANNTransformation(itV->second, W2I_1, B2I_1, IL1, 1);
252 applyANNTransformation(itV->second, W2I_2, B2I_2, IL2, 1);
253 applyANNTransformation(itV->second, W2I_3, B2I_3, IL3, 1);
257 applyANNTransformation(IL1, W2L1_1, B2L1_1, HL1, 1);
258 applyANNTransformation(IL2, W2L1_2, B2L1_2, HL2, 1);
259 applyANNTransformation(IL3, W2L1_3, B2L1_3, HL3, 1);
263 applyANNTransformation(HL1, W2L2_1, B2L2_1, OL1, 0);
264 applyANNTransformation(HL2, W2L2_2, B2L2_2, OL2, 0);
265 applyANNTransformation(HL3, W2L2_3, B2L2_3, OL3, 0);
268 applyVecAverage(OL1,OL2,OL3,OUT2);
270 ANN_OUT_MTX_LEVEL2[itV->first] = OUT2;
279 ANN_IN_MTX_LEVEL1 = inMatrix;
293 if( inp.size() != w[0].size() || w.size() != b.size() ) {
294 tr.Error <<
" ANN prediction failed with inconsistent data!" << endl;
298 for(
Size i = 0; i < w.size(); i++ ) {
300 for(
Size j = 0; j < inp.size(); j++ ) sum += inp[j]*w[i][j];
303 if( code == 1 ) out.push_back( 2.0/(1.0+exp(-2.0*sum))-1.0 );
304 else if( code == 0 ) out.push_back( sum );
313 if( v1.size() == v2.size() && v1.size() == v3.size() ) {
319 for(
Size i=0; i<v1.size(); i++) {
322 vout.push_back( (v1[i]+v2[i]+v3[i])/3.0 );
334 float a=v[0],
b=v[1],
c=v[2];
335 if(a>1) a=1.0;
else if(a<0) a=0.0;
336 if(b>1) b=1.0;
else if(b<0) b=0.0;
337 if(c>1) c=1.0;
else if(c<0) c=0.0;
340 a/=sum; b/=sum; c/=sum;
342 v.push_back(a); v.push_back(b); v.push_back(c);
351 if( v.size() != 3 )
return -1.0;
353 return 2.0*
MAX(v[0],
MAX(v[1],v[2])) - (v[0]+v[1]+v[2]) +
MIN(v[0],
MIN(v[1],v[2]));
362 cnt+=(v[1]==1); cnt+=(v[3]==1); cnt+=(v[5]==1); cnt+=(v[7]==1); cnt+=(v[9]==1); cnt+=(v[11]==1);
371 sprintf(buff,
"%d", n);
378 char *
ANN::ftoa(
float n,
char *buff,
char f,
int prec )
380 if ( !(f==
'f' || f==
'F' || f==
'e' || f==
'E' || f==
'g' || f==
'G') ) {
391 *fs++ = prec / 10 +
'0';
392 *fs++ = prec % 10 +
'0';
400 sprintf( buff, format, n );
407 if( getenv(
"PATH" ) != NULL) {
408 string temp = getenv(
"PATH" );
409 if(temp.find(
"/") != string::npos ) slash_char =
"/";
410 else if(temp.find(
"\\") != string::npos ) slash_char =
"\\";
412 else slash_char =
"/";