Logo ROOT  
Reference Guide
SVWorkingSet.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andrzej Zemla
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : SVWorkingSet *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation *
12  * *
13  * Authors (alphabetical): *
14  * Marcin Wolter <Marcin.Wolter@cern.ch> - IFJ PAN, Krakow, Poland *
15  * Andrzej Zemla <azemla@cern.ch> - IFJ PAN, Krakow, Poland *
16  * (IFJ PAN: Henryk Niewodniczanski Inst. Nucl. Physics, Krakow, Poland) *
17  * *
18  * Copyright (c) 2005: *
19  * CERN, Switzerland *
20  * MPI-K Heidelberg, Germany *
21  * PAN, Krakow, Poland *
22  * *
23  * Redistribution and use in source and binary forms, with or without *
24  * modification, are permitted according to the terms listed in LICENSE *
25  * (http://tmva.sourceforge.net/LICENSE) *
26  **********************************************************************************/
27 
28 /*! \class TMVA::SVWorkingSet
29 \ingroup TMVA
30 Working class for Support Vector Machine
31 */
32 
33 #include "TMVA/SVWorkingSet.h"
34 
35 #include "TMVA/MsgLogger.h"
36 #include "TMVA/SVEvent.h"
37 #include "TMVA/SVKernelFunction.h"
38 #include "TMVA/SVKernelMatrix.h"
39 #include "TMVA/Types.h"
40 
41 
42 #include "TMath.h"
43 #include "TRandom3.h"
44 
45 #include <vector>
46 
47 ////////////////////////////////////////////////////////////////////////////////
48 /// constructor
49 
51  : fdoRegression(kFALSE),
52  fInputData(0),
53  fSupVec(0),
54  fKFunction(0),
55  fKMatrix(0),
56  fTEventUp(0),
57  fTEventLow(0),
58  fB_low(1.),
59  fB_up(-1.),
60  fTolerance(0.01),
61  fLogger( new MsgLogger( "SVWorkingSet", kINFO ) )
62 {
63 }
64 
65 ////////////////////////////////////////////////////////////////////////////////
66 /// constructor
67 
68 TMVA::SVWorkingSet::SVWorkingSet(std::vector<TMVA::SVEvent*>*inputVectors, SVKernelFunction* kernelFunction,
69  Float_t tol, Bool_t doreg)
70  : fdoRegression(doreg),
71  fInputData(inputVectors),
72  fSupVec(0),
73  fKFunction(kernelFunction),
74  fTEventUp(0),
75  fTEventLow(0),
76  fB_low(1.),
77  fB_up(-1.),
78  fTolerance(tol),
79  fLogger( new MsgLogger( "SVWorkingSet", kINFO ) )
80 {
81  fKMatrix = new TMVA::SVKernelMatrix(inputVectors, kernelFunction);
82  Float_t *pt;
83  for( UInt_t i = 0; i < fInputData->size(); i++){
84  pt = fKMatrix->GetLine(i);
85  fInputData->at(i)->SetLine(pt);
86  fInputData->at(i)->SetNs(i);
87  if(fdoRegression) fInputData->at(i)->SetErrorCache(fInputData->at(i)->GetTarget());
88  }
89  TRandom3 rand;
90  UInt_t kk = rand.Integer(fInputData->size());
91  if(fdoRegression) {
95  }
96  else{
97  while(1){
98  if(fInputData->at(kk)->GetTypeFlag()==-1){
99  fTEventLow = fInputData->at(kk);
100  break;
101  }
102  kk = rand.Integer(fInputData->size());
103  }
104 
105  while (1){
106  if (fInputData->at(kk)->GetTypeFlag()==1) {
107  fTEventUp = fInputData->at(kk);
108  break;
109  }
110  kk = rand.Integer(fInputData->size());
111  }
112  }
115 }
116 
117 ////////////////////////////////////////////////////////////////////////////////
118 /// destructor
119 
121 {
122  if (fKMatrix != 0) {delete fKMatrix; fKMatrix = 0;}
123  delete fLogger;
124 }
125 
126 ////////////////////////////////////////////////////////////////////////////////
127 
129 {
130  SVEvent* ievt=0;
131  Float_t fErrorC_J = 0.;
132  if( jevt->GetIdx()==0) fErrorC_J = jevt->GetErrorCache();
133  else{
134  Float_t *fKVals = jevt->GetLine();
135  fErrorC_J = 0.;
136  std::vector<TMVA::SVEvent*>::iterator idIter;
137 
138  UInt_t k=0;
139  for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
140  if((*idIter)->GetAlpha()>0)
141  fErrorC_J += (*idIter)->GetAlpha()*(*idIter)->GetTypeFlag()*fKVals[k];
142  k++;
143  }
144 
145 
146  fErrorC_J -= jevt->GetTypeFlag();
147  jevt->SetErrorCache(fErrorC_J);
148 
149  if((jevt->GetIdx() == 1) && (fErrorC_J < fB_up )){
150  fB_up = fErrorC_J;
151  fTEventUp = jevt;
152  }
153  else if ((jevt->GetIdx() == -1)&&(fErrorC_J > fB_low)) {
154  fB_low = fErrorC_J;
155  fTEventLow = jevt;
156  }
157  }
158  Bool_t converged = kTRUE;
159 
160  if((jevt->GetIdx()>=0) && (fB_low - fErrorC_J > 2*fTolerance)) {
161  converged = kFALSE;
162  ievt = fTEventLow;
163  }
164 
165  if((jevt->GetIdx()<=0) && (fErrorC_J - fB_up > 2*fTolerance)) {
166  converged = kFALSE;
167  ievt = fTEventUp;
168  }
169 
170  if (converged) return kFALSE;
171 
172  if(jevt->GetIdx()==0){
173  if(fB_low - fErrorC_J > fErrorC_J - fB_up) ievt = fTEventLow;
174  else ievt = fTEventUp;
175  }
176 
177  if (TakeStep(ievt, jevt)) return kTRUE;
178  else return kFALSE;
179 }
180 
181 
182 ////////////////////////////////////////////////////////////////////////////////
183 
185 {
186  if (ievt == jevt) return kFALSE;
187  std::vector<TMVA::SVEvent*>::iterator idIter;
188  const Float_t epsilon = 1e-8; //make it 1-e6 or 1-e5 to make it faster
189 
190  Float_t type_I, type_J;
191  Float_t errorC_I, errorC_J;
192  Float_t alpha_I, alpha_J;
193 
194  Float_t newAlpha_I, newAlpha_J;
195  Int_t s;
196 
197  Float_t l, h, lobj = 0, hobj = 0;
198  Float_t eta;
199 
200  type_I = ievt->GetTypeFlag();
201  alpha_I = ievt->GetAlpha();
202  errorC_I = ievt->GetErrorCache();
203 
204  type_J = jevt->GetTypeFlag();
205  alpha_J = jevt->GetAlpha();
206  errorC_J = jevt->GetErrorCache();
207 
208  s = Int_t( type_I * type_J );
209 
210  Float_t c_i = ievt->GetCweight();
211 
212  Float_t c_j = jevt->GetCweight();
213 
214  // compute l, h objective function
215 
216  if (type_I == type_J) {
217  Float_t gamma = alpha_I + alpha_J;
218 
219  if ( c_i > c_j ) {
220  if ( gamma < c_j ) {
221  l = 0;
222  h = gamma;
223  }
224  else{
225  h = c_j;
226  if ( gamma < c_i )
227  l = 0;
228  else
229  l = gamma - c_i;
230  }
231  }
232  else {
233  if ( gamma < c_i ){
234  l = 0;
235  h = gamma;
236  }
237  else {
238  l = gamma - c_i;
239  if ( gamma < c_j )
240  h = gamma;
241  else
242  h = c_j;
243  }
244  }
245  }
246  else {
247  Float_t gamma = alpha_I - alpha_J;
248  if (gamma > 0) {
249  l = 0;
250  if ( gamma >= (c_i - c_j) )
251  h = c_i - gamma;
252  else
253  h = c_j;
254  }
255  else {
256  l = -gamma;
257  if ( (c_i - c_j) >= gamma)
258  h = c_j;
259  else
260  h = c_i - gamma;
261  }
262  }
263 
264  if (l == h) return kFALSE;
265  Float_t kernel_II, kernel_IJ, kernel_JJ;
266 
267  kernel_II = fKMatrix->GetElement(ievt->GetNs(),ievt->GetNs());
268  kernel_IJ = fKMatrix->GetElement(ievt->GetNs(), jevt->GetNs());
269  kernel_JJ = fKMatrix->GetElement(jevt->GetNs(),jevt->GetNs());
270 
271  eta = 2*kernel_IJ - kernel_II - kernel_JJ;
272  if (eta < 0) {
273  newAlpha_J = alpha_J + (type_J*( errorC_J - errorC_I ))/eta;
274  if (newAlpha_J < l) newAlpha_J = l;
275  else if (newAlpha_J > h) newAlpha_J = h;
276 
277  }
278 
279  else {
280 
281  Float_t c_I = eta/2;
282  Float_t c_J = type_J*( errorC_I - errorC_J ) - eta * alpha_J;
283  lobj = c_I * l * l + c_J * l;
284  hobj = c_I * h * h + c_J * h;
285 
286  if (lobj > hobj + epsilon) newAlpha_J = l;
287  else if (lobj < hobj - epsilon) newAlpha_J = h;
288  else newAlpha_J = alpha_J;
289  }
290 
291  if (TMath::Abs( newAlpha_J - alpha_J ) < ( epsilon * ( newAlpha_J + alpha_J+ epsilon ))){
292  return kFALSE;
293  //it spends here to much time... it is stupido
294  }
295  newAlpha_I = alpha_I - s*( newAlpha_J - alpha_J );
296 
297  if (newAlpha_I < 0) {
298  newAlpha_J += s* newAlpha_I;
299  newAlpha_I = 0;
300  }
301  else if (newAlpha_I > c_i) {
302  Float_t temp = newAlpha_I - c_i;
303  newAlpha_J += s * temp;
304  newAlpha_I = c_i;
305  }
306 
307  Float_t dL_I = type_I * ( newAlpha_I - alpha_I );
308  Float_t dL_J = type_J * ( newAlpha_J - alpha_J );
309 
310  Int_t k = 0;
311  for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
312  k++;
313  if((*idIter)->GetIdx()==0){
314  Float_t ii = fKMatrix->GetElement(ievt->GetNs(), (*idIter)->GetNs());
315  Float_t jj = fKMatrix->GetElement(jevt->GetNs(), (*idIter)->GetNs());
316 
317  (*idIter)->UpdateErrorCache(dL_I * ii + dL_J * jj);
318  }
319  }
320  ievt->SetAlpha(newAlpha_I);
321  jevt->SetAlpha(newAlpha_J);
322  // set new indexes
323  SetIndex(ievt);
324  SetIndex(jevt);
325 
326  // update error cache
327  ievt->SetErrorCache(errorC_I + dL_I*kernel_II + dL_J*kernel_IJ);
328  jevt->SetErrorCache(errorC_J + dL_I*kernel_IJ + dL_J*kernel_JJ);
329 
330  // compute fI_low, fB_low
331 
332  fB_low = -1*1e30;
333  fB_up = 1e30;
334 
335  for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
336  if((*idIter)->GetIdx()==0){
337  if((*idIter)->GetErrorCache()> fB_low){
338  fB_low = (*idIter)->GetErrorCache();
339  fTEventLow = (*idIter);
340  }
341  if( (*idIter)->GetErrorCache()< fB_up){
342  fB_up =(*idIter)->GetErrorCache();
343  fTEventUp = (*idIter);
344  }
345  }
346  }
347 
348  // for optimized alfa's
349  if (fB_low < TMath::Max(ievt->GetErrorCache(), jevt->GetErrorCache())) {
350  if (ievt->GetErrorCache() > fB_low) {
351  fB_low = ievt->GetErrorCache();
352  fTEventLow = ievt;
353  }
354  else {
355  fB_low = jevt->GetErrorCache();
356  fTEventLow = jevt;
357  }
358  }
359 
360  if (fB_up > TMath::Max(ievt->GetErrorCache(), jevt->GetErrorCache())) {
361  if (ievt->GetErrorCache()< fB_low) {
362  fB_up =ievt->GetErrorCache();
363  fTEventUp = ievt;
364  }
365  else {
366  fB_up =jevt->GetErrorCache() ;
367  fTEventUp = jevt;
368  }
369  }
370  return kTRUE;
371 }
372 
373 ////////////////////////////////////////////////////////////////////////////////
374 
376 {
377  if((fB_up > fB_low - 2*fTolerance)) return kTRUE;
378  return kFALSE;
379 }
380 
381 ////////////////////////////////////////////////////////////////////////////////
382 /// train the SVM
383 
385 {
386 
387  Int_t numChanged = 0;
388  Int_t examineAll = 1;
389 
390  Float_t numChangedOld = 0;
391  Int_t deltaChanges = 0;
392  UInt_t numit = 0;
393 
394  std::vector<TMVA::SVEvent*>::iterator idIter;
395 
396  while ((numChanged > 0) || (examineAll > 0)) {
397  if (fIPyCurrentIter) *fIPyCurrentIter = numit;
398  if (fExitFromTraining && *fExitFromTraining) break;
399  numChanged = 0;
400  if (examineAll) {
401  for (idIter = fInputData->begin(); idIter!=fInputData->end(); ++idIter){
402  if(!fdoRegression) numChanged += (UInt_t)ExamineExample(*idIter);
403  else numChanged += (UInt_t)ExamineExampleReg(*idIter);
404  }
405  }
406  else {
407  for (idIter = fInputData->begin(); idIter!=fInputData->end(); ++idIter) {
408  if ((*idIter)->IsInI0()) {
409  if(!fdoRegression) numChanged += (UInt_t)ExamineExample(*idIter);
410  else numChanged += (UInt_t)ExamineExampleReg(*idIter);
411  if (Terminated()) {
412  numChanged = 0;
413  break;
414  }
415  }
416  }
417  }
418 
419  if (examineAll == 1) examineAll = 0;
420  else if (numChanged == 0 || numChanged < 10 || deltaChanges > 3 ) examineAll = 1;
421 
422  if (numChanged == numChangedOld) deltaChanges++;
423  else deltaChanges = 0;
424  numChangedOld = numChanged;
425  ++numit;
426 
427  if (numit >= nMaxIter) {
428  *fLogger << kWARNING
429  << "Max number of iterations exceeded. "
430  << "Training may not be completed. Try use less Cost parameter" << Endl;
431  break;
432  }
433  }
434 }
435 
436 ////////////////////////////////////////////////////////////////////////////////
437 
439 {
440  if( (0< event->GetAlpha()) && (event->GetAlpha()< event->GetCweight()))
441  event->SetIdx(0);
442 
443  if( event->GetTypeFlag() == 1){
444  if( event->GetAlpha() == 0)
445  event->SetIdx(1);
446  else if( event->GetAlpha() == event->GetCweight() )
447  event->SetIdx(-1);
448  }
449  if( event->GetTypeFlag() == -1){
450  if( event->GetAlpha() == 0)
451  event->SetIdx(-1);
452  else if( event->GetAlpha() == event->GetCweight() )
453  event->SetIdx(1);
454  }
455 }
456 
457 ////////////////////////////////////////////////////////////////////////////////
458 
460 {
461  std::vector<TMVA::SVEvent*>::iterator idIter;
462  UInt_t counter = 0;
463  for( idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter)
464  if((*idIter)->GetAlpha() !=0) counter++;
465 }
466 
467 ////////////////////////////////////////////////////////////////////////////////
468 
469 std::vector<TMVA::SVEvent*>* TMVA::SVWorkingSet::GetSupportVectors()
470 {
471  std::vector<TMVA::SVEvent*>::iterator idIter;
472  if( fSupVec != 0) {delete fSupVec; fSupVec = 0; }
473  fSupVec = new std::vector<TMVA::SVEvent*>(0);
474 
475  for( idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
476  if((*idIter)->GetDeltaAlpha() !=0){
477  fSupVec->push_back((*idIter));
478  }
479  }
480  return fSupVec;
481 }
482 
483 //for regression
484 
486 {
487  if (ievt == jevt) return kFALSE;
488  std::vector<TMVA::SVEvent*>::iterator idIter;
489  const Float_t epsilon = 0.001*fTolerance;//TODO
490 
491  const Float_t kernel_II = fKMatrix->GetElement(ievt->GetNs(),ievt->GetNs());
492  const Float_t kernel_IJ = fKMatrix->GetElement(ievt->GetNs(),jevt->GetNs());
493  const Float_t kernel_JJ = fKMatrix->GetElement(jevt->GetNs(),jevt->GetNs());
494 
495  //compute eta & gamma
496  const Float_t eta = -2*kernel_IJ + kernel_II + kernel_JJ;
497  const Float_t gamma = ievt->GetDeltaAlpha() + jevt->GetDeltaAlpha();
498 
499  //TODO CHECK WHAT IF ETA <0
500  //w.r.t Mercer's conditions it should never happen, but what if?
501 
502  Bool_t caseA, caseB, caseC, caseD, terminated;
503  caseA = caseB = caseC = caseD = terminated = kFALSE;
504  Float_t b_alpha_i, b_alpha_j, b_alpha_i_p, b_alpha_j_p; //temporary Lagrange multipliers
505  const Float_t b_cost_i = ievt->GetCweight();
506  const Float_t b_cost_j = jevt->GetCweight();
507 
508  b_alpha_i = ievt->GetAlpha();
509  b_alpha_j = jevt->GetAlpha();
510  b_alpha_i_p = ievt->GetAlpha_p();
511  b_alpha_j_p = jevt->GetAlpha_p();
512 
513  //calculate deltafi
514  Float_t deltafi = ievt->GetErrorCache()-jevt->GetErrorCache();
515 
516  // main loop
517  while(!terminated) {
518  const Float_t null = 0.; //!!! dummy float null declaration because of problems with TMath::Max/Min(Float_t, Float_t) function
519  Float_t low, high;
520  Float_t tmp_alpha_i, tmp_alpha_j;
521  tmp_alpha_i = tmp_alpha_j = 0.;
522 
523  //TODO check this conditions, are they proper
524  if((caseA == kFALSE) && (b_alpha_i > 0 || (b_alpha_i_p == 0 && deltafi > 0)) && (b_alpha_j > 0 || (b_alpha_j_p == 0 && deltafi < 0)))
525  {
526  //compute low, high w.r.t a_i, a_j
527  low = TMath::Max( null, gamma - b_cost_j );
528  high = TMath::Min( b_cost_i , gamma);
529 
530  if(low<high){
531  tmp_alpha_j = b_alpha_j - (deltafi/eta);
532  tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
533  tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
534  tmp_alpha_i = b_alpha_i - (tmp_alpha_j - b_alpha_j);
535 
536  //update Li & Lj if change is significant (??)
537  if( IsDiffSignificant(b_alpha_j,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i,tmp_alpha_i, epsilon)){
538  b_alpha_j = tmp_alpha_j;
539  b_alpha_i = tmp_alpha_i;
540  }
541 
542  }
543  else
544  terminated = kTRUE;
545 
546  caseA = kTRUE;
547  }
548  else if((caseB==kFALSE) && (b_alpha_i>0 || (b_alpha_i_p==0 && deltafi >2*epsilon )) && (b_alpha_j_p>0 || (b_alpha_j==0 && deltafi>2*epsilon)))
549  {
550  //compute LH w.r.t. a_i, a_j*
551  low = TMath::Max( null, gamma ); //TODO
552  high = TMath::Min( b_cost_i , b_cost_j + gamma);
553 
554 
555  if(low<high){
556  tmp_alpha_j = b_alpha_j_p - ((deltafi-2*epsilon)/eta);
557  tmp_alpha_j = TMath::Min(tmp_alpha_j,high);
558  tmp_alpha_j = TMath::Max(low,tmp_alpha_j);
559  tmp_alpha_i = b_alpha_i - (tmp_alpha_j - b_alpha_j_p);
560 
561  //update alphai alphaj_p
562  if( IsDiffSignificant(b_alpha_j_p,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i,tmp_alpha_i, epsilon)){
563  b_alpha_j_p = tmp_alpha_j;
564  b_alpha_i = tmp_alpha_i;
565  }
566  }
567  else
568  terminated = kTRUE;
569 
570  caseB = kTRUE;
571  }
572  else if((caseC==kFALSE) && (b_alpha_i_p>0 || (b_alpha_i==0 && deltafi < -2*epsilon )) && (b_alpha_j>0 || (b_alpha_j_p==0 && deltafi< -2*epsilon)))
573  {
574  //compute LH w.r.t. alphai_p alphaj
575  low = TMath::Max(null, -gamma );
576  high = TMath::Min(b_cost_i, -gamma+b_cost_j);
577 
578  if(low<high){
579  tmp_alpha_j = b_alpha_j - ((deltafi+2*epsilon)/eta);
580  tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
581  tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
582  tmp_alpha_i = b_alpha_i_p - (tmp_alpha_j - b_alpha_j);
583 
584  //update alphai_p alphaj
585  if( IsDiffSignificant(b_alpha_j,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i_p,tmp_alpha_i, epsilon)){
586  b_alpha_j = tmp_alpha_j;
587  b_alpha_i_p = tmp_alpha_i;
588  }
589  }
590  else
591  terminated = kTRUE;
592 
593  caseC = kTRUE;
594  }
595  else if((caseD == kFALSE) &&
596  (b_alpha_i_p>0 || (b_alpha_i==0 && deltafi <0 )) &&
597  (b_alpha_j_p>0 || (b_alpha_j==0 && deltafi >0 )))
598  {
599  //compute LH w.r.t. alphai_p alphaj_p
600  low = TMath::Max(null,-gamma - b_cost_j);
601  high = TMath::Min(b_cost_i, -gamma);
602 
603  if(low<high){
604  tmp_alpha_j = b_alpha_j_p + (deltafi/eta);
605  tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
606  tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
607  tmp_alpha_i = b_alpha_i_p - (tmp_alpha_j - b_alpha_j_p);
608 
609  if( IsDiffSignificant(b_alpha_j_p,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i_p,tmp_alpha_i, epsilon)){
610  b_alpha_j_p = tmp_alpha_j;
611  b_alpha_i_p = tmp_alpha_i;
612  }
613  }
614  else
615  terminated = kTRUE;
616 
617  caseD = kTRUE;
618  }
619  else
620  terminated = kTRUE;
621  }
622  // TODO ad commment how it was calculated
623  deltafi += ievt->GetDeltaAlpha()*(kernel_II - kernel_IJ) + jevt->GetDeltaAlpha()*(kernel_IJ - kernel_JJ);
624 
625  if( IsDiffSignificant(b_alpha_i, ievt->GetAlpha(), epsilon) ||
626  IsDiffSignificant(b_alpha_j, jevt->GetAlpha(), epsilon) ||
627  IsDiffSignificant(b_alpha_i_p, ievt->GetAlpha_p(), epsilon) ||
628  IsDiffSignificant(b_alpha_j_p, jevt->GetAlpha_p(), epsilon) ){
629 
630  //TODO check if these conditions might be easier
631  //TODO write documentation for this
632  const Float_t diff_alpha_i = ievt->GetDeltaAlpha()+b_alpha_i_p - ievt->GetAlpha();
633  const Float_t diff_alpha_j = jevt->GetDeltaAlpha()+b_alpha_j_p - jevt->GetAlpha();
634 
635  //update error cache
636  Int_t k = 0;
637  for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
638  k++;
639  //there will be some changes in Idx notation
640  if((*idIter)->GetIdx()==0){
641  Float_t k_ii = fKMatrix->GetElement(ievt->GetNs(), (*idIter)->GetNs());
642  Float_t k_jj = fKMatrix->GetElement(jevt->GetNs(), (*idIter)->GetNs());
643 
644  (*idIter)->UpdateErrorCache(diff_alpha_i * k_ii + diff_alpha_j * k_jj);
645  }
646  }
647 
648  //store new alphas in SVevents
649  ievt->SetAlpha(b_alpha_i);
650  jevt->SetAlpha(b_alpha_j);
651  ievt->SetAlpha_p(b_alpha_i_p);
652  jevt->SetAlpha_p(b_alpha_j_p);
653 
654  //TODO update Idexes
655 
656  // compute fI_low, fB_low
657 
658  fB_low = -1*1e30;
659  fB_up =1e30;
660 
661  for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
662  if((!(*idIter)->IsInI3()) && ((*idIter)->GetErrorCache()> fB_low)){
663  fB_low = (*idIter)->GetErrorCache();
664  fTEventLow = (*idIter);
665 
666  }
667  if((!(*idIter)->IsInI2()) && ((*idIter)->GetErrorCache()< fB_up)){
668  fB_up =(*idIter)->GetErrorCache();
669  fTEventUp = (*idIter);
670  }
671  }
672  return kTRUE;
673  } else return kFALSE;
674 }
675 
676 
677 ////////////////////////////////////////////////////////////////////////////////
678 
680 {
681  Float_t feps = 1e-7;// TODO check which value is the best
682  SVEvent* ievt=0;
683  Float_t fErrorC_J = 0.;
684  if( jevt->IsInI0()) {
685  fErrorC_J = jevt->GetErrorCache();
686  }
687  else{
688  Float_t *fKVals = jevt->GetLine();
689  fErrorC_J = 0.;
690  std::vector<TMVA::SVEvent*>::iterator idIter;
691 
692  UInt_t k=0;
693  for(idIter = fInputData->begin(); idIter != fInputData->end(); ++idIter){
694  fErrorC_J -= (*idIter)->GetDeltaAlpha()*fKVals[k];
695  k++;
696  }
697 
698  fErrorC_J += jevt->GetTarget();
699  jevt->SetErrorCache(fErrorC_J);
700 
701  if(jevt->IsInI1()){
702  if(fErrorC_J + feps < fB_up ){
703  fB_up = fErrorC_J + feps;
704  fTEventUp = jevt;
705  }
706  else if(fErrorC_J -feps > fB_low) {
707  fB_low = fErrorC_J - feps;
708  fTEventLow = jevt;
709  }
710  }else if((jevt->IsInI2()) && (fErrorC_J + feps > fB_low)){
711  fB_low = fErrorC_J + feps;
712  fTEventLow = jevt;
713  }else if((jevt->IsInI3()) && (fErrorC_J - feps < fB_up)){
714  fB_up = fErrorC_J - feps;
715  fTEventUp = jevt;
716  }
717  }
718 
719  Bool_t converged = kTRUE;
720  //case 1
721  if(jevt->IsInI0a()){
722  if( fB_low -fErrorC_J + feps > 2*fTolerance){
723  converged = kFALSE;
724  ievt = fTEventLow;
725  if(fErrorC_J-feps-fB_up > fB_low-fErrorC_J+feps){
726  ievt = fTEventUp;
727  }
728  }else if(fErrorC_J -feps - fB_up > 2*fTolerance){
729  converged = kFALSE;
730  ievt = fTEventUp;
731  if(fB_low - fErrorC_J+feps > fErrorC_J-feps -fB_up){
732  ievt = fTEventLow;
733  }
734  }
735  }
736 
737  //case 2
738  if(jevt->IsInI0b()){
739  if( fB_low -fErrorC_J - feps > 2*fTolerance){
740  converged = kFALSE;
741  ievt = fTEventLow;
742  if(fErrorC_J+feps-fB_up > fB_low-fErrorC_J-feps){
743  ievt = fTEventUp;
744  }
745  }else if(fErrorC_J + feps - fB_up > 2*fTolerance){
746  converged = kFALSE;
747  ievt = fTEventUp;
748  if(fB_low - fErrorC_J-feps > fErrorC_J+feps -fB_up){
749  ievt = fTEventLow;
750  }
751  }
752  }
753 
754  //case 3
755  if(jevt->IsInI1()){
756  if( fB_low -fErrorC_J - feps > 2*fTolerance){
757  converged = kFALSE;
758  ievt = fTEventLow;
759  if(fErrorC_J+feps-fB_up > fB_low-fErrorC_J-feps){
760  ievt = fTEventUp;
761  }
762  }else if(fErrorC_J - feps - fB_up > 2*fTolerance){
763  converged = kFALSE;
764  ievt = fTEventUp;
765  if(fB_low - fErrorC_J+feps > fErrorC_J-feps -fB_up){
766  ievt = fTEventLow;
767  }
768  }
769  }
770 
771  //case 4
772  if(jevt->IsInI2()){
773  if( fErrorC_J + feps -fB_up > 2*fTolerance){
774  converged = kFALSE;
775  ievt = fTEventUp;
776  }
777  }
778 
779  //case 5
780  if(jevt->IsInI3()){
781  if(fB_low -fErrorC_J +feps > 2*fTolerance){
782  converged = kFALSE;
783  ievt = fTEventLow;
784  }
785  }
786 
787  if(converged) return kFALSE;
788  if (TakeStepReg(ievt, jevt)) return kTRUE;
789  else return kFALSE;
790 }
791 
793 {
794  if( TMath::Abs(a_i - a_j) > eps*(a_i + a_j + eps)) return kTRUE;
795  else return kFALSE;
796 }
797 
l
auto * l
Definition: textangle.C:4
ROOT::Math::Cephes::gamma
double gamma(double x)
Definition: SpecFuncCephes.cxx:339
TMVA::SVEvent::SetErrorCache
void SetErrorCache(Float_t err_cache)
Definition: SVEvent.h:53
TMVA::SVKernelFunction
Kernel for Support Vector Machine.
Definition: SVKernelFunction.h:37
kTRUE
const Bool_t kTRUE
Definition: RtypesCore.h:91
e
#define e(i)
Definition: RSha256.hxx:103
TMVA::SVWorkingSet::fB_up
Float_t fB_up
Definition: SVWorkingSet.h:83
TMVA::SVEvent::IsInI0a
Bool_t IsInI0a() const
Definition: SVEvent.h:74
TMath::Max
Short_t Max(Short_t a, Short_t b)
Definition: TMathBase.h:212
TMVA::SVKernelMatrix
Kernel matrix for Support Vector Machine.
Definition: SVKernelMatrix.h:41
TMVA::SVWorkingSet::fB_low
Float_t fB_low
Definition: SVWorkingSet.h:82
TMVA::SVEvent::GetNs
UInt_t GetNs() const
Definition: SVEvent.h:70
TMVA::SVWorkingSet::IsDiffSignificant
Bool_t IsDiffSignificant(Float_t, Float_t, Float_t)
Definition: SVWorkingSet.cxx:792
TMVA::SVWorkingSet::Terminated
Bool_t Terminated()
Definition: SVWorkingSet.cxx:375
TMVA::SVWorkingSet::GetSupportVectors
std::vector< TMVA::SVEvent * > * GetSupportVectors()
Definition: SVWorkingSet.cxx:469
TMVA::SVEvent::GetErrorCache
Float_t GetErrorCache() const
Definition: SVEvent.h:65
Float_t
float Float_t
Definition: RtypesCore.h:57
TGeant4Unit::s
static constexpr double s
Definition: TGeant4SystemOfUnits.h:162
Int_t
int Int_t
Definition: RtypesCore.h:45
TMVA::SVWorkingSet::ExamineExample
Bool_t ExamineExample(SVEvent *)
Definition: SVWorkingSet.cxx:128
TMath::Abs
Short_t Abs(Short_t d)
Definition: TMathBase.h:120
TMVA::SVWorkingSet::SVWorkingSet
SVWorkingSet()
constructor
Definition: SVWorkingSet.cxx:50
bool
TMVA::SVWorkingSet::TakeStep
Bool_t TakeStep(SVEvent *, SVEvent *)
Definition: SVWorkingSet.cxx:184
TMVA::SVEvent::GetAlpha_p
Float_t GetAlpha_p() const
Definition: SVEvent.h:62
TMVA::SVWorkingSet::ExamineExampleReg
Bool_t ExamineExampleReg(SVEvent *)
Definition: SVWorkingSet.cxx:679
TMVA::SVEvent::GetLine
Float_t * GetLine() const
Definition: SVEvent.h:69
TMVA::SVEvent::GetIdx
Int_t GetIdx() const
Definition: SVEvent.h:68
TMVA::SVWorkingSet::~SVWorkingSet
~SVWorkingSet()
destructor
Definition: SVWorkingSet.cxx:120
MsgLogger.h
TMVA::SVEvent::SetAlpha_p
void SetAlpha_p(Float_t alpha)
Definition: SVEvent.h:52
TRandom3
Random number generator class based on M.
Definition: TRandom3.h:27
h
#define h(i)
Definition: RSha256.hxx:106
TMVA::SVWorkingSet::fInputData
std::vector< TMVA::SVEvent * > * fInputData
Definition: SVWorkingSet.h:74
epsilon
REAL epsilon
Definition: triangle.c:617
TMVA::SVWorkingSet::fKMatrix
SVKernelMatrix * fKMatrix
Definition: SVWorkingSet.h:77
kFALSE
const Bool_t kFALSE
Definition: RtypesCore.h:92
TRandom3.h
UInt_t
unsigned int UInt_t
Definition: RtypesCore.h:46
TMVA::SVEvent::GetAlpha
Float_t GetAlpha() const
Definition: SVEvent.h:61
Types.h
TRandom::Integer
virtual UInt_t Integer(UInt_t imax)
Returns a random integer uniformly distributed on the interval [ 0, imax-1 ].
Definition: TRandom.cxx:360
TMath::Min
Short_t Min(Short_t a, Short_t b)
Definition: TMathBase.h:180
TMVA::Endl
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
SVWorkingSet.h
SVKernelFunction.h
unsigned int
TMVA::SVEvent::IsInI1
Bool_t IsInI1() const
Definition: SVEvent.h:77
TMVA::SVEvent::IsInI0b
Bool_t IsInI0b() const
Definition: SVEvent.h:75
TMVA::SVWorkingSet::fTEventLow
SVEvent * fTEventLow
Definition: SVWorkingSet.h:80
TMVA::SVKernelMatrix::GetLine
Float_t * GetLine(UInt_t)
returns a row of the kernel matrix
Definition: SVKernelMatrix.cxx:100
TMVA::SVEvent
Event class for Support Vector Machine.
Definition: SVEvent.h:40
TMVA::SVEvent::IsInI0
Bool_t IsInI0() const
Definition: SVEvent.h:76
TMVA::MsgLogger
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
TMVA::SVEvent::GetTarget
Float_t GetTarget() const
Definition: SVEvent.h:72
TMVA::SVWorkingSet::fTEventUp
SVEvent * fTEventUp
Definition: SVWorkingSet.h:79
SVKernelMatrix.h
TMVA::SVEvent::IsInI3
Bool_t IsInI3() const
Definition: SVEvent.h:79
TMVA::SVWorkingSet::TakeStepReg
Bool_t TakeStepReg(SVEvent *, SVEvent *)
Definition: SVWorkingSet.cxx:485
TMVA::SVWorkingSet::SetIndex
void SetIndex(TMVA::SVEvent *)
Definition: SVWorkingSet.cxx:438
TMVA::SVWorkingSet::Train
void Train(UInt_t nIter=1000)
train the SVM
Definition: SVWorkingSet.cxx:384
SVEvent.h
pt
TPaveText * pt
Definition: entrylist_figure1.C:7
TMVA::SVWorkingSet::fTolerance
Float_t fTolerance
Definition: SVWorkingSet.h:84
TMVA::SVWorkingSet::fdoRegression
Bool_t fdoRegression
Definition: SVWorkingSet.h:73
TMVA::SVEvent::SetAlpha
void SetAlpha(Float_t alpha)
Definition: SVEvent.h:51
TMVA::SVWorkingSet::PrintStat
void PrintStat()
Definition: SVWorkingSet.cxx:459
TMVA::SVEvent::IsInI2
Bool_t IsInI2() const
Definition: SVEvent.h:78
TMath.h
TMVA::SVEvent::GetCweight
Float_t GetCweight() const
Definition: SVEvent.h:71
int
TMVA::SVEvent::GetTypeFlag
Int_t GetTypeFlag() const
Definition: SVEvent.h:66
TMVA::SVEvent::GetDeltaAlpha
Float_t GetDeltaAlpha() const
Definition: SVEvent.h:63