Logo ROOT  
Reference Guide
VarTransformHandler.cxx
Go to the documentation of this file.
1/**********************************************************************************
2 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3 * Package: TMVA *
4 * Class : VarTransformHandler *
5 * Web : http://tmva.sourceforge.net *
6 * *
7 * Description: *
8 * Implementation of unsupervised variable transformation methods *
9 * *
10 * Authors (alphabetical): *
11 * Abhinav Moudgil <abhinav.moudgil@research.iiit.ac.in> - IIIT-H, India *
12 * *
13 * Copyright (c) 2005: *
14 * CERN, Switzerland *
15 * *
16 * Redistribution and use in source and binary forms, with or without *
17 * modification, are permitted according to the terms listed in LICENSE *
18 * (http://tmva.sourceforge.net/LICENSE) *
19 **********************************************************************************/
20
22
24#include "TMVA/DataLoader.h"
25#include "TMVA/Event.h"
27#include "TMVA/DataSet.h"
28#include "TMVA/DataSetInfo.h"
29#include "TMVA/MethodBase.h"
30#include "TMVA/MethodDNN.h"
31#include "TMVA/MsgLogger.h"
32#include "TMVA/Tools.h"
33#include "TMVA/Types.h"
34#include "TMVA/VariableInfo.h"
35
36#include "TMath.h"
37#include "TVectorD.h"
38#include "TMatrix.h"
39#include "TMatrixTSparse.h"
40#include "TMatrixDSparsefwd.h"
41
42#include <algorithm>
43#include <iomanip>
44#include <vector>
45
46////////////////////////////////////////////////////////////////////////////////
47/// constructor
48
50 : fLogger ( new MsgLogger(TString("VarTransformHandler").Data(), kINFO) ),
51 fDataSetInfo(dl->GetDataSetInfo()),
52 fDataLoader (dl),
53 fEvents (fDataSetInfo.GetDataSet()->GetEventCollection())
54{
55 Log() << kINFO << "Number of events - " << fEvents.size() << Endl;
56}
57
58////////////////////////////////////////////////////////////////////////////////
59/// destructor
60
62{
63 // do something
64 delete fLogger;
65}
66
67////////////////////////////////////////////////////////////////////////////////
68/// Computes variance of all the variables and
69/// returns a new DataLoader with the selected variables whose variance is above a specific threshold.
70/// Threshold can be provided by user otherwise default value is 0 i.e. remove the variables which have same value in all
71/// the events.
72///
73/// \param[in] threshold value (Double)
74///
75/// Transformation Definition String Format: "VT(optional float value)"
76///
77/// Usage examples:
78///
79/// String | Description
80/// ------- |----------------------------------------
81/// "VT" | Select variables whose variance is above threshold value = 0 (Default)
82/// "VT(1.5)" | Select variables whose variance is above threshold value = 1.5
83
85{
86 CalcNorm();
87 const UInt_t nvars = fDataSetInfo.GetNVariables();
88 Log() << kINFO << "Number of variables before transformation: " << nvars << Endl;
89 std::vector<VariableInfo>& vars = fDataSetInfo.GetVariableInfos();
90
91 // return a new dataloader
92 // iterate over all variables, ignore the ones whose variance is below specific threshold
93 // DataLoader *transformedLoader=(DataLoader *)fDataLoader->Clone("vt_transformed_dataset");
94 // TMVA::DataLoader *transformedLoader = new TMVA::DataLoader(fDataSetInfo.GetName());
95 TMVA::DataLoader *transformedLoader = new TMVA::DataLoader("vt_transformed_dataset");
96 Log() << kINFO << "Selecting variables whose variance is above threshold value = " << threshold << Endl;
97 Int_t maxL = fDataSetInfo.GetVariableNameMaxLength();
98 maxL = maxL + 16;
99 Log() << kINFO << "----------------------------------------------------------------" << Endl;
100 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << "Selected Variables";
101 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(10) << "Variance" << Endl;
102 Log() << kINFO << "----------------------------------------------------------------" << Endl;
103 for (UInt_t ivar=0; ivar<nvars; ivar++) {
104 Double_t variance = vars[ivar].GetVariance();
105 if (variance > threshold)
106 {
107 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << vars[ivar].GetExpression();
108 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << variance << Endl;
109 transformedLoader->AddVariable(vars[ivar].GetExpression(), vars[ivar].GetVarType());
110 }
111 }
112 CopyDataLoader(transformedLoader,fDataLoader);
113 Log() << kINFO << "----------------------------------------------------------------" << Endl;
114 // CopyDataLoader(transformedLoader, fDataLoader);
115 // DataLoader *transformedLoader=(DataLoader *)fDataLoader->Clone(fDataSetInfo.GetName());
116 transformedLoader->PrepareTrainingAndTestTree(fDataLoader->GetDataSetInfo().GetCut("Signal"), fDataLoader->GetDataSetInfo().GetCut("Background"), fDataLoader->GetDataSetInfo().GetSplitOptions());
117 Log() << kINFO << "Number of variables after transformation: " << transformedLoader->GetDataSetInfo().GetNVariables() << Endl;
118
119 return transformedLoader;
120}
121
122///////////////////////////////////////////////////////////////////////////////
123////////////////////////////// Utility methods ////////////////////////////////
124///////////////////////////////////////////////////////////////////////////////
125
126////////////////////////////////////////////////////////////////////////////////
127/// Updates maximum and minimum value of a variable or target
128
130{
131 Int_t nvars = fDataSetInfo.GetNVariables();
132 std::vector<VariableInfo>& vars = fDataSetInfo.GetVariableInfos();
133 std::vector<VariableInfo>& tars = fDataSetInfo.GetTargetInfos();
134 if( ivar < nvars ){
135 if (x < vars[ivar].GetMin()) vars[ivar].SetMin(x);
136 if (x > vars[ivar].GetMax()) vars[ivar].SetMax(x);
137 }
138 else{
139 if (x < tars[ivar-nvars].GetMin()) tars[ivar-nvars].SetMin(x);
140 if (x > tars[ivar-nvars].GetMax()) tars[ivar-nvars].SetMax(x);
141 }
142}
143
144////////////////////////////////////////////////////////////////////////////////
145/// Computes maximum, minimum, mean, RMS and variance for all
146/// variables and targets
147
149{
150 const std::vector<TMVA::Event*>& events = fDataSetInfo.GetDataSet()->GetEventCollection();
151
152 const UInt_t nvars = fDataSetInfo.GetNVariables();
153 const UInt_t ntgts = fDataSetInfo.GetNTargets();
154 std::vector<VariableInfo>& vars = fDataSetInfo.GetVariableInfos();
155 std::vector<VariableInfo>& tars = fDataSetInfo.GetTargetInfos();
156
157 UInt_t nevts = events.size();
158
159 TVectorD x2( nvars+ntgts ); x2 *= 0;
160 TVectorD x0( nvars+ntgts ); x0 *= 0;
161 TVectorD v0( nvars+ntgts ); v0 *= 0;
162
163 Double_t sumOfWeights = 0;
164 for (UInt_t ievt=0; ievt<nevts; ievt++) {
165 const Event* ev = events[ievt];
166
167 Double_t weight = ev->GetWeight();
168 sumOfWeights += weight;
169 for (UInt_t ivar=0; ivar<nvars; ivar++) {
170 Double_t x = ev->GetValue(ivar);
171 if (ievt==0) {
172 vars[ivar].SetMin(x);
173 vars[ivar].SetMax(x);
174 }
175 else {
176 UpdateNorm(ivar, x );
177 }
178 x0(ivar) += x*weight;
179 x2(ivar) += x*x*weight;
180 }
181 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
182 Double_t x = ev->GetTarget(itgt);
183 if (ievt==0) {
184 tars[itgt].SetMin(x);
185 tars[itgt].SetMax(x);
186 }
187 else {
188 UpdateNorm( nvars+itgt, x );
189 }
190 x0(nvars+itgt) += x*weight;
191 x2(nvars+itgt) += x*x*weight;
192 }
193 }
194
195 if (sumOfWeights <= 0) {
196 Log() << kFATAL << " the sum of event weights calculated for your input is == 0"
197 << " or exactly: " << sumOfWeights << " there is obviously some problem..."<< Endl;
198 }
199
200 // set Mean and RMS
201 for (UInt_t ivar=0; ivar<nvars; ivar++) {
202 Double_t mean = x0(ivar)/sumOfWeights;
203
204 vars[ivar].SetMean( mean );
205 if (x2(ivar)/sumOfWeights - mean*mean < 0) {
206 Log() << kFATAL << " the RMS of your input variable " << ivar
207 << " evaluates to an imaginary number: sqrt("<< x2(ivar)/sumOfWeights - mean*mean
208 <<") .. sometimes related to a problem with outliers and negative event weights"
209 << Endl;
210 }
211 vars[ivar].SetRMS( TMath::Sqrt( x2(ivar)/sumOfWeights - mean*mean) );
212 }
213 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
214 Double_t mean = x0(nvars+itgt)/sumOfWeights;
215 tars[itgt].SetMean( mean );
216 if (x2(nvars+itgt)/sumOfWeights - mean*mean < 0) {
217 Log() << kFATAL << " the RMS of your target variable " << itgt
218 << " evaluates to an imaginary number: sqrt(" << x2(nvars+itgt)/sumOfWeights - mean*mean
219 <<") .. sometimes related to a problem with outliers and negative event weights"
220 << Endl;
221 }
222 tars[itgt].SetRMS( TMath::Sqrt( x2(nvars+itgt)/sumOfWeights - mean*mean) );
223 }
224
225 // calculate variance
226 for (UInt_t ievt=0; ievt<nevts; ievt++) {
227 const Event* ev = events[ievt];
228 Double_t weight = ev->GetWeight();
229
230 for (UInt_t ivar=0; ivar<nvars; ivar++) {
231 Double_t x = ev->GetValue(ivar);
232 Double_t mean = vars[ivar].GetMean();
233 v0(ivar) += weight*(x-mean)*(x-mean);
234 }
235
236 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
237 Double_t x = ev->GetTarget(itgt);
238 Double_t mean = tars[itgt].GetMean();
239 v0(nvars+itgt) += weight*(x-mean)*(x-mean);
240 }
241 }
242
243 Int_t maxL = fDataSetInfo.GetVariableNameMaxLength();
244 maxL = maxL + 8;
245 Log() << kINFO << "----------------------------------------------------------------" << Endl;
246 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << "Variables";
247 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(10) << "Variance" << Endl;
248 Log() << kINFO << "----------------------------------------------------------------" << Endl;
249
250 // set variance
251 Log() << std::setprecision(5);
252 for (UInt_t ivar=0; ivar<nvars; ivar++) {
253 Double_t variance = v0(ivar)/sumOfWeights;
254 vars[ivar].SetVariance( variance );
255 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << vars[ivar].GetExpression();
256 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << variance << Endl;
257 }
258
259 maxL = fDataSetInfo.GetTargetNameMaxLength();
260 maxL = maxL + 8;
261 Log() << kINFO << "----------------------------------------------------------------" << Endl;
262 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << "Targets";
263 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(10) << "Variance" << Endl;
264 Log() << kINFO << "----------------------------------------------------------------" << Endl;
265
266 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
267 Double_t variance = v0(nvars+itgt)/sumOfWeights;
268 tars[itgt].SetVariance( variance );
269 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << tars[itgt].GetExpression();
270 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << variance << Endl;
271 }
272
273 Log() << kINFO << "Set minNorm/maxNorm for variables to: " << Endl;
274 Log() << std::setprecision(3);
275 for (UInt_t ivar=0; ivar<nvars; ivar++)
276 Log() << " " << vars[ivar].GetExpression()
277 << "\t: [" << vars[ivar].GetMin() << "\t, " << vars[ivar].GetMax() << "\t] " << Endl;
278 Log() << kINFO << "Set minNorm/maxNorm for targets to: " << Endl;
279 Log() << std::setprecision(3);
280 for (UInt_t itgt=0; itgt<ntgts; itgt++)
281 Log() << " " << tars[itgt].GetExpression()
282 << "\t: [" << tars[itgt].GetMin() << "\t, " << tars[itgt].GetMax() << "\t] " << Endl;
283 Log() << std::setprecision(5); // reset to better value
284}
285
286////////////////////////////////////////////////////////////////////////////////
288{
289 for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Sbegin();treeinfo!=src->DataInput().Send();++treeinfo)
290 {
291 des->AddSignalTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
292 }
293
294 for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Bbegin();treeinfo!=src->DataInput().Bend();++treeinfo)
295 {
296 des->AddBackgroundTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
297 }
298}
static const double x2[5]
int Int_t
Definition: RtypesCore.h:45
unsigned int UInt_t
Definition: RtypesCore.h:46
double Double_t
Definition: RtypesCore.h:59
std::vector< TreeInfo >::const_iterator Send() const
std::vector< TreeInfo >::const_iterator Sbegin() const
std::vector< TreeInfo >::const_iterator Bbegin() const
std::vector< TreeInfo >::const_iterator Bend() const
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: DataLoader.cxx:371
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
Definition: DataLoader.cxx:632
DataInputHandler & DataInput()
Definition: DataLoader.h:172
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: DataLoader.cxx:402
DataSetInfo & GetDataSetInfo()
Definition: DataLoader.cxx:137
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: DataLoader.cxx:485
UInt_t GetNVariables() const
Definition: DataSetInfo.h:127
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:236
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition: Event.cxx:381
Float_t GetTarget(UInt_t itgt) const
Definition: Event.h:102
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:57
void UpdateNorm(Int_t ivar, Double_t x)
Updates maximum and minimum value of a variable or target.
MsgLogger & Log() const
message logger
void CopyDataLoader(TMVA::DataLoader *des, TMVA::DataLoader *src)
TMVA::DataLoader * VarianceThreshold(Double_t threshold)
Computes variance of all the variables and returns a new DataLoader with the selected variables whose...
void CalcNorm()
Computes maximum, minimum, mean, RMS and variance for all variables and targets.
const std::vector< Event * > & fEvents
VarTransformHandler(DataLoader *)
constructor
@ kINFO
Definition: Types.h:58
@ kFATAL
Definition: Types.h:61
Basic string class.
Definition: TString.h:136
Double_t x[n]
Definition: legend1.C:17
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:148
Double_t Log(Double_t x)
Definition: TMath.h:710
Double_t Sqrt(Double_t x)
Definition: TMath.h:641