Logo ROOT  
Reference Guide
VarTransformHandler.cxx
Go to the documentation of this file.
1 /**********************************************************************************
2  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3  * Package: TMVA *
4  * Class : VarTransformHandler *
5  * Web : http://tmva.sourceforge.net *
6  * *
7  * Description: *
8  * Implementation of unsupervised variable transformation methods *
9  * *
10  * Authors (alphabetical): *
11  * Abhinav Moudgil <abhinav.moudgil@research.iiit.ac.in> - IIIT-H, India *
12  * *
13  * Copyright (c) 2005: *
14  * CERN, Switzerland *
15  * *
16  * Redistribution and use in source and binary forms, with or without *
17  * modification, are permitted according to the terms listed in LICENSE *
18  * (http://tmva.sourceforge.net/LICENSE) *
19  **********************************************************************************/
20 
22 
23 #include "TMVA/ClassifierFactory.h"
24 #include "TMVA/DataLoader.h"
25 #include "TMVA/Event.h"
26 #include "TMVA/DataInputHandler.h"
27 #include "TMVA/DataSet.h"
28 #include "TMVA/DataSetInfo.h"
29 #include "TMVA/MethodBase.h"
30 #include "TMVA/MethodDNN.h"
31 #include "TMVA/MsgLogger.h"
32 #include "TMVA/Tools.h"
33 #include "TMVA/Types.h"
34 #include "TMVA/VariableInfo.h"
35 
36 #include "TMath.h"
37 #include "TVectorD.h"
38 #include "TMatrix.h"
39 #include "TMatrixTSparse.h"
40 #include "TMatrixDSparsefwd.h"
41 
42 #include <algorithm>
43 #include <iomanip>
44 #include <vector>
45 
46 ////////////////////////////////////////////////////////////////////////////////
47 /// constructor
48 
50  : fLogger ( new MsgLogger(TString("VarTransformHandler").Data(), kINFO) ),
51  fDataSetInfo(dl->GetDataSetInfo()),
52  fDataLoader (dl),
53  fEvents (fDataSetInfo.GetDataSet()->GetEventCollection())
54 {
55  Log() << kINFO << "Number of events - " << fEvents.size() << Endl;
56 }
57 
58 ////////////////////////////////////////////////////////////////////////////////
59 /// destructor
60 
62 {
63  // do something
64  delete fLogger;
65 }
66 
67 ////////////////////////////////////////////////////////////////////////////////
68 /// Computes variance of all the variables and
69 /// returns a new DataLoader with the selected variables whose variance is above a specific threshold.
70 /// Threshold can be provided by user otherwise default value is 0 i.e. remove the variables which have same value in all
71 /// the events.
72 ///
73 /// \param[in] threshold value (Double)
74 ///
75 /// Transformation Definition String Format: "VT(optional float value)"
76 ///
77 /// Usage examples:
78 ///
79 /// String | Description
80 /// ------- |----------------------------------------
81 /// "VT" | Select variables whose variance is above threshold value = 0 (Default)
82 /// "VT(1.5)" | Select variables whose variance is above threshold value = 1.5
83 
85 {
86  CalcNorm();
87  const UInt_t nvars = fDataSetInfo.GetNVariables();
88  Log() << kINFO << "Number of variables before transformation: " << nvars << Endl;
89  std::vector<VariableInfo>& vars = fDataSetInfo.GetVariableInfos();
90 
91  // return a new dataloader
92  // iterate over all variables, ignore the ones whose variance is below specific threshold
93  // DataLoader *transformedLoader=(DataLoader *)fDataLoader->Clone("vt_transformed_dataset");
94  // TMVA::DataLoader *transformedLoader = new TMVA::DataLoader(fDataSetInfo.GetName());
95  TMVA::DataLoader *transformedLoader = new TMVA::DataLoader("vt_transformed_dataset");
96  Log() << kINFO << "Selecting variables whose variance is above threshold value = " << threshold << Endl;
97  Int_t maxL = fDataSetInfo.GetVariableNameMaxLength();
98  maxL = maxL + 16;
99  Log() << kINFO << "----------------------------------------------------------------" << Endl;
100  Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << "Selected Variables";
101  Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(10) << "Variance" << Endl;
102  Log() << kINFO << "----------------------------------------------------------------" << Endl;
103  for (UInt_t ivar=0; ivar<nvars; ivar++) {
104  Double_t variance = vars[ivar].GetVariance();
105  if (variance > threshold)
106  {
107  Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << vars[ivar].GetExpression();
108  Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << variance << Endl;
109  transformedLoader->AddVariable(vars[ivar].GetExpression(), vars[ivar].GetVarType());
110  }
111  }
112  CopyDataLoader(transformedLoader,fDataLoader);
113  Log() << kINFO << "----------------------------------------------------------------" << Endl;
114  // CopyDataLoader(transformedLoader, fDataLoader);
115  // DataLoader *transformedLoader=(DataLoader *)fDataLoader->Clone(fDataSetInfo.GetName());
116  transformedLoader->PrepareTrainingAndTestTree(fDataLoader->GetDataSetInfo().GetCut("Signal"), fDataLoader->GetDataSetInfo().GetCut("Background"), fDataLoader->GetDataSetInfo().GetSplitOptions());
117  Log() << kINFO << "Number of variables after transformation: " << transformedLoader->GetDataSetInfo().GetNVariables() << Endl;
118 
119  return transformedLoader;
120 }
121 
122 ///////////////////////////////////////////////////////////////////////////////
123 ////////////////////////////// Utility methods ////////////////////////////////
124 ///////////////////////////////////////////////////////////////////////////////
125 
126 ////////////////////////////////////////////////////////////////////////////////
127 /// Updates maximum and minimum value of a variable or target
128 
130 {
131  Int_t nvars = fDataSetInfo.GetNVariables();
132  std::vector<VariableInfo>& vars = fDataSetInfo.GetVariableInfos();
133  std::vector<VariableInfo>& tars = fDataSetInfo.GetTargetInfos();
134  if( ivar < nvars ){
135  if (x < vars[ivar].GetMin()) vars[ivar].SetMin(x);
136  if (x > vars[ivar].GetMax()) vars[ivar].SetMax(x);
137  }
138  else{
139  if (x < tars[ivar-nvars].GetMin()) tars[ivar-nvars].SetMin(x);
140  if (x > tars[ivar-nvars].GetMax()) tars[ivar-nvars].SetMax(x);
141  }
142 }
143 
144 ////////////////////////////////////////////////////////////////////////////////
145 /// Computes maximum, minimum, mean, RMS and variance for all
146 /// variables and targets
147 
149 {
150  const std::vector<TMVA::Event*>& events = fDataSetInfo.GetDataSet()->GetEventCollection();
151 
152  const UInt_t nvars = fDataSetInfo.GetNVariables();
153  const UInt_t ntgts = fDataSetInfo.GetNTargets();
154  std::vector<VariableInfo>& vars = fDataSetInfo.GetVariableInfos();
155  std::vector<VariableInfo>& tars = fDataSetInfo.GetTargetInfos();
156 
157  UInt_t nevts = events.size();
158 
159  TVectorD x2( nvars+ntgts ); x2 *= 0;
160  TVectorD x0( nvars+ntgts ); x0 *= 0;
161  TVectorD v0( nvars+ntgts ); v0 *= 0;
162 
163  Double_t sumOfWeights = 0;
164  for (UInt_t ievt=0; ievt<nevts; ievt++) {
165  const Event* ev = events[ievt];
166 
167  Double_t weight = ev->GetWeight();
168  sumOfWeights += weight;
169  for (UInt_t ivar=0; ivar<nvars; ivar++) {
170  Double_t x = ev->GetValue(ivar);
171  if (ievt==0) {
172  vars[ivar].SetMin(x);
173  vars[ivar].SetMax(x);
174  }
175  else {
176  UpdateNorm(ivar, x );
177  }
178  x0(ivar) += x*weight;
179  x2(ivar) += x*x*weight;
180  }
181  for (UInt_t itgt=0; itgt<ntgts; itgt++) {
182  Double_t x = ev->GetTarget(itgt);
183  if (ievt==0) {
184  tars[itgt].SetMin(x);
185  tars[itgt].SetMax(x);
186  }
187  else {
188  UpdateNorm( nvars+itgt, x );
189  }
190  x0(nvars+itgt) += x*weight;
191  x2(nvars+itgt) += x*x*weight;
192  }
193  }
194 
195  if (sumOfWeights <= 0) {
196  Log() << kFATAL << " the sum of event weights calculated for your input is == 0"
197  << " or exactly: " << sumOfWeights << " there is obviously some problem..."<< Endl;
198  }
199 
200  // set Mean and RMS
201  for (UInt_t ivar=0; ivar<nvars; ivar++) {
202  Double_t mean = x0(ivar)/sumOfWeights;
203 
204  vars[ivar].SetMean( mean );
205  if (x2(ivar)/sumOfWeights - mean*mean < 0) {
206  Log() << kFATAL << " the RMS of your input variable " << ivar
207  << " evaluates to an imaginary number: sqrt("<< x2(ivar)/sumOfWeights - mean*mean
208  <<") .. sometimes related to a problem with outliers and negative event weights"
209  << Endl;
210  }
211  vars[ivar].SetRMS( TMath::Sqrt( x2(ivar)/sumOfWeights - mean*mean) );
212  }
213  for (UInt_t itgt=0; itgt<ntgts; itgt++) {
214  Double_t mean = x0(nvars+itgt)/sumOfWeights;
215  tars[itgt].SetMean( mean );
216  if (x2(nvars+itgt)/sumOfWeights - mean*mean < 0) {
217  Log() << kFATAL << " the RMS of your target variable " << itgt
218  << " evaluates to an imaginary number: sqrt(" << x2(nvars+itgt)/sumOfWeights - mean*mean
219  <<") .. sometimes related to a problem with outliers and negative event weights"
220  << Endl;
221  }
222  tars[itgt].SetRMS( TMath::Sqrt( x2(nvars+itgt)/sumOfWeights - mean*mean) );
223  }
224 
225  // calculate variance
226  for (UInt_t ievt=0; ievt<nevts; ievt++) {
227  const Event* ev = events[ievt];
228  Double_t weight = ev->GetWeight();
229 
230  for (UInt_t ivar=0; ivar<nvars; ivar++) {
231  Double_t x = ev->GetValue(ivar);
232  Double_t mean = vars[ivar].GetMean();
233  v0(ivar) += weight*(x-mean)*(x-mean);
234  }
235 
236  for (UInt_t itgt=0; itgt<ntgts; itgt++) {
237  Double_t x = ev->GetTarget(itgt);
238  Double_t mean = tars[itgt].GetMean();
239  v0(nvars+itgt) += weight*(x-mean)*(x-mean);
240  }
241  }
242 
243  Int_t maxL = fDataSetInfo.GetVariableNameMaxLength();
244  maxL = maxL + 8;
245  Log() << kINFO << "----------------------------------------------------------------" << Endl;
246  Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << "Variables";
247  Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(10) << "Variance" << Endl;
248  Log() << kINFO << "----------------------------------------------------------------" << Endl;
249 
250  // set variance
251  Log() << std::setprecision(5);
252  for (UInt_t ivar=0; ivar<nvars; ivar++) {
253  Double_t variance = v0(ivar)/sumOfWeights;
254  vars[ivar].SetVariance( variance );
255  Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << vars[ivar].GetExpression();
256  Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << variance << Endl;
257  }
258 
259  maxL = fDataSetInfo.GetTargetNameMaxLength();
260  maxL = maxL + 8;
261  Log() << kINFO << "----------------------------------------------------------------" << Endl;
262  Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << "Targets";
263  Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(10) << "Variance" << Endl;
264  Log() << kINFO << "----------------------------------------------------------------" << Endl;
265 
266  for (UInt_t itgt=0; itgt<ntgts; itgt++) {
267  Double_t variance = v0(nvars+itgt)/sumOfWeights;
268  tars[itgt].SetVariance( variance );
269  Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << tars[itgt].GetExpression();
270  Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << variance << Endl;
271  }
272 
273  Log() << kINFO << "Set minNorm/maxNorm for variables to: " << Endl;
274  Log() << std::setprecision(3);
275  for (UInt_t ivar=0; ivar<nvars; ivar++)
276  Log() << " " << vars[ivar].GetExpression()
277  << "\t: [" << vars[ivar].GetMin() << "\t, " << vars[ivar].GetMax() << "\t] " << Endl;
278  Log() << kINFO << "Set minNorm/maxNorm for targets to: " << Endl;
279  Log() << std::setprecision(3);
280  for (UInt_t itgt=0; itgt<ntgts; itgt++)
281  Log() << " " << tars[itgt].GetExpression()
282  << "\t: [" << tars[itgt].GetMin() << "\t, " << tars[itgt].GetMax() << "\t] " << Endl;
283  Log() << std::setprecision(5); // reset to better value
284 }
285 
286 ////////////////////////////////////////////////////////////////////////////////
288 {
289  for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Sbegin();treeinfo!=src->DataInput().Send();++treeinfo)
290  {
291  des->AddSignalTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
292  }
293 
294  for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Bbegin();treeinfo!=src->DataInput().Bend();++treeinfo)
295  {
296  des->AddBackgroundTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
297  }
298 }
v0
@ v0
Definition: rootcling_impl.cxx:3636
TVectorD.h
TMVA::DataInputHandler::Sbegin
std::vector< TreeInfo >::const_iterator Sbegin() const
Definition: DataInputHandler.h:113
TMVA::DataLoader::PrepareTrainingAndTestTree
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
Definition: DataLoader.cxx:631
TMatrixDSparsefwd.h
TMVA::VarTransformHandler::UpdateNorm
void UpdateNorm(Int_t ivar, Double_t x)
Updates maximum and minimum value of a variable or target.
Definition: VarTransformHandler.cxx:129
DataSetInfo.h
TMath::Log
Double_t Log(Double_t x)
Definition: TMath.h:760
TMath::Sqrt
Double_t Sqrt(Double_t x)
Definition: TMath.h:691
TMVA::DataSetInfo::GetNVariables
UInt_t GetNVariables() const
Definition: DataSetInfo.h:127
DataLoader.h
VariableInfo.h
TMVA::VarTransformHandler::fEvents
const std::vector< Event * > & fEvents
Definition: VarTransformHandler.h:55
MethodDNN.h
TMVA::VarTransformHandler::CalcNorm
void CalcNorm()
Computes maximum, minimum, mean, RMS and variance for all variables and targets.
Definition: VarTransformHandler.cxx:148
TMVA::DataInputHandler::Bend
std::vector< TreeInfo >::const_iterator Bend() const
Definition: DataInputHandler.h:116
x
Double_t x[n]
Definition: legend1.C:17
MethodBase.h
TMVA::Event::GetTarget
Float_t GetTarget(UInt_t itgt) const
Definition: Event.h:102
VarTransformHandler.h
TMVA::DataLoader::AddSignalTree
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: DataLoader.cxx:370
TString
Basic string class.
Definition: TString.h:136
TMatrix.h
TMVA::Event::GetValue
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:236
DataInputHandler.h
MsgLogger.h
Event.h
Types.h
TMVA::Endl
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
unsigned int
TVectorT< Double_t >
TMVA::VarTransformHandler::CopyDataLoader
void CopyDataLoader(TMVA::DataLoader *des, TMVA::DataLoader *src)
Definition: VarTransformHandler.cxx:287
TMVA::VarTransformHandler::VarTransformHandler
VarTransformHandler(DataLoader *)
constructor
Definition: VarTransformHandler.cxx:49
Double_t
double Double_t
Definition: RtypesCore.h:59
TMVA::MsgLogger
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
TMVA::Event::GetWeight
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition: Event.cxx:381
TMVA::VarTransformHandler::VarianceThreshold
TMVA::DataLoader * VarianceThreshold(Double_t threshold)
Computes variance of all the variables and returns a new DataLoader with the selected variables whose...
Definition: VarTransformHandler.cxx:84
TMVA::DataLoader::DataInput
DataInputHandler & DataInput()
Definition: DataLoader.h:172
TMVA::VarTransformHandler::Log
MsgLogger & Log() const
message logger
Definition: VarTransformHandler.h:49
TMVA::Event
Definition: Event.h:51
x2
static const double x2[5]
Definition: RooGaussKronrodIntegrator1D.cxx:364
TMVA::DataInputHandler::Send
std::vector< TreeInfo >::const_iterator Send() const
Definition: DataInputHandler.h:114
TMVA::DataInputHandler::Bbegin
std::vector< TreeInfo >::const_iterator Bbegin() const
Definition: DataInputHandler.h:115
TMVA::VarTransformHandler::~VarTransformHandler
~VarTransformHandler()
destructor
Definition: VarTransformHandler.cxx:61
TMVA::DataLoader::AddBackgroundTree
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: DataLoader.cxx:401
Tools.h
ClassifierFactory.h
TMVA::DataLoader::AddVariable
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: DataLoader.cxx:484
TMatrixTSparse.h
TMVA::DataLoader::GetDataSetInfo
DataSetInfo & GetDataSetInfo()
Definition: DataLoader.cxx:137
DataSet.h
TMath.h
int
TMVA::DataLoader
Definition: DataLoader.h:50