ROOT  6.07/01
Reference Guide
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
Rule.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : Rule *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * A class describung a 'rule' *
12  * Each internal node of a tree defines a rule from all the parental nodes. *
13  * A rule consists of atleast 2 nodes. *
14  * Input: a decision tree (in the constructor) *
15  * *
16  * Authors (alphabetical): *
17  * Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA *
18  * Helge Voss <Helge.Voss@cern.ch> - MPI-KP Heidelberg, Ger. *
19  * *
20  * Copyright (c) 2005: *
21  * CERN, Switzerland *
22  * Iowa State U. *
23  * MPI-K Heidelberg, Germany *
24  * *
25  * Redistribution and use in source and binary forms, with or without *
26  * modification, are permitted according to the terms listed in LICENSE *
27  * (http://tmva.sourceforge.net/LICENSE) *
28  **********************************************************************************/
29 
30 //________________________________________________________________________________
31 //
32 // Implementation of a rule
33 //
34 // A rule is simply a branch or a part of a branch in a tree.
35 // It fullfills the following:
36 // * First node is the root node of the originating tree
37 // * Consists of a minimum of 2 nodes
38 // * A rule returns for a given event:
39 // 0 : if the event fails at any node
40 // 1 : otherwise
41 // * If the rule contains <2 nodes, it returns 0 SHOULD NOT HAPPEN!
42 //
43 // The coefficient is found by either brute force or some sort of
44 // intelligent fitting. See the RuleEnsemble class for more info.
45 //________________________________________________________________________________
46 
47 #include "TMVA/Rule.h"
48 
49 #include "TMVA/Event.h"
50 #include "TMVA/MethodBase.h"
51 #include "TMVA/MethodRuleFit.h"
52 #include "TMVA/MsgLogger.h"
53 #include "TMVA/RuleCut.h"
54 #include "TMVA/RuleFit.h"
55 #include "TMVA/RuleEnsemble.h"
56 #include "TMVA/Tools.h"
57 #include "TMVA/Types.h"
58 
59 ////////////////////////////////////////////////////////////////////////////////
60 /// the main constructor for a Rule
61 
63  const std::vector< const Node * >& nodes )
64  : fCut ( 0 )
65  , fNorm ( 1.0 )
66  , fSupport ( 0.0 )
67  , fSigma ( 0.0 )
68  , fCoefficient ( 0.0 )
69  , fImportance ( 0.0 )
70  , fImportanceRef ( 1.0 )
71  , fRuleEnsemble ( re )
72  , fSSB ( 0 )
73  , fSSBNeve ( 0 )
74  , fLogger( new MsgLogger("RuleFit") )
75 {
76  //
77  // input:
78  // nodes - a vector of Node; from these all possible rules will be created
79  //
80  //
81 
82  fCut = new RuleCut( nodes );
83  fSSB = fCut->GetPurity();
85 }
86 
87 ////////////////////////////////////////////////////////////////////////////////
88 /// the simple constructor
89 
91  : fCut ( 0 )
92  , fNorm ( 1.0 )
93  , fSupport ( 0.0 )
94  , fSigma ( 0.0 )
95  , fCoefficient ( 0.0 )
96  , fImportance ( 0.0 )
97  , fImportanceRef ( 1.0 )
98  , fRuleEnsemble ( re )
99  , fSSB ( 0 )
100  , fSSBNeve ( 0 )
101  , fLogger( new MsgLogger("RuleFit") )
102 {
103 }
104 
105 ////////////////////////////////////////////////////////////////////////////////
106 /// the simple constructor
107 
109  : fCut ( 0 )
110  , fNorm ( 1.0 )
111  , fSupport ( 0.0 )
112  , fSigma ( 0.0 )
113  , fCoefficient ( 0.0 )
114  , fImportance ( 0.0 )
115  , fImportanceRef ( 1.0 )
116  , fRuleEnsemble ( 0 )
117  , fSSB ( 0 )
118  , fSSBNeve ( 0 )
119  , fLogger( new MsgLogger("RuleFit") )
120 {
121 }
122 
123 ////////////////////////////////////////////////////////////////////////////////
124 /// destructor
125 
127 {
128  delete fCut;
129  delete fLogger;
130 }
131 
132 ////////////////////////////////////////////////////////////////////////////////
133 /// check if variable in node
134 
136 {
137  Bool_t found = kFALSE;
138  Bool_t doneLoop = kFALSE;
139  UInt_t nvars = fCut->GetNvars();
140  UInt_t i = 0;
141  //
142  while (!doneLoop) {
143  found = (fCut->GetSelector(i) == iv);
144  i++;
145  doneLoop = (found || (i==nvars));
146  }
147  return found;
148 }
149 
150 ////////////////////////////////////////////////////////////////////////////////
151 
153 {
154  fLogger->SetMinType(t);
155 }
156 
157 
158 ////////////////////////////////////////////////////////////////////////////////
159 ///
160 /// Compare two rules.
161 /// useCutValue: true -> calculate a distance between the two rules based on the cut values
162 /// if the rule cuts are not equal, the distance is < 0 (-1.0)
163 /// return true if d<mindist
164 /// false-> ignore mindist, return true if rules are equal, ignoring cut values
165 /// mindist: min distance allowed between rules; if < 0 => set useCutValue=false;
166 ///
167 
168 Bool_t TMVA::Rule::Equal( const Rule& other, Bool_t useCutValue, Double_t mindist ) const
169 {
170  Bool_t rval=kFALSE;
171  if (mindist<0) useCutValue=kFALSE;
172  Double_t d = RuleDist( other, useCutValue );
173  // cut value used - return true if 0<=d<mindist
174  if (useCutValue) rval = ( (!(d<0)) && (d<mindist) );
175  else rval = (!(d<0));
176  // cut value not used, return true if <> -1
177  return rval;
178 }
179 
180 ////////////////////////////////////////////////////////////////////////////////
181 /// Returns:
182 /// -1.0 : rules are NOT equal, i.e, variables and/or cut directions are wrong
183 /// >=0: rules are equal apart from the cutvalue, returns d = sqrt(sum(c1-c2)^2)
184 /// If not useCutValue, the distance is exactly zero if they are equal
185 ///
186 
187 Double_t TMVA::Rule::RuleDist( const Rule& other, Bool_t useCutValue ) const
188 {
189  if (fCut->GetNvars()!=other.GetRuleCut()->GetNvars()) return -1.0; // check number of cuts
190  //
191  const UInt_t nvars = fCut->GetNvars();
192  //
193  Int_t sel; // cut variable
194  Double_t rms; // rms of cut variable
195  Double_t smin; // distance between the lower range
196  Double_t smax; // distance between the upper range
197  Double_t vminA,vmaxA; // min,max range of cut A (cut from this Rule)
198  Double_t vminB,vmaxB; // idem from other Rule
199  //
200  // compare nodes
201  // A 'distance' is assigned if the two rules has exactly the same set of cuts but with
202  // different cut values.
203  // The distance is given in number of sigmas
204  //
205  UInt_t in = 0; // cut index
206  Double_t sumdc2 = 0; // sum of 'distances'
207  Bool_t equal = true; // flag if cut are equal
208  //
209  const RuleCut *otherCut = other.GetRuleCut();
210  while ((equal) && (in<nvars)) {
211  // check equality in cut topology
212  equal = ( (fCut->GetSelector(in) == (otherCut->GetSelector(in))) &&
213  (fCut->GetCutDoMin(in) == (otherCut->GetCutDoMin(in))) &&
214  (fCut->GetCutDoMax(in) == (otherCut->GetCutDoMax(in))) );
215  // if equal topology, check cut values
216  if (equal) {
217  if (useCutValue) {
218  sel = fCut->GetSelector(in);
219  vminA = fCut->GetCutMin(in);
220  vmaxA = fCut->GetCutMax(in);
221  vminB = other.GetRuleCut()->GetCutMin(in);
222  vmaxB = other.GetRuleCut()->GetCutMax(in);
223  // messy - but ok...
224  rms = fRuleEnsemble->GetRuleFit()->GetMethodBase()->GetRMS(sel);
225  smin=0;
226  smax=0;
227  if (fCut->GetCutDoMin(in))
228  smin = ( rms>0 ? (vminA-vminB)/rms : 0 );
229  if (fCut->GetCutDoMax(in))
230  smax = ( rms>0 ? (vmaxA-vmaxB)/rms : 0 );
231  sumdc2 += smin*smin + smax*smax;
232  // sumw += 1.0/(rms*rms); // TODO: probably not needed
233  }
234  }
235  in++;
236  }
237  if (!useCutValue) sumdc2 = (equal ? 0.0:-1.0); // ignore cut values
238  else sumdc2 = (equal ? sqrt(sumdc2) : -1.0);
239 
240  return sumdc2;
241 }
242 
243 ////////////////////////////////////////////////////////////////////////////////
244 /// comparison operator ==
245 
246 Bool_t TMVA::Rule::operator==( const Rule& other ) const
247 {
248  return this->Equal( other, kTRUE, 1e-3 );
249 }
250 
251 ////////////////////////////////////////////////////////////////////////////////
252 /// comparison operator <
253 
254 Bool_t TMVA::Rule::operator<( const Rule& other ) const
255 {
256  return (fImportance < other.GetImportance());
257 }
258 
259 ////////////////////////////////////////////////////////////////////////////////
260 /// std::ostream operator
261 
262 std::ostream& TMVA::operator<< ( std::ostream& os, const Rule& rule )
263 {
264  rule.Print( os );
265  return os;
266 }
267 
268 ////////////////////////////////////////////////////////////////////////////////
269 /// returns the name of a rule
270 
272 {
273  return fRuleEnsemble->GetMethodBase()->GetInputLabel(i);
274 }
275 
276 ////////////////////////////////////////////////////////////////////////////////
277 /// copy function
278 
279 void TMVA::Rule::Copy( const Rule& other )
280 {
281  if(this != &other) {
282  SetRuleEnsemble( other.GetRuleEnsemble() );
283  fCut = new RuleCut( *(other.GetRuleCut()) );
284  fSSB = other.GetSSB();
285  fSSBNeve = other.GetSSBNeve();
286  SetCoefficient(other.GetCoefficient());
287  SetSupport( other.GetSupport() );
288  SetSigma( other.GetSigma() );
289  SetNorm( other.GetNorm() );
290  CalcImportance();
291  SetImportanceRef( other.GetImportanceRef() );
292  }
293 }
294 
295 ////////////////////////////////////////////////////////////////////////////////
296 /// print function
297 
298 void TMVA::Rule::Print( std::ostream& os ) const
299 {
300  const UInt_t nvars = fCut->GetNvars();
301  if (nvars<1) os << " *** WARNING - <EMPTY RULE> ***" << std::endl; // TODO: Fix this, use fLogger
302  //
303  Int_t sel;
304  Double_t valmin, valmax;
305  //
306  os << " Importance = " << Form("%1.4f", fImportance/fImportanceRef) << std::endl;
307  os << " Coefficient = " << Form("%1.4f", fCoefficient) << std::endl;
308  os << " Support = " << Form("%1.4f", fSupport) << std::endl;
309  os << " S/(S+B) = " << Form("%1.4f", fSSB) << std::endl;
310 
311  for ( UInt_t i=0; i<nvars; i++) {
312  os << " ";
313  sel = fCut->GetSelector(i);
314  valmin = fCut->GetCutMin(i);
315  valmax = fCut->GetCutMax(i);
316  //
317  os << Form("* Cut %2d",i+1) << " : " << std::flush;
318  if (fCut->GetCutDoMin(i)) os << Form("%10.3g",valmin) << " < " << std::flush;
319  else os << " " << std::flush;
320  os << GetVarName(sel) << std::flush;
321  if (fCut->GetCutDoMax(i)) os << " < " << Form("%10.3g",valmax) << std::flush;
322  else os << " " << std::flush;
323  os << std::endl;
324  }
325 }
326 
327 ////////////////////////////////////////////////////////////////////////////////
328 /// print function
329 
330 void TMVA::Rule::PrintLogger(const char *title) const
331 {
332  const UInt_t nvars = fCut->GetNvars();
333  if (nvars<1) Log() << kWARNING << "BUG TRAP: EMPTY RULE!!!" << Endl;
334  //
335  Int_t sel;
336  Double_t valmin, valmax;
337  //
338  if (title) Log() << kINFO << title;
339  Log() << kINFO
340  << "Importance = " << Form("%1.4f", fImportance/fImportanceRef) << Endl;
341 
342  for ( UInt_t i=0; i<nvars; i++) {
343 
344  Log() << kINFO << " ";
345  sel = fCut->GetSelector(i);
346  valmin = fCut->GetCutMin(i);
347  valmax = fCut->GetCutMax(i);
348  //
349  Log() << kINFO << Form("Cut %2d",i+1) << " : ";
350  if (fCut->GetCutDoMin(i)) Log() << kINFO << Form("%10.3g",valmin) << " < ";
351  else Log() << kINFO << " ";
352  Log() << kINFO << GetVarName(sel);
353  if (fCut->GetCutDoMax(i)) Log() << kINFO << " < " << Form("%10.3g",valmax);
354  else Log() << kINFO << " ";
355  Log() << Endl;
356  }
357 }
358 
359 ////////////////////////////////////////////////////////////////////////////////
360 /// extensive print function used to print info for the weight file
361 
362 void TMVA::Rule::PrintRaw( std::ostream& os ) const
363 {
364  Int_t dp = os.precision();
365  const UInt_t nvars = fCut->GetNvars();
366  os << "Parameters: "
367  << std::setprecision(10)
368  << fImportance << " "
369  << fImportanceRef << " "
370  << fCoefficient << " "
371  << fSupport << " "
372  << fSigma << " "
373  << fNorm << " "
374  << fSSB << " "
375  << fSSBNeve << " "
376  << std::endl; \
377  os << "N(cuts): " << nvars << std::endl; // mark end of nodes
378  for ( UInt_t i=0; i<nvars; i++) {
379  os << "Cut " << i << " : " << std::flush;
380  os << fCut->GetSelector(i)
381  << std::setprecision(10)
382  << " " << fCut->GetCutMin(i)
383  << " " << fCut->GetCutMax(i)
384  << " " << (fCut->GetCutDoMin(i) ? "T":"F")
385  << " " << (fCut->GetCutDoMax(i) ? "T":"F")
386  << std::endl;
387  }
388  os << std::setprecision(dp);
389 }
390 
391 ////////////////////////////////////////////////////////////////////////////////
392 
393 void* TMVA::Rule::AddXMLTo( void* parent ) const
394 {
395  void* rule = gTools().AddChild( parent, "Rule" );
396  const UInt_t nvars = fCut->GetNvars();
397 
398  gTools().AddAttr( rule, "Importance", fImportance );
399  gTools().AddAttr( rule, "Ref", fImportanceRef );
400  gTools().AddAttr( rule, "Coeff", fCoefficient );
401  gTools().AddAttr( rule, "Support", fSupport );
402  gTools().AddAttr( rule, "Sigma", fSigma );
403  gTools().AddAttr( rule, "Norm", fNorm );
404  gTools().AddAttr( rule, "SSB", fSSB );
405  gTools().AddAttr( rule, "SSBNeve", fSSBNeve );
406  gTools().AddAttr( rule, "Nvars", nvars );
407 
408  for (UInt_t i=0; i<nvars; i++) {
409  void* cut = gTools().AddChild( rule, "Cut" );
410  gTools().AddAttr( cut, "Selector", fCut->GetSelector(i) );
411  gTools().AddAttr( cut, "Min", fCut->GetCutMin(i) );
412  gTools().AddAttr( cut, "Max", fCut->GetCutMax(i) );
413  gTools().AddAttr( cut, "DoMin", (fCut->GetCutDoMin(i) ? "T":"F") );
414  gTools().AddAttr( cut, "DoMax", (fCut->GetCutDoMax(i) ? "T":"F") );
415  }
416 
417  return rule;
418 }
419 
420 ////////////////////////////////////////////////////////////////////////////////
421 /// read rule from XML
422 
423 void TMVA::Rule::ReadFromXML( void* wghtnode )
424 {
425  TString nodeName = TString( gTools().GetName(wghtnode) );
426  if (nodeName != "Rule") Log() << kFATAL << "<ReadFromXML> Unexpected node name: " << nodeName << Endl;
427 
428  gTools().ReadAttr( wghtnode, "Importance", fImportance );
429  gTools().ReadAttr( wghtnode, "Ref", fImportanceRef );
430  gTools().ReadAttr( wghtnode, "Coeff", fCoefficient );
431  gTools().ReadAttr( wghtnode, "Support", fSupport );
432  gTools().ReadAttr( wghtnode, "Sigma", fSigma );
433  gTools().ReadAttr( wghtnode, "Norm", fNorm );
434  gTools().ReadAttr( wghtnode, "SSB", fSSB );
435  gTools().ReadAttr( wghtnode, "SSBNeve", fSSBNeve );
436 
437  UInt_t nvars;
438  gTools().ReadAttr( wghtnode, "Nvars", nvars );
439  if (fCut) delete fCut;
440  fCut = new RuleCut();
441  fCut->SetNvars( nvars );
442 
443  // read Cut
444  void* ch = gTools().GetChild( wghtnode );
445  UInt_t i = 0;
446  UInt_t ui;
447  Double_t d;
448  Char_t c;
449  while (ch) {
450  gTools().ReadAttr( ch, "Selector", ui );
451  fCut->SetSelector( i, ui );
452  gTools().ReadAttr( ch, "Min", d );
453  fCut->SetCutMin ( i, d );
454  gTools().ReadAttr( ch, "Max", d );
455  fCut->SetCutMax ( i, d );
456  gTools().ReadAttr( ch, "DoMin", c );
457  fCut->SetCutDoMin( i, (c == 'T' ? kTRUE : kFALSE ) );
458  gTools().ReadAttr( ch, "DoMax", c );
459  fCut->SetCutDoMax( i, (c == 'T' ? kTRUE : kFALSE ) );
460 
461  i++;
462  ch = gTools().GetNextChild(ch);
463  }
464 
465  // sanity check
466  if (i != nvars) Log() << kFATAL << "<ReadFromXML> Mismatch in number of cuts: " << i << " != " << nvars << Endl;
467 }
468 
469 ////////////////////////////////////////////////////////////////////////////////
470 /// read function (format is the same as written by PrintRaw)
471 
472 void TMVA::Rule::ReadRaw( std::istream& istr )
473 {
474  TString dummy;
475  UInt_t nvars;
476  istr >> dummy
477  >> fImportance
478  >> fImportanceRef
479  >> fCoefficient
480  >> fSupport
481  >> fSigma
482  >> fNorm
483  >> fSSB
484  >> fSSBNeve;
485  // coverity[tainted_data_argument]
486  istr >> dummy >> nvars;
487  Double_t cutmin,cutmax;
488  UInt_t sel,idum;
489  Char_t bA, bB;
490  //
491  if (fCut) delete fCut;
492  fCut = new RuleCut();
493  fCut->SetNvars( nvars );
494  for ( UInt_t i=0; i<nvars; i++) {
495  istr >> dummy >> idum; // get 'Node' and index
496  istr >> dummy; // get ':'
497  istr >> sel >> cutmin >> cutmax >> bA >> bB;
498  fCut->SetSelector(i,sel);
499  fCut->SetCutMin(i,cutmin);
500  fCut->SetCutMax(i,cutmax);
501  fCut->SetCutDoMin(i,(bA=='T' ? kTRUE:kFALSE));
502  fCut->SetCutDoMax(i,(bB=='T' ? kTRUE:kFALSE));
503  }
504 }
Double_t GetPurity() const
Definition: RuleCut.h:79
Rule()
the simple constructor
Definition: Rule.cxx:108
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
return c
bool equal(double d1, double d2, double stol=10000)
Basic string class.
Definition: TString.h:137
Double_t GetCutMax(Int_t is) const
Definition: RuleCut.h:75
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
RuleCut * fCut
Definition: Rule.h:178
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
Definition: Tools.h:308
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1134
Double_t RuleDist(const Rule &other, Bool_t useCutValue) const
Returns: -1.0 : rules are NOT equal, i.e, variables and/or cut directions are wrong >=0: rules are eq...
Definition: Rule.cxx:187
UInt_t GetSelector(Int_t is) const
Definition: RuleCut.h:73
Bool_t Equal(const Rule &other, Bool_t useCutValue, Double_t maxdist) const
Compare two rules.
Definition: Rule.cxx:168
double sqrt(double)
Tools & gTools()
Definition: Tools.cxx:79
Double_t GetSSBNeve() const
Definition: Rule.h:124
Double_t GetImportanceRef() const
Definition: Rule.h:152
Char_t GetCutDoMax(Int_t is) const
Definition: RuleCut.h:77
Double_t GetCutMin(Int_t is) const
Definition: RuleCut.h:74
int d
Definition: tornado.py:11
Double_t fSSBNeve
Definition: Rule.h:187
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1158
Double_t GetSSB() const
Definition: Rule.h:123
const RuleCut * GetRuleCut() const
Definition: Rule.h:145
UInt_t GetNvars() const
Definition: RuleCut.h:72
const TString & GetVarName(Int_t i) const
returns the name of a rule
Definition: Rule.cxx:271
Bool_t operator==(const Rule &other) const
comparison operator ==
Definition: Rule.cxx:246
TThread * t[5]
Definition: threadsh1.C:13
Double_t GetNorm() const
Definition: Rule.h:150
Bool_t ContainsVariable(UInt_t iv) const
check if variable in node
Definition: Rule.cxx:135
TPaveLabel title(3, 27.1, 15, 28.7,"ROOT Environment and Tools")
void ReadFromXML(void *wghtnode)
read rule from XML
Definition: Rule.cxx:423
Double_t GetImportance() const
Definition: Rule.h:151
EMsgType
Definition: Types.h:61
void * AddXMLTo(void *parent) const
Definition: Rule.cxx:393
const RuleEnsemble * GetRuleEnsemble() const
Definition: Rule.h:146
Double_t GetSigma() const
Definition: Rule.h:149
unsigned int UInt_t
Definition: RtypesCore.h:42
char * Form(const char *fmt,...)
void PrintRaw(std::ostream &os) const
extensive print function used to print info for the weight file
Definition: Rule.cxx:362
void ReadRaw(std::istream &os)
read function (format is the same as written by PrintRaw)
Definition: Rule.cxx:472
void ReadAttr(void *node, const char *, T &value)
Definition: Tools.h:295
void PrintLogger(const char *title=0) const
print function
Definition: Rule.cxx:330
void Copy(const Rule &other)
copy function
Definition: Rule.cxx:279
std::ostream & operator<<(std::ostream &os, const BinaryTree &tree)
print the tree recursinvely using the << operator
Definition: BinaryTree.cxx:157
Double_t fSSB
Definition: Rule.h:186
double Double_t
Definition: RtypesCore.h:55
void Print(std::ostream &os) const
print function
Definition: Rule.cxx:298
static RooMathCoreReg dummy
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1170
Char_t GetCutDoMin(Int_t is) const
Definition: RuleCut.h:76
Double_t GetCoefficient() const
Definition: Rule.h:147
Bool_t operator<(const Rule &other) const
comparison operator <
Definition: Rule.cxx:254
char Char_t
Definition: RtypesCore.h:29
void SetMsgType(EMsgType t)
Definition: Rule.cxx:152
Double_t GetSupport() const
Definition: Rule.h:148
Double_t GetCutNeve() const
Definition: RuleCut.h:78
const Bool_t kTRUE
Definition: Rtypes.h:91
virtual ~Rule()
destructor
Definition: Rule.cxx:126
Definition: math.cpp:60