Logo ROOT  
Reference Guide
BinarySearchTree.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : BinarySearchTree *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation (see header file for description) *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
16 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
17 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
18 * *
19 * Copyright (c) 2005: *
20 * CERN, Switzerland *
21 * U. of Victoria, Canada *
22 * MPI-K Heidelberg, Germany *
23 * LAPP, Annecy, France *
24 * *
25 * Redistribution and use in source and binary forms, with or without *
26 * modification, are permitted according to the terms listed in LICENSE *
27 * (http://tmva.sourceforge.net/LICENSE) *
28 * *
29 **********************************************************************************/
30
31/*! \class TMVA::BinarySearchTree
32\ingroup TMVA
33
34A simple Binary search tree including a volume search method.
35
36*/
37
38#include <stdexcept>
39#include <cstdlib>
40#include <queue>
41#include <algorithm>
42
43#include "TMath.h"
44
45#include "TMatrixDBase.h"
46#include "TTree.h"
47
48#include "TMVA/MsgLogger.h"
49#include "TMVA/MethodBase.h"
50#include "TMVA/Tools.h"
51#include "TMVA/Event.h"
53
54#include "TMVA/BinaryTree.h"
55#include "TMVA/Types.h"
56#include "TMVA/Node.h"
57
59
60////////////////////////////////////////////////////////////////////////////////
61/// default constructor
62
65 fPeriod ( 1 ),
66 fCurrentDepth( 0 ),
67 fStatisticsIsValid( kFALSE ),
68 fSumOfWeights( 0 ),
69 fCanNormalize( kFALSE )
70{
71 fNEventsW[0]=fNEventsW[1]=0.;
72}
73
74////////////////////////////////////////////////////////////////////////////////
75/// copy constructor that creates a true copy, i.e. a completely independent tree
76
78 : BinaryTree(),
79 fPeriod ( b.fPeriod ),
80 fCurrentDepth( 0 ),
81 fStatisticsIsValid( kFALSE ),
82 fSumOfWeights( b.fSumOfWeights ),
83 fCanNormalize( kFALSE )
84{
85 fNEventsW[0]=fNEventsW[1]=0.;
86 Log() << kFATAL << " Copy constructor not implemented yet " << Endl;
87}
88
89////////////////////////////////////////////////////////////////////////////////
90/// destructor
91
93{
94 for(std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator pIt = fNormalizeTreeTable.begin();
95 pIt != fNormalizeTreeTable.end(); ++pIt) {
96 delete pIt->second;
97 }
98}
99
100////////////////////////////////////////////////////////////////////////////////
101/// re-create a new tree (decision tree or search tree) from XML
102
104 std::string type("");
105 gTools().ReadAttr(node,"type", type);
107 bt->ReadXML( node, tmva_Version_Code );
108 return bt;
109}
110
111////////////////////////////////////////////////////////////////////////////////
112/// insert a new "event" in the binary tree
113
115{
116 fCurrentDepth=0;
117 fStatisticsIsValid = kFALSE;
118
119 if (this->GetRoot() == NULL) { // If the list is empty...
120 this->SetRoot( new BinarySearchTreeNode(event)); //Make the new node the root.
121 // have to use "s" for start as "r" for "root" would be the same as "r" for "right"
122 this->GetRoot()->SetPos('s');
123 this->GetRoot()->SetDepth(0);
124 fNNodes = 1;
125 fSumOfWeights = event->GetWeight();
126 ((BinarySearchTreeNode*)this->GetRoot())->SetSelector((UInt_t)0);
127 this->SetPeriode(event->GetNVariables());
128 }
129 else {
130 // sanity check:
131 if (event->GetNVariables() != (UInt_t)this->GetPeriode()) {
132 Log() << kFATAL << "<Insert> event vector length != Periode specified in Binary Tree" << Endl
133 << "--- event size: " << event->GetNVariables() << " Periode: " << this->GetPeriode() << Endl
134 << "--- and all this when trying filling the "<<fNNodes+1<<"th Node" << Endl;
135 }
136 // insert a new node at the propper position
137 this->Insert(event, this->GetRoot());
138 }
139
140 // normalise the tree to speed up searches
141 if (fCanNormalize) fNormalizeTreeTable.push_back( std::make_pair(0.0,new const Event(*event)) );
142}
143
144////////////////////////////////////////////////////////////////////////////////
145/// private internal function to insert a event (node) at the proper position
146
148 Node *node )
149{
150 fCurrentDepth++;
151 fStatisticsIsValid = kFALSE;
152
153 if (node->GoesLeft(*event)){ // If the adding item is less than the current node's data...
154 if (node->GetLeft() != NULL){ // If there is a left node...
155 // Add the new event to the left node
156 this->Insert(event, node->GetLeft());
157 }
158 else { // If there is not a left node...
159 // Make the new node for the new event
161 fNNodes++;
162 fSumOfWeights += event->GetWeight();
163 current->SetSelector(fCurrentDepth%((Int_t)event->GetNVariables()));
164 current->SetParent(node); // Set the new node's previous node.
165 current->SetPos('l');
166 current->SetDepth( node->GetDepth() + 1 );
167 node->SetLeft(current); // Make it the left node of the current one.
168 }
169 }
170 else if (node->GoesRight(*event)) { // If the adding item is less than or equal to the current node's data...
171 if (node->GetRight() != NULL) { // If there is a right node...
172 // Add the new node to it.
173 this->Insert(event, node->GetRight());
174 }
175 else { // If there is not a right node...
176 // Make the new node.
178 fNNodes++;
179 fSumOfWeights += event->GetWeight();
180 current->SetSelector(fCurrentDepth%((Int_t)event->GetNVariables()));
181 current->SetParent(node); // Set the new node's previous node.
182 current->SetPos('r');
183 current->SetDepth( node->GetDepth() + 1 );
184 node->SetRight(current); // Make it the left node of the current one.
185 }
186 }
187 else Log() << kFATAL << "<Insert> neither left nor right :)" << Endl;
188}
189
190////////////////////////////////////////////////////////////////////////////////
191///search the tree to find the node matching "event"
192
194{
195 return this->Search( event, this->GetRoot() );
196}
197
198////////////////////////////////////////////////////////////////////////////////
199/// Private, recursive, function for searching.
200
202{
203 if (node != NULL) { // If the node is not NULL...
204 // If we have found the node...
205 if (((BinarySearchTreeNode*)(node))->EqualsMe(*event))
206 return (BinarySearchTreeNode*)node; // Return it
207 if (node->GoesLeft(*event)) // If the node's data is greater than the search item...
208 return this->Search(event, node->GetLeft()); //Search the left node.
209 else //If the node's data is less than the search item...
210 return this->Search(event, node->GetRight()); //Search the right node.
211 }
212 else return NULL; //If the node is NULL, return NULL.
213}
214
215////////////////////////////////////////////////////////////////////////////////
216/// return the sum of event (node) weights
217
219{
220 if (fSumOfWeights <= 0) {
221 Log() << kWARNING << "you asked for the SumOfWeights, which is not filled yet"
222 << " I call CalcStatistics which hopefully fixes things"
223 << Endl;
224 }
225 if (fSumOfWeights <= 0) Log() << kFATAL << " Zero events in your Search Tree" <<Endl;
226
227 return fSumOfWeights;
228}
229
230////////////////////////////////////////////////////////////////////////////////
231/// return the sum of event (node) weights
232
234{
235 if (fSumOfWeights <= 0) {
236 Log() << kWARNING << "you asked for the SumOfWeights, which is not filled yet"
237 << " I call CalcStatistics which hopefully fixes things"
238 << Endl;
239 }
240 if (fSumOfWeights <= 0) Log() << kFATAL << " Zero events in your Search Tree" <<Endl;
241
242 return fNEventsW[ ( theType == Types::kSignal) ? 0 : 1 ];
243}
244
245////////////////////////////////////////////////////////////////////////////////
246/// create the search tree from the event collection
247/// using ONLY the variables specified in "theVars"
248
249Double_t TMVA::BinarySearchTree::Fill( const std::vector<Event*>& events, const std::vector<Int_t>& theVars,
250 Int_t theType )
251{
252 fPeriod = theVars.size();
253 return Fill(events, theType);
254}
255
256////////////////////////////////////////////////////////////////////////////////
257/// create the search tree from the events in a TTree
258/// using ALL the variables specified included in the Event
259
260Double_t TMVA::BinarySearchTree::Fill( const std::vector<Event*>& events, Int_t theType )
261{
262 UInt_t n=events.size();
263
264 UInt_t nevents = 0;
265 if (fSumOfWeights != 0) {
266 Log() << kWARNING
267 << "You are filling a search three that is not empty.. "
268 << " do you know what you are doing?"
269 << Endl;
270 }
271 for (UInt_t ievt=0; ievt<n; ievt++) {
272 // insert event into binary tree
273 if (theType == -1 || (Int_t(events[ievt]->GetClass()) == theType) ) {
274 this->Insert( events[ievt] );
275 nevents++;
276 fSumOfWeights += events[ievt]->GetWeight();
277 }
278 } // end of event loop
279 CalcStatistics(0);
280
281 return fSumOfWeights;
282}
283
284////////////////////////////////////////////////////////////////////////////////
285/// normalises the binary-search tree to reduce the branch length and hence speed up the
286/// search procedure (on average).
287
288void TMVA::BinarySearchTree::NormalizeTree ( std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator leftBound,
289 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator rightBound,
290 UInt_t actDim )
291{
292
293 if (leftBound == rightBound) return;
294
295 if (actDim == fPeriod) actDim = 0;
296 for (std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator i=leftBound; i!=rightBound; ++i) {
297 i->first = i->second->GetValue( actDim );
298 }
299
300 std::sort( leftBound, rightBound );
301
302 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator leftTemp = leftBound;
303 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator rightTemp = rightBound;
304
305 // meet in the middle
306 while (true) {
307 --rightTemp;
308 if (rightTemp == leftTemp ) {
309 break;
310 }
311 ++leftTemp;
312 if (leftTemp == rightTemp) {
313 break;
314 }
315 }
316
317 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator mid = leftTemp;
318 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator midTemp = mid;
319
320 if (mid!=leftBound)--midTemp;
321
322 while (mid != leftBound && mid->second->GetValue( actDim ) == midTemp->second->GetValue( actDim )) {
323 --mid;
324 --midTemp;
325 }
326
327 Insert( mid->second );
328
329 // Print(std::cout);
330 // std::cout << std::endl << std::endl;
331
332 NormalizeTree( leftBound, mid, actDim+1 );
333 ++mid;
334 // Print(std::cout);
335 // std::cout << std::endl << std::endl;
336 NormalizeTree( mid, rightBound, actDim+1 );
337
338
339 return;
340}
341
342////////////////////////////////////////////////////////////////////////////////
343/// Normalisation of tree
344
346{
347 SetNormalize( kFALSE );
348 Clear( NULL );
349 this->SetRoot(NULL);
350 NormalizeTree( fNormalizeTreeTable.begin(), fNormalizeTreeTable.end(), 0 );
351}
352
353////////////////////////////////////////////////////////////////////////////////
354/// clear nodes
355
357{
358 BinarySearchTreeNode* currentNode = (BinarySearchTreeNode*)(n == NULL ? this->GetRoot() : n);
359
360 if (currentNode->GetLeft() != 0) Clear( currentNode->GetLeft() );
361 if (currentNode->GetRight() != 0) Clear( currentNode->GetRight() );
362
363 if (n != NULL) delete n;
364
365 return;
366}
367
368////////////////////////////////////////////////////////////////////////////////
369/// search the whole tree and add up all weights of events that
370/// lie within the given volume
371
373 std::vector<const BinarySearchTreeNode*>* events )
374{
375 return SearchVolume( this->GetRoot(), volume, 0, events );
376}
377
378////////////////////////////////////////////////////////////////////////////////
379/// recursively walk through the daughter nodes and add up all weights of events that
380/// lie within the given volume
381
383 std::vector<const BinarySearchTreeNode*>* events )
384{
385 if (t==NULL) return 0; // Are we at an outer leave?
386
388
389 Double_t count = 0.0;
390 if (InVolume( st->GetEventV(), volume )) {
391 count += st->GetWeight();
392 if (NULL != events) events->push_back( st );
393 }
394 if (st->GetLeft()==NULL && st->GetRight()==NULL) {
395
396 return count; // Are we at an outer leave?
397 }
398
399 Bool_t tl, tr;
400 Int_t d = depth%this->GetPeriode();
401 if (d != st->GetSelector()) {
402 Log() << kFATAL << "<SearchVolume> selector in Searchvolume "
403 << d << " != " << "node "<< st->GetSelector() << Endl;
404 }
405 tl = (*(volume->fLower))[d] < st->GetEventV()[d]; // Should we descend left?
406 tr = (*(volume->fUpper))[d] >= st->GetEventV()[d]; // Should we descend right?
407
408 if (tl) count += SearchVolume( st->GetLeft(), volume, (depth+1), events );
409 if (tr) count += SearchVolume( st->GetRight(), volume, (depth+1), events );
410
411 return count;
412}
413
414////////////////////////////////////////////////////////////////////////////////
415/// test if the data points are in the given volume
416
417Bool_t TMVA::BinarySearchTree::InVolume(const std::vector<Float_t>& event, Volume* volume ) const
418{
419
420 Bool_t result = false;
421 for (UInt_t ivar=0; ivar< fPeriod; ivar++) {
422 result = ( (*(volume->fLower))[ivar] < event[ivar] &&
423 (*(volume->fUpper))[ivar] >= event[ivar] );
424 if (!result) break;
425 }
426 return result;
427}
428
429////////////////////////////////////////////////////////////////////////////////
430/// calculate basic statistics (mean, rms for each variable)
431
433{
434 if (fStatisticsIsValid) return;
435
437
438 // default, start at the tree top, then descend recursively
439 if (n == NULL) {
440 fSumOfWeights = 0;
441 for (Int_t sb=0; sb<2; sb++) {
442 fNEventsW[sb] = 0;
443 fMeans[sb] = std::vector<Float_t>(fPeriod);
444 fRMS[sb] = std::vector<Float_t>(fPeriod);
445 fMin[sb] = std::vector<Float_t>(fPeriod);
446 fMax[sb] = std::vector<Float_t>(fPeriod);
447 fSum[sb] = std::vector<Double_t>(fPeriod);
448 fSumSq[sb] = std::vector<Double_t>(fPeriod);
449 for (UInt_t j=0; j<fPeriod; j++) {
450 fMeans[sb][j] = fRMS[sb][j] = fSum[sb][j] = fSumSq[sb][j] = 0;
451 fMin[sb][j] = FLT_MAX;
452 fMax[sb][j] = -FLT_MAX;
453 }
454 }
455 currentNode = (BinarySearchTreeNode*) this->GetRoot();
456 if (currentNode == NULL) return; // no root-node
457 }
458
459 const std::vector<Float_t> & evtVec = currentNode->GetEventV();
460 Double_t weight = currentNode->GetWeight();
461 // Int_t type = currentNode->IsSignal();
462 // Int_t type = currentNode->IsSignal() ? 0 : 1;
463 Int_t type = Int_t(currentNode->GetClass())== Types::kSignal ? 0 : 1;
464
465 fNEventsW[type] += weight;
466 fSumOfWeights += weight;
467
468 for (UInt_t j=0; j<fPeriod; j++) {
469 Float_t val = evtVec[j];
470 fSum[type][j] += val*weight;
471 fSumSq[type][j] += val*val*weight;
472 if (val < fMin[type][j]) fMin[type][j] = val;
473 if (val > fMax[type][j]) fMax[type][j] = val;
474 }
475
476 if ( (currentNode->GetLeft() != NULL) ) CalcStatistics( currentNode->GetLeft() );
477 if ( (currentNode->GetRight() != NULL) ) CalcStatistics( currentNode->GetRight() );
478
479 if (n == NULL) { // i.e. the root node
480 for (Int_t sb=0; sb<2; sb++) {
481 for (UInt_t j=0; j<fPeriod; j++) {
482 if (fNEventsW[sb] == 0) { fMeans[sb][j] = fRMS[sb][j] = 0; continue; }
483 fMeans[sb][j] = fSum[sb][j]/fNEventsW[sb];
484 fRMS[sb][j] = TMath::Sqrt(fSumSq[sb][j]/fNEventsW[sb] - fMeans[sb][j]*fMeans[sb][j]);
485 }
486 }
487 fStatisticsIsValid = kTRUE;
488 }
489
490 return;
491}
492
493////////////////////////////////////////////////////////////////////////////////
494/// recursively walk through the daughter nodes and add up all weights of events that
495/// lie within the given volume a maximum number of events can be given
496
497Int_t TMVA::BinarySearchTree::SearchVolumeWithMaxLimit( Volume *volume, std::vector<const BinarySearchTreeNode*>* events,
498 Int_t max_points )
499{
500 if (this->GetRoot() == NULL) return 0; // Are we at an outer leave?
501
502 std::queue< std::pair< const BinarySearchTreeNode*, Int_t > > queue;
503 std::pair< const BinarySearchTreeNode*, Int_t > st = std::make_pair( (const BinarySearchTreeNode*)this->GetRoot(), 0 );
504 queue.push( st );
505
506 Int_t count = 0;
507
508 while ( !queue.empty() ) {
509 st = queue.front(); queue.pop();
510
511 if (count == max_points)
512 return count;
513
514 if (InVolume( st.first->GetEventV(), volume )) {
515 count++;
516 if (NULL != events) events->push_back( st.first );
517 }
518
519 Bool_t tl, tr;
520 Int_t d = st.second;
521 if ( d == Int_t(this->GetPeriode()) ) d = 0;
522
523 if (d != st.first->GetSelector()) {
524 Log() << kFATAL << "<SearchVolume> selector in Searchvolume "
525 << d << " != " << "node "<< st.first->GetSelector() << Endl;
526 }
527
528 tl = (*(volume->fLower))[d] < st.first->GetEventV()[d] && st.first->GetLeft() != NULL; // Should we descend left?
529 tr = (*(volume->fUpper))[d] >= st.first->GetEventV()[d] && st.first->GetRight() != NULL; // Should we descend right?
530
531 if (tl) queue.push( std::make_pair( (const BinarySearchTreeNode*)st.first->GetLeft(), d+1 ) );
532 if (tr) queue.push( std::make_pair( (const BinarySearchTreeNode*)st.first->GetRight(), d+1 ) );
533 }
534
535 return count;
536}
#define d(i)
Definition: RSha256.hxx:102
#define b(i)
Definition: RSha256.hxx:100
int Int_t
Definition: RtypesCore.h:45
unsigned int UInt_t
Definition: RtypesCore.h:46
const Bool_t kFALSE
Definition: RtypesCore.h:101
bool Bool_t
Definition: RtypesCore.h:63
double Double_t
Definition: RtypesCore.h:59
float Float_t
Definition: RtypesCore.h:57
const Bool_t kTRUE
Definition: RtypesCore.h:100
#define ClassImp(name)
Definition: Rtypes.h:364
int type
Definition: TGX11.cxx:121
Node for the BinarySearch or Decision Trees.
const std::vector< Float_t > & GetEventV() const
A simple Binary search tree including a volume search method.
BinarySearchTreeNode * Search(Event *event) const
search the tree to find the node matching "event"
Bool_t InVolume(const std::vector< Float_t > &, Volume *) const
test if the data points are in the given volume
virtual ~BinarySearchTree(void)
destructor
Double_t Fill(const std::vector< TMVA::Event * > &events, const std::vector< Int_t > &theVars, Int_t theType=-1)
create the search tree from the event collection using ONLY the variables specified in "theVars"
void NormalizeTree()
Normalisation of tree.
Double_t GetSumOfWeights(void) const
return the sum of event (node) weights
static BinarySearchTree * CreateFromXML(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
re-create a new tree (decision tree or search tree) from XML
void Clear(TMVA::Node *n=0)
clear nodes
BinarySearchTree(void)
default constructor
void CalcStatistics(TMVA::Node *n=0)
calculate basic statistics (mean, rms for each variable)
Double_t SearchVolume(Volume *, std::vector< const TMVA::BinarySearchTreeNode * > *events=0)
search the whole tree and add up all weights of events that lie within the given volume
void Insert(const Event *)
insert a new "event" in the binary tree
Int_t SearchVolumeWithMaxLimit(TMVA::Volume *, std::vector< const TMVA::BinarySearchTreeNode * > *events=0, Int_t=-1)
recursively walk through the daughter nodes and add up all weights of events that lie within the give...
Base class for BinarySearch and Decision Trees.
Definition: BinaryTree.h:62
virtual void ReadXML(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
read attributes from XML
Definition: BinaryTree.cxx:144
MsgLogger & Log() const
Definition: BinaryTree.cxx:235
Node for the BinarySearch or Decision Trees.
Definition: Node.h:58
virtual Node * GetLeft() const
Definition: Node.h:89
virtual void SetRight(Node *r)
Definition: Node.h:95
virtual Bool_t GoesLeft(const Event &) const =0
virtual void SetLeft(Node *l)
Definition: Node.h:94
void SetPos(char s)
Definition: Node.h:119
void SetDepth(UInt_t d)
Definition: Node.h:113
virtual void SetParent(Node *p)
Definition: Node.h:96
virtual Bool_t GoesRight(const Event &) const =0
UInt_t GetDepth() const
Definition: Node.h:116
virtual Node * GetRight() const
Definition: Node.h:90
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:329
@ kSignal
Definition: Types.h:135
Volume for BinarySearchTree.
Definition: Volume.h:47
std::vector< Double_t > * fLower
Definition: Volume.h:75
std::vector< Double_t > * fUpper
Definition: Volume.h:76
@ kWARNING
Definition: Types.h:59
@ kFATAL
Definition: Types.h:61
const Int_t n
Definition: legend1.C:16
TClass * GetClass(T *)
Definition: TClass.h:658
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:148
Double_t Log(Double_t x)
Definition: TMath.h:710
Double_t Sqrt(Double_t x)
Definition: TMath.h:641