Logo ROOT  
Reference Guide
VarTransformHandler.cxx
Go to the documentation of this file.
1/**********************************************************************************
2 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
3 * Package: TMVA *
4 * Class : VarTransformHandler *
5 * Web : http://tmva.sourceforge.net *
6 * *
7 * Description: *
8 * Implementation of unsupervised variable transformation methods *
9 * *
10 * Authors (alphabetical): *
11 * Abhinav Moudgil <abhinav.moudgil@research.iiit.ac.in> - IIIT-H, India *
12 * *
13 * Copyright (c) 2005: *
14 * CERN, Switzerland *
15 * *
16 * Redistribution and use in source and binary forms, with or without *
17 * modification, are permitted according to the terms listed in LICENSE *
18 * (http://tmva.sourceforge.net/LICENSE) *
19 **********************************************************************************/
20
22
24#include "TMVA/DataLoader.h"
25#include "TMVA/Event.h"
27#include "TMVA/DataSet.h"
28#include "TMVA/DataSetInfo.h"
29#include "TMVA/MethodBase.h"
30#include "TMVA/MethodDNN.h"
31#include "TMVA/MsgLogger.h"
32#include "TMVA/Tools.h"
33#include "TMVA/Types.h"
34#include "TMVA/VariableInfo.h"
35
36#include "TMath.h"
37#include "TVectorD.h"
38#include "TFile.h"
39#include "TTree.h"
40#include "TMatrix.h"
41#include "TMatrixTSparse.h"
42#include "TMatrixDSparsefwd.h"
43#include "TCanvas.h"
44#include "TGraph.h"
45#include "TStyle.h"
46#include "TLegend.h"
47#include "TH2.h"
48
49#include <algorithm>
50#include <iomanip>
51#include <vector>
52
53////////////////////////////////////////////////////////////////////////////////
54/// constructor
55
57 : fLogger ( new MsgLogger(TString("VarTransformHandler").Data(), kINFO) ),
58 fDataSetInfo(dl->GetDataSetInfo()),
59 fDataLoader (dl),
60 fEvents (fDataSetInfo.GetDataSet()->GetEventCollection())
61{
62 Log() << kINFO << "Number of events - " << fEvents.size() << Endl;
63}
64
65////////////////////////////////////////////////////////////////////////////////
66/// destructor
67
69{
70 // do something
71 delete fLogger;
72}
73
74////////////////////////////////////////////////////////////////////////////////
75/// Computes variance of all the variables and
76/// returns a new DataLoader with the selected variables whose variance is above a specific threshold.
77/// Threshold can be provided by user otherwise default value is 0 i.e. remove the variables which have same value in all
78/// the events.
79///
80/// \param[in] threshold value (Double)
81///
82/// Transformation Definition String Format: "VT(optional float value)"
83///
84/// Usage examples:
85///
86/// String | Description
87/// ------- |----------------------------------------
88/// "VT" | Select variables whose variance is above threshold value = 0 (Default)
89/// "VT(1.5)" | Select variables whose variance is above threshold value = 1.5
90
92{
93 CalcNorm();
94 const UInt_t nvars = fDataSetInfo.GetNVariables();
95 Log() << kINFO << "Number of variables before transformation: " << nvars << Endl;
96 std::vector<VariableInfo>& vars = fDataSetInfo.GetVariableInfos();
97
98 // return a new dataloader
99 // iterate over all variables, ignore the ones whose variance is below specific threshold
100 // DataLoader *transformedLoader=(DataLoader *)fDataLoader->Clone("vt_transformed_dataset");
101 // TMVA::DataLoader *transformedLoader = new TMVA::DataLoader(fDataSetInfo.GetName());
102 TMVA::DataLoader *transformedLoader = new TMVA::DataLoader("vt_transformed_dataset");
103 Log() << kINFO << "Selecting variables whose variance is above threshold value = " << threshold << Endl;
104 Int_t maxL = fDataSetInfo.GetVariableNameMaxLength();
105 maxL = maxL + 16;
106 Log() << kINFO << "----------------------------------------------------------------" << Endl;
107 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << "Selected Variables";
108 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(10) << "Variance" << Endl;
109 Log() << kINFO << "----------------------------------------------------------------" << Endl;
110 for (UInt_t ivar=0; ivar<nvars; ivar++) {
111 Double_t variance = vars[ivar].GetVariance();
112 if (variance > threshold)
113 {
114 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << vars[ivar].GetExpression();
115 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << variance << Endl;
116 transformedLoader->AddVariable(vars[ivar].GetExpression(), vars[ivar].GetVarType());
117 }
118 }
119 CopyDataLoader(transformedLoader,fDataLoader);
120 Log() << kINFO << "----------------------------------------------------------------" << Endl;
121 // CopyDataLoader(transformedLoader, fDataLoader);
122 // DataLoader *transformedLoader=(DataLoader *)fDataLoader->Clone(fDataSetInfo.GetName());
123 transformedLoader->PrepareTrainingAndTestTree(fDataLoader->GetDataSetInfo().GetCut("Signal"), fDataLoader->GetDataSetInfo().GetCut("Background"), fDataLoader->GetDataSetInfo().GetSplitOptions());
124 Log() << kINFO << "Number of variables after transformation: " << transformedLoader->GetDataSetInfo().GetNVariables() << Endl;
125
126 return transformedLoader;
127}
128
129///////////////////////////////////////////////////////////////////////////////
130////////////////////////////// Utility methods ////////////////////////////////
131///////////////////////////////////////////////////////////////////////////////
132
133////////////////////////////////////////////////////////////////////////////////
134/// Updates maximum and minimum value of a variable or target
135
137{
138 Int_t nvars = fDataSetInfo.GetNVariables();
139 std::vector<VariableInfo>& vars = fDataSetInfo.GetVariableInfos();
140 std::vector<VariableInfo>& tars = fDataSetInfo.GetTargetInfos();
141 if( ivar < nvars ){
142 if (x < vars[ivar].GetMin()) vars[ivar].SetMin(x);
143 if (x > vars[ivar].GetMax()) vars[ivar].SetMax(x);
144 }
145 else{
146 if (x < tars[ivar-nvars].GetMin()) tars[ivar-nvars].SetMin(x);
147 if (x > tars[ivar-nvars].GetMax()) tars[ivar-nvars].SetMax(x);
148 }
149}
150
151////////////////////////////////////////////////////////////////////////////////
152/// Computes maximum, minimum, mean, RMS and variance for all
153/// variables and targets
154
156{
157 const std::vector<TMVA::Event*>& events = fDataSetInfo.GetDataSet()->GetEventCollection();
158
159 const UInt_t nvars = fDataSetInfo.GetNVariables();
160 const UInt_t ntgts = fDataSetInfo.GetNTargets();
161 std::vector<VariableInfo>& vars = fDataSetInfo.GetVariableInfos();
162 std::vector<VariableInfo>& tars = fDataSetInfo.GetTargetInfos();
163
164 UInt_t nevts = events.size();
165
166 TVectorD x2( nvars+ntgts ); x2 *= 0;
167 TVectorD x0( nvars+ntgts ); x0 *= 0;
168 TVectorD v0( nvars+ntgts ); v0 *= 0;
169
170 Double_t sumOfWeights = 0;
171 for (UInt_t ievt=0; ievt<nevts; ievt++) {
172 const Event* ev = events[ievt];
173
174 Double_t weight = ev->GetWeight();
175 sumOfWeights += weight;
176 for (UInt_t ivar=0; ivar<nvars; ivar++) {
177 Double_t x = ev->GetValue(ivar);
178 if (ievt==0) {
179 vars[ivar].SetMin(x);
180 vars[ivar].SetMax(x);
181 }
182 else {
183 UpdateNorm(ivar, x );
184 }
185 x0(ivar) += x*weight;
186 x2(ivar) += x*x*weight;
187 }
188 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
189 Double_t x = ev->GetTarget(itgt);
190 if (ievt==0) {
191 tars[itgt].SetMin(x);
192 tars[itgt].SetMax(x);
193 }
194 else {
195 UpdateNorm( nvars+itgt, x );
196 }
197 x0(nvars+itgt) += x*weight;
198 x2(nvars+itgt) += x*x*weight;
199 }
200 }
201
202 if (sumOfWeights <= 0) {
203 Log() << kFATAL << " the sum of event weights calculated for your input is == 0"
204 << " or exactly: " << sumOfWeights << " there is obviously some problem..."<< Endl;
205 }
206
207 // set Mean and RMS
208 for (UInt_t ivar=0; ivar<nvars; ivar++) {
209 Double_t mean = x0(ivar)/sumOfWeights;
210
211 vars[ivar].SetMean( mean );
212 if (x2(ivar)/sumOfWeights - mean*mean < 0) {
213 Log() << kFATAL << " the RMS of your input variable " << ivar
214 << " evaluates to an imaginary number: sqrt("<< x2(ivar)/sumOfWeights - mean*mean
215 <<") .. sometimes related to a problem with outliers and negative event weights"
216 << Endl;
217 }
218 vars[ivar].SetRMS( TMath::Sqrt( x2(ivar)/sumOfWeights - mean*mean) );
219 }
220 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
221 Double_t mean = x0(nvars+itgt)/sumOfWeights;
222 tars[itgt].SetMean( mean );
223 if (x2(nvars+itgt)/sumOfWeights - mean*mean < 0) {
224 Log() << kFATAL << " the RMS of your target variable " << itgt
225 << " evaluates to an imaginary number: sqrt(" << x2(nvars+itgt)/sumOfWeights - mean*mean
226 <<") .. sometimes related to a problem with outliers and negative event weights"
227 << Endl;
228 }
229 tars[itgt].SetRMS( TMath::Sqrt( x2(nvars+itgt)/sumOfWeights - mean*mean) );
230 }
231
232 // calculate variance
233 for (UInt_t ievt=0; ievt<nevts; ievt++) {
234 const Event* ev = events[ievt];
235 Double_t weight = ev->GetWeight();
236
237 for (UInt_t ivar=0; ivar<nvars; ivar++) {
238 Double_t x = ev->GetValue(ivar);
239 Double_t mean = vars[ivar].GetMean();
240 v0(ivar) += weight*(x-mean)*(x-mean);
241 }
242
243 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
244 Double_t x = ev->GetTarget(itgt);
245 Double_t mean = tars[itgt].GetMean();
246 v0(nvars+itgt) += weight*(x-mean)*(x-mean);
247 }
248 }
249
250 Int_t maxL = fDataSetInfo.GetVariableNameMaxLength();
251 maxL = maxL + 8;
252 Log() << kINFO << "----------------------------------------------------------------" << Endl;
253 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << "Variables";
254 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(10) << "Variance" << Endl;
255 Log() << kINFO << "----------------------------------------------------------------" << Endl;
256
257 // set variance
258 Log() << std::setprecision(5);
259 for (UInt_t ivar=0; ivar<nvars; ivar++) {
260 Double_t variance = v0(ivar)/sumOfWeights;
261 vars[ivar].SetVariance( variance );
262 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << vars[ivar].GetExpression();
263 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << variance << Endl;
264 }
265
266 maxL = fDataSetInfo.GetTargetNameMaxLength();
267 maxL = maxL + 8;
268 Log() << kINFO << "----------------------------------------------------------------" << Endl;
269 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << "Targets";
270 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(10) << "Variance" << Endl;
271 Log() << kINFO << "----------------------------------------------------------------" << Endl;
272
273 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
274 Double_t variance = v0(nvars+itgt)/sumOfWeights;
275 tars[itgt].SetVariance( variance );
276 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << tars[itgt].GetExpression();
277 Log() << kINFO << std::setiosflags(std::ios::left) << std::setw(maxL) << variance << Endl;
278 }
279
280 Log() << kINFO << "Set minNorm/maxNorm for variables to: " << Endl;
281 Log() << std::setprecision(3);
282 for (UInt_t ivar=0; ivar<nvars; ivar++)
283 Log() << " " << vars[ivar].GetExpression()
284 << "\t: [" << vars[ivar].GetMin() << "\t, " << vars[ivar].GetMax() << "\t] " << Endl;
285 Log() << kINFO << "Set minNorm/maxNorm for targets to: " << Endl;
286 Log() << std::setprecision(3);
287 for (UInt_t itgt=0; itgt<ntgts; itgt++)
288 Log() << " " << tars[itgt].GetExpression()
289 << "\t: [" << tars[itgt].GetMin() << "\t, " << tars[itgt].GetMax() << "\t] " << Endl;
290 Log() << std::setprecision(5); // reset to better value
291}
292
293////////////////////////////////////////////////////////////////////////////////
295{
296 for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Sbegin();treeinfo!=src->DataInput().Send();++treeinfo)
297 {
298 des->AddSignalTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
299 }
300
301 for( std::vector<TreeInfo>::const_iterator treeinfo=src->DataInput().Bbegin();treeinfo!=src->DataInput().Bend();++treeinfo)
302 {
303 des->AddBackgroundTree( (*treeinfo).GetTree(), (*treeinfo).GetWeight(),(*treeinfo).GetTreeType());
304 }
305}
static const double x2[5]
double Double_t
Definition: RtypesCore.h:57
std::vector< TreeInfo >::const_iterator Send() const
std::vector< TreeInfo >::const_iterator Sbegin() const
std::vector< TreeInfo >::const_iterator Bbegin() const
std::vector< TreeInfo >::const_iterator Bend() const
void AddSignalTree(TTree *signal, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: DataLoader.cxx:372
void PrepareTrainingAndTestTree(const TCut &cut, const TString &splitOpt)
prepare the training and test trees -> same cuts for signal and background
Definition: DataLoader.cxx:633
DataInputHandler & DataInput()
Definition: DataLoader.h:173
void AddBackgroundTree(TTree *background, Double_t weight=1.0, Types::ETreeType treetype=Types::kMaxTreeType)
number of signal events (used to compute significance)
Definition: DataLoader.cxx:403
DataSetInfo & GetDataSetInfo()
Definition: DataLoader.cxx:139
void AddVariable(const TString &expression, const TString &title, const TString &unit, char type='F', Double_t min=0, Double_t max=0)
user inserts discriminating variable in data set info
Definition: DataLoader.cxx:486
UInt_t GetNVariables() const
Definition: DataSetInfo.h:125
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:236
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition: Event.cxx:381
Float_t GetTarget(UInt_t itgt) const
Definition: Event.h:102
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
void UpdateNorm(Int_t ivar, Double_t x)
Updates maximum and minimum value of a variable or target.
MsgLogger & Log() const
message logger
void CopyDataLoader(TMVA::DataLoader *des, TMVA::DataLoader *src)
TMVA::DataLoader * VarianceThreshold(Double_t threshold)
Computes variance of all the variables and returns a new DataLoader with the selected variables whose...
void CalcNorm()
Computes maximum, minimum, mean, RMS and variance for all variables and targets.
const std::vector< Event * > & fEvents
VarTransformHandler(DataLoader *)
constructor
Basic string class.
Definition: TString.h:131
Double_t x[n]
Definition: legend1.C:17
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:750
Double_t Sqrt(Double_t x)
Definition: TMath.h:681