Logo ROOT  
Reference Guide
NodekNN.h
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Rustem Ospanov
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : Node *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * kd-tree (binary tree) template *
12 * *
13 * Author: *
14 * Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA *
15 * *
16 * Copyright (c) 2007: *
17 * CERN, Switzerland *
18 * MPI-K Heidelberg, Germany *
19 * U. of Texas at Austin, USA *
20 * *
21 * Redistribution and use in source and binary forms, with or without *
22 * modification, are permitted according to the terms listed in LICENSE *
23 * (http://tmva.sourceforge.net/LICENSE) *
24 **********************************************************************************/
25
26#ifndef ROOT_TMVA_NodekNN
27#define ROOT_TMVA_NodekNN
28
29// C++
30#include <cassert>
31#include <list>
32#include <string>
33#include <iostream>
34#include <utility>
35
36// ROOT
37#include "RtypesCore.h"
38
39/*! \class TMVA::kNN::Node
40\ingroup TMVA
41This file contains binary tree and global function template
42that searches tree for k-nearest neigbors
43
44Node class template parameter T has to provide these functions:
45 rtype GetVar(UInt_t) const;
46 - rtype is any type convertible to Float_t
47 UInt_t GetNVar(void) const;
48 rtype GetWeight(void) const;
49 - rtype is any type convertible to Double_t
50
51Find function template parameter T has to provide these functions:
52(in addition to above requirements)
53 rtype GetDist(Float_t, UInt_t) const;
54 - rtype is any type convertible to Float_t
55 rtype GetDist(const T &) const;
56 - rtype is any type convertible to Float_t
57
58 where T::GetDist(Float_t, UInt_t) <= T::GetDist(const T &)
59 for any pair of events and any variable number for these events
60*/
61
62namespace TMVA
63{
64 namespace kNN
65 {
66 template <class T>
67 class Node
68 {
69
70 public:
71
72 Node(const Node *parent, const T &event, Int_t mod);
73 ~Node();
74
75 const Node* Add(const T &event, UInt_t depth);
76
77 void SetNodeL(Node *node);
78 void SetNodeR(Node *node);
79
80 const T& GetEvent() const;
81
82 const Node* GetNodeL() const;
83 const Node* GetNodeR() const;
84 const Node* GetNodeP() const;
85
86 Double_t GetWeight() const;
87
88 Float_t GetVarDis() const;
89 Float_t GetVarMin() const;
90 Float_t GetVarMax() const;
91
92 UInt_t GetMod() const;
93
94 void Print() const;
95 void Print(std::ostream& os, const std::string &offset = "") const;
96
97 private:
98
99 // these methods are private and not implemented by design
100 // use provided public constructor for all uses of this template class
102 Node(const Node &);
103 const Node& operator=(const Node &);
104
105 private:
106
107 const Node* fNodeP;
108
111
112 const T fEvent;
113
115
118
120 };
121
122 // recursive search for k-nearest neighbor: k = nfind
123 template<class T>
124 UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
125 const Node<T> *node, const T &event, UInt_t nfind);
126
127 // recursive search for k-nearest neighbor
128 // find k events with sum of event weights >= nfind
129 template<class T>
130 UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
131 const Node<T> *node, const T &event, Double_t nfind, Double_t ncurr);
132
133 // recursively travel upward until root node is reached
134 template <class T>
135 UInt_t Depth(const Node<T> *node);
136
137 // prInt_t node content and content of its children
138 //template <class T>
139 //std::ostream& operator<<(std::ostream& os, const Node<T> &node);
140
141 //
142 // Inlined functions for Node template
143 //
144 template <class T>
145 inline void Node<T>::SetNodeL(Node<T> *node)
146 {
147 fNodeL = node;
148 }
149
150 template <class T>
151 inline void Node<T>::SetNodeR(Node<T> *node)
152 {
153 fNodeR = node;
154 }
155
156 template <class T>
157 inline const T& Node<T>::GetEvent() const
158 {
159 return fEvent;
160 }
161
162 template <class T>
163 inline const Node<T>* Node<T>::GetNodeL() const
164 {
165 return fNodeL;
166 }
167
168 template <class T>
169 inline const Node<T>* Node<T>::GetNodeR() const
170 {
171 return fNodeR;
172 }
173
174 template <class T>
175 inline const Node<T>* Node<T>::GetNodeP() const
176 {
177 return fNodeP;
178 }
179
180 template <class T>
182 {
183 return fEvent.GetWeight();
184 }
185
186 template <class T>
188 {
189 return fVarDis;
190 }
191
192 template <class T>
194 {
195 return fVarMin;
196 }
197
198 template <class T>
200 {
201 return fVarMax;
202 }
203
204 template <class T>
205 inline UInt_t Node<T>::GetMod() const
206 {
207 return fMod;
208 }
209
210 //
211 // Inlined global function(s)
212 //
213 template <class T>
214 inline UInt_t Depth(const Node<T> *node)
215 {
216 if (!node) return 0;
217 else return Depth(node->GetNodeP()) + 1;
218 }
219
220 } // end of kNN namespace
221} // end of TMVA namespace
222
223////////////////////////////////////////////////////////////////////////////////
224template<class T>
225TMVA::kNN::Node<T>::Node(const Node<T> *parent, const T &event, const Int_t mod)
226:fNodeP(parent),
227 fNodeL(0),
228 fNodeR(0),
229 fEvent(event),
230 fVarDis(event.GetVar(mod)),
231 fVarMin(fVarDis),
232 fVarMax(fVarDis),
233 fMod(mod)
234{}
235
236////////////////////////////////////////////////////////////////////////////////
237template<class T>
239{
240 if (fNodeL) delete fNodeL;
241 if (fNodeR) delete fNodeR;
242}
243
244////////////////////////////////////////////////////////////////////////////////
245/// This is Node member function that adds a new node to a binary tree.
246/// each node contains maximum and minimum values of splitting variable
247/// left or right nodes are added based on value of splitting variable
248
249template<class T>
251{
252
253 assert(fMod == depth % event.GetNVar() && "Wrong recursive depth in Node<>::Add");
254
255 const Float_t value = event.GetVar(fMod);
256
257 fVarMin = std::min(fVarMin, value);
258 fVarMax = std::max(fVarMax, value);
259
260 Node<T> *node = 0;
261 if (value < fVarDis) {
262 if (fNodeL)
263 {
264 return fNodeL->Add(event, depth + 1);
265 }
266 else {
267 fNodeL = new Node<T>(this, event, (depth + 1) % event.GetNVar());
268 node = fNodeL;
269 }
270 }
271 else {
272 if (fNodeR) {
273 return fNodeR->Add(event, depth + 1);
274 }
275 else {
276 fNodeR = new Node<T>(this, event, (depth + 1) % event.GetNVar());
277 node = fNodeR;
278 }
279 }
280
281 return node;
282}
283
284////////////////////////////////////////////////////////////////////////////////
285template<class T>
287{
288 Print(std::cout);
289}
290
291////////////////////////////////////////////////////////////////////////////////
292template<class T>
293void TMVA::kNN::Node<T>::Print(std::ostream& os, const std::string &offset) const
294{
295 os << offset << "-----------------------------------------------------------" << std::endl;
296 os << offset << "Node: mod " << fMod
297 << " at " << fVarDis
298 << " with weight: " << GetWeight() << std::endl
299 << offset << fEvent;
300
301 if (fNodeL) {
302 os << offset << "Has left node " << std::endl;
303 }
304 if (fNodeR) {
305 os << offset << "Has right node" << std::endl;
306 }
307
308 if (fNodeL) {
309 os << offset << "PrInt_t left node " << std::endl;
310 fNodeL->Print(os, offset + " ");
311 }
312 if (fNodeR) {
313 os << offset << "PrInt_t right node" << std::endl;
314 fNodeR->Print(os, offset + " ");
315 }
316
317 if (!fNodeL && !fNodeR) {
318 os << std::endl;
319 }
320}
321
322////////////////////////////////////////////////////////////////////////////////
323/// This is a global templated function that searches for k-nearest neighbors.
324/// list contains k or less nodes that are closest to event.
325/// only nodes with positive weights are added to list.
326/// each node contains maximum and minimum values of splitting variable
327/// for all its children - this range is checked to avoid descending into
328/// nodes that are definitely outside current minimum neighbourhood.
329///
330/// This function should be modified with care.
331
332template<class T>
333UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
334 const TMVA::kNN::Node<T> *node, const T &event, const UInt_t nfind)
335{
336 if (!node || nfind < 1) {
337 return 0;
338 }
339
340 const Float_t value = event.GetVar(node->GetMod());
341
342 if (node->GetWeight() > 0.0) {
343
344 Float_t max_dist = 0.0;
345
346 if (!nlist.empty()) {
347
348 max_dist = nlist.back().second;
349
350 if (nlist.size() == nfind) {
351 if (value > node->GetVarMax() &&
352 event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
353 return 0;
354 }
355 if (value < node->GetVarMin() &&
356 event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
357 return 0;
358 }
359 }
360 }
361
362 const Float_t distance = event.GetDist(node->GetEvent());
363
364 Bool_t insert_this = kFALSE;
365 Bool_t remove_back = kFALSE;
366
367 if (nlist.size() < nfind) {
368 insert_this = kTRUE;
369 }
370 else if (nlist.size() == nfind) {
371 if (distance < max_dist) {
372 insert_this = kTRUE;
373 remove_back = kTRUE;
374 }
375 }
376 else {
377 std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
378 return 1;
379 }
380
381 if (insert_this) {
382 // need typename keyword because qualified dependent names
383 // are not valid types unless preceded by 'typename'.
384 typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
385
386 // find a place where current node should be inserted
387 for (; lit != nlist.end(); ++lit) {
388 if (distance < lit->second) {
389 break;
390 }
391 else {
392 continue;
393 }
394 }
395
396 nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
397
398 if (remove_back) {
399 nlist.pop_back();
400 }
401 }
402 }
403
404 UInt_t count = 1;
405 if (node->GetNodeL() && node->GetNodeR()) {
406 if (value < node->GetVarDis()) {
407 count += Find(nlist, node->GetNodeL(), event, nfind);
408 count += Find(nlist, node->GetNodeR(), event, nfind);
409 }
410 else {
411 count += Find(nlist, node->GetNodeR(), event, nfind);
412 count += Find(nlist, node->GetNodeL(), event, nfind);
413 }
414 }
415 else {
416 if (node->GetNodeL()) {
417 count += Find(nlist, node->GetNodeL(), event, nfind);
418 }
419 if (node->GetNodeR()) {
420 count += Find(nlist, node->GetNodeR(), event, nfind);
421 }
422 }
423
424 return count;
425}
426
427////////////////////////////////////////////////////////////////////////////////
428/// This is a global templated function that searches for k-nearest neighbors.
429/// list contains all nodes that are closest to event
430/// and have sum of event weights >= nfind.
431/// Only nodes with positive weights are added to list.
432/// Requirement for used classes:
433/// - each node contains maximum and minimum values of splitting variable
434/// for all its children
435/// - min and max range is checked to avoid descending into
436/// nodes that are definitely outside current minimum neighbourhood.
437///
438/// This function should be modified with care.
439
440template<class T>
441UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
442 const TMVA::kNN::Node<T> *node, const T &event, const Double_t nfind, Double_t ncurr)
443{
444
445 if (!node || !(nfind < 0.0)) {
446 return 0;
447 }
448
449 const Float_t value = event.GetVar(node->GetMod());
450
451 if (node->GetWeight() > 0.0) {
452
453 Float_t max_dist = 0.0;
454
455 if (!nlist.empty()) {
456
457 max_dist = nlist.back().second;
458
459 if (!(ncurr < nfind)) {
460 if (value > node->GetVarMax() &&
461 event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
462 return 0;
463 }
464 if (value < node->GetVarMin() &&
465 event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
466 return 0;
467 }
468 }
469 }
470
471 const Float_t distance = event.GetDist(node->GetEvent());
472
473 Bool_t insert_this = kFALSE;
474
475 if (ncurr < nfind) {
476 insert_this = kTRUE;
477 }
478 else if (!nlist.empty()) {
479 if (distance < max_dist) {
480 insert_this = kTRUE;
481 }
482 }
483 else {
484 std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
485 return 1;
486 }
487
488 if (insert_this) {
489 // (re)compute total current weight when inserting a new node
490 ncurr = 0;
491
492 // need typename keyword because qualified dependent names
493 // are not valid types unless preceded by 'typename'.
494 typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
495
496 // find a place where current node should be inserted
497 for (; lit != nlist.end(); ++lit) {
498 if (distance < lit->second) {
499 break;
500 }
501
502 ncurr += lit -> first -> GetWeight();
503 }
504
505 lit = nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
506
507 for (; lit != nlist.end(); ++lit) {
508 ncurr += lit -> first -> GetWeight();
509 if (!(ncurr < nfind)) {
510 ++lit;
511 break;
512 }
513 }
514
515 if(lit != nlist.end())
516 {
517 nlist.erase(lit, nlist.end());
518 }
519 }
520 }
521
522 UInt_t count = 1;
523 if (node->GetNodeL() && node->GetNodeR()) {
524 if (value < node->GetVarDis()) {
525 count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
526 count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
527 }
528 else {
529 count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
530 count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
531 }
532 }
533 else {
534 if (node->GetNodeL()) {
535 count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
536 }
537 if (node->GetNodeR()) {
538 count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
539 }
540 }
541
542 return count;
543}
544
545#endif
546
int Int_t
Definition: RtypesCore.h:45
unsigned int UInt_t
Definition: RtypesCore.h:46
const Bool_t kFALSE
Definition: RtypesCore.h:101
bool Bool_t
Definition: RtypesCore.h:63
double Double_t
Definition: RtypesCore.h:59
float Float_t
Definition: RtypesCore.h:57
const Bool_t kTRUE
Definition: RtypesCore.h:100
This file contains binary tree and global function template that searches tree for k-nearest neigbors...
Definition: NodekNN.h:68
Float_t GetVarMin() const
Definition: NodekNN.h:193
void SetNodeL(Node *node)
Definition: NodekNN.h:145
const Node * fNodeP
Definition: NodekNN.h:107
void SetNodeR(Node *node)
Definition: NodekNN.h:151
Double_t GetWeight() const
Definition: NodekNN.h:181
Float_t GetVarDis() const
Definition: NodekNN.h:187
Float_t GetVarMax() const
Definition: NodekNN.h:199
const Node * GetNodeL() const
Definition: NodekNN.h:163
Float_t fVarMin
Definition: NodekNN.h:116
const Node * GetNodeP() const
Definition: NodekNN.h:175
const Float_t fVarDis
Definition: NodekNN.h:114
const UInt_t fMod
Definition: NodekNN.h:119
void Print() const
Definition: NodekNN.h:286
Node * fNodeR
Definition: NodekNN.h:110
const Node & operator=(const Node &)
Node(const Node &)
const Node * GetNodeR() const
Definition: NodekNN.h:169
UInt_t GetMod() const
Definition: NodekNN.h:205
const T fEvent
Definition: NodekNN.h:112
const Node * Add(const T &event, UInt_t depth)
This is Node member function that adds a new node to a binary tree.
Definition: NodekNN.h:250
Node * fNodeL
Definition: NodekNN.h:109
Float_t fVarMax
Definition: NodekNN.h:117
const T & GetEvent() const
Definition: NodekNN.h:157
kNN::Event describes point in input variable vector-space, with additional functionality like distanc...
double T(double x)
Definition: ChebyshevPol.h:34
void Print(std::ostream &os, const OptionType &opt)
static constexpr double second
UInt_t Find(std::list< std::pair< const Node< T > *, Float_t > > &nlist, const Node< T > *node, const T &event, UInt_t nfind)
UInt_t Depth(const Node< T > *node)
Definition: NodekNN.h:214
create variable transformations
Definition: first.py:1