Logo ROOT  
Reference Guide
BinarySearchTree.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : BinarySearchTree *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation (see header file for description) *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Joerg Stelzer <stelzer@cern.ch> - DESY, Germany *
16 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
17 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada *
18 * *
19 * Copyright (c) 2005: *
20 * CERN, Switzerland *
21 * U. of Victoria, Canada *
22 * MPI-K Heidelberg, Germany *
23 * LAPP, Annecy, France *
24 * *
25 * Redistribution and use in source and binary forms, with or without *
26 * modification, are permitted according to the terms listed in LICENSE *
27 * (http://tmva.sourceforge.net/LICENSE) *
28 * *
29 **********************************************************************************/
30
31/*! \class TMVA::BinarySearchTree
32\ingroup TMVA
33
34A simple Binary search tree including a volume search method.
35
36*/
37
38#include <stdexcept>
39#include <cstdlib>
40#include <queue>
41#include <algorithm>
42
43#include "TMath.h"
44
45#include "TMatrixDBase.h"
46#include "TObjString.h"
47#include "TTree.h"
48
49#include "TMVA/MsgLogger.h"
50#include "TMVA/MethodBase.h"
51#include "TMVA/Tools.h"
52#include "TMVA/Event.h"
54
55#include "TMVA/BinaryTree.h"
56#include "TMVA/Types.h"
57#include "TMVA/Node.h"
58
60
61////////////////////////////////////////////////////////////////////////////////
62/// default constructor
63
66 fPeriod ( 1 ),
67 fCurrentDepth( 0 ),
68 fStatisticsIsValid( kFALSE ),
69 fSumOfWeights( 0 ),
70 fCanNormalize( kFALSE )
71{
72 fNEventsW[0]=fNEventsW[1]=0.;
73}
74
75////////////////////////////////////////////////////////////////////////////////
76/// copy constructor that creates a true copy, i.e. a completely independent tree
77
79 : BinaryTree(),
80 fPeriod ( b.fPeriod ),
81 fCurrentDepth( 0 ),
82 fStatisticsIsValid( kFALSE ),
83 fSumOfWeights( b.fSumOfWeights ),
84 fCanNormalize( kFALSE )
85{
86 fNEventsW[0]=fNEventsW[1]=0.;
87 Log() << kFATAL << " Copy constructor not implemented yet " << Endl;
88}
89
90////////////////////////////////////////////////////////////////////////////////
91/// destructor
92
94{
95 for(std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator pIt = fNormalizeTreeTable.begin();
96 pIt != fNormalizeTreeTable.end(); ++pIt) {
97 delete pIt->second;
98 }
99}
100
101////////////////////////////////////////////////////////////////////////////////
102/// re-create a new tree (decision tree or search tree) from XML
103
105 std::string type("");
106 gTools().ReadAttr(node,"type", type);
108 bt->ReadXML( node, tmva_Version_Code );
109 return bt;
110}
111
112////////////////////////////////////////////////////////////////////////////////
113/// insert a new "event" in the binary tree
114
116{
117 fCurrentDepth=0;
118 fStatisticsIsValid = kFALSE;
119
120 if (this->GetRoot() == NULL) { // If the list is empty...
121 this->SetRoot( new BinarySearchTreeNode(event)); //Make the new node the root.
122 // have to use "s" for start as "r" for "root" would be the same as "r" for "right"
123 this->GetRoot()->SetPos('s');
124 this->GetRoot()->SetDepth(0);
125 fNNodes = 1;
126 fSumOfWeights = event->GetWeight();
127 ((BinarySearchTreeNode*)this->GetRoot())->SetSelector((UInt_t)0);
128 this->SetPeriode(event->GetNVariables());
129 }
130 else {
131 // sanity check:
132 if (event->GetNVariables() != (UInt_t)this->GetPeriode()) {
133 Log() << kFATAL << "<Insert> event vector length != Periode specified in Binary Tree" << Endl
134 << "--- event size: " << event->GetNVariables() << " Periode: " << this->GetPeriode() << Endl
135 << "--- and all this when trying filling the "<<fNNodes+1<<"th Node" << Endl;
136 }
137 // insert a new node at the propper position
138 this->Insert(event, this->GetRoot());
139 }
140
141 // normalise the tree to speed up searches
142 if (fCanNormalize) fNormalizeTreeTable.push_back( std::make_pair(0.0,new const Event(*event)) );
143}
144
145////////////////////////////////////////////////////////////////////////////////
146/// private internal function to insert a event (node) at the proper position
147
149 Node *node )
150{
151 fCurrentDepth++;
152 fStatisticsIsValid = kFALSE;
153
154 if (node->GoesLeft(*event)){ // If the adding item is less than the current node's data...
155 if (node->GetLeft() != NULL){ // If there is a left node...
156 // Add the new event to the left node
157 this->Insert(event, node->GetLeft());
158 }
159 else { // If there is not a left node...
160 // Make the new node for the new event
161 BinarySearchTreeNode* current = new BinarySearchTreeNode(event);
162 fNNodes++;
163 fSumOfWeights += event->GetWeight();
164 current->SetSelector(fCurrentDepth%((Int_t)event->GetNVariables()));
165 current->SetParent(node); // Set the new node's previous node.
166 current->SetPos('l');
167 current->SetDepth( node->GetDepth() + 1 );
168 node->SetLeft(current); // Make it the left node of the current one.
169 }
170 }
171 else if (node->GoesRight(*event)) { // If the adding item is less than or equal to the current node's data...
172 if (node->GetRight() != NULL) { // If there is a right node...
173 // Add the new node to it.
174 this->Insert(event, node->GetRight());
175 }
176 else { // If there is not a right node...
177 // Make the new node.
178 BinarySearchTreeNode* current = new BinarySearchTreeNode(event);
179 fNNodes++;
180 fSumOfWeights += event->GetWeight();
181 current->SetSelector(fCurrentDepth%((Int_t)event->GetNVariables()));
182 current->SetParent(node); // Set the new node's previous node.
183 current->SetPos('r');
184 current->SetDepth( node->GetDepth() + 1 );
185 node->SetRight(current); // Make it the left node of the current one.
186 }
187 }
188 else Log() << kFATAL << "<Insert> neither left nor right :)" << Endl;
189}
190
191////////////////////////////////////////////////////////////////////////////////
192///search the tree to find the node matching "event"
193
195{
196 return this->Search( event, this->GetRoot() );
197}
198
199////////////////////////////////////////////////////////////////////////////////
200/// Private, recursive, function for searching.
201
203{
204 if (node != NULL) { // If the node is not NULL...
205 // If we have found the node...
206 if (((BinarySearchTreeNode*)(node))->EqualsMe(*event))
207 return (BinarySearchTreeNode*)node; // Return it
208 if (node->GoesLeft(*event)) // If the node's data is greater than the search item...
209 return this->Search(event, node->GetLeft()); //Search the left node.
210 else //If the node's data is less than the search item...
211 return this->Search(event, node->GetRight()); //Search the right node.
212 }
213 else return NULL; //If the node is NULL, return NULL.
214}
215
216////////////////////////////////////////////////////////////////////////////////
217/// return the sum of event (node) weights
218
220{
221 if (fSumOfWeights <= 0) {
222 Log() << kWARNING << "you asked for the SumOfWeights, which is not filled yet"
223 << " I call CalcStatistics which hopefully fixes things"
224 << Endl;
225 }
226 if (fSumOfWeights <= 0) Log() << kFATAL << " Zero events in your Search Tree" <<Endl;
227
228 return fSumOfWeights;
229}
230
231////////////////////////////////////////////////////////////////////////////////
232/// return the sum of event (node) weights
233
235{
236 if (fSumOfWeights <= 0) {
237 Log() << kWARNING << "you asked for the SumOfWeights, which is not filled yet"
238 << " I call CalcStatistics which hopefully fixes things"
239 << Endl;
240 }
241 if (fSumOfWeights <= 0) Log() << kFATAL << " Zero events in your Search Tree" <<Endl;
242
243 return fNEventsW[ ( theType == Types::kSignal) ? 0 : 1 ];
244}
245
246////////////////////////////////////////////////////////////////////////////////
247/// create the search tree from the event collection
248/// using ONLY the variables specified in "theVars"
249
250Double_t TMVA::BinarySearchTree::Fill( const std::vector<Event*>& events, const std::vector<Int_t>& theVars,
251 Int_t theType )
252{
253 fPeriod = theVars.size();
254 return Fill(events, theType);
255}
256
257////////////////////////////////////////////////////////////////////////////////
258/// create the search tree from the events in a TTree
259/// using ALL the variables specified included in the Event
260
261Double_t TMVA::BinarySearchTree::Fill( const std::vector<Event*>& events, Int_t theType )
262{
263 UInt_t n=events.size();
264
265 UInt_t nevents = 0;
266 if (fSumOfWeights != 0) {
267 Log() << kWARNING
268 << "You are filling a search three that is not empty.. "
269 << " do you know what you are doing?"
270 << Endl;
271 }
272 for (UInt_t ievt=0; ievt<n; ievt++) {
273 // insert event into binary tree
274 if (theType == -1 || (Int_t(events[ievt]->GetClass()) == theType) ) {
275 this->Insert( events[ievt] );
276 nevents++;
277 fSumOfWeights += events[ievt]->GetWeight();
278 }
279 } // end of event loop
280 CalcStatistics(0);
281
282 return fSumOfWeights;
283}
284
285////////////////////////////////////////////////////////////////////////////////
286/// normalises the binary-search tree to reduce the branch length and hence speed up the
287/// search procedure (on average).
288
289void TMVA::BinarySearchTree::NormalizeTree ( std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator leftBound,
290 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator rightBound,
291 UInt_t actDim )
292{
293
294 if (leftBound == rightBound) return;
295
296 if (actDim == fPeriod) actDim = 0;
297 for (std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator i=leftBound; i!=rightBound; ++i) {
298 i->first = i->second->GetValue( actDim );
299 }
300
301 std::sort( leftBound, rightBound );
302
303 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator leftTemp = leftBound;
304 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator rightTemp = rightBound;
305
306 // meet in the middle
307 while (true) {
308 --rightTemp;
309 if (rightTemp == leftTemp ) {
310 break;
311 }
312 ++leftTemp;
313 if (leftTemp == rightTemp) {
314 break;
315 }
316 }
317
318 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator mid = leftTemp;
319 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator midTemp = mid;
320
321 if (mid!=leftBound)--midTemp;
322
323 while (mid != leftBound && mid->second->GetValue( actDim ) == midTemp->second->GetValue( actDim )) {
324 --mid;
325 --midTemp;
326 }
327
328 Insert( mid->second );
329
330 // Print(std::cout);
331 // std::cout << std::endl << std::endl;
332
333 NormalizeTree( leftBound, mid, actDim+1 );
334 ++mid;
335 // Print(std::cout);
336 // std::cout << std::endl << std::endl;
337 NormalizeTree( mid, rightBound, actDim+1 );
338
339
340 return;
341}
342
343////////////////////////////////////////////////////////////////////////////////
344/// Normalisation of tree
345
347{
348 SetNormalize( kFALSE );
349 Clear( NULL );
350 this->SetRoot(NULL);
351 NormalizeTree( fNormalizeTreeTable.begin(), fNormalizeTreeTable.end(), 0 );
352}
353
354////////////////////////////////////////////////////////////////////////////////
355/// clear nodes
356
358{
359 BinarySearchTreeNode* currentNode = (BinarySearchTreeNode*)(n == NULL ? this->GetRoot() : n);
360
361 if (currentNode->GetLeft() != 0) Clear( currentNode->GetLeft() );
362 if (currentNode->GetRight() != 0) Clear( currentNode->GetRight() );
363
364 if (n != NULL) delete n;
365
366 return;
367}
368
369////////////////////////////////////////////////////////////////////////////////
370/// search the whole tree and add up all weights of events that
371/// lie within the given volume
372
374 std::vector<const BinarySearchTreeNode*>* events )
375{
376 return SearchVolume( this->GetRoot(), volume, 0, events );
377}
378
379////////////////////////////////////////////////////////////////////////////////
380/// recursively walk through the daughter nodes and add up all weights of events that
381/// lie within the given volume
382
384 std::vector<const BinarySearchTreeNode*>* events )
385{
386 if (t==NULL) return 0; // Are we at an outer leave?
387
389
390 Double_t count = 0.0;
391 if (InVolume( st->GetEventV(), volume )) {
392 count += st->GetWeight();
393 if (NULL != events) events->push_back( st );
394 }
395 if (st->GetLeft()==NULL && st->GetRight()==NULL) {
396
397 return count; // Are we at an outer leave?
398 }
399
400 Bool_t tl, tr;
401 Int_t d = depth%this->GetPeriode();
402 if (d != st->GetSelector()) {
403 Log() << kFATAL << "<SearchVolume> selector in Searchvolume "
404 << d << " != " << "node "<< st->GetSelector() << Endl;
405 }
406 tl = (*(volume->fLower))[d] < st->GetEventV()[d]; // Should we descend left?
407 tr = (*(volume->fUpper))[d] >= st->GetEventV()[d]; // Should we descend right?
408
409 if (tl) count += SearchVolume( st->GetLeft(), volume, (depth+1), events );
410 if (tr) count += SearchVolume( st->GetRight(), volume, (depth+1), events );
411
412 return count;
413}
414
415////////////////////////////////////////////////////////////////////////////////
416/// test if the data points are in the given volume
417
418Bool_t TMVA::BinarySearchTree::InVolume(const std::vector<Float_t>& event, Volume* volume ) const
419{
420
421 Bool_t result = false;
422 for (UInt_t ivar=0; ivar< fPeriod; ivar++) {
423 result = ( (*(volume->fLower))[ivar] < event[ivar] &&
424 (*(volume->fUpper))[ivar] >= event[ivar] );
425 if (!result) break;
426 }
427 return result;
428}
429
430////////////////////////////////////////////////////////////////////////////////
431/// calculate basic statistics (mean, rms for each variable)
432
434{
435 if (fStatisticsIsValid) return;
436
438
439 // default, start at the tree top, then descend recursively
440 if (n == NULL) {
441 fSumOfWeights = 0;
442 for (Int_t sb=0; sb<2; sb++) {
443 fNEventsW[sb] = 0;
444 fMeans[sb] = std::vector<Float_t>(fPeriod);
445 fRMS[sb] = std::vector<Float_t>(fPeriod);
446 fMin[sb] = std::vector<Float_t>(fPeriod);
447 fMax[sb] = std::vector<Float_t>(fPeriod);
448 fSum[sb] = std::vector<Double_t>(fPeriod);
449 fSumSq[sb] = std::vector<Double_t>(fPeriod);
450 for (UInt_t j=0; j<fPeriod; j++) {
451 fMeans[sb][j] = fRMS[sb][j] = fSum[sb][j] = fSumSq[sb][j] = 0;
452 fMin[sb][j] = FLT_MAX;
453 fMax[sb][j] = -FLT_MAX;
454 }
455 }
456 currentNode = (BinarySearchTreeNode*) this->GetRoot();
457 if (currentNode == NULL) return; // no root-node
458 }
459
460 const std::vector<Float_t> & evtVec = currentNode->GetEventV();
461 Double_t weight = currentNode->GetWeight();
462 // Int_t type = currentNode->IsSignal();
463 // Int_t type = currentNode->IsSignal() ? 0 : 1;
464 Int_t type = Int_t(currentNode->GetClass())== Types::kSignal ? 0 : 1;
465
466 fNEventsW[type] += weight;
467 fSumOfWeights += weight;
468
469 for (UInt_t j=0; j<fPeriod; j++) {
470 Float_t val = evtVec[j];
471 fSum[type][j] += val*weight;
472 fSumSq[type][j] += val*val*weight;
473 if (val < fMin[type][j]) fMin[type][j] = val;
474 if (val > fMax[type][j]) fMax[type][j] = val;
475 }
476
477 if ( (currentNode->GetLeft() != NULL) ) CalcStatistics( currentNode->GetLeft() );
478 if ( (currentNode->GetRight() != NULL) ) CalcStatistics( currentNode->GetRight() );
479
480 if (n == NULL) { // i.e. the root node
481 for (Int_t sb=0; sb<2; sb++) {
482 for (UInt_t j=0; j<fPeriod; j++) {
483 if (fNEventsW[sb] == 0) { fMeans[sb][j] = fRMS[sb][j] = 0; continue; }
484 fMeans[sb][j] = fSum[sb][j]/fNEventsW[sb];
485 fRMS[sb][j] = TMath::Sqrt(fSumSq[sb][j]/fNEventsW[sb] - fMeans[sb][j]*fMeans[sb][j]);
486 }
487 }
488 fStatisticsIsValid = kTRUE;
489 }
490
491 return;
492}
493
494////////////////////////////////////////////////////////////////////////////////
495/// recursively walk through the daughter nodes and add up all weights of events that
496/// lie within the given volume a maximum number of events can be given
497
498Int_t TMVA::BinarySearchTree::SearchVolumeWithMaxLimit( Volume *volume, std::vector<const BinarySearchTreeNode*>* events,
499 Int_t max_points )
500{
501 if (this->GetRoot() == NULL) return 0; // Are we at an outer leave?
502
503 std::queue< std::pair< const BinarySearchTreeNode*, Int_t > > queue;
504 std::pair< const BinarySearchTreeNode*, Int_t > st = std::make_pair( (const BinarySearchTreeNode*)this->GetRoot(), 0 );
505 queue.push( st );
506
507 Int_t count = 0;
508
509 while ( !queue.empty() ) {
510 st = queue.front(); queue.pop();
511
512 if (count == max_points)
513 return count;
514
515 if (InVolume( st.first->GetEventV(), volume )) {
516 count++;
517 if (NULL != events) events->push_back( st.first );
518 }
519
520 Bool_t tl, tr;
521 Int_t d = st.second;
522 if ( d == Int_t(this->GetPeriode()) ) d = 0;
523
524 if (d != st.first->GetSelector()) {
525 Log() << kFATAL << "<SearchVolume> selector in Searchvolume "
526 << d << " != " << "node "<< st.first->GetSelector() << Endl;
527 }
528
529 tl = (*(volume->fLower))[d] < st.first->GetEventV()[d] && st.first->GetLeft() != NULL; // Should we descend left?
530 tr = (*(volume->fUpper))[d] >= st.first->GetEventV()[d] && st.first->GetRight() != NULL; // Should we descend right?
531
532 if (tl) queue.push( std::make_pair( (const BinarySearchTreeNode*)st.first->GetLeft(), d+1 ) );
533 if (tr) queue.push( std::make_pair( (const BinarySearchTreeNode*)st.first->GetRight(), d+1 ) );
534 }
535
536 return count;
537}
#define d(i)
Definition: RSha256.hxx:102
#define b(i)
Definition: RSha256.hxx:100
int Int_t
Definition: RtypesCore.h:41
unsigned int UInt_t
Definition: RtypesCore.h:42
const Bool_t kFALSE
Definition: RtypesCore.h:88
bool Bool_t
Definition: RtypesCore.h:59
double Double_t
Definition: RtypesCore.h:55
float Float_t
Definition: RtypesCore.h:53
const Bool_t kTRUE
Definition: RtypesCore.h:87
#define ClassImp(name)
Definition: Rtypes.h:365
int type
Definition: TGX11.cxx:120
Node for the BinarySearch or Decision Trees.
const std::vector< Float_t > & GetEventV() const
A simple Binary search tree including a volume search method.
BinarySearchTreeNode * Search(Event *event) const
search the tree to find the node matching "event"
Bool_t InVolume(const std::vector< Float_t > &, Volume *) const
test if the data points are in the given volume
virtual ~BinarySearchTree(void)
destructor
Double_t Fill(const std::vector< TMVA::Event * > &events, const std::vector< Int_t > &theVars, Int_t theType=-1)
create the search tree from the event collection using ONLY the variables specified in "theVars"
void NormalizeTree()
Normalisation of tree.
Double_t GetSumOfWeights(void) const
return the sum of event (node) weights
static BinarySearchTree * CreateFromXML(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
re-create a new tree (decision tree or search tree) from XML
void Clear(TMVA::Node *n=0)
clear nodes
BinarySearchTree(void)
default constructor
void CalcStatistics(TMVA::Node *n=0)
calculate basic statistics (mean, rms for each variable)
Double_t SearchVolume(Volume *, std::vector< const TMVA::BinarySearchTreeNode * > *events=0)
search the whole tree and add up all weights of events that lie within the given volume
void Insert(const Event *)
insert a new "event" in the binary tree
Int_t SearchVolumeWithMaxLimit(TMVA::Volume *, std::vector< const TMVA::BinarySearchTreeNode * > *events=0, Int_t=-1)
recursively walk through the daughter nodes and add up all weights of events that lie within the give...
Base class for BinarySearch and Decision Trees.
Definition: BinaryTree.h:62
virtual void ReadXML(void *node, UInt_t tmva_Version_Code=TMVA_VERSION_CODE)
read attributes from XML
Definition: BinaryTree.cxx:144
MsgLogger & Log() const
Definition: BinaryTree.cxx:235
UInt_t GetNVariables() const
accessor to the number of variables
Definition: Event.cxx:309
Node for the BinarySearch or Decision Trees.
Definition: Node.h:56
virtual Node * GetLeft() const
Definition: Node.h:87
virtual void SetRight(Node *r)
Definition: Node.h:93
virtual Bool_t GoesLeft(const Event &) const =0
virtual void SetLeft(Node *l)
Definition: Node.h:92
void SetPos(char s)
Definition: Node.h:117
void SetDepth(UInt_t d)
Definition: Node.h:111
virtual void SetParent(Node *p)
Definition: Node.h:94
virtual Bool_t GoesRight(const Event &) const =0
UInt_t GetDepth() const
Definition: Node.h:114
virtual Node * GetRight() const
Definition: Node.h:88
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:337
@ kSignal
Definition: Types.h:136
Volume for BinarySearchTree.
Definition: Volume.h:48
std::vector< Double_t > * fLower
Definition: Volume.h:76
std::vector< Double_t > * fUpper
Definition: Volume.h:77
const Int_t n
Definition: legend1.C:16
TClass * GetClass(T *)
Definition: TClass.h:608
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:750
Double_t Sqrt(Double_t x)
Definition: TMath.h:681