Logo ROOT   6.10/09
Reference Guide
NodekNN.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Rustem Ospanov
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : Node *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * kd-tree (binary tree) template *
12  * *
13  * Author: *
14  * Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA *
15  * *
16  * Copyright (c) 2007: *
17  * CERN, Switzerland *
18  * MPI-K Heidelberg, Germany *
19  * U. of Texas at Austin, USA *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
23  * (http://tmva.sourceforge.net/LICENSE) *
24  **********************************************************************************/
25 
26 #ifndef ROOT_TMVA_NodekNN
27 #define ROOT_TMVA_NodekNN
28 
29 // C++
30 #include <cassert>
31 #include <list>
32 #include <string>
33 #include <iostream>
34 
35 // ROOT
36 #include "Rtypes.h"
37 
38 /*! \class TMVA::kNN:Node
39 \ingroup TMVA
40 This file contains binary tree and global function template
41 that searches tree for k-nearest neigbors
42 
43 Node class template parameter T has to provide these functions:
44  rtype GetVar(UInt_t) const;
45  - rtype is any type convertible to Float_t
46  UInt_t GetNVar(void) const;
47  rtype GetWeight(void) const;
48  - rtype is any type convertible to Double_t
49 
50 Find function template parameter T has to provide these functions:
51 (in addition to above requirements)
52  rtype GetDist(Float_t, UInt_t) const;
53  - rtype is any type convertible to Float_t
54  rtype GetDist(const T &) const;
55  - rtype is any type convertible to Float_t
56 
57  where T::GetDist(Float_t, UInt_t) <= T::GetDist(const T &)
58  for any pair of events and any variable number for these events
59 */
60 
61 namespace TMVA
62 {
63  namespace kNN
64  {
65  template <class T>
66  class Node
67  {
68 
69  public:
70 
71  Node(const Node *parent, const T &event, Int_t mod);
72  ~Node();
73 
74  const Node* Add(const T &event, UInt_t depth);
75 
76  void SetNodeL(Node *node);
77  void SetNodeR(Node *node);
78 
79  const T& GetEvent() const;
80 
81  const Node* GetNodeL() const;
82  const Node* GetNodeR() const;
83  const Node* GetNodeP() const;
84 
85  Double_t GetWeight() const;
86 
87  Float_t GetVarDis() const;
88  Float_t GetVarMin() const;
89  Float_t GetVarMax() const;
90 
91  UInt_t GetMod() const;
92 
93  void Print() const;
94  void Print(std::ostream& os, const std::string &offset = "") const;
95 
96  private:
97 
98  // these methods are private and not implemented by design
99  // use provided public constructor for all uses of this template class
100  Node();
101  Node(const Node &);
102  const Node& operator=(const Node &);
103 
104  private:
105 
106  const Node* fNodeP;
107 
110 
111  const T fEvent;
112 
114 
117 
118  const UInt_t fMod;
119  };
120 
121  // recursive search for k-nearest neighbor: k = nfind
122  template<class T>
123  UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
124  const Node<T> *node, const T &event, UInt_t nfind);
125 
126  // recursive search for k-nearest neighbor
127  // find k events with sum of event weights >= nfind
128  template<class T>
129  UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
130  const Node<T> *node, const T &event, Double_t nfind, Double_t ncurr);
131 
132  // recursively travel upward until root node is reached
133  template <class T>
134  UInt_t Depth(const Node<T> *node);
135 
136  // prInt_t node content and content of its children
137  //template <class T>
138  //std::ostream& operator<<(std::ostream& os, const Node<T> &node);
139 
140  //
141  // Inlined functions for Node template
142  //
143  template <class T>
144  inline void Node<T>::SetNodeL(Node<T> *node)
145  {
146  fNodeL = node;
147  }
148 
149  template <class T>
150  inline void Node<T>::SetNodeR(Node<T> *node)
151  {
152  fNodeR = node;
153  }
154 
155  template <class T>
156  inline const T& Node<T>::GetEvent() const
157  {
158  return fEvent;
159  }
160 
161  template <class T>
162  inline const Node<T>* Node<T>::GetNodeL() const
163  {
164  return fNodeL;
165  }
166 
167  template <class T>
168  inline const Node<T>* Node<T>::GetNodeR() const
169  {
170  return fNodeR;
171  }
172 
173  template <class T>
174  inline const Node<T>* Node<T>::GetNodeP() const
175  {
176  return fNodeP;
177  }
178 
179  template <class T>
181  {
182  return fEvent.GetWeight();
183  }
184 
185  template <class T>
187  {
188  return fVarDis;
189  }
190 
191  template <class T>
193  {
194  return fVarMin;
195  }
196 
197  template <class T>
199  {
200  return fVarMax;
201  }
202 
203  template <class T>
204  inline UInt_t Node<T>::GetMod() const
205  {
206  return fMod;
207  }
208 
209  //
210  // Inlined global function(s)
211  //
212  template <class T>
213  inline UInt_t Depth(const Node<T> *node)
214  {
215  if (!node) return 0;
216  else return Depth(node->GetNodeP()) + 1;
217  }
218 
219  } // end of kNN namespace
220 } // end of TMVA namespace
221 
222 ////////////////////////////////////////////////////////////////////////////////
223 template<class T>
224 TMVA::kNN::Node<T>::Node(const Node<T> *parent, const T &event, const Int_t mod)
225 :fNodeP(parent),
226  fNodeL(0),
227  fNodeR(0),
228  fEvent(event),
229  fVarDis(event.GetVar(mod)),
230  fVarMin(fVarDis),
231  fVarMax(fVarDis),
232  fMod(mod)
233 {}
234 
235 ////////////////////////////////////////////////////////////////////////////////
236 template<class T>
238 {
239  if (fNodeL) delete fNodeL;
240  if (fNodeR) delete fNodeR;
241 }
242 
243 ////////////////////////////////////////////////////////////////////////////////
244 /// This is Node member function that adds a new node to a binary tree.
245 /// each node contains maximum and minimum values of splitting variable
246 /// left or right nodes are added based on value of splitting variable
247 
248 template<class T>
249 const TMVA::kNN::Node<T>* TMVA::kNN::Node<T>::Add(const T &event, const UInt_t depth)
250 {
251 
252  assert(fMod == depth % event.GetNVar() && "Wrong recursive depth in Node<>::Add");
253 
254  const Float_t value = event.GetVar(fMod);
255 
256  fVarMin = std::min(fVarMin, value);
257  fVarMax = std::max(fVarMax, value);
258 
259  Node<T> *node = 0;
260  if (value < fVarDis) {
261  if (fNodeL)
262  {
263  return fNodeL->Add(event, depth + 1);
264  }
265  else {
266  fNodeL = new Node<T>(this, event, (depth + 1) % event.GetNVar());
267  node = fNodeL;
268  }
269  }
270  else {
271  if (fNodeR) {
272  return fNodeR->Add(event, depth + 1);
273  }
274  else {
275  fNodeR = new Node<T>(this, event, (depth + 1) % event.GetNVar());
276  node = fNodeR;
277  }
278  }
279 
280  return node;
281 }
282 
283 ////////////////////////////////////////////////////////////////////////////////
284 template<class T>
286 {
287  Print(std::cout);
288 }
289 
290 ////////////////////////////////////////////////////////////////////////////////
291 template<class T>
292 void TMVA::kNN::Node<T>::Print(std::ostream& os, const std::string &offset) const
293 {
294  os << offset << "-----------------------------------------------------------" << std::endl;
295  os << offset << "Node: mod " << fMod
296  << " at " << fVarDis
297  << " with weight: " << GetWeight() << std::endl
298  << offset << fEvent;
299 
300  if (fNodeL) {
301  os << offset << "Has left node " << std::endl;
302  }
303  if (fNodeR) {
304  os << offset << "Has right node" << std::endl;
305  }
306 
307  if (fNodeL) {
308  os << offset << "PrInt_t left node " << std::endl;
309  fNodeL->Print(os, offset + " ");
310  }
311  if (fNodeR) {
312  os << offset << "PrInt_t right node" << std::endl;
313  fNodeR->Print(os, offset + " ");
314  }
315 
316  if (!fNodeL && !fNodeR) {
317  os << std::endl;
318  }
319 }
320 
321 ////////////////////////////////////////////////////////////////////////////////
322 /// This is a global templated function that searches for k-nearest neighbors.
323 /// list contains k or less nodes that are closest to event.
324 /// only nodes with positive weights are added to list.
325 /// each node contains maximum and minimum values of splitting variable
326 /// for all its children - this range is checked to avoid descending into
327 /// nodes that are definitely outside current minimum neighbourhood.
328 ///
329 /// This function should be modified with care.
330 
331 template<class T>
332 UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
333  const TMVA::kNN::Node<T> *node, const T &event, const UInt_t nfind)
334 {
335  if (!node || nfind < 1) {
336  return 0;
337  }
338 
339  const Float_t value = event.GetVar(node->GetMod());
340 
341  if (node->GetWeight() > 0.0) {
342 
343  Float_t max_dist = 0.0;
344 
345  if (!nlist.empty()) {
346 
347  max_dist = nlist.back().second;
348 
349  if (nlist.size() == nfind) {
350  if (value > node->GetVarMax() &&
351  event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
352  return 0;
353  }
354  if (value < node->GetVarMin() &&
355  event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
356  return 0;
357  }
358  }
359  }
360 
361  const Float_t distance = event.GetDist(node->GetEvent());
362 
363  Bool_t insert_this = kFALSE;
364  Bool_t remove_back = kFALSE;
365 
366  if (nlist.size() < nfind) {
367  insert_this = kTRUE;
368  }
369  else if (nlist.size() == nfind) {
370  if (distance < max_dist) {
371  insert_this = kTRUE;
372  remove_back = kTRUE;
373  }
374  }
375  else {
376  std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
377  return 1;
378  }
379 
380  if (insert_this) {
381  // need typename keyword because qualified dependent names
382  // are not valid types unless preceded by 'typename'.
383  typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
384 
385  // find a place where current node should be inserted
386  for (; lit != nlist.end(); ++lit) {
387  if (distance < lit->second) {
388  break;
389  }
390  else {
391  continue;
392  }
393  }
394 
395  nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
396 
397  if (remove_back) {
398  nlist.pop_back();
399  }
400  }
401  }
402 
403  UInt_t count = 1;
404  if (node->GetNodeL() && node->GetNodeR()) {
405  if (value < node->GetVarDis()) {
406  count += Find(nlist, node->GetNodeL(), event, nfind);
407  count += Find(nlist, node->GetNodeR(), event, nfind);
408  }
409  else {
410  count += Find(nlist, node->GetNodeR(), event, nfind);
411  count += Find(nlist, node->GetNodeL(), event, nfind);
412  }
413  }
414  else {
415  if (node->GetNodeL()) {
416  count += Find(nlist, node->GetNodeL(), event, nfind);
417  }
418  if (node->GetNodeR()) {
419  count += Find(nlist, node->GetNodeR(), event, nfind);
420  }
421  }
422 
423  return count;
424 }
425 
426 ////////////////////////////////////////////////////////////////////////////////
427 /// This is a global templated function that searches for k-nearest neighbors.
428 /// list contains all nodes that are closest to event
429 /// and have sum of event weights >= nfind.
430 /// Only nodes with positive weights are added to list.
431 /// Requirement for used classes:
432 /// - each node contains maximum and minimum values of splitting variable
433 /// for all its children
434 /// - min and max range is checked to avoid descending into
435 /// nodes that are definitely outside current minimum neighbourhood.
436 ///
437 /// This function should be modified with care.
438 
439 template<class T>
440 UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
441  const TMVA::kNN::Node<T> *node, const T &event, const Double_t nfind, Double_t ncurr)
442 {
443 
444  if (!node || !(nfind < 0.0)) {
445  return 0;
446  }
447 
448  const Float_t value = event.GetVar(node->GetMod());
449 
450  if (node->GetWeight() > 0.0) {
451 
452  Float_t max_dist = 0.0;
453 
454  if (!nlist.empty()) {
455 
456  max_dist = nlist.back().second;
457 
458  if (!(ncurr < nfind)) {
459  if (value > node->GetVarMax() &&
460  event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
461  return 0;
462  }
463  if (value < node->GetVarMin() &&
464  event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
465  return 0;
466  }
467  }
468  }
469 
470  const Float_t distance = event.GetDist(node->GetEvent());
471 
472  Bool_t insert_this = kFALSE;
473 
474  if (ncurr < nfind) {
475  insert_this = kTRUE;
476  }
477  else if (!nlist.empty()) {
478  if (distance < max_dist) {
479  insert_this = kTRUE;
480  }
481  }
482  else {
483  std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
484  return 1;
485  }
486 
487  if (insert_this) {
488  // (re)compute total current weight when inserting a new node
489  ncurr = 0;
490 
491  // need typename keyword because qualified dependent names
492  // are not valid types unless preceded by 'typename'.
493  typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
494 
495  // find a place where current node should be inserted
496  for (; lit != nlist.end(); ++lit) {
497  if (distance < lit->second) {
498  break;
499  }
500 
501  ncurr += lit -> first -> GetWeight();
502  }
503 
504  lit = nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
505 
506  for (; lit != nlist.end(); ++lit) {
507  ncurr += lit -> first -> GetWeight();
508  if (!(ncurr < nfind)) {
509  ++lit;
510  break;
511  }
512  }
513 
514  if(lit != nlist.end())
515  {
516  nlist.erase(lit, nlist.end());
517  }
518  }
519  }
520 
521  UInt_t count = 1;
522  if (node->GetNodeL() && node->GetNodeR()) {
523  if (value < node->GetVarDis()) {
524  count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
525  count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
526  }
527  else {
528  count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
529  count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
530  }
531  }
532  else {
533  if (node->GetNodeL()) {
534  count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
535  }
536  if (node->GetNodeR()) {
537  count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
538  }
539  }
540 
541  return count;
542 }
543 
544 #endif
545 
const Node * fNodeP
Definition: NodekNN.h:106
Float_t GetVarMin() const
Definition: NodekNN.h:192
float Float_t
Definition: RtypesCore.h:53
double T(double x)
Definition: ChebyshevPol.h:34
Node * fNodeR
Definition: NodekNN.h:109
void Print() const
Definition: NodekNN.h:285
const Node * GetNodeP() const
Definition: NodekNN.h:174
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
UInt_t Find(std::list< std::pair< const Node< T > *, Float_t > > &nlist, const Node< T > *node, const T &event, Double_t nfind, Double_t ncurr)
const Node * Add(const T &event, UInt_t depth)
This is Node member function that adds a new node to a binary tree.
Definition: NodekNN.h:249
UInt_t Depth(const Node< T > *node)
Definition: NodekNN.h:213
Double_t GetWeight() const
Definition: NodekNN.h:180
const Node & operator=(const Node &)
const Node * GetNodeL() const
Definition: NodekNN.h:162
const T fEvent
Definition: NodekNN.h:111
void SetNodeL(Node *node)
Definition: NodekNN.h:144
const Node * GetNodeR() const
Definition: NodekNN.h:168
const Float_t fVarDis
Definition: NodekNN.h:113
UInt_t GetMod() const
Definition: NodekNN.h:204
unsigned int UInt_t
Definition: RtypesCore.h:42
const UInt_t fMod
Definition: NodekNN.h:118
const Bool_t kFALSE
Definition: RtypesCore.h:92
Float_t fVarMin
Definition: NodekNN.h:115
void Print(std::ostream &os, const OptionType &opt)
double Double_t
Definition: RtypesCore.h:55
Float_t fVarMax
Definition: NodekNN.h:116
Float_t GetVarMax() const
Definition: NodekNN.h:198
UInt_t Find(std::list< std::pair< const Node< T > *, Float_t > > &nlist, const Node< T > *node, const T &event, UInt_t nfind)
Abstract ClassifierFactory template that handles arbitrary types.
Float_t GetVarDis() const
Definition: NodekNN.h:186
void SetNodeR(Node *node)
Definition: NodekNN.h:150
Node * fNodeL
Definition: NodekNN.h:108
Definition: first.py:1
const Bool_t kTRUE
Definition: RtypesCore.h:91
const T & GetEvent() const
Definition: NodekNN.h:156