ROOT  6.06/09
Reference Guide
BinarySearchTree.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : BinarySearchTree *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation (see header file for description) *
12  * *
13  * Authors (alphabetical): *
14  * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15  * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
16  * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
17  * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
18  * *
19  * Copyright (c) 2005: *
20  * CERN, Switzerland *
21  * U. of Victoria, Canada *
22  * MPI-K Heidelberg, Germany *
23  * LAPP, Annecy, France *
24  * *
25  * Redistribution and use in source and binary forms, with or without *
26  * modification, are permitted according to the terms listed in LICENSE *
27  * (http://tmva.sourceforge.net/LICENSE) *
28  * *
29  **********************************************************************************/
30 
31 //////////////////////////////////////////////////////////////////////////
32 // //
33 // BinarySearchTree //
34 // //
35 // A simple Binary search tree including a volume search method //
36 // //
37 //////////////////////////////////////////////////////////////////////////
38 
39 #include <stdexcept>
40 #include <cstdlib>
41 #include <queue>
42 #include <algorithm>
43 
44 // #if ROOT_VERSION_CODE >= 364802
45 // #ifndef ROOT_TMathBase
46 // #include "TMathBase.h"
47 // #endif
48 // #else
49 // #ifndef ROOT_TMath
50 #include "TMath.h"
51 // #endif
52 // #endif
53 
54 #include "TMatrixDBase.h"
55 #include "TObjString.h"
56 #include "TTree.h"
57 
58 #ifndef ROOT_TMVA_MsgLogger
59 #include "TMVA/MsgLogger.h"
60 #endif
61 #ifndef ROOT_TMVA_MethodBase
62 #include "TMVA/MethodBase.h"
63 #endif
64 #ifndef ROOT_TMVA_Tools
65 #include "TMVA/Tools.h"
66 #endif
67 #ifndef ROOT_TMVA_Event
68 #include "TMVA/Event.h"
69 #endif
70 #ifndef ROOT_TMVA_BinarySearchTree
71 #include "TMVA/BinarySearchTree.h"
72 #endif
73 
75 
76 ////////////////////////////////////////////////////////////////////////////////
77 /// default constructor
78 
79 TMVA::BinarySearchTree::BinarySearchTree( void ) :
80  BinaryTree(),
81  fPeriod ( 1 ),
82  fCurrentDepth( 0 ),
83  fStatisticsIsValid( kFALSE ),
84  fSumOfWeights( 0 ),
85  fCanNormalize( kFALSE )
86 {
87  fNEventsW[0]=fNEventsW[1]=0.;
88 }
89 
90 ////////////////////////////////////////////////////////////////////////////////
91 /// copy constructor that creates a true copy, i.e. a completely independent tree
92 
94  : BinaryTree(),
95  fPeriod ( b.fPeriod ),
96  fCurrentDepth( 0 ),
97  fStatisticsIsValid( kFALSE ),
98  fSumOfWeights( b.fSumOfWeights ),
99  fCanNormalize( kFALSE )
100 {
101  fNEventsW[0]=fNEventsW[1]=0.;
102  Log() << kFATAL << " Copy constructor not implemented yet " << Endl;
103 }
104 
105 ////////////////////////////////////////////////////////////////////////////////
106 /// destructor
107 
109 {
110  for(std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator pIt = fNormalizeTreeTable.begin();
111  pIt != fNormalizeTreeTable.end(); pIt++) {
112  delete pIt->second;
113  }
114 }
115 
116 ////////////////////////////////////////////////////////////////////////////////
117 /// re-create a new tree (decision tree or search tree) from XML
118 
120  std::string type("");
121  gTools().ReadAttr(node,"type", type);
123  bt->ReadXML( node, tmva_Version_Code );
124  return bt;
125 }
126 
127 ////////////////////////////////////////////////////////////////////////////////
128 /// insert a new "event" in the binary tree
129 
131 {
132  fCurrentDepth=0;
133  fStatisticsIsValid = kFALSE;
134 
135  if (this->GetRoot() == NULL) { // If the list is empty...
136  this->SetRoot( new BinarySearchTreeNode(event)); //Make the new node the root.
137  // have to use "s" for start as "r" for "root" would be the same as "r" for "right"
138  this->GetRoot()->SetPos('s');
139  this->GetRoot()->SetDepth(0);
140  fNNodes = 1;
141  fSumOfWeights = event->GetWeight();
142  ((BinarySearchTreeNode*)this->GetRoot())->SetSelector((UInt_t)0);
143  this->SetPeriode(event->GetNVariables());
144  }
145  else {
146  // sanity check:
147  if (event->GetNVariables() != (UInt_t)this->GetPeriode()) {
148  Log() << kFATAL << "<Insert> event vector length != Periode specified in Binary Tree" << Endl
149  << "--- event size: " << event->GetNVariables() << " Periode: " << this->GetPeriode() << Endl
150  << "--- and all this when trying filling the "<<fNNodes+1<<"th Node" << Endl;
151  }
152  // insert a new node at the propper position
153  this->Insert(event, this->GetRoot());
154  }
155 
156  // normalise the tree to speed up searches
157  if (fCanNormalize) fNormalizeTreeTable.push_back( std::make_pair(0.0,new const Event(*event)) );
158 }
159 
160 ////////////////////////////////////////////////////////////////////////////////
161 /// private internal function to insert a event (node) at the proper position
162 
164  Node *node )
165 {
166  fCurrentDepth++;
167  fStatisticsIsValid = kFALSE;
168 
169  if (node->GoesLeft(*event)){ // If the adding item is less than the current node's data...
170  if (node->GetLeft() != NULL){ // If there is a left node...
171  // Add the new event to the left node
172  this->Insert(event, node->GetLeft());
173  }
174  else { // If there is not a left node...
175  // Make the new node for the new event
176  BinarySearchTreeNode* current = new BinarySearchTreeNode(event);
177  fNNodes++;
178  fSumOfWeights += event->GetWeight();
179  current->SetSelector(fCurrentDepth%((Int_t)event->GetNVariables()));
180  current->SetParent(node); // Set the new node's previous node.
181  current->SetPos('l');
182  current->SetDepth( node->GetDepth() + 1 );
183  node->SetLeft(current); // Make it the left node of the current one.
184  }
185  }
186  else if (node->GoesRight(*event)) { // If the adding item is less than or equal to the current node's data...
187  if (node->GetRight() != NULL) { // If there is a right node...
188  // Add the new node to it.
189  this->Insert(event, node->GetRight());
190  }
191  else { // If there is not a right node...
192  // Make the new node.
193  BinarySearchTreeNode* current = new BinarySearchTreeNode(event);
194  fNNodes++;
195  fSumOfWeights += event->GetWeight();
196  current->SetSelector(fCurrentDepth%((Int_t)event->GetNVariables()));
197  current->SetParent(node); // Set the new node's previous node.
198  current->SetPos('r');
199  current->SetDepth( node->GetDepth() + 1 );
200  node->SetRight(current); // Make it the left node of the current one.
201  }
202  }
203  else Log() << kFATAL << "<Insert> neither left nor right :)" << Endl;
204 }
205 
206 ////////////////////////////////////////////////////////////////////////////////
207 ///search the tree to find the node matching "event"
208 
210 {
211  return this->Search( event, this->GetRoot() );
212 }
213 
214 ////////////////////////////////////////////////////////////////////////////////
215 /// Private, recursive, function for searching.
216 
218 {
219  if (node != NULL) { // If the node is not NULL...
220  // If we have found the node...
221  if (((BinarySearchTreeNode*)(node))->EqualsMe(*event))
222  return (BinarySearchTreeNode*)node; // Return it
223  if (node->GoesLeft(*event)) // If the node's data is greater than the search item...
224  return this->Search(event, node->GetLeft()); //Search the left node.
225  else //If the node's data is less than the search item...
226  return this->Search(event, node->GetRight()); //Search the right node.
227  }
228  else return NULL; //If the node is NULL, return NULL.
229 }
230 
231 ////////////////////////////////////////////////////////////////////////////////
232 ///return the sum of event (node) weights
233 
235 {
236  if (fSumOfWeights <= 0) {
237  Log() << kWARNING << "you asked for the SumOfWeights, which is not filled yet"
238  << " I call CalcStatistics which hopefully fixes things"
239  << Endl;
240  }
241  if (fSumOfWeights <= 0) Log() << kFATAL << " Zero events in your Search Tree" <<Endl;
242 
243  return fSumOfWeights;
244 }
245 
246 ////////////////////////////////////////////////////////////////////////////////
247 ///return the sum of event (node) weights
248 
250 {
251  if (fSumOfWeights <= 0) {
252  Log() << kWARNING << "you asked for the SumOfWeights, which is not filled yet"
253  << " I call CalcStatistics which hopefully fixes things"
254  << Endl;
255  }
256  if (fSumOfWeights <= 0) Log() << kFATAL << " Zero events in your Search Tree" <<Endl;
257 
258  return fNEventsW[ ( theType == Types::kSignal) ? 0 : 1 ];
259 }
260 
261 ////////////////////////////////////////////////////////////////////////////////
262 /// create the search tree from the event collection
263 /// using ONLY the variables specified in "theVars"
264 
265 Double_t TMVA::BinarySearchTree::Fill( const std::vector<Event*>& events, const std::vector<Int_t>& theVars,
266  Int_t theType )
267 {
268  fPeriod = theVars.size();
269  return Fill(events, theType);
270 }
271 
272 ////////////////////////////////////////////////////////////////////////////////
273 /// create the search tree from the events in a TTree
274 /// using ALL the variables specified included in the Event
275 
276 Double_t TMVA::BinarySearchTree::Fill( const std::vector<Event*>& events, Int_t theType )
277 {
278  UInt_t n=events.size();
279 
280  UInt_t nevents = 0;
281  if (fSumOfWeights != 0) {
282  Log() << kWARNING
283  << "You are filling a search three that is not empty.. "
284  << " do you know what you are doing?"
285  << Endl;
286  }
287  for (UInt_t ievt=0; ievt<n; ievt++) {
288  // insert event into binary tree
289  if (theType == -1 || (Int_t(events[ievt]->GetClass()) == theType) ) {
290  this->Insert( events[ievt] );
291  nevents++;
292  fSumOfWeights += events[ievt]->GetWeight();
293  }
294  } // end of event loop
295  CalcStatistics(0);
296 
297  return fSumOfWeights;
298 }
299 
300 ////////////////////////////////////////////////////////////////////////////////
301 
302 void TMVA::BinarySearchTree::NormalizeTree ( std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator leftBound,
303  std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator rightBound,
304  UInt_t actDim )
305 {
306  // normalises the binary-search tree to reduce the branch length and hence speed up the
307  // search procedure (on average)
308  if (leftBound == rightBound) return;
309 
310  if (actDim == fPeriod) actDim = 0;
311  for (std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator i=leftBound; i!=rightBound; i++) {
312  i->first = i->second->GetValue( actDim );
313  }
314 
315  std::sort( leftBound, rightBound );
316 
317  std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator leftTemp = leftBound;
318  std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator rightTemp = rightBound;
319 
320  // meet in the middle
321  while (true) {
322  rightTemp--;
323  if (rightTemp == leftTemp ) {
324  break;
325  }
326  leftTemp++;
327  if (leftTemp == rightTemp) {
328  break;
329  }
330  }
331 
332  std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator mid = leftTemp;
333  std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator midTemp = mid;
334 
335  if (mid!=leftBound) midTemp--;
336 
337  while (mid != leftBound && mid->second->GetValue( actDim ) == midTemp->second->GetValue( actDim )) {
338  mid--;
339  midTemp--;
340  }
341 
342  Insert( mid->second );
343 
344  // Print(std::cout);
345  // std::cout << std::endl << std::endl;
346 
347  NormalizeTree( leftBound, mid, actDim+1 );
348  mid++;
349  // Print(std::cout);
350  // std::cout << std::endl << std::endl;
351  NormalizeTree( mid, rightBound, actDim+1 );
352 
353 
354  return;
355 }
356 
357 ////////////////////////////////////////////////////////////////////////////////
358 /// Normalisation of tree
359 
361 {
362  SetNormalize( kFALSE );
363  Clear( NULL );
364  this->SetRoot(NULL);
365  NormalizeTree( fNormalizeTreeTable.begin(), fNormalizeTreeTable.end(), 0 );
366 }
367 
368 ////////////////////////////////////////////////////////////////////////////////
369 /// clear nodes
370 
372 {
373  BinarySearchTreeNode* currentNode = (BinarySearchTreeNode*)(n == NULL ? this->GetRoot() : n);
374 
375  if (currentNode->GetLeft() != 0) Clear( currentNode->GetLeft() );
376  if (currentNode->GetRight() != 0) Clear( currentNode->GetRight() );
377 
378  if (n != NULL) delete n;
379 
380  return;
381 }
382 
383 ////////////////////////////////////////////////////////////////////////////////
384 /// search the whole tree and add up all weigths of events that
385 /// lie within the given voluem
386 
388  std::vector<const BinarySearchTreeNode*>* events )
389 {
390  return SearchVolume( this->GetRoot(), volume, 0, events );
391 }
392 
393 ////////////////////////////////////////////////////////////////////////////////
394 /// recursively walk through the daughter nodes and add up all weigths of events that
395 /// lie within the given volume
396 
398  std::vector<const BinarySearchTreeNode*>* events )
399 {
400  if (t==NULL) return 0; // Are we at an outer leave?
401 
403 
404  Double_t count = 0.0;
405  if (InVolume( st->GetEventV(), volume )) {
406  count += st->GetWeight();
407  if (NULL != events) events->push_back( st );
408  }
409  if (st->GetLeft()==NULL && st->GetRight()==NULL) {
410 
411  return count; // Are we at an outer leave?
412  }
413 
414  Bool_t tl, tr;
415  Int_t d = depth%this->GetPeriode();
416  if (d != st->GetSelector()) {
417  Log() << kFATAL << "<SearchVolume> selector in Searchvolume "
418  << d << " != " << "node "<< st->GetSelector() << Endl;
419  }
420  tl = (*(volume->fLower))[d] < st->GetEventV()[d]; // Should we descend left?
421  tr = (*(volume->fUpper))[d] >= st->GetEventV()[d]; // Should we descend right?
422 
423  if (tl) count += SearchVolume( st->GetLeft(), volume, (depth+1), events );
424  if (tr) count += SearchVolume( st->GetRight(), volume, (depth+1), events );
425 
426  return count;
427 }
428 
429 Bool_t TMVA::BinarySearchTree::InVolume(const std::vector<Float_t>& event, Volume* volume ) const
430 {
431  // test if the data points are in the given volume
432 
433  Bool_t result = false;
434  for (UInt_t ivar=0; ivar< fPeriod; ivar++) {
435  result = ( (*(volume->fLower))[ivar] < event[ivar] &&
436  (*(volume->fUpper))[ivar] >= event[ivar] );
437  if (!result) break;
438  }
439  return result;
440 }
441 
442 ////////////////////////////////////////////////////////////////////////////////
443 /// calculate basic statistics (mean, rms for each variable)
444 
446 {
447  if (fStatisticsIsValid) return;
448 
449  BinarySearchTreeNode * currentNode = (BinarySearchTreeNode*)n;
450 
451  // default, start at the tree top, then descend recursively
452  if (n == NULL) {
453  fSumOfWeights = 0;
454  for (Int_t sb=0; sb<2; sb++) {
455  fNEventsW[sb] = 0;
456  fMeans[sb] = std::vector<Float_t>(fPeriod);
457  fRMS[sb] = std::vector<Float_t>(fPeriod);
458  fMin[sb] = std::vector<Float_t>(fPeriod);
459  fMax[sb] = std::vector<Float_t>(fPeriod);
460  fSum[sb] = std::vector<Double_t>(fPeriod);
461  fSumSq[sb] = std::vector<Double_t>(fPeriod);
462  for (UInt_t j=0; j<fPeriod; j++) {
463  fMeans[sb][j] = fRMS[sb][j] = fSum[sb][j] = fSumSq[sb][j] = 0;
464  fMin[sb][j] = FLT_MAX;
465  fMax[sb][j] = -FLT_MAX;
466  }
467  }
468  currentNode = (BinarySearchTreeNode*) this->GetRoot();
469  if (currentNode == NULL) return; // no root-node
470  }
471 
472  const std::vector<Float_t> & evtVec = currentNode->GetEventV();
473  Double_t weight = currentNode->GetWeight();
474 // Int_t type = currentNode->IsSignal();
475 // Int_t type = currentNode->IsSignal() ? 0 : 1;
476  Int_t type = Int_t(currentNode->GetClass())== Types::kSignal ? 0 : 1;
477 
478  fNEventsW[type] += weight;
479  fSumOfWeights += weight;
480 
481  for (UInt_t j=0; j<fPeriod; j++) {
482  Float_t val = evtVec[j];
483  fSum[type][j] += val*weight;
484  fSumSq[type][j] += val*val*weight;
485  if (val < fMin[type][j]) fMin[type][j] = val;
486  if (val > fMax[type][j]) fMax[type][j] = val;
487  }
488 
489  if ( (currentNode->GetLeft() != NULL) ) CalcStatistics( currentNode->GetLeft() );
490  if ( (currentNode->GetRight() != NULL) ) CalcStatistics( currentNode->GetRight() );
491 
492  if (n == NULL) { // i.e. the root node
493  for (Int_t sb=0; sb<2; sb++) {
494  for (UInt_t j=0; j<fPeriod; j++) {
495  if (fNEventsW[sb] == 0) { fMeans[sb][j] = fRMS[sb][j] = 0; continue; }
496  fMeans[sb][j] = fSum[sb][j]/fNEventsW[sb];
497  fRMS[sb][j] = TMath::Sqrt(fSumSq[sb][j]/fNEventsW[sb] - fMeans[sb][j]*fMeans[sb][j]);
498  }
499  }
500  fStatisticsIsValid = kTRUE;
501  }
502 
503  return;
504 }
505 
506 Int_t TMVA::BinarySearchTree::SearchVolumeWithMaxLimit( Volume *volume, std::vector<const BinarySearchTreeNode*>* events,
507  Int_t max_points )
508 {
509  // recursively walk through the daughter nodes and add up all weigths of events that
510  // lie within the given volume a maximum number of events can be given
511  if (this->GetRoot() == NULL) return 0; // Are we at an outer leave?
512 
513  std::queue< std::pair< const BinarySearchTreeNode*, Int_t > > queue;
514  std::pair< const BinarySearchTreeNode*, Int_t > st = std::make_pair( (const BinarySearchTreeNode*)this->GetRoot(), 0 );
515  queue.push( st );
516 
517  Int_t count = 0;
518 
519  while ( !queue.empty() ) {
520  st = queue.front(); queue.pop();
521 
522  if (count == max_points)
523  return count;
524 
525  if (InVolume( st.first->GetEventV(), volume )) {
526  count++;
527  if (NULL != events) events->push_back( st.first );
528  }
529 
530  Bool_t tl, tr;
531  Int_t d = st.second;
532  if ( d == Int_t(this->GetPeriode()) ) d = 0;
533 
534  if (d != st.first->GetSelector()) {
535  Log() << kFATAL << "<SearchVolume> selector in Searchvolume "
536  << d << " != " << "node "<< st.first->GetSelector() << Endl;
537  }
538 
539  tl = (*(volume->fLower))[d] < st.first->GetEventV()[d] && st.first->GetLeft() != NULL; // Should we descend left?
540  tr = (*(volume->fUpper))[d] >= st.first->GetEventV()[d] && st.first->GetRight() != NULL; // Should we descend right?
541 
542  if (tl) queue.push( std::make_pair( (const BinarySearchTreeNode*)st.first->GetLeft(), d+1 ) );
543  if (tr) queue.push( std::make_pair( (const BinarySearchTreeNode*)st.first->GetRight(), d+1 ) );
544  }
545 
546  return count;
547 }
std::vector< Double_t > * fLower
Definition: Volume.h:75
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
virtual ~BinarySearchTree(void)
destructor
Bool_t InVolume(const std::vector< Float_t > &, Volume *) const
float Float_t
Definition: RtypesCore.h:53
std::vector< Double_t > * fUpper
Definition: Volume.h:76
Int_t SearchVolumeWithMaxLimit(TMVA::Volume *, std::vector< const TMVA::BinarySearchTreeNode * > *events=0, Int_t=-1)
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
virtual void SetRight(Node *r)
Definition: Node.h:97
const Bool_t kFALSE
Definition: Rtypes.h:92
void NormalizeTree()
Normalisation of tree.
void SetDepth(UInt_t d)
Definition: Node.h:115
TClass * GetClass(T *)
Definition: TClass.h:555
Tools & gTools()
Definition: Tools.cxx:79
UInt_t GetDepth() const
Definition: Node.h:118
virtual Bool_t GoesLeft(const Event &) const =0
virtual void SetLeft(Node *l)
Definition: Node.h:96
void Clear(TMVA::Node *n=0)
clear nodes
ClassImp(TMVA::BinarySearchTree) TMVA
default constructor
Double_t SearchVolume(Volume *, std::vector< const TMVA::BinarySearchTreeNode * > *events=0)
search the whole tree and add up all weigths of events that lie within the given voluem ...
Double_t GetSumOfWeights(void) const
return the sum of event (node) weights
void CalcStatistics(TMVA::Node *n=0)
calculate basic statistics (mean, rms for each variable)
UInt_t GetNVariables() const
accessor to the number of variables
Definition: Event.cxx:303
static BinarySearchTree * CreateFromXML(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
re-create a new tree (decision tree or search tree) from XML
unsigned int UInt_t
Definition: RtypesCore.h:42
void ReadAttr(void *node, const char *, T &value)
Definition: Tools.h:295
void Insert(const Event *)
insert a new "event" in the binary tree
virtual void ReadXML(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
read attributes from XML
Definition: BinaryTree.cxx:141
virtual void SetParent(Node *p)
Definition: Node.h:98
double Double_t
Definition: RtypesCore.h:55
int type
Definition: TGX11.cxx:120
virtual Node * GetRight() const
Definition: Node.h:92
Abstract ClassifierFactory template that handles arbitrary types.
virtual Bool_t GoesRight(const Event &) const =0
Double_t Fill(const std::vector< TMVA::Event * > &events, const std::vector< Int_t > &theVars, Int_t theType=-1)
create the search tree from the event collection using ONLY the variables specified in "theVars" ...
const std::vector< Float_t > & GetEventV() const
#define NULL
Definition: Rtypes.h:82
double result[121]
Double_t Sqrt(Double_t x)
Definition: TMath.h:464
MsgLogger & Log() const
Definition: BinaryTree.cxx:232
const Bool_t kTRUE
Definition: Rtypes.h:91
BinarySearchTreeNode * Search(Event *event) const
search the tree to find the node matching "event"
const Int_t n
Definition: legend1.C:16
Definition: math.cpp:60
void SetPos(char s)
Definition: Node.h:121
virtual Node * GetLeft() const
Definition: Node.h:91