Logo ROOT  
Reference Guide
MethodKNN.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Rustem Ospanov
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodKNN *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation *
12 * *
13 * Author: *
14 * Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA *
15 * *
16 * Copyright (c) 2007: *
17 * CERN, Switzerland *
18 * MPI-K Heidelberg, Germany *
19 * U. of Texas at Austin, USA *
20 * *
21 * Redistribution and use in source and binary forms, with or without *
22 * modification, are permitted according to the terms listed in LICENSE *
23 * (http://tmva.sourceforge.net/LICENSE) *
24 **********************************************************************************/
25
26/*! \class TMVA::MethodKNN
27\ingroup TMVA
28
29Analysis of k-nearest neighbor.
30
31*/
32
33#include "TMVA/MethodKNN.h"
34
36#include "TMVA/Configurable.h"
37#include "TMVA/DataSetInfo.h"
38#include "TMVA/Event.h"
39#include "TMVA/LDA.h"
40#include "TMVA/IMethod.h"
41#include "TMVA/MethodBase.h"
42#include "TMVA/MsgLogger.h"
43#include "TMVA/Ranking.h"
44#include "TMVA/Tools.h"
45#include "TMVA/Types.h"
46
47#include "TFile.h"
48#include "TMath.h"
49#include "TTree.h"
50
51#include <cmath>
52#include <string>
53#include <cstdlib>
54
56
58
59////////////////////////////////////////////////////////////////////////////////
60/// standard constructor
61
63 const TString& methodTitle,
64 DataSetInfo& theData,
65 const TString& theOption )
66 : TMVA::MethodBase(jobName, Types::kKNN, methodTitle, theData, theOption)
67 , fSumOfWeightsS(0)
68 , fSumOfWeightsB(0)
69 , fModule(0)
70 , fnkNN(0)
71 , fBalanceDepth(0)
72 , fScaleFrac(0)
73 , fSigmaFact(0)
74 , fTrim(kFALSE)
75 , fUseKernel(kFALSE)
76 , fUseWeight(kFALSE)
77 , fUseLDA(kFALSE)
78 , fTreeOptDepth(0)
79{
80}
81
82////////////////////////////////////////////////////////////////////////////////
83/// constructor from weight file
84
86 const TString& theWeightFile)
87 : TMVA::MethodBase( Types::kKNN, theData, theWeightFile)
88 , fSumOfWeightsS(0)
89 , fSumOfWeightsB(0)
90 , fModule(0)
91 , fnkNN(0)
92 , fBalanceDepth(0)
93 , fScaleFrac(0)
94 , fSigmaFact(0)
95 , fTrim(kFALSE)
96 , fUseKernel(kFALSE)
97 , fUseWeight(kFALSE)
98 , fUseLDA(kFALSE)
99 , fTreeOptDepth(0)
100{
101}
102
103////////////////////////////////////////////////////////////////////////////////
104/// destructor
105
107{
108 if (fModule) delete fModule;
109}
110
111////////////////////////////////////////////////////////////////////////////////
112/// MethodKNN options
113///
114/// - fnkNN = 20; // number of k-nearest neighbors
115/// - fBalanceDepth = 6; // number of binary tree levels used for tree balancing
116/// - fScaleFrac = 0.8; // fraction of events used to compute variable width
117/// - fSigmaFact = 1.0; // scale factor for Gaussian sigma
118/// - fKernel = use polynomial (1-x^3)^3 or Gaussian kernel
119/// - fTrim = false; // use equal number of signal and background events
120/// - fUseKernel = false; // use polynomial kernel weight function
121/// - fUseWeight = true; // count events using weights
122/// - fUseLDA = false
123
125{
126 DeclareOptionRef(fnkNN = 20, "nkNN", "Number of k-nearest neighbors");
127 DeclareOptionRef(fBalanceDepth = 6, "BalanceDepth", "Binary tree balance depth");
128 DeclareOptionRef(fScaleFrac = 0.80, "ScaleFrac", "Fraction of events used to compute variable width");
129 DeclareOptionRef(fSigmaFact = 1.0, "SigmaFact", "Scale factor for sigma in Gaussian kernel");
130 DeclareOptionRef(fKernel = "Gaus", "Kernel", "Use polynomial (=Poln) or Gaussian (=Gaus) kernel");
131 DeclareOptionRef(fTrim = kFALSE, "Trim", "Use equal number of signal and background events");
132 DeclareOptionRef(fUseKernel = kFALSE, "UseKernel", "Use polynomial kernel weight");
133 DeclareOptionRef(fUseWeight = kTRUE, "UseWeight", "Use weight to count kNN events");
134 DeclareOptionRef(fUseLDA = kFALSE, "UseLDA", "Use local linear discriminant - experimental feature");
135}
136
137////////////////////////////////////////////////////////////////////////////////
138/// options that are used ONLY for the READER to ensure backward compatibility
139
142 DeclareOptionRef(fTreeOptDepth = 6, "TreeOptDepth", "Binary tree optimisation depth");
143}
144
145////////////////////////////////////////////////////////////////////////////////
146/// process the options specified by the user
147
149{
150 if (!(fnkNN > 0)) {
151 fnkNN = 10;
152 Log() << kWARNING << "kNN must be a positive integer: set kNN = " << fnkNN << Endl;
153 }
154 if (fScaleFrac < 0.0) {
155 fScaleFrac = 0.0;
156 Log() << kWARNING << "ScaleFrac can not be negative: set ScaleFrac = " << fScaleFrac << Endl;
157 }
158 if (fScaleFrac > 1.0) {
159 fScaleFrac = 1.0;
160 }
161 if (!(fBalanceDepth > 0)) {
162 fBalanceDepth = 6;
163 Log() << kWARNING << "Optimize must be a positive integer: set Optimize = " << fBalanceDepth << Endl;
164 }
165
166 Log() << kVERBOSE
167 << "kNN options: \n"
168 << " kNN = \n" << fnkNN
169 << " UseKernel = \n" << fUseKernel
170 << " SigmaFact = \n" << fSigmaFact
171 << " ScaleFrac = \n" << fScaleFrac
172 << " Kernel = \n" << fKernel
173 << " Trim = \n" << fTrim
174 << " Optimize = " << fBalanceDepth << Endl;
175}
176
177////////////////////////////////////////////////////////////////////////////////
178/// FDA can handle classification with 2 classes and regression with one regression-target
179
181{
182 if (type == Types::kClassification && numberClasses == 2) return kTRUE;
183 if (type == Types::kRegression) return kTRUE;
184 return kFALSE;
185}
186
187////////////////////////////////////////////////////////////////////////////////
188/// Initialization
189
191{
192 // fScaleFrac <= 0.0 then do not scale input variables
193 // fScaleFrac >= 1.0 then use all event coordinates to scale input variables
194
195 fModule = new kNN::ModulekNN();
196 fSumOfWeightsS = 0;
197 fSumOfWeightsB = 0;
198}
199
200////////////////////////////////////////////////////////////////////////////////
201/// create kNN
202
204{
205 if (!fModule) {
206 Log() << kFATAL << "ModulekNN is not created" << Endl;
207 }
208
209 fModule->Clear();
210
211 std::string option;
212 if (fScaleFrac > 0.0) {
213 option += "metric";
214 }
215 if (fTrim) {
216 option += "trim";
217 }
218
219 Log() << kINFO << "Creating kd-tree with " << fEvent.size() << " events" << Endl;
220
221 for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
222 fModule->Add(*event);
223 }
224
225 // create binary tree
226 fModule->Fill(static_cast<UInt_t>(fBalanceDepth),
227 static_cast<UInt_t>(100.0*fScaleFrac),
228 option);
229}
230
231////////////////////////////////////////////////////////////////////////////////
232/// kNN training
233
235{
236 Log() << kHEADER << "<Train> start..." << Endl;
237
238 if (IsNormalised()) {
239 Log() << kINFO << "Input events are normalized - setting ScaleFrac to 0" << Endl;
240 fScaleFrac = 0.0;
241 }
242
243 if (!fEvent.empty()) {
244 Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
245 fEvent.clear();
246 }
247 if (GetNVariables() < 1)
248 Log() << kFATAL << "MethodKNN::Train() - mismatched or wrong number of event variables" << Endl;
249
250
251 Log() << kINFO << "Reading " << GetNEvents() << " events" << Endl;
252
253 for (UInt_t ievt = 0; ievt < GetNEvents(); ++ievt) {
254 // read the training event
255 const Event* evt_ = GetEvent(ievt);
256 Double_t weight = evt_->GetWeight();
257
258 // in case event with neg weights are to be ignored
259 if (IgnoreEventsWithNegWeightsInTraining() && weight <= 0) continue;
260
261 kNN::VarVec vvec(GetNVariables(), 0.0);
262 for (UInt_t ivar = 0; ivar < evt_ -> GetNVariables(); ++ivar) vvec[ivar] = evt_->GetValue(ivar);
263
264 Short_t event_type = 0;
265
266 if (DataInfo().IsSignal(evt_)) { // signal type = 1
267 fSumOfWeightsS += weight;
268 event_type = 1;
269 }
270 else { // background type = 2
271 fSumOfWeightsB += weight;
272 event_type = 2;
273 }
274
275 //
276 // Create event and add classification variables, weight, type and regression variables
277 //
278 kNN::Event event_knn(vvec, weight, event_type);
279 event_knn.SetTargets(evt_->GetTargets());
280 fEvent.push_back(event_knn);
281
282 }
283 Log() << kINFO
284 << "Number of signal events " << fSumOfWeightsS << Endl
285 << "Number of background events " << fSumOfWeightsB << Endl;
286
287 // create kd-tree (binary tree) structure
288 MakeKNN();
289
290 ExitFromTraining();
291}
292
293////////////////////////////////////////////////////////////////////////////////
294/// Compute classifier response
295
297{
298 // cannot determine error
299 NoErrorCalc(err, errUpper);
300
301 //
302 // Define local variables
303 //
304 const Event *ev = GetEvent();
305 const Int_t nvar = GetNVariables();
306 const Double_t weight = ev->GetWeight();
307 const UInt_t knn = static_cast<UInt_t>(fnkNN);
308
309 kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
310
311 for (Int_t ivar = 0; ivar < nvar; ++ivar) {
312 vvec[ivar] = ev->GetValue(ivar);
313 }
314
315 // search for fnkNN+2 nearest neighbors, pad with two
316 // events to avoid Monte-Carlo events with zero distance
317 // most of CPU time is spent in this recursive function
318 const kNN::Event event_knn(vvec, weight, 3);
319 fModule->Find(event_knn, knn + 2);
320
321 const kNN::List &rlist = fModule->GetkNNList();
322 if (rlist.size() != knn + 2) {
323 Log() << kFATAL << "kNN result list is empty" << Endl;
324 return -100.0;
325 }
326
327 if (fUseLDA) return MethodKNN::getLDAValue(rlist, event_knn);
328
329 //
330 // Set flags for kernel option=Gaus, Poln
331 //
332 Bool_t use_gaus = false, use_poln = false;
333 if (fUseKernel) {
334 if (fKernel == "Gaus") use_gaus = true;
335 else if (fKernel == "Poln") use_poln = true;
336 }
337
338 //
339 // Compute radius for polynomial kernel
340 //
341 Double_t kradius = -1.0;
342 if (use_poln) {
343 kradius = MethodKNN::getKernelRadius(rlist);
344
345 if (!(kradius > 0.0)) {
346 Log() << kFATAL << "kNN radius is not positive" << Endl;
347 return -100.0;
348 }
349
350 kradius = 1.0/TMath::Sqrt(kradius);
351 }
352
353 //
354 // Compute RMS of variable differences for Gaussian sigma
355 //
356 std::vector<Double_t> rms_vec;
357 if (use_gaus) {
358 rms_vec = TMVA::MethodKNN::getRMS(rlist, event_knn);
359
360 if (rms_vec.empty() || rms_vec.size() != event_knn.GetNVar()) {
361 Log() << kFATAL << "Failed to compute RMS vector" << Endl;
362 return -100.0;
363 }
364 }
365
366 UInt_t count_all = 0;
367 Double_t weight_all = 0, weight_sig = 0, weight_bac = 0;
368
369 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
370
371 // get reference to current node to make code more readable
372 const kNN::Node<kNN::Event> &node = *(lit->first);
373
374 // Warn about Monte-Carlo event with zero distance
375 // this happens when this query event is also in learning sample
376 if (lit->second < 0.0) {
377 Log() << kFATAL << "A neighbor has negative distance to query event" << Endl;
378 }
379 else if (!(lit->second > 0.0)) {
380 Log() << kVERBOSE << "A neighbor has zero distance to query event" << Endl;
381 }
382
383 // get event weight and scale weight by kernel function
384 Double_t evweight = node.GetWeight();
385 if (use_gaus) evweight *= MethodKNN::GausKernel(event_knn, node.GetEvent(), rms_vec);
386 else if (use_poln) evweight *= MethodKNN::PolnKernel(TMath::Sqrt(lit->second)*kradius);
387
388 if (fUseWeight) weight_all += evweight;
389 else ++weight_all;
390
391 if (node.GetEvent().GetType() == 1) { // signal type = 1
392 if (fUseWeight) weight_sig += evweight;
393 else ++weight_sig;
394 }
395 else if (node.GetEvent().GetType() == 2) { // background type = 2
396 if (fUseWeight) weight_bac += evweight;
397 else ++weight_bac;
398 }
399 else {
400 Log() << kFATAL << "Unknown type for training event" << Endl;
401 }
402
403 // use only fnkNN events
404 ++count_all;
405
406 if (count_all >= knn) {
407 break;
408 }
409 }
410
411 // check that total number of events or total weight sum is positive
412 if (!(count_all > 0)) {
413 Log() << kFATAL << "Size kNN result list is not positive" << Endl;
414 return -100.0;
415 }
416
417 // check that number of events matches number of k in knn
418 if (count_all < knn) {
419 Log() << kDEBUG << "count_all and kNN have different size: " << count_all << " < " << knn << Endl;
420 }
421
422 // Check that total weight is positive
423 if (!(weight_all > 0.0)) {
424 Log() << kFATAL << "kNN result total weight is not positive" << Endl;
425 return -100.0;
426 }
427
428 return weight_sig/weight_all;
429}
430
431////////////////////////////////////////////////////////////////////////////////
432/// Return vector of averages for target values of k-nearest neighbors.
433/// Use own copy of the regression vector, I do not like using a pointer to vector.
434
435const std::vector< Float_t >& TMVA::MethodKNN::GetRegressionValues()
436{
437 if( fRegressionReturnVal == 0 )
438 fRegressionReturnVal = new std::vector<Float_t>;
439 else
440 fRegressionReturnVal->clear();
441
442 //
443 // Define local variables
444 //
445 const Event *evt = GetEvent();
446 const Int_t nvar = GetNVariables();
447 const UInt_t knn = static_cast<UInt_t>(fnkNN);
448 std::vector<float> reg_vec;
449
450 kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
451
452 for (Int_t ivar = 0; ivar < nvar; ++ivar) {
453 vvec[ivar] = evt->GetValue(ivar);
454 }
455
456 // search for fnkNN+2 nearest neighbors, pad with two
457 // events to avoid Monte-Carlo events with zero distance
458 // most of CPU time is spent in this recursive function
459 const kNN::Event event_knn(vvec, evt->GetWeight(), 3);
460 fModule->Find(event_knn, knn + 2);
461
462 const kNN::List &rlist = fModule->GetkNNList();
463 if (rlist.size() != knn + 2) {
464 Log() << kFATAL << "kNN result list is empty" << Endl;
465 return *fRegressionReturnVal;
466 }
467
468 // compute regression values
469 Double_t weight_all = 0;
470 UInt_t count_all = 0;
471
472 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
473
474 // get reference to current node to make code more readable
475 const kNN::Node<kNN::Event> &node = *(lit->first);
476 const kNN::VarVec &tvec = node.GetEvent().GetTargets();
477 const Double_t weight = node.GetEvent().GetWeight();
478
479 if (reg_vec.empty()) {
480 reg_vec= kNN::VarVec(tvec.size(), 0.0);
481 }
482
483 for(UInt_t ivar = 0; ivar < tvec.size(); ++ivar) {
484 if (fUseWeight) reg_vec[ivar] += tvec[ivar]*weight;
485 else reg_vec[ivar] += tvec[ivar];
486 }
487
488 if (fUseWeight) weight_all += weight;
489 else ++weight_all;
490
491 // use only fnkNN events
492 ++count_all;
493
494 if (count_all == knn) {
495 break;
496 }
497 }
498
499 // check that number of events matches number of k in knn
500 if (!(weight_all > 0.0)) {
501 Log() << kFATAL << "Total weight sum is not positive: " << weight_all << Endl;
502 return *fRegressionReturnVal;
503 }
504
505 for (UInt_t ivar = 0; ivar < reg_vec.size(); ++ivar) {
506 reg_vec[ivar] /= weight_all;
507 }
508
509 // copy result
510 fRegressionReturnVal->insert(fRegressionReturnVal->begin(), reg_vec.begin(), reg_vec.end());
511
512 return *fRegressionReturnVal;
513}
514
515////////////////////////////////////////////////////////////////////////////////
516/// no ranking available
517
519{
520 return 0;
521}
522
523////////////////////////////////////////////////////////////////////////////////
524/// write weights to XML
525
526void TMVA::MethodKNN::AddWeightsXMLTo( void* parent ) const {
527 void* wght = gTools().AddChild(parent, "Weights");
528 gTools().AddAttr(wght,"NEvents",fEvent.size());
529 if (fEvent.size()>0) gTools().AddAttr(wght,"NVar",fEvent.begin()->GetNVar());
530 if (fEvent.size()>0) gTools().AddAttr(wght,"NTgt",fEvent.begin()->GetNTgt());
531
532 for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
533
534 std::stringstream s("");
535 s.precision( 16 );
536 for (UInt_t ivar = 0; ivar < event->GetNVar(); ++ivar) {
537 if (ivar>0) s << " ";
538 s << std::scientific << event->GetVar(ivar);
539 }
540
541 for (UInt_t itgt = 0; itgt < event->GetNTgt(); ++itgt) {
542 s << " " << std::scientific << event->GetTgt(itgt);
543 }
544
545 void* evt = gTools().AddChild(wght, "Event", s.str().c_str());
546 gTools().AddAttr(evt,"Type", event->GetType());
547 gTools().AddAttr(evt,"Weight", event->GetWeight());
548 }
549}
550
551////////////////////////////////////////////////////////////////////////////////
552
554 void* ch = gTools().GetChild(wghtnode); // first event
555 UInt_t nvar = 0, ntgt = 0;
556 gTools().ReadAttr( wghtnode, "NVar", nvar );
557 gTools().ReadAttr( wghtnode, "NTgt", ntgt );
558
559
560 Short_t evtType(0);
561 Double_t evtWeight(0);
562
563 while (ch) {
564 // build event
565 kNN::VarVec vvec(nvar, 0);
566 kNN::VarVec tvec(ntgt, 0);
567
568 gTools().ReadAttr( ch, "Type", evtType );
569 gTools().ReadAttr( ch, "Weight", evtWeight );
570 std::stringstream s( gTools().GetContent(ch) );
571
572 for(UInt_t ivar=0; ivar<nvar; ivar++)
573 s >> vvec[ivar];
574
575 for(UInt_t itgt=0; itgt<ntgt; itgt++)
576 s >> tvec[itgt];
577
578 ch = gTools().GetNextChild(ch);
579
580 kNN::Event event_knn(vvec, evtWeight, evtType, tvec);
581 fEvent.push_back(event_knn);
582 }
583
584 // create kd-tree (binary tree) structure
585 MakeKNN();
586}
587
588////////////////////////////////////////////////////////////////////////////////
589/// read the weights
590
592{
593 Log() << kINFO << "Starting ReadWeightsFromStream(std::istream& is) function..." << Endl;
594
595 if (!fEvent.empty()) {
596 Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
597 fEvent.clear();
598 }
599
600 UInt_t nvar = 0;
601
602 while (is) {
603 std::string line;
604 std::getline(is, line);
605
606 if (line.empty() || line.find("#") != std::string::npos) {
607 continue;
608 }
609
610 UInt_t count = 0;
611 std::string::size_type pos=0;
612 while( (pos=line.find(',',pos)) != std::string::npos ) { count++; pos++; }
613
614 if (nvar == 0) {
615 nvar = count - 2;
616 }
617 if (count < 3 || nvar != count - 2) {
618 Log() << kFATAL << "Missing comma delimeter(s)" << Endl;
619 }
620
621 // Int_t ievent = -1;
622 Int_t type = -1;
623 Double_t weight = -1.0;
624
625 kNN::VarVec vvec(nvar, 0.0);
626
627 UInt_t vcount = 0;
628 std::string::size_type prev = 0;
629
630 for (std::string::size_type ipos = 0; ipos < line.size(); ++ipos) {
631 if (line[ipos] != ',' && ipos + 1 != line.size()) {
632 continue;
633 }
634
635 if (!(ipos > prev)) {
636 Log() << kFATAL << "Wrong substring limits" << Endl;
637 }
638
639 std::string vstring = line.substr(prev, ipos - prev);
640 if (ipos + 1 == line.size()) {
641 vstring = line.substr(prev, ipos - prev + 1);
642 }
643
644 if (vstring.empty()) {
645 Log() << kFATAL << "Failed to parse string" << Endl;
646 }
647
648 if (vcount == 0) {
649 // ievent = std::atoi(vstring.c_str());
650 }
651 else if (vcount == 1) {
652 type = std::atoi(vstring.c_str());
653 }
654 else if (vcount == 2) {
655 weight = std::atof(vstring.c_str());
656 }
657 else if (vcount - 3 < vvec.size()) {
658 vvec[vcount - 3] = std::atof(vstring.c_str());
659 }
660 else {
661 Log() << kFATAL << "Wrong variable count" << Endl;
662 }
663
664 prev = ipos + 1;
665 ++vcount;
666 }
667
668 fEvent.push_back(kNN::Event(vvec, weight, type));
669 }
670
671 Log() << kINFO << "Read " << fEvent.size() << " events from text file" << Endl;
672
673 // create kd-tree (binary tree) structure
674 MakeKNN();
675}
676
677////////////////////////////////////////////////////////////////////////////////
678/// save weights to ROOT file
679
681{
682 Log() << kINFO << "Starting WriteWeightsToStream(TFile &rf) function..." << Endl;
683
684 if (fEvent.empty()) {
685 Log() << kWARNING << "MethodKNN contains no events " << Endl;
686 return;
687 }
688
689 kNN::Event *event = new kNN::Event();
690 TTree *tree = new TTree("knn", "event tree");
691 tree->SetDirectory(0);
692 tree->Branch("event", "TMVA::kNN::Event", &event);
693
694 Double_t size = 0.0;
695 for (kNN::EventVec::const_iterator it = fEvent.begin(); it != fEvent.end(); ++it) {
696 (*event) = (*it);
697 size += tree->Fill();
698 }
699
700 // !!! hard coded tree name !!!
701 rf.WriteTObject(tree, "knn", "Overwrite");
702
703 // scale to MegaBytes
704 size /= 1048576.0;
705
706 Log() << kINFO << "Wrote " << size << "MB and " << fEvent.size()
707 << " events to ROOT file" << Endl;
708
709 delete tree;
710 delete event;
711}
712
713////////////////////////////////////////////////////////////////////////////////
714/// read weights from ROOT file
715
717{
718 Log() << kINFO << "Starting ReadWeightsFromStream(TFile &rf) function..." << Endl;
719
720 if (!fEvent.empty()) {
721 Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
722 fEvent.clear();
723 }
724
725 // !!! hard coded tree name !!!
726 TTree *tree = dynamic_cast<TTree *>(rf.Get("knn"));
727 if (!tree) {
728 Log() << kFATAL << "Failed to find knn tree" << Endl;
729 return;
730 }
731
732 kNN::Event *event = new kNN::Event();
733 tree->SetBranchAddress("event", &event);
734
735 const Int_t nevent = tree->GetEntries();
736
737 Double_t size = 0.0;
738 for (Int_t i = 0; i < nevent; ++i) {
739 size += tree->GetEntry(i);
740 fEvent.push_back(*event);
741 }
742
743 // scale to MegaBytes
744 size /= 1048576.0;
745
746 Log() << kINFO << "Read " << size << "MB and " << fEvent.size()
747 << " events from ROOT file" << Endl;
748
749 delete event;
750
751 // create kd-tree (binary tree) structure
752 MakeKNN();
753}
754
755////////////////////////////////////////////////////////////////////////////////
756/// write specific classifier response
757
758void TMVA::MethodKNN::MakeClassSpecific( std::ostream& fout, const TString& className ) const
759{
760 fout << " // not implemented for class: \"" << className << "\"" << std::endl;
761 fout << "};" << std::endl;
762}
763
764////////////////////////////////////////////////////////////////////////////////
765/// get help message text
766///
767/// typical length of text line:
768/// "|--------------------------------------------------------------|"
769
771{
772 Log() << Endl;
773 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
774 Log() << Endl;
775 Log() << "The k-nearest neighbor (k-NN) algorithm is a multi-dimensional classification" << Endl
776 << "and regression algorithm. Similarly to other TMVA algorithms, k-NN uses a set of" << Endl
777 << "training events for which a classification category/regression target is known. " << Endl
778 << "The k-NN method compares a test event to all training events using a distance " << Endl
779 << "function, which is an Euclidean distance in a space defined by the input variables. "<< Endl
780 << "The k-NN method, as implemented in TMVA, uses a kd-tree algorithm to perform a" << Endl
781 << "quick search for the k events with shortest distance to the test event. The method" << Endl
782 << "returns a fraction of signal events among the k neighbors. It is recommended" << Endl
783 << "that a histogram which stores the k-NN decision variable is binned with k+1 bins" << Endl
784 << "between 0 and 1." << Endl;
785
786 Log() << Endl;
787 Log() << gTools().Color("bold") << "--- Performance tuning via configuration options: "
788 << gTools().Color("reset") << Endl;
789 Log() << Endl;
790 Log() << "The k-NN method estimates a density of signal and background events in a "<< Endl
791 << "neighborhood around the test event. The method assumes that the density of the " << Endl
792 << "signal and background events is uniform and constant within the neighborhood. " << Endl
793 << "k is an adjustable parameter and it determines an average size of the " << Endl
794 << "neighborhood. Small k values (less than 10) are sensitive to statistical " << Endl
795 << "fluctuations and large (greater than 100) values might not sufficiently capture " << Endl
796 << "local differences between events in the training set. The speed of the k-NN" << Endl
797 << "method also increases with larger values of k. " << Endl;
798 Log() << Endl;
799 Log() << "The k-NN method assigns equal weight to all input variables. Different scales " << Endl
800 << "among the input variables is compensated using ScaleFrac parameter: the input " << Endl
801 << "variables are scaled so that the widths for central ScaleFrac*100% events are " << Endl
802 << "equal among all the input variables." << Endl;
803
804 Log() << Endl;
805 Log() << gTools().Color("bold") << "--- Additional configuration options: "
806 << gTools().Color("reset") << Endl;
807 Log() << Endl;
808 Log() << "The method inclues an option to use a Gaussian kernel to smooth out the k-NN" << Endl
809 << "response. The kernel re-weights events using a distance to the test event." << Endl;
810}
811
812////////////////////////////////////////////////////////////////////////////////
813/// polynomial kernel
814
816{
817 const Double_t avalue = TMath::Abs(value);
818
819 if (!(avalue < 1.0)) {
820 return 0.0;
821 }
822
823 const Double_t prod = 1.0 - avalue * avalue * avalue;
824
825 return (prod * prod * prod);
826}
827
828////////////////////////////////////////////////////////////////////////////////
829/// Gaussian kernel
830
832 const kNN::Event &event, const std::vector<Double_t> &svec) const
833{
834 if (event_knn.GetNVar() != event.GetNVar() || event_knn.GetNVar() != svec.size()) {
835 Log() << kFATAL << "Mismatched vectors in Gaussian kernel function" << Endl;
836 return 0.0;
837 }
838
839 //
840 // compute exponent
841 //
842 double sum_exp = 0.0;
843
844 for(unsigned int ivar = 0; ivar < event_knn.GetNVar(); ++ivar) {
845
846 const Double_t diff_ = event.GetVar(ivar) - event_knn.GetVar(ivar);
847 const Double_t sigm_ = svec[ivar];
848 if (!(sigm_ > 0.0)) {
849 Log() << kFATAL << "Bad sigma value = " << sigm_ << Endl;
850 return 0.0;
851 }
852
853 sum_exp += diff_*diff_/(2.0*sigm_*sigm_);
854 }
855
856 //
857 // Return unnormalized(!) Gaussian function, because normalization
858 // cancels for the ratio of weights.
859 //
860
861 return std::exp(-sum_exp);
862}
863
864////////////////////////////////////////////////////////////////////////////////
865///
866/// Get polynomial kernel radius
867///
868
869Double_t TMVA::MethodKNN::getKernelRadius(const kNN::List &rlist) const
870{
871 Double_t kradius = -1.0;
872 UInt_t kcount = 0;
873 const UInt_t knn = static_cast<UInt_t>(fnkNN);
874
875 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
876 {
877 if (!(lit->second > 0.0)) continue;
878
879 if (kradius < lit->second || kradius < 0.0) kradius = lit->second;
880
881 ++kcount;
882 if (kcount >= knn) break;
883 }
884
885 return kradius;
886}
887
888////////////////////////////////////////////////////////////////////////////////
889///
890/// Get polynomial kernel radius
891///
892
893const std::vector<Double_t> TMVA::MethodKNN::getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
894{
895 std::vector<Double_t> rvec;
896 UInt_t kcount = 0;
897 const UInt_t knn = static_cast<UInt_t>(fnkNN);
898
899 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
900 {
901 if (!(lit->second > 0.0)) continue;
902
903 const kNN::Node<kNN::Event> *node_ = lit -> first;
904 const kNN::Event &event_ = node_-> GetEvent();
905
906 if (rvec.empty()) {
907 rvec.insert(rvec.end(), event_.GetNVar(), 0.0);
908 }
909 else if (rvec.size() != event_.GetNVar()) {
910 Log() << kFATAL << "Wrong number of variables, should never happen!" << Endl;
911 rvec.clear();
912 return rvec;
913 }
914
915 for(unsigned int ivar = 0; ivar < event_.GetNVar(); ++ivar) {
916 const Double_t diff_ = event_.GetVar(ivar) - event_knn.GetVar(ivar);
917 rvec[ivar] += diff_*diff_;
918 }
919
920 ++kcount;
921 if (kcount >= knn) break;
922 }
923
924 if (kcount < 1) {
925 Log() << kFATAL << "Bad event kcount = " << kcount << Endl;
926 rvec.clear();
927 return rvec;
928 }
929
930 for(unsigned int ivar = 0; ivar < rvec.size(); ++ivar) {
931 if (!(rvec[ivar] > 0.0)) {
932 Log() << kFATAL << "Bad RMS value = " << rvec[ivar] << Endl;
933 rvec.clear();
934 return rvec;
935 }
936
937 rvec[ivar] = std::abs(fSigmaFact)*std::sqrt(rvec[ivar]/kcount);
938 }
939
940 return rvec;
941}
942
943////////////////////////////////////////////////////////////////////////////////
944
945Double_t TMVA::MethodKNN::getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn)
946{
947 LDAEvents sig_vec, bac_vec;
948
949 for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
950
951 // get reference to current node to make code more readable
952 const kNN::Node<kNN::Event> &node = *(lit->first);
953 const kNN::VarVec &tvec = node.GetEvent().GetVars();
954
955 if (node.GetEvent().GetType() == 1) { // signal type = 1
956 sig_vec.push_back(tvec);
957 }
958 else if (node.GetEvent().GetType() == 2) { // background type = 2
959 bac_vec.push_back(tvec);
960 }
961 else {
962 Log() << kFATAL << "Unknown type for training event" << Endl;
963 }
964 }
965
966 fLDA.Initialize(sig_vec, bac_vec);
967
968 return fLDA.GetProb(event_knn.GetVars(), 1);
969}
#define REGISTER_METHOD(CLASS)
for example
std::vector< std::vector< Float_t > > LDAEvents
Definition: LDA.h:38
const Bool_t kFALSE
Definition: RtypesCore.h:90
short Short_t
Definition: RtypesCore.h:37
double Double_t
Definition: RtypesCore.h:57
const Bool_t kTRUE
Definition: RtypesCore.h:89
#define ClassImp(name)
Definition: Rtypes.h:361
int type
Definition: TGX11.cxx:120
double sqrt(double)
double exp(double)
Int_t WriteTObject(const TObject *obj, const char *name=nullptr, Option_t *option="", Int_t bufsize=0) override
Write object obj to this directory.
TObject * Get(const char *namecycle) override
Return pointer to object identified by namecycle.
A ROOT file is a suite of consecutive data records (TKey instances) with a well defined format.
Definition: TFile.h:53
Class that contains all the data information.
Definition: DataSetInfo.h:60
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:236
std::vector< Float_t > & GetTargets()
Definition: Event.h:103
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition: Event.cxx:381
Virtual base Class for all MVA method.
Definition: MethodBase.h:111
virtual void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility they are hence without any...
Definition: MethodBase.cxx:598
Analysis of k-nearest neighbor.
Definition: MethodKNN.h:54
void Init(void)
Initialization.
Definition: MethodKNN.cxx:190
void MakeKNN(void)
create kNN
Definition: MethodKNN.cxx:203
virtual ~MethodKNN(void)
destructor
Definition: MethodKNN.cxx:106
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
FDA can handle classification with 2 classes and regression with one regression-target.
Definition: MethodKNN.cxx:180
const std::vector< Double_t > getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
Get polynomial kernel radius.
Definition: MethodKNN.cxx:893
const Ranking * CreateRanking()
no ranking available
Definition: MethodKNN.cxx:518
void MakeClassSpecific(std::ostream &, const TString &) const
write specific classifier response
Definition: MethodKNN.cxx:758
MethodKNN(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="KNN")
standard constructor
Definition: MethodKNN.cxx:62
void DeclareCompatibilityOptions()
options that are used ONLY for the READER to ensure backward compatibility
Definition: MethodKNN.cxx:140
Double_t getKernelRadius(const kNN::List &rlist) const
Get polynomial kernel radius.
Definition: MethodKNN.cxx:869
void GetHelpMessage() const
get help message text
Definition: MethodKNN.cxx:770
const std::vector< Float_t > & GetRegressionValues()
Return vector of averages for target values of k-nearest neighbors.
Definition: MethodKNN.cxx:435
void Train(void)
kNN training
Definition: MethodKNN.cxx:234
void ReadWeightsFromStream(std::istream &istr)
read the weights
Definition: MethodKNN.cxx:591
double getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn)
Definition: MethodKNN.cxx:945
Double_t PolnKernel(Double_t value) const
polynomial kernel
Definition: MethodKNN.cxx:815
void ProcessOptions()
process the options specified by the user
Definition: MethodKNN.cxx:148
void ReadWeightsFromXML(void *wghtnode)
Definition: MethodKNN.cxx:553
void AddWeightsXMLTo(void *parent) const
write weights to XML
Definition: MethodKNN.cxx:526
void DeclareOptions()
MethodKNN options.
Definition: MethodKNN.cxx:124
void WriteWeightsToStream(TFile &rf) const
save weights to ROOT file
Definition: MethodKNN.cxx:680
Double_t GausKernel(const kNN::Event &event_knn, const kNN::Event &event, const std::vector< Double_t > &svec) const
Gaussian kernel.
Definition: MethodKNN.cxx:831
Double_t GetMvaValue(Double_t *err=0, Double_t *errUpper=0)
Compute classifier response.
Definition: MethodKNN.cxx:296
Ranking for variables in method (implementation)
Definition: Ranking.h:48
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1173
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1135
const TString & Color(const TString &)
human readable color strings
Definition: Tools.cxx:839
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1161
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:335
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:353
Singleton class for Global types used by TMVA.
Definition: Types.h:73
EAnalysisType
Definition: Types.h:127
@ kClassification
Definition: Types.h:128
@ kRegression
Definition: Types.h:129
void SetTargets(const VarVec &tvec)
Definition: ModulekNN.cxx:107
UInt_t GetNVar() const
Definition: ModulekNN.h:187
VarType GetVar(UInt_t i) const
Definition: ModulekNN.h:178
const VarVec & GetVars() const
Definition: ModulekNN.cxx:121
This file contains binary tree and global function template that searches tree for k-nearest neigbors...
Definition: NodekNN.h:67
Double_t GetWeight() const
Definition: NodekNN.h:180
const T & GetEvent() const
Definition: NodekNN.h:156
virtual void Clear(Option_t *="")
Definition: TObject.h:115
Basic string class.
Definition: TString.h:131
A TTree represents a columnar dataset.
Definition: TTree.h:78
TLine * line
static constexpr double s
static constexpr double second
create variable transformations
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:750
Double_t Sqrt(Double_t x)
Definition: TMath.h:681
Short_t Abs(Short_t d)
Definition: TMathBase.h:120
Definition: first.py:1
Definition: tree.py:1