ROOT  6.07/01
Reference Guide
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Properties Friends Macros Groups Pages
newKDTreeTest.cxx
Go to the documentation of this file.
1 // @(#)root/mathcore:$Id$
2 // Author: C. Gumpert 09/2011
3 
4 // program to test new KDTree class
5 
6 #include <time.h>
7 // STL include(s)
8 #include <iostream>
9 #include <stdlib.h>
10 #include <vector>
11 #include "assert.h"
12 
13 // custom include(s)
14 #include "Math/KDTree.h"
15 #include "Math/TDataPoint.h"
16 
17 template<class _DataPoint>
18 void CreatePseudoData(const unsigned long int nPoints,std::vector<const _DataPoint*>& vDataPoints)
19 {
20  _DataPoint* pData = 0;
21  for(unsigned long int i = 0; i < nPoints; ++i)
22  {
23  pData = new _DataPoint();
24  for(unsigned int k = 0; k < _DataPoint::Dimension(); ++k)
25  pData->SetCoordinate(k,rand() % 1000);
26  pData->SetWeight(rand() % 1000);
27  vDataPoints.push_back(pData);
28  }
29 }
30 
31 template<class _DataPoint>
32 void DeletePseudoData(std::vector<const _DataPoint*>& vDataPoints)
33 {
34  for(typename std::vector<const _DataPoint*>::iterator it = vDataPoints.begin();
35  it != vDataPoints.end(); ++it)
36  delete *it;
37 
38  vDataPoints.clear();
39 }
40 
41 template<class _DataPoint>
42 ROOT::Math::KDTree<_DataPoint>* BuildTree(const std::vector<const _DataPoint*>& vDataPoints,const unsigned int iBucketSize)
43 {
45  try
46  {
47  pTree = new ROOT::Math::KDTree<_DataPoint>(iBucketSize);
48  //pTree->SetSplitOption(TKDTree<_DataPoint>::kBinContent);
49 
50  for(typename std::vector<const _DataPoint*>::const_iterator it = vDataPoints.begin(); it != vDataPoints.end(); ++it)
51  pTree->Insert(**it);
52  }
53  catch (std::exception& e)
54  {
55  std::cerr << "exception caught: " << e.what() << std::endl;
56  if(pTree)
57  delete pTree;
58 
59  pTree = 0;
60  }
61 
62  return pTree;
63 }
64 
65 template<class _DataPoint>
66 bool CheckBasicTreeProperties(const ROOT::Math::KDTree<_DataPoint>* pTree,const std::vector<const _DataPoint*>& vDataPoints)
67 {
68  if(pTree->GetEntries() != vDataPoints.size())
69  {
70  std::cout << " --> wrong number of data points in tree: " << pTree->GetEntries() << " != " << vDataPoints.size() << std::endl;
71  return false;
72  }
73 
74  double fSumw = 0;
75  double fSumw2 = 0;
76  for(typename std::vector<const _DataPoint*>::const_iterator it = vDataPoints.begin();
77  it != vDataPoints.end(); ++it)
78  {
79  fSumw += (*it)->GetWeight();
80  fSumw2 += pow((*it)->GetWeight(),2);
81  }
82 
83  if(fabs(pTree->GetTotalSumw2() - fSumw2)/fSumw2 > 1e-4)
84  {
85  std::cout << " --> inconsistent Sum weights^2 in tree: " << pTree->GetTotalSumw2() << " != " << fSumw2 << std::endl;
86  return false;
87  }
88 
89  if(fabs(pTree->GetTotalSumw() - fSumw)/fSumw > 1e-4)
90  {
91  std::cout << " --> inconsistent Sum weights in tree: " << pTree->GetTotalSumw() << " != " << fSumw << std::endl;
92  return false;
93  }
94 
95  if(fabs(pTree->GetEffectiveEntries() - pow(fSumw,2)/fSumw2)/(pow(fSumw,2)/fSumw2) > 1e-4)
96  {
97  std::cout << " --> inconsistent effective entries in tree: " << pTree->GetEffectiveEntries() << " != " << pow(fSumw,2)/fSumw2 << std::endl;
98  return false;
99  }
100 
101  return true;
102 }
103 
104 template<class _DataPoint>
106 {
107  typedef std::pair<typename _DataPoint::value_type,typename _DataPoint::value_type> tBoundary;
108 
109  std::cout << " --> checking " << pTree->GetNBins() << " bins" << std::endl;
110 
111  unsigned int iBin = 0;
112  for(typename ROOT::Math::KDTree<_DataPoint>::iterator it = pTree->First(); it != pTree->End(); ++it,++iBin)
113  {
114  const std::vector<const _DataPoint*> vDataPoints = it.TN()->GetPoints();
115  assert(vDataPoints.size() == it->GetEntries());
116 
117  std::vector<tBoundary> vBoundaries = it->GetBoundaries();
118  assert(_DataPoint::Dimension() == vBoundaries.size());
119 
120  // check whether all points in this bin are inside the boundaries
121  for(typename std::vector<const _DataPoint*>::const_iterator pit = vDataPoints.begin();
122  pit != vDataPoints.end(); ++pit)
123  {
124  for(unsigned int k = 0; k < _DataPoint::Dimension(); ++k)
125  {
126  if(((*pit)->GetCoordinate(k) < vBoundaries.at(k).first) || ((*pit)->GetCoordinate(k) > vBoundaries.at(k).second))
127  {
128  std::cout << " --> boundaries of bin " << iBin << " in " << k << ". dimension are inconsistent with data point in bucket" << std::endl;
129  return false;
130  }
131  }
132  }
133  }
134 
135  return true;
136 }
137 
138 template<class _DataPoint>
140 {
141  for(typename ROOT::Math::KDTree<_DataPoint>::iterator it = pTree->First(); it != pTree->End(); ++it)
142  {
143  if(it->GetEffectiveEntries() > 2*pTree->GetBucketSize())
144  {
145  std::cout << " --> found bin with " << it->GetEffectiveEntries() << " while the bucketsize is " << pTree->GetBucketSize() << std::endl;
146  return false;
147  }
148  }
149 
150  return true;
151 }
152 
153 template<class _DataPoint>
155 {
156  typedef std::pair<typename _DataPoint::value_type,typename _DataPoint::value_type> tBoundary;
157 
158  _DataPoint test;
159  std::cout << " --> test reference point at (";
160  for(unsigned int k = 0; k < _DataPoint::Dimension(); ++k)
161  {
162  test.SetCoordinate(k,rand() % 1000);
163  std::cout << test.GetCoordinate(k);
164  if(k < _DataPoint::Dimension()-1)
165  std::cout << ",";
166  }
167  std::cout << ")" << std::endl;
168 
169  const typename ROOT::Math::KDTree<_DataPoint>::Bin* bin = pTree->FindBin(test);
170 
171  // check whether test point is actually inside the bin boundaries
172  // is not necessarily the case if the point as the range of the bin which is NOT determined by a splitting but by the minimum coordinate of points inside the bin
173  std::vector<tBoundary> vBoundaries = bin->GetBoundaries();
174  assert(_DataPoint::Dimension() == vBoundaries.size());
175 
176  for(unsigned int k = 0; k < _DataPoint::Dimension(); ++k)
177  {
178  if((test.GetCoordinate(k) < vBoundaries.at(k).first) || (test.GetCoordinate(k) > vBoundaries.at(k).second))
179  {
180  if(pTree->IsFrozen() && bin)
181  {
182  std::cout << " --> " << test.GetCoordinate(k) << " is not within (" << vBoundaries.at(k).first << "," << vBoundaries.at(k).second << ")" << std::endl;
183  return false;
184  }
185  }
186  }
187 
188  return true;
189 }
190 
191 template<class _DataPoint>
192 bool CheckNearestNeighborSearches(const ROOT::Math::KDTree<_DataPoint>* pTree,const std::vector<const _DataPoint*>& vDataPoints)
193 {
194  _DataPoint test;
195  std::cout << " --> test with reference point at (";
196  for(unsigned int k = 0; k < _DataPoint::Dimension(); ++k)
197  {
198  test.SetCoordinate(k,rand() % 1000);
199  std::cout << test.GetCoordinate(k);
200  if(k < _DataPoint::Dimension()-1)
201  std::cout << ",";
202  }
203  std::cout << ")" << std::endl;
204 
205  std::vector<const _DataPoint*> vFoundPoints;
206  std::vector<const _DataPoint*> vFoundPointsCheck;
207 
208  double fDist = rand() % 500;
209  std::cout << " --> look for points within in distance of " << fDist << std::endl;
210  pTree->GetPointsWithinDist(test,fDist,vFoundPoints);
211 
212  // get points by hand
213  for(typename std::vector<const _DataPoint*>::const_iterator it = vDataPoints.begin();
214  it != vDataPoints.end(); ++it)
215  {
216  if((*it)->Distance(test) <= fDist)
217  {
218  vFoundPointsCheck.push_back(*it);
219  // check whether this point was also found by the algorithm
220  bool bChecked = false;
221  for(unsigned int i = 0; i < vFoundPoints.size(); ++i)
222  {
223  if(vFoundPoints.at(i) == *it)
224  {
225  bChecked = true;
226  break;
227  }
228  }
229 
230  if(!bChecked)
231  {
232  std::cout << " --> point (";
233  for(unsigned int k = 0; k < _DataPoint::Dimension(); ++k)
234  {
235  std::cout << (*it)->GetCoordinate(k);
236  if(k < _DataPoint::Dimension()-1)
237  std::cout << ",";
238  }
239  std::cout << ") was not found by the algorithm while its distance to the reference point is " << (*it)->Distance(test) << std::endl;
240 
241  return false;
242  }
243  }
244  }
245 
246  if(vFoundPointsCheck.size() != vFoundPoints.size())
247  {
248  std::cout << " --> GetPointsWithinDist returns wrong number of found points (" << vFoundPointsCheck.size() << " expected/ " << vFoundPoints.size() << " found)" << std::endl;
249  return false;
250  }
251 
252  const int nNeighbors = (int)(rand() % 100/1000.0 * pTree->GetEntries() + 1);
253  std::cout << " --> look for " << nNeighbors << " nearest neighbors" << std::endl;
254 
255  std::vector<std::pair<const _DataPoint*,double> > vFoundNeighbors;
256  std::vector<std::pair<const _DataPoint*,double> > vFoundNeighborsCheck;
257  typename std::vector<std::pair<const _DataPoint*,double> >::iterator nit;
258 
259  pTree->GetClosestPoints(test,nNeighbors,vFoundNeighbors);
260  fDist = vFoundNeighbors.back().second;
261 
262  // check closest points manually
263  for(typename std::vector<const _DataPoint*>::const_iterator it = vDataPoints.begin();
264  it != vDataPoints.end(); ++it)
265  {
266  if((*it)->Distance(test) <= fDist)
267  vFoundNeighborsCheck.push_back(std::make_pair(*it,(*it)->Distance(test)));
268  }
269 
270  // vFoundNeighborsCheck can have more data points because there might be more points with the same (maximal) distance
271  if(vFoundNeighborsCheck.size() < vFoundNeighbors.size())
272  {
273  std::cout << " --> GetClosestPoints returns wrong number of found points (" << vFoundNeighborsCheck.size() << " expected/ " << vFoundNeighbors.size() << " found)" << std::endl;
274  return false;
275  }
276 
277  //check whether all points found by the algorithm are also found manually
278  bool bChecked = false;
279  for(unsigned int i = 0; i < vFoundNeighbors.size(); ++i)
280  {
281  bChecked = false;
282  for(unsigned int j = 0; j < vFoundNeighborsCheck.size(); ++j)
283  {
284  if(vFoundNeighbors.at(i).first == vFoundNeighborsCheck.at(j).first)
285  {
286  if(fabs(vFoundNeighbors.at(i).second - vFoundNeighborsCheck.at(j).second)/vFoundNeighbors.at(i).second < 1e-2)
287  bChecked = true;
288 
289  break;
290  }
291  }
292 
293  if(!bChecked)
294  return false;
295  }
296 
297  return true;
298 }
299 
300 template<class _DataPoint>
301 bool CheckTreeClear(ROOT::Math::KDTree<_DataPoint>* pTree,const std::vector<const _DataPoint*>& vDataPoints)
302 {
303  pTree->Reset();
304  if(pTree->GetEntries() != 0)
305  {
306  std::cout << " --> tree contains still " << pTree->GetEntries() << " data points after calling Clear()" << std::endl;
307  return false;
308  }
309  if(pTree->GetNBins() != 1)
310  {
311  std::cout << " --> tree contains more than one bin after calling Clear()" << std::endl;
312  return false;
313  }
314  if(pTree->GetEffectiveEntries() != 0)
315  {
316  std::cout << " --> tree contains still " << pTree->GetEffectiveEntries() << " effective entries after calling Clear()" << std::endl;
317  return false;
318  }
319 
320  // try to fill tree again
321  try
322  {
323  for(typename std::vector<const _DataPoint*>::const_iterator it = vDataPoints.begin(); it != vDataPoints.end(); ++it)
324  pTree->Insert(**it);
325  }
326  catch (std::exception& e)
327  {
328  std::cout << " --> unable to fill tree after calling Clear()" << std::endl;
329  std::cerr << "exception caught: " << e.what() << std::endl;
330 
331  return false;
332  }
333 
334  return true;
335 }
336 
337 int main()
338 {
339  std::cout << "\nunit test for class KDTree" << std::endl;
340  std::cout << "==========================\n" << std::endl;
341 
342  int iSeed = time(0);
343  std::cout << "using random seed: " << iSeed << std::endl;
344 
345  srand(iSeed);
346 
347  const unsigned long int NPOINTS = 1e5;
348  const unsigned int BUCKETSIZE = 1e2;
349  const unsigned int DIM = 5;
350 
351  typedef ROOT::Math::TDataPoint<DIM> DP;
352 
353  std::cout << "using " << NPOINTS << " data points in " << DIM << " dimensions" << std::endl;
354  std::cout << "bucket size: " << BUCKETSIZE << std::endl;
355 
356  std::vector<const DP*> vDataPoints;
357  CreatePseudoData(NPOINTS,vDataPoints);
358 
359  ROOT::Math::KDTree<DP>* pTree = BuildTree(vDataPoints,BUCKETSIZE);
360 
361  if(CheckBasicTreeProperties(pTree,vDataPoints))
362  std::cerr << "basic tree properties...DONE" << std::endl;
363  else
364  std::cerr << "basic tree properties...FAILED" << std::endl;
365 
366  if(CheckBinBoundaries(pTree))
367  std::cerr << "consistency check of bin boundaries...DONE" << std::endl;
368  else
369  std::cerr << "consistency check of bin boundaries...FAILED" << std::endl;
370 
371  if(CheckEffectiveBinEntries(pTree))
372  std::cerr << "check effective entries per bin...DONE" << std::endl;
373  else
374  std::cerr << "check effective entries per bin...FAILED" << std::endl;
375 
376  if(CheckFindBin(pTree))
377  std::cerr << "check FindBin...DONE" << std::endl;
378  else
379  std::cerr << "check FindBin...FAILED" << std::endl;
380 
381  if(CheckNearestNeighborSearches(pTree,vDataPoints))
382  std::cerr << "check nearest neighbor searches...DONE" << std::endl;
383  else
384  std::cerr << "check nearest neighbor searches...FAILED" << std::endl;
385 
386  if(CheckTreeClear(pTree,vDataPoints))
387  std::cerr << "check KDTree::Clear...DONE" << std::endl;
388  else
389  std::cerr << "check KDTree:Clear...FAILED" << std::endl;
390 
391  //pTree->Print();
392  pTree->Freeze();
393  //pTree->Print();
394  ROOT::Math::KDTree<DP>* pCopy = pTree->GetFrozenCopy();
395  //pCopy->Print();
396 
397  delete pCopy;
398  delete pTree;
399 
400  DeletePseudoData(vDataPoints);
401 
402  return 0;
403 }
void DeletePseudoData(std::vector< const _DataPoint * > &vDataPoints)
#define assert(cond)
Definition: unittest.h:542
Bool_t Insert(const point_type &rData)
Definition: KDTree.h:350
bool CheckNearestNeighborSearches(const ROOT::Math::KDTree< _DataPoint > *pTree, const std::vector< const _DataPoint * > &vDataPoints)
Double_t GetEffectiveEntries() const
Definition: KDTree.icc:220
bool CheckTreeClear(ROOT::Math::KDTree< _DataPoint > *pTree, const std::vector< const _DataPoint * > &vDataPoints)
iterator End()
Definition: KDTree.icc:99
bool CheckBasicTreeProperties(const ROOT::Math::KDTree< _DataPoint > *pTree, const std::vector< const _DataPoint * > &vDataPoints)
double pow(double, double)
ROOT::Math::KDTree< _DataPoint > * BuildTree(const std::vector< const _DataPoint * > &vDataPoints, const unsigned int iBucketSize)
VecExpr< UnaryOp< Fabs< T >, VecExpr< A, T, D >, T >, T, D > fabs(const VecExpr< A, T, D > &rhs)
Bool_t IsFrozen() const
Definition: KDTree.h:351
bool CheckFindBin(const ROOT::Math::KDTree< _DataPoint > *pTree)
Double_t GetTotalSumw2() const
Definition: KDTree.icc:317
void GetClosestPoints(const point_type &rRef, UInt_t nPoints, std::vector< std::pair< const _DataPoint *, Double_t > > &vFoundPoints) const
Definition: KDTree.icc:196
int main()
Double_t GetTotalSumw() const
Definition: KDTree.icc:304
Double_t GetBucketSize() const
Definition: KDTree.h:341
void CreatePseudoData(const unsigned long int nPoints, std::vector< const _DataPoint * > &vDataPoints)
const Bin * FindBin(const point_type &rPoint) const
Definition: KDTree.h:337
bool CheckEffectiveBinEntries(const ROOT::Math::KDTree< _DataPoint > *pTree)
KDTree< _DataPoint > * GetFrozenCopy()
Definition: KDTree.icc:252
UInt_t GetNBins() const
Definition: KDTree.icc:269
virtual const std::vector< tBoundary > & GetBoundaries() const
Definition: KDTree.h:199
bool CheckBinBoundaries(const ROOT::Math::KDTree< _DataPoint > *pTree)
UInt_t GetEntries() const
Definition: KDTree.icc:239
iterator First()
Definition: KDTree.icc:127
void GetPointsWithinDist(const point_type &rRef, value_type fDist, std::vector< const point_type * > &vFoundPoints) const
Definition: KDTree.icc:282