00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00027 #error "This include file is not yet finished!"
00028
00029 #ifndef _HESS_EVALUATOR_H_
00030 #define _HESS_EVALUATOR_H_
00031
00032 #include <coconut_config.h>
00033 #include <evaluator.h>
00034 #include <expression.h>
00035 #include <model.h>
00036 #include <eval_main.h>
00037 #include <linalg.h>
00038 #include <math.h>
00039
00040 using namespace vgtl;
00041
00042 typedef bool (*prep_h_evaluator)();
00043 typedef double (*func_d_evaluator)(const std::vector<double>* __x,
00044 const variable_indicator& __v,
00045 std::vector<double>& __d_data);
00046 typedef std::vector<double>& (*der_evaluator)(const std::vector<double>& __d_dat,
00047 const variable_indicator& __v);
00048
00049 struct prep_h_eval_tp
00050 {
00051 std::vector<std::vector<double> >* d;
00052 std::vector<matrix<double> >* h;
00053 }
00054
00055 class prep_h_eval : public
00056 cached_forward_evaluator_base<prep_h_eval_tp, expression_node, bool, model::walker>
00057 {
00058 private:
00059 typedef cached_forward_evaluator_base<prep_h_eval_tp,
00060 expression_node,bool,model::walker> _Base;
00061
00062 public:
00063 prep_h_eval(std::vector<std::vector<double> >& __d,
00064 std::vector<matrix<double> >& __h,
00065 unsigned int _num_of_nodes)
00066 {
00067 eval_data.d = &__d;
00068 eval_data.h = &__h;
00069 if((*eval_data.d).size() < _num_of_nodes)
00070 (*eval_data.d).insert((*eval_data.d).end(),
00071 _num_of_nodes-(*eval_data.d).size(), std::vector<double>());
00072 if((*eval_data.h).size() < _num_of_nodes)
00073 (*eval_data.h).insert((*eval_data.h).end(),
00074 _num_of_nodes-(*eval_data.h).size(), matrix<double>());
00075 }
00076
00077 prep_h_eval(const prep_h_eval& __x) { eval_data.d = __x.eval_data.d;
00078 eval_data.h = __x.eval_data.h; }
00079
00080 ~prep_h_eval() {}
00081
00082 void initialize() { return; }
00083
00084 bool is_cached(const expression_node& __data)
00085 {
00086 return (*eval_data.d)[__data.node_num].size() > 0;
00087 }
00088
00089 void retrieve_from_cache(const expression_node& __data) { return; }
00090
00091 int initialize(const expression_node& __data)
00092 {
00093 (*eval_data.d)[__data.node_num].insert(
00094 (*eval_data.d)[__data.node_num].end(), __data.n_children, 0);
00095 (*eval_data.h)[__data.node_num] =
00096 matrix<double>(__data.n_children, __data.n_children);
00097 return 1;
00098 }
00099
00100 void calculate(const expression_node& __data) { return; }
00101
00102 int update(bool __rval) { return 0; }
00103
00104 int update(const expression_node& __data, bool __rval)
00105 { return 0; }
00106
00107 bool calculate_value(bool eval_all) { return true; }
00108 };
00109
00110
00111
00112
00113 struct func_h_eval_ret
00114 {
00115 double f;
00116 unsigned int nn;
00117 tristate in_chn;
00118 };
00119
00120 struct func_h_eval_type
00121 {
00122 const std::vector<double>* x;
00123 std::vector<double>* f_cache;
00124 std::vector<std::vector<double> >* d_data;
00125 std::vector<matrix<double> >* h_data;
00126 vector<unsigned int>* t;
00127 vector<bool>* b;
00128 func_h_eval_ret r;
00129 const model* mod;
00130 union { void* p; double d; } u;
00131 unsigned int n,
00132 info;
00133 };
00134
00135 class func_h_eval : public
00136 cached_forward_evaluator_base<func_h_eval_type, expression_node,
00137 func_h_eval_ret, model::walker>
00138 {
00139 private:
00140 typedef cached_forward_evaluator_base<func_h_eval_type,expression_node,
00141 func_h_eval_ret, model::walker> _Base;
00142
00143 protected:
00144 bool is_cached(const node_data_type& __data)
00145 {
00146 if(__data.operator_type == EXPRINFO_LIN ||
00147 __data.operator_type == EXPRINFO_QUAD)
00148 return true;
00149 if(eval_data.f_cache && __data.n_parents > 1 && __data.n_children > 0 &&
00150 v_ind->match(__data.var_indicator()))
00151 return true;
00152 else
00153 return false;
00154 }
00155
00156 private:
00157 double __power(double __coeff, double __x, int __exp)
00158 {
00159 if(__exp == 0)
00160 return 1.;
00161 else
00162 {
00163 double k = __coeff*__x;
00164 switch(__exp)
00165 {
00166 case 1:
00167 return k;
00168 break;
00169 case 2:
00170 return k*k;
00171 break;
00172 case -1:
00173 return 1./k;
00174 break;
00175 case -2:
00176 return 1./(k*k);
00177 break;
00178 default:
00179 if(__exp & 1)
00180 {
00181 if(k < 0)
00182 return -pow(-k, __exp);
00183 else
00184 return pow(k, __exp);
00185 }
00186 else
00187 return pow(fabs(k), __exp);
00188 break;
00189 }
00190 }
00191 }
00192
00193 void __calc_max(double h, const expression_node& __data)
00194 {
00195 if(h >= eval_data.r.f)
00196 {
00197 (*eval_data.d_data)[__data.node_num][eval_data.info] = 0;
00198 if(h > eval_data.r.f)
00199 (*eval_data.d_data)[__data.node_num][eval_data.n] =
00200 __data.coeffs[eval_data.n];
00201 else
00202
00203 (*eval_data.d_data)[__data.node_num][eval_data.n] = 0;
00204 eval_data.r.f = h;
00205 eval_data.info = eval_data.n;
00206 }
00207 else
00208 {
00209 (*eval_data.d_data)[__data.node_num][eval_data.n] = 0;
00210 }
00211 }
00212
00213 public:
00214 func_h_eval(const std::vector<double>& __x, const variable_indicator& __v,
00215 const model& __m, std::vector<std::vector<double> >& __d,
00216 std::vector<matrix<double> >& __h,
00217 std::vector<unsigned int>& __t, std::vector<bool>& __b,
00218 std::vector<double>* __c) : _Base()
00219 {
00220 eval_data.t = &__t;
00221 eval_data.b = &__b;
00222 eval_data.x = &__x;
00223 eval_data.f_cache = __c;
00224 eval_data.mod = &__m;
00225 eval_data.d_data = &__d;
00226 eval_data.h_data = &__h;
00227 eval_data.n = 0;
00228 eval_data.r.f = 0;
00229 eval_data.u.d = 0;
00230 v_ind = &__v;
00231 }
00232
00233 func_h_eval(const func_h_eval& __x) : _Base(__x) {}
00234
00235 ~func_h_eval() {}
00236
00237 model::walker short_cut_to(const expression_node& __data)
00238 { return eval_data.mod->node(0); }
00239
00240 void new_point(const std::vector<double>& __x, const variable_indicator& __v)
00241 {
00242 eval_data.x = &__x;
00243 v_ind = &__v;
00244 }
00245
00246 void initialize() { return; }
00247
00248 int initialize(const expression_node& __data)
00249 {
00250 eval_data.r.nn = __data.node_num;
00251 eval_data.r.in_chn = __data.sem.info_flags.has_1chnbase;
00252 eval_data.n = 0;
00253 if(__data.ev != NULL && (*__data.ev)[FUNC_D_EVALUATOR] != NULL)
00254
00255 {
00256 eval_data.r.f =
00257 (*(func_d_evaluator)(*__data.ev)[FUNC_D_EVALUATOR])(eval_data.x,
00258 *v_ind, (*eval_data.d_data)[__data.node_num]);
00259 return 0;
00260 }
00261 else
00262 {
00263 switch(__data.operator_type)
00264 {
00265 case EXPRINFO_MAX:
00266 case EXPRINFO_MIN:
00267 eval_data.info = 0;
00268
00269 case EXPRINFO_SUM:
00270 case EXPRINFO_PROD:
00271 case EXPRINFO_INVERT:
00272 eval_data.r.f = __data.params.nd();
00273 break;
00274 case EXPRINFO_IN:
00275 case EXPRINFO_AND:
00276 case EXPRINFO_NOGOOD:
00277 eval_data.r.f = 1.;
00278 break;
00279 case EXPRINFO_ALLDIFF:
00280 eval_data.u.p = (void*) new std::vector<double>;
00281 ((std::vector<double>*)eval_data.u.p)->reserve(__data.n_children);
00282
00283 case EXPRINFO_MEAN:
00284 case EXPRINFO_IF:
00285 case EXPRINFO_OR:
00286 case EXPRINFO_NOT:
00287 case EXPRINFO_COUNT:
00288 case EXPRINFO_SCPROD:
00289 case EXPRINFO_LEVEL:
00290 eval_data.r.f = 0.;
00291 break;
00292 case EXPRINFO_NORM:
00293 eval_data.info = 0;
00294 eval_data.r.f = 0.;
00295 break;
00296 case EXPRINFO_DET:
00297 case EXPRINFO_PSD:
00298
00299 break;
00300 case EXPRINFO_COND:
00301 case EXPRINFO_FEM:
00302 case EXPRINFO_MPROD:
00303
00304 break;
00305 }
00306 return 1;
00307 }
00308 }
00309
00310 void calculate(const expression_node& __data)
00311 {
00312 if(__data.operator_type > 0)
00313 {
00314 eval_data.r.f = __data.f_evaluate(-1, __data.params.nn(), *eval_data.x,
00315 *v_ind, eval_data.r.f, 0,
00316 &((*eval_data.d_data)[__data.node_num]));
00317 }
00318 if(!(*eval_data.b)[__data.node_num])
00319 {
00320 if(eval_data.r.in_chn == t_false)
00321 (*eval_data.t).push_back(__data.node_num);
00322 (*eval_data.b)[__data.node_num] = true;
00323 }
00324 }
00325
00326 void retrieve_from_cache(const expression_node& __data)
00327 {
00328
00329 if(__data.operator_type == EXPRINFO_LIN)
00330 eval_data.r.f = linalg_dot(eval_data.mod->lin[__data.params.nn()],
00331 *eval_data.x,0.);
00332 else if(__data.operator_type == EXPRINFO_QUAD)
00333 {
00334 std::vector<double> irslt = *eval_data.x;
00335 unsigned int r = __data.params.m().nrows();
00336
00337
00338 irslt.push_back(0);
00339 linalg_matvec(__data.params.m(), irslt, irslt);
00340 irslt.pop_back();
00341
00342 eval_data.r.f = linalg_dot(__data.params.m()[r-1], *eval_data.x, 0.);
00343
00344 eval_data.r.f += linalg_dot(irslt,*eval_data.x,0.);
00345
00346 linalg_add(linalg_scale(__data.params.m()[r-1], 0.5), irslt,
00347 (*eval_data.d_data)[__data.node_num]);
00348 }
00349 else
00350 eval_data.r.f = (*eval_data.f_cache)[__data.node_num];
00351 eval_data.r.nn = __data.node_num;
00352 eval_data.r.in_chn = __data.sem.info_flags.has_1chnbase;
00353 }
00354
00355 int update(const func_h_eval_ret& __rval)
00356 {
00357 eval_data.r.f = __rval;
00358 if(__rval.in_chn != t_false)
00359 (*eval_data.t).push_back(__rval.nn);
00360 return 0;
00361 }
00362
00363 int update(const expression_node& __data, const func_h_eval_ret& __rval)
00364 {
00365 int ret = 0;
00366 double __x;
00367 if(eval_data.r.in_chn == t_false && __rval.in_chn != t_false)
00368 (*eval_data.t).push_back(__rval.nn);
00369 if(__data.operator_type < 0)
00370 {
00371 switch(__data.operator_type)
00372 {
00373 case EXPRINFO_CONSTANT:
00374 eval_data.r.f = __data.params.nd();
00375
00376 break;
00377 case EXPRINFO_VARIABLE:
00378 eval_data.r.f = (*eval_data.x)[__data.params.nn()];
00379
00380 break;
00381 case EXPRINFO_SUM:
00382 case EXPRINFO_MEAN:
00383 { double h = __data.coeffs[eval_data.n];
00384 eval_data.r.f += h*__rval;
00385 (*eval_data.d_data)[__data.node_num][eval_data.n++] = h;
00386 }
00387 break;
00388 case EXPRINFO_PROD:
00389 if(eval_data.n == 0)
00390 {
00391 eval_data.r.f *= __rval;
00392 (*eval_data.d_data)[__data.node_num][0] = __data.params.nd();
00393 }
00394 else
00395 {
00396 (*eval_data.d_data)[__data.node_num][eval_data.n] = eval_data.r.f;
00397 eval_data.r.f *= __rval;
00398 for(int i = eval_data.n-1; i >= 0; i--)
00399 (*eval_data.d_data)[__data.node_num][i] *= __rval;
00400 }
00401 ++eval_data.n;
00402 break;
00403 case EXPRINFO_MONOME:
00404 if(eval_data.n == 0)
00405 {
00406 int n = __data.params.n()[0];
00407 if(n != 0)
00408 {
00409 __x = __power(__data.coeffs[0], __rval, n-1)*__data.coeffs[0];
00410 eval_data.r.f = __x*__rval;
00411 (*eval_data.d_data)[__data.node_num][0] = n*__x;
00412 }
00413 else
00414 {
00415 (*eval_data.d_data)[__data.node_num][0] = 0;
00416 eval_data.r.f = 1.;
00417 }
00418 }
00419 else
00420 {
00421 int n = __data.params.n()[eval_data.n];
00422 if(n != 0)
00423 {
00424 __x = __power(__data.coeffs[eval_data.n], __rval, n-1)*
00425 __data.coeffs[eval_data.n];
00426 (*eval_data.d_data)[__data.node_num][eval_data.n] =
00427 eval_data.r.f*n*__x;
00428 __x *= __rval;
00429 eval_data.r.f *= __x;
00430 for(int i = eval_data.n-1; i >= 0; i--)
00431 (*eval_data.d_data)[__data.node_num][i] *= __x;
00432 }
00433 else
00434 (*eval_data.d_data)[__data.node_num][eval_data.n] = 0;
00435 }
00436 ++eval_data.n;
00437 break;
00438 case EXPRINFO_MAX:
00439 __calc_max(__rval * __data.coeffs[eval_data.n], __data);
00440 ++eval_data.n;
00441 break;
00442 case EXPRINFO_MIN:
00443 { double h = __rval * __data.coeffs[eval_data.n];
00444 if(h <= eval_data.r.f)
00445 {
00446 (*eval_data.d_data)[__data.node_num][eval_data.info] = 0;
00447 if(h < eval_data.r.f)
00448 (*eval_data.d_data)[__data.node_num][eval_data.n] =
00449 __data.coeffs[eval_data.n];
00450 else
00451
00452 (*eval_data.d_data)[__data.node_num][eval_data.n] = 0;
00453 eval_data.r.f = h;
00454 eval_data.info = eval_data.n;
00455 }
00456 else
00457 {
00458 (*eval_data.d_data)[__data.node_num][eval_data.n] = 0;
00459 }
00460 }
00461 ++eval_data.n;
00462 break;
00463 case EXPRINFO_SCPROD:
00464 { double h = __data.coeffs[eval_data.n]*__rval;
00465
00466
00467 if(eval_data.n & 1)
00468 {
00469 eval_data.r.f += eval_data.u.d*h;
00470 (*eval_data.d_data)[__data.node_num][eval_data.n] =
00471 eval_data.u.d*__data.coeffs[eval_data.n-1];
00472 (*eval_data.d_data)[__data.node_num][eval_data.n-1] =
00473 h*__data.coeffs[eval_data.n];
00474 }
00475 else
00476 eval_data.u.d = h;
00477 }
00478 eval_data.n++;
00479 break;
00480 case EXPRINFO_NORM:
00481 if(__data.params.nd() == INFINITY)
00482 __calc_max(fabs(__rval * __data.coeffs[eval_data.n]), __data);
00483 else
00484 {
00485 double h = __data.coeffs[eval_data.n]*fabs(__rval);
00486 double O = pow(h, __data.params.nd()-1);
00487 eval_data.r.f += O*h;
00488 (*eval_data.d_data)[__data.node_num][eval_data.n] = O;
00489 }
00490 eval_data.n++;
00491 if(eval_data.n == __data.n_children &&
00492 __data.params.nd() != INFINITY)
00493 {
00494 double h = pow(eval_data.r.f,1./(__data.params.nd())-1.);
00495 for(unsigned int i = 0; i < eval_data.n; ++i)
00496 (*eval_data.d_data)[__data.node_num][eval_data.n] *= h;
00497 eval_data.r.f = pow(eval_data.r.f,1./(__data.params.nd()));
00498 }
00499 break;
00500 case EXPRINFO_INVERT:
00501 { double h = 1/__rval;
00502 eval_data.r.f *= h;
00503 (*eval_data.d_data)[__data.node_num][0] = -__data.params.nd()*h*h;
00504 }
00505 break;
00506 case EXPRINFO_DIV:
00507 if(eval_data.n++ == 0)
00508 eval_data.r.f = __rval;
00509 else
00510 {
00511 double h = 1/__rval;
00512 eval_data.r.f *=
00513 (*eval_data.d_data)[__data.node_num][0] = __data.params.nd()*h;
00514 (*eval_data.d_data)[__data.node_num][1] = -eval_data.r.f*h;
00515 }
00516 break;
00517 case EXPRINFO_SQUARE:
00518 { double h = __data.coeffs[0]*__rval+__data.params.nd();
00519 eval_data.r.f = h*h;
00520 (*eval_data.d_data)[__data.node_num][0] = 2*h*__data.coeffs[0];
00521 (*eval_data.h_data)[__data.node_num](0,0) =
00522 2*__data.coeffs[0]*__data.coeffs[0];
00523 }
00524 break;
00525 case EXPRINFO_INTPOWER:
00526 { int hl = __data.params.nn();
00527 if(hl == 0)
00528 {
00529 eval_data.r.f = 1;
00530 (*eval_data.d_data)[__data.node_num][0] = 0;
00531 }
00532 else
00533 {
00534 double kl = __data.coeffs[0]*__rval;
00535 switch(hl)
00536 {
00537 case 1:
00538 eval_data.r.f = kl;
00539 (*eval_data.d_data)[__data.node_num][0] = __data.coeffs[0];
00540 break;
00541 case 2:
00542 eval_data.r.f = kl*kl;
00543 (*eval_data.d_data)[__data.node_num][0] =
00544 2*kl*__data.coeffs[0];
00545 break;
00546 case -1:
00547 { double h = 1/kl;
00548 eval_data.r.f = h;
00549 (*eval_data.d_data)[__data.node_num][0] =
00550 -h*h*__data.coeffs[0];
00551 }
00552 break;
00553 case -2:
00554 { double h = 1/kl;
00555 double k = h*h;
00556 eval_data.r.f = k;
00557 (*eval_data.d_data)[__data.node_num][0] =
00558 -2*h*k*__data.coeffs[0];
00559 }
00560 break;
00561 default:
00562 { double h;
00563 if(hl & 1)
00564 h = pow(fabs(kl), hl-1);
00565 else
00566 {
00567 if(kl < 0)
00568 h = -pow(-kl, hl-1);
00569 else
00570 h = pow(kl, hl-1);
00571 }
00572 eval_data.r.f = h*kl;
00573 (*eval_data.d_data)[__data.node_num][0] =
00574 hl*h*__data.coeffs[0];
00575 }
00576 break;
00577 }
00578 }
00579 }
00580 break;
00581 case EXPRINFO_SQROOT:
00582 { double h = sqrt(__data.coeffs[0]*__rval+__data.params.nd());
00583 eval_data.r.f = h;
00584 double k = 0.5*__data.coeffs[0]/h;
00585 (*eval_data.d_data)[__data.node_num][0] = k;
00586 (*eval_data.h_data)[__data.node_num](0,0) = -k*k/h;
00587 }
00588 break;
00589 case EXPRINFO_ABS:
00590 { double h = __data.coeffs[0]*__rval+__data.params.nd();
00591 eval_data.r.f = fabs(h);
00592 (*eval_data.d_data)[__data.node_num][0] =
00593 h > 0 ? __data.coeffs[0] : (h < 0 ? -__data.coeffs[0] : 0);
00594 }
00595 break;
00596 case EXPRINFO_POW:
00597 { double hh = __rval * __data.coeffs[eval_data.n];
00598 if(eval_data.n++ == 0)
00599 eval_data.r.f = hh+__data.params.nd();
00600 else
00601 {
00602 if(hh == 0)
00603 {
00604 (*eval_data.d_data)[__data.node_num][0] = 0;
00605 (*eval_data.d_data)[__data.node_num][1] =
00606 log(eval_data.r.f)*__data.coeffs[1];
00607 eval_data.r.f = 1;
00608 }
00609 else
00610 {
00611 double h = pow(eval_data.r.f, hh);
00612
00613 (*eval_data.d_data)[__data.node_num][0] =
00614 hh*pow(eval_data.r.f, hh-1)*__data.coeffs[0];
00615 (*eval_data.d_data)[__data.node_num][1] =
00616 log(eval_data.r.f)*h*__data.coeffs[1];
00617 eval_data.r.f = h;
00618 }
00619 }
00620 }
00621 break;
00622 case EXPRINFO_EXP:
00623 { double h = exp(__rval*__data.coeffs[0]+__data.params.nd());
00624 eval_data.r.f = h;
00625 (*eval_data.d_data)[__data.node_num][0] = h*__data.coeffs[0];
00626 }
00627 break;
00628 case EXPRINFO_LOG:
00629 { double h = __rval*__data.coeffs[0]+__data.params.nd();
00630 eval_data.r.f = log(h);
00631 (*eval_data.d_data)[__data.node_num][0] = __data.coeffs[0]/h;
00632 }
00633 break;
00634 case EXPRINFO_SIN:
00635 { double h = __rval*__data.coeffs[0]+__data.params.nd();
00636 eval_data.r.f = sin(h);
00637 (*eval_data.d_data)[__data.node_num][0] = __data.coeffs[0]*cos(h);
00638 }
00639 break;
00640 case EXPRINFO_COS:
00641 { double h = __rval*__data.coeffs[0]+__data.params.nd();
00642 eval_data.r.f = cos(h);
00643 (*eval_data.d_data)[__data.node_num][0] = -__data.coeffs[0]*sin(h);
00644 }
00645 break;
00646 case EXPRINFO_ATAN2:
00647 { double hh = __rval * __data.coeffs[eval_data.n];
00648 if(eval_data.n++ == 0)
00649 eval_data.r.f = hh;
00650 else
00651 { double h = eval_data.r.f;
00652 h *= h;
00653 h += hh*hh;
00654 (*eval_data.d_data)[__data.node_num][0] = __data.coeffs[0]*hh/h;
00655 (*eval_data.d_data)[__data.node_num][1] =
00656 -__data.coeffs[1]*eval_data.r.f/h;
00657 eval_data.r.f = atan2(eval_data.r.f,hh);
00658 }
00659 }
00660 break;
00661 case EXPRINFO_GAUSS:
00662 { double h = (__data.coeffs[0]*__rval-__data.params.d()[0])/
00663 __data.params.d()[1];
00664 double k = exp(-h*h);
00665 eval_data.r.f = k;
00666 (*eval_data.d_data)[__data.node_num][0] =
00667 -2*__data.coeffs[0]*k*h/__data.params.d()[1];
00668 }
00669 break;
00671 case EXPRINFO_POLY:
00672 std::cerr << "func_d_evaluator: Polynomes NYI" << std::endl;
00673 throw "NYI";
00674 break;
00675 case EXPRINFO_LIN:
00676 case EXPRINFO_QUAD:
00677
00678 break;
00679 case EXPRINFO_IN:
00680 {
00681 __x = __data.coeffs[eval_data.n]*__rval;
00682 const interval& i(__data.params.i()[eval_data.n]);
00683 if(eval_data.r.f != -1 && i.contains(__x))
00684 {
00685 if(eval_data.r.f == 1 && (__x == i.inf() || __x == i.sup()))
00686 eval_data.r.f = 0;
00687 }
00688 else
00689 {
00690 eval_data.r.f = -1;
00691 ret = -1;
00692 }
00693 }
00694
00695 if(eval_data.n == 0)
00696 {
00697 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00698 v.erase(v.begin(),v.end());
00699 v.insert(v.begin(),__data.n_children,0.);
00700 }
00701 eval_data.n++;
00702 break;
00703 case EXPRINFO_IF:
00704 __x = __rval * __data.coeffs[eval_data.n];
00705 if(eval_data.n == 0)
00706 {
00707 const interval& i(__data.params.ni());
00708 if(!i.contains(__x))
00709 {
00710 ret = 1;
00711 (*eval_data.d_data)[__data.node_num][1] = 0.;
00712 }
00713 else
00714 (*eval_data.d_data)[__data.node_num][2] = 0.;
00715 (*eval_data.d_data)[__data.node_num][0] = 0.;
00716 }
00717 else
00718 {
00719 eval_data.r.f = __x;
00720 (*eval_data.d_data)[__data.node_num][eval_data.n] =
00721 __data.coeffs[eval_data.n];
00722 ret = -1;
00723 }
00724 eval_data.n += ret+1;
00725 break;
00726 case EXPRINFO_AND:
00727 { __x = __data.coeffs[eval_data.n]*__rval;
00728 const interval& i(__data.params.i()[eval_data.n]);
00729 if(eval_data.r.f == 1 && !i.contains(__x))
00730 {
00731 eval_data.r.f = 0;
00732 ret = -1;
00733 }
00734 }
00735
00736 if(eval_data.n == 0)
00737 {
00738 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00739 v.erase(v.begin(),v.end());
00740 v.insert(v.begin(),__data.n_children,0.);
00741 }
00742 eval_data.n++;
00743 break;
00744 case EXPRINFO_OR:
00745 { __x = __data.coeffs[eval_data.n]*__rval;
00746 const interval& i(__data.params.i()[eval_data.n]);
00747 if(eval_data.r.f == 0 && i.contains(__x))
00748 {
00749 eval_data.r.f = 1;
00750 ret = -1;
00751 }
00752 }
00753
00754 if(eval_data.n == 0)
00755 {
00756 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00757 v.erase(v.begin(),v.end());
00758 v.insert(v.begin(),__data.n_children,0.);
00759 }
00760 eval_data.n++;
00761 break;
00762 case EXPRINFO_NOT:
00763 { __x = __data.coeffs[0]*__rval;
00764 const interval& i(__data.params.ni());
00765 if(i.contains(__x))
00766 eval_data.r.f = 0;
00767 else
00768 eval_data.r.f = 1;
00769
00770 (*eval_data.d_data)[__data.node_num][0] = 0.;
00771 }
00772 break;
00773 case EXPRINFO_IMPLIES:
00774 { const interval& i(__data.params.i()[eval_data.n]);
00775 __x = __rval * __data.coeffs[eval_data.n];
00776 if(eval_data.n == 0)
00777 {
00778 if(!i.contains(__x))
00779 {
00780 eval_data.r.f = 1;
00781 ret = -1;
00782 }
00783
00784 (*eval_data.d_data)[__data.node_num][0] = 0.;
00785 (*eval_data.d_data)[__data.node_num][1] = 0.;
00786 }
00787 else
00788 eval_data.r.f = i.contains(__x) ? 1 : 0;
00789 ++eval_data.n;
00790 }
00791 break;
00792 case EXPRINFO_COUNT:
00793 { __x = __data.coeffs[eval_data.n]*__rval;
00794 const interval& i(__data.params.i()[eval_data.n]);
00795 if(i.contains(__x))
00796 eval_data.r.f += 1;
00797 }
00798
00799 if(eval_data.n == 0)
00800 {
00801 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00802 v.erase(v.begin(),v.end());
00803 v.insert(v.begin(),__data.n_children,0.);
00804 }
00805 eval_data.n++;
00806 break;
00807 case EXPRINFO_ALLDIFF:
00808 { __x = __data.coeffs[eval_data.n]*__rval;
00809 for(std::vector<double>::const_iterator _b =
00810 ((std::vector<double>*)eval_data.u.p)->begin();
00811 _b != ((std::vector<double>*)eval_data.u.p)->end(); ++_b)
00812 {
00813 if(fabs(__x-*_b) <= __data.params.nd())
00814 {
00815 eval_data.r.f = 0;
00816 ret = -1;
00817 break;
00818 }
00819 }
00820 if(ret != -1)
00821 ((std::vector<double>*) eval_data.u.p)->push_back(__x);
00822 }
00823
00824 if(eval_data.n == 0)
00825 {
00826 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00827 v.erase(v.begin(),v.end());
00828 v.insert(v.begin(),__data.n_children,0.);
00829 }
00830 eval_data.n++;
00831 if(eval_data.n == __data.n_children || ret == -1)
00832 delete (std::vector<double>*) eval_data.u.p;
00833 break;
00834 case EXPRINFO_HISTOGRAM:
00835 std::cerr << "func_d_evaluator: histogram NYI" << std::endl;
00836 throw "NYI";
00837 break;
00838 case EXPRINFO_LEVEL:
00839 { int h = (int)eval_data.r.f;
00840 __x = __data.coeffs[eval_data.n]*__rval;
00841 interval _h;
00842
00843 if(h != INT_MAX)
00844 {
00845 while(h < __data.params.im().nrows())
00846 {
00847 _h = __data.params.im()[h][eval_data.n];
00848 if(_h.contains(__x))
00849 break;
00850 h++;
00851 }
00852 if(h == __data.params.im().nrows())
00853 {
00854 ret = -1;
00855 eval_data.r.f = INT_MAX;
00856 }
00857 else
00858 eval_data.r.f = h;
00859 }
00860 }
00861
00862 if(eval_data.n == 0)
00863 {
00864 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00865 v.erase(v.begin(),v.end());
00866 v.insert(v.begin(),__data.n_children,0.);
00867 }
00868 eval_data.n++;
00869 break;
00870 case EXPRINFO_NEIGHBOR:
00871 if(eval_data.n == 0)
00872 eval_data.r.f = __data.coeffs[0]*__rval;
00873 else
00874 {
00875 double h = eval_data.r.f;
00876 eval_data.r.f = 0;
00877 __x = __data.coeffs[1]*__rval;
00878 for(unsigned int i = 0; i < __data.params.n().size(); i+=2)
00879 {
00880 if(h == __data.params.n()[i] && __x == __data.params.n()[i+1])
00881 {
00882 eval_data.r.f = 1;
00883 break;
00884 }
00885 }
00886 }
00887
00888 if(eval_data.n == 0)
00889 {
00890 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00891 v.erase(v.begin(),v.end());
00892 v.insert(v.begin(),__data.n_children,0.);
00893 }
00894 eval_data.n++;
00895 break;
00896 case EXPRINFO_NOGOOD:
00897 {
00898 __x = __data.coeffs[eval_data.n]*__rval;
00899 if(eval_data.r.f == 0 || __data.params.n()[eval_data.n] != __x)
00900 {
00901 eval_data.r.f = 0;
00902 ret = -1;
00903 }
00904 }
00905
00906 if(eval_data.n == 0)
00907 {
00908 std::vector<double>& v((*eval_data.d_data)[__data.node_num]);
00909 v.erase(v.begin(),v.end());
00910 v.insert(v.begin(),__data.n_children,0.);
00911 }
00912 eval_data.n++;
00913 break;
00914 case EXPRINFO_EXPECTATION:
00915 std::cerr << "func_d_evaluator: E NYI" << std::endl;
00916 throw "NYI";
00917 break;
00918 case EXPRINFO_INTEGRAL:
00919 std::cerr << "func_d_evaluator: INT NYI" << std::endl;
00920 throw "NYI";
00921 break;
00922 case EXPRINFO_LOOKUP:
00923 case EXPRINFO_PWLIN:
00924 case EXPRINFO_SPLINE:
00925 case EXPRINFO_PWCONSTLC:
00926 case EXPRINFO_PWCONSTRC:
00927 std::cerr << "func_d_evaluator: Table Operations NYI" << std::endl;
00928 throw "NYI";
00929 break;
00930 case EXPRINFO_DET:
00931 case EXPRINFO_COND:
00932 case EXPRINFO_PSD:
00933 case EXPRINFO_MPROD:
00934 case EXPRINFO_FEM:
00935 std::cerr << "func_d_evaluator: Matrix Operations NYI" << std::endl;
00936 throw "NYI";
00937 break;
00938 case EXPRINFO_RE:
00939 case EXPRINFO_IM:
00940 case EXPRINFO_ARG:
00941 case EXPRINFO_CPLXCONJ:
00942 std::cerr << "func_d_evaluator: Complex Numbers NYI" << std::endl;
00943 throw "NYI";
00944 break;
00945 case EXPRINFO_CMPROD:
00946 case EXPRINFO_CGFEM:
00947 std::cerr << "func_d_evaluator: Const Matrix Operations NYI" << std::endl;
00948 throw "NYI";
00949 break;
00950 default:
00951 std::cerr << "func_d_evaluator: unknown function type " <<
00952 __data.operator_type << std::endl;
00953 throw "Programming error";
00954 break;
00955 }
00956 }
00957 else if(__data.operator_type > 0)
00958
00959 eval_data.r.f = __data.f_evaluate(eval_data.n++, __data.params.nn(),
00960 *eval_data.x, *v_ind, eval_data.r.f, __rval,
00961 &(*eval_data.d_data)[__data.node_num]);
00962
00963 if(eval_data.f_cache && __data.n_parents > 1 && __data.n_children > 0)
00964 (*eval_data.f_cache)[__data.node_num] = eval_data.r.f;
00965 return ret;
00966 }
00967
00968 func_h_eval_ret calculate_value(bool eval_all)
00969 {
00970 return eval_data.r;
00971 }
00972 };
00973
00974 struct der_eval_type
00975 {
00976 std::vector<std::vector<double> >* d_data;
00977 std::vector<std::vector<double> >* d_cache;
00978 std::vector<double>* grad_vec;
00979 const model* mod;
00980 double mult;
00981 double mult_trans;
00982 unsigned int child_n;
00983 };
00984
00985 class der_eval : public
00986 cached_backward_evaluator_base<der_eval_type,expression_node,bool,
00987 model::walker>
00988 {
00989 private:
00990 typedef cached_backward_evaluator_base<der_eval_type,expression_node,
00991 bool,model::walker> _Base;
00992
00993 protected:
00994 bool is_cached(const node_data_type& __data)
00995 {
00996 if(eval_data.d_cache && __data.n_parents > 1 && __data.n_children > 0
00997 && (*eval_data.d_cache)[__data.node_num].size() > 0 &&
00998 v_ind->match(__data.var_indicator()))
00999 {
01000 return true;
01001 }
01002 else
01003 return false;
01004 }
01005
01006 public:
01007 der_eval(std::vector<std::vector<double> >& __der_data, variable_indicator& __v,
01008 const model& __m, std::vector<std::vector<double > >* __d,
01009 std::vector<double>& __grad)
01010 {
01011 eval_data.d_data = &__der_data;
01012 eval_data.d_cache = __d;
01013 eval_data.mod = &__m;
01014 eval_data.grad_vec = &__grad;
01015 eval_data.mult_trans = 1;
01016 eval_data.mult = 0;
01017 v_ind = &__v;
01018 }
01019
01020 der_eval(const der_eval& __d) { eval_data = __d.eval_data; }
01021
01022 ~der_eval() {}
01023
01024 void new_point(std::vector<std::vector<double> >& __der_data,
01025 const variable_indicator& __v)
01026 {
01027 eval_data.d_data = &__der_data;
01028 v_ind = &__v;
01029 }
01030
01031 void new_result(std::vector<double>& __grad)
01032 {
01033 eval_data.grad_vec = &__grad;
01034 }
01035
01036 void set_mult(double scal)
01037 {
01038 eval_data.mult_trans = scal;
01039 }
01040 public:
01041
01042 model::walker short_cut_to(const expression_node& __data)
01043 { return eval_data.mod->node(0); }
01044
01045
01046 void initialize()
01047 {
01048 eval_data.child_n = 0;
01049 }
01050
01051
01052 int calculate(const expression_node& __data)
01053 {
01054 if(__data.operator_type == EXPRINFO_CONSTANT)
01055 return 0;
01056 else if(__data.operator_type == EXPRINFO_VARIABLE)
01057 {
01058
01059 (*eval_data.grad_vec)[__data.params.nn()] += eval_data.mult_trans;
01060 return 0;
01061 }
01062 else if(__data.operator_type == EXPRINFO_LIN)
01063 {
01064 linalg_add(linalg_scale(eval_data.mod->lin[__data.params.nn()], eval_data.mult_trans),
01065 *eval_data.grad_vec,*eval_data.grad_vec);
01066 return 0;
01067 }
01068 else if(__data.operator_type == EXPRINFO_QUAD)
01069 {
01070 linalg_ssum(*eval_data.grad_vec, 2*eval_data.mult_trans,
01071 (*eval_data.d_data)[__data.node_num]);
01072 return 0;
01073 }
01074 else if(__data.ev && (*__data.ev)[HESS_EVALUATOR])
01075
01076 {
01077 linalg_ssum(*eval_data.grad_vec, eval_data.mult,
01078 (*(der_evaluator)(*__data.ev)[HESS_EVALUATOR])(
01079 (*eval_data.d_data)[__data.node_num], *v_ind));
01080 return 0;
01081 }
01082 else if(eval_data.mult_trans == 0)
01083
01084 return 0;
01085 else
01086 {
01087 eval_data.child_n = 1;
01088 eval_data.mult = eval_data.mult_trans;
01089 if(__data.n_parents > 1 && __data.n_children > 0 && eval_data.d_cache)
01090 {
01091 eval_data.mult_trans = (*eval_data.d_data)[__data.node_num][0];
01092 }
01093 else
01094 eval_data.mult_trans *= (*eval_data.d_data)[__data.node_num][0];
01095 return 1;
01096 }
01097 }
01098
01099
01100 void cleanup(const expression_node& __data)
01101 {
01102
01103 if(__data.n_parents > 1 && __data.n_children > 0 && eval_data.d_cache
01104 && (*eval_data.d_cache)[__data.node_num].size() == 0)
01105 {
01106 (*eval_data.d_cache)[__data.node_num] = *eval_data.grad_vec;
01107 linalg_smult(*eval_data.grad_vec, eval_data.mult);
01108 }
01109 }
01110
01111 void retrieve_from_cache(const expression_node& __data)
01112 {
01113
01114 linalg_ssum(*eval_data.grad_vec, eval_data.mult_trans,
01115 (*eval_data.d_cache)[__data.node_num]);
01116 }
01117
01118 int update(const bool& __rval)
01119 {
01120 eval_data.child_n++;
01121 return 0;
01122 }
01123
01124
01125 int update(const expression_node& __data, const bool& __rval)
01126 {
01127 if(__data.n_children == 0)
01128 return 0;
01129 if(__data.n_parents > 1 && __data.n_children > 0 && eval_data.d_cache)
01130 {
01131 if(eval_data.child_n < __data.n_children)
01132 eval_data.mult_trans =
01133 (*eval_data.d_data)[__data.node_num][eval_data.child_n];
01134 }
01135 else if(eval_data.child_n < __data.n_children)
01136 {
01137 eval_data.mult_trans = eval_data.mult *
01138 (*eval_data.d_data)[__data.node_num][eval_data.child_n];
01139 }
01140 eval_data.child_n++;
01141 return 0;
01142 }
01143
01144 bool calculate_value(bool eval_all)
01145 {
01146 return true;
01147 }
01148 };
01149
01150 #endif