Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodRuleFit.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Fredrik Tegenfeldt
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodRuleFit *
8 * *
9 * *
10 * Description: *
11 * Implementation (see header file for description) *
12 * *
13 * Authors (alphabetical): *
14 * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
15 * *
16 * Copyright (c) 2005: *
17 * CERN, Switzerland *
18 * Iowa State U. *
19 * MPI-K Heidelberg, Germany *
20 * *
21 * Redistribution and use in source and binary forms, with or without *
22 * modification, are permitted according to the terms listed in LICENSE *
23 * (see tmva/doc/LICENSE) *
24 **********************************************************************************/
25
26/*! \class TMVA::MethodRuleFit
27\ingroup TMVA
28J Friedman's RuleFit method
29*/
30
31#include "TMVA/MethodRuleFit.h"
32
34#include "TMVA/Config.h"
35#include "TMVA/Configurable.h"
36#include "TMVA/CrossEntropy.h"
37#include "TMVA/DataSet.h"
38#include "TMVA/DecisionTree.h"
39#include "TMVA/GiniIndex.h"
40#include "TMVA/IMethod.h"
41#include "TMVA/MethodBase.h"
43#include "TMVA/MsgLogger.h"
44#include "TMVA/Ranking.h"
45#include "TMVA/RuleFitAPI.h"
46#include "TMVA/SdivSqrtSplusB.h"
47#include "TMVA/SeparationBase.h"
48#include "TMVA/Timer.h"
49#include "TMVA/Tools.h"
50#include "TMVA/Types.h"
51
52#include "TRandom3.h"
53#include "TMatrix.h"
54
55#include <iostream>
56#include <iomanip>
57#include <algorithm>
58#include <list>
59#include <random>
60
61using std::min;
62
63REGISTER_METHOD(RuleFit)
64
66
67////////////////////////////////////////////////////////////////////////////////
68/// standard constructor
69
71 const TString& methodTitle,
73 const TString& theOption) :
74 MethodBase( jobName, Types::kRuleFit, methodTitle, theData, theOption)
75 , fSignalFraction(0)
76 , fNTImportance(0)
77 , fNTCoefficient(0)
78 , fNTSupport(0)
79 , fNTNcuts(0)
80 , fNTNvars(0)
81 , fNTPtag(0)
82 , fNTPss(0)
83 , fNTPsb(0)
84 , fNTPbs(0)
85 , fNTPbb(0)
86 , fNTSSB(0)
87 , fNTType(0)
88 , fUseRuleFitJF(kFALSE)
89 , fRFNrules(0)
90 , fRFNendnodes(0)
91 , fNTrees(0)
92 , fTreeEveFrac(0)
93 , fSepType(0)
94 , fMinFracNEve(0)
95 , fMaxFracNEve(0)
96 , fNCuts(0)
97 , fPruneMethod(TMVA::DecisionTree::kCostComplexityPruning)
98 , fPruneStrength(0)
99 , fUseBoost(kFALSE)
100 , fGDPathEveFrac(0)
101 , fGDValidEveFrac(0)
102 , fGDTau(0)
103 , fGDTauPrec(0)
104 , fGDTauMin(0)
105 , fGDTauMax(0)
106 , fGDTauScan(0)
107 , fGDPathStep(0)
108 , fGDNPathSteps(0)
109 , fGDErrScale(0)
110 , fMinimp(0)
111 , fRuleMinDist(0)
112 , fLinQuantile(0)
113{
115}
116
117////////////////////////////////////////////////////////////////////////////////
118/// constructor from weight file
119
121 const TString& theWeightFile) :
123 , fSignalFraction(0)
124 , fNTImportance(0)
125 , fNTCoefficient(0)
126 , fNTSupport(0)
127 , fNTNcuts(0)
128 , fNTNvars(0)
129 , fNTPtag(0)
130 , fNTPss(0)
131 , fNTPsb(0)
132 , fNTPbs(0)
133 , fNTPbb(0)
134 , fNTSSB(0)
135 , fNTType(0)
136 , fUseRuleFitJF(kFALSE)
137 , fRFNrules(0)
138 , fRFNendnodes(0)
139 , fNTrees(0)
140 , fTreeEveFrac(0)
141 , fSepType(0)
142 , fMinFracNEve(0)
143 , fMaxFracNEve(0)
144 , fNCuts(0)
145 , fPruneMethod(TMVA::DecisionTree::kCostComplexityPruning)
146 , fPruneStrength(0)
147 , fUseBoost(kFALSE)
148 , fGDPathEveFrac(0)
149 , fGDValidEveFrac(0)
150 , fGDTau(0)
151 , fGDTauPrec(0)
152 , fGDTauMin(0)
153 , fGDTauMax(0)
154 , fGDTauScan(0)
155 , fGDPathStep(0)
156 , fGDNPathSteps(0)
157 , fGDErrScale(0)
158 , fMinimp(0)
159 , fRuleMinDist(0)
160 , fLinQuantile(0)
161{
163}
164
165////////////////////////////////////////////////////////////////////////////////
166/// destructor
167
169{
170 for (UInt_t i=0; i<fEventSample.size(); i++) delete fEventSample[i];
171 for (UInt_t i=0; i<fForest.size(); i++) delete fForest[i];
172}
173
174////////////////////////////////////////////////////////////////////////////////
175/// RuleFit can handle classification with 2 classes
176
182
183////////////////////////////////////////////////////////////////////////////////
184/// define the options (their key words) that can be set in the option string
185/// know options.
186///
187/// #### general
188///
189/// - RuleFitModule `<string>`
190/// available values are:
191/// - RFTMVA - use TMVA implementation
192/// - RFFriedman - use Friedmans original implementation
193///
194/// #### Path search (fitting)
195///
196/// - GDTau `<float>` gradient-directed path: fit threshold, default
197/// - GDTauPrec `<float>` gradient-directed path: precision of estimated tau
198/// - GDStep `<float>` gradient-directed path: step size
199/// - GDNSteps `<float>` gradient-directed path: number of steps
200/// - GDErrScale `<float>` stop scan when error>scale*errmin
201///
202/// #### Tree generation
203///
204/// - fEventsMin `<float>` minimum fraction of events in a splittable node
205/// - fEventsMax `<float>` maximum fraction of events in a splittable node
206/// - nTrees `<float>` number of trees in forest.
207/// - ForestType `<string>`
208/// available values are:
209/// - Random - create forest using random subsample and only random variables subset at each node
210/// - AdaBoost - create forest with boosted events
211///
212/// #### Model creation
213///
214/// - RuleMinDist `<float>` min distance allowed between rules
215/// - MinImp `<float>` minimum rule importance accepted
216/// - Model `<string>` model to be used
217/// available values are:
218/// - ModRuleLinear `<default>`
219/// - ModRule
220/// - ModLinear
221///
222/// #### Friedmans module
223///
224/// - RFWorkDir `<string>` directory where Friedmans module (rf_go.exe) is installed
225/// - RFNrules `<int>` maximum number of rules allowed
226/// - RFNendnodes `<int>` average number of end nodes in the forest of trees
227
229{
230 DeclareOptionRef(fGDTau=-1, "GDTau", "Gradient-directed (GD) path: default fit cut-off");
231 DeclareOptionRef(fGDTauPrec=0.01, "GDTauPrec", "GD path: precision of tau");
232 DeclareOptionRef(fGDPathStep=0.01, "GDStep", "GD path: step size");
233 DeclareOptionRef(fGDNPathSteps=10000, "GDNSteps", "GD path: number of steps");
234 DeclareOptionRef(fGDErrScale=1.1, "GDErrScale", "Stop scan when error > scale*errmin");
235 DeclareOptionRef(fLinQuantile, "LinQuantile", "Quantile of linear terms (removes outliers)");
236 DeclareOptionRef(fGDPathEveFrac=0.5, "GDPathEveFrac", "Fraction of events used for the path search");
237 DeclareOptionRef(fGDValidEveFrac=0.5, "GDValidEveFrac", "Fraction of events used for the validation");
238 // tree options
239 DeclareOptionRef(fMinFracNEve=0.1, "fEventsMin", "Minimum fraction of events in a splittable node");
240 DeclareOptionRef(fMaxFracNEve=0.9, "fEventsMax", "Maximum fraction of events in a splittable node");
241 DeclareOptionRef(fNTrees=20, "nTrees", "Number of trees in forest.");
242
243 DeclareOptionRef(fForestTypeS="AdaBoost", "ForestType", "Method to use for forest generation (AdaBoost or RandomForest)");
244 AddPreDefVal(TString("AdaBoost"));
245 AddPreDefVal(TString("Random"));
246 // rule cleanup options
247 DeclareOptionRef(fRuleMinDist=0.001, "RuleMinDist", "Minimum distance between rules");
248 DeclareOptionRef(fMinimp=0.01, "MinImp", "Minimum rule importance accepted");
249 // rule model option
250 DeclareOptionRef(fModelTypeS="ModRuleLinear", "Model", "Model to be used");
251 AddPreDefVal(TString("ModRule"));
252 AddPreDefVal(TString("ModRuleLinear"));
253 AddPreDefVal(TString("ModLinear"));
254 DeclareOptionRef(fRuleFitModuleS="RFTMVA", "RuleFitModule","Which RuleFit module to use");
255 AddPreDefVal(TString("RFTMVA"));
256 AddPreDefVal(TString("RFFriedman"));
257
258 DeclareOptionRef(fRFWorkDir="./rulefit", "RFWorkDir", "Friedman\'s RuleFit module (RFF): working dir");
259 DeclareOptionRef(fRFNrules=2000, "RFNrules", "RFF: Mximum number of rules");
260 DeclareOptionRef(fRFNendnodes=4, "RFNendnodes", "RFF: Average number of end nodes");
261}
262
263////////////////////////////////////////////////////////////////////////////////
264/// process the options specified by the user
265
267{
268 if (IgnoreEventsWithNegWeightsInTraining()) {
269 Log() << kFATAL << "Mechanism to ignore events with negative weights in training not yet available for method: "
270 << GetMethodTypeName()
271 << " --> please remove \"IgnoreNegWeightsInTraining\" option from booking string."
272 << Endl;
273 }
274
275 fRuleFitModuleS.ToLower();
276 if (fRuleFitModuleS == "rftmva") fUseRuleFitJF = kFALSE;
277 else if (fRuleFitModuleS == "rffriedman") fUseRuleFitJF = kTRUE;
278 else fUseRuleFitJF = kTRUE;
279
280 fSepTypeS.ToLower();
281 if (fSepTypeS == "misclassificationerror") fSepType = new MisClassificationError();
282 else if (fSepTypeS == "giniindex") fSepType = new GiniIndex();
283 else if (fSepTypeS == "crossentropy") fSepType = new CrossEntropy();
284 else fSepType = new SdivSqrtSplusB();
285
286 fModelTypeS.ToLower();
287 if (fModelTypeS == "modlinear" ) fRuleFit.SetModelLinear();
288 else if (fModelTypeS == "modrule" ) fRuleFit.SetModelRules();
289 else fRuleFit.SetModelFull();
290
291 fPruneMethodS.ToLower();
292 if (fPruneMethodS == "expectederror" ) fPruneMethod = DecisionTree::kExpectedErrorPruning;
293 else if (fPruneMethodS == "costcomplexity" ) fPruneMethod = DecisionTree::kCostComplexityPruning;
294 else fPruneMethod = DecisionTree::kNoPruning;
295
296 fForestTypeS.ToLower();
297 if (fForestTypeS == "random" ) fUseBoost = kFALSE;
298 else if (fForestTypeS == "adaboost" ) fUseBoost = kTRUE;
299 else fUseBoost = kTRUE;
300 //
301 // if creating the forest by boosting the events
302 // the full training sample is used per tree
303 // -> only true for the TMVA version of RuleFit.
304 if (fUseBoost && (!fUseRuleFitJF)) fTreeEveFrac = 1.0;
305
306 // check event fraction for tree generation
307 // if <0 set to automatic number
308 if (fTreeEveFrac<=0) {
309 Int_t nevents = Data()->GetNTrainingEvents();
310 Double_t n = static_cast<Double_t>(nevents);
311 fTreeEveFrac = min( 0.5, (100.0 +6.0*sqrt(n))/n);
312 }
313 // verify ranges of options
314 VerifyRange(Log(), "nTrees", fNTrees,0,100000,20);
315 VerifyRange(Log(), "MinImp", fMinimp,0.0,1.0,0.0);
316 VerifyRange(Log(), "GDTauPrec", fGDTauPrec,1e-5,5e-1);
317 VerifyRange(Log(), "GDTauMin", fGDTauMin,0.0,1.0);
318 VerifyRange(Log(), "GDTauMax", fGDTauMax,fGDTauMin,1.0);
319 VerifyRange(Log(), "GDPathStep", fGDPathStep,0.0,100.0,0.01);
320 VerifyRange(Log(), "GDErrScale", fGDErrScale,1.0,100.0,1.1);
321 VerifyRange(Log(), "GDPathEveFrac", fGDPathEveFrac,0.01,0.9,0.5);
322 VerifyRange(Log(), "GDValidEveFrac",fGDValidEveFrac,0.01,1.0-fGDPathEveFrac,1.0-fGDPathEveFrac);
323 VerifyRange(Log(), "fEventsMin", fMinFracNEve,0.0,1.0);
324 VerifyRange(Log(), "fEventsMax", fMaxFracNEve,fMinFracNEve,1.0);
325
326 fRuleFit.GetRuleEnsemblePtr()->SetLinQuantile(fLinQuantile);
327 fRuleFit.GetRuleFitParamsPtr()->SetGDTauRange(fGDTauMin,fGDTauMax);
328 fRuleFit.GetRuleFitParamsPtr()->SetGDTau(fGDTau);
329 fRuleFit.GetRuleFitParamsPtr()->SetGDTauPrec(fGDTauPrec);
330 fRuleFit.GetRuleFitParamsPtr()->SetGDTauScan(fGDTauScan);
331 fRuleFit.GetRuleFitParamsPtr()->SetGDPathStep(fGDPathStep);
332 fRuleFit.GetRuleFitParamsPtr()->SetGDNPathSteps(fGDNPathSteps);
333 fRuleFit.GetRuleFitParamsPtr()->SetGDErrScale(fGDErrScale);
334 fRuleFit.SetImportanceCut(fMinimp);
335 fRuleFit.SetRuleMinDist(fRuleMinDist);
336
337
338 // check if Friedmans module is used.
339 // print a message concerning the options.
340 if (fUseRuleFitJF) {
341 Log() << kINFO << "" << Endl;
342 Log() << kINFO << "--------------------------------------" <<Endl;
343 Log() << kINFO << "Friedmans RuleFit module is selected." << Endl;
344 Log() << kINFO << "Only the following options are used:" << Endl;
345 Log() << kINFO << Endl;
346 Log() << kINFO << gTools().Color("bold") << " Model" << gTools().Color("reset") << Endl;
347 Log() << kINFO << gTools().Color("bold") << " RFWorkDir" << gTools().Color("reset") << Endl;
348 Log() << kINFO << gTools().Color("bold") << " RFNrules" << gTools().Color("reset") << Endl;
349 Log() << kINFO << gTools().Color("bold") << " RFNendnodes" << gTools().Color("reset") << Endl;
350 Log() << kINFO << gTools().Color("bold") << " GDNPathSteps" << gTools().Color("reset") << Endl;
351 Log() << kINFO << gTools().Color("bold") << " GDPathStep" << gTools().Color("reset") << Endl;
352 Log() << kINFO << gTools().Color("bold") << " GDErrScale" << gTools().Color("reset") << Endl;
353 Log() << kINFO << "--------------------------------------" <<Endl;
354 Log() << kINFO << Endl;
355 }
356
357 // Select what weight to use in the 'importance' rule visualisation plots.
358 // Note that if UseCoefficientsVisHists() is selected, the following weight is used:
359 // w = rule coefficient * rule support
360 // The support is a positive number which is 0 if no events are accepted by the rule.
361 // Normally the importance gives more useful information.
362 //
363 //fRuleFit.UseCoefficientsVisHists();
364 fRuleFit.UseImportanceVisHists();
365
366 fRuleFit.SetMsgType( Log().GetMinType() );
367
368 if (HasTrainingTree()) InitEventSample();
369
370}
371
372////////////////////////////////////////////////////////////////////////////////
373/// initialize the monitoring ntuple
374
376{
377 BaseDir()->cd();
378 fMonitorNtuple= new TTree("MonitorNtuple_RuleFit","RuleFit variables");
379 fMonitorNtuple->Branch("importance",&fNTImportance,"importance/D");
380 fMonitorNtuple->Branch("support",&fNTSupport,"support/D");
381 fMonitorNtuple->Branch("coefficient",&fNTCoefficient,"coefficient/D");
382 fMonitorNtuple->Branch("ncuts",&fNTNcuts,"ncuts/I");
383 fMonitorNtuple->Branch("nvars",&fNTNvars,"nvars/I");
384 fMonitorNtuple->Branch("type",&fNTType,"type/I");
385 fMonitorNtuple->Branch("ptag",&fNTPtag,"ptag/D");
386 fMonitorNtuple->Branch("pss",&fNTPss,"pss/D");
387 fMonitorNtuple->Branch("psb",&fNTPsb,"psb/D");
388 fMonitorNtuple->Branch("pbs",&fNTPbs,"pbs/D");
389 fMonitorNtuple->Branch("pbb",&fNTPbb,"pbb/D");
390 fMonitorNtuple->Branch("soversb",&fNTSSB,"soversb/D");
391}
392
393////////////////////////////////////////////////////////////////////////////////
394/// default initialization
395
397{
398 // the minimum requirement to declare an event signal-like
399 SetSignalReferenceCut( 0.0 );
400
401 // set variables that used to be options
402 // any modifications are then made in ProcessOptions()
403 fLinQuantile = 0.025; // Quantile of linear terms (remove outliers)
404 fTreeEveFrac = -1.0; // Fraction of events used to train each tree
405 fNCuts = 20; // Number of steps during node cut optimisation
406 fSepTypeS = "GiniIndex"; // Separation criterion for node splitting; see BDT
407 fPruneMethodS = "NONE"; // Pruning method; see BDT
408 fPruneStrength = 3.5; // Pruning strength; see BDT
409 fGDTauMin = 0.0; // Gradient-directed path: min fit threshold (tau)
410 fGDTauMax = 1.0; // Gradient-directed path: max fit threshold (tau)
411 fGDTauScan = 1000; // Gradient-directed path: number of points scanning for best tau
412
413}
414
415////////////////////////////////////////////////////////////////////////////////
416/// write all Events from the Tree into a vector of Events, that are
417/// more easily manipulated.
418/// This method should never be called without existing trainingTree, as it
419/// the vector of events from the ROOT training tree
420
422{
423 if (Data()->GetNEvents()==0) Log() << kFATAL << "<Init> Data().TrainingTree() is zero pointer" << Endl;
424
425 Int_t nevents = Data()->GetNEvents();
426 for (Int_t ievt=0; ievt<nevents; ievt++){
427 const Event * ev = GetEvent(ievt);
428 fEventSample.push_back( new Event(*ev));
429 }
430 if (fTreeEveFrac<=0) {
431 Double_t n = static_cast<Double_t>(nevents);
432 fTreeEveFrac = min( 0.5, (100.0 +6.0*sqrt(n))/n);
433 }
434 if (fTreeEveFrac>1.0) fTreeEveFrac=1.0;
435 //
436 std::shuffle(fEventSample.begin(), fEventSample.end(), std::default_random_engine{});
437 //
438 Log() << kDEBUG << "Set sub-sample fraction to " << fTreeEveFrac << Endl;
439}
440
441////////////////////////////////////////////////////////////////////////////////
442
444{
446 // training of rules
447
448 if(!IsSilentFile()) InitMonitorNtuple();
449
450 // fill the STL Vector with the event sample
451 this->InitEventSample();
452
453 if (fUseRuleFitJF) {
454 TrainJFRuleFit();
455 }
456 else {
457 TrainTMVARuleFit();
458 }
459 fRuleFit.GetRuleEnsemblePtr()->ClearRuleMap();
461 ExitFromTraining();
462}
463
464////////////////////////////////////////////////////////////////////////////////
465/// training of rules using TMVA implementation
466
468{
469 if (IsNormalised()) Log() << kFATAL << "\"Normalise\" option cannot be used with RuleFit; "
470 << "please remove the option from the configuration string, or "
471 << "use \"!Normalise\""
472 << Endl;
473
474 // timer
475 Timer timer( 1, GetName() );
476
477 // test tree nmin cut -> for debug purposes
478 // the routine will generate trees with stopping cut on N(eve) given by
479 // a fraction between [20,N(eve)-1].
480 //
481 // MakeForestRnd();
482 // exit(1);
483 //
484
485 // Init RuleFit object and create rule ensemble
486 // + make forest & rules
487 fRuleFit.Initialize( this );
488
489 // Make forest of decision trees
490 // if (fRuleFit.GetRuleEnsemble().DoRules()) fRuleFit.MakeForest();
491
492 // Fit the rules
493 Log() << kDEBUG << "Fitting rule coefficients ..." << Endl;
494 fRuleFit.FitCoefficients();
495
496 // Calculate importance
497 Log() << kDEBUG << "Computing rule and variable importance" << Endl;
498 fRuleFit.CalcImportance();
499
500 // Output results and fill monitor ntuple
501 fRuleFit.GetRuleEnsemblePtr()->Print();
502 //
503 if(!IsSilentFile())
504 {
505 Log() << kDEBUG << "Filling rule ntuple" << Endl;
506 UInt_t nrules = fRuleFit.GetRuleEnsemble().GetRulesConst().size();
507 const Rule *rule;
508 for (UInt_t i=0; i<nrules; i++ ) {
509 rule = fRuleFit.GetRuleEnsemble().GetRulesConst(i);
510 fNTImportance = rule->GetRelImportance();
511 fNTSupport = rule->GetSupport();
512 fNTCoefficient = rule->GetCoefficient();
513 fNTType = (rule->IsSignalRule() ? 1:-1 );
514 fNTNvars = rule->GetRuleCut()->GetNvars();
515 fNTNcuts = rule->GetRuleCut()->GetNcuts();
516 fNTPtag = fRuleFit.GetRuleEnsemble().GetRulePTag(i); // should be identical with support
517 fNTPss = fRuleFit.GetRuleEnsemble().GetRulePSS(i);
518 fNTPsb = fRuleFit.GetRuleEnsemble().GetRulePSB(i);
519 fNTPbs = fRuleFit.GetRuleEnsemble().GetRulePBS(i);
520 fNTPbb = fRuleFit.GetRuleEnsemble().GetRulePBB(i);
521 fNTSSB = rule->GetSSB();
522 fMonitorNtuple->Fill();
523 }
524
525 fRuleFit.MakeVisHists();
526 fRuleFit.MakeDebugHists();
527 }
528 Log() << kDEBUG << "Training done" << Endl;
529
530}
531
532////////////////////////////////////////////////////////////////////////////////
533/// training of rules using Jerome Friedmans implementation
534
536{
537 fRuleFit.InitPtrs( this );
538 Data()->SetCurrentType(Types::kTraining);
539 UInt_t nevents = Data()->GetNTrainingEvents();
540 std::vector<const TMVA::Event*> tmp;
541 for (Long64_t ievt=0; ievt<nevents; ievt++) {
542 const Event *event = GetEvent(ievt);
543 tmp.push_back(event);
544 }
545 fRuleFit.SetTrainingEvents( tmp );
546
547 RuleFitAPI *rfAPI = new RuleFitAPI( this, &fRuleFit, Log().GetMinType() );
548
549 rfAPI->WelcomeMessage();
550
551 // timer
552 Timer timer( 1, GetName() );
553
554 Log() << kINFO << "Training ..." << Endl;
555 rfAPI->TrainRuleFit();
556
557 Log() << kDEBUG << "reading model summary from rf_go.exe output" << Endl;
558 rfAPI->ReadModelSum();
559
560 // fRuleFit.GetRuleEnsemblePtr()->MakeRuleMap();
561
562 Log() << kDEBUG << "calculating rule and variable importance" << Endl;
563 fRuleFit.CalcImportance();
564
565 // Output results and fill monitor ntuple
566 fRuleFit.GetRuleEnsemblePtr()->Print();
567 //
568 if(!IsSilentFile())fRuleFit.MakeVisHists();
569
570 delete rfAPI;
571
572 Log() << kDEBUG << "done training" << Endl;
573}
574
575////////////////////////////////////////////////////////////////////////////////
576/// computes ranking of input variables
577
579{
580 // create the ranking object
581 fRanking = new Ranking( GetName(), "Importance" );
582
583 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
584 fRanking->AddRank( Rank( GetInputLabel(ivar), fRuleFit.GetRuleEnsemble().GetVarImportance(ivar) ) );
585 }
586
587 return fRanking;
588}
589
590////////////////////////////////////////////////////////////////////////////////
591/// add the rules to XML node
592
593void TMVA::MethodRuleFit::AddWeightsXMLTo( void* parent ) const
594{
595 fRuleFit.GetRuleEnsemble().AddXMLTo( parent );
596}
597
598////////////////////////////////////////////////////////////////////////////////
599/// read rules from an std::istream
600
602{
603 fRuleFit.GetRuleEnsemblePtr()->ReadRaw( istr );
604}
605
606////////////////////////////////////////////////////////////////////////////////
607/// read rules from XML node
608
610{
611 fRuleFit.GetRuleEnsemblePtr()->ReadFromXML( wghtnode );
612}
613
614////////////////////////////////////////////////////////////////////////////////
615/// returns MVA value for given event
616
618{
619 // cannot determine error
620 NoErrorCalc(err, errUpper);
621
622 return fRuleFit.EvalEvent( *GetEvent() );
623}
624
625////////////////////////////////////////////////////////////////////////////////
626/// write special monitoring histograms to file (here ntuple)
627
629{
630 BaseDir()->cd();
631 Log() << kINFO << "Write monitoring ntuple to file: " << BaseDir()->GetPath() << Endl;
632 fMonitorNtuple->Write();
633}
634
635////////////////////////////////////////////////////////////////////////////////
636/// write specific classifier response
637
638void TMVA::MethodRuleFit::MakeClassSpecific( std::ostream& fout, const TString& className ) const
639{
640 Int_t dp = fout.precision();
641 fout << " // not implemented for class: \"" << className << "\"" << std::endl;
642 fout << "};" << std::endl;
643 fout << "void " << className << "::Initialize(){}" << std::endl;
644 fout << "void " << className << "::Clear(){}" << std::endl;
645 fout << "double " << className << "::GetMvaValue__( const std::vector<double>& inputValues ) const {" << std::endl;
646 fout << " double rval=" << std::setprecision(10) << fRuleFit.GetRuleEnsemble().GetOffset() << ";" << std::endl;
647 MakeClassRuleCuts(fout);
648 MakeClassLinear(fout);
649 fout << " return rval;" << std::endl;
650 fout << "}" << std::endl;
651 fout << std::setprecision(dp);
652}
653
654////////////////////////////////////////////////////////////////////////////////
655/// print out the rule cuts
656
658{
659 Int_t dp = fout.precision();
660 if (!fRuleFit.GetRuleEnsemble().DoRules()) {
661 fout << " //" << std::endl;
662 fout << " // ==> MODEL CONTAINS NO RULES <==" << std::endl;
663 fout << " //" << std::endl;
664 return;
665 }
666 const RuleEnsemble *rens = &(fRuleFit.GetRuleEnsemble());
667 const std::vector< Rule* > *rules = &(rens->GetRulesConst());
668 const RuleCut *ruleCut;
669 //
670 std::list< std::pair<Double_t,Int_t> > sortedRules;
671 for (UInt_t ir=0; ir<rules->size(); ir++) {
672 sortedRules.push_back( std::pair<Double_t,Int_t>( (*rules)[ir]->GetImportance()/rens->GetImportanceRef(),ir ) );
673 }
674 sortedRules.sort();
675 //
676 fout << " //" << std::endl;
677 fout << " // here follows all rules ordered in importance (most important first)" << std::endl;
678 fout << " // at the end of each line, the relative importance of the rule is given" << std::endl;
679 fout << " //" << std::endl;
680 //
681 for ( std::list< std::pair<double,int> >::reverse_iterator itpair = sortedRules.rbegin();
682 itpair != sortedRules.rend(); ++itpair ) {
683 UInt_t ir = itpair->second;
684 Double_t impr = itpair->first;
685 ruleCut = (*rules)[ir]->GetRuleCut();
686 if (impr<rens->GetImportanceCut()) fout << " //" << std::endl;
687 fout << " if (" << std::flush;
688 for (UInt_t ic=0; ic<ruleCut->GetNvars(); ic++) {
689 Double_t sel = ruleCut->GetSelector(ic);
690 Double_t valmin = ruleCut->GetCutMin(ic);
691 Double_t valmax = ruleCut->GetCutMax(ic);
692 Bool_t domin = ruleCut->GetCutDoMin(ic);
693 Bool_t domax = ruleCut->GetCutDoMax(ic);
694 //
695 if (ic>0) fout << "&&" << std::flush;
696 if (domin) {
697 fout << "(" << std::setprecision(10) << valmin << std::flush;
698 fout << "<inputValues[" << sel << "])" << std::flush;
699 }
700 if (domax) {
701 if (domin) fout << "&&" << std::flush;
702 fout << "(inputValues[" << sel << "]" << std::flush;
703 fout << "<" << std::setprecision(10) << valmax << ")" <<std::flush;
704 }
705 }
706 fout << ") rval+=" << std::setprecision(10) << (*rules)[ir]->GetCoefficient() << ";" << std::flush;
707 fout << " // importance = " << TString::Format("%3.3f",impr) << std::endl;
708 }
709 fout << std::setprecision(dp);
710}
711
712////////////////////////////////////////////////////////////////////////////////
713/// print out the linear terms
714
716{
717 if (!fRuleFit.GetRuleEnsemble().DoLinear()) {
718 fout << " //" << std::endl;
719 fout << " // ==> MODEL CONTAINS NO LINEAR TERMS <==" << std::endl;
720 fout << " //" << std::endl;
721 return;
722 }
723 fout << " //" << std::endl;
724 fout << " // here follows all linear terms" << std::endl;
725 fout << " // at the end of each line, the relative importance of the term is given" << std::endl;
726 fout << " //" << std::endl;
727 const RuleEnsemble *rens = &(fRuleFit.GetRuleEnsemble());
728 UInt_t nlin = rens->GetNLinear();
729 for (UInt_t il=0; il<nlin; il++) {
730 if (rens->IsLinTermOK(il)) {
731 Double_t norm = rens->GetLinNorm(il);
732 Double_t imp = rens->GetLinImportance(il)/rens->GetImportanceRef();
733 fout << " rval+="
734 // << std::setprecision(10) << rens->GetLinCoefficients(il)*norm << "*std::min(" << setprecision(10) << rens->GetLinDP(il)
735 // << ", std::max( inputValues[" << il << "]," << std::setprecision(10) << rens->GetLinDM(il) << "));"
736 << std::setprecision(10) << rens->GetLinCoefficients(il)*norm
737 << "*std::min( double(" << std::setprecision(10) << rens->GetLinDP(il)
738 << "), std::max( double(inputValues[" << il << "]), double(" << std::setprecision(10) << rens->GetLinDM(il) << ")));"
739 << std::flush;
740 fout << " // importance = " << TString::Format("%3.3f",imp) << std::endl;
741 }
742 }
743}
744
745////////////////////////////////////////////////////////////////////////////////
746/// get help message text
747///
748/// typical length of text line:
749/// "|--------------------------------------------------------------|"
750
752{
753 TString col = gConfig().WriteOptionsReference() ? TString() : gTools().Color("bold");
755 TString brk = gConfig().WriteOptionsReference() ? "<br>" : "";
756
757 Log() << Endl;
758 Log() << col << "--- Short description:" << colres << Endl;
759 Log() << Endl;
760 Log() << "This method uses a collection of so called rules to create a" << Endl;
761 Log() << "discriminating scoring function. Each rule consists of a series" << Endl;
762 Log() << "of cuts in parameter space. The ensemble of rules are created" << Endl;
763 Log() << "from a forest of decision trees, trained using the training data." << Endl;
764 Log() << "Each node (apart from the root) corresponds to one rule." << Endl;
765 Log() << "The scoring function is then obtained by linearly combining" << Endl;
766 Log() << "the rules. A fitting procedure is applied to find the optimum" << Endl;
767 Log() << "set of coefficients. The goal is to find a model with few rules" << Endl;
768 Log() << "but with a strong discriminating power." << Endl;
769 Log() << Endl;
770 Log() << col << "--- Performance optimisation:" << colres << Endl;
771 Log() << Endl;
772 Log() << "There are two important considerations to make when optimising:" << Endl;
773 Log() << Endl;
774 Log() << " 1. Topology of the decision tree forest" << brk << Endl;
775 Log() << " 2. Fitting of the coefficients" << Endl;
776 Log() << Endl;
777 Log() << "The maximum complexity of the rules is defined by the size of" << Endl;
778 Log() << "the trees. Large trees will yield many complex rules and capture" << Endl;
779 Log() << "higher order correlations. On the other hand, small trees will" << Endl;
780 Log() << "lead to a smaller ensemble with simple rules, only capable of" << Endl;
781 Log() << "modeling simple structures." << Endl;
782 Log() << "Several parameters exists for controlling the complexity of the" << Endl;
783 Log() << "rule ensemble." << Endl;
784 Log() << Endl;
785 Log() << "The fitting procedure searches for a minimum using a gradient" << Endl;
786 Log() << "directed path. Apart from step size and number of steps, the" << Endl;
787 Log() << "evolution of the path is defined by a cut-off parameter, tau." << Endl;
788 Log() << "This parameter is unknown and depends on the training data." << Endl;
789 Log() << "A large value will tend to give large weights to a few rules." << Endl;
790 Log() << "Similarly, a small value will lead to a large set of rules" << Endl;
791 Log() << "with similar weights." << Endl;
792 Log() << Endl;
793 Log() << "A final point is the model used; rules and/or linear terms." << Endl;
794 Log() << "For a given training sample, the result may improve by adding" << Endl;
795 Log() << "linear terms. If best performance is obtained using only linear" << Endl;
796 Log() << "terms, it is very likely that the Fisher discriminant would be" << Endl;
797 Log() << "a better choice. Ideally the fitting procedure should be able to" << Endl;
798 Log() << "make this choice by giving appropriate weights for either terms." << Endl;
799 Log() << Endl;
800 Log() << col << "--- Performance tuning via configuration options:" << colres << Endl;
801 Log() << Endl;
802 Log() << "I. TUNING OF RULE ENSEMBLE:" << Endl;
803 Log() << Endl;
804 Log() << " " << col << "ForestType " << colres
805 << ": Recommended is to use the default \"AdaBoost\"." << brk << Endl;
806 Log() << " " << col << "nTrees " << colres
807 << ": More trees leads to more rules but also slow" << Endl;
808 Log() << " performance. With too few trees the risk is" << Endl;
809 Log() << " that the rule ensemble becomes too simple." << brk << Endl;
810 Log() << " " << col << "fEventsMin " << colres << brk << Endl;
811 Log() << " " << col << "fEventsMax " << colres
812 << ": With a lower min, more large trees will be generated" << Endl;
813 Log() << " leading to more complex rules." << Endl;
814 Log() << " With a higher max, more small trees will be" << Endl;
815 Log() << " generated leading to more simple rules." << Endl;
816 Log() << " By changing this range, the average complexity" << Endl;
817 Log() << " of the rule ensemble can be controlled." << brk << Endl;
818 Log() << " " << col << "RuleMinDist " << colres
819 << ": By increasing the minimum distance between" << Endl;
820 Log() << " rules, fewer and more diverse rules will remain." << Endl;
821 Log() << " Initially it is a good idea to keep this small" << Endl;
822 Log() << " or zero and let the fitting do the selection of" << Endl;
823 Log() << " rules. In order to reduce the ensemble size," << Endl;
824 Log() << " the value can then be increased." << Endl;
825 Log() << Endl;
826 // "|--------------------------------------------------------------|"
827 Log() << "II. TUNING OF THE FITTING:" << Endl;
828 Log() << Endl;
829 Log() << " " << col << "GDPathEveFrac " << colres
830 << ": fraction of events in path evaluation" << Endl;
831 Log() << " Increasing this fraction will improve the path" << Endl;
832 Log() << " finding. However, a too high value will give few" << Endl;
833 Log() << " unique events available for error estimation." << Endl;
834 Log() << " It is recommended to use the default = 0.5." << brk << Endl;
835 Log() << " " << col << "GDTau " << colres
836 << ": cutoff parameter tau" << Endl;
837 Log() << " By default this value is set to -1.0." << Endl;
838 // "|----------------|---------------------------------------------|"
839 Log() << " This means that the cut off parameter is" << Endl;
840 Log() << " automatically estimated. In most cases" << Endl;
841 Log() << " this should be fine. However, you may want" << Endl;
842 Log() << " to fix this value if you already know it" << Endl;
843 Log() << " and want to reduce on training time." << brk << Endl;
844 Log() << " " << col << "GDTauPrec " << colres
845 << ": precision of estimated tau" << Endl;
846 Log() << " Increase this precision to find a more" << Endl;
847 Log() << " optimum cut-off parameter." << brk << Endl;
848 Log() << " " << col << "GDNStep " << colres
849 << ": number of steps in path search" << Endl;
850 Log() << " If the number of steps is too small, then" << Endl;
851 Log() << " the program will give a warning message." << Endl;
852 Log() << Endl;
853 Log() << "III. WARNING MESSAGES" << Endl;
854 Log() << Endl;
855 Log() << col << "Risk(i+1)>=Risk(i) in path" << colres << brk << Endl;
856 Log() << col << "Chaotic behaviour of risk evolution." << colres << Endl;
857 // "|----------------|---------------------------------------------|"
858 Log() << " The error rate was still decreasing at the end" << Endl;
859 Log() << " By construction the Risk should always decrease." << Endl;
860 Log() << " However, if the training sample is too small or" << Endl;
861 Log() << " the model is overtrained, such warnings can" << Endl;
862 Log() << " occur." << Endl;
863 Log() << " The warnings can safely be ignored if only a" << Endl;
864 Log() << " few (<3) occur. If more warnings are generated," << Endl;
865 Log() << " the fitting fails." << Endl;
866 Log() << " A remedy may be to increase the value" << brk << Endl;
867 Log() << " "
868 << col << "GDValidEveFrac" << colres
869 << " to 1.0 (or a larger value)." << brk << Endl;
870 Log() << " In addition, if "
871 << col << "GDPathEveFrac" << colres
872 << " is too high" << Endl;
873 Log() << " the same warnings may occur since the events" << Endl;
874 Log() << " used for error estimation are also used for" << Endl;
875 Log() << " path estimation." << Endl;
876 Log() << " Another possibility is to modify the model - " << Endl;
877 Log() << " See above on tuning the rule ensemble." << Endl;
878 Log() << Endl;
879 Log() << col << "The error rate was still decreasing at the end of the path"
880 << colres << Endl;
881 Log() << " Too few steps in path! Increase "
882 << col << "GDNSteps" << colres << "." << Endl;
883 Log() << Endl;
884 Log() << col << "Reached minimum early in the search" << colres << Endl;
885
886 Log() << " Minimum was found early in the fitting. This" << Endl;
887 Log() << " may indicate that the used step size "
888 << col << "GDStep" << colres << "." << Endl;
889 Log() << " was too large. Reduce it and rerun." << Endl;
890 Log() << " If the results still are not OK, modify the" << Endl;
891 Log() << " model either by modifying the rule ensemble" << Endl;
892 Log() << " or add/remove linear terms" << Endl;
893}
#define REGISTER_METHOD(CLASS)
for example
#define e(i)
Definition RSha256.hxx:103
constexpr Bool_t kFALSE
Definition RtypesCore.h:94
constexpr Bool_t kTRUE
Definition RtypesCore.h:93
#define ClassImp(name)
Definition Rtypes.h:374
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t sel
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
Bool_t WriteOptionsReference() const
Definition Config.h:65
Implementation of the CrossEntropy as separation criterion.
Class that contains all the data information.
Definition DataSetInfo.h:62
static void SetIsTraining(bool on)
Implementation of a Decision Tree.
Implementation of the GiniIndex as separation criterion.
Definition GiniIndex.h:63
Virtual base Class for all MVA method.
Definition MethodBase.h:111
J Friedman's RuleFit method.
void GetHelpMessage() const override
get help message text
void ReadWeightsFromXML(void *wghtnode) override
read rules from XML node
void DeclareOptions() override
define the options (their key words) that can be set in the option string know options.
Double_t GetMvaValue(Double_t *err=nullptr, Double_t *errUpper=nullptr) override
returns MVA value for given event
void MakeClassLinear(std::ostream &) const
print out the linear terms
void TrainJFRuleFit()
training of rules using Jerome Friedmans implementation
void InitEventSample(void)
write all Events from the Tree into a vector of Events, that are more easily manipulated.
void MakeClassRuleCuts(std::ostream &) const
print out the rule cuts
void Init(void) override
default initialization
void InitMonitorNtuple()
initialize the monitoring ntuple
virtual ~MethodRuleFit(void)
destructor
void ReadWeightsFromStream(std::istream &istr) override
read rules from an std::istream
TTree * fMonitorNtuple
pointer to monitor rule ntuple
void ProcessOptions() override
process the options specified by the user
void AddWeightsXMLTo(void *parent) const override
add the rules to XML node
MethodRuleFit(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
standard constructor
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t) override
RuleFit can handle classification with 2 classes.
const Ranking * CreateRanking() override
computes ranking of input variables
void MakeClassSpecific(std::ostream &, const TString &) const override
write specific classifier response
void Train(void) override
void WriteMonitoringHistosToFile(void) const override
write special monitoring histograms to file (here ntuple)
void TrainTMVARuleFit()
training of rules using TMVA implementation
Implementation of the MisClassificationError as separation criterion.
Ranking for variables in method (implementation)
Definition Ranking.h:48
A class describing a 'rule cut'.
Definition RuleCut.h:36
J Friedman's RuleFit method.
Definition RuleFitAPI.h:51
Implementation of a rule.
Definition Rule.h:50
Implementation of the SdivSqrtSplusB as separation criterion.
Timing information for training and evaluation of MVA methods.
Definition Timer.h:58
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:828
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kClassification
Definition Types.h:127
@ kTraining
Definition Types.h:143
virtual Int_t Write(const char *name=nullptr, Int_t option=0, Int_t bufsize=0)
Write this object to the current directory.
Definition TObject.cxx:964
virtual void Print(Option_t *option="") const
This method must be overridden when a class wants to print itself.
Definition TObject.cxx:655
Basic string class.
Definition TString.h:139
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2378
A TTree represents a columnar dataset.
Definition TTree.h:84
const Int_t n
Definition legend1.C:16
create variable transformations
Config & gConfig()
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148