Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RuleEnsemble.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : RuleEnsemble *
8 * *
9 * *
10 * Description: *
11 * A class generating an ensemble of rules *
12 * Input: a forest of decision trees *
13 * Output: an ensemble of rules *
14 * *
15 * Authors (alphabetical): *
16 * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
17 * Helge Voss <Helge.Voss@cern.ch> - MPI-KP Heidelberg, GER *
18 * *
19 * Copyright (c) 2005: *
20 * CERN, Switzerland *
21 * Iowa State U. *
22 * MPI-K Heidelberg, Germany *
23 * *
24 * Redistribution and use in source and binary forms, with or without *
25 * modification, are permitted according to the terms listed in LICENSE *
26 * (see tmva/doc/LICENSE) *
27 **********************************************************************************/
28
29/*! \class TMVA::RuleEnsemble
30\ingroup TMVA
31
32*/
33#include "TMVA/RuleEnsemble.h"
34
35#include "TMVA/DataSetInfo.h"
36#include "TMVA/DecisionTree.h"
37#include "TMVA/Event.h"
38#include "TMVA/MethodBase.h"
39#include "TMVA/MethodRuleFit.h"
40#include "TMVA/MsgLogger.h"
41#include "TMVA/RuleFit.h"
42#include "TMVA/Tools.h"
43#include "TMVA/Types.h"
44
45#include "TRandom3.h"
46#include "TH1F.h"
47
48#include <algorithm>
49#include <cstdlib>
50#include <iomanip>
51#include <list>
52
53////////////////////////////////////////////////////////////////////////////////
54/// constructor
55
57 : fLearningModel ( kFull )
58 , fImportanceCut ( 0 )
59 , fLinQuantile ( 0.025 ) // default quantile for killing outliers in linear terms
60 , fOffset ( 0 )
61 , fAverageSupport ( 0.8 )
62 , fAverageRuleSigma( 0.4 ) // default value - used if only linear model is chosen
63 , fRuleFSig ( 0 )
64 , fRuleNCave ( 0 )
65 , fRuleNCsig ( 0 )
66 , fRuleMinDist ( 1e-3 ) // closest allowed 'distance' between two rules
67 , fNRulesGenerated ( 0 )
68 , fEvent ( 0 )
69 , fEventCacheOK ( true )
70 , fRuleMapOK ( true )
71 , fRuleMapInd0 ( 0 )
72 , fRuleMapInd1 ( 0 )
73 , fRuleMapEvents ( 0 )
74 , fLogger( new MsgLogger("RuleFit") )
75{
76 Initialize( rf );
77}
78
79////////////////////////////////////////////////////////////////////////////////
80/// copy constructor
81
83 : fAverageSupport ( 1 )
84 , fEvent(0)
85 , fRuleMapEvents(0)
86 , fRuleFit(0)
87 , fLogger( new MsgLogger("RuleFit") )
88{
89 Copy( other );
90}
91
92////////////////////////////////////////////////////////////////////////////////
93/// constructor
94
96 : fLearningModel ( kFull )
97 , fImportanceCut ( 0 )
98 , fLinQuantile ( 0.025 ) // default quantile for killing outliers in linear terms
99 , fOffset ( 0 )
100 , fImportanceRef ( 1.0 )
101 , fAverageSupport ( 0.8 )
102 , fAverageRuleSigma( 0.4 ) // default value - used if only linear model is chosen
103 , fRuleFSig ( 0 )
104 , fRuleNCave ( 0 )
105 , fRuleNCsig ( 0 )
106 , fRuleMinDist ( 1e-3 ) // closest allowed 'distance' between two rules
107 , fNRulesGenerated ( 0 )
108 , fEvent ( 0 )
109 , fEventCacheOK ( true )
110 , fRuleMapOK ( true )
111 , fRuleMapInd0 ( 0 )
112 , fRuleMapInd1 ( 0 )
113 , fRuleMapEvents ( 0 )
114 , fRuleFit ( 0 )
115 , fLogger( new MsgLogger("RuleFit") )
116{
117}
118
119////////////////////////////////////////////////////////////////////////////////
120/// destructor
121
123{
124 for ( std::vector<Rule *>::iterator itrRule = fRules.begin(); itrRule != fRules.end(); ++itrRule ) {
125 delete *itrRule;
126 }
127 // NOTE: Should not delete the histos fLinPDFB/S since they are delete elsewhere
128 delete fLogger;
129}
130
131////////////////////////////////////////////////////////////////////////////////
132/// Initializes all member variables with default values
133
135{
136 SetAverageRuleSigma(0.4); // default value - used if only linear model is chosen
137 fRuleFit = rf;
138 UInt_t nvars = GetMethodBase()->GetNvar();
139 fVarImportance.clear();
140 fLinPDFB.clear();
141 fLinPDFS.clear();
142 //
143 fVarImportance.resize( nvars,0.0 );
144 fLinPDFB.resize( nvars,0 );
145 fLinPDFS.resize( nvars,0 );
146 fImportanceRef = 1.0;
147 for (UInt_t i=0; i<nvars; i++) { // a priori all linear terms are equally valid
148 fLinTermOK.push_back(kTRUE);
149 }
150}
151
152////////////////////////////////////////////////////////////////////////////////
153
155 fLogger->SetMinType(t);
156}
157
158////////////////////////////////////////////////////////////////////////////////
159///
160/// Get a pointer to the original MethodRuleFit.
161
163{
164 return ( fRuleFit==0 ? 0:fRuleFit->GetMethodRuleFit());
165}
166
167////////////////////////////////////////////////////////////////////////////////
168///
169/// Get a pointer to the original MethodRuleFit.
170
172{
173 return ( fRuleFit==0 ? 0:fRuleFit->GetMethodBase());
174}
175
176////////////////////////////////////////////////////////////////////////////////
177/// create model
178
180{
181 MakeRules( fRuleFit->GetForest() );
182
183 MakeLinearTerms();
184
185 MakeRuleMap();
186
187 CalcRuleSupport();
188
189 RuleStatistics();
190
191 PrintRuleGen();
192}
193
194////////////////////////////////////////////////////////////////////////////////
195///
196/// Calculates sqrt(Sum(a_i^2)), i=1..N (NOTE do not include a0)
197
199{
200 Int_t ncoeffs = fRules.size();
201 if (ncoeffs<1) return 0;
202 //
203 Double_t sum2=0;
204 Double_t val;
205 for (Int_t i=0; i<ncoeffs; i++) {
206 val = fRules[i]->GetCoefficient();
207 sum2 += val*val;
208 }
209 return sum2;
210}
211
212////////////////////////////////////////////////////////////////////////////////
213/// reset all rule coefficients
214
216{
217 fOffset = 0.0;
218 UInt_t nrules = fRules.size();
219 for (UInt_t i=0; i<nrules; i++) {
220 fRules[i]->SetCoefficient(0.0);
221 }
222}
223
224////////////////////////////////////////////////////////////////////////////////
225/// set all rule coefficients
226
227void TMVA::RuleEnsemble::SetCoefficients( const std::vector< Double_t > & v )
228{
229 UInt_t nrules = fRules.size();
230 if (v.size()!=nrules) {
231 Log() << kFATAL << "<SetCoefficients> - BUG TRAP - input vector wrong size! It is = " << v.size()
232 << " when it should be = " << nrules << Endl;
233 }
234 for (UInt_t i=0; i<nrules; i++) {
235 fRules[i]->SetCoefficient(v[i]);
236 }
237}
238
239////////////////////////////////////////////////////////////////////////////////
240/// Retrieve all rule coefficients
241
242void TMVA::RuleEnsemble::GetCoefficients( std::vector< Double_t > & v )
243{
244 UInt_t nrules = fRules.size();
245 v.resize(nrules);
246 if (nrules==0) return;
247 //
248 for (UInt_t i=0; i<nrules; i++) {
249 v[i] = (fRules[i]->GetCoefficient());
250 }
251}
252
253////////////////////////////////////////////////////////////////////////////////
254/// get list of training events from the rule fitter
255
256const std::vector<const TMVA::Event*>* TMVA::RuleEnsemble::GetTrainingEvents() const
257{
258 return &(fRuleFit->GetTrainingEvents());
259}
260
261////////////////////////////////////////////////////////////////////////////////
262/// get the training event from the rule fitter
263
265{
266 return fRuleFit->GetTrainingEvent(i);
267}
268
269////////////////////////////////////////////////////////////////////////////////
270/// remove rules that behave similar
271
273{
274 Log() << kVERBOSE << "Removing similar rules; distance = " << fRuleMinDist << Endl;
275
276 UInt_t nrulesIn = fRules.size();
277 TMVA::Rule *first, *second;
278 std::vector< Char_t > removeMe( nrulesIn,false ); // <--- stores boolean
279
280 Int_t remind=-1;
281 Double_t r;
282
283 for (UInt_t i=0; i<nrulesIn; i++) {
284 if (!removeMe[i]) {
285 first = fRules[i];
286 for (UInt_t k=i+1; k<nrulesIn; k++) {
287 if (!removeMe[k]) {
288 second = fRules[k];
289 Bool_t equal = first->Equal(*second,kTRUE,fRuleMinDist);
290 if (equal) {
291 r = gRandom->Rndm();
292 remind = (r>0.5 ? k:i); // randomly select rule
293 }
294 else {
295 remind = -1;
296 }
297
298 if (remind>-1) {
299 if (!removeMe[remind]) {
300 removeMe[remind] = true;
301 }
302 }
303 }
304 }
305 }
306 }
307 UInt_t ind = 0;
308 Rule *theRule;
309 for (UInt_t i=0; i<nrulesIn; i++) {
310 if (removeMe[i]) {
311 theRule = fRules[ind];
312 fRules.erase( fRules.begin() + ind );
313 delete theRule;
314 ind--;
315 }
316 ind++;
317 }
318 UInt_t nrulesOut = fRules.size();
319 Log() << kVERBOSE << "Removed " << nrulesIn - nrulesOut << " out of " << nrulesIn << " rules" << Endl;
320}
321
322////////////////////////////////////////////////////////////////////////////////
323/// cleanup rules
324
326{
327 UInt_t nrules = fRules.size();
328 if (nrules==0) return;
329 Log() << kVERBOSE << "Removing rules with relative importance < " << fImportanceCut << Endl;
330 if (fImportanceCut<=0) return;
331 //
332 // Mark rules to be removed
333 //
334 Rule *therule;
335 Int_t ind=0;
336 for (UInt_t i=0; i<nrules; i++) {
337 if (fRules[ind]->GetRelImportance()<fImportanceCut) {
338 therule = fRules[ind];
339 fRules.erase( fRules.begin() + ind );
340 delete therule;
341 ind--;
342 }
343 ind++;
344 }
345 Log() << kINFO << "Removed " << nrules-ind << " out of a total of " << nrules
346 << " rules with importance < " << fImportanceCut << Endl;
347}
348
349////////////////////////////////////////////////////////////////////////////////
350/// cleanup linear model
351
353{
354 UInt_t nlin = fLinNorm.size();
355 if (nlin==0) return;
356 Log() << kVERBOSE << "Removing linear terms with relative importance < " << fImportanceCut << Endl;
357 //
358 fLinTermOK.clear();
359 for (UInt_t i=0; i<nlin; i++) {
360 fLinTermOK.push_back( (fLinImportance[i]/fImportanceRef > fImportanceCut) );
361 }
362}
363
364////////////////////////////////////////////////////////////////////////////////
365/// calculate the support for all rules
366
368{
369 Log() << kVERBOSE << "Evaluating Rule support" << Endl;
370 Double_t s,t,stot,ssb;
371 Double_t ssig, sbkg, ssum;
372 stot = 0;
373 // reset to default values
374 SetAverageRuleSigma(0.4);
375 const std::vector<const Event *> *events = GetTrainingEvents();
376 Double_t nrules = static_cast<Double_t>(fRules.size());
377 Double_t ew;
378 //
379 if ((nrules>0) && (events->size()>0)) {
380 for ( std::vector< Rule * >::iterator itrRule=fRules.begin(); itrRule!=fRules.end(); ++itrRule ) {
381 s=0.0;
382 ssig=0.0;
383 sbkg=0.0;
384 for ( std::vector<const Event * >::const_iterator itrEvent=events->begin(); itrEvent!=events->end(); ++itrEvent ) {
385 if ((*itrRule)->EvalEvent( *(*itrEvent) )) {
386 ew = (*itrEvent)->GetWeight();
387 s += ew;
388 if (GetMethodRuleFit()->DataInfo().IsSignal(*itrEvent)) ssig += ew;
389 else sbkg += ew;
390 }
391 }
392 //
393 s = s/fRuleFit->GetNEveEff();
394 t = s*(1.0-s);
395 t = (t<0 ? 0:sqrt(t));
396 stot += s;
397 ssum = ssig+sbkg;
398 ssb = (ssum>0 ? Double_t(ssig)/Double_t(ssig+sbkg) : 0.0 );
399 (*itrRule)->SetSupport(s);
400 (*itrRule)->SetNorm(t);
401 (*itrRule)->SetSSB( ssb );
402 (*itrRule)->SetSSBNeve(Double_t(ssig+sbkg));
403 }
404 fAverageSupport = stot/nrules;
405 fAverageRuleSigma = TMath::Sqrt(fAverageSupport*(1.0-fAverageSupport));
406 Log() << kVERBOSE << "Standard deviation of support = " << fAverageRuleSigma << Endl;
407 Log() << kVERBOSE << "Average rule support = " << fAverageSupport << Endl;
408 }
409}
410
411////////////////////////////////////////////////////////////////////////////////
412/// calculate the importance of each rule
413
415{
416 Double_t maxRuleImp = CalcRuleImportance();
417 Double_t maxLinImp = CalcLinImportance();
418 Double_t maxImp = (maxRuleImp>maxLinImp ? maxRuleImp : maxLinImp);
419 SetImportanceRef( maxImp );
420}
421
422////////////////////////////////////////////////////////////////////////////////
423/// set reference importance
424
426{
427 for ( UInt_t i=0; i<fRules.size(); i++ ) {
428 fRules[i]->SetImportanceRef(impref);
429 }
430 fImportanceRef = impref;
431}
432////////////////////////////////////////////////////////////////////////////////
433/// calculate importance of each rule
434
436{
437 Double_t maxImp=-1.0;
438 Double_t imp;
439 Int_t nrules = fRules.size();
440 for ( int i=0; i<nrules; i++ ) {
441 fRules[i]->CalcImportance();
442 imp = fRules[i]->GetImportance();
443 if (imp>maxImp) maxImp = imp;
444 }
445 for ( Int_t i=0; i<nrules; i++ ) {
446 fRules[i]->SetImportanceRef(maxImp);
447 }
448
449 return maxImp;
450}
451
452////////////////////////////////////////////////////////////////////////////////
453/// calculate the linear importance for each rule
454
456{
457 Double_t maxImp=-1.0;
458 UInt_t nvars = fLinCoefficients.size();
459 fLinImportance.resize(nvars,0.0);
460 if (!DoLinear()) return maxImp;
461 //
462 // The linear importance is:
463 // I = |b_x|*sigma(x)
464 // Note that the coefficients in fLinCoefficients are for the normalized x
465 // => b'_x * x' = b'_x * sigma(r)*x/sigma(x)
466 // => b_x = b'_x*sigma(r)/sigma(x)
467 // => I = |b'_x|*sigma(r)
468 //
469 Double_t imp;
470 for ( UInt_t i=0; i<nvars; i++ ) {
471 imp = fAverageRuleSigma*TMath::Abs(fLinCoefficients[i]);
472 fLinImportance[i] = imp;
473 if (imp>maxImp) maxImp = imp;
474 }
475 return maxImp;
476}
477
478////////////////////////////////////////////////////////////////////////////////
479/// Calculates variable importance using eq (35) in RuleFit paper by Friedman et.al
480
482{
483 Log() << kVERBOSE << "Compute variable importance" << Endl;
484 Double_t rimp;
485 UInt_t nrules = fRules.size();
486 if (GetMethodBase()==0) Log() << kFATAL << "RuleEnsemble::CalcVarImportance() - should not be here!" << Endl;
487 UInt_t nvars = GetMethodBase()->GetNvar();
488 UInt_t nvarsUsed;
489 Double_t rimpN;
490 fVarImportance.resize(nvars,0);
491 // rules
492 if (DoRules()) {
493 for ( UInt_t ind=0; ind<nrules; ind++ ) {
494 rimp = fRules[ind]->GetImportance();
495 nvarsUsed = fRules[ind]->GetNumVarsUsed();
496 if (nvarsUsed<1)
497 Log() << kFATAL << "<CalcVarImportance> Variables for importance calc!!!??? A BUG!" << Endl;
498 rimpN = (nvarsUsed > 0 ? rimp/nvarsUsed:0.0);
499 for ( UInt_t iv=0; iv<nvars; iv++ ) {
500 if (fRules[ind]->ContainsVariable(iv)) {
501 fVarImportance[iv] += rimpN;
502 }
503 }
504 }
505 }
506 // linear terms
507 if (DoLinear()) {
508 for ( UInt_t iv=0; iv<fLinTermOK.size(); iv++ ) {
509 if (fLinTermOK[iv]) fVarImportance[iv] += fLinImportance[iv];
510 }
511 }
512 //
513 // Make variable importance relative the strongest variable
514 //
515 Double_t maximp = 0.0;
516 for ( UInt_t iv=0; iv<nvars; iv++ ) {
517 if ( fVarImportance[iv] > maximp ) maximp = fVarImportance[iv];
518 }
519 if (maximp>0) {
520 for ( UInt_t iv=0; iv<nvars; iv++ ) {
521 fVarImportance[iv] *= 1.0/maximp;
522 }
523 }
524}
525
526////////////////////////////////////////////////////////////////////////////////
527/// set rules
528///
529/// first clear all
530
531void TMVA::RuleEnsemble::SetRules( const std::vector<Rule *> & rules )
532{
533 DeleteRules();
534 //
535 fRules.resize(rules.size());
536 for (UInt_t i=0; i<fRules.size(); i++) {
537 fRules[i] = rules[i];
538 }
539 fEventCacheOK = kFALSE;
540}
541
542////////////////////////////////////////////////////////////////////////////////
543/// Makes rules from the given decision tree.
544/// First node in all rules is ALWAYS the root node.
545
546void TMVA::RuleEnsemble::MakeRules( const std::vector< const DecisionTree *> & forest )
547{
548 fRules.clear();
549 if (!DoRules()) return;
550 //
551 Int_t nrulesCheck=0;
552 Int_t nrules;
553 Int_t nendn;
554 Double_t sumnendn=0;
555 Double_t sumn2=0;
556 //
557 // UInt_t prevs;
558 UInt_t ntrees = forest.size();
559 for ( UInt_t ind=0; ind<ntrees; ind++ ) {
560 // prevs = fRules.size();
561 MakeRulesFromTree( forest[ind] );
562 nrules = CalcNRules( forest[ind] );
563 nendn = (nrules/2) + 1;
564 sumnendn += nendn;
565 sumn2 += nendn*nendn;
566 nrulesCheck += nrules;
567 }
568 Double_t nmean = (ntrees>0) ? sumnendn/ntrees : 0;
569 Double_t nsigm = TMath::Sqrt( gTools().ComputeVariance(sumn2,sumnendn,ntrees) );
570 Double_t ndev = 2.0*(nmean-2.0-nsigm)/(nmean-2.0+nsigm);
571 //
572 Log() << kVERBOSE << "Average number of end nodes per tree = " << nmean << Endl;
573 if (ntrees>1) Log() << kVERBOSE << "sigma of ditto ( ~= mean-2 ?) = "
574 << nsigm
575 << Endl;
576 Log() << kVERBOSE << "Deviation from exponential model = " << ndev << Endl;
577 Log() << kVERBOSE << "Corresponds to L (eq. 13, RuleFit ppr) = " << nmean << Endl;
578 // a BUG trap
579 if (nrulesCheck != static_cast<Int_t>(fRules.size())) {
580 Log() << kFATAL
581 << "BUG! number of generated and possible rules do not match! N(rules) = " << fRules.size()
582 << " != " << nrulesCheck << Endl;
583 }
584 Log() << kVERBOSE << "Number of generated rules: " << fRules.size() << Endl;
585
586 // save initial number of rules
587 fNRulesGenerated = fRules.size();
588
589 RemoveSimilarRules();
590
591 ResetCoefficients();
592
593}
594
595////////////////////////////////////////////////////////////////////////////////
596/// Make the linear terms as in eq 25, ref 2
597/// For this the b and (1-b) quantiles are needed
598
600{
601 if (!DoLinear()) return;
602
603 const std::vector<const Event *> *events = GetTrainingEvents();
604 UInt_t neve = events->size();
605 UInt_t nvars = ((*events)[0])->GetNVariables(); // Event -> GetNVariables();
606 Double_t val,ew;
607 typedef std::pair< Double_t, Int_t> dataType;
608 typedef std::pair< Double_t, dataType > dataPoint;
609
610 std::vector< std::vector<dataPoint> > vardata(nvars);
611 std::vector< Double_t > varsum(nvars,0.0);
612 std::vector< Double_t > varsum2(nvars,0.0);
613 // first find stats of all variables
614 // vardata[v][i].first -> value of var <v> in event <i>
615 // vardata[v][i].second.first -> the event weight
616 // vardata[v][i].second.second -> the event type
617 for (UInt_t i=0; i<neve; i++) {
618 ew = ((*events)[i])->GetWeight();
619 for (UInt_t v=0; v<nvars; v++) {
620 val = ((*events)[i])->GetValue(v);
621 vardata[v].push_back( dataPoint( val, dataType(ew,((*events)[i])->GetClass()) ) );
622 }
623 }
624 //
625 fLinDP.clear();
626 fLinDM.clear();
627 fLinCoefficients.clear();
628 fLinNorm.clear();
629 fLinDP.resize(nvars,0);
630 fLinDM.resize(nvars,0);
631 fLinCoefficients.resize(nvars,0);
632 fLinNorm.resize(nvars,0);
633
634 Double_t averageWeight = neve ? fRuleFit->GetNEveEff()/static_cast<Double_t>(neve) : 0;
635 // sort and find limits
636 Double_t stdl;
637
638 // find normalisation given in ref 2 after eq 26
639 Double_t lx;
640 Double_t nquant;
641 Double_t neff;
642 UInt_t indquantM;
643 UInt_t indquantP;
644
645 MethodBase *fMB=const_cast<MethodBase *>(fRuleFit->GetMethodBase());
646
647 for (UInt_t v=0; v<nvars; v++) {
648 varsum[v] = 0;
649 varsum2[v] = 0;
650 //
651 std::sort( vardata[v].begin(),vardata[v].end() );
652 nquant = fLinQuantile*fRuleFit->GetNEveEff(); // quantile = 0.025
653 neff=0;
654 UInt_t ie=0;
655 // first scan for lower quantile (including weights)
656 while ( (ie<neve) && (neff<nquant) ) {
657 neff += vardata[v][ie].second.first;
658 ie++;
659 }
660 indquantM = (ie==0 ? 0:ie-1);
661 // now for upper quantile
662 ie = neve;
663 neff=0;
664 while ( (ie>0) && (neff<nquant) ) {
665 ie--;
666 neff += vardata[v][ie].second.first;
667 }
668 indquantP = (ie==neve ? ie=neve-1:ie);
669 //
670 fLinDM[v] = vardata[v][indquantM].first; // delta-
671 fLinDP[v] = vardata[v][indquantP].first; // delta+
672
673 if(!fMB->IsSilentFile())
674 {
675 if (fLinPDFB[v]) delete fLinPDFB[v];
676 if (fLinPDFS[v]) delete fLinPDFS[v];
677 fLinPDFB[v] = new TH1F(TString::Format("bkgvar%d",v),"bkg temphist",40,fLinDM[v],fLinDP[v]);
678 fLinPDFS[v] = new TH1F(TString::Format("sigvar%d",v),"sig temphist",40,fLinDM[v],fLinDP[v]);
679 fLinPDFB[v]->Sumw2();
680 fLinPDFS[v]->Sumw2();
681 }
682 //
683 Int_t type;
684 const Double_t w = 1.0/fRuleFit->GetNEveEff();
685 for (ie=0; ie<neve; ie++) {
686 val = vardata[v][ie].first;
687 ew = vardata[v][ie].second.first;
688 type = vardata[v][ie].second.second;
689 lx = TMath::Min( fLinDP[v], TMath::Max( fLinDM[v], val ) );
690 varsum[v] += ew*lx;
691 varsum2[v] += ew*lx*lx;
692 if(!fMB->IsSilentFile())
693 {
694 if (type==1) fLinPDFS[v]->Fill(lx,w*ew);
695 else fLinPDFB[v]->Fill(lx,w*ew);
696 }
697 }
698 //
699 // Get normalization.
700 //
701 stdl = TMath::Sqrt( (varsum2[v] - (varsum[v]*varsum[v]/fRuleFit->GetNEveEff()))/(fRuleFit->GetNEveEff()-averageWeight) );
702 fLinNorm[v] = CalcLinNorm(stdl);
703 }
704 // Save PDFs - for debugging purpose
705 if(!fMB->IsSilentFile())
706 {
707 for (UInt_t v=0; v<nvars; v++) {
708 fLinPDFS[v]->Write();
709 fLinPDFB[v]->Write();
710 }
711 }
712}
713
714////////////////////////////////////////////////////////////////////////////////
715/// This function returns Pr( y = 1 | x ) for the linear terms.
716
718{
719 UInt_t nvars=fLinDP.size();
720
721 Double_t fstot=0;
722 Double_t fbtot=0;
723 nsig = 0;
724 ntot = nvars;
725 for (UInt_t v=0; v<nvars; v++) {
726 Double_t val = fEventLinearVal[v];
727 Int_t bin = fLinPDFS[v]->FindBin(val);
728 fstot += fLinPDFS[v]->GetBinContent(bin);
729 fbtot += fLinPDFB[v]->GetBinContent(bin);
730 }
731 if (nvars<1) return 0;
732 ntot = (fstot+fbtot)/Double_t(nvars);
733 nsig = (fstot)/Double_t(nvars);
734 return fstot/(fstot+fbtot);
735}
736
737////////////////////////////////////////////////////////////////////////////////
738/// This function returns Pr( y = 1 | x ) for rules.
739/// The probability returned is normalized against the number of rules which are actually passed
740
742{
743 Double_t sump = 0;
744 Double_t sumok = 0;
745 Double_t ssb;
746 Double_t neve;
747 //
748 UInt_t nrules = fRules.size();
749 for (UInt_t ir=0; ir<nrules; ir++) {
750 if (fEventRuleVal[ir]>0) {
751 ssb = fEventRuleVal[ir]*GetRulesConst(ir)->GetSSB(); // S/(S+B) is evaluated in CalcRuleSupport() using ALL training events
752 neve = GetRulesConst(ir)->GetSSBNeve(); // number of events accepted by the rule
753 sump += ssb*neve; // number of signal events
754 sumok += neve; // total number of events passed
755 }
756 }
757
758 nsig = sump;
759 ntot = sumok;
760 //
761 if (ntot>0) return nsig/ntot;
762 return 0.0;
763}
764
765////////////////////////////////////////////////////////////////////////////////
766/// We want to estimate F* = argmin Eyx( L(y,F(x) ), min wrt F(x)
767/// F(x) = FL(x) + FR(x) , linear and rule part
768
770{
771 SetEvent(e);
772 UpdateEventVal();
773 return FStar();
774}
775
776////////////////////////////////////////////////////////////////////////////////
777/// We want to estimate F* = argmin Eyx( L(y,F(x) ), min wrt F(x)
778/// F(x) = FL(x) + FR(x) , linear and rule part
779
781{
782 Double_t p=0;
783 Double_t nrs=0, nrt=0;
784 Double_t nls=0, nlt=0;
785 Double_t nt;
786 Double_t pr=0;
787 Double_t pl=0;
788
789 // first calculate Pr(y=1|X) for rules and linear terms
790 if (DoLinear()) pl = PdfLinear(nls, nlt);
791 if (DoRules()) pr = PdfRule(nrs, nrt);
792 // nr(l)t=0 or 1
793 if ((nlt>0) && (nrt>0)) nt=2.0;
794 else nt=1.0;
795 p = (pl+pr)/nt;
796 return 2.0*p-1.0;
797}
798
799////////////////////////////////////////////////////////////////////////////////
800/// calculate various statistics for this rule
801
803{
804 // TODO: NOT YET UPDATED FOR WEIGHTS
805 const std::vector<const Event *> *events = GetTrainingEvents();
806 const UInt_t neve = events->size();
807 const UInt_t nvars = GetMethodBase()->GetNvar();
808 const UInt_t nrules = fRules.size();
809 const Event *eveData;
810 // Flags
811 Bool_t sigRule;
812 Bool_t sigTag;
813 Bool_t bkgTag;
814 // Bool_t noTag;
815 Bool_t sigTrue;
816 Bool_t tagged;
817 // Counters
818 Int_t nsig=0;
819 Int_t nbkg=0;
820 Int_t ntag=0;
821 Int_t nss=0;
822 Int_t nsb=0;
823 Int_t nbb=0;
824 Int_t nbs=0;
825 std::vector<Int_t> varcnt;
826 // Clear vectors
827 fRulePSS.clear();
828 fRulePSB.clear();
829 fRulePBS.clear();
830 fRulePBB.clear();
831 fRulePTag.clear();
832 //
833 varcnt.resize(nvars,0);
834 fRuleVarFrac.clear();
835 fRuleVarFrac.resize(nvars,0);
836 //
837 for ( UInt_t i=0; i<nrules; i++ ) {
838 for ( UInt_t v=0; v<nvars; v++) {
839 if (fRules[i]->ContainsVariable(v)) varcnt[v]++; // count how often a variable occurs
840 }
841 sigRule = fRules[i]->IsSignalRule();
842 if (sigRule) { // rule is a signal rule (ie s/(s+b)>0.5)
843 nsig++;
844 }
845 else {
846 nbkg++;
847 }
848 // reset counters
849 nss=0;
850 nsb=0;
851 nbs=0;
852 nbb=0;
853 ntag=0;
854 // loop over all events
855 for (UInt_t e=0; e<neve; e++) {
856 eveData = (*events)[e];
857 tagged = fRules[i]->EvalEvent(*eveData);
858 sigTag = (tagged && sigRule); // it's tagged as a signal
859 bkgTag = (tagged && (!sigRule)); // ... as bkg
860 // noTag = !(sigTag || bkgTag); // ... not tagged
861 sigTrue = (eveData->GetClass() == 0); // true if event is true signal
862 if (tagged) {
863 ntag++;
864 if (sigTag && sigTrue) nss++;
865 if (sigTag && !sigTrue) nsb++;
866 if (bkgTag && sigTrue) nbs++;
867 if (bkgTag && !sigTrue) nbb++;
868 }
869 }
870 // Fill tagging probabilities
871 if (ntag>0 && neve > 0) { // should always be the case, but let's make sure and keep coverity quiet
872 fRulePTag.push_back(Double_t(ntag)/Double_t(neve));
873 fRulePSS.push_back(Double_t(nss)/Double_t(ntag));
874 fRulePSB.push_back(Double_t(nsb)/Double_t(ntag));
875 fRulePBS.push_back(Double_t(nbs)/Double_t(ntag));
876 fRulePBB.push_back(Double_t(nbb)/Double_t(ntag));
877 }
878 //
879 }
880 fRuleFSig = (nsig>0) ? static_cast<Double_t>(nsig)/static_cast<Double_t>(nsig+nbkg) : 0;
881 for ( UInt_t v=0; v<nvars; v++) {
882 fRuleVarFrac[v] = (nrules>0) ? Double_t(varcnt[v])/Double_t(nrules) : 0;
883 }
884}
885
886////////////////////////////////////////////////////////////////////////////////
887/// calculate various statistics for this rule
888
890{
891 const UInt_t nrules = fRules.size();
892 Double_t nc;
893 Double_t sumNc =0;
894 Double_t sumNc2=0;
895 for ( UInt_t i=0; i<nrules; i++ ) {
896 nc = static_cast<Double_t>(fRules[i]->GetNcuts());
897 sumNc += nc;
898 sumNc2 += nc*nc;
899 }
900 fRuleNCave = 0.0;
901 fRuleNCsig = 0.0;
902 if (nrules>0) {
903 fRuleNCave = sumNc/nrules;
904 fRuleNCsig = TMath::Sqrt(gTools().ComputeVariance(sumNc2,sumNc,nrules));
905 }
906}
907
908////////////////////////////////////////////////////////////////////////////////
909/// print rule generation info
910
912{
913 Log() << kHEADER << "-------------------RULE ENSEMBLE SUMMARY------------------------" << Endl;
914 const MethodRuleFit *mrf = GetMethodRuleFit();
915 if (mrf) Log() << kINFO << "Tree training method : " << (mrf->UseBoost() ? "AdaBoost":"Random") << Endl;
916 Log() << kINFO << "Number of events per tree : " << fRuleFit->GetNTreeSample() << Endl;
917 Log() << kINFO << "Number of trees : " << fRuleFit->GetForest().size() << Endl;
918 Log() << kINFO << "Number of generated rules : " << fNRulesGenerated << Endl;
919 Log() << kINFO << "Idem, after cleanup : " << fRules.size() << Endl;
920 Log() << kINFO << "Average number of cuts per rule : " << Form("%8.2f",fRuleNCave) << Endl;
921 Log() << kINFO << "Spread in number of cuts per rules : " << Form("%8.2f",fRuleNCsig) << Endl;
922 Log() << kVERBOSE << "Complexity : " << Form("%8.2f",fRuleNCave*fRuleNCsig) << Endl;
923 Log() << kINFO << "----------------------------------------------------------------" << Endl;
924 Log() << kINFO << Endl;
925}
926
927////////////////////////////////////////////////////////////////////////////////
928/// print function
929
931{
932 const EMsgType kmtype=kINFO;
933 const Bool_t isDebug = (fLogger->GetMinType()<=kDEBUG);
934 //
935 Log() << kmtype << Endl;
936 Log() << kmtype << "================================================================" << Endl;
937 Log() << kmtype << " M o d e l " << Endl;
938 Log() << kmtype << "================================================================" << Endl;
939
940 Int_t ind;
941 const UInt_t nvars = GetMethodBase()->GetNvar();
942 const Int_t nrules = fRules.size();
943 const Int_t printN = TMath::Min(10,nrules); //nrules+1;
944 Int_t maxL = 0;
945 for (UInt_t iv = 0; iv<fVarImportance.size(); iv++) {
946 if (GetMethodBase()->GetInputLabel(iv).Length() > maxL) maxL = GetMethodBase()->GetInputLabel(iv).Length();
947 }
948 //
949 if (isDebug) {
950 Log() << kDEBUG << "Variable importance:" << Endl;
951 for (UInt_t iv = 0; iv<fVarImportance.size(); iv++) {
952 Log() << kDEBUG << std::setw(maxL) << GetMethodBase()->GetInputLabel(iv)
953 << std::resetiosflags(std::ios::right)
954 << " : " << Form(" %3.3f",fVarImportance[iv]) << Endl;
955 }
956 }
957 //
958 Log() << kHEADER << "Offset (a0) = " << fOffset << Endl;
959 //
960 if (DoLinear()) {
961 if (fLinNorm.size() > 0) {
962 Log() << kmtype << "------------------------------------" << Endl;
963 Log() << kmtype << "Linear model (weights unnormalised)" << Endl;
964 Log() << kmtype << "------------------------------------" << Endl;
965 Log() << kmtype << std::setw(maxL) << "Variable"
966 << std::resetiosflags(std::ios::right) << " : "
967 << std::setw(11) << " Weights"
968 << std::resetiosflags(std::ios::right) << " : "
969 << "Importance"
970 << std::resetiosflags(std::ios::right)
971 << Endl;
972 Log() << kmtype << "------------------------------------" << Endl;
973 for ( UInt_t i=0; i<fLinNorm.size(); i++ ) {
974 Log() << kmtype << std::setw(std::max(maxL,8)) << GetMethodBase()->GetInputLabel(i);
975 if (fLinTermOK[i]) {
976 Log() << kmtype
977 << std::resetiosflags(std::ios::right)
978 << " : " << Form(" %10.3e",fLinCoefficients[i]*fLinNorm[i])
979 << " : " << Form(" %3.3f",fLinImportance[i]/fImportanceRef) << Endl;
980 }
981 else {
982 Log() << kmtype << "-> importance below threshold = "
983 << Form(" %3.3f",fLinImportance[i]/fImportanceRef) << Endl;
984 }
985 }
986 Log() << kmtype << "------------------------------------" << Endl;
987 }
988 }
989 else Log() << kmtype << "Linear terms were disabled" << Endl;
990
991 if ((!DoRules()) || (nrules==0)) {
992 if (!DoRules()) {
993 Log() << kmtype << "Rule terms were disabled" << Endl;
994 }
995 else {
996 Log() << kmtype << "Even though rules were included in the model, none passed! " << nrules << Endl;
997 }
998 }
999 else {
1000 Log() << kmtype << "Number of rules = " << nrules << Endl;
1001 if (isDebug) {
1002 Log() << kmtype << "N(cuts) in rules, average = " << fRuleNCave << Endl;
1003 Log() << kmtype << " RMS = " << fRuleNCsig << Endl;
1004 Log() << kmtype << "Fraction of signal rules = " << fRuleFSig << Endl;
1005 Log() << kmtype << "Fraction of rules containing a variable (%):" << Endl;
1006 for ( UInt_t v=0; v<nvars; v++) {
1007 Log() << kmtype << " " << std::setw(maxL) << GetMethodBase()->GetInputLabel(v);
1008 Log() << kmtype << Form(" = %2.2f",fRuleVarFrac[v]*100.0) << " %" << Endl;
1009 }
1010 }
1011 //
1012 // Print out all rules sorted in importance
1013 //
1014 std::list< std::pair<double,int> > sortedImp;
1015 for (Int_t i=0; i<nrules; i++) {
1016 sortedImp.push_back( std::pair<double,int>( fRules[i]->GetImportance(),i ) );
1017 }
1018 sortedImp.sort();
1019 //
1020 Log() << kmtype << "Printing the first " << printN << " rules, ordered in importance." << Endl;
1021 int pind=0;
1022 for ( std::list< std::pair<double,int> >::reverse_iterator itpair = sortedImp.rbegin();
1023 itpair != sortedImp.rend(); ++itpair ) {
1024 ind = itpair->second;
1025 // if (pind==0) impref =
1026 // Log() << kmtype << "Rule #" <<
1027 // Log() << kmtype << *fRules[ind] << Endl;
1028 fRules[ind]->PrintLogger(TString::Format("Rule %4d : ",pind+1));
1029 pind++;
1030 if (pind==printN) {
1031 if (nrules==printN) {
1032 Log() << kmtype << "All rules printed" << Endl;
1033 }
1034 else {
1035 Log() << kmtype << "Skipping the next " << nrules-printN << " rules" << Endl;
1036 }
1037 break;
1038 }
1039 }
1040 }
1041 Log() << kmtype << "================================================================" << Endl;
1042 Log() << kmtype << Endl;
1043}
1044
1045////////////////////////////////////////////////////////////////////////////////
1046/// write rules to stream
1047
1048void TMVA::RuleEnsemble::PrintRaw( std::ostream & os ) const
1049{
1050 Int_t dp = os.precision();
1051 UInt_t nrules = fRules.size();
1052 // std::sort(fRules.begin(),fRules.end());
1053 //
1054 os << "ImportanceCut= " << fImportanceCut << std::endl;
1055 os << "LinQuantile= " << fLinQuantile << std::endl;
1056 os << "AverageSupport= " << fAverageSupport << std::endl;
1057 os << "AverageRuleSigma= " << fAverageRuleSigma << std::endl;
1058 os << "Offset= " << fOffset << std::endl;
1059 os << "NRules= " << nrules << std::endl;
1060 for (UInt_t i=0; i<nrules; i++){
1061 os << "***Rule " << i << std::endl;
1062 (fRules[i])->PrintRaw(os);
1063 }
1064 UInt_t nlinear = fLinNorm.size();
1065 //
1066 os << "NLinear= " << fLinTermOK.size() << std::endl;
1067 for (UInt_t i=0; i<nlinear; i++) {
1068 os << "***Linear " << i << std::endl;
1069 os << std::setprecision(10) << (fLinTermOK[i] ? 1:0) << " "
1070 << fLinCoefficients[i] << " "
1071 << fLinNorm[i] << " "
1072 << fLinDM[i] << " "
1073 << fLinDP[i] << " "
1074 << fLinImportance[i] << " " << std::endl;
1075 }
1076 os << std::setprecision(dp);
1077}
1078
1079////////////////////////////////////////////////////////////////////////////////
1080/// write rules to XML
1081
1082void* TMVA::RuleEnsemble::AddXMLTo(void* parent) const
1083{
1084 void* re = gTools().AddChild( parent, "Weights" ); // this is the "RuleEnsemble"
1085
1086 UInt_t nrules = fRules.size();
1087 UInt_t nlinear = fLinNorm.size();
1088 gTools().AddAttr( re, "NRules", nrules );
1089 gTools().AddAttr( re, "NLinear", nlinear );
1090 gTools().AddAttr( re, "LearningModel", (int)fLearningModel );
1091 gTools().AddAttr( re, "ImportanceCut", fImportanceCut );
1092 gTools().AddAttr( re, "LinQuantile", fLinQuantile );
1093 gTools().AddAttr( re, "AverageSupport", fAverageSupport );
1094 gTools().AddAttr( re, "AverageRuleSigma", fAverageRuleSigma );
1095 gTools().AddAttr( re, "Offset", fOffset );
1096 for (UInt_t i=0; i<nrules; i++) fRules[i]->AddXMLTo(re);
1097
1098 for (UInt_t i=0; i<nlinear; i++) {
1099 void* lin = gTools().AddChild( re, "Linear" );
1100 gTools().AddAttr( lin, "OK", (fLinTermOK[i] ? 1:0) );
1101 gTools().AddAttr( lin, "Coeff", fLinCoefficients[i] );
1102 gTools().AddAttr( lin, "Norm", fLinNorm[i] );
1103 gTools().AddAttr( lin, "DM", fLinDM[i] );
1104 gTools().AddAttr( lin, "DP", fLinDP[i] );
1105 gTools().AddAttr( lin, "Importance", fLinImportance[i] );
1106 }
1107 return re;
1108}
1109
1110////////////////////////////////////////////////////////////////////////////////
1111/// read rules from XML
1112
1114{
1115 UInt_t nrules, nlinear;
1116 gTools().ReadAttr( wghtnode, "NRules", nrules );
1117 gTools().ReadAttr( wghtnode, "NLinear", nlinear );
1118 Int_t iLearningModel;
1119 gTools().ReadAttr( wghtnode, "LearningModel", iLearningModel );
1120 fLearningModel = (ELearningModel) iLearningModel;
1121 gTools().ReadAttr( wghtnode, "ImportanceCut", fImportanceCut );
1122 gTools().ReadAttr( wghtnode, "LinQuantile", fLinQuantile );
1123 gTools().ReadAttr( wghtnode, "AverageSupport", fAverageSupport );
1124 gTools().ReadAttr( wghtnode, "AverageRuleSigma", fAverageRuleSigma );
1125 gTools().ReadAttr( wghtnode, "Offset", fOffset );
1126
1127 // read rules
1128 DeleteRules();
1129
1130 UInt_t i = 0;
1131 fRules.resize( nrules );
1132 void* ch = gTools().GetChild( wghtnode );
1133 for (i=0; i<nrules; i++) {
1134 fRules[i] = new Rule();
1135 fRules[i]->SetRuleEnsemble( this );
1136 fRules[i]->ReadFromXML( ch );
1137
1138 ch = gTools().GetNextChild(ch);
1139 }
1140
1141 // read linear classifier (Fisher)
1142 fLinNorm .resize( nlinear );
1143 fLinTermOK .resize( nlinear );
1144 fLinCoefficients.resize( nlinear );
1145 fLinDP .resize( nlinear );
1146 fLinDM .resize( nlinear );
1147 fLinImportance .resize( nlinear );
1148
1149 Int_t iok;
1150 i=0;
1151 while(ch) {
1152 gTools().ReadAttr( ch, "OK", iok );
1153 fLinTermOK[i] = (iok == 1);
1154 gTools().ReadAttr( ch, "Coeff", fLinCoefficients[i] );
1155 gTools().ReadAttr( ch, "Norm", fLinNorm[i] );
1156 gTools().ReadAttr( ch, "DM", fLinDM[i] );
1157 gTools().ReadAttr( ch, "DP", fLinDP[i] );
1158 gTools().ReadAttr( ch, "Importance", fLinImportance[i] );
1159
1160 i++;
1161 ch = gTools().GetNextChild(ch);
1162 }
1163}
1164
1165////////////////////////////////////////////////////////////////////////////////
1166/// read rule ensemble from stream
1167
1168void TMVA::RuleEnsemble::ReadRaw( std::istream & istr )
1169{
1170 UInt_t nrules;
1171 //
1172 std::string dummy;
1173 Int_t idum;
1174 //
1175 // First block is general stuff
1176 //
1177 istr >> dummy >> fImportanceCut;
1178 istr >> dummy >> fLinQuantile;
1179 istr >> dummy >> fAverageSupport;
1180 istr >> dummy >> fAverageRuleSigma;
1181 istr >> dummy >> fOffset;
1182 istr >> dummy >> nrules;
1183 //
1184 // Now read in the rules
1185 //
1186 DeleteRules();
1187 //
1188 for (UInt_t i=0; i<nrules; i++){
1189 istr >> dummy >> idum; // read line "***Rule <ind>"
1190 fRules.push_back( new Rule() );
1191 (fRules.back())->SetRuleEnsemble( this );
1192 (fRules.back())->ReadRaw(istr);
1193 }
1194 //
1195 // and now the linear terms
1196 //
1197 UInt_t nlinear;
1198 //
1199 // coverity[tainted_data_argument]
1200 istr >> dummy >> nlinear;
1201 //
1202 fLinNorm .resize( nlinear );
1203 fLinTermOK .resize( nlinear );
1204 fLinCoefficients.resize( nlinear );
1205 fLinDP .resize( nlinear );
1206 fLinDM .resize( nlinear );
1207 fLinImportance .resize( nlinear );
1208 //
1209
1210 Int_t iok;
1211 for (UInt_t i=0; i<nlinear; i++) {
1212 istr >> dummy >> idum;
1213 istr >> iok;
1214 fLinTermOK[i] = (iok==1);
1215 istr >> fLinCoefficients[i];
1216 istr >> fLinNorm[i];
1217 istr >> fLinDM[i];
1218 istr >> fLinDP[i];
1219 istr >> fLinImportance[i];
1220 }
1221}
1222
1223////////////////////////////////////////////////////////////////////////////////
1224/// copy function
1225
1227{
1228 if(this != &other) {
1229 fRuleFit = other.GetRuleFit();
1230 fRuleMinDist = other.GetRuleMinDist();
1231 fOffset = other.GetOffset();
1232 fRules = other.GetRulesConst();
1233 fImportanceCut = other.GetImportanceCut();
1234 fVarImportance = other.GetVarImportance();
1235 fLearningModel = other.GetLearningModel();
1236 fLinQuantile = other.GetLinQuantile();
1237 fRuleNCsig = other.fRuleNCsig;
1238 fAverageRuleSigma = other.fAverageRuleSigma;
1239 fEventCacheOK = other.fEventCacheOK;
1240 fImportanceRef = other.fImportanceRef;
1241 fNRulesGenerated = other.fNRulesGenerated;
1242 fRuleFSig = other.fRuleFSig;
1243 fRuleMapInd0 = other.fRuleMapInd0;
1244 fRuleMapInd1 = other.fRuleMapInd1;
1245 fRuleMapOK = other.fRuleMapOK;
1246 fRuleNCave = other.fRuleNCave;
1247 }
1248}
1249
1250////////////////////////////////////////////////////////////////////////////////
1251/// calculate the number of rules
1252
1254{
1255 if (dtree==0) return 0;
1256 Node *node = dtree->GetRoot();
1257 Int_t nendnodes = 0;
1258 FindNEndNodes( node, nendnodes );
1259 return 2*(nendnodes-1);
1260}
1261
1262////////////////////////////////////////////////////////////////////////////////
1263/// find the number of leaf nodes
1264
1265void TMVA::RuleEnsemble::FindNEndNodes( const Node *node, Int_t & nendnodes )
1266{
1267 if (node==0) return;
1268 if ((node->GetRight()==0) && (node->GetLeft()==0)) {
1269 ++nendnodes;
1270 return;
1271 }
1272 const Node *nodeR = node->GetRight();
1273 const Node *nodeL = node->GetLeft();
1274 FindNEndNodes( nodeR, nendnodes );
1275 FindNEndNodes( nodeL, nendnodes );
1276}
1277
1278////////////////////////////////////////////////////////////////////////////////
1279/// create rules from the decision tree structure
1280
1282{
1283 Node *node = dtree->GetRoot();
1284 AddRule( node );
1285}
1286
1287////////////////////////////////////////////////////////////////////////////////
1288/// add a new rule to the tree
1289
1291{
1292 if (node==0) return;
1293 if (node->GetParent()==0) { // it's a root node, don't make a rule
1294 AddRule( node->GetRight() );
1295 AddRule( node->GetLeft() );
1296 }
1297 else {
1298 Rule *rule = MakeTheRule(node);
1299 if (rule) {
1300 fRules.push_back( rule );
1301 AddRule( node->GetRight() );
1302 AddRule( node->GetLeft() );
1303 }
1304 else {
1305 Log() << kFATAL << "<AddRule> - ERROR failed in creating a rule! BUG!" << Endl;
1306 }
1307 }
1308}
1309
1310////////////////////////////////////////////////////////////////////////////////
1311/// Make a Rule from a given Node.
1312/// The root node (ie no parent) does not generate a Rule.
1313/// The first node in a rule is always the root node => fNodes.size()>=2
1314/// Each node corresponds to a cut and the cut value is given by the parent node.
1315
1317{
1318 if (node==0) {
1319 Log() << kFATAL << "<MakeTheRule> Input node is NULL. Should not happen. BUG!" << Endl;
1320 return 0;
1321 }
1322
1323 if (node->GetParent()==0) { // a root node - ignore
1324 return 0;
1325 }
1326 //
1327 std::vector< const Node * > nodeVec;
1328 const Node *parent = node;
1329 //
1330 // Make list with the input node at the end:
1331 // <root node> <node1> <node2> ... <node given as argument>
1332 //
1333 nodeVec.push_back( node );
1334 while (parent!=0) {
1335 parent = parent->GetParent();
1336 if (!parent) continue;
1337 const DecisionTreeNode* dtn = dynamic_cast<const DecisionTreeNode*>(parent);
1338 if (dtn && dtn->GetSelector()>=0)
1339 nodeVec.insert( nodeVec.begin(), parent );
1340
1341 }
1342 if (nodeVec.size()<2) {
1343 Log() << kFATAL << "<MakeTheRule> BUG! Inconsistent Rule!" << Endl;
1344 return 0;
1345 }
1346 Rule *rule = new Rule( this, nodeVec );
1347 rule->SetMsgType( Log().GetMinType() );
1348 return rule;
1349}
1350
1351////////////////////////////////////////////////////////////////////////////////
1352/// Makes rule map for all events
1353
1354void TMVA::RuleEnsemble::MakeRuleMap(const std::vector<const Event *> *events, UInt_t ifirst, UInt_t ilast)
1355{
1356 Log() << kVERBOSE << "Making Rule map for all events" << Endl;
1357 // make rule response map
1358 if (events==0) events = GetTrainingEvents();
1359 if ((ifirst==0) || (ilast==0) || (ifirst>ilast)) {
1360 ifirst = 0;
1361 ilast = events->size()-1;
1362 }
1363 // check if identical to previous call
1364 if ((events!=fRuleMapEvents) ||
1365 (ifirst!=fRuleMapInd0) ||
1366 (ilast !=fRuleMapInd1)) {
1367 fRuleMapOK = kFALSE;
1368 }
1369 //
1370 if (fRuleMapOK) {
1371 Log() << kVERBOSE << "<MakeRuleMap> Map is already valid" << Endl;
1372 return; // already cached
1373 }
1374 fRuleMapEvents = events;
1375 fRuleMapInd0 = ifirst;
1376 fRuleMapInd1 = ilast;
1377 // check number of rules
1378 UInt_t nrules = GetNRules();
1379 if (nrules==0) {
1380 Log() << kVERBOSE << "No rules found in MakeRuleMap()" << Endl;
1381 fRuleMapOK = kTRUE;
1382 return;
1383 }
1384 //
1385 // init map
1386 //
1387 std::vector<UInt_t> ruleind;
1388 fRuleMap.clear();
1389 for (UInt_t i=ifirst; i<=ilast; i++) {
1390 ruleind.clear();
1391 fRuleMap.push_back( ruleind );
1392 for (UInt_t r=0; r<nrules; r++) {
1393 if (fRules[r]->EvalEvent(*((*events)[i]))) {
1394 fRuleMap.back().push_back(r); // save only rules that are accepted
1395 }
1396 }
1397 }
1398 fRuleMapOK = kTRUE;
1399 Log() << kVERBOSE << "Made rule map for event# " << ifirst << " : " << ilast << Endl;
1400}
1401
1402////////////////////////////////////////////////////////////////////////////////
1403/// std::ostream operator
1404
1405std::ostream& TMVA::operator<< ( std::ostream& os, const RuleEnsemble & rules )
1406{
1407 os << "DON'T USE THIS - TO BE REMOVED" << std::endl;
1408 rules.Print();
1409 return os;
1410}
#define e(i)
Definition RSha256.hxx:103
constexpr Bool_t kFALSE
Definition RtypesCore.h:101
double Double_t
Definition RtypesCore.h:59
constexpr Bool_t kTRUE
Definition RtypesCore.h:100
winID h TVirtualViewer3D TVirtualGLPainter p
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
R__EXTERN TRandom * gRandom
Definition TRandom.h:62
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2489
1-D histogram with a float per channel (see TH1 documentation)
Definition TH1.h:621
Short_t GetSelector() const
return index of variable used for discrimination at this node
Implementation of a Decision Tree.
virtual DecisionTreeNode * GetRoot() const
UInt_t GetClass() const
Definition Event.h:86
Virtual base Class for all MVA method.
Definition MethodBase.h:111
Bool_t IsSilentFile() const
Definition MethodBase.h:379
J Friedman's RuleFit method.
Bool_t UseBoost() const
ostringstream derivative to redirect and format output
Definition MsgLogger.h:57
Node for the BinarySearch or Decision Trees.
Definition Node.h:58
virtual Node * GetLeft() const
Definition Node.h:89
virtual Node * GetParent() const
Definition Node.h:91
virtual Node * GetRight() const
Definition Node.h:90
virtual ~RuleEnsemble()
destructor
void CalcVarImportance()
Calculates variable importance using eq (35) in RuleFit paper by Friedman et.al.
void SetImportanceRef(Double_t impref)
set reference importance
void CalcImportance()
calculate the importance of each rule
void PrintRuleGen() const
print rule generation info
Int_t CalcNRules(const TMVA::DecisionTree *dtree)
calculate the number of rules
UInt_t fNRulesGenerated
number of rules generated, before cleanup
void ResetCoefficients()
reset all rule coefficients
void SetMsgType(EMsgType t)
Double_t GetLinQuantile() const
void ReadRaw(std::istream &istr)
read rule ensemble from stream
void AddRule(const Node *node)
add a new rule to the tree
void ReadFromXML(void *wghtnode)
read rules from XML
Double_t GetImportanceCut() const
const Event * GetTrainingEvent(UInt_t i) const
get the training event from the rule fitter
const std::vector< const TMVA::Event * > * GetTrainingEvents() const
get list of training events from the rule fitter
Double_t GetRuleMinDist() const
void SetRules(const std::vector< TMVA::Rule * > &rules)
set rules
void MakeRules(const std::vector< const TMVA::DecisionTree * > &forest)
Makes rules from the given decision tree.
void RemoveSimilarRules()
remove rules that behave similar
Double_t fRuleFSig
N(sig)/N(sig)+N(bkg)
void FindNEndNodes(const TMVA::Node *node, Int_t &nendnodes)
find the number of leaf nodes
Bool_t fRuleMapOK
true if MakeRuleMap() has been called
RuleEnsemble()
constructor
UInt_t fRuleMapInd1
last index
const std::vector< Double_t > & GetVarImportance() const
void CleanupRules()
cleanup rules
void Initialize(const RuleFit *rf)
Initializes all member variables with default values.
void CleanupLinear()
cleanup linear model
void RuleResponseStats()
calculate various statistics for this rule
const RuleFit * GetRuleFit() const
void * AddXMLTo(void *parent) const
write rules to XML
const std::vector< TMVA::Rule * > & GetRulesConst() const
const MethodRuleFit * GetMethodRuleFit() const
Get a pointer to the original MethodRuleFit.
void MakeModel()
create model
void RuleStatistics()
calculate various statistics for this rule
void SetCoefficients(const std::vector< Double_t > &v)
set all rule coefficients
void Print() const
print function
Double_t PdfRule(Double_t &nsig, Double_t &ntot) const
This function returns Pr( y = 1 | x ) for rules.
void MakeRuleMap(const std::vector< const TMVA::Event * > *events=nullptr, UInt_t ifirst=0, UInt_t ilast=0)
Makes rule map for all events.
const MethodBase * GetMethodBase() const
Get a pointer to the original MethodRuleFit.
Double_t GetOffset() const
Double_t fRuleNCave
N(cuts) average.
void Copy(RuleEnsemble const &other)
copy function
Double_t CalcLinImportance()
calculate the linear importance for each rule
Double_t CalcRuleImportance()
calculate importance of each rule
Double_t fImportanceRef
reference importance (max)
void PrintRaw(std::ostream &os) const
write rules to stream
Double_t fAverageRuleSigma
average rule sigma
Bool_t fEventCacheOK
true if rule/linear respons are updated
void CalcRuleSupport()
calculate the support for all rules
ELearningModel GetLearningModel() const
Double_t PdfLinear(Double_t &nsig, Double_t &ntot) const
This function returns Pr( y = 1 | x ) for the linear terms.
Double_t CoefficientRadius()
Calculates sqrt(Sum(a_i^2)), i=1..N (NOTE do not include a0)
UInt_t fRuleMapInd0
start index
void MakeRulesFromTree(const DecisionTree *dtree)
create rules from the decision tree structure
Double_t fRuleNCsig
idem sigma
void MakeLinearTerms()
Make the linear terms as in eq 25, ref 2 For this the b and (1-b) quantiles are needed.
Rule * MakeTheRule(const Node *node)
Make a Rule from a given Node.
void GetCoefficients(std::vector< Double_t > &v)
Retrieve all rule coefficients.
Double_t FStar() const
We want to estimate F* = argmin Eyx( L(y,F(x) ), min wrt F(x) F(x) = FL(x) + FR(x) ,...
A class implementing various fits of rule ensembles.
Definition RuleFit.h:46
Implementation of a rule.
Definition Rule.h:50
void SetMsgType(EMsgType t)
Definition Rule.cxx:156
Bool_t Equal(const Rule &other, Bool_t useCutValue, Double_t maxdist) const
Compare two rules.
Definition Rule.cxx:172
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition Tools.h:329
void * GetChild(void *parent, const char *childname=nullptr)
get child node
Definition Tools.cxx:1150
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition Tools.h:347
void * AddChild(void *parent, const char *childname, const char *content=nullptr, bool isRootNode=false)
add child node
Definition Tools.cxx:1124
void * GetNextChild(void *prevchild, const char *childname=nullptr)
XML helpers.
Definition Tools.cxx:1162
Double_t Rndm() override
Machine independent random number generator.
Definition TRandom.cxx:559
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2378
Tools & gTools()
std::ostream & operator<<(std::ostream &os, const BinaryTree &tree)
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
Short_t Max(Short_t a, Short_t b)
Returns the largest of a and b.
Definition TMathBase.h:250
Double_t Sqrt(Double_t x)
Returns the square root of x.
Definition TMath.h:662
Short_t Min(Short_t a, Short_t b)
Returns the smallest of a and b.
Definition TMathBase.h:198
Short_t Abs(Short_t d)
Returns the absolute value of parameter Short_t d.
Definition TMathBase.h:123