ROOT  6.06/09
Reference Guide
MethodKNN.cxx
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Rustem Ospanov
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : MethodKNN *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Implementation *
12  * *
13  * Author: *
14  * Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA *
15  * *
16  * Copyright (c) 2007: *
17  * CERN, Switzerland *
18  * MPI-K Heidelberg, Germany *
19  * U. of Texas at Austin, USA *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
23  * (http://tmva.sourceforge.net/LICENSE) *
24  **********************************************************************************/
25 
26 //////////////////////////////////////////////////////////////////////////
27 // //
28 // MethodKNN //
29 // //
30 // Analysis of k-nearest neighbor //
31 // //
32 //////////////////////////////////////////////////////////////////////////
33 
34 // C/C++
35 #include <cmath>
36 #include <string>
37 #include <cstdlib>
38 
39 // ROOT
40 #include "TFile.h"
41 #include "TMath.h"
42 #include "TTree.h"
43 
44 // TMVA
45 #include "TMVA/ClassifierFactory.h"
46 #include "TMVA/MethodKNN.h"
47 #include "TMVA/Ranking.h"
48 #include "TMVA/Tools.h"
49 
50 REGISTER_METHOD(KNN)
51 
52 ClassImp(TMVA::MethodKNN)
53 
54 ////////////////////////////////////////////////////////////////////////////////
55 /// standard constructor
56 
57 TMVA::MethodKNN::MethodKNN( const TString& jobName,
58  const TString& methodTitle,
59  DataSetInfo& theData,
60  const TString& theOption,
61  TDirectory* theTargetDir )
62  : TMVA::MethodBase(jobName, Types::kKNN, methodTitle, theData, theOption, theTargetDir)
63  , fSumOfWeightsS(0)
64  , fSumOfWeightsB(0)
65  , fModule(0)
66  , fnkNN(0)
67  , fBalanceDepth(0)
68  , fScaleFrac(0)
69  , fSigmaFact(0)
70  , fTrim(kFALSE)
71  , fUseKernel(kFALSE)
72  , fUseWeight(kFALSE)
73  , fUseLDA(kFALSE)
74  , fTreeOptDepth(0)
75 {
76 }
77 
78 ////////////////////////////////////////////////////////////////////////////////
79 /// constructor from weight file
80 
82  const TString& theWeightFile,
83  TDirectory* theTargetDir )
84  : TMVA::MethodBase( Types::kKNN, theData, theWeightFile, theTargetDir)
85  , fSumOfWeightsS(0)
86  , fSumOfWeightsB(0)
87  , fModule(0)
88  , fnkNN(0)
89  , fBalanceDepth(0)
90  , fScaleFrac(0)
91  , fSigmaFact(0)
92  , fTrim(kFALSE)
93  , fUseKernel(kFALSE)
94  , fUseWeight(kFALSE)
95  , fUseLDA(kFALSE)
96  , fTreeOptDepth(0)
97 {
98 }
99 
100 ////////////////////////////////////////////////////////////////////////////////
101 /// destructor
102 
104 {
105  if (fModule) delete fModule;
106 }
107 
108 ////////////////////////////////////////////////////////////////////////////////
109 /// MethodKNN options
110 
112 {
113  // fnkNN = 20; // number of k-nearest neighbors
114  // fBalanceDepth = 6; // number of binary tree levels used for tree balancing
115  // fScaleFrac = 0.8; // fraction of events used to compute variable width
116  // fSigmaFact = 1.0; // scale factor for Gaussian sigma
117  // fKernel = use polynomial (1-x^3)^3 or Gaussian kernel
118  // fTrim = false; // use equal number of signal and background events
119  // fUseKernel = false; // use polynomial kernel weight function
120  // fUseWeight = true; // count events using weights
121  // fUseLDA = false
122 
123  DeclareOptionRef(fnkNN = 20, "nkNN", "Number of k-nearest neighbors");
124  DeclareOptionRef(fBalanceDepth = 6, "BalanceDepth", "Binary tree balance depth");
125  DeclareOptionRef(fScaleFrac = 0.80, "ScaleFrac", "Fraction of events used to compute variable width");
126  DeclareOptionRef(fSigmaFact = 1.0, "SigmaFact", "Scale factor for sigma in Gaussian kernel");
127  DeclareOptionRef(fKernel = "Gaus", "Kernel", "Use polynomial (=Poln) or Gaussian (=Gaus) kernel");
128  DeclareOptionRef(fTrim = kFALSE, "Trim", "Use equal number of signal and background events");
129  DeclareOptionRef(fUseKernel = kFALSE, "UseKernel", "Use polynomial kernel weight");
130  DeclareOptionRef(fUseWeight = kTRUE, "UseWeight", "Use weight to count kNN events");
131  DeclareOptionRef(fUseLDA = kFALSE, "UseLDA", "Use local linear discriminant - experimental feature");
132 }
133 
134 ////////////////////////////////////////////////////////////////////////////////
135 /// options that are used ONLY for the READER to ensure backward compatibility
136 
139  DeclareOptionRef(fTreeOptDepth = 6, "TreeOptDepth", "Binary tree optimisation depth");
140 }
141 
142 ////////////////////////////////////////////////////////////////////////////////
143 /// process the options specified by the user
144 
146 {
147  if (!(fnkNN > 0)) {
148  fnkNN = 10;
149  Log() << kWARNING << "kNN must be a positive integer: set kNN = " << fnkNN << Endl;
150  }
151  if (fScaleFrac < 0.0) {
152  fScaleFrac = 0.0;
153  Log() << kWARNING << "ScaleFrac can not be negative: set ScaleFrac = " << fScaleFrac << Endl;
154  }
155  if (fScaleFrac > 1.0) {
156  fScaleFrac = 1.0;
157  }
158  if (!(fBalanceDepth > 0)) {
159  fBalanceDepth = 6;
160  Log() << kWARNING << "Optimize must be a positive integer: set Optimize = " << fBalanceDepth << Endl;
161  }
162 
163  Log() << kVERBOSE
164  << "kNN options: \n"
165  << " kNN = \n" << fnkNN
166  << " UseKernel = \n" << fUseKernel
167  << " SigmaFact = \n" << fSigmaFact
168  << " ScaleFrac = \n" << fScaleFrac
169  << " Kernel = \n" << fKernel
170  << " Trim = \n" << fTrim
171  << " Optimize = " << fBalanceDepth << Endl;
172 }
173 
174 ////////////////////////////////////////////////////////////////////////////////
175 /// FDA can handle classification with 2 classes and regression with one regression-target
176 
178 {
179  if (type == Types::kClassification && numberClasses == 2) return kTRUE;
180  if (type == Types::kRegression) return kTRUE;
181  return kFALSE;
182 }
183 
184 ////////////////////////////////////////////////////////////////////////////////
185 /// Initialization
186 
188 {
189  // fScaleFrac <= 0.0 then do not scale input variables
190  // fScaleFrac >= 1.0 then use all event coordinates to scale input variables
191 
192  fModule = new kNN::ModulekNN();
193  fSumOfWeightsS = 0;
194  fSumOfWeightsB = 0;
195 }
196 
197 ////////////////////////////////////////////////////////////////////////////////
198 /// create kNN
199 
201 {
202  if (!fModule) {
203  Log() << kFATAL << "ModulekNN is not created" << Endl;
204  }
205 
206  fModule->Clear();
207 
208  std::string option;
209  if (fScaleFrac > 0.0) {
210  option += "metric";
211  }
212  if (fTrim) {
213  option += "trim";
214  }
215 
216  Log() << kINFO << "Creating kd-tree with " << fEvent.size() << " events" << Endl;
217 
218  for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
219  fModule->Add(*event);
220  }
221 
222  // create binary tree
223  fModule->Fill(static_cast<UInt_t>(fBalanceDepth),
224  static_cast<UInt_t>(100.0*fScaleFrac),
225  option);
226 }
227 
228 ////////////////////////////////////////////////////////////////////////////////
229 /// kNN training
230 
232 {
233  Log() << kINFO << "<Train> start..." << Endl;
234 
235  if (IsNormalised()) {
236  Log() << kINFO << "Input events are normalized - setting ScaleFrac to 0" << Endl;
237  fScaleFrac = 0.0;
238  }
239 
240  if (!fEvent.empty()) {
241  Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
242  fEvent.clear();
243  }
244  if (GetNVariables() < 1)
245  Log() << kFATAL << "MethodKNN::Train() - mismatched or wrong number of event variables" << Endl;
246 
247 
248  Log() << kINFO << "Reading " << GetNEvents() << " events" << Endl;
249 
250  for (UInt_t ievt = 0; ievt < GetNEvents(); ++ievt) {
251  // read the training event
252  const Event* evt_ = GetEvent(ievt);
253  Double_t weight = evt_->GetWeight();
254 
255  // in case event with neg weights are to be ignored
256  if (IgnoreEventsWithNegWeightsInTraining() && weight <= 0) continue;
257 
258  kNN::VarVec vvec(GetNVariables(), 0.0);
259  for (UInt_t ivar = 0; ivar < evt_ -> GetNVariables(); ++ivar) vvec[ivar] = evt_->GetValue(ivar);
260 
261  Short_t event_type = 0;
262 
263  if (DataInfo().IsSignal(evt_)) { // signal type = 1
264  fSumOfWeightsS += weight;
265  event_type = 1;
266  }
267  else { // background type = 2
268  fSumOfWeightsB += weight;
269  event_type = 2;
270  }
271 
272  //
273  // Create event and add classification variables, weight, type and regression variables
274  //
275  kNN::Event event_knn(vvec, weight, event_type);
276  event_knn.SetTargets(evt_->GetTargets());
277  fEvent.push_back(event_knn);
278 
279  }
280  Log() << kINFO
281  << "Number of signal events " << fSumOfWeightsS << Endl
282  << "Number of background events " << fSumOfWeightsB << Endl;
283 
284  // create kd-tree (binary tree) structure
285  MakeKNN();
286 }
287 
288 ////////////////////////////////////////////////////////////////////////////////
289 /// Compute classifier response
290 
292 {
293  // cannot determine error
294  NoErrorCalc(err, errUpper);
295 
296  //
297  // Define local variables
298  //
299  const Event *ev = GetEvent();
300  const Int_t nvar = GetNVariables();
301  const Double_t weight = ev->GetWeight();
302  const UInt_t knn = static_cast<UInt_t>(fnkNN);
303 
304  kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
305 
306  for (Int_t ivar = 0; ivar < nvar; ++ivar) {
307  vvec[ivar] = ev->GetValue(ivar);
308  }
309 
310  // search for fnkNN+2 nearest neighbors, pad with two
311  // events to avoid Monte-Carlo events with zero distance
312  // most of CPU time is spent in this recursive function
313  const kNN::Event event_knn(vvec, weight, 3);
314  fModule->Find(event_knn, knn + 2);
315 
316  const kNN::List &rlist = fModule->GetkNNList();
317  if (rlist.size() != knn + 2) {
318  Log() << kFATAL << "kNN result list is empty" << Endl;
319  return -100.0;
320  }
321 
322  if (fUseLDA) return MethodKNN::getLDAValue(rlist, event_knn);
323 
324  //
325  // Set flags for kernel option=Gaus, Poln
326  //
327  Bool_t use_gaus = false, use_poln = false;
328  if (fUseKernel) {
329  if (fKernel == "Gaus") use_gaus = true;
330  else if (fKernel == "Poln") use_poln = true;
331  }
332 
333  //
334  // Compute radius for polynomial kernel
335  //
336  Double_t kradius = -1.0;
337  if (use_poln) {
338  kradius = MethodKNN::getKernelRadius(rlist);
339 
340  if (!(kradius > 0.0)) {
341  Log() << kFATAL << "kNN radius is not positive" << Endl;
342  return -100.0;
343  }
344 
345  kradius = 1.0/TMath::Sqrt(kradius);
346  }
347 
348  //
349  // Compute RMS of variable differences for Gaussian sigma
350  //
351  std::vector<Double_t> rms_vec;
352  if (use_gaus) {
353  rms_vec = TMVA::MethodKNN::getRMS(rlist, event_knn);
354 
355  if (rms_vec.empty() || rms_vec.size() != event_knn.GetNVar()) {
356  Log() << kFATAL << "Failed to compute RMS vector" << Endl;
357  return -100.0;
358  }
359  }
360 
361  UInt_t count_all = 0;
362  Double_t weight_all = 0, weight_sig = 0, weight_bac = 0;
363 
364  for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
365 
366  // get reference to current node to make code more readable
367  const kNN::Node<kNN::Event> &node = *(lit->first);
368 
369  // Warn about Monte-Carlo event with zero distance
370  // this happens when this query event is also in learning sample
371  if (lit->second < 0.0) {
372  Log() << kFATAL << "A neighbor has negative distance to query event" << Endl;
373  }
374  else if (!(lit->second > 0.0)) {
375  Log() << kVERBOSE << "A neighbor has zero distance to query event" << Endl;
376  }
377 
378  // get event weight and scale weight by kernel function
379  Double_t evweight = node.GetWeight();
380  if (use_gaus) evweight *= MethodKNN::GausKernel(event_knn, node.GetEvent(), rms_vec);
381  else if (use_poln) evweight *= MethodKNN::PolnKernel(TMath::Sqrt(lit->second)*kradius);
382 
383  if (fUseWeight) weight_all += evweight;
384  else ++weight_all;
385 
386  if (node.GetEvent().GetType() == 1) { // signal type = 1
387  if (fUseWeight) weight_sig += evweight;
388  else ++weight_sig;
389  }
390  else if (node.GetEvent().GetType() == 2) { // background type = 2
391  if (fUseWeight) weight_bac += evweight;
392  else ++weight_bac;
393  }
394  else {
395  Log() << kFATAL << "Unknown type for training event" << Endl;
396  }
397 
398  // use only fnkNN events
399  ++count_all;
400 
401  if (count_all >= knn) {
402  break;
403  }
404  }
405 
406  // check that total number of events or total weight sum is positive
407  if (!(count_all > 0)) {
408  Log() << kFATAL << "Size kNN result list is not positive" << Endl;
409  return -100.0;
410  }
411 
412  // check that number of events matches number of k in knn
413  if (count_all < knn) {
414  Log() << kDEBUG << "count_all and kNN have different size: " << count_all << " < " << knn << Endl;
415  }
416 
417  // Check that total weight is positive
418  if (!(weight_all > 0.0)) {
419  Log() << kFATAL << "kNN result total weight is not positive" << Endl;
420  return -100.0;
421  }
422 
423  return weight_sig/weight_all;
424 }
425 
426 ////////////////////////////////////////////////////////////////////////////////
427 ///
428 /// Return vector of averages for target values of k-nearest neighbors.
429 /// Use own copy of the regression vector, I do not like using a pointer to vector.
430 ///
431 
432 const std::vector< Float_t >& TMVA::MethodKNN::GetRegressionValues()
433 {
434  if( fRegressionReturnVal == 0 )
435  fRegressionReturnVal = new std::vector<Float_t>;
436  else
437  fRegressionReturnVal->clear();
438 
439  //
440  // Define local variables
441  //
442  const Event *evt = GetEvent();
443  const Int_t nvar = GetNVariables();
444  const UInt_t knn = static_cast<UInt_t>(fnkNN);
445  std::vector<float> reg_vec;
446 
447  kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
448 
449  for (Int_t ivar = 0; ivar < nvar; ++ivar) {
450  vvec[ivar] = evt->GetValue(ivar);
451  }
452 
453  // search for fnkNN+2 nearest neighbors, pad with two
454  // events to avoid Monte-Carlo events with zero distance
455  // most of CPU time is spent in this recursive function
456  const kNN::Event event_knn(vvec, evt->GetWeight(), 3);
457  fModule->Find(event_knn, knn + 2);
458 
459  const kNN::List &rlist = fModule->GetkNNList();
460  if (rlist.size() != knn + 2) {
461  Log() << kFATAL << "kNN result list is empty" << Endl;
462  return *fRegressionReturnVal;
463  }
464 
465  // compute regression values
466  Double_t weight_all = 0;
467  UInt_t count_all = 0;
468 
469  for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
470 
471  // get reference to current node to make code more readable
472  const kNN::Node<kNN::Event> &node = *(lit->first);
473  const kNN::VarVec &tvec = node.GetEvent().GetTargets();
474  const Double_t weight = node.GetEvent().GetWeight();
475 
476  if (reg_vec.empty()) {
477  reg_vec= kNN::VarVec(tvec.size(), 0.0);
478  }
479 
480  for(UInt_t ivar = 0; ivar < tvec.size(); ++ivar) {
481  if (fUseWeight) reg_vec[ivar] += tvec[ivar]*weight;
482  else reg_vec[ivar] += tvec[ivar];
483  }
484 
485  if (fUseWeight) weight_all += weight;
486  else ++weight_all;
487 
488  // use only fnkNN events
489  ++count_all;
490 
491  if (count_all == knn) {
492  break;
493  }
494  }
495 
496  // check that number of events matches number of k in knn
497  if (!(weight_all > 0.0)) {
498  Log() << kFATAL << "Total weight sum is not positive: " << weight_all << Endl;
499  return *fRegressionReturnVal;
500  }
501 
502  for (UInt_t ivar = 0; ivar < reg_vec.size(); ++ivar) {
503  reg_vec[ivar] /= weight_all;
504  }
505 
506  // copy result
507  fRegressionReturnVal->insert(fRegressionReturnVal->begin(), reg_vec.begin(), reg_vec.end());
508 
509  return *fRegressionReturnVal;
510 }
511 
512 ////////////////////////////////////////////////////////////////////////////////
513 /// no ranking available
514 
516 {
517  return 0;
518 }
519 
520 ////////////////////////////////////////////////////////////////////////////////
521 /// write weights to XML
522 
523 void TMVA::MethodKNN::AddWeightsXMLTo( void* parent ) const {
524  void* wght = gTools().AddChild(parent, "Weights");
525  gTools().AddAttr(wght,"NEvents",fEvent.size());
526  if (fEvent.size()>0) gTools().AddAttr(wght,"NVar",fEvent.begin()->GetNVar());
527  if (fEvent.size()>0) gTools().AddAttr(wght,"NTgt",fEvent.begin()->GetNTgt());
528 
529  for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
530 
531  std::stringstream s("");
532  s.precision( 16 );
533  for (UInt_t ivar = 0; ivar < event->GetNVar(); ++ivar) {
534  if (ivar>0) s << " ";
535  s << std::scientific << event->GetVar(ivar);
536  }
537 
538  for (UInt_t itgt = 0; itgt < event->GetNTgt(); ++itgt) {
539  s << " " << std::scientific << event->GetTgt(itgt);
540  }
541 
542  void* evt = gTools().AddChild(wght, "Event", s.str().c_str());
543  gTools().AddAttr(evt,"Type", event->GetType());
544  gTools().AddAttr(evt,"Weight", event->GetWeight());
545  }
546 }
547 
548 ////////////////////////////////////////////////////////////////////////////////
549 
550 void TMVA::MethodKNN::ReadWeightsFromXML( void* wghtnode ) {
551  void* ch = gTools().GetChild(wghtnode); // first event
552  UInt_t nvar = 0, ntgt = 0;
553  gTools().ReadAttr( wghtnode, "NVar", nvar );
554  gTools().ReadAttr( wghtnode, "NTgt", ntgt );
555 
556 
557  Short_t evtType(0);
558  Double_t evtWeight(0);
559 
560  while (ch) {
561  // build event
562  kNN::VarVec vvec(nvar, 0);
563  kNN::VarVec tvec(ntgt, 0);
564 
565  gTools().ReadAttr( ch, "Type", evtType );
566  gTools().ReadAttr( ch, "Weight", evtWeight );
567  std::stringstream s( gTools().GetContent(ch) );
568 
569  for(UInt_t ivar=0; ivar<nvar; ivar++)
570  s >> vvec[ivar];
571 
572  for(UInt_t itgt=0; itgt<ntgt; itgt++)
573  s >> tvec[itgt];
574 
575  ch = gTools().GetNextChild(ch);
576 
577  kNN::Event event_knn(vvec, evtWeight, evtType, tvec);
578  fEvent.push_back(event_knn);
579  }
580 
581  // create kd-tree (binary tree) structure
582  MakeKNN();
583 }
584 
585 ////////////////////////////////////////////////////////////////////////////////
586 /// read the weights
587 
589 {
590  Log() << kINFO << "Starting ReadWeightsFromStream(std::istream& is) function..." << Endl;
591 
592  if (!fEvent.empty()) {
593  Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
594  fEvent.clear();
595  }
596 
597  UInt_t nvar = 0;
598 
599  while (!is.eof()) {
600  std::string line;
601  std::getline(is, line);
602 
603  if (line.empty() || line.find("#") != std::string::npos) {
604  continue;
605  }
606 
607  UInt_t count = 0;
608  std::string::size_type pos=0;
609  while( (pos=line.find(',',pos)) != std::string::npos ) { count++; pos++; }
610 
611  if (nvar == 0) {
612  nvar = count - 2;
613  }
614  if (count < 3 || nvar != count - 2) {
615  Log() << kFATAL << "Missing comma delimeter(s)" << Endl;
616  }
617 
618  // Int_t ievent = -1;
619  Int_t type = -1;
620  Double_t weight = -1.0;
621 
622  kNN::VarVec vvec(nvar, 0.0);
623 
624  UInt_t vcount = 0;
625  std::string::size_type prev = 0;
626 
627  for (std::string::size_type ipos = 0; ipos < line.size(); ++ipos) {
628  if (line[ipos] != ',' && ipos + 1 != line.size()) {
629  continue;
630  }
631 
632  if (!(ipos > prev)) {
633  Log() << kFATAL << "Wrong substring limits" << Endl;
634  }
635 
636  std::string vstring = line.substr(prev, ipos - prev);
637  if (ipos + 1 == line.size()) {
638  vstring = line.substr(prev, ipos - prev + 1);
639  }
640 
641  if (vstring.empty()) {
642  Log() << kFATAL << "Failed to parse string" << Endl;
643  }
644 
645  if (vcount == 0) {
646  // ievent = std::atoi(vstring.c_str());
647  }
648  else if (vcount == 1) {
649  type = std::atoi(vstring.c_str());
650  }
651  else if (vcount == 2) {
652  weight = std::atof(vstring.c_str());
653  }
654  else if (vcount - 3 < vvec.size()) {
655  vvec[vcount - 3] = std::atof(vstring.c_str());
656  }
657  else {
658  Log() << kFATAL << "Wrong variable count" << Endl;
659  }
660 
661  prev = ipos + 1;
662  ++vcount;
663  }
664 
665  fEvent.push_back(kNN::Event(vvec, weight, type));
666  }
667 
668  Log() << kINFO << "Read " << fEvent.size() << " events from text file" << Endl;
669 
670  // create kd-tree (binary tree) structure
671  MakeKNN();
672 }
673 
674 ////////////////////////////////////////////////////////////////////////////////
675 /// save weights to ROOT file
676 
678 {
679  Log() << kINFO << "Starting WriteWeightsToStream(TFile &rf) function..." << Endl;
680 
681  if (fEvent.empty()) {
682  Log() << kWARNING << "MethodKNN contains no events " << Endl;
683  return;
684  }
685 
686  kNN::Event *event = new kNN::Event();
687  TTree *tree = new TTree("knn", "event tree");
688  tree->SetDirectory(0);
689  tree->Branch("event", "TMVA::kNN::Event", &event);
690 
691  Double_t size = 0.0;
692  for (kNN::EventVec::const_iterator it = fEvent.begin(); it != fEvent.end(); ++it) {
693  (*event) = (*it);
694  size += tree->Fill();
695  }
696 
697  // !!! hard coded tree name !!!
698  rf.WriteTObject(tree, "knn", "Overwrite");
699 
700  // scale to MegaBytes
701  size /= 1048576.0;
702 
703  Log() << kINFO << "Wrote " << size << "MB and " << fEvent.size()
704  << " events to ROOT file" << Endl;
705 
706  delete tree;
707  delete event;
708 }
709 
710 ////////////////////////////////////////////////////////////////////////////////
711 /// read weights from ROOT file
712 
714 {
715  Log() << kINFO << "Starting ReadWeightsFromStream(TFile &rf) function..." << Endl;
716 
717  if (!fEvent.empty()) {
718  Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
719  fEvent.clear();
720  }
721 
722  // !!! hard coded tree name !!!
723  TTree *tree = dynamic_cast<TTree *>(rf.Get("knn"));
724  if (!tree) {
725  Log() << kFATAL << "Failed to find knn tree" << Endl;
726  return;
727  }
728 
729  kNN::Event *event = new kNN::Event();
730  tree->SetBranchAddress("event", &event);
731 
732  const Int_t nevent = tree->GetEntries();
733 
734  Double_t size = 0.0;
735  for (Int_t i = 0; i < nevent; ++i) {
736  size += tree->GetEntry(i);
737  fEvent.push_back(*event);
738  }
739 
740  // scale to MegaBytes
741  size /= 1048576.0;
742 
743  Log() << kINFO << "Read " << size << "MB and " << fEvent.size()
744  << " events from ROOT file" << Endl;
745 
746  delete event;
747 
748  // create kd-tree (binary tree) structure
749  MakeKNN();
750 }
751 
752 ////////////////////////////////////////////////////////////////////////////////
753 /// write specific classifier response
754 
755 void TMVA::MethodKNN::MakeClassSpecific( std::ostream& fout, const TString& className ) const
756 {
757  fout << " // not implemented for class: \"" << className << "\"" << std::endl;
758  fout << "};" << std::endl;
759 }
760 
761 ////////////////////////////////////////////////////////////////////////////////
762 /// get help message text
763 ///
764 /// typical length of text line:
765 /// "|--------------------------------------------------------------|"
766 
768 {
769  Log() << Endl;
770  Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
771  Log() << Endl;
772  Log() << "The k-nearest neighbor (k-NN) algorithm is a multi-dimensional classification" << Endl
773  << "and regression algorithm. Similarly to other TMVA algorithms, k-NN uses a set of" << Endl
774  << "training events for which a classification category/regression target is known. " << Endl
775  << "The k-NN method compares a test event to all training events using a distance " << Endl
776  << "function, which is an Euclidean distance in a space defined by the input variables. "<< Endl
777  << "The k-NN method, as implemented in TMVA, uses a kd-tree algorithm to perform a" << Endl
778  << "quick search for the k events with shortest distance to the test event. The method" << Endl
779  << "returns a fraction of signal events among the k neighbors. It is recommended" << Endl
780  << "that a histogram which stores the k-NN decision variable is binned with k+1 bins" << Endl
781  << "between 0 and 1." << Endl;
782 
783  Log() << Endl;
784  Log() << gTools().Color("bold") << "--- Performance tuning via configuration options: "
785  << gTools().Color("reset") << Endl;
786  Log() << Endl;
787  Log() << "The k-NN method estimates a density of signal and background events in a "<< Endl
788  << "neighborhood around the test event. The method assumes that the density of the " << Endl
789  << "signal and background events is uniform and constant within the neighborhood. " << Endl
790  << "k is an adjustable parameter and it determines an average size of the " << Endl
791  << "neighborhood. Small k values (less than 10) are sensitive to statistical " << Endl
792  << "fluctuations and large (greater than 100) values might not sufficiently capture " << Endl
793  << "local differences between events in the training set. The speed of the k-NN" << Endl
794  << "method also increases with larger values of k. " << Endl;
795  Log() << Endl;
796  Log() << "The k-NN method assigns equal weight to all input variables. Different scales " << Endl
797  << "among the input variables is compensated using ScaleFrac parameter: the input " << Endl
798  << "variables are scaled so that the widths for central ScaleFrac*100% events are " << Endl
799  << "equal among all the input variables." << Endl;
800 
801  Log() << Endl;
802  Log() << gTools().Color("bold") << "--- Additional configuration options: "
803  << gTools().Color("reset") << Endl;
804  Log() << Endl;
805  Log() << "The method inclues an option to use a Gaussian kernel to smooth out the k-NN" << Endl
806  << "response. The kernel re-weights events using a distance to the test event." << Endl;
807 }
808 
809 ////////////////////////////////////////////////////////////////////////////////
810 /// polynomial kernel
811 
813 {
814  const Double_t avalue = TMath::Abs(value);
815 
816  if (!(avalue < 1.0)) {
817  return 0.0;
818  }
819 
820  const Double_t prod = 1.0 - avalue * avalue * avalue;
821 
822  return (prod * prod * prod);
823 }
824 
825 ////////////////////////////////////////////////////////////////////////////////
826 /// Gaussian kernel
827 
829  const kNN::Event &event, const std::vector<Double_t> &svec) const
830 {
831  if (event_knn.GetNVar() != event.GetNVar() || event_knn.GetNVar() != svec.size()) {
832  Log() << kFATAL << "Mismatched vectors in Gaussian kernel function" << Endl;
833  return 0.0;
834  }
835 
836  //
837  // compute exponent
838  //
839  double sum_exp = 0.0;
840 
841  for(unsigned int ivar = 0; ivar < event_knn.GetNVar(); ++ivar) {
842 
843  const Double_t diff_ = event.GetVar(ivar) - event_knn.GetVar(ivar);
844  const Double_t sigm_ = svec[ivar];
845  if (!(sigm_ > 0.0)) {
846  Log() << kFATAL << "Bad sigma value = " << sigm_ << Endl;
847  return 0.0;
848  }
849 
850  sum_exp += diff_*diff_/(2.0*sigm_*sigm_);
851  }
852 
853  //
854  // Return unnormalized(!) Gaussian function, because normalization
855  // cancels for the ratio of weights.
856  //
857 
858  return std::exp(-sum_exp);
859 }
860 
861 ////////////////////////////////////////////////////////////////////////////////
862 ///
863 /// Get polynomial kernel radius
864 ///
865 
867 {
868  Double_t kradius = -1.0;
869  UInt_t kcount = 0;
870  const UInt_t knn = static_cast<UInt_t>(fnkNN);
871 
872  for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
873  {
874  if (!(lit->second > 0.0)) continue;
875 
876  if (kradius < lit->second || kradius < 0.0) kradius = lit->second;
877 
878  ++kcount;
879  if (kcount >= knn) break;
880  }
881 
882  return kradius;
883 }
884 
885 ////////////////////////////////////////////////////////////////////////////////
886 ///
887 /// Get polynomial kernel radius
888 ///
889 
890 const std::vector<Double_t> TMVA::MethodKNN::getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
891 {
892  std::vector<Double_t> rvec;
893  UInt_t kcount = 0;
894  const UInt_t knn = static_cast<UInt_t>(fnkNN);
895 
896  for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
897  {
898  if (!(lit->second > 0.0)) continue;
899 
900  const kNN::Node<kNN::Event> *node_ = lit -> first;
901  const kNN::Event &event_ = node_-> GetEvent();
902 
903  if (rvec.empty()) {
904  rvec.insert(rvec.end(), event_.GetNVar(), 0.0);
905  }
906  else if (rvec.size() != event_.GetNVar()) {
907  Log() << kFATAL << "Wrong number of variables, should never happen!" << Endl;
908  rvec.clear();
909  return rvec;
910  }
911 
912  for(unsigned int ivar = 0; ivar < event_.GetNVar(); ++ivar) {
913  const Double_t diff_ = event_.GetVar(ivar) - event_knn.GetVar(ivar);
914  rvec[ivar] += diff_*diff_;
915  }
916 
917  ++kcount;
918  if (kcount >= knn) break;
919  }
920 
921  if (kcount < 1) {
922  Log() << kFATAL << "Bad event kcount = " << kcount << Endl;
923  rvec.clear();
924  return rvec;
925  }
926 
927  for(unsigned int ivar = 0; ivar < rvec.size(); ++ivar) {
928  if (!(rvec[ivar] > 0.0)) {
929  Log() << kFATAL << "Bad RMS value = " << rvec[ivar] << Endl;
930  rvec.clear();
931  return rvec;
932  }
933 
934  rvec[ivar] = std::abs(fSigmaFact)*std::sqrt(rvec[ivar]/kcount);
935  }
936 
937  return rvec;
938 }
939 
940 ////////////////////////////////////////////////////////////////////////////////
941 
943 {
944  LDAEvents sig_vec, bac_vec;
945 
946  for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
947 
948  // get reference to current node to make code more readable
949  const kNN::Node<kNN::Event> &node = *(lit->first);
950  const kNN::VarVec &tvec = node.GetEvent().GetVars();
951 
952  if (node.GetEvent().GetType() == 1) { // signal type = 1
953  sig_vec.push_back(tvec);
954  }
955  else if (node.GetEvent().GetType() == 2) { // background type = 2
956  bac_vec.push_back(tvec);
957  }
958  else {
959  Log() << kFATAL << "Unknown type for training event" << Endl;
960  }
961  }
962 
963  fLDA.Initialize(sig_vec, bac_vec);
964 
965  return fLDA.GetProb(event_knn.GetVars(), 1);
966 }
virtual void Clear(Option_t *="")
Definition: TObject.h:110
void ProcessOptions()
process the options specified by the user
Definition: MethodKNN.cxx:145
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:162
TLine * line
void DeclareOptions()
MethodKNN options.
Definition: MethodKNN.cxx:111
virtual Int_t Fill()
Fill all branches.
Definition: TTree.cxx:4328
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format...
Definition: TFile.h:45
EAnalysisType
Definition: Types.h:124
void Train(void)
kNN training
Definition: MethodKNN.cxx:231
virtual TObject * Get(const char *namecycle)
Return pointer to object identified by namecycle.
void MakeKNN(void)
create kNN
Definition: MethodKNN.cxx:200
virtual Int_t GetEntry(Long64_t entry=0, Int_t getall=0)
Read all branches of entry and return total number of bytes read.
Definition: TTree.cxx:5163
Basic string class.
Definition: TString.h:137
int Int_t
Definition: RtypesCore.h:41
bool Bool_t
Definition: RtypesCore.h:59
const Bool_t kFALSE
Definition: Rtypes.h:92
Short_t GetType() const
Definition: ModulekNN.h:204
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not...
Definition: Event.cxx:376
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
Definition: Tools.h:308
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1134
Short_t Abs(Short_t d)
Definition: TMathBase.h:110
void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility
Definition: MethodKNN.cxx:137
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:231
virtual Int_t WriteTObject(const TObject *obj, const char *name=0, Option_t *option="", Int_t bufsize=0)
Write object obj to this directory.
void ReadWeightsFromStream(std::istream &istr)
read the weights
Definition: MethodKNN.cxx:588
MethodKNN(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="KNN", TDirectory *theTargetDir=NULL)
double sqrt(double)
Tools & gTools()
Definition: Tools.cxx:79
void Init(void)
Initialization.
Definition: MethodKNN.cxx:187
virtual Int_t SetBranchAddress(const char *bname, void *add, TBranch **ptr=0)
Change branch address, dealing with clone trees properly.
Definition: TTree.cxx:7529
Double_t PolnKernel(Double_t value) const
polynomial kernel
Definition: MethodKNN.cxx:812
std::list< Elem > List
Definition: ModulekNN.h:107
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1158
static Vc_ALWAYS_INLINE Vector< T > abs(const Vector< T > &x)
Definition: vector.h:450
void WriteWeightsToStream(TFile &rf) const
save weights to ROOT file
Definition: MethodKNN.cxx:677
const VarVec & GetTargets() const
Definition: ModulekNN.cxx:107
void SetTargets(const VarVec &tvec)
Definition: ModulekNN.cxx:100
const VarVec & GetVars() const
Definition: ModulekNN.cxx:114
void MakeClassSpecific(std::ostream &, const TString &) const
write specific classifier response
Definition: MethodKNN.cxx:755
std::vector< Float_t > & GetTargets()
Definition: Event.h:102
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
Compute classifier response.
Definition: MethodKNN.cxx:291
Double_t GausKernel(const kNN::Event &event_knn, const kNN::Event &event, const std::vector< Double_t > &svec) const
Gaussian kernel.
Definition: MethodKNN.cxx:828
virtual ~MethodKNN(void)
destructor
Definition: MethodKNN.cxx:103
unsigned int UInt_t
Definition: RtypesCore.h:42
ClassImp(TMVA::MethodKNN) TMVA
standard constructor
Definition: MethodKNN.cxx:52
short Short_t
Definition: RtypesCore.h:35
void ReadAttr(void *node, const char *, T &value)
Definition: Tools.h:295
const Ranking * CreateRanking()
no ranking available
Definition: MethodKNN.cxx:515
Double_t GetWeight() const
Definition: ModulekNN.h:183
void ReadWeightsFromXML(void *wghtnode)
Definition: MethodKNN.cxx:550
virtual void SetDirectory(TDirectory *dir)
Change the tree's directory.
Definition: TTree.cxx:8095
double Double_t
Definition: RtypesCore.h:55
Describe directory structure in memory.
Definition: TDirectory.h:41
int type
Definition: TGX11.cxx:120
Double_t GetWeight() const
Definition: NodekNN.h:184
UInt_t GetNVar() const
Definition: ModulekNN.h:196
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1170
const T & GetEvent() const
Definition: NodekNN.h:160
void AddWeightsXMLTo(void *parent) const
write weights to XML
Definition: MethodKNN.cxx:523
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
FDA can handle classification with 2 classes and regression with one regression-target.
Definition: MethodKNN.cxx:177
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:837
virtual Int_t Branch(TCollection *list, Int_t bufsize=32000, Int_t splitlevel=99, const char *name="")
Create one branch for each element in the collection.
Definition: TTree.cxx:1624
#define REGISTER_METHOD(CLASS)
for example
Abstract ClassifierFactory template that handles arbitrary types.
VarType GetVar(UInt_t i) const
Definition: ModulekNN.h:187
void GetHelpMessage() const
get help message text
Definition: MethodKNN.cxx:767
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Definition: MethodBase.cxx:599
const std::vector< Double_t > getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
Get polynomial kernel radius.
Definition: MethodKNN.cxx:890
std::vector< std::vector< Float_t > > LDAEvents
Definition: LDA.h:42
virtual Long64_t GetEntries() const
Definition: TTree.h:382
A TTree object has a header with a name and a title.
Definition: TTree.h:94
Double_t Sqrt(Double_t x)
Definition: TMath.h:464
double exp(double)
const Bool_t kTRUE
Definition: Rtypes.h:91
const std::vector< Float_t > & GetRegressionValues()
Return vector of averages for target values of k-nearest neighbors.
Definition: MethodKNN.cxx:432
std::vector< VarType > VarVec
Definition: ModulekNN.h:65
float value
Definition: math.cpp:443
double getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn)
Definition: MethodKNN.cxx:942
Definition: math.cpp:60
Double_t getKernelRadius(const kNN::List &rlist) const
Get polynomial kernel radius.
Definition: MethodKNN.cxx:866