ROOT  6.06/09
Reference Guide
Rule.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : Rule *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * A class describung a 'rule' *
12  * Each internal node of a tree defines a rule from all the parental nodes. *
13  * A rule consists of atleast 2 nodes. *
14  * Input: a decision tree (in the constructor) *
15  * *
16  * Authors (alphabetical): *
17  * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
18  * Helge Voss <Helge.Voss@cern.ch> - MPI-KP Heidelberg, Ger. *
19  * *
20  * Copyright (c) 2005: *
21  * CERN, Switzerland *
22  * Iowa State U. *
23  * MPI-K Heidelberg, Germany *
24  * *
25  * Redistribution and use in source and binary forms, with or without *
26  * modification, are permitted according to the terms listed in LICENSE *
27  * (http://tmva.sourceforge.net/LICENSE) *
28  **********************************************************************************/
29 
30 //________________________________________________________________________________
31 //
32 // Implementation of a rule
33 //
34 // A rule is simply a branch or a part of a branch in a tree.
35 // It fullfills the following:
36 // * First node is the root node of the originating tree
37 // * Consists of a minimum of 2 nodes
38 // * A rule returns for a given event:
39 // 0 : if the event fails at any node
40 // 1 : otherwise
41 // * If the rule contains <2 nodes, it returns 0 SHOULD NOT HAPPEN!
42 //
43 // The coefficient is found by either brute force or some sort of
44 // intelligent fitting. See the RuleEnsemble class for more info.
45 //________________________________________________________________________________
46 
47 #include "TMVA/Event.h"
48 #include "TMVA/RuleCut.h"
49 #include "TMVA/Rule.h"
50 #include "TMVA/RuleFit.h"
51 #include "TMVA/RuleEnsemble.h"
52 #include "TMVA/MethodRuleFit.h"
53 #include "TMVA/Tools.h"
54 
55 ////////////////////////////////////////////////////////////////////////////////
56 /// the main constructor for a Rule
57 
59  const std::vector< const Node * >& nodes )
60  : fCut ( 0 )
61  , fNorm ( 1.0 )
62  , fSupport ( 0.0 )
63  , fSigma ( 0.0 )
64  , fCoefficient ( 0.0 )
65  , fImportance ( 0.0 )
66  , fImportanceRef ( 1.0 )
67  , fRuleEnsemble ( re )
68  , fSSB ( 0 )
69  , fSSBNeve ( 0 )
70  , fLogger( new MsgLogger("RuleFit") )
71 {
72  //
73  // input:
74  // nodes - a vector of Node; from these all possible rules will be created
75  //
76  //
77 
78  fCut = new RuleCut( nodes );
79  fSSB = fCut->GetPurity();
81 }
82 
83 ////////////////////////////////////////////////////////////////////////////////
84 /// the simple constructor
85 
87  : fCut ( 0 )
88  , fNorm ( 1.0 )
89  , fSupport ( 0.0 )
90  , fSigma ( 0.0 )
91  , fCoefficient ( 0.0 )
92  , fImportance ( 0.0 )
93  , fImportanceRef ( 1.0 )
94  , fRuleEnsemble ( re )
95  , fSSB ( 0 )
96  , fSSBNeve ( 0 )
97  , fLogger( new MsgLogger("RuleFit") )
98 {
99 }
100 
101 ////////////////////////////////////////////////////////////////////////////////
102 /// the simple constructor
103 
105  : fCut ( 0 )
106  , fNorm ( 1.0 )
107  , fSupport ( 0.0 )
108  , fSigma ( 0.0 )
109  , fCoefficient ( 0.0 )
110  , fImportance ( 0.0 )
111  , fImportanceRef ( 1.0 )
112  , fRuleEnsemble ( 0 )
113  , fSSB ( 0 )
114  , fSSBNeve ( 0 )
115  , fLogger( new MsgLogger("RuleFit") )
116 {
117 }
118 
119 ////////////////////////////////////////////////////////////////////////////////
120 /// destructor
121 
123 {
124  delete fCut;
125  delete fLogger;
126 }
127 
128 ////////////////////////////////////////////////////////////////////////////////
129 /// check if variable in node
130 
132 {
133  Bool_t found = kFALSE;
134  Bool_t doneLoop = kFALSE;
135  UInt_t nvars = fCut->GetNvars();
136  UInt_t i = 0;
137  //
138  while (!doneLoop) {
139  found = (fCut->GetSelector(i) == iv);
140  i++;
141  doneLoop = (found || (i==nvars));
142  }
143  return found;
144 }
145 
146 ////////////////////////////////////////////////////////////////////////////////
147 
149 {
150  fLogger->SetMinType(t);
151 }
152 
153 
154 ////////////////////////////////////////////////////////////////////////////////
155 ///
156 /// Compare two rules.
157 /// useCutValue: true -> calculate a distance between the two rules based on the cut values
158 /// if the rule cuts are not equal, the distance is < 0 (-1.0)
159 /// return true if d<mindist
160 /// false-> ignore mindist, return true if rules are equal, ignoring cut values
161 /// mindist: min distance allowed between rules; if < 0 => set useCutValue=false;
162 ///
163 
164 Bool_t TMVA::Rule::Equal( const Rule& other, Bool_t useCutValue, Double_t mindist ) const
165 {
166  Bool_t rval=kFALSE;
167  if (mindist<0) useCutValue=kFALSE;
168  Double_t d = RuleDist( other, useCutValue );
169  // cut value used - return true if 0<=d<mindist
170  if (useCutValue) rval = ( (!(d<0)) && (d<mindist) );
171  else rval = (!(d<0));
172  // cut value not used, return true if <> -1
173  return rval;
174 }
175 
176 ////////////////////////////////////////////////////////////////////////////////
177 /// Returns:
178 /// -1.0 : rules are NOT equal, i.e, variables and/or cut directions are wrong
179 /// >=0: rules are equal apart from the cutvalue, returns d = sqrt(sum(c1-c2)^2)
180 /// If not useCutValue, the distance is exactly zero if they are equal
181 ///
182 
183 Double_t TMVA::Rule::RuleDist( const Rule& other, Bool_t useCutValue ) const
184 {
185  if (fCut->GetNvars()!=other.GetRuleCut()->GetNvars()) return -1.0; // check number of cuts
186  //
187  const UInt_t nvars = fCut->GetNvars();
188  //
189  Int_t sel; // cut variable
190  Double_t rms; // rms of cut variable
191  Double_t smin; // distance between the lower range
192  Double_t smax; // distance between the upper range
193  Double_t vminA,vmaxA; // min,max range of cut A (cut from this Rule)
194  Double_t vminB,vmaxB; // idem from other Rule
195  //
196  // compare nodes
197  // A 'distance' is assigned if the two rules has exactly the same set of cuts but with
198  // different cut values.
199  // The distance is given in number of sigmas
200  //
201  UInt_t in = 0; // cut index
202  Double_t sumdc2 = 0; // sum of 'distances'
203  Bool_t equal = true; // flag if cut are equal
204  //
205  const RuleCut *otherCut = other.GetRuleCut();
206  while ((equal) && (in<nvars)) {
207  // check equality in cut topology
208  equal = ( (fCut->GetSelector(in) == (otherCut->GetSelector(in))) &&
209  (fCut->GetCutDoMin(in) == (otherCut->GetCutDoMin(in))) &&
210  (fCut->GetCutDoMax(in) == (otherCut->GetCutDoMax(in))) );
211  // if equal topology, check cut values
212  if (equal) {
213  if (useCutValue) {
214  sel = fCut->GetSelector(in);
215  vminA = fCut->GetCutMin(in);
216  vmaxA = fCut->GetCutMax(in);
217  vminB = other.GetRuleCut()->GetCutMin(in);
218  vmaxB = other.GetRuleCut()->GetCutMax(in);
219  // messy - but ok...
220  rms = fRuleEnsemble->GetRuleFit()->GetMethodBase()->GetRMS(sel);
221  smin=0;
222  smax=0;
223  if (fCut->GetCutDoMin(in))
224  smin = ( rms>0 ? (vminA-vminB)/rms : 0 );
225  if (fCut->GetCutDoMax(in))
226  smax = ( rms>0 ? (vmaxA-vmaxB)/rms : 0 );
227  sumdc2 += smin*smin + smax*smax;
228  // sumw += 1.0/(rms*rms); // TODO: probably not needed
229  }
230  }
231  in++;
232  }
233  if (!useCutValue) sumdc2 = (equal ? 0.0:-1.0); // ignore cut values
234  else sumdc2 = (equal ? sqrt(sumdc2) : -1.0);
235 
236  return sumdc2;
237 }
238 
239 ////////////////////////////////////////////////////////////////////////////////
240 /// comparison operator ==
241 
242 Bool_t TMVA::Rule::operator==( const Rule& other ) const
243 {
244  return this->Equal( other, kTRUE, 1e-3 );
245 }
246 
247 ////////////////////////////////////////////////////////////////////////////////
248 /// comparison operator <
249 
250 Bool_t TMVA::Rule::operator<( const Rule& other ) const
251 {
252  return (fImportance < other.GetImportance());
253 }
254 
255 ////////////////////////////////////////////////////////////////////////////////
256 /// std::ostream operator
257 
258 std::ostream& TMVA::operator<< ( std::ostream& os, const Rule& rule )
259 {
260  rule.Print( os );
261  return os;
262 }
263 
264 ////////////////////////////////////////////////////////////////////////////////
265 /// returns the name of a rule
266 
268 {
269  return fRuleEnsemble->GetMethodBase()->GetInputLabel(i);
270 }
271 
272 ////////////////////////////////////////////////////////////////////////////////
273 /// copy function
274 
275 void TMVA::Rule::Copy( const Rule& other )
276 {
277  if(this != &other) {
278  SetRuleEnsemble( other.GetRuleEnsemble() );
279  fCut = new RuleCut( *(other.GetRuleCut()) );
280  fSSB = other.GetSSB();
281  fSSBNeve = other.GetSSBNeve();
282  SetCoefficient(other.GetCoefficient());
283  SetSupport( other.GetSupport() );
284  SetSigma( other.GetSigma() );
285  SetNorm( other.GetNorm() );
286  CalcImportance();
287  SetImportanceRef( other.GetImportanceRef() );
288  }
289 }
290 
291 ////////////////////////////////////////////////////////////////////////////////
292 /// print function
293 
294 void TMVA::Rule::Print( std::ostream& os ) const
295 {
296  const UInt_t nvars = fCut->GetNvars();
297  if (nvars<1) os << " *** WARNING - <EMPTY RULE> ***" << std::endl; // TODO: Fix this, use fLogger
298  //
299  Int_t sel;
300  Double_t valmin, valmax;
301  //
302  os << " Importance = " << Form("%1.4f", fImportance/fImportanceRef) << std::endl;
303  os << " Coefficient = " << Form("%1.4f", fCoefficient) << std::endl;
304  os << " Support = " << Form("%1.4f", fSupport) << std::endl;
305  os << " S/(S+B) = " << Form("%1.4f", fSSB) << std::endl;
306 
307  for ( UInt_t i=0; i<nvars; i++) {
308  os << " ";
309  sel = fCut->GetSelector(i);
310  valmin = fCut->GetCutMin(i);
311  valmax = fCut->GetCutMax(i);
312  //
313  os << Form("* Cut %2d",i+1) << " : " << std::flush;
314  if (fCut->GetCutDoMin(i)) os << Form("%10.3g",valmin) << " < " << std::flush;
315  else os << " " << std::flush;
316  os << GetVarName(sel) << std::flush;
317  if (fCut->GetCutDoMax(i)) os << " < " << Form("%10.3g",valmax) << std::flush;
318  else os << " " << std::flush;
319  os << std::endl;
320  }
321 }
322 
323 ////////////////////////////////////////////////////////////////////////////////
324 /// print function
325 
326 void TMVA::Rule::PrintLogger(const char *title) const
327 {
328  const UInt_t nvars = fCut->GetNvars();
329  if (nvars<1) Log() << kWARNING << "BUG TRAP: EMPTY RULE!!!" << Endl;
330  //
331  Int_t sel;
332  Double_t valmin, valmax;
333  //
334  if (title) Log() << kINFO << title;
335  Log() << kINFO
336  << "Importance = " << Form("%1.4f", fImportance/fImportanceRef) << Endl;
337 
338  for ( UInt_t i=0; i<nvars; i++) {
339 
340  Log() << kINFO << " ";
341  sel = fCut->GetSelector(i);
342  valmin = fCut->GetCutMin(i);
343  valmax = fCut->GetCutMax(i);
344  //
345  Log() << kINFO << Form("Cut %2d",i+1) << " : ";
346  if (fCut->GetCutDoMin(i)) Log() << kINFO << Form("%10.3g",valmin) << " < ";
347  else Log() << kINFO << " ";
348  Log() << kINFO << GetVarName(sel);
349  if (fCut->GetCutDoMax(i)) Log() << kINFO << " < " << Form("%10.3g",valmax);
350  else Log() << kINFO << " ";
351  Log() << Endl;
352  }
353 }
354 
355 ////////////////////////////////////////////////////////////////////////////////
356 /// extensive print function used to print info for the weight file
357 
358 void TMVA::Rule::PrintRaw( std::ostream& os ) const
359 {
360  Int_t dp = os.precision();
361  const UInt_t nvars = fCut->GetNvars();
362  os << "Parameters: "
363  << std::setprecision(10)
364  << fImportance << " "
365  << fImportanceRef << " "
366  << fCoefficient << " "
367  << fSupport << " "
368  << fSigma << " "
369  << fNorm << " "
370  << fSSB << " "
371  << fSSBNeve << " "
372  << std::endl; \
373  os << "N(cuts): " << nvars << std::endl; // mark end of nodes
374  for ( UInt_t i=0; i<nvars; i++) {
375  os << "Cut " << i << " : " << std::flush;
376  os << fCut->GetSelector(i)
377  << std::setprecision(10)
378  << " " << fCut->GetCutMin(i)
379  << " " << fCut->GetCutMax(i)
380  << " " << (fCut->GetCutDoMin(i) ? "T":"F")
381  << " " << (fCut->GetCutDoMax(i) ? "T":"F")
382  << std::endl;
383  }
384  os << std::setprecision(dp);
385 }
386 
387 ////////////////////////////////////////////////////////////////////////////////
388 
389 void* TMVA::Rule::AddXMLTo( void* parent ) const
390 {
391  void* rule = gTools().AddChild( parent, "Rule" );
392  const UInt_t nvars = fCut->GetNvars();
393 
394  gTools().AddAttr( rule, "Importance", fImportance );
395  gTools().AddAttr( rule, "Ref", fImportanceRef );
396  gTools().AddAttr( rule, "Coeff", fCoefficient );
397  gTools().AddAttr( rule, "Support", fSupport );
398  gTools().AddAttr( rule, "Sigma", fSigma );
399  gTools().AddAttr( rule, "Norm", fNorm );
400  gTools().AddAttr( rule, "SSB", fSSB );
401  gTools().AddAttr( rule, "SSBNeve", fSSBNeve );
402  gTools().AddAttr( rule, "Nvars", nvars );
403 
404  for (UInt_t i=0; i<nvars; i++) {
405  void* cut = gTools().AddChild( rule, "Cut" );
406  gTools().AddAttr( cut, "Selector", fCut->GetSelector(i) );
407  gTools().AddAttr( cut, "Min", fCut->GetCutMin(i) );
408  gTools().AddAttr( cut, "Max", fCut->GetCutMax(i) );
409  gTools().AddAttr( cut, "DoMin", (fCut->GetCutDoMin(i) ? "T":"F") );
410  gTools().AddAttr( cut, "DoMax", (fCut->GetCutDoMax(i) ? "T":"F") );
411  }
412 
413  return rule;
414 }
415 
416 ////////////////////////////////////////////////////////////////////////////////
417 /// read rule from XML
418 
419 void TMVA::Rule::ReadFromXML( void* wghtnode )
420 {
421  TString nodeName = TString( gTools().GetName(wghtnode) );
422  if (nodeName != "Rule") Log() << kFATAL << "<ReadFromXML> Unexpected node name: " << nodeName << Endl;
423 
424  gTools().ReadAttr( wghtnode, "Importance", fImportance );
425  gTools().ReadAttr( wghtnode, "Ref", fImportanceRef );
426  gTools().ReadAttr( wghtnode, "Coeff", fCoefficient );
427  gTools().ReadAttr( wghtnode, "Support", fSupport );
428  gTools().ReadAttr( wghtnode, "Sigma", fSigma );
429  gTools().ReadAttr( wghtnode, "Norm", fNorm );
430  gTools().ReadAttr( wghtnode, "SSB", fSSB );
431  gTools().ReadAttr( wghtnode, "SSBNeve", fSSBNeve );
432 
433  UInt_t nvars;
434  gTools().ReadAttr( wghtnode, "Nvars", nvars );
435  if (fCut) delete fCut;
436  fCut = new RuleCut();
437  fCut->SetNvars( nvars );
438 
439  // read Cut
440  void* ch = gTools().GetChild( wghtnode );
441  UInt_t i = 0;
442  UInt_t ui;
443  Double_t d;
444  Char_t c;
445  while (ch) {
446  gTools().ReadAttr( ch, "Selector", ui );
447  fCut->SetSelector( i, ui );
448  gTools().ReadAttr( ch, "Min", d );
449  fCut->SetCutMin ( i, d );
450  gTools().ReadAttr( ch, "Max", d );
451  fCut->SetCutMax ( i, d );
452  gTools().ReadAttr( ch, "DoMin", c );
453  fCut->SetCutDoMin( i, (c == 'T' ? kTRUE : kFALSE ) );
454  gTools().ReadAttr( ch, "DoMax", c );
455  fCut->SetCutDoMax( i, (c == 'T' ? kTRUE : kFALSE ) );
456 
457  i++;
458  ch = gTools().GetNextChild(ch);
459  }
460 
461  // sanity check
462  if (i != nvars) Log() << kFATAL << "<ReadFromXML> Mismatch in number of cuts: " << i << " != " << nvars << Endl;
463 }
464 
465 ////////////////////////////////////////////////////////////////////////////////
466 /// read function (format is the same as written by PrintRaw)
467 
468 void TMVA::Rule::ReadRaw( std::istream& istr )
469 {
470  TString dummy;
471  UInt_t nvars;
472  istr >> dummy
473  >> fImportance
474  >> fImportanceRef
475  >> fCoefficient
476  >> fSupport
477  >> fSigma
478  >> fNorm
479  >> fSSB
480  >> fSSBNeve;
481  // coverity[tainted_data_argument]
482  istr >> dummy >> nvars;
483  Double_t cutmin,cutmax;
484  UInt_t sel,idum;
485  Char_t bA, bB;
486  //
487  if (fCut) delete fCut;
488  fCut = new RuleCut();
489  fCut->SetNvars( nvars );
490  for ( UInt_t i=0; i<nvars; i++) {
491  istr >> dummy >> idum; // get 'Node' and index
492  istr >> dummy; // get ':'
493  istr >> sel >> cutmin >> cutmax >> bA >> bB;
494  fCut->SetSelector(i,sel);
495  fCut->SetCutMin(i,cutmin);
496  fCut->SetCutMax(i,cutmax);
497  fCut->SetCutDoMin(i,(bA=='T' ? kTRUE:kFALSE));
498  fCut->SetCutDoMax(i,(bB=='T' ? kTRUE:kFALSE));
499  }
500 }
Double_t GetPurity() const
Definition: RuleCut.h:79
Rule()
the simple constructor
Definition: Rule.cxx:104
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
bool equal(double d1, double d2, double stol=10000)
Basic string class.
Definition: TString.h:137
Double_t GetCutMax(Int_t is) const
Definition: RuleCut.h:75
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
RuleCut * fCut
Definition: Rule.h:178
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
Definition: Tools.h:308
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1134
Double_t RuleDist(const Rule &other, Bool_t useCutValue) const
Returns: -1.0 : rules are NOT equal, i.e, variables and/or cut directions are wrong >=0: rules are eq...
Definition: Rule.cxx:183
UInt_t GetSelector(Int_t is) const
Definition: RuleCut.h:73
Bool_t Equal(const Rule &other, Bool_t useCutValue, Double_t maxdist) const
Compare two rules.
Definition: Rule.cxx:164
double sqrt(double)
Tools & gTools()
Definition: Tools.cxx:79
Double_t GetSSBNeve() const
Definition: Rule.h:124
Double_t GetImportanceRef() const
Definition: Rule.h:152
Char_t GetCutDoMax(Int_t is) const
Definition: RuleCut.h:77
Double_t GetCutMin(Int_t is) const
Definition: RuleCut.h:74
Double_t fSSBNeve
Definition: Rule.h:187
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1158
Double_t GetSSB() const
Definition: Rule.h:123
const RuleCut * GetRuleCut() const
Definition: Rule.h:145
UInt_t GetNvars() const
Definition: RuleCut.h:72
const TString & GetVarName(Int_t i) const
returns the name of a rule
Definition: Rule.cxx:267
Bool_t operator==(const Rule &other) const
comparison operator ==
Definition: Rule.cxx:242
Double_t GetNorm() const
Definition: Rule.h:150
Bool_t ContainsVariable(UInt_t iv) const
check if variable in node
Definition: Rule.cxx:131
void ReadFromXML(void *wghtnode)
read rule from XML
Definition: Rule.cxx:419
Double_t GetImportance() const
Definition: Rule.h:151
EMsgType
Definition: Types.h:61
void * AddXMLTo(void *parent) const
Definition: Rule.cxx:389
const RuleEnsemble * GetRuleEnsemble() const
Definition: Rule.h:146
Double_t GetSigma() const
Definition: Rule.h:149
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
void PrintRaw(std::ostream &os) const
extensive print function used to print info for the weight file
Definition: Rule.cxx:358
void ReadRaw(std::istream &os)
read function (format is the same as written by PrintRaw)
Definition: Rule.cxx:468
void ReadAttr(void *node, const char *, T &value)
Definition: Tools.h:295
void PrintLogger(const char *title=0) const
print function
Definition: Rule.cxx:326
void Copy(const Rule &other)
copy function
Definition: Rule.cxx:275
std::ostream & operator<<(std::ostream &os, const BinaryTree &tree)
print the tree recursinvely using the << operator
Definition: BinaryTree.cxx:155
Double_t fSSB
Definition: Rule.h:186
double Double_t
Definition: RtypesCore.h:55
void Print(std::ostream &os) const
print function
Definition: Rule.cxx:294
static RooMathCoreReg dummy
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1170
Char_t GetCutDoMin(Int_t is) const
Definition: RuleCut.h:76
Double_t GetCoefficient() const
Definition: Rule.h:147
Bool_t operator<(const Rule &other) const
comparison operator <
Definition: Rule.cxx:250
char Char_t
Definition: RtypesCore.h:29
void SetMsgType(EMsgType t)
Definition: Rule.cxx:148
Double_t GetSupport() const
Definition: Rule.h:148
Double_t GetCutNeve() const
Definition: RuleCut.h:78
const Bool_t kTRUE
Definition: Rtypes.h:91
virtual ~Rule()
destructor
Definition: Rule.cxx:122
Definition: math.cpp:60