ROOT  6.06/09
Reference Guide
SVWorkingSet.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andrzej Zemla
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : SVWorkingSet *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation *
12  * *
13  * Authors (alphabetical): *
14  * Marcin Wolter <Marcin.Wolter@cern.ch> - IFJ PAN, Krakow, Poland *
15  * Andrzej Zemla <azemla@cern.ch> - IFJ PAN, Krakow, Poland *
16  * (IFJ PAN: Henryk Niewodniczanski Inst. Nucl. Physics, Krakow, Poland) *
17  * *
18  * Copyright (c) 2005: *
19  * CERN, Switzerland *
20  * MPI-K Heidelberg, Germany *
21  * PAN, Krakow, Poland *
22  * *
23  * Redistribution and use in source and binary forms, with or without *
24  * modification, are permitted according to the terms listed in LICENSE *
25  * (http://tmva.sourceforge.net/LICENSE) *
26  **********************************************************************************/
27 
28 #include "TMath.h"
29 #include "TRandom3.h"
30 
31 #ifndef ROOT_TMVA_MsgLogger
32 #include "TMVA/MsgLogger.h"
33 #endif
34 #include "TMVA/SVWorkingSet.h"
35 #include "TMVA/SVKernelFunction.h"
36 #include "TMVA/SVEvent.h"
37 #include "TMVA/SVKernelMatrix.h"
38 
39 #include <vector>
40 #include <iostream>
41 
42 ////////////////////////////////////////////////////////////////////////////////
43 /// constructor
44 
46  : fdoRegression(kFALSE),
47  fInputData(0),
48  fSupVec(0),
49  fKFunction(0),
50  fKMatrix(0),
51  fTEventUp(0),
52  fTEventLow(0),
53  fB_low(1.),
54  fB_up(-1.),
55  fTolerance(0.01),
56  fLogger( new MsgLogger( "SVWorkingSet", kINFO ) )
57 {
58 }
59 
60 ////////////////////////////////////////////////////////////////////////////////
61 /// constructor
62 
63 TMVA::SVWorkingSet::SVWorkingSet(std::vector<TMVA::SVEvent*>*inputVectors, SVKernelFunction* kernelFunction,
64  Float_t tol, Bool_t doreg)
65  : fdoRegression(doreg),
66  fInputData(inputVectors),
67  fSupVec(0),
68  fKFunction(kernelFunction),
69  fTEventUp(0),
70  fTEventLow(0),
71  fB_low(1.),
72  fB_up(-1.),
73  fTolerance(tol),
74  fLogger( new MsgLogger( "SVWorkingSet", kINFO ) )
75 {
76  fKMatrix = new TMVA::SVKernelMatrix(inputVectors, kernelFunction);
77  Float_t *pt;
78  for( UInt_t i = 0; i < fInputData->size(); i++){
79  pt = fKMatrix->GetLine(i);
80  fInputData->at(i)->SetLine(pt);
81  fInputData->at(i)->SetNs(i);
82  if(fdoRegression) fInputData->at(i)->SetErrorCache(fInputData->at(i)->GetTarget());
83  }
84  TRandom3 rand;
85  UInt_t kk = rand.Integer(fInputData->size());
86  if(fdoRegression) {
90  }
91  else{
92  while(1){
93  if(fInputData->at(kk)->GetTypeFlag()==-1){
94  fTEventLow = fInputData->at(kk);
95  break;
96  }
97  kk = rand.Integer(fInputData->size());
98  }
99 
100  while (1){
101  if (fInputData->at(kk)->GetTypeFlag()==1) {
102  fTEventUp = fInputData->at(kk);
103  break;
104  }
105  kk = rand.Integer(fInputData->size());
106  }
107  }
110 }
111 
112 ////////////////////////////////////////////////////////////////////////////////
113 /// destructor
114 
116 {
117  if (fKMatrix != 0) {delete fKMatrix; fKMatrix = 0;}
118  delete fLogger;
119 }
120 
121 ////////////////////////////////////////////////////////////////////////////////
122 
124 {
125  SVEvent* ievt=0;
126  Float_t fErrorC_J = 0.;
127  if( jevt->GetIdx()==0) fErrorC_J = jevt->GetErrorCache();
128  else{
129  Float_t *fKVals = jevt->GetLine();
130  fErrorC_J = 0.;
131  std::vector<TMVA::SVEvent*>::iterator idIter;
132 
133  UInt_t k=0;
134  for(idIter = fInputData->begin(); idIter != fInputData->end(); idIter++){
135  if((*idIter)->GetAlpha()>0)
136  fErrorC_J += (*idIter)->GetAlpha()*(*idIter)->GetTypeFlag()*fKVals[k];
137  k++;
138  }
139 
140 
141  fErrorC_J -= jevt->GetTypeFlag();
142  jevt->SetErrorCache(fErrorC_J);
143 
144  if((jevt->GetIdx() == 1) && (fErrorC_J < fB_up )){
145  fB_up = fErrorC_J;
146  fTEventUp = jevt;
147  }
148  else if ((jevt->GetIdx() == -1)&&(fErrorC_J > fB_low)) {
149  fB_low = fErrorC_J;
150  fTEventLow = jevt;
151  }
152  }
153  Bool_t converged = kTRUE;
154 
155  if((jevt->GetIdx()>=0) && (fB_low - fErrorC_J > 2*fTolerance)) {
156  converged = kFALSE;
157  ievt = fTEventLow;
158  }
159 
160  if((jevt->GetIdx()<=0) && (fErrorC_J - fB_up > 2*fTolerance)) {
161  converged = kFALSE;
162  ievt = fTEventUp;
163  }
164 
165  if (converged) return kFALSE;
166 
167  if(jevt->GetIdx()==0){
168  if(fB_low - fErrorC_J > fErrorC_J - fB_up) ievt = fTEventLow;
169  else ievt = fTEventUp;
170  }
171 
172  if (TakeStep(ievt, jevt)) return kTRUE;
173  else return kFALSE;
174 }
175 
176 
177 ////////////////////////////////////////////////////////////////////////////////
178 
180 {
181  if (ievt == jevt) return kFALSE;
182  std::vector<TMVA::SVEvent*>::iterator idIter;
183  const Float_t epsilon = 1e-8; //make it 1-e6 or 1-e5 to make it faster
184 
185  Float_t type_I, type_J;
186  Float_t errorC_I, errorC_J;
187  Float_t alpha_I, alpha_J;
188 
189  Float_t newAlpha_I, newAlpha_J;
190  Int_t s;
191 
192  Float_t l, h, lobj = 0, hobj = 0;
193  Float_t eta;
194 
195  type_I = ievt->GetTypeFlag();
196  alpha_I = ievt->GetAlpha();
197  errorC_I = ievt->GetErrorCache();
198 
199  type_J = jevt->GetTypeFlag();
200  alpha_J = jevt->GetAlpha();
201  errorC_J = jevt->GetErrorCache();
202 
203  s = Int_t( type_I * type_J );
204 
205  Float_t c_i = ievt->GetCweight();
206 
207  Float_t c_j = jevt->GetCweight();
208 
209  // compute l, h objective function
210 
211  if (type_I == type_J) {
212  Float_t gamma = alpha_I + alpha_J;
213 
214  if ( c_i > c_j ) {
215  if ( gamma < c_j ) {
216  l = 0;
217  h = gamma;
218  }
219  else{
220  h = c_j;
221  if ( gamma < c_i )
222  l = 0;
223  else
224  l = gamma - c_i;
225  }
226  }
227  else {
228  if ( gamma < c_i ){
229  l = 0;
230  h = gamma;
231  }
232  else {
233  l = gamma - c_i;
234  if ( gamma < c_j )
235  h = gamma;
236  else
237  h = c_j;
238  }
239  }
240  }
241  else {
242  Float_t gamma = alpha_I - alpha_J;
243  if (gamma > 0) {
244  l = 0;
245  if ( gamma >= (c_i - c_j) )
246  h = c_i - gamma;
247  else
248  h = c_j;
249  }
250  else {
251  l = -gamma;
252  if ( (c_i - c_j) >= gamma)
253  h = c_j;
254  else
255  h = c_i - gamma;
256  }
257  }
258 
259  if (l == h) return kFALSE;
260  Float_t kernel_II, kernel_IJ, kernel_JJ;
261 
262  kernel_II = fKMatrix->GetElement(ievt->GetNs(),ievt->GetNs());
263  kernel_IJ = fKMatrix->GetElement(ievt->GetNs(), jevt->GetNs());
264  kernel_JJ = fKMatrix->GetElement(jevt->GetNs(),jevt->GetNs());
265 
266  eta = 2*kernel_IJ - kernel_II - kernel_JJ;
267  if (eta < 0) {
268  newAlpha_J = alpha_J + (type_J*( errorC_J - errorC_I ))/eta;
269  if (newAlpha_J < l) newAlpha_J = l;
270  else if (newAlpha_J > h) newAlpha_J = h;
271 
272  }
273 
274  else {
275 
276  Float_t c_I = eta/2;
277  Float_t c_J = type_J*( errorC_I - errorC_J ) - eta * alpha_J;
278  lobj = c_I * l * l + c_J * l;
279  hobj = c_I * h * h + c_J * h;
280 
281  if (lobj > hobj + epsilon) newAlpha_J = l;
282  else if (lobj < hobj - epsilon) newAlpha_J = h;
283  else newAlpha_J = alpha_J;
284  }
285 
286  if (TMath::Abs( newAlpha_J - alpha_J ) < ( epsilon * ( newAlpha_J + alpha_J+ epsilon ))){
287  return kFALSE;
288  //it spends here to much time... it is stupido
289  }
290  newAlpha_I = alpha_I - s*( newAlpha_J - alpha_J );
291 
292  if (newAlpha_I < 0) {
293  newAlpha_J += s* newAlpha_I;
294  newAlpha_I = 0;
295  }
296  else if (newAlpha_I > c_i) {
297  Float_t temp = newAlpha_I - c_i;
298  newAlpha_J += s * temp;
299  newAlpha_I = c_i;
300  }
301 
302  Float_t dL_I = type_I * ( newAlpha_I - alpha_I );
303  Float_t dL_J = type_J * ( newAlpha_J - alpha_J );
304 
305  Int_t k = 0;
306  for(idIter = fInputData->begin(); idIter != fInputData->end(); idIter++){
307  k++;
308  if((*idIter)->GetIdx()==0){
309  Float_t ii = fKMatrix->GetElement(ievt->GetNs(), (*idIter)->GetNs());
310  Float_t jj = fKMatrix->GetElement(jevt->GetNs(), (*idIter)->GetNs());
311 
312  (*idIter)->UpdateErrorCache(dL_I * ii + dL_J * jj);
313  }
314  }
315  ievt->SetAlpha(newAlpha_I);
316  jevt->SetAlpha(newAlpha_J);
317  // set new indexes
318  SetIndex(ievt);
319  SetIndex(jevt);
320 
321  // update error cache
322  ievt->SetErrorCache(errorC_I + dL_I*kernel_II + dL_J*kernel_IJ);
323  jevt->SetErrorCache(errorC_J + dL_I*kernel_IJ + dL_J*kernel_JJ);
324 
325  // compute fI_low, fB_low
326 
327  fB_low = -1*1e30;
328  fB_up = 1e30;
329 
330  for(idIter = fInputData->begin(); idIter != fInputData->end(); idIter++){
331  if((*idIter)->GetIdx()==0){
332  if((*idIter)->GetErrorCache()> fB_low){
333  fB_low = (*idIter)->GetErrorCache();
334  fTEventLow = (*idIter);
335  }
336  if( (*idIter)->GetErrorCache()< fB_up){
337  fB_up =(*idIter)->GetErrorCache();
338  fTEventUp = (*idIter);
339  }
340  }
341  }
342 
343  // for optimized alfa's
344  if (fB_low < TMath::Max(ievt->GetErrorCache(), jevt->GetErrorCache())) {
345  if (ievt->GetErrorCache() > fB_low) {
346  fB_low = ievt->GetErrorCache();
347  fTEventLow = ievt;
348  }
349  else {
350  fB_low = jevt->GetErrorCache();
351  fTEventLow = jevt;
352  }
353  }
354 
355  if (fB_up > TMath::Max(ievt->GetErrorCache(), jevt->GetErrorCache())) {
356  if (ievt->GetErrorCache()< fB_low) {
357  fB_up =ievt->GetErrorCache();
358  fTEventUp = ievt;
359  }
360  else {
361  fB_up =jevt->GetErrorCache() ;
362  fTEventUp = jevt;
363  }
364  }
365  return kTRUE;
366 }
367 
368 ////////////////////////////////////////////////////////////////////////////////
369 
371 {
372  if((fB_up > fB_low - 2*fTolerance)) return kTRUE;
373  return kFALSE;
374 }
375 
376 ////////////////////////////////////////////////////////////////////////////////
377 /// train the SVM
378 
380 {
381 
382  Int_t numChanged = 0;
383  Int_t examineAll = 1;
384 
385  Float_t numChangedOld = 0;
386  Int_t deltaChanges = 0;
387  UInt_t numit = 0;
388 
389  std::vector<TMVA::SVEvent*>::iterator idIter;
390 
391  while ((numChanged > 0) || (examineAll > 0)) {
392  numChanged = 0;
393  if (examineAll) {
394  for (idIter = fInputData->begin(); idIter!=fInputData->end(); idIter++){
395  if(!fdoRegression) numChanged += (UInt_t)ExamineExample(*idIter);
396  else numChanged += (UInt_t)ExamineExampleReg(*idIter);
397  }
398  }
399  else {
400  for (idIter = fInputData->begin(); idIter!=fInputData->end(); idIter++) {
401  if ((*idIter)->IsInI0()) {
402  if(!fdoRegression) numChanged += (UInt_t)ExamineExample(*idIter);
403  else numChanged += (UInt_t)ExamineExampleReg(*idIter);
404  if (Terminated()) {
405  numChanged = 0;
406  break;
407  }
408  }
409  }
410  }
411 
412  if (examineAll == 1) examineAll = 0;
413  else if (numChanged == 0 || numChanged < 10 || deltaChanges > 3 ) examineAll = 1;
414 
415  if (numChanged == numChangedOld) deltaChanges++;
416  else deltaChanges = 0;
417  numChangedOld = numChanged;
418  ++numit;
419 
420  if (numit >= nMaxIter) {
421  *fLogger << kWARNING
422  << "Max number of iterations exceeded. "
423  << "Training may not be completed. Try use less Cost parameter" << Endl;
424  break;
425  }
426  }
427 }
428 
429 ////////////////////////////////////////////////////////////////////////////////
430 
432 {
433  if( (0< event->GetAlpha()) && (event->GetAlpha()< event->GetCweight()))
434  event->SetIdx(0);
435 
436  if( event->GetTypeFlag() == 1){
437  if( event->GetAlpha() == 0)
438  event->SetIdx(1);
439  else if( event->GetAlpha() == event->GetCweight() )
440  event->SetIdx(-1);
441  }
442  if( event->GetTypeFlag() == -1){
443  if( event->GetAlpha() == 0)
444  event->SetIdx(-1);
445  else if( event->GetAlpha() == event->GetCweight() )
446  event->SetIdx(1);
447  }
448 }
449 
450 ////////////////////////////////////////////////////////////////////////////////
451 
453 {
454  std::vector<TMVA::SVEvent*>::iterator idIter;
455  UInt_t counter = 0;
456  for( idIter = fInputData->begin(); idIter != fInputData->end(); idIter++)
457  if((*idIter)->GetAlpha() !=0) counter++;
458 }
459 
460 ////////////////////////////////////////////////////////////////////////////////
461 
462 std::vector<TMVA::SVEvent*>* TMVA::SVWorkingSet::GetSupportVectors()
463 {
464  std::vector<TMVA::SVEvent*>::iterator idIter;
465  if( fSupVec != 0) {delete fSupVec; fSupVec = 0; }
466  fSupVec = new std::vector<TMVA::SVEvent*>(0);
467 
468  for( idIter = fInputData->begin(); idIter != fInputData->end(); idIter++){
469  if((*idIter)->GetDeltaAlpha() !=0){
470  fSupVec->push_back((*idIter));
471  }
472  }
473  return fSupVec;
474 }
475 
476 //for regression
477 
479 {
480  if (ievt == jevt) return kFALSE;
481  std::vector<TMVA::SVEvent*>::iterator idIter;
482  const Float_t epsilon = 0.001*fTolerance;//TODO
483 
484  const Float_t kernel_II = fKMatrix->GetElement(ievt->GetNs(),ievt->GetNs());
485  const Float_t kernel_IJ = fKMatrix->GetElement(ievt->GetNs(),jevt->GetNs());
486  const Float_t kernel_JJ = fKMatrix->GetElement(jevt->GetNs(),jevt->GetNs());
487 
488  //compute eta & gamma
489  const Float_t eta = -2*kernel_IJ + kernel_II + kernel_JJ;
490  const Float_t gamma = ievt->GetDeltaAlpha() + jevt->GetDeltaAlpha();
491 
492  //TODO CHECK WHAT IF ETA <0
493  //w.r.t Mercer's conditions it should never happen, but what if?
494 
495  Bool_t caseA, caseB, caseC, caseD, terminated;
496  caseA = caseB = caseC = caseD = terminated = kFALSE;
497  Float_t b_alpha_i, b_alpha_j, b_alpha_i_p, b_alpha_j_p; //temporary lagrange multipliers
498  const Float_t b_cost_i = ievt->GetCweight();
499  const Float_t b_cost_j = jevt->GetCweight();
500 
501  b_alpha_i = ievt->GetAlpha();
502  b_alpha_j = jevt->GetAlpha();
503  b_alpha_i_p = ievt->GetAlpha_p();
504  b_alpha_j_p = jevt->GetAlpha_p();
505 
506  //calculate deltafi
507  Float_t deltafi = ievt->GetErrorCache()-jevt->GetErrorCache();
508 
509  // main loop
510  while(!terminated) {
511  const Float_t null = 0.; //!!! dummy float null declaration because of problems with TMath::Max/Min(Float_t, Float_t) function
512  Float_t low, high;
513  Float_t tmp_alpha_i, tmp_alpha_j;
514  tmp_alpha_i = tmp_alpha_j = 0.;
515 
516  //TODO check this conditions, are they proper
517  if((caseA == kFALSE) && (b_alpha_i > 0 || (b_alpha_i_p == 0 && deltafi > 0)) && (b_alpha_j > 0 || (b_alpha_j_p == 0 && deltafi < 0)))
518  {
519  //compute low, high w.r.t a_i, a_j
520  low = TMath::Max( null, gamma - b_cost_j );
521  high = TMath::Min( b_cost_i , gamma);
522 
523  if(low<high){
524  tmp_alpha_j = b_alpha_j - (deltafi/eta);
525  tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
526  tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
527  tmp_alpha_i = b_alpha_i - (tmp_alpha_j - b_alpha_j);
528 
529  //update Li & Lj if change is significant (??)
530  if( IsDiffSignificant(b_alpha_j,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i,tmp_alpha_i, epsilon)){
531  b_alpha_j = tmp_alpha_j;
532  b_alpha_i = tmp_alpha_i;
533  }
534 
535  }
536  else
537  terminated = kTRUE;
538 
539  caseA = kTRUE;
540  }
541  else if((caseB==kFALSE) && (b_alpha_i>0 || (b_alpha_i_p==0 && deltafi >2*epsilon )) && (b_alpha_j_p>0 || (b_alpha_j==0 && deltafi>2*epsilon)))
542  {
543  //compute LH w.r.t. a_i, a_j*
544  low = TMath::Max( null, gamma ); //TODO
545  high = TMath::Min( b_cost_i , b_cost_j + gamma);
546 
547 
548  if(low<high){
549  tmp_alpha_j = b_alpha_j_p - ((deltafi-2*epsilon)/eta);
550  tmp_alpha_j = TMath::Min(tmp_alpha_j,high);
551  tmp_alpha_j = TMath::Max(low,tmp_alpha_j);
552  tmp_alpha_i = b_alpha_i - (tmp_alpha_j - b_alpha_j_p);
553 
554  //update alphai alphaj_p
555  if( IsDiffSignificant(b_alpha_j_p,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i,tmp_alpha_i, epsilon)){
556  b_alpha_j_p = tmp_alpha_j;
557  b_alpha_i = tmp_alpha_i;
558  }
559  }
560  else
561  terminated = kTRUE;
562 
563  caseB = kTRUE;
564  }
565  else if((caseC==kFALSE) && (b_alpha_i_p>0 || (b_alpha_i==0 && deltafi < -2*epsilon )) && (b_alpha_j>0 || (b_alpha_j_p==0 && deltafi< -2*epsilon)))
566  {
567  //compute LH w.r.t. alphai_p alphaj
568  low = TMath::Max(null, -gamma );
569  high = TMath::Min(b_cost_i, -gamma+b_cost_j);
570 
571  if(low<high){
572  tmp_alpha_j = b_alpha_j - ((deltafi+2*epsilon)/eta);
573  tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
574  tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
575  tmp_alpha_i = b_alpha_i_p - (tmp_alpha_j - b_alpha_j);
576 
577  //update alphai_p alphaj
578  if( IsDiffSignificant(b_alpha_j,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i_p,tmp_alpha_i, epsilon)){
579  b_alpha_j = tmp_alpha_j;
580  b_alpha_i_p = tmp_alpha_i;
581  }
582  }
583  else
584  terminated = kTRUE;
585 
586  caseC = kTRUE;
587  }
588  else if((caseD == kFALSE) &&
589  (b_alpha_i_p>0 || (b_alpha_i==0 && deltafi <0 )) &&
590  (b_alpha_j_p>0 || (b_alpha_j==0 && deltafi >0 )))
591  {
592  //compute LH w.r.t. alphai_p alphaj_p
593  low = TMath::Max(null,-gamma - b_cost_j);
594  high = TMath::Min(b_cost_i, -gamma);
595 
596  if(low<high){
597  tmp_alpha_j = b_alpha_j_p + (deltafi/eta);
598  tmp_alpha_j = TMath::Min(tmp_alpha_j,high );
599  tmp_alpha_j = TMath::Max(low ,tmp_alpha_j);
600  tmp_alpha_i = b_alpha_i_p - (tmp_alpha_j - b_alpha_j_p);
601 
602  if( IsDiffSignificant(b_alpha_j_p,tmp_alpha_j, epsilon) || IsDiffSignificant(b_alpha_i_p,tmp_alpha_i, epsilon)){
603  b_alpha_j_p = tmp_alpha_j;
604  b_alpha_i_p = tmp_alpha_i;
605  }
606  }
607  else
608  terminated = kTRUE;
609 
610  caseD = kTRUE;
611  }
612  else
613  terminated = kTRUE;
614  }
615  // TODO ad commment how it was calculated
616  deltafi += ievt->GetDeltaAlpha()*(kernel_II - kernel_IJ) + jevt->GetDeltaAlpha()*(kernel_IJ - kernel_JJ);
617 
618  if( IsDiffSignificant(b_alpha_i, ievt->GetAlpha(), epsilon) ||
619  IsDiffSignificant(b_alpha_j, jevt->GetAlpha(), epsilon) ||
620  IsDiffSignificant(b_alpha_i_p, ievt->GetAlpha_p(), epsilon) ||
621  IsDiffSignificant(b_alpha_j_p, jevt->GetAlpha_p(), epsilon) ){
622 
623  //TODO check if these conditions might be easier
624  //TODO write documentation for this
625  const Float_t diff_alpha_i = ievt->GetDeltaAlpha()+b_alpha_i_p - ievt->GetAlpha();
626  const Float_t diff_alpha_j = jevt->GetDeltaAlpha()+b_alpha_j_p - jevt->GetAlpha();
627 
628  //update error cache
629  Int_t k = 0;
630  for(idIter = fInputData->begin(); idIter != fInputData->end(); idIter++){
631  k++;
632  //there will be some changes in Idx notation
633  if((*idIter)->GetIdx()==0){
634  Float_t k_ii = fKMatrix->GetElement(ievt->GetNs(), (*idIter)->GetNs());
635  Float_t k_jj = fKMatrix->GetElement(jevt->GetNs(), (*idIter)->GetNs());
636 
637  (*idIter)->UpdateErrorCache(diff_alpha_i * k_ii + diff_alpha_j * k_jj);
638  }
639  }
640 
641  //store new alphas in SVevents
642  ievt->SetAlpha(b_alpha_i);
643  jevt->SetAlpha(b_alpha_j);
644  ievt->SetAlpha_p(b_alpha_i_p);
645  jevt->SetAlpha_p(b_alpha_j_p);
646 
647  //TODO update Idexes
648 
649  // compute fI_low, fB_low
650 
651  fB_low = -1*1e30;
652  fB_up =1e30;
653 
654  for(idIter = fInputData->begin(); idIter != fInputData->end(); idIter++){
655  if((!(*idIter)->IsInI3()) && ((*idIter)->GetErrorCache()> fB_low)){
656  fB_low = (*idIter)->GetErrorCache();
657  fTEventLow = (*idIter);
658 
659  }
660  if((!(*idIter)->IsInI2()) && ((*idIter)->GetErrorCache()< fB_up)){
661  fB_up =(*idIter)->GetErrorCache();
662  fTEventUp = (*idIter);
663  }
664  }
665  return kTRUE;
666  } else return kFALSE;
667 }
668 
669 
670 ////////////////////////////////////////////////////////////////////////////////
671 
673 {
674  Float_t feps = 1e-7;// TODO check which value is the best
675  SVEvent* ievt=0;
676  Float_t fErrorC_J = 0.;
677  if( jevt->IsInI0()) {
678  fErrorC_J = jevt->GetErrorCache();
679  }
680  else{
681  Float_t *fKVals = jevt->GetLine();
682  fErrorC_J = 0.;
683  std::vector<TMVA::SVEvent*>::iterator idIter;
684 
685  UInt_t k=0;
686  for(idIter = fInputData->begin(); idIter != fInputData->end(); idIter++){
687  fErrorC_J -= (*idIter)->GetDeltaAlpha()*fKVals[k];
688  k++;
689  }
690 
691  fErrorC_J += jevt->GetTarget();
692  jevt->SetErrorCache(fErrorC_J);
693 
694  if(jevt->IsInI1()){
695  if(fErrorC_J + feps < fB_up ){
696  fB_up = fErrorC_J + feps;
697  fTEventUp = jevt;
698  }
699  else if(fErrorC_J -feps > fB_low) {
700  fB_low = fErrorC_J - feps;
701  fTEventLow = jevt;
702  }
703  }else if((jevt->IsInI2()) && (fErrorC_J + feps > fB_low)){
704  fB_low = fErrorC_J + feps;
705  fTEventLow = jevt;
706  }else if((jevt->IsInI3()) && (fErrorC_J - feps < fB_up)){
707  fB_up = fErrorC_J - feps;
708  fTEventUp = jevt;
709  }
710  }
711 
712  Bool_t converged = kTRUE;
713  //case 1
714  if(jevt->IsInI0a()){
715  if( fB_low -fErrorC_J + feps > 2*fTolerance){
716  converged = kFALSE;
717  ievt = fTEventLow;
718  if(fErrorC_J-feps-fB_up > fB_low-fErrorC_J+feps){
719  ievt = fTEventUp;
720  }
721  }else if(fErrorC_J -feps - fB_up > 2*fTolerance){
722  converged = kFALSE;
723  ievt = fTEventUp;
724  if(fB_low - fErrorC_J+feps > fErrorC_J-feps -fB_up){
725  ievt = fTEventLow;
726  }
727  }
728  }
729 
730  //case 2
731  if(jevt->IsInI0b()){
732  if( fB_low -fErrorC_J - feps > 2*fTolerance){
733  converged = kFALSE;
734  ievt = fTEventLow;
735  if(fErrorC_J+feps-fB_up > fB_low-fErrorC_J-feps){
736  ievt = fTEventUp;
737  }
738  }else if(fErrorC_J + feps - fB_up > 2*fTolerance){
739  converged = kFALSE;
740  ievt = fTEventUp;
741  if(fB_low - fErrorC_J-feps > fErrorC_J+feps -fB_up){
742  ievt = fTEventLow;
743  }
744  }
745  }
746 
747  //case 3
748  if(jevt->IsInI1()){
749  if( fB_low -fErrorC_J - feps > 2*fTolerance){
750  converged = kFALSE;
751  ievt = fTEventLow;
752  if(fErrorC_J+feps-fB_up > fB_low-fErrorC_J-feps){
753  ievt = fTEventUp;
754  }
755  }else if(fErrorC_J - feps - fB_up > 2*fTolerance){
756  converged = kFALSE;
757  ievt = fTEventUp;
758  if(fB_low - fErrorC_J+feps > fErrorC_J-feps -fB_up){
759  ievt = fTEventLow;
760  }
761  }
762  }
763 
764  //case 4
765  if(jevt->IsInI2()){
766  if( fErrorC_J + feps -fB_up > 2*fTolerance){
767  converged = kFALSE;
768  ievt = fTEventUp;
769  }
770  }
771 
772  //case 5
773  if(jevt->IsInI3()){
774  if(fB_low -fErrorC_J +feps > 2*fTolerance){
775  converged = kFALSE;
776  ievt = fTEventLow;
777  }
778  }
779 
780  if(converged) return kFALSE;
781  if (TakeStepReg(ievt, jevt)) return kTRUE;
782  else return kFALSE;
783 }
784 
786 {
787  if( TMath::Abs(a_i - a_j) > eps*(a_i + a_j + eps)) return kTRUE;
788  else return kFALSE;
789 }
790 
Random number generator class based on M.
Definition: TRandom3.h:29
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
Float_t * GetLine(UInt_t)
returns a row of the kernel matrix
void SetAlpha(Float_t alpha)
Definition: SVEvent.h:53
~SVWorkingSet()
destructor
float Float_t
Definition: RtypesCore.h:53
SVWorkingSet()
constructor
TH1 * h
Definition: legend2.C:5
void SetIndex(TMVA::SVEvent *)
message logger
void Train(UInt_t nIter=1000)
train the SVM
Short_t Min(Short_t a, Short_t b)
Definition: TMathBase.h:170
int Int_t
Definition: RtypesCore.h:41
Bool_t TakeStep(SVEvent *, SVEvent *)
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
Bool_t IsInI0a() const
Definition: SVEvent.h:76
Float_t GetAlpha_p() const
Definition: SVEvent.h:64
Short_t Abs(Short_t d)
Definition: TMathBase.h:110
null_t< F > null()
void SetErrorCache(Float_t err_cache)
Definition: SVEvent.h:55
Float_t GetTarget() const
Definition: SVEvent.h:74
virtual UInt_t Integer(UInt_t imax)
Returns a random integer on [ 0, imax-1 ].
Definition: TRandom.cxx:320
Float_t GetCweight() const
Definition: SVEvent.h:73
Bool_t IsInI2() const
Definition: SVEvent.h:80
Bool_t ExamineExampleReg(SVEvent *)
Int_t GetTypeFlag() const
Definition: SVEvent.h:68
double gamma(double x)
const double tol
TPaveText * pt
Float_t * GetLine() const
Definition: SVEvent.h:71
unsigned int UInt_t
Definition: RtypesCore.h:42
TLine * l
Definition: textangle.C:4
Float_t GetAlpha() const
Definition: SVEvent.h:63
Bool_t ExamineExample(SVEvent *)
REAL epsilon
Definition: triangle.c:617
Float_t GetErrorCache() const
Definition: SVEvent.h:67
void SetIdx(Int_t idx)
Definition: SVEvent.h:58
Bool_t TakeStepReg(SVEvent *, SVEvent *)
std::vector< TMVA::SVEvent * > * GetSupportVectors()
UInt_t GetNs() const
Definition: SVEvent.h:72
Bool_t IsInI1() const
Definition: SVEvent.h:79
Float_t GetDeltaAlpha() const
Definition: SVEvent.h:65
SVKernelMatrix * fKMatrix
Definition: SVWorkingSet.h:73
Bool_t IsDiffSignificant(Float_t, Float_t, Float_t)
Short_t Max(Short_t a, Short_t b)
Definition: TMathBase.h:202
Bool_t IsInI0b() const
Definition: SVEvent.h:77
void SetAlpha_p(Float_t alpha)
Definition: SVEvent.h:54
Bool_t IsInI3() const
Definition: SVEvent.h:81
std::vector< TMVA::SVEvent * > * fInputData
Definition: SVWorkingSet.h:70
Bool_t IsInI0() const
Definition: SVEvent.h:78
const Bool_t kTRUE
Definition: Rtypes.h:91
SVEvent * fTEventUp
Definition: SVWorkingSet.h:75
SVEvent * fTEventLow
Definition: SVWorkingSet.h:76
Int_t GetIdx() const
Definition: SVEvent.h:70