Logo ROOT   6.08/07
Reference Guide
ModulekNN.h
Go to the documentation of this file.
1 // @(#)root/tmva $Id$
2 // Author: Rustem Ospanov
3 
4 /**********************************************************************************
5  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6  * Package: TMVA *
7  * Class : ModulekNN *
8  * Web : http://tmva.sourceforge.net *
9  * *
10  * Description: *
11  * Module for k-nearest neighbor algorithm *
12  * *
13  * Author: *
14  * Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA *
15  * *
16  * Copyright (c) 2007: *
17  * CERN, Switzerland *
18  * MPI-K Heidelberg, Germany *
19  * U. of Texas at Austin, USA *
20  * *
21  * Redistribution and use in source and binary forms, with or without *
22  * modification, are permitted according to the terms listed in LICENSE *
23  * (http://tmva.sourceforge.net/LICENSE) *
24  **********************************************************************************/
25 
26 #ifndef ROOT_TMVA_ModulekNN
27 #define ROOT_TMVA_ModulekNN
28 
29 //______________________________________________________________________
30 /*
31  kNN::Event describes point in input variable vector-space, with
32  additional functionality like distance between points
33 */
34 //______________________________________________________________________
35 
36 
37 // C++
38 #include <cassert>
39 #include <iosfwd>
40 #include <map>
41 #include <string>
42 #include <vector>
43 
44 // ROOT
45 #ifndef ROOT_Rtypes
46 #include "Rtypes.h"
47 #endif
48 #ifndef ROOT_TRandom
49 #include "TRandom3.h"
50 #endif
51 #ifndef ROOT_ThreadLocalStorage
52 #include "ThreadLocalStorage.h"
53 #endif
54 #ifndef ROOT_TMVA_NodekNN
55 #include "TMVA/NodekNN.h"
56 #endif
57 
58 namespace TMVA {
59 
60  class MsgLogger;
61 
62  namespace kNN {
63 
64  typedef Float_t VarType;
65  typedef std::vector<VarType> VarVec;
66 
67  class Event {
68  public:
69 
70  Event();
71  Event(const VarVec &vec, Double_t weight, Short_t type);
72  Event(const VarVec &vec, Double_t weight, Short_t type, const VarVec &tvec);
73  ~Event();
74 
75  Double_t GetWeight() const;
76 
77  VarType GetVar(UInt_t i) const;
78  VarType GetTgt(UInt_t i) const;
79 
80  UInt_t GetNVar() const;
81  UInt_t GetNTgt() const;
82 
83  Short_t GetType() const;
84 
85  // keep these two function separate
86  VarType GetDist(VarType var, UInt_t ivar) const;
87  VarType GetDist(const Event &other) const;
88 
89  void SetTargets(const VarVec &tvec);
90  const VarVec& GetTargets() const;
91  const VarVec& GetVars() const;
92 
93  void Print() const;
94  void Print(std::ostream& os) const;
95 
96  private:
97 
98  VarVec fVar; // coordinates (variables) for knn search
99  VarVec fTgt; // targets for regression analysis
100 
101  Double_t fWeight; // event weight
102  Short_t fType; // event type ==0 or == 1, expand it to arbitrary class types?
103  };
104 
105  typedef std::vector<TMVA::kNN::Event> EventVec;
106  typedef std::pair<const Node<Event> *, VarType> Elem;
107  typedef std::list<Elem> List;
108 
109  std::ostream& operator<<(std::ostream& os, const Event& event);
110 
111  class ModulekNN
112  {
113  public:
114 
115  typedef std::map<int, std::vector<Double_t> > VarMap;
116 
117  public:
118 
119  ModulekNN();
120  ~ModulekNN();
121 
122  void Clear();
123 
124  void Add(const Event &event);
125 
126  Bool_t Fill(const UShort_t odepth, UInt_t ifrac, const std::string &option = "");
127 
128  Bool_t Find(Event event, UInt_t nfind = 100, const std::string &option = "count") const;
129  Bool_t Find(UInt_t nfind, const std::string &option) const;
130 
131  const EventVec& GetEventVec() const;
132 
133  const List& GetkNNList() const;
134  const Event& GetkNNEvent() const;
135 
136  const VarMap& GetVarMap() const;
137 
138  const std::map<Int_t, Double_t>& GetMetric() const;
139 
140  void Print() const;
141  void Print(std::ostream &os) const;
142 
143  private:
144 
145  Node<Event>* Optimize(UInt_t optimize_depth);
146 
147  void ComputeMetric(UInt_t ifrac);
148 
149  const Event Scale(const Event &event) const;
150 
151  private:
152 
153  // This is a workaround for OSx where static thread_local data members are
154  // not supported. The C++ solution would indeed be the following:
155  static TRandom3& GetRndmThreadLocal() {TTHREAD_TLS_DECL_ARG(TRandom3,fgRndm,1); return fgRndm;};
156 
157  UInt_t fDimn;
158 
160 
161  std::map<Int_t, Double_t> fVarScale;
162 
163  mutable List fkNNList; // latest result from kNN search
164  mutable Event fkNNEvent; // latest event used for kNN search
165 
166  std::map<Short_t, UInt_t> fCount; // count number of events of each type
167 
168  EventVec fEvent; // vector of all events used to build tree and analysis
169  VarMap fVar; // sorted map of variables in each dimension for all event types
170 
171  mutable MsgLogger* fLogger; // message logger
172  MsgLogger& Log() const { return *fLogger; }
173  };
174 
175  //
176  // inlined functions for Event class
177  //
178  inline VarType Event::GetDist(const VarType var1, const UInt_t ivar) const
179  {
180  const VarType var2 = GetVar(ivar);
181  return (var1 - var2) * (var1 - var2);
182  }
183  inline Double_t Event::GetWeight() const
184  {
185  return fWeight;
186  }
187  inline VarType Event::GetVar(const UInt_t i) const
188  {
189  return fVar[i];
190  }
191  inline VarType Event::GetTgt(const UInt_t i) const
192  {
193  return fTgt[i];
194  }
195 
196  inline UInt_t Event::GetNVar() const
197  {
198  return fVar.size();
199  }
200  inline UInt_t Event::GetNTgt() const
201  {
202  return fTgt.size();
203  }
204  inline Short_t Event::GetType() const
205  {
206  return fType;
207  }
208 
209  //
210  // inline functions for ModulekNN class
211  //
212  inline const List& ModulekNN::GetkNNList() const
213  {
214  return fkNNList;
215  }
216  inline const Event& ModulekNN::GetkNNEvent() const
217  {
218  return fkNNEvent;
219  }
220  inline const EventVec& ModulekNN::GetEventVec() const
221  {
222  return fEvent;
223  }
225  {
226  return fVar;
227  }
228  inline const std::map<Int_t, Double_t>& ModulekNN::GetMetric() const
229  {
230  return fVarScale;
231  }
232 
233  } // end of kNN namespace
234 } // end of TMVA namespace
235 
236 #endif
237 
Event()
default constructor
Definition: ModulekNN.cxx:44
RooCmdArg Optimize(Int_t flag=2)
Random number generator class based on M.
Definition: TRandom3.h:29
float Float_t
Definition: RtypesCore.h:53
unsigned short UShort_t
Definition: RtypesCore.h:36
const List & GetkNNList() const
Definition: ModulekNN.h:212
bool Bool_t
Definition: RtypesCore.h:59
VarType GetVar(UInt_t i) const
Definition: ModulekNN.h:187
std::ostream & operator<<(std::ostream &os, const Event &event)
streamer
Definition: ModulekNN.cxx:158
Double_t GetWeight() const
Definition: ModulekNN.h:183
std::map< int, std::vector< Double_t > > VarMap
Definition: ModulekNN.h:115
std::map< Int_t, Double_t > fVarScale
Definition: ModulekNN.h:161
const std::map< Int_t, Double_t > & GetMetric() const
Definition: ModulekNN.h:228
MsgLogger & Log() const
Definition: ModulekNN.h:172
Short_t GetType() const
Definition: ModulekNN.h:204
std::list< Elem > List
Definition: ModulekNN.h:107
const VarVec & GetTargets() const
Definition: ModulekNN.cxx:108
const EventVec & GetEventVec() const
Definition: ModulekNN.h:220
void Add(THist< DIMENSIONS, PRECISION_TO, STAT_TO... > &to, THist< DIMENSIONS, PRECISION_FROM, STAT_FROM... > &from)
Add two histograms.
Definition: THist.hxx:327
std::vector< TMVA::kNN::Event > EventVec
Definition: ModulekNN.h:105
Double_t fWeight
Definition: ModulekNN.h:101
void SetTargets(const VarVec &tvec)
Definition: ModulekNN.cxx:101
VarType GetDist(VarType var, UInt_t ivar) const
Definition: ModulekNN.h:178
Float_t VarType
Definition: ModulekNN.h:64
UInt_t GetNVar() const
Definition: ModulekNN.h:196
unsigned int UInt_t
Definition: RtypesCore.h:42
short Short_t
Definition: RtypesCore.h:35
VarType GetTgt(UInt_t i) const
Definition: ModulekNN.h:191
const VarMap & GetVarMap() const
Definition: ModulekNN.h:224
double Double_t
Definition: RtypesCore.h:55
std::pair< const Node< Event > *, VarType > Elem
Definition: ModulekNN.h:106
int type
Definition: TGX11.cxx:120
UInt_t GetNTgt() const
Definition: ModulekNN.h:200
static TRandom3 & GetRndmThreadLocal()
Definition: ModulekNN.h:155
UInt_t Find(std::list< std::pair< const Node< T > *, Float_t > > &nlist, const Node< T > *node, const T &event, UInt_t nfind)
MsgLogger * fLogger
Definition: ModulekNN.h:171
const VarVec & GetVars() const
Definition: ModulekNN.cxx:115
const Event & GetkNNEvent() const
Definition: ModulekNN.h:216
Abstract ClassifierFactory template that handles arbitrary types.
void Print() const
print
Definition: ModulekNN.cxx:123
std::map< Short_t, UInt_t > fCount
Definition: ModulekNN.h:166
std::vector< VarType > VarVec
Definition: ModulekNN.h:65
Node< Event > * fTree
Definition: ModulekNN.h:159
~Event()
destructor
Definition: ModulekNN.cxx:75