ROOT  6.06/09
Reference Guide
RuleEnsemble.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : RuleEnsemble *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * A class generating an ensemble of rules *
12  * Input: a forest of decision trees *
13  * Output: an ensemble of rules *
14  * *
15  * Authors (alphabetical): *
16  * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
17  * Helge Voss <Helge.Voss@cern.ch> - MPI-KP Heidelberg, GER *
18  * *
19  * Copyright (c) 2005: *
20  * CERN, Switzerland *
21  * Iowa State U. *
22  * MPI-K Heidelberg, Germany *
23  * *
24  * Redistribution and use in source and binary forms, with or without *
25  * modification, are permitted according to the terms listed in LICENSE *
26  * (http://tmva.sourceforge.net/LICENSE) *
27  **********************************************************************************/
28 
29 #include <algorithm>
30 #include <list>
31 #include <cstdlib>
32 #include <iomanip>
33 
34 #include "TRandom3.h"
35 #include "TH1F.h"
36 #include "TMVA/RuleEnsemble.h"
37 #include "TMVA/RuleFit.h"
38 #include "TMVA/MethodRuleFit.h"
39 #include "TMVA/Tools.h"
40 
41 ////////////////////////////////////////////////////////////////////////////////
42 /// constructor
43 
45  : fLearningModel ( kFull )
46  , fImportanceCut ( 0 )
47  , fLinQuantile ( 0.025 ) // default quantile for killing outliers in linear terms
48  , fOffset ( 0 )
49  , fAverageSupport ( 0.8 )
50  , fAverageRuleSigma( 0.4 ) // default value - used if only linear model is chosen
51  , fRuleFSig ( 0 )
52  , fRuleNCave ( 0 )
53  , fRuleNCsig ( 0 )
54  , fRuleMinDist ( 1e-3 ) // closest allowed 'distance' between two rules
55  , fNRulesGenerated ( 0 )
56  , fEvent ( 0 )
57  , fEventCacheOK ( true )
58  , fRuleMapOK ( true )
59  , fRuleMapInd0 ( 0 )
60  , fRuleMapInd1 ( 0 )
61  , fRuleMapEvents ( 0 )
62  , fLogger( new MsgLogger("RuleFit") )
63 {
64  Initialize( rf );
65 }
66 
67 ////////////////////////////////////////////////////////////////////////////////
68 /// copy constructor
69 
71  : fAverageSupport ( 1 )
72  , fEvent(0)
73  , fRuleMapEvents(0)
74  , fRuleFit(0)
75  , fLogger( new MsgLogger("RuleFit") )
76 {
77  Copy( other );
78 }
79 
80 ////////////////////////////////////////////////////////////////////////////////
81 /// constructor
82 
84  : fLearningModel ( kFull )
85  , fImportanceCut ( 0 )
86  , fLinQuantile ( 0.025 ) // default quantile for killing outliers in linear terms
87  , fOffset ( 0 )
88  , fImportanceRef ( 1.0 )
89  , fAverageSupport ( 0.8 )
90  , fAverageRuleSigma( 0.4 ) // default value - used if only linear model is chosen
91  , fRuleFSig ( 0 )
92  , fRuleNCave ( 0 )
93  , fRuleNCsig ( 0 )
94  , fRuleMinDist ( 1e-3 ) // closest allowed 'distance' between two rules
95  , fNRulesGenerated ( 0 )
96  , fEvent ( 0 )
97  , fEventCacheOK ( true )
98  , fRuleMapOK ( true )
99  , fRuleMapInd0 ( 0 )
100  , fRuleMapInd1 ( 0 )
101  , fRuleMapEvents ( 0 )
102  , fRuleFit ( 0 )
103  , fLogger( new MsgLogger("RuleFit") )
104 {
105 }
106 
107 ////////////////////////////////////////////////////////////////////////////////
108 /// destructor
109 
111 {
112  for ( std::vector<Rule *>::iterator itrRule = fRules.begin(); itrRule != fRules.end(); itrRule++ ) {
113  delete *itrRule;
114  }
115  // NOTE: Should not delete the histos fLinPDFB/S since they are delete elsewhere
116  delete fLogger;
117 }
118 
119 ////////////////////////////////////////////////////////////////////////////////
120 /// Initializes all member variables with default values
121 
123 {
124  SetAverageRuleSigma(0.4); // default value - used if only linear model is chosen
125  fRuleFit = rf;
126  UInt_t nvars = GetMethodBase()->GetNvar();
127  fVarImportance.clear();
128  fLinPDFB.clear();
129  fLinPDFS.clear();
130  //
131  fVarImportance.resize( nvars,0.0 );
132  fLinPDFB.resize( nvars,0 );
133  fLinPDFS.resize( nvars,0 );
134  fImportanceRef = 1.0;
135  for (UInt_t i=0; i<nvars; i++) { // a priori all linear terms are equally valid
136  fLinTermOK.push_back(kTRUE);
137  }
138 }
139 
140 ////////////////////////////////////////////////////////////////////////////////
141 
143  fLogger->SetMinType(t);
144 }
145 
146 
147 ////////////////////////////////////////////////////////////////////////////////
148 ///
149 /// Get a pointer to the original MethodRuleFit.
150 ///
151 
153 {
154  return ( fRuleFit==0 ? 0:fRuleFit->GetMethodRuleFit());
155 }
156 
157 ////////////////////////////////////////////////////////////////////////////////
158 ///
159 /// Get a pointer to the original MethodRuleFit.
160 ///
161 
163 {
164  return ( fRuleFit==0 ? 0:fRuleFit->GetMethodBase());
165 }
166 
167 ////////////////////////////////////////////////////////////////////////////////
168 /// create model
169 
171 {
172  MakeRules( fRuleFit->GetForest() );
173 
174  MakeLinearTerms();
175 
176  MakeRuleMap();
177 
178  CalcRuleSupport();
179 
180  RuleStatistics();
181 
182  PrintRuleGen();
183 }
184 
185 ////////////////////////////////////////////////////////////////////////////////
186 ///
187 /// Calculates sqrt(Sum(a_i^2)), i=1..N (NOTE do not include a0)
188 ///
189 
191 {
192  Int_t ncoeffs = fRules.size();
193  if (ncoeffs<1) return 0;
194  //
195  Double_t sum2=0;
196  Double_t val;
197  for (Int_t i=0; i<ncoeffs; i++) {
198  val = fRules[i]->GetCoefficient();
199  sum2 += val*val;
200  }
201  return sum2;
202 }
203 
204 ////////////////////////////////////////////////////////////////////////////////
205 /// reset all rule coefficients
206 
208 {
209  fOffset = 0.0;
210  UInt_t nrules = fRules.size();
211  for (UInt_t i=0; i<nrules; i++) {
212  fRules[i]->SetCoefficient(0.0);
213  }
214 }
215 
216 ////////////////////////////////////////////////////////////////////////////////
217 /// set all rule coefficients
218 
219 void TMVA::RuleEnsemble::SetCoefficients( const std::vector< Double_t > & v )
220 {
221  UInt_t nrules = fRules.size();
222  if (v.size()!=nrules) {
223  Log() << kFATAL << "<SetCoefficients> - BUG TRAP - input vector worng size! It is = " << v.size()
224  << " when it should be = " << nrules << Endl;
225  }
226  for (UInt_t i=0; i<nrules; i++) {
227  fRules[i]->SetCoefficient(v[i]);
228  }
229 }
230 
231 ////////////////////////////////////////////////////////////////////////////////
232 /// Retrieve all rule coefficients
233 
234 void TMVA::RuleEnsemble::GetCoefficients( std::vector< Double_t > & v )
235 {
236  UInt_t nrules = fRules.size();
237  v.resize(nrules);
238  if (nrules==0) return;
239  //
240  for (UInt_t i=0; i<nrules; i++) {
241  v[i] = (fRules[i]->GetCoefficient());
242  }
243 }
244 
245 ////////////////////////////////////////////////////////////////////////////////
246 /// get list of training events from the rule fitter
247 
248 const std::vector<const TMVA::Event*>* TMVA::RuleEnsemble::GetTrainingEvents() const
249 {
250  return &(fRuleFit->GetTrainingEvents());
251 }
252 
253 ////////////////////////////////////////////////////////////////////////////////
254 /// get the training event from the rule fitter
255 
257 {
258  return fRuleFit->GetTrainingEvent(i);
259 }
260 
261 ////////////////////////////////////////////////////////////////////////////////
262 /// remove rules that behave similar
263 
265 {
266  Log() << kVERBOSE << "Removing similar rules; distance = " << fRuleMinDist << Endl;
267 
268  UInt_t nrulesIn = fRules.size();
269  TMVA::Rule *first, *second;
270  std::vector< Char_t > removeMe( nrulesIn,false ); // <--- stores boolean
271 
272  Int_t nrem = 0;
273  Int_t remind=-1;
274  Double_t r;
275 
276  for (UInt_t i=0; i<nrulesIn; i++) {
277  if (!removeMe[i]) {
278  first = fRules[i];
279  for (UInt_t k=i+1; k<nrulesIn; k++) {
280  if (!removeMe[k]) {
281  second = fRules[k];
282  Bool_t equal = first->Equal(*second,kTRUE,fRuleMinDist);
283  if (equal) {
284  r = gRandom->Rndm();
285  remind = (r>0.5 ? k:i); // randomly select rule
286  }
287  else {
288  remind = -1;
289  }
290 
291  if (remind>-1) {
292  if (!removeMe[remind]) {
293  removeMe[remind] = true;
294  nrem++;
295  }
296  }
297  }
298  }
299  }
300  }
301  UInt_t ind = 0;
302  Rule *theRule;
303  for (UInt_t i=0; i<nrulesIn; i++) {
304  if (removeMe[i]) {
305  theRule = fRules[ind];
306 #if _MSC_VER >= 1400
307  fRules.erase( std::vector<Rule *>::iterator(&fRules[ind], &fRules) );
308 #else
309  fRules.erase( fRules.begin() + ind );
310 #endif
311  delete theRule;
312  ind--;
313  }
314  ind++;
315  }
316  UInt_t nrulesOut = fRules.size();
317  Log() << kVERBOSE << "Removed " << nrulesIn - nrulesOut << " out of " << nrulesIn << " rules" << Endl;
318 }
319 
320 ////////////////////////////////////////////////////////////////////////////////
321 /// cleanup rules
322 
324 {
325  UInt_t nrules = fRules.size();
326  if (nrules==0) return;
327  Log() << kVERBOSE << "Removing rules with relative importance < " << fImportanceCut << Endl;
328  if (fImportanceCut<=0) return;
329  //
330  // Mark rules to be removed
331  //
332  Rule *therule;
333  Int_t ind=0;
334  for (UInt_t i=0; i<nrules; i++) {
335  if (fRules[ind]->GetRelImportance()<fImportanceCut) {
336  therule = fRules[ind];
337 #if _MSC_VER >= 1400
338  fRules.erase( std::vector<Rule *>::iterator(&fRules[ind], &fRules) );
339 #else
340  fRules.erase( fRules.begin() + ind );
341 #endif
342  delete therule;
343  ind--;
344  }
345  ind++;
346  }
347  Log() << kINFO << "Removed " << nrules-ind << " out of a total of " << nrules
348  << " rules with importance < " << fImportanceCut << Endl;
349 }
350 
351 ////////////////////////////////////////////////////////////////////////////////
352 /// cleanup linear model
353 
355 {
356  UInt_t nlin = fLinNorm.size();
357  if (nlin==0) return;
358  Log() << kVERBOSE << "Removing linear terms with relative importance < " << fImportanceCut << Endl;
359  //
360  fLinTermOK.clear();
361  for (UInt_t i=0; i<nlin; i++) {
362  fLinTermOK.push_back( (fLinImportance[i]/fImportanceRef > fImportanceCut) );
363  }
364 }
365 
366 ////////////////////////////////////////////////////////////////////////////////
367 /// calculate the support for all rules
368 
370 {
371  Log() << kVERBOSE << "Evaluating Rule support" << Endl;
372  Double_t s,t,stot,ttot,ssb;
373  Double_t ssig, sbkg, ssum;
374  Int_t indrule=0;
375  stot = 0;
376  ttot = 0;
377  // reset to default values
378  SetAverageRuleSigma(0.4);
379  const std::vector<const Event *> *events = GetTrainingEvents();
380  Double_t nrules = static_cast<Double_t>(fRules.size());
381  Double_t ew;
382  //
383  if ((nrules>0) && (events->size()>0)) {
384  for ( std::vector< Rule * >::iterator itrRule=fRules.begin(); itrRule!=fRules.end(); itrRule++ ) {
385  s=0.0;
386  ssig=0.0;
387  sbkg=0.0;
388  for ( std::vector<const Event * >::const_iterator itrEvent=events->begin(); itrEvent!=events->end(); itrEvent++ ) {
389  if ((*itrRule)->EvalEvent( *(*itrEvent) )) {
390  ew = (*itrEvent)->GetWeight();
391  s += ew;
392  if (GetMethodRuleFit()->DataInfo().IsSignal(*itrEvent)) ssig += ew;
393  else sbkg += ew;
394  }
395  }
396  //
397  s = s/fRuleFit->GetNEveEff();
398  t = s*(1.0-s);
399  t = (t<0 ? 0:sqrt(t));
400  stot += s;
401  ttot += t;
402  ssum = ssig+sbkg;
403  ssb = (ssum>0 ? Double_t(ssig)/Double_t(ssig+sbkg) : 0.0 );
404  (*itrRule)->SetSupport(s);
405  (*itrRule)->SetNorm(t);
406  (*itrRule)->SetSSB( ssb );
407  (*itrRule)->SetSSBNeve(Double_t(ssig+sbkg));
408  indrule++;
409  }
410  fAverageSupport = stot/nrules;
411  fAverageRuleSigma = TMath::Sqrt(fAverageSupport*(1.0-fAverageSupport));
412  Log() << kVERBOSE << "Standard deviation of support = " << fAverageRuleSigma << Endl;
413  Log() << kVERBOSE << "Average rule support = " << fAverageSupport << Endl;
414  }
415 }
416 
417 ////////////////////////////////////////////////////////////////////////////////
418 /// calculate the importance of each rule
419 
421 {
422  Double_t maxRuleImp = CalcRuleImportance();
423  Double_t maxLinImp = CalcLinImportance();
424  Double_t maxImp = (maxRuleImp>maxLinImp ? maxRuleImp : maxLinImp);
425  SetImportanceRef( maxImp );
426 }
427 
428 ////////////////////////////////////////////////////////////////////////////////
429 /// set reference importance
430 
432 {
433  for ( UInt_t i=0; i<fRules.size(); i++ ) {
434  fRules[i]->SetImportanceRef(impref);
435  }
436  fImportanceRef = impref;
437 }
438 ////////////////////////////////////////////////////////////////////////////////
439 /// calculate importance of each rule
440 
442 {
443  Double_t maxImp=-1.0;
444  Double_t imp;
445  Int_t nrules = fRules.size();
446  for ( int i=0; i<nrules; i++ ) {
447  fRules[i]->CalcImportance();
448  imp = fRules[i]->GetImportance();
449  if (imp>maxImp) maxImp = imp;
450  }
451  for ( Int_t i=0; i<nrules; i++ ) {
452  fRules[i]->SetImportanceRef(maxImp);
453  }
454 
455  return maxImp;
456 }
457 
458 ////////////////////////////////////////////////////////////////////////////////
459 /// calculate the linear importance for each rule
460 
462 {
463  Double_t maxImp=-1.0;
464  UInt_t nvars = fLinCoefficients.size();
465  fLinImportance.resize(nvars,0.0);
466  if (!DoLinear()) return maxImp;
467  //
468  // The linear importance is:
469  // I = |b_x|*sigma(x)
470  // Note that the coefficients in fLinCoefficients are for the normalized x
471  // => b'_x * x' = b'_x * sigma(r)*x/sigma(x)
472  // => b_x = b'_x*sigma(r)/sigma(x)
473  // => I = |b'_x|*sigma(r)
474  //
475  Double_t imp;
476  for ( UInt_t i=0; i<nvars; i++ ) {
477  imp = fAverageRuleSigma*TMath::Abs(fLinCoefficients[i]);
478  fLinImportance[i] = imp;
479  if (imp>maxImp) maxImp = imp;
480  }
481  return maxImp;
482 }
483 
484 ////////////////////////////////////////////////////////////////////////////////
485 ///
486 /// Calculates variable importance using eq (35) in RuleFit paper by Friedman et.al
487 ///
488 
490 {
491  Log() << kVERBOSE << "Compute variable importance" << Endl;
492  Double_t rimp;
493  UInt_t nrules = fRules.size();
494  if (GetMethodBase()==0) Log() << kFATAL << "RuleEnsemble::CalcVarImportance() - should not be here!" << Endl;
495  UInt_t nvars = GetMethodBase()->GetNvar();
496  UInt_t nvarsUsed;
497  Double_t rimpN;
498  fVarImportance.resize(nvars,0);
499  // rules
500  if (DoRules()) {
501  for ( UInt_t ind=0; ind<nrules; ind++ ) {
502  rimp = fRules[ind]->GetImportance();
503  nvarsUsed = fRules[ind]->GetNumVarsUsed();
504  if (nvarsUsed<1)
505  Log() << kFATAL << "<CalcVarImportance> Variables for importance calc!!!??? A BUG!" << Endl;
506  rimpN = (nvarsUsed > 0 ? rimp/nvarsUsed:0.0);
507  for ( UInt_t iv=0; iv<nvars; iv++ ) {
508  if (fRules[ind]->ContainsVariable(iv)) {
509  fVarImportance[iv] += rimpN;
510  }
511  }
512  }
513  }
514  // linear terms
515  if (DoLinear()) {
516  for ( UInt_t iv=0; iv<fLinTermOK.size(); iv++ ) {
517  if (fLinTermOK[iv]) fVarImportance[iv] += fLinImportance[iv];
518  }
519  }
520  //
521  // Make variable importance relative the strongest variable
522  //
523  Double_t maximp = 0.0;
524  for ( UInt_t iv=0; iv<nvars; iv++ ) {
525  if ( fVarImportance[iv] > maximp ) maximp = fVarImportance[iv];
526  }
527  if (maximp>0) {
528  for ( UInt_t iv=0; iv<nvars; iv++ ) {
529  fVarImportance[iv] *= 1.0/maximp;
530  }
531  }
532 }
533 
534 ////////////////////////////////////////////////////////////////////////////////
535 /// set rules
536 ///
537 /// first clear all
538 
539 void TMVA::RuleEnsemble::SetRules( const std::vector<Rule *> & rules )
540 {
541  DeleteRules();
542  //
543  fRules.resize(rules.size());
544  for (UInt_t i=0; i<fRules.size(); i++) {
545  fRules[i] = rules[i];
546  }
547  fEventCacheOK = kFALSE;
548 }
549 
550 ////////////////////////////////////////////////////////////////////////////////
551 ///
552 /// Makes rules from the given decision tree.
553 /// First node in all rules is ALWAYS the root node.
554 ///
555 
556 void TMVA::RuleEnsemble::MakeRules( const std::vector< const DecisionTree *> & forest )
557 {
558  fRules.clear();
559  if (!DoRules()) return;
560  //
561  Int_t nrulesCheck=0;
562  Int_t nrules;
563  Int_t nendn;
564  Double_t sumnendn=0;
565  Double_t sumn2=0;
566  //
567  // UInt_t prevs;
568  UInt_t ntrees = forest.size();
569  for ( UInt_t ind=0; ind<ntrees; ind++ ) {
570  // prevs = fRules.size();
571  MakeRulesFromTree( forest[ind] );
572  nrules = CalcNRules( forest[ind] );
573  nendn = (nrules/2) + 1;
574  sumnendn += nendn;
575  sumn2 += nendn*nendn;
576  nrulesCheck += nrules;
577  }
578  Double_t nmean = (ntrees>0) ? sumnendn/ntrees : 0;
579  Double_t nsigm = TMath::Sqrt( gTools().ComputeVariance(sumn2,sumnendn,ntrees) );
580  Double_t ndev = 2.0*(nmean-2.0-nsigm)/(nmean-2.0+nsigm);
581  //
582  Log() << kVERBOSE << "Average number of end nodes per tree = " << nmean << Endl;
583  if (ntrees>1) Log() << kVERBOSE << "sigma of ditto ( ~= mean-2 ?) = "
584  << nsigm
585  << Endl;
586  Log() << kVERBOSE << "Deviation from exponential model = " << ndev << Endl;
587  Log() << kVERBOSE << "Corresponds to L (eq. 13, RuleFit ppr) = " << nmean << Endl;
588  // a BUG trap
589  if (nrulesCheck != static_cast<Int_t>(fRules.size())) {
590  Log() << kFATAL
591  << "BUG! number of generated and possible rules do not match! N(rules) = " << fRules.size()
592  << " != " << nrulesCheck << Endl;
593  }
594  Log() << kVERBOSE << "Number of generated rules: " << fRules.size() << Endl;
595 
596  // save initial number of rules
597  fNRulesGenerated = fRules.size();
598 
599  RemoveSimilarRules();
600 
601  ResetCoefficients();
602 
603 }
604 
605 ////////////////////////////////////////////////////////////////////////////////
606 ///
607 /// Make the linear terms as in eq 25, ref 2
608 /// For this the b and (1-b) quatiles are needed
609 ///
610 
612 {
613  if (!DoLinear()) return;
614 
615  const std::vector<const Event *> *events = GetTrainingEvents();
616  UInt_t neve = events->size();
617  UInt_t nvars = ((*events)[0])->GetNVariables(); // Event -> GetNVariables();
618  Double_t val,ew;
619  typedef std::pair< Double_t, Int_t> dataType;
620  typedef std::pair< Double_t, dataType > dataPoint;
621 
622  std::vector< std::vector<dataPoint> > vardata(nvars);
623  std::vector< Double_t > varsum(nvars,0.0);
624  std::vector< Double_t > varsum2(nvars,0.0);
625  // first find stats of all variables
626  // vardata[v][i].first -> value of var <v> in event <i>
627  // vardata[v][i].second.first -> the event weight
628  // vardata[v][i].second.second -> the event type
629  for (UInt_t i=0; i<neve; i++) {
630  ew = ((*events)[i])->GetWeight();
631  for (UInt_t v=0; v<nvars; v++) {
632  val = ((*events)[i])->GetValue(v);
633  vardata[v].push_back( dataPoint( val, dataType(ew,((*events)[i])->GetClass()) ) );
634  }
635  }
636  //
637  fLinDP.clear();
638  fLinDM.clear();
639  fLinCoefficients.clear();
640  fLinNorm.clear();
641  fLinDP.resize(nvars,0);
642  fLinDM.resize(nvars,0);
643  fLinCoefficients.resize(nvars,0);
644  fLinNorm.resize(nvars,0);
645 
646  Double_t averageWeight = neve ? fRuleFit->GetNEveEff()/static_cast<Double_t>(neve) : 0;
647  // sort and find limits
648  Double_t stdl;
649 
650  // find normalisation given in ref 2 after eq 26
651  Double_t lx;
652  Double_t nquant;
653  Double_t neff;
654  UInt_t indquantM;
655  UInt_t indquantP;
656 
657  for (UInt_t v=0; v<nvars; v++) {
658  varsum[v] = 0;
659  varsum2[v] = 0;
660  //
661  std::sort( vardata[v].begin(),vardata[v].end() );
662  nquant = fLinQuantile*fRuleFit->GetNEveEff(); // quantile = 0.025
663  neff=0;
664  UInt_t ie=0;
665  // first scan for lower quantile (including weights)
666  while ( (ie<neve) && (neff<nquant) ) {
667  neff += vardata[v][ie].second.first;
668  ie++;
669  }
670  indquantM = (ie==0 ? 0:ie-1);
671  // now for upper quantile
672  ie = neve;
673  neff=0;
674  while ( (ie>0) && (neff<nquant) ) {
675  ie--;
676  neff += vardata[v][ie].second.first;
677  }
678  indquantP = (ie==neve ? ie=neve-1:ie);
679  //
680  fLinDM[v] = vardata[v][indquantM].first; // delta-
681  fLinDP[v] = vardata[v][indquantP].first; // delta+
682  if (fLinPDFB[v]) delete fLinPDFB[v];
683  if (fLinPDFS[v]) delete fLinPDFS[v];
684  fLinPDFB[v] = new TH1F(Form("bkgvar%d",v),"bkg temphist",40,fLinDM[v],fLinDP[v]);
685  fLinPDFS[v] = new TH1F(Form("sigvar%d",v),"sig temphist",40,fLinDM[v],fLinDP[v]);
686  fLinPDFB[v]->Sumw2();
687  fLinPDFS[v]->Sumw2();
688  //
689  Int_t type;
690  const Double_t w = 1.0/fRuleFit->GetNEveEff();
691  for (ie=0; ie<neve; ie++) {
692  val = vardata[v][ie].first;
693  ew = vardata[v][ie].second.first;
694  type = vardata[v][ie].second.second;
695  lx = TMath::Min( fLinDP[v], TMath::Max( fLinDM[v], val ) );
696  varsum[v] += ew*lx;
697  varsum2[v] += ew*lx*lx;
698  if (type==1) fLinPDFS[v]->Fill(lx,w*ew);
699  else fLinPDFB[v]->Fill(lx,w*ew);
700  }
701  //
702  // Get normalization.
703  //
704  stdl = TMath::Sqrt( (varsum2[v] - (varsum[v]*varsum[v]/fRuleFit->GetNEveEff()))/(fRuleFit->GetNEveEff()-averageWeight) );
705  fLinNorm[v] = CalcLinNorm(stdl);
706  }
707  // Save PDFs - for debugging purpose
708  for (UInt_t v=0; v<nvars; v++) {
709  fLinPDFS[v]->Write();
710  fLinPDFB[v]->Write();
711  }
712 }
713 
714 
715 ////////////////////////////////////////////////////////////////////////////////
716 ///
717 /// This function returns Pr( y = 1 | x ) for the linear terms.
718 ///
719 
721 {
722  UInt_t nvars=fLinDP.size();
723 
724  Double_t fstot=0;
725  Double_t fbtot=0;
726  nsig = 0;
727  ntot = nvars;
728  for (UInt_t v=0; v<nvars; v++) {
729  Double_t val = fEventLinearVal[v];
730  Int_t bin = fLinPDFS[v]->FindBin(val);
731  fstot += fLinPDFS[v]->GetBinContent(bin);
732  fbtot += fLinPDFB[v]->GetBinContent(bin);
733  }
734  if (nvars<1) return 0;
735  ntot = (fstot+fbtot)/Double_t(nvars);
736  nsig = (fstot)/Double_t(nvars);
737  return fstot/(fstot+fbtot);
738 }
739 
740 ////////////////////////////////////////////////////////////////////////////////
741 ///
742 /// This function returns Pr( y = 1 | x ) for rules.
743 /// The probability returned is normalized against the number of rules which are actually passed
744 ///
745 
747 {
748  Double_t sump = 0;
749  Double_t sumok = 0;
750  Double_t sumz = 0;
751  Double_t ssb;
752  Double_t neve;
753  //
754  UInt_t nrules = fRules.size();
755  for (UInt_t ir=0; ir<nrules; ir++) {
756  if (fEventRuleVal[ir]>0) {
757  ssb = fEventRuleVal[ir]*GetRulesConst(ir)->GetSSB(); // S/(S+B) is evaluated in CalcRuleSupport() using ALL training events
758  neve = GetRulesConst(ir)->GetSSBNeve(); // number of events accepted by the rule
759  sump += ssb*neve; // number of signal events
760  sumok += neve; // total number of events passed
761  }
762  else sumz += 1.0; // all events
763  }
764 
765  nsig = sump;
766  ntot = sumok;
767  //
768  if (ntot>0) return nsig/ntot;
769  return 0.0;
770 }
771 
772 ////////////////////////////////////////////////////////////////////////////////
773 ///
774 /// We want to estimate F* = argmin Eyx( L(y,F(x) ), min wrt F(x)
775 /// F(x) = FL(x) + FR(x) , linear and rule part
776 ///
777 ///
778 
780 {
781  SetEvent(e);
782  UpdateEventVal();
783  return FStar();
784 }
785 
786 ////////////////////////////////////////////////////////////////////////////////
787 ///
788 /// We want to estimate F* = argmin Eyx( L(y,F(x) ), min wrt F(x)
789 /// F(x) = FL(x) + FR(x) , linear and rule part
790 ///
791 ///
792 
794 {
795  Double_t p=0;
796  Double_t nrs=0, nrt=0;
797  Double_t nls=0, nlt=0;
798  Double_t nt;
799  Double_t pr=0;
800  Double_t pl=0;
801 
802  // first calculate Pr(y=1|X) for rules and linear terms
803  if (DoLinear()) pl = PdfLinear(nls, nlt);
804  if (DoRules()) pr = PdfRule(nrs, nrt);
805  // nr(l)t=0 or 1
806  if ((nlt>0) && (nrt>0)) nt=2.0;
807  else nt=1.0;
808  p = (pl+pr)/nt;
809  return 2.0*p-1.0;
810 }
811 
812 ////////////////////////////////////////////////////////////////////////////////
813 /// calculate various statistics for this rule
814 
816 {
817  // TODO: NOT YET UPDATED FOR WEIGHTS
818  const std::vector<const Event *> *events = GetTrainingEvents();
819  const UInt_t neve = events->size();
820  const UInt_t nvars = GetMethodBase()->GetNvar();
821  const UInt_t nrules = fRules.size();
822  const Event *eveData;
823  // Flags
824  Bool_t sigRule;
825  Bool_t sigTag;
826  Bool_t bkgTag;
827  // Bool_t noTag;
828  Bool_t sigTrue;
829  Bool_t tagged;
830  // Counters
831  Int_t nsig=0;
832  Int_t nbkg=0;
833  Int_t ntag=0;
834  Int_t nss=0;
835  Int_t nsb=0;
836  Int_t nbb=0;
837  Int_t nbs=0;
838  std::vector<Int_t> varcnt;
839  // Clear vectors
840  fRulePSS.clear();
841  fRulePSB.clear();
842  fRulePBS.clear();
843  fRulePBB.clear();
844  fRulePTag.clear();
845  //
846  varcnt.resize(nvars,0);
847  fRuleVarFrac.clear();
848  fRuleVarFrac.resize(nvars,0);
849  //
850  for ( UInt_t i=0; i<nrules; i++ ) {
851  for ( UInt_t v=0; v<nvars; v++) {
852  if (fRules[i]->ContainsVariable(v)) varcnt[v]++; // count how often a variable occurs
853  }
854  sigRule = fRules[i]->IsSignalRule();
855  if (sigRule) { // rule is a signal rule (ie s/(s+b)>0.5)
856  nsig++;
857  }
858  else {
859  nbkg++;
860  }
861  // reset counters
862  nss=0;
863  nsb=0;
864  nbs=0;
865  nbb=0;
866  ntag=0;
867  // loop over all events
868  for (UInt_t e=0; e<neve; e++) {
869  eveData = (*events)[e];
870  tagged = fRules[i]->EvalEvent(*eveData);
871  sigTag = (tagged && sigRule); // it's tagged as a signal
872  bkgTag = (tagged && (!sigRule)); // ... as bkg
873  // noTag = !(sigTag || bkgTag); // ... not tagged
874  sigTrue = (eveData->GetClass() == 0); // true if event is true signal
875  if (tagged) {
876  ntag++;
877  if (sigTag && sigTrue) nss++;
878  if (sigTag && !sigTrue) nsb++;
879  if (bkgTag && sigTrue) nbs++;
880  if (bkgTag && !sigTrue) nbb++;
881  }
882  }
883  // Fill tagging probabilities
884  if (ntag>0 && neve > 0) { // should always be the case, but let's make sure and keep coverity quiet
885  fRulePTag.push_back(Double_t(ntag)/Double_t(neve));
886  fRulePSS.push_back(Double_t(nss)/Double_t(ntag));
887  fRulePSB.push_back(Double_t(nsb)/Double_t(ntag));
888  fRulePBS.push_back(Double_t(nbs)/Double_t(ntag));
889  fRulePBB.push_back(Double_t(nbb)/Double_t(ntag));
890  }
891  //
892  }
893  fRuleFSig = (nsig>0) ? static_cast<Double_t>(nsig)/static_cast<Double_t>(nsig+nbkg) : 0;
894  for ( UInt_t v=0; v<nvars; v++) {
895  fRuleVarFrac[v] = (nrules>0) ? Double_t(varcnt[v])/Double_t(nrules) : 0;
896  }
897 }
898 
899 ////////////////////////////////////////////////////////////////////////////////
900 /// calculate various statistics for this rule
901 
903 {
904  const UInt_t nrules = fRules.size();
905  Double_t nc;
906  Double_t sumNc =0;
907  Double_t sumNc2=0;
908  for ( UInt_t i=0; i<nrules; i++ ) {
909  nc = static_cast<Double_t>(fRules[i]->GetNcuts());
910  sumNc += nc;
911  sumNc2 += nc*nc;
912  }
913  fRuleNCave = 0.0;
914  fRuleNCsig = 0.0;
915  if (nrules>0) {
916  fRuleNCave = sumNc/nrules;
917  fRuleNCsig = TMath::Sqrt(gTools().ComputeVariance(sumNc2,sumNc,nrules));
918  }
919 }
920 
921 ////////////////////////////////////////////////////////////////////////////////
922 /// print rule generation info
923 
925 {
926  Log() << kINFO << "-------------------RULE ENSEMBLE SUMMARY------------------------" << Endl;
927  const MethodRuleFit *mrf = GetMethodRuleFit();
928  if (mrf) Log() << kINFO << "Tree training method : " << (mrf->UseBoost() ? "AdaBoost":"Random") << Endl;
929  Log() << kINFO << "Number of events per tree : " << fRuleFit->GetNTreeSample() << Endl;
930  Log() << kINFO << "Number of trees : " << fRuleFit->GetForest().size() << Endl;
931  Log() << kINFO << "Number of generated rules : " << fNRulesGenerated << Endl;
932  Log() << kINFO << "Idem, after cleanup : " << fRules.size() << Endl;
933  Log() << kINFO << "Average number of cuts per rule : " << Form("%8.2f",fRuleNCave) << Endl;
934  Log() << kINFO << "Spread in number of cuts per rules : " << Form("%8.2f",fRuleNCsig) << Endl;
935  Log() << kVERBOSE << "Complexity : " << Form("%8.2f",fRuleNCave*fRuleNCsig) << Endl;
936  Log() << kINFO << "----------------------------------------------------------------" << Endl;
937  Log() << kINFO << Endl;
938 }
939 
940 ////////////////////////////////////////////////////////////////////////////////
941 /// print function
942 
944 {
945  const EMsgType kmtype=kINFO;
946  const Bool_t isDebug = (fLogger->GetMinType()<=kDEBUG);
947  //
948  Log() << kmtype << Endl;
949  Log() << kmtype << "================================================================" << Endl;
950  Log() << kmtype << " M o d e l " << Endl;
951  Log() << kmtype << "================================================================" << Endl;
952 
953  Int_t ind;
954  const UInt_t nvars = GetMethodBase()->GetNvar();
955  const Int_t nrules = fRules.size();
956  const Int_t printN = TMath::Min(10,nrules); //nrules+1;
957  Int_t maxL = 0;
958  for (UInt_t iv = 0; iv<fVarImportance.size(); iv++) {
959  if (GetMethodBase()->GetInputLabel(iv).Length() > maxL) maxL = GetMethodBase()->GetInputLabel(iv).Length();
960  }
961  //
962  if (isDebug) {
963  Log() << kDEBUG << "Variable importance:" << Endl;
964  for (UInt_t iv = 0; iv<fVarImportance.size(); iv++) {
965  Log() << kDEBUG << std::setw(maxL) << GetMethodBase()->GetInputLabel(iv)
966  << std::resetiosflags(std::ios::right)
967  << " : " << Form(" %3.3f",fVarImportance[iv]) << Endl;
968  }
969  }
970  //
971  Log() << kmtype << "Offset (a0) = " << fOffset << Endl;
972  //
973  if (DoLinear()) {
974  if (fLinNorm.size() > 0) {
975  Log() << kmtype << "------------------------------------" << Endl;
976  Log() << kmtype << "Linear model (weights unnormalised)" << Endl;
977  Log() << kmtype << "------------------------------------" << Endl;
978  Log() << kmtype << std::setw(maxL) << "Variable"
979  << std::resetiosflags(std::ios::right) << " : "
980  << std::setw(11) << " Weights"
981  << std::resetiosflags(std::ios::right) << " : "
982  << "Importance"
983  << std::resetiosflags(std::ios::right)
984  << Endl;
985  Log() << kmtype << "------------------------------------" << Endl;
986  for ( UInt_t i=0; i<fLinNorm.size(); i++ ) {
987  Log() << kmtype << std::setw(std::max(maxL,8)) << GetMethodBase()->GetInputLabel(i);
988  if (fLinTermOK[i]) {
989  Log() << kmtype
990  << std::resetiosflags(std::ios::right)
991  << " : " << Form(" %10.3e",fLinCoefficients[i]*fLinNorm[i])
992  << " : " << Form(" %3.3f",fLinImportance[i]/fImportanceRef) << Endl;
993  }
994  else {
995  Log() << kmtype << "-> importance below threshhold = "
996  << Form(" %3.3f",fLinImportance[i]/fImportanceRef) << Endl;
997  }
998  }
999  Log() << kmtype << "------------------------------------" << Endl;
1000  }
1001  }
1002  else Log() << kmtype << "Linear terms were disabled" << Endl;
1003 
1004  if ((!DoRules()) || (nrules==0)) {
1005  if (!DoRules()) {
1006  Log() << kmtype << "Rule terms were disabled" << Endl;
1007  }
1008  else {
1009  Log() << kmtype << "Eventhough rules were included in the model, none passed! " << nrules << Endl;
1010  }
1011  }
1012  else {
1013  Log() << kmtype << "Number of rules = " << nrules << Endl;
1014  if (isDebug) {
1015  Log() << kmtype << "N(cuts) in rules, average = " << fRuleNCave << Endl;
1016  Log() << kmtype << " RMS = " << fRuleNCsig << Endl;
1017  Log() << kmtype << "Fraction of signal rules = " << fRuleFSig << Endl;
1018  Log() << kmtype << "Fraction of rules containing a variable (%):" << Endl;
1019  for ( UInt_t v=0; v<nvars; v++) {
1020  Log() << kmtype << " " << std::setw(maxL) << GetMethodBase()->GetInputLabel(v);
1021  Log() << kmtype << Form(" = %2.2f",fRuleVarFrac[v]*100.0) << " %" << Endl;
1022  }
1023  }
1024  //
1025  // Print out all rules sorted in importance
1026  //
1027  std::list< std::pair<double,int> > sortedImp;
1028  for (Int_t i=0; i<nrules; i++) {
1029  sortedImp.push_back( std::pair<double,int>( fRules[i]->GetImportance(),i ) );
1030  }
1031  sortedImp.sort();
1032  //
1033  Log() << kmtype << "Printing the first " << printN << " rules, ordered in importance." << Endl;
1034  int pind=0;
1035  for ( std::list< std::pair<double,int> >::reverse_iterator itpair = sortedImp.rbegin();
1036  itpair != sortedImp.rend(); itpair++ ) {
1037  ind = itpair->second;
1038  // if (pind==0) impref =
1039  // Log() << kmtype << "Rule #" <<
1040  // Log() << kmtype << *fRules[ind] << Endl;
1041  fRules[ind]->PrintLogger(Form("Rule %4d : ",pind+1));
1042  pind++;
1043  if (pind==printN) {
1044  if (nrules==printN) {
1045  Log() << kmtype << "All rules printed" << Endl;
1046  }
1047  else {
1048  Log() << kmtype << "Skipping the next " << nrules-printN << " rules" << Endl;
1049  }
1050  break;
1051  }
1052  }
1053  }
1054  Log() << kmtype << "================================================================" << Endl;
1055  Log() << kmtype << Endl;
1056 }
1057 
1058 ////////////////////////////////////////////////////////////////////////////////
1059 /// write rules to stream
1060 
1061 void TMVA::RuleEnsemble::PrintRaw( std::ostream & os ) const
1062 {
1063  Int_t dp = os.precision();
1064  UInt_t nrules = fRules.size();
1065  // std::sort(fRules.begin(),fRules.end());
1066  //
1067  os << "ImportanceCut= " << fImportanceCut << std::endl;
1068  os << "LinQuantile= " << fLinQuantile << std::endl;
1069  os << "AverageSupport= " << fAverageSupport << std::endl;
1070  os << "AverageRuleSigma= " << fAverageRuleSigma << std::endl;
1071  os << "Offset= " << fOffset << std::endl;
1072  os << "NRules= " << nrules << std::endl;
1073  for (UInt_t i=0; i<nrules; i++){
1074  os << "***Rule " << i << std::endl;
1075  (fRules[i])->PrintRaw(os);
1076  }
1077  UInt_t nlinear = fLinNorm.size();
1078  //
1079  os << "NLinear= " << fLinTermOK.size() << std::endl;
1080  for (UInt_t i=0; i<nlinear; i++) {
1081  os << "***Linear " << i << std::endl;
1082  os << std::setprecision(10) << (fLinTermOK[i] ? 1:0) << " "
1083  << fLinCoefficients[i] << " "
1084  << fLinNorm[i] << " "
1085  << fLinDM[i] << " "
1086  << fLinDP[i] << " "
1087  << fLinImportance[i] << " " << std::endl;
1088  }
1089  os << std::setprecision(dp);
1090 }
1091 
1092 ////////////////////////////////////////////////////////////////////////////////
1093 /// write rules to XML
1094 
1095 void* TMVA::RuleEnsemble::AddXMLTo(void* parent) const
1096 {
1097  void* re = gTools().AddChild( parent, "Weights" ); // this is the "RuleEnsemble"
1098 
1099  UInt_t nrules = fRules.size();
1100  UInt_t nlinear = fLinNorm.size();
1101  gTools().AddAttr( re, "NRules", nrules );
1102  gTools().AddAttr( re, "NLinear", nlinear );
1103  gTools().AddAttr( re, "LearningModel", (int)fLearningModel );
1104  gTools().AddAttr( re, "ImportanceCut", fImportanceCut );
1105  gTools().AddAttr( re, "LinQuantile", fLinQuantile );
1106  gTools().AddAttr( re, "AverageSupport", fAverageSupport );
1107  gTools().AddAttr( re, "AverageRuleSigma", fAverageRuleSigma );
1108  gTools().AddAttr( re, "Offset", fOffset );
1109  for (UInt_t i=0; i<nrules; i++) fRules[i]->AddXMLTo(re);
1110 
1111  for (UInt_t i=0; i<nlinear; i++) {
1112  void* lin = gTools().AddChild( re, "Linear" );
1113  gTools().AddAttr( lin, "OK", (fLinTermOK[i] ? 1:0) );
1114  gTools().AddAttr( lin, "Coeff", fLinCoefficients[i] );
1115  gTools().AddAttr( lin, "Norm", fLinNorm[i] );
1116  gTools().AddAttr( lin, "DM", fLinDM[i] );
1117  gTools().AddAttr( lin, "DP", fLinDP[i] );
1118  gTools().AddAttr( lin, "Importance", fLinImportance[i] );
1119  }
1120  return re;
1121 }
1122 
1123 ////////////////////////////////////////////////////////////////////////////////
1124 /// read rules from XML
1125 
1126 void TMVA::RuleEnsemble::ReadFromXML( void* wghtnode )
1127 {
1128  UInt_t nrules, nlinear;
1129  gTools().ReadAttr( wghtnode, "NRules", nrules );
1130  gTools().ReadAttr( wghtnode, "NLinear", nlinear );
1131  Int_t iLearningModel;
1132  gTools().ReadAttr( wghtnode, "LearningModel", iLearningModel );
1133  fLearningModel = (ELearningModel) iLearningModel;
1134  gTools().ReadAttr( wghtnode, "ImportanceCut", fImportanceCut );
1135  gTools().ReadAttr( wghtnode, "LinQuantile", fLinQuantile );
1136  gTools().ReadAttr( wghtnode, "AverageSupport", fAverageSupport );
1137  gTools().ReadAttr( wghtnode, "AverageRuleSigma", fAverageRuleSigma );
1138  gTools().ReadAttr( wghtnode, "Offset", fOffset );
1139 
1140  // read rules
1141  DeleteRules();
1142 
1143  UInt_t i = 0;
1144  fRules.resize( nrules );
1145  void* ch = gTools().GetChild( wghtnode );
1146  for (i=0; i<nrules; i++) {
1147  fRules[i] = new Rule();
1148  fRules[i]->SetRuleEnsemble( this );
1149  fRules[i]->ReadFromXML( ch );
1150 
1151  ch = gTools().GetNextChild(ch);
1152  }
1153 
1154  // read linear classifier (Fisher)
1155  fLinNorm .resize( nlinear );
1156  fLinTermOK .resize( nlinear );
1157  fLinCoefficients.resize( nlinear );
1158  fLinDP .resize( nlinear );
1159  fLinDM .resize( nlinear );
1160  fLinImportance .resize( nlinear );
1161 
1162  Int_t iok;
1163  i=0;
1164  while(ch) {
1165  gTools().ReadAttr( ch, "OK", iok );
1166  fLinTermOK[i] = (iok == 1);
1167  gTools().ReadAttr( ch, "Coeff", fLinCoefficients[i] );
1168  gTools().ReadAttr( ch, "Norm", fLinNorm[i] );
1169  gTools().ReadAttr( ch, "DM", fLinDM[i] );
1170  gTools().ReadAttr( ch, "DP", fLinDP[i] );
1171  gTools().ReadAttr( ch, "Importance", fLinImportance[i] );
1172 
1173  i++;
1174  ch = gTools().GetNextChild(ch);
1175  }
1176 }
1177 
1178 ////////////////////////////////////////////////////////////////////////////////
1179 /// read rule ensemble from stream
1180 
1181 void TMVA::RuleEnsemble::ReadRaw( std::istream & istr )
1182 {
1183  UInt_t nrules;
1184  //
1185  std::string dummy;
1186  Int_t idum;
1187  //
1188  // First block is general stuff
1189  //
1190  istr >> dummy >> fImportanceCut;
1191  istr >> dummy >> fLinQuantile;
1192  istr >> dummy >> fAverageSupport;
1193  istr >> dummy >> fAverageRuleSigma;
1194  istr >> dummy >> fOffset;
1195  istr >> dummy >> nrules;
1196  //
1197  // Now read in the rules
1198  //
1199  DeleteRules();
1200  //
1201  for (UInt_t i=0; i<nrules; i++){
1202  istr >> dummy >> idum; // read line "***Rule <ind>"
1203  fRules.push_back( new Rule() );
1204  (fRules.back())->SetRuleEnsemble( this );
1205  (fRules.back())->ReadRaw(istr);
1206  }
1207  //
1208  // and now the linear terms
1209  //
1210  UInt_t nlinear;
1211  //
1212  // coverity[tainted_data_argument]
1213  istr >> dummy >> nlinear;
1214  //
1215  fLinNorm .resize( nlinear );
1216  fLinTermOK .resize( nlinear );
1217  fLinCoefficients.resize( nlinear );
1218  fLinDP .resize( nlinear );
1219  fLinDM .resize( nlinear );
1220  fLinImportance .resize( nlinear );
1221  //
1222 
1223  Int_t iok;
1224  for (UInt_t i=0; i<nlinear; i++) {
1225  istr >> dummy >> idum;
1226  istr >> iok;
1227  fLinTermOK[i] = (iok==1);
1228  istr >> fLinCoefficients[i];
1229  istr >> fLinNorm[i];
1230  istr >> fLinDM[i];
1231  istr >> fLinDP[i];
1232  istr >> fLinImportance[i];
1233  }
1234 }
1235 
1236 ////////////////////////////////////////////////////////////////////////////////
1237 /// copy function
1238 
1240 {
1241  if(this != &other) {
1242  fRuleFit = other.GetRuleFit();
1243  fRuleMinDist = other.GetRuleMinDist();
1244  fOffset = other.GetOffset();
1245  fRules = other.GetRulesConst();
1246  fImportanceCut = other.GetImportanceCut();
1247  fVarImportance = other.GetVarImportance();
1248  fLearningModel = other.GetLearningModel();
1249  fLinQuantile = other.GetLinQuantile();
1250  fRuleNCsig = other.fRuleNCsig;
1251  fAverageRuleSigma = other.fAverageRuleSigma;
1252  fEventCacheOK = other.fEventCacheOK;
1253  fImportanceRef = other.fImportanceRef;
1254  fNRulesGenerated = other.fNRulesGenerated;
1255  fRuleFSig = other.fRuleFSig;
1256  fRuleMapInd0 = other.fRuleMapInd0;
1257  fRuleMapInd1 = other.fRuleMapInd1;
1258  fRuleMapOK = other.fRuleMapOK;
1259  fRuleNCave = other.fRuleNCave;
1260  }
1261 }
1262 
1263 ////////////////////////////////////////////////////////////////////////////////
1264 /// calculate the number of rules
1265 
1267 {
1268  if (dtree==0) return 0;
1269  Node *node = dtree->GetRoot();
1270  Int_t nendnodes = 0;
1271  FindNEndNodes( node, nendnodes );
1272  return 2*(nendnodes-1);
1273 }
1274 
1275 ////////////////////////////////////////////////////////////////////////////////
1276 /// find the number of leaf nodes
1277 
1278 void TMVA::RuleEnsemble::FindNEndNodes( const Node *node, Int_t & nendnodes )
1279 {
1280  if (node==0) return;
1281  if ((node->GetRight()==0) && (node->GetLeft()==0)) {
1282  ++nendnodes;
1283  return;
1284  }
1285  const Node *nodeR = node->GetRight();
1286  const Node *nodeL = node->GetLeft();
1287  FindNEndNodes( nodeR, nendnodes );
1288  FindNEndNodes( nodeL, nendnodes );
1289 }
1290 
1291 ////////////////////////////////////////////////////////////////////////////////
1292 /// create rules from the decsision tree structure
1293 
1295 {
1296  Node *node = dtree->GetRoot();
1297  AddRule( node );
1298 }
1299 
1300 ////////////////////////////////////////////////////////////////////////////////
1301 /// add a new rule to the tree
1302 
1304 {
1305  if (node==0) return;
1306  if (node->GetParent()==0) { // it's a root node, don't make a rule
1307  AddRule( node->GetRight() );
1308  AddRule( node->GetLeft() );
1309  }
1310  else {
1311  Rule *rule = MakeTheRule(node);
1312  if (rule) {
1313  fRules.push_back( rule );
1314  AddRule( node->GetRight() );
1315  AddRule( node->GetLeft() );
1316  }
1317  else {
1318  Log() << kFATAL << "<AddRule> - ERROR failed in creating a rule! BUG!" << Endl;
1319  }
1320  }
1321 }
1322 
1323 ////////////////////////////////////////////////////////////////////////////////
1324 ///
1325 /// Make a Rule from a given Node.
1326 /// The root node (ie no parent) does not generate a Rule.
1327 /// The first node in a rule is always the root node => fNodes.size()>=2
1328 /// Each node corresponds to a cut and the cut value is given by the parent node.
1329 ///
1330 ///
1331 
1333 {
1334  if (node==0) {
1335  Log() << kFATAL << "<MakeTheRule> Input node is NULL. Should not happen. BUG!" << Endl;
1336  return 0;
1337  }
1338 
1339  if (node->GetParent()==0) { // a root node - ignore
1340  return 0;
1341  }
1342  //
1343  std::vector< const Node * > nodeVec;
1344  const Node *parent = node;
1345  //
1346  // Make list with the input node at the end:
1347  // <root node> <node1> <node2> ... <node given as argument>
1348  //
1349  nodeVec.push_back( node );
1350  while (parent!=0) {
1351  parent = parent->GetParent();
1352  if (!parent) continue;
1353  const DecisionTreeNode* dtn = dynamic_cast<const DecisionTreeNode*>(parent);
1354  if (dtn && dtn->GetSelector()>=0)
1355  nodeVec.insert( nodeVec.begin(), parent );
1356 
1357  }
1358  if (nodeVec.size()<2) {
1359  Log() << kFATAL << "<MakeTheRule> BUG! Inconsistent Rule!" << Endl;
1360  return 0;
1361  }
1362  Rule *rule = new Rule( this, nodeVec );
1363  rule->SetMsgType( Log().GetMinType() );
1364  return rule;
1365 }
1366 
1367 ////////////////////////////////////////////////////////////////////////////////
1368 /// Makes rule map for all events
1369 
1370 void TMVA::RuleEnsemble::MakeRuleMap(const std::vector<const Event *> *events, UInt_t ifirst, UInt_t ilast)
1371 {
1372  Log() << kVERBOSE << "Making Rule map for all events" << Endl;
1373  // make rule response map
1374  if (events==0) events = GetTrainingEvents();
1375  if ((ifirst==0) || (ilast==0) || (ifirst>ilast)) {
1376  ifirst = 0;
1377  ilast = events->size()-1;
1378  }
1379  // check if identical to previous call
1380  if ((events!=fRuleMapEvents) ||
1381  (ifirst!=fRuleMapInd0) ||
1382  (ilast !=fRuleMapInd1)) {
1383  fRuleMapOK = kFALSE;
1384  }
1385  //
1386  if (fRuleMapOK) {
1387  Log() << kVERBOSE << "<MakeRuleMap> Map is already valid" << Endl;
1388  return; // already cached
1389  }
1390  fRuleMapEvents = events;
1391  fRuleMapInd0 = ifirst;
1392  fRuleMapInd1 = ilast;
1393  // check number of rules
1394  UInt_t nrules = GetNRules();
1395  if (nrules==0) {
1396  Log() << kVERBOSE << "No rules found in MakeRuleMap()" << Endl;
1397  fRuleMapOK = kTRUE;
1398  return;
1399  }
1400  //
1401  // init map
1402  //
1403  std::vector<UInt_t> ruleind;
1404  fRuleMap.clear();
1405  for (UInt_t i=ifirst; i<=ilast; i++) {
1406  ruleind.clear();
1407  fRuleMap.push_back( ruleind );
1408  for (UInt_t r=0; r<nrules; r++) {
1409  if (fRules[r]->EvalEvent(*((*events)[i]))) {
1410  fRuleMap.back().push_back(r); // save only rules that are accepted
1411  }
1412  }
1413  }
1414  fRuleMapOK = kTRUE;
1415  Log() << kVERBOSE << "Made rule map for event# " << ifirst << " : " << ilast << Endl;
1416 }
1417 
1418 ////////////////////////////////////////////////////////////////////////////////
1419 /// std::ostream operator
1420 
1421 std::ostream& TMVA::operator<< ( std::ostream& os, const RuleEnsemble & rules )
1422 {
1423  os << "DON'T USE THIS - TO BE REMOVED" << std::endl;
1424  rules.Print();
1425  return os;
1426 }
const std::vector< const TMVA::Event * > * GetTrainingEvents() const
get list of training events from the rule fitter
Double_t GetImportanceCut() const
Definition: RuleEnsemble.h:273
Double_t PdfLinear(Double_t &nsig, Double_t &ntot) const
This function returns Pr( y = 1 | x ) for the linear terms.
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
ELearningModel GetLearningModel() const
Definition: RuleEnsemble.h:272
const std::vector< Double_t > & GetVarImportance() const
Definition: RuleEnsemble.h:282
virtual Double_t Rndm(Int_t i=0)
Machine independent random number generator.
Definition: TRandom.cxx:512
bool equal(double d1, double d2, double stol=10000)
RuleEnsemble()
constructor
Rule * MakeTheRule(const Node *node)
Make a Rule from a given Node.
Int_t CalcNRules(const TMVA::DecisionTree *dtree)
calculate the number of rules
virtual ~RuleEnsemble()
destructor
THist< 1, float > TH1F
Definition: THist.h:315
Short_t Min(Short_t a, Short_t b)
Definition: TMathBase.h:170
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
Double_t GetLinQuantile() const
Definition: RuleEnsemble.h:284
const Bool_t kFALSE
Definition: Rtypes.h:92
void Print() const
print function
void PrintRuleGen() const
print rule generation info
void CleanupLinear()
cleanup linear model
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
Definition: Tools.h:308
void MakeRuleMap(const std::vector< const TMVA::Event * > *events=0, UInt_t ifirst=0, UInt_t ilast=0)
Makes rule map for all events.
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1134
Short_t Abs(Short_t d)
Definition: TMathBase.h:110
virtual DecisionTreeNode * GetRoot() const
Definition: DecisionTree.h:102
Bool_t Equal(const Rule &other, Bool_t useCutValue, Double_t maxdist) const
Compare two rules.
Definition: Rule.cxx:164
void SetMsgType(EMsgType t)
const Event * GetTrainingEvent(UInt_t i) const
get the training event from the rule fitter
void SetImportanceRef(Double_t impref)
set reference importance
void RuleResponseStats()
calculate various statistics for this rule
double sqrt(double)
TClass * GetClass(T *)
Definition: TClass.h:555
Tools & gTools()
Definition: Tools.cxx:79
Double_t PdfRule(Double_t &nsig, Double_t &ntot) const
This function returns Pr( y = 1 | x ) for rules.
void RemoveSimilarRules()
remove rules that behave similar
void Copy(RuleEnsemble const &other)
copy function
void PrintRaw(std::ostream &os) const
write rules to stream
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1158
const MethodRuleFit * GetMethodRuleFit() const
Get a pointer to the original MethodRuleFit.
void CalcImportance()
calculate the importance of each rule
void CleanupRules()
cleanup rules
Double_t FStar() const
We want to estimate F* = argmin Eyx( L(y,F(x) ), min wrt F(x) F(x) = FL(x) + FR(x) ...
void CalcVarImportance()
Calculates variable importance using eq (35) in RuleFit paper by Friedman et.al.
ROOT::R::TRInterface & r
Definition: Object.C:4
SVector< double, 2 > v
Definition: Dict.h:5
const RuleFit * GetRuleFit() const
Definition: RuleEnsemble.h:261
Double_t CalcRuleImportance()
calculate importance of each rule
EMsgType
Definition: Types.h:61
void MakeRulesFromTree(const DecisionTree *dtree)
create rules from the decsision tree structure
Double_t CoefficientRadius()
Calculates sqrt(Sum(a_i^2)), i=1..N (NOTE do not include a0)
void AddRule(const Node *node)
add a new rule to the tree
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
void RuleStatistics()
calculate various statistics for this rule
void MakeRules(const std::vector< const TMVA::DecisionTree * > &forest)
Makes rules from the given decision tree.
Double_t GetWeight(Double_t x) const
void ReadAttr(void *node, const char *, T &value)
Definition: Tools.h:295
const std::vector< TMVA::Rule * > & GetRulesConst() const
Definition: RuleEnsemble.h:277
R__EXTERN TRandom * gRandom
Definition: TRandom.h:62
std::ostream & operator<<(std::ostream &os, const BinaryTree &tree)
print the tree recursinvely using the << operator
Definition: BinaryTree.cxx:155
Double_t fImportanceRef
Definition: RuleEnsemble.h:365
void ReadFromXML(void *wghtnode)
read rules from XML
void FindNEndNodes(const TMVA::Node *node, Int_t &nendnodes)
find the number of leaf nodes
void Initialize(const RuleFit *rf)
Initializes all member variables with default values.
void SetCoefficients(const std::vector< Double_t > &v)
set all rule coefficients
double Double_t
Definition: RtypesCore.h:55
int type
Definition: TGX11.cxx:120
static RooMathCoreReg dummy
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1170
virtual Node * GetParent() const
Definition: Node.h:93
void MakeLinearTerms()
Make the linear terms as in eq 25, ref 2 For this the b and (1-b) quatiles are needed.
virtual Node * GetRight() const
Definition: Node.h:92
Double_t fAverageRuleSigma
Definition: RuleEnsemble.h:367
UInt_t GetClass() const
Definition: Event.h:86
void CalcRuleSupport()
calculate the support for all rules
static Vc_ALWAYS_INLINE int_v max(const int_v &x, const int_v &y)
Definition: vector.h:440
Double_t GetRuleMinDist() const
Definition: RuleEnsemble.h:290
void MakeModel()
create model
Short_t GetSelector() const
void SetRules(const std::vector< TMVA::Rule * > &rules)
set rules
Bool_t UseBoost() const
Definition: MethodRuleFit.h:96
void SetMsgType(EMsgType t)
Definition: Rule.cxx:148
Short_t Max(Short_t a, Short_t b)
Definition: TMathBase.h:202
void ResetCoefficients()
reset all rule coefficients
void * AddXMLTo(void *parent) const
write rules to XML
void GetCoefficients(std::vector< Double_t > &v)
Retrieve all rule coefficients.
Double_t Sqrt(Double_t x)
Definition: TMath.h:464
Double_t CalcLinImportance()
calculate the linear importance for each rule
const Bool_t kTRUE
Definition: Rtypes.h:91
const MethodBase * GetMethodBase() const
Get a pointer to the original MethodRuleFit.
Double_t GetOffset() const
Definition: RuleEnsemble.h:275
Definition: math.cpp:60
virtual Node * GetLeft() const
Definition: Node.h:91
void ReadRaw(std::istream &istr)
read rule ensemble from stream