Logo ROOT   6.08/07
Reference Guide
SVWorkingSet.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andrzej Zemla
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : SVWorkingSet *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation *
12  * *
13  * Authors (alphabetical): *
14  * Marcin Wolter <Marcin.Wolter@cern.ch> - IFJ PAN, Krakow, Poland *
15  * Andrzej Zemla <azemla@cern.ch> - IFJ PAN, Krakow, Poland *
16  * (IFJ PAN: Henryk Niewodniczanski Inst. Nucl. Physics, Krakow, Poland) *
17  * *
18  * Copyright (c) 2005: *
19  * CERN, Switzerland *
20  * MPI-K Heidelberg, Germany *
21  * PAN, Krakow, Poland *
22  * *
23  * Redistribution and use in source and binary forms, with or without *
24  * modification, are permitted according to the terms listed in LICENSE *
25  * (http://tmva.sourceforge.net/LICENSE) *
26  **********************************************************************************/
27 
28 #include "TMVA/SVWorkingSet.h"
29 
30 #include "TMVA/MsgLogger.h"
31 #include "TMVA/SVEvent.h"
32 #include "TMVA/SVKernelFunction.h"
33 #include "TMVA/SVKernelMatrix.h"
34 #include "TMVA/Types.h"
35 
36 
37 #include "TMath.h"
38 #include "TRandom3.h"
39 
40 #include <iostream>
41 #include <vector>
42 
43 ////////////////////////////////////////////////////////////////////////////////
44 /// constructor
45 
47  : fdoRegression(kFALSE),
48  fInputData(0),
49  fSupVec(0),
50  fKFunction(0),
51  fKMatrix(0),
52  fTEventUp(0),
53  fTEventLow(0),
54  fB_low(1.),
55  fB_up(-1.),
56  fTolerance(0.01),
57  fLogger( new MsgLogger( "SVWorkingSet", kINFO ) )
58 {
59 }
60 
61 ////////////////////////////////////////////////////////////////////////////////
62 /// constructor
63 
64 TMVA::SVWorkingSet::SVWorkingSet(std::vector<TMVA::SVEvent*>*inputVectors, SVKernelFunction* kernelFunction,
65  Float_t tol, Bool_t doreg)
66  : fdoRegression(doreg),
67  fInputData(inputVectors),
68  fSupVec(0),
69  fKFunction(kernelFunction),
70  fTEventUp(0),
71  fTEventLow(0),
72  fB_low(1.),
73  fB_up(-1.),
74  fTolerance(tol),
75  fLogger( new MsgLogger( "SVWorkingSet", kINFO ) )
76 {
77  fKMatrix = new TMVA::SVKernelMatrix(inputVectors, kernelFunction);
78  Float_t *pt;
79  for( UInt_t i = 0; i < fInputData->size(); i++){
80  pt = fKMatrix->GetLine(i);
81  fInputData->at(i)->SetLine(pt);
82  fInputData->at(i)->SetNs(i);
83  if(fdoRegression) fInputData->at(i)->SetErrorCache(fInputData->at(i)->GetTarget());
84  }
85  TRandom3 rand;
86  UInt_t kk = rand.Integer(fInputData->size());
87  if(fdoRegression) {
91  }
92  else{
93  while(1){
94  if(fInputData->at(kk)->GetTypeFlag()==-1){
95  fTEventLow = fInputData->at(kk);
96  break;
97  }
98  kk = rand.Integer(fInputData->size());
99  }
100 
101  while (1){
102  if (fInputData->at(kk)->GetTypeFlag()==1) {
103  fTEventUp = fInputData->at(kk);
104  break;
105  }
106  kk = rand.Integer(fInputData->size());
107  }
108  }
111 }
112 
113 ////////////////////////////////////////////////////////////////////////////////
114 /// destructor
115 
117 {
118  if (fKMatrix != 0) {delete fKMatrix; fKMatrix = 0;}
119  delete fLogger;
120 }
121 
122 ////////////////////////////////////////////////////////////////////////////////
123 
125 {
126  SVEvent* ievt=0;
127  Float_t fErrorC_J = 0.;
128  if( jevt->GetIdx()==0) fErrorC_J = jevt->GetErrorCache();
129  else{
130  Float_t *fKVals = jevt->GetLine();
131  fErrorC_J = 0.;
132  std::vector<TMVA::SVEvent*>::iterator idIter;
133 
134  UInt_t k=0;
135  for(idIter = fInputData->begin(); idIter != fInputData->end(); idIter++){
136  if((*idIter)->GetAlpha()>0)
137  fErrorC_J += (*idIter)->GetAlpha()*(*idIter)->GetTypeFlag()*fKVals[k];
138  k++;
139  }
140 
141 
142  fErrorC_J -= jevt->GetTypeFlag();
143  jevt->SetErrorCache(fErrorC_J);
144 
145  if((jevt->GetIdx() == 1) && (fErrorC_J < fB_up )){
146  fB_up = fErrorC_J;
147  fTEventUp = jevt;
148  }
149  else if ((jevt->GetIdx() == -1)&&(fErrorC_J > fB_low)) {
150  fB_low = fErrorC_J;
151  fTEventLow = jevt;
152  }
153  }
154  Bool_t converged = kTRUE;
155 
156  if((jevt->GetIdx()>=0) && (fB_low - fErrorC_J > 2*fTolerance)) {
157  converged = kFALSE;
158  ievt = fTEventLow;
159  }
160 
161  if((jevt->GetIdx()<=0) && (fErrorC_J - fB_up > 2*fTolerance)) {
162  converged = kFALSE;
163  ievt = fTEventUp;
164  }
165 
166  if (converged) return kFALSE;
167 
168  if(jevt->GetIdx()==0){
169  if(fB_low - fErrorC_J > fErrorC_J - fB_up) ievt = fTEventLow;
170  else ievt = fTEventUp;
171  }
172 
173  if (TakeStep(ievt, jevt)) return kTRUE;
174  else return kFALSE;
175 }
176 
177 
178 ////////////////////////////////////////////////////////////////////////////////
179 
181 {
182  if (ievt == jevt) return kFALSE;
183  std::vector<TMVA::SVEvent*>::iterator idIter;
184  const Float_t epsilon = 1e-8; //make it 1-e6 or 1-e5 to make it faster
185 
186  Float_t type_I, type_J;
187  Float_t errorC_I, errorC_J;
188  Float_t alpha_I, alpha_J;
189 
190  Float_t newAlpha_I, newAlpha_J;
191  Int_t s;
192 
193  Float_t l, h, lobj = 0, hobj = 0;
194  Float_t eta;
195 
196  type_I = ievt->GetTypeFlag();
197  alpha_I = ievt->GetAlpha();
198  errorC_I = ievt->GetErrorCache();
199 
200  type_J = jevt->GetTypeFlag();
201  alpha_J = jevt->GetAlpha();
202  errorC_J = jevt->GetErrorCache();
203 
204  s = Int_t( type_I * type_J );
205 
206  Float_t c_i = ievt->GetCweight();
207 
208  Float_t c_j = jevt->GetCweight();
209 
210  // compute l, h objective function
211 
212  if (type_I == type_J) {
213  Float_t gamma = alpha_I + alpha_J;
214 
215  if ( c_i > c_j ) {
216  if ( gamma < c_j ) {
217  l = 0;
218  h = gamma;
219  }
220  else{
221  h = c_j;
222  if ( gamma < c_i )
223  l = 0;
224  else
225  l = gamma - c_i;
226  }
227  }
228  else {
229  if ( gamma < c_i ){
230  l = 0;
231  h = gamma;
232  }
233  else {
234  l = gamma - c_i;
235  if ( gamma < c_j )
236  h = gamma;
237  else
238  h = c_j;
239  }
240  }
241  }
242  else {
243  Float_t gamma = alpha_I - alpha_J;
244  if (gamma > 0) {
245  l = 0;
246  if ( gamma >= (c_i - c_j) )
247  h = c_i - gamma;
248  else
249  h = c_j;
250  }
251  else {
252  l = -gamma;
253  if ( (c_i - c_j) >= gamma)
254  h = c_j;
255  else
256  h = c_i - gamma;
257  }
258  }
259 
260  if (l == h) return kFALSE;
261  Float_t kernel_II, kernel_IJ, kernel_JJ;
262 
263  kernel_II = fKMatrix->GetElement(ievt->GetNs(),ievt->GetNs());
264  kernel_IJ = fKMatrix->GetElement(ievt->GetNs(), jevt->GetNs());
265  kernel_JJ = fKMatrix->GetElement(jevt->GetNs(),jevt->GetNs());
266 
267  eta = 2*kernel_IJ - kernel_II - kernel_JJ;
268  if (eta < 0) {
269  newAlpha_J = alpha_J + (type_J*( errorC_J - errorC_I ))/eta;
270  if (newAlpha_J < l) newAlpha_J = l;
271  else if (newAlpha_J > h) newAlpha_J = h;
272 
273  }
274 
275  else {
276 
277  Float_t c_I = eta/2;
278  Float_t c_J = type_J*( errorC_I - errorC_J ) - eta * alpha_J;
279  lobj = c_I * l * l + c_J * l;
280  hobj = c_I * h * h + c_J * h;
281 
282  if (lobj > hobj + epsilon) newAlpha_J = l;
283  else if (lobj < hobj - epsilon) newAlpha_J = h;
284  else newAlpha_J = alpha_J;
285  }
286 
287  if (TMath::Abs( newAlpha_J - alpha_J ) < ( epsilon * ( newAlpha_J + alpha_J+ epsilon ))){
288  return kFALSE;
289  //it spends here to much time... it is stupido
290  }
291  newAlpha_I = alpha_I - s*( newAlpha_J - alpha_J );
292 
293  if (newAlpha_I < 0) {
294  newAlpha_J += s* newAlpha_I;
295  newAlpha_I = 0;
296  }
297  else if (newAlpha_I > c_i) {
298  Float_t temp = newAlpha_I - c_i;
299  newAlpha_J += s * temp;
300  newAlpha_I = c_i;
301  }
302 
303  Float_t dL_I = type_I * ( newAlpha_I - alpha_I );
304  Float_t dL_J = type_J * ( newAlpha_J - alpha_J );
305 
306  Int_t k = 0;
307  for(idIter = fInputData->begin(); idIter != fInputData->end(); idIter++){
308  k++;
309  if((*idIter)->GetIdx()==0){
310  Float_t ii = fKMatrix->GetElement(ievt->GetNs(), (*idIter)->GetNs());
311  Float_t jj = fKMatrix->GetElement(jevt->GetNs(), (*idIter)->GetNs());
312 
313  (*idIter)->UpdateErrorCache(dL_I * ii + dL_J * jj);
314  }
315  }
316  ievt->SetAlpha(newAlpha_I);
317  jevt->SetAlpha(newAlpha_J);
318  // set new indexes
319  SetIndex(ievt);
320  SetIndex(jevt);
321 
322  // update error cache
323  ievt->SetErrorCache(errorC_I + dL_I*kernel_II + dL_J*kernel_IJ);
324  jevt->SetErrorCache(errorC_J + dL_I*kernel_IJ + dL_J*kernel_JJ);
325 
326  // compute fI_low, fB_low
327 
328  fB_low = -1*1e30;
329  fB_up = 1e30;
330 
331  for(idIter = fInputData->begin(); idIter != fInputData->end(); idIter++){
332  if((*idIter)->GetIdx()==0){
333  if((*idIter)->GetErrorCache()> fB_low){
334  fB_low = (*idIter)->GetErrorCache();
335  fTEventLow = (*idIter);
336  }
337  if( (*idIter)->GetErrorCache()< fB_up){
338  fB_up =(*idIter)->GetErrorCache();
339  fTEventUp = (*idIter);
340  }
341  }
342  }
343 
344  // for optimized alfa's
345  if (fB_low < TMath::Max(ievt->GetErrorCache(), jevt->GetErrorCache())) {
346  if (ievt->GetErrorCache() > fB_low) {
347  fB_low = ievt->GetErrorCache();
348  fTEventLow = ievt;
349  }
350  else {
351  fB_low = jevt->GetErrorCache();
352  fTEventLow = jevt;
353  }
354  }
355 
356  if (fB_up > TMath::Max(ievt->GetErrorCache(), jevt->GetErrorCache())) {
357  if (ievt->GetErrorCache()< fB_low) {
358  fB_up =ievt->GetErrorCache();
359  fTEventUp = ievt;
360  }
361  else {
362  fB_up =jevt->GetErrorCache() ;
363  fTEventUp = jevt;
364  }
365  }
366  return kTRUE;
367 }
368 
369 ////////////////////////////////////////////////////////////////////////////////
370 
372 {
373  if((fB_up > fB_low - 2*fTolerance)) return kTRUE;
374  return kFALSE;
375 }
376 
377 ////////////////////////////////////////////////////////////////////////////////
378 /// train the SVM
379 
381 {
382 
383  Int_t numChanged = 0;
384  Int_t examineAll = 1;
385 
386  Float_t numChangedOld = 0;
387  Int_t deltaChanges = 0;
388  UInt_t numit = 0;
389 
390  std::vector<TMVA::SVEvent*>::iterator idIter;
391 
392  while ((numChanged > 0) || (examineAll > 0)) {
393  if (fIPyCurrentIter) *fIPyCurrentIter = numit;
394  if (fExitFromTraining && *fExitFromTraining) break;
395  numChanged = 0;
396  if (examineAll) {
397  for (idIter = fInputData->begin(); idIter!=fInputData->end(); idIter++){
398  if(!fdoRegression) numChanged += (UInt_t)ExamineExample(*idIter);
399  else numChanged += (UInt_t)ExamineExampleReg(*idIter);
400  }
401  }
402  else {
403  for (idIter = fInputData->begin(); idIter!=fInputData->end(); idIter++) {
404  if ((*idIter)->IsInI0()) {
405  if(!fdoRegression) numChanged += (UInt_t)ExamineExample(*idIter);
406  else numChanged += (UInt_t)ExamineExampleReg(*idIter);
407  if (Terminated()) {
408  numChanged = 0;
409  break;
410  }
411  }
412  }
413  }
414 
415  if (examineAll == 1) examineAll = 0;
416  else if (numChanged == 0 || numChanged < 10 || deltaChanges > 3 ) examineAll = 1;
417 
418  if (numChanged == numChangedOld) deltaChanges++;
419  else deltaChanges = 0;
420  numChangedOld = numChanged;
421  ++numit;
422 
423  if (numit >= nMaxIter) {
424  *fLogger << kWARNING
425  << "Max number of iterations exceeded. "
426  << "Training may not be completed. Try use less Cost parameter" << Endl;
427  break;
428  }
429  }
430 }
431 
432 ////////////////////////////////////////////////////////////////////////////////
433 
435 {
436  if( (0< event->GetAlpha()) && (event->GetAlpha()< event->GetCweight()))
437  event->SetIdx(0);
438 
439  if( event->GetTypeFlag() == 1){
440  if( event->GetAlpha() == 0)
441  event->SetIdx(1);
442  else if( event->GetAlpha() == event->GetCweight() )
443  event->SetIdx(-1);
444  }
445  if( event->GetTypeFlag() == -1){
446  if( event->GetAlpha() == 0)
447  event->SetIdx(-1);
448  else if( event->GetAlpha() == event->GetCweight() )
449  event->SetIdx(1);
450  }
451 }
452 
453 ////////////////////////////////////////////////////////////////////////////////
454 
456 {
457  std::vector<TMVA::SVEvent*>::iterator idIter;
458  UInt_t counter = 0;
459  for( idIter = fInputData->begin(); idIter != fInputData->end(); idIter++)
460  if((*idIter)->GetAlpha() !=0) counter++;
461 }
462 
463 ////////////////////////////////////////////////////////////////////////////////
464 
465 std::vector<TMVA::SVEvent*>* TMVA::SVWorkingSet::GetSupportVectors()
466 {
467  std::vector<TMVA::SVEvent*>::iterator idIter;
468  if( fSupVec != 0) {delete fSupVec; fSupVec = 0; }
469  fSupVec = new std::vector<TMVA::SVEvent*>(0);
470 
471  for( idIter = fInputData->begin(); idIter != fInputData->end(); idIter++){
472  if((*idIter)->GetDeltaAlpha() !=0){
473  fSupVec->push_back((*idIter));
474  }
475  }
476  return fSupVec;
477 }
478 
479 //for regression
480 
482 {
483  if (ievt == jevt) return kFALSE;
484  std::vector<TMVA::SVEvent*>::iterator idIter;
485  const Float_t epsilon = 0.001*fTolerance;//TODO
486 
487  const Float_t kernel_II = fKMatrix->GetElement(ievt->GetNs(),ievt->GetNs());
488  const Float_t kernel_IJ = fKMatrix->GetElement(ievt->GetNs(),jevt->GetNs());
489  const Float_t kernel_JJ = fKMatrix->GetElement(jevt->GetNs(),jevt->GetNs());
490 
491  //compute eta & gamma
492  const Float_t eta = -2*kernel_IJ + kernel_II + kernel_JJ;
493  const Float_t gamma = ievt->GetDeltaAlpha() + jevt->GetDeltaAlpha();
494 
495  //TODO CHECK WHAT IF ETA <0
496  //w.r.t Mercer's conditions it should never happen, but what if?
497 
498  Bool_t caseA, caseB, caseC, caseD, terminated;
499  caseA = caseB = caseC = caseD = terminated = kFALSE;
500  Float_t b_alpha_i, b_alpha_j, b_alpha_i_p, b_alpha_j_p; //temporary lagrange multipliers
501  const Float_t b_cost_i = ievt->GetCweight();
502  const Float_t b_cost_j = jevt->GetCweight();
503 
504  b_alpha_i = ievt->GetAlpha();
505  b_alpha_j = jevt->GetAlpha();
506  b_alpha_i_p = ievt->GetAlpha_p();
507  b_alpha_j_p = jevt->GetAlpha_p();
508 
509  //calculate deltafi
510  Float_t deltafi = ievt->GetErrorCache()-jevt->GetErrorCache();
511 
512  // main loop
513  while(!terminated) {
514  const Float_t null = 0.; //!!! dummy float null declaration because of problems with TMath::Max/Min(Float_t, Float_t) function
515  Float_t low, high;
516  Float_t tmp_alpha_i, tmp_alpha_j;
517  tmp_alpha_i = tmp_alpha_j = 0.;
518 
519  //TODO check this conditions, are they proper
520  if((caseA == kFALSE) && (b_alpha_i > 0 || (b_alpha_i_p == 0 && deltafi > 0)) && (b_alpha_j > 0 || (b_alpha_j_p == 0 && deltafi < 0)))
521  {
522  //compute low, high w.r.t a_i, a_j
523  low = TMath::Max( null, gamma - b_cost_j );
524  high = TMath::Min( b_cost_i , gamma);
525 
526  if(low<high){
527  tmp_alpha_j = b_alpha_j - (deltafi/eta);
528  tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
529  tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
530  tmp_alpha_i = b_alpha_i - (tmp_alpha_j - b_alpha_j);
531 
532  //update Li & Lj if change is significant (??)
533  if( IsDiffSignificant(b_alpha_j,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i,tmp_alpha_i, epsilon)){
534  b_alpha_j = tmp_alpha_j;
535  b_alpha_i = tmp_alpha_i;
536  }
537 
538  }
539  else
540  terminated = kTRUE;
541 
542  caseA = kTRUE;
543  }
544  else if((caseB==kFALSE) && (b_alpha_i>0 || (b_alpha_i_p==0 && deltafi >2*epsilon )) && (b_alpha_j_p>0 || (b_alpha_j==0 && deltafi>2*epsilon)))
545  {
546  //compute LH w.r.t. a_i, a_j*
547  low = TMath::Max( null, gamma ); //TODO
548  high = TMath::Min( b_cost_i , b_cost_j + gamma);
549 
550 
551  if(low<high){
552  tmp_alpha_j = b_alpha_j_p - ((deltafi-2*epsilon)/eta);
553  tmp_alpha_j = TMath::Min(tmp_alpha_j,high);
554  tmp_alpha_j = TMath::Max(low,tmp_alpha_j);
555  tmp_alpha_i = b_alpha_i - (tmp_alpha_j - b_alpha_j_p);
556 
557  //update alphai alphaj_p
558  if( IsDiffSignificant(b_alpha_j_p,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i,tmp_alpha_i, epsilon)){
559  b_alpha_j_p = tmp_alpha_j;
560  b_alpha_i = tmp_alpha_i;
561  }
562  }
563  else
564  terminated = kTRUE;
565 
566  caseB = kTRUE;
567  }
568  else if((caseC==kFALSE) && (b_alpha_i_p>0 || (b_alpha_i==0 && deltafi < -2*epsilon )) && (b_alpha_j>0 || (b_alpha_j_p==0 && deltafi< -2*epsilon)))
569  {
570  //compute LH w.r.t. alphai_p alphaj
571  low = TMath::Max(null, -gamma );
572  high = TMath::Min(b_cost_i, -gamma+b_cost_j);
573 
574  if(low<high){
575  tmp_alpha_j = b_alpha_j - ((deltafi+2*epsilon)/eta);
576  tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
577  tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
578  tmp_alpha_i = b_alpha_i_p - (tmp_alpha_j - b_alpha_j);
579 
580  //update alphai_p alphaj
581  if( IsDiffSignificant(b_alpha_j,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i_p,tmp_alpha_i, epsilon)){
582  b_alpha_j = tmp_alpha_j;
583  b_alpha_i_p = tmp_alpha_i;
584  }
585  }
586  else
587  terminated = kTRUE;
588 
589  caseC = kTRUE;
590  }
591  else if((caseD == kFALSE) &&
592  (b_alpha_i_p>0 || (b_alpha_i==0 && deltafi <0 )) &&
593  (b_alpha_j_p>0 || (b_alpha_j==0 && deltafi >0 )))
594  {
595  //compute LH w.r.t. alphai_p alphaj_p
596  low = TMath::Max(null,-gamma - b_cost_j);
597  high = TMath::Min(b_cost_i, -gamma);
598 
599  if(low<high){
600  tmp_alpha_j = b_alpha_j_p + (deltafi/eta);
601  tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
602  tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
603  tmp_alpha_i = b_alpha_i_p - (tmp_alpha_j - b_alpha_j_p);
604 
605  if( IsDiffSignificant(b_alpha_j_p,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i_p,tmp_alpha_i, epsilon)){
606  b_alpha_j_p = tmp_alpha_j;
607  b_alpha_i_p = tmp_alpha_i;
608  }
609  }
610  else
611  terminated = kTRUE;
612 
613  caseD = kTRUE;
614  }
615  else
616  terminated = kTRUE;
617  }
618  // TODO ad commment how it was calculated
619  deltafi += ievt->GetDeltaAlpha()*(kernel_II - kernel_IJ) + jevt->GetDeltaAlpha()*(kernel_IJ - kernel_JJ);
620 
621  if( IsDiffSignificant(b_alpha_i, ievt->GetAlpha(), epsilon) ||
622  IsDiffSignificant(b_alpha_j, jevt->GetAlpha(), epsilon) ||
623  IsDiffSignificant(b_alpha_i_p, ievt->GetAlpha_p(), epsilon) ||
624  IsDiffSignificant(b_alpha_j_p, jevt->GetAlpha_p(), epsilon) ){
625 
626  //TODO check if these conditions might be easier
627  //TODO write documentation for this
628  const Float_t diff_alpha_i = ievt->GetDeltaAlpha()+b_alpha_i_p - ievt->GetAlpha();
629  const Float_t diff_alpha_j = jevt->GetDeltaAlpha()+b_alpha_j_p - jevt->GetAlpha();
630 
631  //update error cache
632  Int_t k = 0;
633  for(idIter = fInputData->begin(); idIter != fInputData->end(); idIter++){
634  k++;
635  //there will be some changes in Idx notation
636  if((*idIter)->GetIdx()==0){
637  Float_t k_ii = fKMatrix->GetElement(ievt->GetNs(), (*idIter)->GetNs());
638  Float_t k_jj = fKMatrix->GetElement(jevt->GetNs(), (*idIter)->GetNs());
639 
640  (*idIter)->UpdateErrorCache(diff_alpha_i * k_ii + diff_alpha_j * k_jj);
641  }
642  }
643 
644  //store new alphas in SVevents
645  ievt->SetAlpha(b_alpha_i);
646  jevt->SetAlpha(b_alpha_j);
647  ievt->SetAlpha_p(b_alpha_i_p);
648  jevt->SetAlpha_p(b_alpha_j_p);
649 
650  //TODO update Idexes
651 
652  // compute fI_low, fB_low
653 
654  fB_low = -1*1e30;
655  fB_up =1e30;
656 
657  for(idIter = fInputData->begin(); idIter != fInputData->end(); idIter++){
658  if((!(*idIter)->IsInI3()) && ((*idIter)->GetErrorCache()> fB_low)){
659  fB_low = (*idIter)->GetErrorCache();
660  fTEventLow = (*idIter);
661 
662  }
663  if((!(*idIter)->IsInI2()) && ((*idIter)->GetErrorCache()< fB_up)){
664  fB_up =(*idIter)->GetErrorCache();
665  fTEventUp = (*idIter);
666  }
667  }
668  return kTRUE;
669  } else return kFALSE;
670 }
671 
672 
673 ////////////////////////////////////////////////////////////////////////////////
674 
676 {
677  Float_t feps = 1e-7;// TODO check which value is the best
678  SVEvent* ievt=0;
679  Float_t fErrorC_J = 0.;
680  if( jevt->IsInI0()) {
681  fErrorC_J = jevt->GetErrorCache();
682  }
683  else{
684  Float_t *fKVals = jevt->GetLine();
685  fErrorC_J = 0.;
686  std::vector<TMVA::SVEvent*>::iterator idIter;
687 
688  UInt_t k=0;
689  for(idIter = fInputData->begin(); idIter != fInputData->end(); idIter++){
690  fErrorC_J -= (*idIter)->GetDeltaAlpha()*fKVals[k];
691  k++;
692  }
693 
694  fErrorC_J += jevt->GetTarget();
695  jevt->SetErrorCache(fErrorC_J);
696 
697  if(jevt->IsInI1()){
698  if(fErrorC_J + feps < fB_up ){
699  fB_up = fErrorC_J + feps;
700  fTEventUp = jevt;
701  }
702  else if(fErrorC_J -feps > fB_low) {
703  fB_low = fErrorC_J - feps;
704  fTEventLow = jevt;
705  }
706  }else if((jevt->IsInI2()) && (fErrorC_J + feps > fB_low)){
707  fB_low = fErrorC_J + feps;
708  fTEventLow = jevt;
709  }else if((jevt->IsInI3()) && (fErrorC_J - feps < fB_up)){
710  fB_up = fErrorC_J - feps;
711  fTEventUp = jevt;
712  }
713  }
714 
715  Bool_t converged = kTRUE;
716  //case 1
717  if(jevt->IsInI0a()){
718  if( fB_low -fErrorC_J + feps > 2*fTolerance){
719  converged = kFALSE;
720  ievt = fTEventLow;
721  if(fErrorC_J-feps-fB_up > fB_low-fErrorC_J+feps){
722  ievt = fTEventUp;
723  }
724  }else if(fErrorC_J -feps - fB_up > 2*fTolerance){
725  converged = kFALSE;
726  ievt = fTEventUp;
727  if(fB_low - fErrorC_J+feps > fErrorC_J-feps -fB_up){
728  ievt = fTEventLow;
729  }
730  }
731  }
732 
733  //case 2
734  if(jevt->IsInI0b()){
735  if( fB_low -fErrorC_J - feps > 2*fTolerance){
736  converged = kFALSE;
737  ievt = fTEventLow;
738  if(fErrorC_J+feps-fB_up > fB_low-fErrorC_J-feps){
739  ievt = fTEventUp;
740  }
741  }else if(fErrorC_J + feps - fB_up > 2*fTolerance){
742  converged = kFALSE;
743  ievt = fTEventUp;
744  if(fB_low - fErrorC_J-feps > fErrorC_J+feps -fB_up){
745  ievt = fTEventLow;
746  }
747  }
748  }
749 
750  //case 3
751  if(jevt->IsInI1()){
752  if( fB_low -fErrorC_J - feps > 2*fTolerance){
753  converged = kFALSE;
754  ievt = fTEventLow;
755  if(fErrorC_J+feps-fB_up > fB_low-fErrorC_J-feps){
756  ievt = fTEventUp;
757  }
758  }else if(fErrorC_J - feps - fB_up > 2*fTolerance){
759  converged = kFALSE;
760  ievt = fTEventUp;
761  if(fB_low - fErrorC_J+feps > fErrorC_J-feps -fB_up){
762  ievt = fTEventLow;
763  }
764  }
765  }
766 
767  //case 4
768  if(jevt->IsInI2()){
769  if( fErrorC_J + feps -fB_up > 2*fTolerance){
770  converged = kFALSE;
771  ievt = fTEventUp;
772  }
773  }
774 
775  //case 5
776  if(jevt->IsInI3()){
777  if(fB_low -fErrorC_J +feps > 2*fTolerance){
778  converged = kFALSE;
779  ievt = fTEventLow;
780  }
781  }
782 
783  if(converged) return kFALSE;
784  if (TakeStepReg(ievt, jevt)) return kTRUE;
785  else return kFALSE;
786 }
787 
789 {
790  if( TMath::Abs(a_i - a_j) > eps*(a_i + a_j + eps)) return kTRUE;
791  else return kFALSE;
792 }
793 
Random number generator class based on M.
Definition: TRandom3.h:29
Bool_t IsInI1() const
Definition: SVEvent.h:79
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
Float_t * GetLine(UInt_t)
returns a row of the kernel matrix
void SetAlpha(Float_t alpha)
Definition: SVEvent.h:53
Int_t GetIdx() const
Definition: SVEvent.h:70
~SVWorkingSet()
destructor
float Float_t
Definition: RtypesCore.h:53
SVWorkingSet()
constructor
Bool_t IsInI0() const
Definition: SVEvent.h:78
TH1 * h
Definition: legend2.C:5
std::vector< TMVA::SVEvent * > * fSupVec
Definition: SVWorkingSet.h:77
void SetIndex(TMVA::SVEvent *)
void Train(UInt_t nIter=1000)
train the SVM
Bool_t IsInI0b() const
Definition: SVEvent.h:77
Short_t Min(Short_t a, Short_t b)
Definition: TMathBase.h:170
int Int_t
Definition: RtypesCore.h:41
Bool_t TakeStep(SVEvent *, SVEvent *)
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
Float_t GetCweight() const
Definition: SVEvent.h:73
Short_t Abs(Short_t d)
Definition: TMathBase.h:110
UInt_t GetNs() const
Definition: SVEvent.h:72
MsgLogger * fLogger
Definition: SVWorkingSet.h:88
null_t< F > null()
void SetErrorCache(Float_t err_cache)
Definition: SVEvent.h:55
Float_t * GetLine() const
Definition: SVEvent.h:71
Bool_t IsInI3() const
Definition: SVEvent.h:81
Float_t GetElement(UInt_t i, UInt_t j)
returns an element of the kernel matrix
SVKernelFunction * fKFunction
Definition: SVWorkingSet.h:78
virtual UInt_t Integer(UInt_t imax)
Returns a random integer on [ 0, imax-1 ].
Definition: TRandom.cxx:320
Float_t GetAlpha_p() const
Definition: SVEvent.h:64
Bool_t ExamineExampleReg(SVEvent *)
bool * fExitFromTraining
Definition: SVWorkingSet.h:92
Float_t GetErrorCache() const
Definition: SVEvent.h:67
double gamma(double x)
const double tol
TPaveText * pt
Int_t GetTypeFlag() const
Definition: SVEvent.h:68
unsigned int UInt_t
Definition: RtypesCore.h:42
TLine * l
Definition: textangle.C:4
Float_t GetDeltaAlpha() const
Definition: SVEvent.h:65
Bool_t IsInI0a() const
Definition: SVEvent.h:76
Bool_t ExamineExample(SVEvent *)
Float_t GetAlpha() const
Definition: SVEvent.h:63
REAL epsilon
Definition: triangle.c:617
void SetIdx(Int_t idx)
Definition: SVEvent.h:58
Bool_t TakeStepReg(SVEvent *, SVEvent *)
std::vector< TMVA::SVEvent * > * GetSupportVectors()
Float_t GetTarget() const
Definition: SVEvent.h:74
you should not use this method at all Int_t Int_t Double_t Double_t Double_t e
Definition: TRolke.cxx:630
Bool_t IsInI2() const
Definition: SVEvent.h:80
SVKernelMatrix * fKMatrix
Definition: SVWorkingSet.h:79
Bool_t IsDiffSignificant(Float_t, Float_t, Float_t)
Short_t Max(Short_t a, Short_t b)
Definition: TMathBase.h:202
void SetAlpha_p(Float_t alpha)
Definition: SVEvent.h:54
std::vector< TMVA::SVEvent * > * fInputData
Definition: SVWorkingSet.h:76
UInt_t * fIPyCurrentIter
message logger
Definition: SVWorkingSet.h:91
const Bool_t kTRUE
Definition: Rtypes.h:91
SVEvent * fTEventUp
Definition: SVWorkingSet.h:81
SVEvent * fTEventLow
Definition: SVWorkingSet.h:82