ROOT  6.07/01
Reference Guide
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
DecisionTree.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : TMVA::DecisionTree *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation of a Decision Tree *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
16  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
17  * Eckhard v. Toerne <evt@uni-bonn.de> - U of Bonn, Germany *
18  * Jan Therhaag <Jan.Therhaag@cern.ch> - U of Bonn, Germany *
19  * *
20  * Copyright (c) 2005-2011: *
21  * CERN, Switzerland *
22  * U. of Victoria, Canada *
23  * MPI-K Heidelberg, Germany *
24  * U. of Bonn, Germany *
25  * *
26  * Redistribution and use in source and binary forms, with or without *
27  * modification, are permitted according to the terms listed in LICENSE *
28  * (http://mva.sourceforge.net/license.txt) *
29  * *
30  **********************************************************************************/
31 
32 //_______________________________________________________________________
33 //
34 // Implementation of a Decision Tree
35 //
36 // In a decision tree successive decision nodes are used to categorize the
37 // events out of the sample as either signal or background. Each node
38 // uses only a single discriminating variable to decide if the event is
39 // signal-like ("goes right") or background-like ("goes left"). This
40 // forms a tree like structure with "baskets" at the end (leave nodes),
41 // and an event is classified as either signal or background according to
42 // whether the basket where it ends up has been classified signal or
43 // background during the training. Training of a decision tree is the
44 // process to define the "cut criteria" for each node. The training
45 // starts with the root node. Here one takes the full training event
46 // sample and selects the variable and corresponding cut value that gives
47 // the best separation between signal and background at this stage. Using
48 // this cut criterion, the sample is then divided into two subsamples, a
49 // signal-like (right) and a background-like (left) sample. Two new nodes
50 // are then created for each of the two sub-samples and they are
51 // constructed using the same mechanism as described for the root
52 // node. The devision is stopped once a certain node has reached either a
53 // minimum number of events, or a minimum or maximum signal purity. These
54 // leave nodes are then called "signal" or "background" if they contain
55 // more signal respective background events from the training sample.
56 //_______________________________________________________________________
57 
58 #include <iostream>
59 #include <algorithm>
60 #include <vector>
61 #include <limits>
62 #include <fstream>
63 #include <algorithm>
64 #include <cassert>
65 
66 #include "TRandom3.h"
67 #include "TMath.h"
68 #include "TMatrix.h"
69 
70 #include "TMVA/MsgLogger.h"
71 #include "TMVA/DecisionTree.h"
72 #include "TMVA/DecisionTreeNode.h"
73 #include "TMVA/BinarySearchTree.h"
74 
75 #include "TMVA/Tools.h"
76 
77 #include "TMVA/GiniIndex.h"
78 #include "TMVA/CrossEntropy.h"
80 #include "TMVA/SdivSqrtSplusB.h"
81 #include "TMVA/Event.h"
82 #include "TMVA/BDTEventWrapper.h"
83 #include "TMVA/IPruneTool.h"
86 
87 const Int_t TMVA::DecisionTree::fgRandomSeed = 0; // set nonzero for debugging and zero for random seeds
88 
89 using std::vector;
90 
92 
93 ////////////////////////////////////////////////////////////////////////////////
94 /// default constructor using the GiniIndex as separation criterion,
95 /// no restrictions on minium number of events in a leave note or the
96 /// separation gain in the node splitting
97 
98 TMVA::DecisionTree::DecisionTree():
99 BinaryTree(),
100  fNvars (0),
101  fNCuts (-1),
102  fUseFisherCuts (kFALSE),
103  fMinLinCorrForFisher (1),
104  fUseExclusiveVars (kTRUE),
105  fSepType (NULL),
106  fRegType (NULL),
107  fMinSize (0),
108  fMinNodeSize (1),
109  fMinSepGain (0),
110  fUseSearchTree(kFALSE),
111  fPruneStrength(0),
112  fPruneMethod (kNoPruning),
113  fNNodesBeforePruning(0),
114  fNodePurityLimit(0.5),
115  fRandomisedTree (kFALSE),
116  fUseNvars (0),
117  fUsePoissonNvars(kFALSE),
118  fMyTrandom (NULL),
119  fMaxDepth (999999),
120  fSigClass (0),
121  fTreeID (0),
122  fAnalysisType (Types::kClassification),
123  fDataSetInfo (NULL)
124 {
125 }
126 
127 ////////////////////////////////////////////////////////////////////////////////
128 /// constructor specifying the separation type, the min number of
129 /// events in a no that is still subjected to further splitting, the
130 /// number of bins in the grid used in applying the cut for the node
131 /// splitting.
132 
134  Bool_t randomisedTree, Int_t useNvars, Bool_t usePoissonNvars,
135  UInt_t nMaxDepth, Int_t iSeed, Float_t purityLimit, Int_t treeID):
136  BinaryTree(),
137  fNvars (0),
138  fNCuts (nCuts),
139  fUseFisherCuts (kFALSE),
140  fMinLinCorrForFisher (1),
141  fUseExclusiveVars (kTRUE),
142  fSepType (sepType),
143  fRegType (NULL),
144  fMinSize (0),
145  fMinNodeSize (minSize),
146  fMinSepGain (0),
147  fUseSearchTree (kFALSE),
148  fPruneStrength (0),
149  fPruneMethod (kNoPruning),
150  fNNodesBeforePruning(0),
151  fNodePurityLimit(purityLimit),
152  fRandomisedTree (randomisedTree),
153  fUseNvars (useNvars),
154  fUsePoissonNvars(usePoissonNvars),
155  fMyTrandom (new TRandom3(iSeed)),
156  fMaxDepth (nMaxDepth),
157  fSigClass (cls),
158  fTreeID (treeID),
159  fAnalysisType (Types::kClassification),
160  fDataSetInfo (dataInfo)
161 {
162  if (sepType == NULL) { // it is interpreted as a regression tree, where
163  // currently the separation type (simple least square)
164  // cannot be chosen freely)
167  if ( nCuts <=0 ) {
168  fNCuts = 200;
169  Log() << kWARNING << " You had choosen the training mode using optimal cuts, not\n"
170  << " based on a grid of " << fNCuts << " by setting the option NCuts < 0\n"
171  << " as this doesn't exist yet, I set it to " << fNCuts << " and use the grid"
172  << Endl;
173  }
174  }else{
176  }
177 }
178 
179 ////////////////////////////////////////////////////////////////////////////////
180 /// copy constructor that creates a true copy, i.e. a completely independent tree
181 /// the node copy will recursively copy all the nodes
182 
184  BinaryTree(),
185  fNvars (d.fNvars),
186  fNCuts (d.fNCuts),
187  fUseFisherCuts (d.fUseFisherCuts),
188  fMinLinCorrForFisher (d.fMinLinCorrForFisher),
189  fUseExclusiveVars (d.fUseExclusiveVars),
190  fSepType (d.fSepType),
191  fRegType (d.fRegType),
192  fMinSize (d.fMinSize),
193  fMinNodeSize(d.fMinNodeSize),
194  fMinSepGain (d.fMinSepGain),
195  fUseSearchTree (d.fUseSearchTree),
196  fPruneStrength (d.fPruneStrength),
197  fPruneMethod (d.fPruneMethod),
198  fNodePurityLimit(d.fNodePurityLimit),
199  fRandomisedTree (d.fRandomisedTree),
200  fUseNvars (d.fUseNvars),
201  fUsePoissonNvars(d.fUsePoissonNvars),
202  fMyTrandom (new TRandom3(fgRandomSeed)), // well, that means it's not an identical copy. But I only ever intend to really copy trees that are "outgrown" already.
203  fMaxDepth (d.fMaxDepth),
204  fSigClass (d.fSigClass),
205  fTreeID (d.fTreeID),
206  fAnalysisType(d.fAnalysisType),
207  fDataSetInfo (d.fDataSetInfo)
208 {
209  this->SetRoot( new TMVA::DecisionTreeNode ( *((DecisionTreeNode*)(d.GetRoot())) ) );
210  this->SetParentTreeInNodes();
211  fNNodes = d.fNNodes;
212 }
213 
214 
215 ////////////////////////////////////////////////////////////////////////////////
216 /// destructor
217 
219 {
220  // destruction of the tree nodes done in the "base class" BinaryTree
221 
222  if (fMyTrandom) delete fMyTrandom;
223  if (fRegType) delete fRegType;
224 }
225 
226 ////////////////////////////////////////////////////////////////////////////////
227 /// descend a tree to find all its leaf nodes, fill max depth reached in the
228 /// tree at the same time.
229 
231 {
232  if (n == NULL) { //default, start at the tree top, then descend recursively
233  n = this->GetRoot();
234  if (n == NULL) {
235  Log() << kFATAL << "SetParentTreeNodes: started with undefined ROOT node" <<Endl;
236  return ;
237  }
238  }
239 
240  if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) != NULL) ) {
241  Log() << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
242  return;
243  } else if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) == NULL) ) {
244  Log() << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
245  return;
246  }
247  else {
248  if (this->GetLeftDaughter(n) != NULL) {
249  this->SetParentTreeInNodes( this->GetLeftDaughter(n) );
250  }
251  if (this->GetRightDaughter(n) != NULL) {
252  this->SetParentTreeInNodes( this->GetRightDaughter(n) );
253  }
254  }
255  n->SetParentTree(this);
256  if (n->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(n->GetDepth());
257  return;
258 }
259 
260 ////////////////////////////////////////////////////////////////////////////////
261 /// re-create a new tree (decision tree or search tree) from XML
262 
264  std::string type("");
265  gTools().ReadAttr(node,"type", type);
266  DecisionTree* dt = new DecisionTree();
267 
268  dt->ReadXML( node, tmva_Version_Code );
269  return dt;
270 }
271 
272 
273 ////////////////////////////////////////////////////////////////////////////////
274 /// building the decision tree by recursively calling the splitting of
275 /// one (root-) node into two daughter nodes (returns the number of nodes)
276 
277 UInt_t TMVA::DecisionTree::BuildTree( const std::vector<const TMVA::Event*> & eventSample,
279 {
280  if (node==NULL) {
281  //start with the root node
282  node = new TMVA::DecisionTreeNode();
283  fNNodes = 1;
284  this->SetRoot(node);
285  // have to use "s" for start as "r" for "root" would be the same as "r" for "right"
286  this->GetRoot()->SetPos('s');
287  this->GetRoot()->SetDepth(0);
288  this->GetRoot()->SetParentTree(this);
289  fMinSize = fMinNodeSize/100. * eventSample.size();
290  if (GetTreeID()==0){
291  Log() << kINFO << "The minimal node size MinNodeSize=" << fMinNodeSize << " fMinNodeSize="<<fMinNodeSize<< "% is translated to an actual number of events = "<< fMinSize<< " for the training sample size of " << eventSample.size() << Endl;
292  Log() << kINFO << "Note: This number will be taken as absolute minimum in the node, " << Endl;
293  Log() << kINFO << " in terms of 'weighted events' and unweighted ones !! " << Endl;
294  }
295  }
296 
297  UInt_t nevents = eventSample.size();
298 
299  if (nevents > 0 ) {
300  if (fNvars==0) fNvars = eventSample[0]->GetNVariables(); // should have been set before, but ... well..
301  fVariableImportance.resize(fNvars);
302  }
303  else Log() << kFATAL << ":<BuildTree> eventsample Size == 0 " << Endl;
304 
305  Double_t s=0, b=0;
306  Double_t suw=0, buw=0;
307  Double_t sub=0, bub=0; // unboosted!
308  Double_t target=0, target2=0;
309  Float_t *xmin = new Float_t[fNvars];
310  Float_t *xmax = new Float_t[fNvars];
311  for (UInt_t ivar=0; ivar<fNvars; ivar++) {
312  xmin[ivar]=xmax[ivar]=0;
313  }
314  for (UInt_t iev=0; iev<eventSample.size(); iev++) {
315  const TMVA::Event* evt = eventSample[iev];
316  const Double_t weight = evt->GetWeight();
317  const Double_t orgWeight = evt->GetOriginalWeight(); // unboosted!
318  if (evt->GetClass() == fSigClass) {
319  s += weight;
320  suw += 1;
321  sub += orgWeight;
322  }
323  else {
324  b += weight;
325  buw += 1;
326  bub += orgWeight;
327  }
328  if ( DoRegression() ) {
329  const Double_t tgt = evt->GetTarget(0);
330  target +=weight*tgt;
331  target2+=weight*tgt*tgt;
332  }
333 
334  for (UInt_t ivar=0; ivar<fNvars; ivar++) {
335  const Double_t val = evt->GetValue(ivar);
336  if (iev==0) xmin[ivar]=xmax[ivar]=val;
337  if (val < xmin[ivar]) xmin[ivar]=val;
338  if (val > xmax[ivar]) xmax[ivar]=val;
339  }
340  }
341 
342 
343  if (s+b < 0) {
344  Log() << kWARNING << " One of the Decision Tree nodes has negative total number of signal or background events. "
345  << "(Nsig="<<s<<" Nbkg="<<b<<" Probaby you use a Monte Carlo with negative weights. That should in principle "
346  << "be fine as long as on average you end up with something positive. For this you have to make sure that the "
347  << "minimul number of (unweighted) events demanded for a tree node (currently you use: MinNodeSize="<<fMinNodeSize
348  << "% of training events, you can set this via the BDT option string when booking the classifier) is large enough "
349  << "to allow for reasonable averaging!!!" << Endl
350  << " If this does not help.. maybe you want to try the option: NoNegWeightsInTraining which ignores events "
351  << "with negative weight in the training." << Endl;
352  double nBkg=0.;
353  for (UInt_t i=0; i<eventSample.size(); i++) {
354  if (eventSample[i]->GetClass() != fSigClass) {
355  nBkg += eventSample[i]->GetWeight();
356  Log() << kDEBUG << "Event "<< i<< " has (original) weight: " << eventSample[i]->GetWeight()/eventSample[i]->GetBoostWeight()
357  << " boostWeight: " << eventSample[i]->GetBoostWeight() << Endl;
358  }
359  }
360  Log() << kDEBUG << " that gives in total: " << nBkg<<Endl;
361  }
362 
363  node->SetNSigEvents(s);
364  node->SetNBkgEvents(b);
365  node->SetNSigEvents_unweighted(suw);
366  node->SetNBkgEvents_unweighted(buw);
367  node->SetNSigEvents_unboosted(sub);
368  node->SetNBkgEvents_unboosted(bub);
369  node->SetPurity();
370  if (node == this->GetRoot()) {
371  node->SetNEvents(s+b);
372  node->SetNEvents_unweighted(suw+buw);
373  node->SetNEvents_unboosted(sub+bub);
374  }
375  for (UInt_t ivar=0; ivar<fNvars; ivar++) {
376  node->SetSampleMin(ivar,xmin[ivar]);
377  node->SetSampleMax(ivar,xmax[ivar]);
378  }
379  delete[] xmin;
380  delete[] xmax;
381 
382  // I now demand the minimum number of events for both daughter nodes. Hence if the number
383  // of events in the parent node is not at least two times as big, I don't even need to try
384  // splitting
385 
386  // ask here for actuall "events" independent of their weight.. OR the weighted events
387  // to execeed the min requested number of events per dauther node
388  // (NOTE: make sure that at the eventSample at the ROOT node has sum_of_weights == sample.size() !
389  // if ((eventSample.size() >= 2*fMinSize ||s+b >= 2*fMinSize) && node->GetDepth() < fMaxDepth
390  // std::cout << "------------------------------------------------------------------"<<std::endl;
391  // std::cout << "------------------------------------------------------------------"<<std::endl;
392  // std::cout << " eveSampleSize = "<< eventSample.size() << " s+b="<<s+b << std::endl;
393  if ((eventSample.size() >= 2*fMinSize && s+b >= 2*fMinSize) && node->GetDepth() < fMaxDepth
394  && ( ( s!=0 && b !=0 && !DoRegression()) || ( (s+b)!=0 && DoRegression()) ) ) {
395  Double_t separationGain;
396  if (fNCuts > 0){
397  separationGain = this->TrainNodeFast(eventSample, node);
398  } else {
399  separationGain = this->TrainNodeFull(eventSample, node);
400  }
401  if (separationGain < std::numeric_limits<double>::epsilon()) { // we could not gain anything, e.g. all events are in one bin,
402  /// if (separationGain < 0.00000001) { // we could not gain anything, e.g. all events are in one bin,
403  // no cut can actually do anything to improve the node
404  // hence, naturally, the current node is a leaf node
405  if (DoRegression()) {
406  node->SetSeparationIndex(fRegType->GetSeparationIndex(s+b,target,target2));
407  node->SetResponse(target/(s+b));
408  if( (target2/(s+b) - target/(s+b)*target/(s+b)) < std::numeric_limits<double>::epsilon() ){
409  node->SetRMS(0);
410  }else{
411  node->SetRMS(TMath::Sqrt(target2/(s+b) - target/(s+b)*target/(s+b)));
412  }
413  }
414  else {
415  node->SetSeparationIndex(fSepType->GetSeparationIndex(s,b));
416 
417  if (node->GetPurity() > fNodePurityLimit) node->SetNodeType(1);
418  else node->SetNodeType(-1);
419  }
420  if (node->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(node->GetDepth());
421 
422  } else {
423 
424  std::vector<const TMVA::Event*> leftSample; leftSample.reserve(nevents);
425  std::vector<const TMVA::Event*> rightSample; rightSample.reserve(nevents);
426 
427  Double_t nRight=0, nLeft=0;
428  Double_t nRightUnBoosted=0, nLeftUnBoosted=0;
429 
430  for (UInt_t ie=0; ie< nevents ; ie++) {
431  if (node->GoesRight(*eventSample[ie])) {
432  rightSample.push_back(eventSample[ie]);
433  nRight += eventSample[ie]->GetWeight();
434  nRightUnBoosted += eventSample[ie]->GetOriginalWeight();
435  }
436  else {
437  leftSample.push_back(eventSample[ie]);
438  nLeft += eventSample[ie]->GetWeight();
439  nLeftUnBoosted += eventSample[ie]->GetOriginalWeight();
440  }
441  }
442  // std::cout << " left:" << leftSample.size()
443  // << " right:" << rightSample.size()
444  // << " total:" << leftSample.size()+rightSample.size()
445  // << std::endl
446  // << " while the separation is thought to be " << separationGain
447  // << std::endl;;
448 
449  // sanity check
450  if (leftSample.empty() || rightSample.empty()) {
451  Log() << kERROR << "<TrainNode> all events went to the same branch" << Endl
452  << "--- Hence new node == old node ... check" << Endl
453  << "--- left:" << leftSample.size()
454  << " right:" << rightSample.size() << Endl
455  << " while the separation is thought to be " << separationGain
456  << kFATAL << "--- this should never happen, please write a bug report to Helge.Voss@cern.ch" << Endl;
457  }
458 
459  // continue building daughter nodes for the left and the right eventsample
460  TMVA::DecisionTreeNode *rightNode = new TMVA::DecisionTreeNode(node,'r');
461  fNNodes++;
462  rightNode->SetNEvents(nRight);
463  rightNode->SetNEvents_unboosted(nRightUnBoosted);
464  rightNode->SetNEvents_unweighted(rightSample.size());
465 
466  TMVA::DecisionTreeNode *leftNode = new TMVA::DecisionTreeNode(node,'l');
467 
468  fNNodes++;
469  leftNode->SetNEvents(nLeft);
470  leftNode->SetNEvents_unboosted(nLeftUnBoosted);
471  leftNode->SetNEvents_unweighted(leftSample.size());
472 
473  node->SetNodeType(0);
474  node->SetLeft(leftNode);
475  node->SetRight(rightNode);
476 
477  this->BuildTree(rightSample, rightNode);
478  this->BuildTree(leftSample, leftNode );
479 
480  }
481  }
482  else{ // it is a leaf node
483  if (DoRegression()) {
484  node->SetSeparationIndex(fRegType->GetSeparationIndex(s+b,target,target2));
485  node->SetResponse(target/(s+b));
486  if( (target2/(s+b) - target/(s+b)*target/(s+b)) < std::numeric_limits<double>::epsilon() ) {
487  node->SetRMS(0);
488  }else{
489  node->SetRMS(TMath::Sqrt(target2/(s+b) - target/(s+b)*target/(s+b)));
490  }
491  }
492  else {
493  node->SetSeparationIndex(fSepType->GetSeparationIndex(s,b));
494  if (node->GetPurity() > fNodePurityLimit) node->SetNodeType(1);
495  else node->SetNodeType(-1);
496  // loop through the event sample ending up in this node and check for events with negative weight
497  // those "cannot" be boosted normally. Hence, if there is one of those
498  // is misclassified, find randomly as many events with positive weights in this
499  // node as needed to get the same absolute number of weight, and mark them as
500  // "not to be boosted" in order to make up for not boosting the negative weight event
501  }
502 
503 
504  if (node->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(node->GetDepth());
505  }
506 
507  // if (IsRootNode) this->CleanTree();
508  return fNNodes;
509 }
510 
511 ////////////////////////////////////////////////////////////////////////////////
512 
513 void TMVA::DecisionTree::FillTree( const std::vector<TMVA::Event*> & eventSample )
514 
515 {
516  // fill the existing the decision tree structure by filling event
517  // in from the top node and see where they happen to end up
518  for (UInt_t i=0; i<eventSample.size(); i++) {
519  this->FillEvent(*(eventSample[i]),NULL);
520  }
521 }
522 
523 ////////////////////////////////////////////////////////////////////////////////
524 /// fill the existing the decision tree structure by filling event
525 /// in from the top node and see where they happen to end up
526 
528  TMVA::DecisionTreeNode *node )
529 {
530  if (node == NULL) { // that's the start, take the Root node
531  node = this->GetRoot();
532  }
533 
534  node->IncrementNEvents( event.GetWeight() );
536 
537  if (event.GetClass() == fSigClass) {
538  node->IncrementNSigEvents( event.GetWeight() );
540  }
541  else {
542  node->IncrementNBkgEvents( event.GetWeight() );
544  }
545  node->SetSeparationIndex(fSepType->GetSeparationIndex(node->GetNSigEvents(),
546  node->GetNBkgEvents()));
547 
548  if (node->GetNodeType() == 0) { //intermediate node --> go down
549  if (node->GoesRight(event))
550  this->FillEvent(event,dynamic_cast<TMVA::DecisionTreeNode*>(node->GetRight())) ;
551  else
552  this->FillEvent(event,dynamic_cast<TMVA::DecisionTreeNode*>(node->GetLeft())) ;
553  }
554 
555 
556 }
557 
558 ////////////////////////////////////////////////////////////////////////////////
559 /// clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree
560 
562 {
563  if (this->GetRoot()!=NULL) this->GetRoot()->ClearNodeAndAllDaughters();
564 
565 }
566 
567 ////////////////////////////////////////////////////////////////////////////////
568 /// remove those last splits that result in two leaf nodes that
569 /// are both of the type (i.e. both signal or both background)
570 /// this of course is only a reasonable thing to do when you use
571 /// "YesOrNo" leafs, while it might loose s.th. if you use the
572 /// purity information in the nodes.
573 /// --> hence I don't call it automatically in the tree building
574 
576 {
577  if (node==NULL) {
578  node = this->GetRoot();
579  }
580 
581  DecisionTreeNode *l = node->GetLeft();
582  DecisionTreeNode *r = node->GetRight();
583 
584  if (node->GetNodeType() == 0) {
585  this->CleanTree(l);
586  this->CleanTree(r);
587  if (l->GetNodeType() * r->GetNodeType() > 0) {
588 
589  this->PruneNode(node);
590  }
591  }
592  // update the number of nodes after the cleaning
593  return this->CountNodes();
594 
595 }
596 
597 ////////////////////////////////////////////////////////////////////////////////
598 /// prune (get rid of internal nodes) the Decision tree to avoid overtraining
599 /// serveral different pruning methods can be applied as selected by the
600 /// variable "fPruneMethod".
601 
603 {
604  // std::ofstream logfile("dt_pruning.log");
605 
606 
607 
608  IPruneTool* tool(NULL);
609  PruningInfo* info(NULL);
610 
611  if( fPruneMethod == kNoPruning ) return 0.0;
612 
613  if (fPruneMethod == kExpectedErrorPruning)
614  // tool = new ExpectedErrorPruneTool(logfile);
615  tool = new ExpectedErrorPruneTool();
616  else if (fPruneMethod == kCostComplexityPruning)
617  {
618  tool = new CostComplexityPruneTool();
619  }
620  else {
621  Log() << kFATAL << "Selected pruning method not yet implemented "
622  << Endl;
623  }
624 
625  if(!tool) return 0.0;
626 
627  tool->SetPruneStrength(GetPruneStrength());
628  if(tool->IsAutomatic()) {
629  if(validationSample == NULL){
630  Log() << kFATAL << "Cannot automate the pruning algorithm without an "
631  << "independent validation sample!" << Endl;
632  }else if(validationSample->size() == 0) {
633  Log() << kFATAL << "Cannot automate the pruning algorithm with "
634  << "independent validation sample of ZERO events!" << Endl;
635  }
636  }
637 
638  info = tool->CalculatePruningInfo(this,validationSample);
639  Double_t pruneStrength=0;
640  if(!info) {
641  Log() << kFATAL << "Error pruning tree! Check prune.log for more information."
642  << Endl;
643  } else {
644  pruneStrength = info->PruneStrength;
645 
646  // Log() << kDEBUG << "Optimal prune strength (alpha): " << pruneStrength
647  // << " has quality index " << info->QualityIndex << Endl;
648 
649 
650  for (UInt_t i = 0; i < info->PruneSequence.size(); ++i) {
651 
652  PruneNode(info->PruneSequence[i]);
653  }
654  // update the number of nodes after the pruning
655  this->CountNodes();
656  }
657 
658  delete tool;
659  delete info;
660 
661  return pruneStrength;
662 };
663 
664 
665 ////////////////////////////////////////////////////////////////////////////////
666 /// run the validation sample through the (pruned) tree and fill in the nodes
667 /// the variables NSValidation and NBValidadtion (i.e. how many of the Signal
668 /// and Background events from the validation sample. This is then later used
669 /// when asking for the "tree quality" ..
670 
671 void TMVA::DecisionTree::ApplyValidationSample( const EventConstList* validationSample ) const
672 {
673  GetRoot()->ResetValidationData();
674  for (UInt_t ievt=0; ievt < validationSample->size(); ievt++) {
675  CheckEventWithPrunedTree((*validationSample)[ievt]);
676  }
677 }
678 
679 ////////////////////////////////////////////////////////////////////////////////
680 /// return the misclassification rate of a pruned tree
681 /// a "pruned tree" may have set the variable "IsTerminal" to "arbitrary" at
682 /// any node, hence this tree quality testing will stop there, hence test
683 /// the pruned tree (while the full tree is still in place for normal/later use)
684 
686 {
687  if (n == NULL) { // default, start at the tree top, then descend recursively
688  n = this->GetRoot();
689  if (n == NULL) {
690  Log() << kFATAL << "TestPrunedTreeQuality: started with undefined ROOT node" <<Endl;
691  return 0;
692  }
693  }
694 
695  if( n->GetLeft() != NULL && n->GetRight() != NULL && !n->IsTerminal() ) {
696  return (TestPrunedTreeQuality( n->GetLeft(), mode ) +
697  TestPrunedTreeQuality( n->GetRight(), mode ));
698  }
699  else { // terminal leaf (in a pruned subtree of T_max at least)
700  if (DoRegression()) {
701  Double_t sumw = n->GetNSValidation() + n->GetNBValidation();
702  return n->GetSumTarget2() - 2*n->GetSumTarget()*n->GetResponse() + sumw*n->GetResponse()*n->GetResponse();
703  }
704  else {
705  if (mode == 0) {
706  if (n->GetPurity() > this->GetNodePurityLimit()) // this is a signal leaf, according to the training
707  return n->GetNBValidation();
708  else
709  return n->GetNSValidation();
710  }
711  else if ( mode == 1 ) {
712  // calculate the weighted error using the pruning validation sample
713  return (n->GetPurity() * n->GetNBValidation() + (1.0 - n->GetPurity()) * n->GetNSValidation());
714  }
715  else {
716  throw std::string("Unknown ValidationQualityMode");
717  }
718  }
719  }
720 }
721 
722 ////////////////////////////////////////////////////////////////////////////////
723 /// pass a single validation event throught a pruned decision tree
724 /// on the way down the tree, fill in all the "intermediate" information
725 /// that would normally be there from training.
726 
728 {
729  DecisionTreeNode* current = this->GetRoot();
730  if (current == NULL) {
731  Log() << kFATAL << "CheckEventWithPrunedTree: started with undefined ROOT node" <<Endl;
732  }
733 
734  while(current != NULL) {
735  if(e->GetClass() == fSigClass)
736  current->SetNSValidation(current->GetNSValidation() + e->GetWeight());
737  else
738  current->SetNBValidation(current->GetNBValidation() + e->GetWeight());
739 
740  if (e->GetNTargets() > 0) {
741  current->AddToSumTarget(e->GetWeight()*e->GetTarget(0));
742  current->AddToSumTarget2(e->GetWeight()*e->GetTarget(0)*e->GetTarget(0));
743  }
744 
745  if (current->GetRight() == NULL || current->GetLeft() == NULL) {
746  current = NULL;
747  }
748  else {
749  if (current->GoesRight(*e))
750  current = (TMVA::DecisionTreeNode*)current->GetRight();
751  else
752  current = (TMVA::DecisionTreeNode*)current->GetLeft();
753  }
754  }
755 }
756 
757 ////////////////////////////////////////////////////////////////////////////////
758 /// calculate the normalization factor for a pruning validation sample
759 
761 {
762  Double_t sumWeights = 0.0;
763  for( EventConstList::const_iterator it = validationSample->begin();
764  it != validationSample->end(); ++it ) {
765  sumWeights += (*it)->GetWeight();
766  }
767  return sumWeights;
768 }
769 
770 
771 
772 ////////////////////////////////////////////////////////////////////////////////
773 /// return the number of terminal nodes in the sub-tree below Node n
774 
776 {
777  if (n == NULL) { // default, start at the tree top, then descend recursively
778  n = this->GetRoot();
779  if (n == NULL) {
780  Log() << kFATAL << "CountLeafNodes: started with undefined ROOT node" <<Endl;
781  return 0;
782  }
783  }
784 
785  UInt_t countLeafs=0;
786 
787  if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
788  countLeafs += 1;
789  }
790  else {
791  if (this->GetLeftDaughter(n) != NULL) {
792  countLeafs += this->CountLeafNodes( this->GetLeftDaughter(n) );
793  }
794  if (this->GetRightDaughter(n) != NULL) {
795  countLeafs += this->CountLeafNodes( this->GetRightDaughter(n) );
796  }
797  }
798  return countLeafs;
799 }
800 
801 ////////////////////////////////////////////////////////////////////////////////
802 /// descend a tree to find all its leaf nodes
803 
805 {
806  if (n == NULL) { // default, start at the tree top, then descend recursively
807  n = this->GetRoot();
808  if (n == NULL) {
809  Log() << kFATAL << "DescendTree: started with undefined ROOT node" <<Endl;
810  return ;
811  }
812  }
813 
814  if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
815  // do nothing
816  }
817  else if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) != NULL) ) {
818  Log() << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
819  return;
820  }
821  else if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) == NULL) ) {
822  Log() << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
823  return;
824  }
825  else {
826  if (this->GetLeftDaughter(n) != NULL) {
827  this->DescendTree( this->GetLeftDaughter(n) );
828  }
829  if (this->GetRightDaughter(n) != NULL) {
830  this->DescendTree( this->GetRightDaughter(n) );
831  }
832  }
833 }
834 
835 ////////////////////////////////////////////////////////////////////////////////
836 /// prune away the subtree below the node
837 
839 {
840  DecisionTreeNode *l = node->GetLeft();
841  DecisionTreeNode *r = node->GetRight();
842 
843  node->SetRight(NULL);
844  node->SetLeft(NULL);
845  node->SetSelector(-1);
846  node->SetSeparationGain(-1);
847  if (node->GetPurity() > fNodePurityLimit) node->SetNodeType(1);
848  else node->SetNodeType(-1);
849  this->DeleteNode(l);
850  this->DeleteNode(r);
851  // update the stored number of nodes in the Tree
852  this->CountNodes();
853 
854 }
855 
856 ////////////////////////////////////////////////////////////////////////////////
857 /// prune a node temporaily (without actually deleting its decendants
858 /// which allows testing the pruned tree quality for many different
859 /// pruning stages without "touching" the tree.
860 
862  if(node == NULL) return;
863  node->SetNTerminal(1);
864  node->SetSubTreeR( node->GetNodeR() );
867  node->SetTerminal(kTRUE); // set the node to be terminal without deleting its descendants FIXME not needed
868 }
869 
870 ////////////////////////////////////////////////////////////////////////////////
871 /// retrieve node from the tree. Its position (up to a maximal tree depth of 64)
872 /// is coded as a sequence of left-right moves starting from the root, coded as
873 /// 0-1 bit patterns stored in the "long-integer" (i.e. 0:left ; 1:right
874 
876 {
877  Node* current = this->GetRoot();
878 
879  for (UInt_t i =0; i < depth; i++) {
880  ULong_t tmp = 1 << i;
881  if ( tmp & sequence) current = this->GetRightDaughter(current);
882  else current = this->GetLeftDaughter(current);
883  }
884 
885  return current;
886 }
887 
888 
889 ////////////////////////////////////////////////////////////////////////////////
890 ///
891 
892 void TMVA::DecisionTree::GetRandomisedVariables(Bool_t *useVariable, UInt_t *mapVariable, UInt_t &useNvars){
893  for (UInt_t ivar=0; ivar<fNvars; ivar++) useVariable[ivar]=kFALSE;
894  if (fUseNvars==0) { // no number specified ... choose s.th. which hopefully works well
895  // watch out, should never happen as it is initialised automatically in MethodBDT already!!!
896  fUseNvars = UInt_t(TMath::Sqrt(fNvars)+0.6);
897  }
898  if (fUsePoissonNvars) useNvars=TMath::Min(fNvars,TMath::Max(UInt_t(1),(UInt_t) fMyTrandom->Poisson(fUseNvars)));
899  else useNvars = fUseNvars;
900 
901  UInt_t nSelectedVars = 0;
902  while (nSelectedVars < useNvars) {
903  Double_t bla = fMyTrandom->Rndm()*fNvars;
904  useVariable[Int_t (bla)] = kTRUE;
905  nSelectedVars = 0;
906  for (UInt_t ivar=0; ivar < fNvars; ivar++) {
907  if (useVariable[ivar] == kTRUE) {
908  mapVariable[nSelectedVars] = ivar;
909  nSelectedVars++;
910  }
911  }
912  }
913  if (nSelectedVars != useNvars) { std::cout << "Bug in TrainNode - GetRandisedVariables()... sorry" << std::endl; std::exit(1);}
914 }
915 
916 ////////////////////////////////////////////////////////////////////////////////
917 /// Decide how to split a node using one of the variables that gives
918 /// the best separation of signal/background. In order to do this, for each
919 /// variable a scan of the different cut values in a grid (grid = fNCuts) is
920 /// performed and the resulting separation gains are compared.
921 /// in addition to the individual variables, one can also ask for a fisher
922 /// discriminant being built out of (some) of the variables and used as a
923 /// possible multivariate split.
924 
926  TMVA::DecisionTreeNode *node )
927 {
928  Double_t separationGainTotal = -1, sepTmp;
929  Double_t *separationGain = new Double_t[fNvars+1];
930  Int_t *cutIndex = new Int_t[fNvars+1]; //-1;
931 
932  for (UInt_t ivar=0; ivar <= fNvars; ivar++) {
933  separationGain[ivar]=-1;
934  cutIndex[ivar]=-1;
935  }
936  Int_t mxVar = -1;
937  Bool_t cutType = kTRUE;
938  Double_t nTotS, nTotB;
939  Int_t nTotS_unWeighted, nTotB_unWeighted;
940  UInt_t nevents = eventSample.size();
941 
942 
943  // the +1 comes from the fact that I treat later on the Fisher output as an
944  // additional possible variable.
945  Bool_t *useVariable = new Bool_t[fNvars+1]; // for performance reasons instead of std::vector<Bool_t> useVariable(fNvars);
946  UInt_t *mapVariable = new UInt_t[fNvars+1]; // map the subset of variables used in randomised trees to the original variable number (used in the Event() )
947 
948  std::vector<Double_t> fisherCoeff;
949 
950  if (fRandomisedTree) { // choose for each node splitting a random subset of variables to choose from
951  UInt_t tmp=fUseNvars;
952  GetRandomisedVariables(useVariable,mapVariable,tmp);
953  }
954  else {
955  for (UInt_t ivar=0; ivar < fNvars; ivar++) {
956  useVariable[ivar] = kTRUE;
957  mapVariable[ivar] = ivar;
958  }
959  }
960  useVariable[fNvars] = kFALSE; //by default fisher is not used..
961 
962  Bool_t fisherOK = kFALSE; // flag to show that the fisher discriminant could be calculated correctly or not;
963  if (fUseFisherCuts) {
964  useVariable[fNvars] = kTRUE; // that's were I store the "fisher MVA"
965 
966  //use for the Fisher discriminant ONLY those variables that show
967  //some reasonable linear correlation in either Signal or Background
968  Bool_t *useVarInFisher = new Bool_t[fNvars]; // for performance reasons instead of std::vector<Bool_t> useVariable(fNvars);
969  UInt_t *mapVarInFisher = new UInt_t[fNvars]; // map the subset of variables used in randomised trees to the original variable number (used in the Event() )
970  for (UInt_t ivar=0; ivar < fNvars; ivar++) {
971  useVarInFisher[ivar] = kFALSE;
972  mapVarInFisher[ivar] = ivar;
973  }
974 
975  std::vector<TMatrixDSym*>* covMatrices;
976  covMatrices = gTools().CalcCovarianceMatrices( eventSample, 2 ); // currently for 2 classes only
977  if (!covMatrices){
978  Log() << kWARNING << " in TrainNodeFast, the covariance Matrices needed for the Fisher-Cuts returned error --> revert to just normal cuts for this node" << Endl;
979  fisherOK = kFALSE;
980  }else{
981  TMatrixD *ss = new TMatrixD(*(covMatrices->at(0)));
982  TMatrixD *bb = new TMatrixD(*(covMatrices->at(1)));
983  const TMatrixD *s = gTools().GetCorrelationMatrix(ss);
984  const TMatrixD *b = gTools().GetCorrelationMatrix(bb);
985 
986  for (UInt_t ivar=0; ivar < fNvars; ivar++) {
987  for (UInt_t jvar=ivar+1; jvar < fNvars; jvar++) {
988  if ( ( TMath::Abs( (*s)(ivar, jvar)) > fMinLinCorrForFisher) ||
989  ( TMath::Abs( (*b)(ivar, jvar)) > fMinLinCorrForFisher) ){
990  useVarInFisher[ivar] = kTRUE;
991  useVarInFisher[jvar] = kTRUE;
992  }
993  }
994  }
995  // now as you know which variables you want to use, count and map them:
996  // such that you can use an array/matrix filled only with THOSE variables
997  // that you used
998  UInt_t nFisherVars = 0;
999  for (UInt_t ivar=0; ivar < fNvars; ivar++) {
1000  //now .. pick those variables that are used in the FIsher and are also
1001  // part of the "allowed" variables in case of Randomized Trees)
1002  if (useVarInFisher[ivar] && useVariable[ivar]) {
1003  mapVarInFisher[nFisherVars++]=ivar;
1004  // now exclud the the variables used in the Fisher cuts, and don't
1005  // use them anymore in the individual variable scan
1006  if (fUseExclusiveVars) useVariable[ivar] = kFALSE;
1007  }
1008  }
1009 
1010 
1011  fisherCoeff = this->GetFisherCoefficients(eventSample, nFisherVars, mapVarInFisher);
1012  fisherOK = kTRUE;
1013  }
1014  delete [] useVarInFisher;
1015  delete [] mapVarInFisher;
1016 
1017  }
1018 
1019 
1020  UInt_t cNvars = fNvars;
1021  if (fUseFisherCuts && fisherOK) cNvars++; // use the Fisher output simple as additional variable
1022 
1023  UInt_t* nBins = new UInt_t [cNvars];
1024 
1025  Double_t** nSelS = new Double_t* [cNvars];
1026  Double_t** nSelB = new Double_t* [cNvars];
1027  Double_t** nSelS_unWeighted = new Double_t* [cNvars];
1028  Double_t** nSelB_unWeighted = new Double_t* [cNvars];
1029  Double_t** target = new Double_t* [cNvars];
1030  Double_t** target2 = new Double_t* [cNvars];
1031  Double_t** cutValues = new Double_t* [cNvars];
1032 
1033  for (UInt_t ivar=0; ivar<cNvars; ivar++) {
1034  nBins[ivar] = fNCuts+1;
1035  if (ivar < fNvars) {
1036  if (fDataSetInfo->GetVariableInfo(ivar).GetVarType() == 'I') {
1037  nBins[ivar] = node->GetSampleMax(ivar) - node->GetSampleMin(ivar) + 1;
1038  }
1039  }
1040 
1041  nSelS[ivar] = new Double_t [nBins[ivar]];
1042  nSelB[ivar] = new Double_t [nBins[ivar]];
1043  nSelS_unWeighted[ivar] = new Double_t [nBins[ivar]];
1044  nSelB_unWeighted[ivar] = new Double_t [nBins[ivar]];
1045  target[ivar] = new Double_t [nBins[ivar]];
1046  target2[ivar] = new Double_t [nBins[ivar]];
1047  cutValues[ivar] = new Double_t [nBins[ivar]];
1048 
1049  }
1050 
1051  Double_t *xmin = new Double_t[cNvars];
1052  Double_t *xmax = new Double_t[cNvars];
1053 
1054  for (UInt_t ivar=0; ivar < cNvars; ivar++) {
1055  if (ivar < fNvars){
1056  xmin[ivar]=node->GetSampleMin(ivar);
1057  xmax[ivar]=node->GetSampleMax(ivar);
1058  if (xmax[ivar]-xmin[ivar] < std::numeric_limits<double>::epsilon() ) {
1059  // std::cout << " variable " << ivar << " has no proper range in (xmax[ivar]-xmin[ivar] = " << xmax[ivar]-xmin[ivar] << std::endl;
1060  // std::cout << " will set useVariable[ivar]=false"<<std::endl;
1061  useVariable[ivar]=kFALSE;
1062  }
1063 
1064  } else { // the fisher variable
1065  xmin[ivar]=999;
1066  xmax[ivar]=-999;
1067  // too bad, for the moment I don't know how to do this without looping
1068  // once to get the "min max" and then AGAIN to fill the histogram
1069  for (UInt_t iev=0; iev<nevents; iev++) {
1070  // returns the Fisher value (no fixed range)
1071  Double_t result = fisherCoeff[fNvars]; // the fisher constant offset
1072  for (UInt_t jvar=0; jvar<fNvars; jvar++)
1073  result += fisherCoeff[jvar]*(eventSample[iev])->GetValue(jvar);
1074  if (result > xmax[ivar]) xmax[ivar]=result;
1075  if (result < xmin[ivar]) xmin[ivar]=result;
1076  }
1077  }
1078  for (UInt_t ibin=0; ibin<nBins[ivar]; ibin++) {
1079  nSelS[ivar][ibin]=0;
1080  nSelB[ivar][ibin]=0;
1081  nSelS_unWeighted[ivar][ibin]=0;
1082  nSelB_unWeighted[ivar][ibin]=0;
1083  target[ivar][ibin]=0;
1084  target2[ivar][ibin]=0;
1085  cutValues[ivar][ibin]=0;
1086  }
1087  }
1088 
1089  // fill the cut values for the scan:
1090  for (UInt_t ivar=0; ivar < cNvars; ivar++) {
1091 
1092  if ( useVariable[ivar] ) {
1093 
1094  //set the grid for the cut scan on the variables like this:
1095  //
1096  // | | | | | ... | |
1097  // xmin xmax
1098  //
1099  // cut 0 1 2 3 ... fNCuts-1 (counting from zero)
1100  // bin 0 1 2 3 ..... nBins-1=fNCuts (counting from zero)
1101  // --> nBins = fNCuts+1
1102  // (NOTE, the cuts at xmin or xmax would just give the whole sample and
1103  // hence can be safely omitted
1104 
1105  Double_t istepSize =( xmax[ivar] - xmin[ivar] ) / Double_t(nBins[ivar]);
1106  if (ivar < fNvars) {
1107  if (fDataSetInfo->GetVariableInfo(ivar).GetVarType() == 'I') istepSize = 1;
1108  }
1109 
1110  // std::cout << "ivar="<<ivar
1111  // <<" min="<<xmin[ivar]
1112  // << " max="<<xmax[ivar]
1113  // << " widht=" << istepSize
1114  // << " nBins["<<ivar<<"]="<<nBins[ivar]<<std::endl;
1115  for (UInt_t icut=0; icut<nBins[ivar]-1; icut++) {
1116  cutValues[ivar][icut]=xmin[ivar]+(Double_t(icut+1))*istepSize;
1117  // std::cout << " cutValues["<<ivar<<"]["<<icut<<"]=" << cutValues[ivar][icut] << std::endl;
1118  }
1119  }
1120  }
1121 
1122  nTotS=0; nTotB=0;
1123  nTotS_unWeighted=0; nTotB_unWeighted=0;
1124  for (UInt_t iev=0; iev<nevents; iev++) {
1125 
1126  Double_t eventWeight = eventSample[iev]->GetWeight();
1127  if (eventSample[iev]->GetClass() == fSigClass) {
1128  nTotS+=eventWeight;
1129  nTotS_unWeighted++;
1130  }
1131  else {
1132  nTotB+=eventWeight;
1133  nTotB_unWeighted++;
1134  }
1135 
1136  Int_t iBin=-1;
1137  for (UInt_t ivar=0; ivar < cNvars; ivar++) {
1138  // now scan trough the cuts for each varable and find which one gives
1139  // the best separationGain at the current stage.
1140  if ( useVariable[ivar] ) {
1141  Double_t eventData;
1142  if (ivar < fNvars) eventData = eventSample[iev]->GetValue(ivar);
1143  else { // the fisher variable
1144  eventData = fisherCoeff[fNvars];
1145  for (UInt_t jvar=0; jvar<fNvars; jvar++)
1146  eventData += fisherCoeff[jvar]*(eventSample[iev])->GetValue(jvar);
1147 
1148  }
1149  // "maximum" is nbins-1 (the "-1" because we start counting from 0 !!
1150  iBin = TMath::Min(Int_t(nBins[ivar]-1),TMath::Max(0,int (nBins[ivar]*(eventData-xmin[ivar])/(xmax[ivar]-xmin[ivar]) ) ));
1151  if (eventSample[iev]->GetClass() == fSigClass) {
1152  nSelS[ivar][iBin]+=eventWeight;
1153  nSelS_unWeighted[ivar][iBin]++;
1154  }
1155  else {
1156  nSelB[ivar][iBin]+=eventWeight;
1157  nSelB_unWeighted[ivar][iBin]++;
1158  }
1159  if (DoRegression()) {
1160  target[ivar][iBin] +=eventWeight*eventSample[iev]->GetTarget(0);
1161  target2[ivar][iBin]+=eventWeight*eventSample[iev]->GetTarget(0)*eventSample[iev]->GetTarget(0);
1162  }
1163  }
1164  }
1165  }
1166  // now turn the "histogram" into a cumulative distribution
1167  for (UInt_t ivar=0; ivar < cNvars; ivar++) {
1168  if (useVariable[ivar]) {
1169  for (UInt_t ibin=1; ibin < nBins[ivar]; ibin++) {
1170  nSelS[ivar][ibin]+=nSelS[ivar][ibin-1];
1171  nSelS_unWeighted[ivar][ibin]+=nSelS_unWeighted[ivar][ibin-1];
1172  nSelB[ivar][ibin]+=nSelB[ivar][ibin-1];
1173  nSelB_unWeighted[ivar][ibin]+=nSelB_unWeighted[ivar][ibin-1];
1174  if (DoRegression()) {
1175  target[ivar][ibin] +=target[ivar][ibin-1] ;
1176  target2[ivar][ibin]+=target2[ivar][ibin-1];
1177  }
1178  }
1179  if (nSelS_unWeighted[ivar][nBins[ivar]-1] +nSelB_unWeighted[ivar][nBins[ivar]-1] != eventSample.size()) {
1180  Log() << kFATAL << "Helge, you have a bug ....nSelS_unw..+nSelB_unw..= "
1181  << nSelS_unWeighted[ivar][nBins[ivar]-1] +nSelB_unWeighted[ivar][nBins[ivar]-1]
1182  << " while eventsample size = " << eventSample.size()
1183  << Endl;
1184  }
1185  double lastBins=nSelS[ivar][nBins[ivar]-1] +nSelB[ivar][nBins[ivar]-1];
1186  double totalSum=nTotS+nTotB;
1187  if (TMath::Abs(lastBins-totalSum)/totalSum>0.01) {
1188  Log() << kFATAL << "Helge, you have another bug ....nSelS+nSelB= "
1189  << lastBins
1190  << " while total number of events = " << totalSum
1191  << Endl;
1192  }
1193  }
1194  }
1195  // now select the optimal cuts for each varable and find which one gives
1196  // the best separationGain at the current stage
1197  for (UInt_t ivar=0; ivar < cNvars; ivar++) {
1198  if (useVariable[ivar]) {
1199  for (UInt_t iBin=0; iBin<nBins[ivar]-1; iBin++) { // the last bin contains "all events" -->skip
1200  // the separationGain is defined as the various indices (Gini, CorssEntropy, e.t.c)
1201  // calculated by the "SamplePurities" fom the branches that would go to the
1202  // left or the right from this node if "these" cuts were used in the Node:
1203  // hereby: nSelS and nSelB would go to the right branch
1204  // (nTotS - nSelS) + (nTotB - nSelB) would go to the left branch;
1205 
1206  // only allow splits where both daughter nodes match the specified miniumum number
1207  // for this use the "unweighted" events, as you are interested in statistically
1208  // significant splits, which is determined by the actual number of entries
1209  // for a node, rather than the sum of event weights.
1210 
1211  Double_t sl = nSelS_unWeighted[ivar][iBin];
1212  Double_t bl = nSelB_unWeighted[ivar][iBin];
1213  Double_t s = nTotS_unWeighted;
1214  Double_t b = nTotB_unWeighted;
1215  Double_t slW = nSelS[ivar][iBin];
1216  Double_t blW = nSelB[ivar][iBin];
1217  Double_t sW = nTotS;
1218  Double_t bW = nTotB;
1219  Double_t sr = s-sl;
1220  Double_t br = b-bl;
1221  Double_t srW = sW-slW;
1222  Double_t brW = bW-blW;
1223  // std::cout << "sl="<<sl << " bl="<<bl<<" fMinSize="<<fMinSize << "sr="<<sr << " br="<<br <<std::endl;
1224  if ( ((sl+bl)>=fMinSize && (sr+br)>=fMinSize)
1225  && ((slW+blW)>=fMinSize && (srW+brW)>=fMinSize)
1226  ) {
1227 
1228  if (DoRegression()) {
1229  sepTmp = fRegType->GetSeparationGain(nSelS[ivar][iBin]+nSelB[ivar][iBin],
1230  target[ivar][iBin],target2[ivar][iBin],
1231  nTotS+nTotB,
1232  target[ivar][nBins[ivar]-1],target2[ivar][nBins[ivar]-1]);
1233  } else {
1234  sepTmp = fSepType->GetSeparationGain(nSelS[ivar][iBin], nSelB[ivar][iBin], nTotS, nTotB);
1235  }
1236  if (separationGain[ivar] < sepTmp) {
1237  separationGain[ivar] = sepTmp;
1238  cutIndex[ivar] = iBin;
1239  }
1240  }
1241  }
1242  }
1243  }
1244 
1245 
1246  //now you have found the best separation cut for each variable, now compare the variables
1247  for (UInt_t ivar=0; ivar < cNvars; ivar++) {
1248  if (useVariable[ivar] ) {
1249  if (separationGainTotal < separationGain[ivar]) {
1250  separationGainTotal = separationGain[ivar];
1251  mxVar = ivar;
1252  }
1253  }
1254  }
1255 
1256  if (mxVar >= 0) {
1257  if (DoRegression()) {
1258  node->SetSeparationIndex(fRegType->GetSeparationIndex(nTotS+nTotB,target[0][nBins[mxVar]-1],target2[0][nBins[mxVar]-1]));
1259  node->SetResponse(target[0][nBins[mxVar]-1]/(nTotS+nTotB));
1260  if ( (target2[0][nBins[mxVar]-1]/(nTotS+nTotB) - target[0][nBins[mxVar]-1]/(nTotS+nTotB)*target[0][nBins[mxVar]-1]/(nTotS+nTotB)) < std::numeric_limits<double>::epsilon() ) {
1261  node->SetRMS(0);
1262  }else{
1263  node->SetRMS(TMath::Sqrt(target2[0][nBins[mxVar]-1]/(nTotS+nTotB) - target[0][nBins[mxVar]-1]/(nTotS+nTotB)*target[0][nBins[mxVar]-1]/(nTotS+nTotB)));
1264  }
1265  }
1266  else {
1267  node->SetSeparationIndex(fSepType->GetSeparationIndex(nTotS,nTotB));
1268  if (mxVar >=0){
1269  if (nSelS[mxVar][cutIndex[mxVar]]/nTotS > nSelB[mxVar][cutIndex[mxVar]]/nTotB) cutType=kTRUE;
1270  else cutType=kFALSE;
1271  }
1272  }
1273  node->SetSelector((UInt_t)mxVar);
1274  node->SetCutValue(cutValues[mxVar][cutIndex[mxVar]]);
1275  node->SetCutType(cutType);
1276  node->SetSeparationGain(separationGainTotal);
1277  if (mxVar < (Int_t) fNvars){ // the fisher cut is actually not used in this node, hence don't need to store fisher components
1278  node->SetNFisherCoeff(0);
1279  fVariableImportance[mxVar] += separationGainTotal*separationGainTotal * (nTotS+nTotB) * (nTotS+nTotB) ;
1280  //for (UInt_t ivar=0; ivar<fNvars; ivar++) fVariableImportance[ivar] += separationGain[ivar]*separationGain[ivar] * (nTotS+nTotB) * (nTotS+nTotB) ;
1281  }else{
1282  // allocate Fisher coefficients (use fNvars, and set the non-used ones to zero. Might
1283  // be even less storage space on average than storing also the mapping used otherwise
1284  // can always be changed relatively easy
1285  node->SetNFisherCoeff(fNvars+1);
1286  for (UInt_t ivar=0; ivar<=fNvars; ivar++) {
1287  node->SetFisherCoeff(ivar,fisherCoeff[ivar]);
1288  // take 'fisher coeff. weighted estimate as variable importance, "Don't fill the offset coefficient though :)
1289  if (ivar<fNvars){
1290  fVariableImportance[ivar] += fisherCoeff[ivar]*fisherCoeff[ivar]*separationGainTotal*separationGainTotal * (nTotS+nTotB) * (nTotS+nTotB) ;
1291  }
1292  }
1293  }
1294  }
1295  else {
1296  separationGainTotal = 0;
1297  }
1298 
1299  // if (mxVar > -1) {
1300  // std::cout << "------------------------------------------------------------------"<<std::endl;
1301  // std::cout << "cutting on Var: " << mxVar << " with cutIndex " << cutIndex[mxVar] << " being: " << cutValues[mxVar][cutIndex[mxVar]] << std::endl;
1302  // std::cout << " nSelS = " << nSelS_unWeighted[mxVar][cutIndex[mxVar]] << " nSelB = " << nSelB_unWeighted[mxVar][cutIndex[mxVar]] << " (right) sum:= " << nSelS_unWeighted[mxVar][cutIndex[mxVar]] + nSelB_unWeighted[mxVar][cutIndex[mxVar]] << std::endl;
1303  // std::cout << " nSelS = " << nTotS_unWeighted - nSelS_unWeighted[mxVar][cutIndex[mxVar]] << " nSelB = " << nTotB_unWeighted-nSelB_unWeighted[mxVar][cutIndex[mxVar]] << " (left) sum:= " << nTotS_unWeighted + nTotB_unWeighted - nSelS_unWeighted[mxVar][cutIndex[mxVar]] - nSelB_unWeighted[mxVar][cutIndex[mxVar]] << std::endl;
1304  // std::cout << " nSelS = " << nSelS[mxVar][cutIndex[mxVar]] << " nSelB = " << nSelB[mxVar][cutIndex[mxVar]] << std::endl;
1305  // std::cout << " s/s+b " << nSelS_unWeighted[mxVar][cutIndex[mxVar]]/( nSelS_unWeighted[mxVar][cutIndex[mxVar]] + nSelB_unWeighted[mxVar][cutIndex[mxVar]])
1306  // << " s/s+b " << (nTotS - nSelS_unWeighted[mxVar][cutIndex[mxVar]])/( nTotS-nSelS_unWeighted[mxVar][cutIndex[mxVar]] + nTotB-nSelB_unWeighted[mxVar][cutIndex[mxVar]]) << std::endl;
1307  // std::cout << " nTotS = " << nTotS << " nTotB = " << nTotB << std::endl;
1308  // std::cout << " separationGainTotal " << separationGainTotal << std::endl;
1309  // } else {
1310  // std::cout << "------------------------------------------------------------------"<<std::endl;
1311  // std::cout << " obviously didn't find new mxVar " << mxVar << std::endl;
1312  // }
1313  for (UInt_t i=0; i<cNvars; i++) {
1314  delete [] nSelS[i];
1315  delete [] nSelB[i];
1316  delete [] nSelS_unWeighted[i];
1317  delete [] nSelB_unWeighted[i];
1318  delete [] target[i];
1319  delete [] target2[i];
1320  delete [] cutValues[i];
1321  }
1322  delete [] nSelS;
1323  delete [] nSelB;
1324  delete [] nSelS_unWeighted;
1325  delete [] nSelB_unWeighted;
1326  delete [] target;
1327  delete [] target2;
1328  delete [] cutValues;
1329 
1330  delete [] xmin;
1331  delete [] xmax;
1332 
1333  delete [] useVariable;
1334  delete [] mapVariable;
1335 
1336  delete [] separationGain;
1337  delete [] cutIndex;
1338 
1339  delete [] nBins;
1340 
1341  return separationGainTotal;
1342 
1343 }
1344 
1345 
1346 
1347 ////////////////////////////////////////////////////////////////////////////////
1348 /// calculate the fisher coefficients for the event sample and the variables used
1349 
1350 std::vector<Double_t> TMVA::DecisionTree::GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher){
1351  std::vector<Double_t> fisherCoeff(fNvars+1);
1352 
1353  // initializaton of global matrices and vectors
1354  // average value of each variables for S, B, S+B
1355  TMatrixD* meanMatx = new TMatrixD( nFisherVars, 3 );
1356 
1357  // the covariance 'within class' and 'between class' matrices
1358  TMatrixD* betw = new TMatrixD( nFisherVars, nFisherVars );
1359  TMatrixD* with = new TMatrixD( nFisherVars, nFisherVars );
1360  TMatrixD* cov = new TMatrixD( nFisherVars, nFisherVars );
1361 
1362  //
1363  // compute mean values of variables in each sample, and the overall means
1364  //
1365 
1366  // initialize internal sum-of-weights variables
1367  Double_t sumOfWeightsS = 0;
1368  Double_t sumOfWeightsB = 0;
1369 
1370 
1371  // init vectors
1372  Double_t* sumS = new Double_t[nFisherVars];
1373  Double_t* sumB = new Double_t[nFisherVars];
1374  for (UInt_t ivar=0; ivar<nFisherVars; ivar++) { sumS[ivar] = sumB[ivar] = 0; }
1375 
1376  UInt_t nevents = eventSample.size();
1377  // compute sample means
1378  for (UInt_t ievt=0; ievt<nevents; ievt++) {
1379 
1380  // read the Training Event into "event"
1381  const Event * ev = eventSample[ievt];
1382 
1383  // sum of weights
1384  Double_t weight = ev->GetWeight();
1385  if (ev->GetClass() == fSigClass) sumOfWeightsS += weight;
1386  else sumOfWeightsB += weight;
1387 
1388  Double_t* sum = ev->GetClass() == fSigClass ? sumS : sumB;
1389  for (UInt_t ivar=0; ivar<nFisherVars; ivar++) {
1390  sum[ivar] += ev->GetValue( mapVarInFisher[ivar] )*weight;
1391  }
1392  }
1393  for (UInt_t ivar=0; ivar<nFisherVars; ivar++) {
1394  (*meanMatx)( ivar, 2 ) = sumS[ivar];
1395  (*meanMatx)( ivar, 0 ) = sumS[ivar]/sumOfWeightsS;
1396 
1397  (*meanMatx)( ivar, 2 ) += sumB[ivar];
1398  (*meanMatx)( ivar, 1 ) = sumB[ivar]/sumOfWeightsB;
1399 
1400  // signal + background
1401  (*meanMatx)( ivar, 2 ) /= (sumOfWeightsS + sumOfWeightsB);
1402  }
1403 
1404  delete [] sumS;
1405 
1406  delete [] sumB;
1407 
1408  // the matrix of covariance 'within class' reflects the dispersion of the
1409  // events relative to the center of gravity of their own class
1410 
1411  // assert required
1412 
1413  assert( sumOfWeightsS > 0 && sumOfWeightsB > 0 );
1414 
1415  // product matrices (x-<x>)(y-<y>) where x;y are variables
1416 
1417  const Int_t nFisherVars2 = nFisherVars*nFisherVars;
1418  Double_t *sum2Sig = new Double_t[nFisherVars2];
1419  Double_t *sum2Bgd = new Double_t[nFisherVars2];
1420  Double_t *xval = new Double_t[nFisherVars2];
1421  memset(sum2Sig,0,nFisherVars2*sizeof(Double_t));
1422  memset(sum2Bgd,0,nFisherVars2*sizeof(Double_t));
1423 
1424  // 'within class' covariance
1425  for (UInt_t ievt=0; ievt<nevents; ievt++) {
1426 
1427  // read the Training Event into "event"
1428  // const Event* ev = eventSample[ievt];
1429  const Event* ev = eventSample.at(ievt);
1430 
1431  Double_t weight = ev->GetWeight(); // may ignore events with negative weights
1432 
1433  for (UInt_t x=0; x<nFisherVars; x++) {
1434  xval[x] = ev->GetValue( mapVarInFisher[x] );
1435  }
1436  Int_t k=0;
1437  for (UInt_t x=0; x<nFisherVars; x++) {
1438  for (UInt_t y=0; y<nFisherVars; y++) {
1439  if ( ev->GetClass() == fSigClass ) sum2Sig[k] += ( (xval[x] - (*meanMatx)(x, 0))*(xval[y] - (*meanMatx)(y, 0)) )*weight;
1440  else sum2Bgd[k] += ( (xval[x] - (*meanMatx)(x, 1))*(xval[y] - (*meanMatx)(y, 1)) )*weight;
1441  k++;
1442  }
1443  }
1444  }
1445  Int_t k=0;
1446  for (UInt_t x=0; x<nFisherVars; x++) {
1447  for (UInt_t y=0; y<nFisherVars; y++) {
1448  (*with)(x, y) = sum2Sig[k]/sumOfWeightsS + sum2Bgd[k]/sumOfWeightsB;
1449  k++;
1450  }
1451  }
1452 
1453  delete [] sum2Sig;
1454  delete [] sum2Bgd;
1455  delete [] xval;
1456 
1457 
1458  // the matrix of covariance 'between class' reflects the dispersion of the
1459  // events of a class relative to the global center of gravity of all the class
1460  // hence the separation between classes
1461 
1462 
1463  Double_t prodSig, prodBgd;
1464 
1465  for (UInt_t x=0; x<nFisherVars; x++) {
1466  for (UInt_t y=0; y<nFisherVars; y++) {
1467 
1468  prodSig = ( ((*meanMatx)(x, 0) - (*meanMatx)(x, 2))*
1469  ((*meanMatx)(y, 0) - (*meanMatx)(y, 2)) );
1470  prodBgd = ( ((*meanMatx)(x, 1) - (*meanMatx)(x, 2))*
1471  ((*meanMatx)(y, 1) - (*meanMatx)(y, 2)) );
1472 
1473  (*betw)(x, y) = (sumOfWeightsS*prodSig + sumOfWeightsB*prodBgd) / (sumOfWeightsS + sumOfWeightsB);
1474  }
1475  }
1476 
1477 
1478 
1479  // compute full covariance matrix from sum of within and between matrices
1480  for (UInt_t x=0; x<nFisherVars; x++)
1481  for (UInt_t y=0; y<nFisherVars; y++)
1482  (*cov)(x, y) = (*with)(x, y) + (*betw)(x, y);
1483 
1484  // Fisher = Sum { [coeff]*[variables] }
1485  //
1486  // let Xs be the array of the mean values of variables for signal evts
1487  // let Xb be the array of the mean values of variables for backgd evts
1488  // let InvWith be the inverse matrix of the 'within class' correlation matrix
1489  //
1490  // then the array of Fisher coefficients is
1491  // [coeff] =TMath::Sqrt(fNsig*fNbgd)/fNevt*transpose{Xs-Xb}*InvWith
1492  TMatrixD* theMat = with; // Fishers original
1493  // TMatrixD* theMat = cov; // Mahalanobis
1494 
1495  TMatrixD invCov( *theMat );
1496  if ( TMath::Abs(invCov.Determinant()) < 10E-24 ) {
1497  Log() << kWARNING << "FisherCoeff matrix is almost singular with deterninant="
1498  << TMath::Abs(invCov.Determinant())
1499  << " did you use the variables that are linear combinations or highly correlated?"
1500  << Endl;
1501  }
1502  if ( TMath::Abs(invCov.Determinant()) < 10E-120 ) {
1503  Log() << kFATAL << "FisherCoeff matrix is singular with determinant="
1504  << TMath::Abs(invCov.Determinant())
1505  << " did you use the variables that are linear combinations?"
1506  << Endl;
1507  }
1508 
1509  invCov.Invert();
1510 
1511  // apply rescaling factor
1512  Double_t xfact = TMath::Sqrt( sumOfWeightsS*sumOfWeightsB ) / (sumOfWeightsS + sumOfWeightsB);
1513 
1514  // compute difference of mean values
1515  std::vector<Double_t> diffMeans( nFisherVars );
1516 
1517  for (UInt_t ivar=0; ivar<=fNvars; ivar++) fisherCoeff[ivar] = 0;
1518  for (UInt_t ivar=0; ivar<nFisherVars; ivar++) {
1519  for (UInt_t jvar=0; jvar<nFisherVars; jvar++) {
1520  Double_t d = (*meanMatx)(jvar, 0) - (*meanMatx)(jvar, 1);
1521  fisherCoeff[mapVarInFisher[ivar]] += invCov(ivar, jvar)*d;
1522  }
1523 
1524  // rescale
1525  fisherCoeff[mapVarInFisher[ivar]] *= xfact;
1526  }
1527 
1528  // offset correction
1529  Double_t f0 = 0.0;
1530  for (UInt_t ivar=0; ivar<nFisherVars; ivar++){
1531  f0 += fisherCoeff[mapVarInFisher[ivar]]*((*meanMatx)(ivar, 0) + (*meanMatx)(ivar, 1));
1532  }
1533  f0 /= -2.0;
1534 
1535  fisherCoeff[fNvars] = f0; //as we start counting variables from "zero", I store the fisher offset at the END
1536 
1537  return fisherCoeff;
1538 }
1539 
1540 ////////////////////////////////////////////////////////////////////////////////
1541 
1543  TMVA::DecisionTreeNode *node )
1544 {
1545  // train a node by finding the single optimal cut for a single variable
1546  // that best separates signal and background (maximizes the separation gain)
1547 
1548  Double_t nTotS = 0.0, nTotB = 0.0;
1549  Int_t nTotS_unWeighted = 0, nTotB_unWeighted = 0;
1550 
1551  std::vector<TMVA::BDTEventWrapper> bdtEventSample;
1552 
1553  // List of optimal cuts, separation gains, and cut types (removed background or signal) - one for each variable
1554  std::vector<Double_t> lCutValue( fNvars, 0.0 );
1555  std::vector<Double_t> lSepGain( fNvars, -1.0e6 );
1556  std::vector<Char_t> lCutType( fNvars ); // <----- bool is stored (for performance reasons, no std::vector<bool> has been taken)
1557  lCutType.assign( fNvars, Char_t(kFALSE) );
1558 
1559  // Initialize (un)weighted counters for signal & background
1560  // Construct a list of event wrappers that point to the original data
1561  for( std::vector<const TMVA::Event*>::const_iterator it = eventSample.begin(); it != eventSample.end(); ++it ) {
1562  if((*it)->GetClass() == fSigClass) { // signal or background event
1563  nTotS += (*it)->GetWeight();
1564  ++nTotS_unWeighted;
1565  }
1566  else {
1567  nTotB += (*it)->GetWeight();
1568  ++nTotB_unWeighted;
1569  }
1570  bdtEventSample.push_back(TMVA::BDTEventWrapper(*it));
1571  }
1572 
1573  std::vector<Char_t> useVariable(fNvars); // <----- bool is stored (for performance reasons, no std::vector<bool> has been taken)
1574  useVariable.assign( fNvars, Char_t(kTRUE) );
1575 
1576  for (UInt_t ivar=0; ivar < fNvars; ivar++) useVariable[ivar]=Char_t(kFALSE);
1577  if (fRandomisedTree) { // choose for each node splitting a random subset of variables to choose from
1578  if (fUseNvars ==0 ) { // no number specified ... choose s.th. which hopefully works well
1579  // watch out, should never happen as it is initialised automatically in MethodBDT already!!!
1580  fUseNvars = UInt_t(TMath::Sqrt(fNvars)+0.6);
1581  }
1582  Int_t nSelectedVars = 0;
1583  while (nSelectedVars < fUseNvars) {
1584  Double_t bla = fMyTrandom->Rndm()*fNvars;
1585  useVariable[Int_t (bla)] = Char_t(kTRUE);
1586  nSelectedVars = 0;
1587  for (UInt_t ivar=0; ivar < fNvars; ivar++) {
1588  if(useVariable[ivar] == Char_t(kTRUE)) nSelectedVars++;
1589  }
1590  }
1591  }
1592  else {
1593  for (UInt_t ivar=0; ivar < fNvars; ivar++) useVariable[ivar] = Char_t(kTRUE);
1594  }
1595 
1596  for( UInt_t ivar = 0; ivar < fNvars; ivar++ ) { // loop over all discriminating variables
1597  if(!useVariable[ivar]) continue; // only optimze with selected variables
1598  TMVA::BDTEventWrapper::SetVarIndex(ivar); // select the variable to sort by
1599  std::sort( bdtEventSample.begin(),bdtEventSample.end() ); // sort the event data
1600 
1601  Double_t bkgWeightCtr = 0.0, sigWeightCtr = 0.0;
1602  std::vector<TMVA::BDTEventWrapper>::iterator it = bdtEventSample.begin(), it_end = bdtEventSample.end();
1603  for( ; it != it_end; ++it ) {
1604  if((**it)->GetClass() == fSigClass ) // specify signal or background event
1605  sigWeightCtr += (**it)->GetWeight();
1606  else
1607  bkgWeightCtr += (**it)->GetWeight();
1608  // Store the accumulated signal (background) weights
1609  it->SetCumulativeWeight(false,bkgWeightCtr);
1610  it->SetCumulativeWeight(true,sigWeightCtr);
1611  }
1612 
1613  const Double_t fPMin = 1.0e-6;
1614  Bool_t cutType = kFALSE;
1615  Long64_t index = 0;
1616  Double_t separationGain = -1.0, sepTmp = 0.0, cutValue = 0.0, dVal = 0.0, norm = 0.0;
1617  // Locate the optimal cut for this (ivar-th) variable
1618  for( it = bdtEventSample.begin(); it != it_end; ++it ) {
1619  if( index == 0 ) { ++index; continue; }
1620  if( *(*it) == NULL ) {
1621  Log() << kFATAL << "In TrainNodeFull(): have a null event! Where index="
1622  << index << ", and parent node=" << node->GetParent() << Endl;
1623  break;
1624  }
1625  dVal = bdtEventSample[index].GetVal() - bdtEventSample[index-1].GetVal();
1626  norm = TMath::Abs(bdtEventSample[index].GetVal() + bdtEventSample[index-1].GetVal());
1627  // Only allow splits where both daughter nodes have the specified miniumum number of events
1628  // Splits are only sensible when the data are ordered (eg. don't split inside a sequence of 0's)
1629  if( index >= fMinSize && (nTotS_unWeighted + nTotB_unWeighted) - index >= fMinSize && TMath::Abs(dVal/(0.5*norm + 1)) > fPMin ) {
1630  sepTmp = fSepType->GetSeparationGain( it->GetCumulativeWeight(true), it->GetCumulativeWeight(false), sigWeightCtr, bkgWeightCtr );
1631  if( sepTmp > separationGain ) {
1632  separationGain = sepTmp;
1633  cutValue = it->GetVal() - 0.5*dVal;
1634  Double_t nSelS = it->GetCumulativeWeight(true);
1635  Double_t nSelB = it->GetCumulativeWeight(false);
1636  // Indicate whether this cut is improving the node purity by removing background (enhancing signal)
1637  // or by removing signal (enhancing background)
1638  if( nSelS/sigWeightCtr > nSelB/bkgWeightCtr ) cutType = kTRUE;
1639  else cutType = kFALSE;
1640  }
1641  }
1642  ++index;
1643  }
1644  lCutType[ivar] = Char_t(cutType);
1645  lCutValue[ivar] = cutValue;
1646  lSepGain[ivar] = separationGain;
1647  }
1648 
1649  Double_t separationGain = -1.0;
1650  Int_t iVarIndex = -1;
1651  for( UInt_t ivar = 0; ivar < fNvars; ivar++ ) {
1652  if( lSepGain[ivar] > separationGain ) {
1653  iVarIndex = ivar;
1654  separationGain = lSepGain[ivar];
1655  }
1656  }
1657 
1658  if(iVarIndex >= 0) {
1659  node->SetSelector(iVarIndex);
1660  node->SetCutValue(lCutValue[iVarIndex]);
1661  node->SetSeparationGain(lSepGain[iVarIndex]);
1662  node->SetCutType(lCutType[iVarIndex]);
1663  fVariableImportance[iVarIndex] += separationGain*separationGain * (nTotS+nTotB) * (nTotS+nTotB);
1664  }
1665  else {
1666  separationGain = 0.0;
1667  }
1668 
1669  return separationGain;
1670 }
1671 
1672 ////////////////////////////////////////////////////////////////////////////////
1673 /// get the pointer to the leaf node where a particular event ends up in...
1674 /// (used in gradient boosting)
1675 
1677 {
1679  while(current->GetNodeType() == 0) { // intermediate node in a tree
1680  current = (current->GoesRight(e)) ?
1681  (TMVA::DecisionTreeNode*)current->GetRight() :
1682  (TMVA::DecisionTreeNode*)current->GetLeft();
1683  }
1684  return current;
1685 }
1686 
1687 ////////////////////////////////////////////////////////////////////////////////
1688 /// the event e is put into the decision tree (starting at the root node)
1689 /// and the output is NodeType (signal) or (background) of the final node (basket)
1690 /// in which the given events ends up. I.e. the result of the classification if
1691 /// the event for this decision tree.
1692 
1694 {
1695  TMVA::DecisionTreeNode *current = this->GetRoot();
1696  if (!current){
1697  Log() << kFATAL << "CheckEvent: started with undefined ROOT node" <<Endl;
1698  return 0; //keeps covarity happy that doesn't know that kFATAL causes an exit
1699  }
1700 
1701  while (current->GetNodeType() == 0) { // intermediate node in a (pruned) tree
1702  current = (current->GoesRight(*e)) ?
1703  current->GetRight() :
1704  current->GetLeft();
1705  if (!current) {
1706  Log() << kFATAL << "DT::CheckEvent: inconsistent tree structure" <<Endl;
1707  }
1708 
1709  }
1710 
1711  if ( DoRegression() ){
1712  return current->GetResponse();
1713  }
1714  else {
1715  if (UseYesNoLeaf) return Double_t ( current->GetNodeType() );
1716  else return current->GetPurity();
1717  }
1718 }
1719 
1720 ////////////////////////////////////////////////////////////////////////////////
1721 /// calculates the purity S/(S+B) of a given event sample
1722 
1723 Double_t TMVA::DecisionTree::SamplePurity( std::vector<TMVA::Event*> eventSample )
1724 {
1725  Double_t sumsig=0, sumbkg=0, sumtot=0;
1726  for (UInt_t ievt=0; ievt<eventSample.size(); ievt++) {
1727  if (eventSample[ievt]->GetClass() != fSigClass) sumbkg+=eventSample[ievt]->GetWeight();
1728  else sumsig+=eventSample[ievt]->GetWeight();
1729  sumtot+=eventSample[ievt]->GetWeight();
1730  }
1731  // sanity check
1732  if (sumtot!= (sumsig+sumbkg)){
1733  Log() << kFATAL << "<SamplePurity> sumtot != sumsig+sumbkg"
1734  << sumtot << " " << sumsig << " " << sumbkg << Endl;
1735  }
1736  if (sumtot>0) return sumsig/(sumsig + sumbkg);
1737  else return -1;
1738 }
1739 
1740 ////////////////////////////////////////////////////////////////////////////////
1741 /// Return the relative variable importance, normalized to all
1742 /// variables together having the importance 1. The importance in
1743 /// evaluated as the total separation-gain that this variable had in
1744 /// the decision trees (weighted by the number of events)
1745 
1747 {
1748  std::vector<Double_t> relativeImportance(fNvars);
1749  Double_t sum=0;
1750  for (UInt_t i=0; i< fNvars; i++) {
1751  sum += fVariableImportance[i];
1752  relativeImportance[i] = fVariableImportance[i];
1753  }
1754 
1755  for (UInt_t i=0; i< fNvars; i++) {
1757  relativeImportance[i] /= sum;
1758  else
1759  relativeImportance[i] = 0;
1760  }
1761  return relativeImportance;
1762 }
1763 
1764 ////////////////////////////////////////////////////////////////////////////////
1765 /// returns the relative improtance of variable ivar
1766 
1768 {
1769  std::vector<Double_t> relativeImportance = this->GetVariableImportance();
1770  if (ivar < fNvars) return relativeImportance[ivar];
1771  else {
1772  Log() << kFATAL << "<GetVariableImportance>" << Endl
1773  << "--- ivar = " << ivar << " is out of range " << Endl;
1774  }
1775 
1776  return -1;
1777 }
1778 
Double_t PruneStrength
quality measure for a pruned subtree T of T_max
Definition: IPruneTool.h:50
TServerSocket * ss
Definition: hserv2.C:30
float xmin
Definition: THbookFile.cxx:93
Random number generator class based on M.
Definition: TRandom3.h:29
void SetSelector(Short_t i)
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
void SetFisherCoeff(Int_t ivar, Double_t coeff)
set fisher coefficients
long long Long64_t
Definition: RtypesCore.h:69
Float_t GetSumTarget() const
float Float_t
Definition: RtypesCore.h:53
const char * current
Definition: demos.C:12
#define assert(cond)
Definition: unittest.h:542
std::vector< TMatrixDSym * > * CalcCovarianceMatrices(const std::vector< Event * > &events, Int_t maxCls, VariableTransformBase *transformBase=0)
compute covariance matrices
Definition: Tools.cxx:1521
virtual PruningInfo * CalculatePruningInfo(DecisionTree *dt, const EventSample *testEvents=NULL, Bool_t isAutomatic=kFALSE)=0
virtual DecisionTreeNode * GetRight() const
TMVA::DecisionTreeNode * GetEventNode(const TMVA::Event &e) const
get the pointer to the leaf node where a particular event ends up in...
void IncrementNEvents(Float_t nev)
Short_t Min(Short_t a, Short_t b)
Definition: TMathBase.h:170
Int_t GetNodeType(void) const
std::vector< Double_t > GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher)
calculate the fisher coefficients for the event sample and the variables used
std::vector< DecisionTreeNode * > PruneSequence
the regularization parameter for pruning
Definition: IPruneTool.h:51
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not...
Definition: Event.cxx:376
virtual void SetParentTree(TMVA::BinaryTree *t)
Definition: Node.h:130
virtual void SetRight(Node *r)
virtual ~DecisionTree(void)
destructor
virtual DecisionTreeNode * GetLeft() const
virtual DecisionTreeNode * GetParent() const
void SetNSigEvents_unweighted(Float_t s)
void SetResponse(Float_t r)
Short_t Abs(Short_t d)
Definition: TMathBase.h:110
void SetNBValidation(Double_t b)
std::vector< Double_t > GetVariableImportance()
Return the relative variable importance, normalized to all variables together having the importance 1...
virtual DecisionTreeNode * GetRoot() const
Definition: DecisionTree.h:102
void SetNFisherCoeff(Int_t nvars)
std::vector< const TMVA::Event * > EventConstList
Definition: DecisionTree.h:82
TClass * GetClass(T *)
Definition: TClass.h:554
Tools & gTools()
Definition: Tools.cxx:79
UInt_t GetDepth() const
Definition: Node.h:118
Double_t x[n]
Definition: legend1.C:17
Double_t GetSumWeights(const EventConstList *validationSample) const
calculate the normalization factor for a pruning validation sample
static const Int_t fgRandomSeed
Definition: DecisionTree.h:77
Double_t GetNSValidation() const
void FillTree(const EventList &eventSample)
void IncrementNBkgEvents(Float_t b)
Double_t SamplePurity(EventList eventSample)
calculates the purity S/(S+B) of a given event sample
Float_t GetPurity(void) const
int d
Definition: tornado.py:11
void SetSeparationGain(Float_t sep)
void SetNBkgEvents(Float_t b)
void SetNSValidation(Double_t s)
UInt_t CountLeafNodes(TMVA::Node *n=NULL)
return the number of terminal nodes in the sub-tree below Node n
void AddToSumTarget(Float_t t)
Double_t TrainNodeFast(const EventConstList &eventSample, DecisionTreeNode *node)
Decide how to split a node using one of the variables that gives the best separation of signal/backgr...
ROOT::Math::KDTree< _DataPoint > * BuildTree(const std::vector< const _DataPoint * > &vDataPoints, const unsigned int iBucketSize)
Double_t GetOriginalWeight() const
Definition: Event.h:84
void DescendTree(Node *n=NULL)
descend a tree to find all its leaf nodes
TMatrixT< Element > & Invert(Double_t *det=0)
Invert the matrix and calculate its determinant.
Definition: TMatrixT.cxx:1388
void FillEvent(const TMVA::Event &event, TMVA::DecisionTreeNode *node)
fill the existing the decision tree structure by filling event in from the top node and see where the...
void SetNEvents(Float_t nev)
Double_t CheckEvent(const TMVA::Event *, Bool_t UseYesNoLeaf=kFALSE) const
the event e is put into the decision tree (starting at the root node) and the output is NodeType (sig...
TMatrixT< Double_t > TMatrixD
Definition: TMatrixDfwd.h:24
Float_t GetNBkgEvents(void) const
void SetSubTreeR(Double_t r)
ROOT::R::TRInterface & r
Definition: Object.C:4
return
Definition: TBase64.cxx:62
Double_t GetNBValidation() const
virtual void SetLeft(Node *l)
void SetAlpha(Double_t alpha)
UInt_t CleanTree(DecisionTreeNode *node=NULL)
remove those last splits that result in two leaf nodes that are both of the type (i.e.
void SetSampleMin(UInt_t ivar, Float_t xmin)
set the minimum of variable ivar from the training sample that pass/end up in this node ...
void SetCutValue(Float_t c)
void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t &nVars)
unsigned int UInt_t
Definition: RtypesCore.h:42
Double_t E()
Definition: TMath.h:54
Double_t TrainNodeFull(const EventConstList &eventSample, DecisionTreeNode *node)
void SetParentTreeInNodes(Node *n=NULL)
descend a tree to find all its leaf nodes, fill max depth reached in the tree at the same time...
void SetPurity(void)
return the S/(S+B) (purity) for the node REM: even if nodes with purity 0.01 are very PURE background...
TLine * l
Definition: textangle.C:4
void CheckEventWithPrunedTree(const TMVA::Event *) const
pass a single validation event throught a pruned decision tree on the way down the tree...
Double_t GetWeight(Double_t x) const
void ReadAttr(void *node, const char *, T &value)
Definition: Tools.h:295
float xmax
Definition: THbookFile.cxx:93
Float_t GetSampleMin(UInt_t ivar) const
return the minimum of variable ivar from the training sample that pass/end up in this node ...
const Double_t infinity
Definition: CsgOps.cxx:85
Bool_t IsAutomatic() const
Definition: IPruneTool.h:97
virtual void ReadXML(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
read attributes from XML
Definition: BinaryTree.cxx:143
void PruneNodeInPlace(TMVA::DecisionTreeNode *node)
prune a node temporaily (without actually deleting its decendants which allows testing the pruned tre...
REAL epsilon
Definition: triangle.c:617
Bool_t IsTerminal() const
static void SetVarIndex(Int_t iVar)
void AddToSumTarget2(Float_t t2)
void SetSampleMax(UInt_t ivar, Float_t xmax)
set the maximum of variable ivar from the training sample that pass/end up in this node ...
ClassImp(TMVA::DecisionTree) TMVA
default constructor using the GiniIndex as separation criterion, no restrictions on minium number of ...
Double_t GetNodeR() const
double Double_t
Definition: RtypesCore.h:55
Node * GetNode(ULong_t sequence, UInt_t depth)
retrieve node from the tree.
void IncrementNSigEvents(Float_t s)
Float_t GetSumTarget2() const
void ClearTree()
clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree ...
const TMatrixD * GetCorrelationMatrix(const TMatrixD *covMat)
turns covariance into correlation matrix
Definition: Tools.cxx:337
void SetAlphaMinSubtree(Double_t g)
int type
Definition: TGX11.cxx:120
Types::EAnalysisType fAnalysisType
Definition: DecisionTree.h:248
unsigned long ULong_t
Definition: RtypesCore.h:51
static DecisionTree * CreateFromXML(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
re-create a new tree (decision tree or search tree) from XML
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:231
Double_t y[n]
Definition: legend1.C:17
UInt_t GetNTargets() const
accessor to the number of targets
Definition: Event.cxx:314
void SetNEvents_unboosted(Float_t nev)
void SetNSigEvents_unboosted(Float_t s)
void SetTerminal(Bool_t s=kTRUE)
UInt_t GetClass() const
Definition: Event.h:86
RegressionVariance * fRegType
Definition: DecisionTree.h:221
void SetNSigEvents(Float_t s)
Float_t GetTarget(UInt_t itgt) const
Definition: Event.h:101
void SetNBkgEvents_unboosted(Float_t b)
void SetNBkgEvents_unweighted(Float_t b)
char Char_t
Definition: RtypesCore.h:29
Double_t PruneTree(const EventConstList *validationSample=NULL)
prune (get rid of internal nodes) the Decision tree to avoid overtraining serveral different pruning ...
Short_t Max(Short_t a, Short_t b)
Definition: TMathBase.h:202
UInt_t BuildTree(const EventConstList &eventSample, DecisionTreeNode *node=NULL)
building the decision tree by recursively calling the splitting of one (root-) node into two daughter...
Double_t TestPrunedTreeQuality(const DecisionTreeNode *dt=NULL, Int_t mode=0) const
return the misclassification rate of a pruned tree a "pruned tree" may have set the variable "IsTermi...
#define NULL
Definition: Rtypes.h:82
virtual Double_t Determinant() const
Return the matrix determinant.
Definition: TMatrixT.cxx:1353
Float_t GetNSigEvents(void) const
void SetSeparationIndex(Float_t sep)
double result[121]
void SetRoot(Node *r)
Definition: BinaryTree.h:86
Double_t Sqrt(Double_t x)
Definition: TMath.h:464
MsgLogger & Log() const
Definition: BinaryTree.cxx:234
void SetPruneStrength(Double_t alpha)
Definition: IPruneTool.h:90
const Bool_t kTRUE
Definition: Rtypes.h:91
double norm(double *x, double *p)
Definition: unuranDistr.cxx:40
virtual Bool_t GoesRight(const Event &) const
test event if it decends the tree at this node to the right
const Int_t n
Definition: legend1.C:16
Float_t GetResponse(void) const
Definition: math.cpp:60
void SetNEvents_unweighted(Float_t nev)
void ApplyValidationSample(const EventConstList *validationSample) const
run the validation sample through the (pruned) tree and fill in the nodes the variables NSValidation ...
void PruneNode(TMVA::DecisionTreeNode *node)
prune away the subtree below the node
Float_t GetSampleMax(UInt_t ivar) const
return the maximum of variable ivar from the training sample that pass/end up in this node ...