Logo ROOT  
Reference Guide
VariableTransformBase.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : VariableTransformBase *
8 * Web : http://tmva.sourceforge.net *
9 * *
10 * Description: *
11 * Implementation (see header for description) *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Peter Speckmayer <Peter.Speckmayer@cern.ch> - CERN, Switzerland *
16 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - CERN, Switzerland *
17 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
18 * *
19 * Copyright (c) 2005: *
20 * CERN, Switzerland *
21 * MPI-K Heidelberg, Germany *
22 * *
23 * Redistribution and use in source and binary forms, with or without *
24 * modification, are permitted according to the terms listed in LICENSE *
25 * (http://tmva.sourceforge.net/LICENSE) *
26 **********************************************************************************/
27
28/*! \class TMVA::VariableTransformBase
29\ingroup TMVA
30Linear interpolation class.
31*/
32
34
35#include "TMVA/Config.h"
36#include "TMVA/DataSetInfo.h"
37#include "TMVA/MsgLogger.h"
38#include "TMVA/Ranking.h"
39#include "TMVA/Tools.h"
40#include "TMVA/Types.h"
41#include "TMVA/VariableInfo.h"
42#include "TMVA/Version.h"
43
44#include "TH1.h"
45#include "TH2.h"
46#include "THashTable.h"
47#include "TList.h"
48#include "TObjString.h"
49#include "TMath.h"
50#include "TProfile.h"
51#include "TVectorD.h"
52
53#include <algorithm>
54#include <cassert>
55#include <exception>
56#include <iomanip>
57#include <stdexcept>
58#include <set>
59
61
63
64////////////////////////////////////////////////////////////////////////////////
65/// standard constructor
66
69 const TString& trfName )
70: TObject(),
71 fDsi(dsi),
72 fDsiOutput(NULL),
73 fTransformedEvent(0),
74 fBackTransformedEvent(0),
75 fVariableTransform(tf),
76 fEnabled( kTRUE ),
77 fCreated( kFALSE ),
78 fNormalise( kFALSE ),
79 fTransformName(trfName),
80 fVariableTypesAreCounted(false),
81 fNVariables(0),
82 fNTargets(0),
83 fNSpectators(0),
84 fSortGet(kTRUE),
85 fTMVAVersion(TMVA_VERSION_CODE),
86 fLogger( 0 )
87{
88 fLogger = new MsgLogger(this, kINFO);
89 for (UInt_t ivar = 0; ivar < fDsi.GetNVariables(); ivar++) {
90 fVariables.push_back( VariableInfo( fDsi.GetVariableInfo(ivar) ) );
91 }
92 for (UInt_t itgt = 0; itgt < fDsi.GetNTargets(); itgt++) {
93 fTargets.push_back( VariableInfo( fDsi.GetTargetInfo(itgt) ) );
94 }
95 for (UInt_t ispct = 0; ispct < fDsi.GetNSpectators(); ispct++) {
96 fSpectators.push_back( VariableInfo( fDsi.GetSpectatorInfo(ispct) ) );
97 }
98}
99
100////////////////////////////////////////////////////////////////////////////////
101
103{
104 if (fTransformedEvent!=0) delete fTransformedEvent;
105 if (fBackTransformedEvent!=0) delete fBackTransformedEvent;
106 // destructor
107 delete fLogger;
108}
109
110////////////////////////////////////////////////////////////////////////////////
111/// select the variables/targets/spectators which serve as input to the transformation
112
113void TMVA::VariableTransformBase::SelectInput( const TString& _inputVariables, Bool_t putIntoVariables )
114{
115 TString inputVariables = _inputVariables;
116
117 // unselect all variables first
118 fGet.clear();
119
120 UInt_t nvars = GetNVariables();
121 UInt_t ntgts = GetNTargets();
122 UInt_t nspcts = GetNSpectators();
123
124 typedef std::set<Int_t> SelectedIndices;
125
126 SelectedIndices varIndices;
127 SelectedIndices tgtIndices;
128 SelectedIndices spctIndices;
129
130 if (inputVariables == "") // default is all variables and all targets
131 { // (the default can be changed by decorating this member function in the implementations)
132 inputVariables = "_V_,_T_";
133 }
134
135 TList* inList = gTools().ParseFormatLine( inputVariables, "," );
136 TListIter inIt(inList);
137 while (TObjString* os = (TObjString*)inIt()) {
138
139 TString variables = os->GetString();
140
141 if( variables.BeginsWith("_") && variables.EndsWith("_") ) { // special symbol (keyword)
142 variables.Remove( 0,1); // remove first "_"
143 variables.Remove( variables.Length()-1,1 ); // remove last "_"
144
145 if( variables.BeginsWith("V") ) { // variables
146 variables.Remove(0,1); // remove "V"
147 if( variables.Length() == 0 ){
148 for( UInt_t ivar = 0; ivar < nvars; ++ivar ) {
149 fGet.push_back( std::pair<Char_t,UInt_t>('v',ivar) );
150 varIndices.insert( ivar );
151 }
152 } else {
153 UInt_t idx = variables.Atoi();
154 if( idx >= nvars )
155 Log() << kFATAL << "You selected variable with index : " << idx << " of only " << nvars << " variables." << Endl;
156 fGet.push_back( std::pair<Char_t,UInt_t>('v',idx) );
157 varIndices.insert( idx );
158 }
159 }else if( variables.BeginsWith("T") ) { // targets
160 variables.Remove(0,1); // remove "T"
161 if( variables.Length() == 0 ){
162 for( UInt_t itgt = 0; itgt < ntgts; ++itgt ) {
163 fGet.push_back( std::pair<Char_t,UInt_t>('t',itgt) );
164 tgtIndices.insert( itgt );
165 }
166 } else {
167 UInt_t idx = variables.Atoi();
168 if( idx >= ntgts )
169 Log() << kFATAL << "You selected target with index : " << idx << " of only " << ntgts << " targets." << Endl;
170 fGet.push_back( std::pair<Char_t,UInt_t>('t',idx) );
171 tgtIndices.insert( idx );
172 }
173 }else if( variables.BeginsWith("S") ) { // spectators
174 variables.Remove(0,1); // remove "S"
175 if( variables.Length() == 0 ){
176 for( UInt_t ispct = 0; ispct < nspcts; ++ispct ) {
177 fGet.push_back( std::pair<Char_t,UInt_t>('s',ispct) );
178 spctIndices.insert( ispct );
179 }
180 } else {
181 UInt_t idx = variables.Atoi();
182 if( idx >= nspcts )
183 Log() << kFATAL << "You selected spectator with index : " << idx << " of only " << nspcts << " spectators." << Endl;
184 fGet.push_back( std::pair<Char_t,UInt_t>('s',idx) );
185 spctIndices.insert( idx );
186 }
187 }else if( TString("REARRANGE").BeginsWith(variables) ) { // toggle rearrange sorting (take sort order given in the options)
188 ToggleInputSortOrder( kFALSE );
189 if( !fSortGet )
190 Log() << kINFO << "Variable rearrangement set true: Variable order given in transformation option is used for input to transformation!" << Endl;
191
192 }
193 }else{ // no keyword, ... user provided variable labels
194 Int_t numIndices = varIndices.size()+tgtIndices.size()+spctIndices.size();
195 for( UInt_t ivar = 0; ivar < nvars; ++ivar ) { // search all variables
196 if( fDsi.GetVariableInfo( ivar ).GetLabel() == variables ) {
197 fGet.push_back( std::pair<Char_t,UInt_t>('v',ivar) );
198 varIndices.insert( ivar );
199 break;
200 }
201 }
202 for( UInt_t itgt = 0; itgt < ntgts; ++itgt ) { // search all targets
203 if( fDsi.GetTargetInfo( itgt ).GetLabel() == variables ) {
204 fGet.push_back( std::pair<Char_t,UInt_t>('t',itgt) );
205 tgtIndices.insert( itgt );
206 break;
207 }
208 }
209 for( UInt_t ispct = 0; ispct < nspcts; ++ispct ) { // search all spectators
210 if( fDsi.GetSpectatorInfo( ispct ).GetLabel() == variables ) {
211 fGet.push_back( std::pair<Char_t,UInt_t>('s',ispct) );
212 spctIndices.insert( ispct );
213 break;
214 }
215 }
216 Int_t numIndicesEndOfLoop = varIndices.size()+tgtIndices.size()+spctIndices.size();
217 if( numIndicesEndOfLoop == numIndices )
218 Log() << kWARNING << "Error at parsing the options for the variable transformations: Variable/Target/Spectator '" << variables.Data() << "' not found." << Endl;
219 numIndices = numIndicesEndOfLoop;
220 }
221 }
222
223
224 if( putIntoVariables ) {
225 Int_t idx = 0;
226 for( SelectedIndices::iterator it = varIndices.begin(), itEnd = varIndices.end(); it != itEnd; ++it ) {
227 fPut.push_back( std::pair<Char_t,UInt_t>('v',idx) );
228 ++idx;
229 }
230 for( SelectedIndices::iterator it = tgtIndices.begin(), itEnd = tgtIndices.end(); it != itEnd; ++it ) {
231 fPut.push_back( std::pair<Char_t,UInt_t>('t',idx) );
232 ++idx;
233 }
234 for( SelectedIndices::iterator it = spctIndices.begin(), itEnd = spctIndices.end(); it != itEnd; ++it ) {
235 fPut.push_back( std::pair<Char_t,UInt_t>('s',idx) );
236 ++idx;
237 }
238 }else {
239 for( SelectedIndices::iterator it = varIndices.begin(), itEnd = varIndices.end(); it != itEnd; ++it ) {
240 Int_t idx = (*it);
241 fPut.push_back( std::pair<Char_t,UInt_t>('v',idx) );
242 }
243 for( SelectedIndices::iterator it = tgtIndices.begin(), itEnd = tgtIndices.end(); it != itEnd; ++it ) {
244 Int_t idx = (*it);
245 fPut.push_back( std::pair<Char_t,UInt_t>('t',idx) );
246 }
247 for( SelectedIndices::iterator it = spctIndices.begin(), itEnd = spctIndices.end(); it != itEnd; ++it ) {
248 Int_t idx = (*it);
249 fPut.push_back( std::pair<Char_t,UInt_t>('s',idx) );
250 }
251
252 // if sorting is turned on, fGet should have the indices sorted as fPut has them.
253 if( fSortGet ) {
254 fGet.clear();
255 fGet.assign( fPut.begin(), fPut.end() );
256 }
257 }
258
259 Log() << kHEADER << "Transformation, Variable selection : " << Endl;
260
261 // choose the new dsi for output if present, if not, take the common one
262 const DataSetInfo* outputDsiPtr = (fDsiOutput? &(*fDsiOutput) : &fDsi );
263
264
265
266 ItVarTypeIdx itGet = fGet.begin(), itGetEnd = fGet.end();
267 ItVarTypeIdx itPut = fPut.begin(); // , itPutEnd = fPut.end();
268 for( ; itGet != itGetEnd; ++itGet ) {
269 TString inputTypeString = "?";
270
271 Char_t inputType = (*itGet).first;
272 Int_t inputIdx = (*itGet).second;
273
274 TString inputLabel = "NOT FOND";
275 if( inputType == 'v' ) {
276 inputLabel = fDsi.GetVariableInfo( inputIdx ).GetLabel();
277 inputTypeString = "variable";
278 }
279 else if( inputType == 't' ){
280 inputLabel = fDsi.GetTargetInfo( inputIdx ).GetLabel();
281 inputTypeString = "target";
282 }
283 else if( inputType == 's' ){
284 inputLabel = fDsi.GetSpectatorInfo( inputIdx ).GetLabel();
285 inputTypeString = "spectator";
286 }
287
288 TString outputTypeString = "?";
289
290 Char_t outputType = (*itPut).first;
291 Int_t outputIdx = (*itPut).second;
292
293 TString outputLabel = "NOT FOUND";
294 if( outputType == 'v' ) {
295 outputLabel = outputDsiPtr->GetVariableInfo( outputIdx ).GetLabel();
296 outputTypeString = "variable";
297 }
298 else if( outputType == 't' ){
299 outputLabel = outputDsiPtr->GetTargetInfo( outputIdx ).GetLabel();
300 outputTypeString = "target";
301 }
302 else if( outputType == 's' ){
303 outputLabel = outputDsiPtr->GetSpectatorInfo( outputIdx ).GetLabel();
304 outputTypeString = "spectator";
305 }
306 Log() << kINFO << "Input : " << inputTypeString.Data() << " '" << inputLabel.Data() << "'" << " <---> " << "Output : " << outputTypeString.Data() << " '" << outputLabel.Data() << "'" << Endl;
307 Log() << kDEBUG << "\t(index=" << inputIdx << ")." << "\t(index=" << outputIdx << ")." << Endl;
308
309 ++itPut;
310 }
311 // Log() << kINFO << Endl;
312}
313
314
315////////////////////////////////////////////////////////////////////////////////
316/// select the values from the event
317
318Bool_t TMVA::VariableTransformBase::GetInput( const Event* event, std::vector<Float_t>& input, std::vector<Char_t>& mask, Bool_t backTransformation ) const
319{
320 ItVarTypeIdxConst itEntry;
321 ItVarTypeIdxConst itEntryEnd;
322
323 input.clear();
324 mask.clear();
325
326 if( backTransformation && !fPut.empty() ){
327 itEntry = fPut.begin();
328 itEntryEnd = fPut.end();
329 input.reserve(fPut.size());
330 }
331 else {
332 itEntry = fGet.begin();
333 itEntryEnd = fGet.end();
334 input.reserve(fGet.size() );
335 }
336
337 Bool_t hasMaskedEntries = kFALSE;
338 // event->Print(std::cout);
339 for( ; itEntry != itEntryEnd; ++itEntry ) {
340 Char_t type = (*itEntry).first;
341 Int_t idx = (*itEntry).second;
342
343 try{
344 switch( type ) {
345 case 'v':
346 input.push_back( event->GetValue(idx) );
347 break;
348 case 't':
349 input.push_back( event->GetTarget(idx) );
350 break;
351 case 's':
352 input.push_back( event->GetSpectator(idx) );
353 break;
354 default:
355 Log() << kFATAL << "VariableTransformBase/GetInput : unknown type '" << type << "'." << Endl;
356 }
357 mask.push_back(kFALSE);
358 }
359 catch(std::out_of_range& /* excpt */ ){ // happens when an event is transformed which does not yet have the targets calculated (in the application phase)
360 input.push_back(0.f);
361 mask.push_back(kTRUE);
362 hasMaskedEntries = kTRUE;
363 }
364 }
365 return hasMaskedEntries;
366}
367
368////////////////////////////////////////////////////////////////////////////////
369/// select the values from the event
370
371void TMVA::VariableTransformBase::SetOutput( Event* event, std::vector<Float_t>& output, std::vector<Char_t>& mask, const Event* oldEvent, Bool_t backTransformation ) const
372{
373 std::vector<Float_t>::iterator itOutput = output.begin();
374 std::vector<Char_t>::iterator itMask = mask.begin();
375
376 if( oldEvent )
377 event->CopyVarValues( *oldEvent );
378
379 try {
380
381 ItVarTypeIdxConst itEntry;
382 ItVarTypeIdxConst itEntryEnd;
383
384 if( backTransformation || fPut.empty() ){ // as in GetInput, but the other way round (from fPut for transformation, from fGet for backTransformation)
385 itEntry = fGet.begin();
386 itEntryEnd = fGet.end();
387 }
388 else {
389 itEntry = fPut.begin();
390 itEntryEnd = fPut.end();
391 }
392
393
394 for( ; itEntry != itEntryEnd; ++itEntry ) {
395
396 if( (*itMask) ){ // if the value is masked
397 continue;
398 }
399
400 Char_t type = (*itEntry).first;
401 Int_t idx = (*itEntry).second;
402 if (itOutput == output.end()) Log() << kFATAL << "Read beyond array boundaries in VariableTransformBase::SetOutput"<<Endl;
403 Float_t value = (*itOutput);
404
405 switch( type ) {
406 case 'v':
407 event->SetVal( idx, value );
408 break;
409 case 't':
410 event->SetTarget( idx, value );
411 break;
412 case 's':
413 event->SetSpectator( idx, value );
414 break;
415 default:
416 Log() << kFATAL << "VariableTransformBase/GetInput : unknown type '" << type << "'." << Endl;
417 }
418 if( !(*itMask) ) ++itOutput;
419 ++itMask;
420
421 }
422 }catch( std::exception& except ){
423 Log() << kFATAL << "VariableTransformBase/SetOutput : exception/" << except.what() << Endl;
424 throw;
425 }
426}
427
428
429////////////////////////////////////////////////////////////////////////////////
430/// count variables, targets and spectators
431
433{
434 if( fVariableTypesAreCounted ){
435 nvars = fNVariables;
436 ntgts = fNTargets;
437 nspcts = fNSpectators;
438 return;
439 }
440
441 nvars = ntgts = nspcts = 0;
442
443 for( ItVarTypeIdxConst itEntry = fGet.begin(), itEntryEnd = fGet.end(); itEntry != itEntryEnd; ++itEntry ) {
444 Char_t type = (*itEntry).first;
445
446 switch( type ) {
447 case 'v':
448 nvars++;
449 break;
450 case 't':
451 ntgts++;
452 break;
453 case 's':
454 nspcts++;
455 break;
456 default:
457 Log() << kFATAL << "VariableTransformBase/GetVariableTypeNumbers : unknown type '" << type << "'." << Endl;
458 }
459 }
460
461 fNVariables = nvars;
462 fNTargets = ntgts;
463 fNSpectators = nspcts;
464
465 fVariableTypesAreCounted = true;
466}
467
468////////////////////////////////////////////////////////////////////////////////
469/// TODO --> adapt to variable,target,spectator selection
470/// method to calculate minimum, maximum, mean, and RMS for all
471/// variables used in the MVA
472
473void TMVA::VariableTransformBase::CalcNorm( const std::vector<const Event*>& events )
474{
475 if (!IsCreated()) return;
476
477 const UInt_t nvars = GetNVariables();
478 const UInt_t ntgts = GetNTargets();
479
480 UInt_t nevts = events.size();
481
482 TVectorD x2( nvars+ntgts ); x2 *= 0;
483 TVectorD x0( nvars+ntgts ); x0 *= 0;
484 TVectorD v0( nvars+ntgts ); v0 *= 0;
485
486 Double_t sumOfWeights = 0;
487 for (UInt_t ievt=0; ievt<nevts; ievt++) {
488 const Event* ev = events[ievt];
489
490 Double_t weight = ev->GetWeight();
491 sumOfWeights += weight;
492 for (UInt_t ivar=0; ivar<nvars; ivar++) {
493 Double_t x = ev->GetValue(ivar);
494 if (ievt==0) {
495 Variables().at(ivar).SetMin(x);
496 Variables().at(ivar).SetMax(x);
497 }
498 else {
499 UpdateNorm( ivar, x );
500 }
501 x0(ivar) += x*weight;
502 x2(ivar) += x*x*weight;
503 }
504 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
505 Double_t x = ev->GetTarget(itgt);
506 if (ievt==0) {
507 Targets().at(itgt).SetMin(x);
508 Targets().at(itgt).SetMax(x);
509 }
510 else {
511 UpdateNorm( nvars+itgt, x );
512 }
513 x0(nvars+itgt) += x*weight;
514 x2(nvars+itgt) += x*x*weight;
515 }
516 }
517
518 if (sumOfWeights <= 0) {
519 Log() << kFATAL << " the sum of event weights calculated for your input is == 0"
520 << " or exactly: " << sumOfWeights << " there is obviously some problem..."<< Endl;
521 }
522
523 // set Mean and RMS
524 for (UInt_t ivar=0; ivar<nvars; ivar++) {
525 Double_t mean = x0(ivar)/sumOfWeights;
526
527 Variables().at(ivar).SetMean( mean );
528 if (x2(ivar)/sumOfWeights - mean*mean < 0) {
529 Log() << kFATAL << " the RMS of your input variable " << ivar
530 << " evaluates to an imaginary number: sqrt("<< x2(ivar)/sumOfWeights - mean*mean
531 <<") .. sometimes related to a problem with outliers and negative event weights"
532 << Endl;
533 }
534 Variables().at(ivar).SetRMS( TMath::Sqrt( x2(ivar)/sumOfWeights - mean*mean) );
535 }
536 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
537 Double_t mean = x0(nvars+itgt)/sumOfWeights;
538 Targets().at(itgt).SetMean( mean );
539 if (x2(nvars+itgt)/sumOfWeights - mean*mean < 0) {
540 Log() << kFATAL << " the RMS of your target variable " << itgt
541 << " evaluates to an imaginary number: sqrt(" << x2(nvars+itgt)/sumOfWeights - mean*mean
542 <<") .. sometimes related to a problem with outliers and negative event weights"
543 << Endl;
544 }
545 Targets().at(itgt).SetRMS( TMath::Sqrt( x2(nvars+itgt)/sumOfWeights - mean*mean) );
546 }
547 // calculate variance
548 for (UInt_t ievt=0; ievt<nevts; ievt++) {
549 const Event* ev = events[ievt];
550 Double_t weight = ev->GetWeight();
551 for (UInt_t ivar=0; ivar<nvars; ivar++) {
552 Double_t x = ev->GetValue(ivar);
553 Double_t mean = Variables().at(ivar).GetMean();
554 v0(ivar) += weight*(x-mean)*(x-mean);
555 }
556 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
557 Double_t x = ev->GetTarget(itgt);
558 Double_t mean = Targets().at(itgt).GetMean();
559 v0(nvars+itgt) += weight*(x-mean)*(x-mean);
560 }
561
562 }
563
564 // set variance
565 for (UInt_t ivar=0; ivar<nvars; ivar++) {
566 Double_t variance = v0(ivar)/sumOfWeights;
567 Variables().at(ivar).SetVariance( variance );
568 Log() << kINFO << "Variable " << Variables().at(ivar).GetExpression() <<" variance = " << variance << Endl;
569 }
570 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
571 Double_t variance = v0(nvars+itgt)/sumOfWeights;
572 Targets().at(itgt).SetVariance( variance );
573 Log() << kINFO << "Target " << Targets().at(itgt).GetExpression() <<" variance = " << variance << Endl;
574 }
575
576 Log() << kVERBOSE << "Set minNorm/maxNorm for variables to: " << Endl;
577 Log() << std::setprecision(3);
578 for (UInt_t ivar=0; ivar<GetNVariables(); ivar++)
579 Log() << " " << Variables().at(ivar).GetInternalName()
580 << "\t: [" << Variables().at(ivar).GetMin() << "\t, " << Variables().at(ivar).GetMax() << "\t] " << Endl;
581 Log() << kVERBOSE << "Set minNorm/maxNorm for targets to: " << Endl;
582 Log() << std::setprecision(3);
583 for (UInt_t itgt=0; itgt<GetNTargets(); itgt++)
584 Log() << " " << Targets().at(itgt).GetInternalName()
585 << "\t: [" << Targets().at(itgt).GetMin() << "\t, " << Targets().at(itgt).GetMax() << "\t] " << Endl;
586 Log() << std::setprecision(5); // reset to better value
587}
588
589////////////////////////////////////////////////////////////////////////////////
590/// TODO --> adapt to variable,target,spectator selection
591/// default transformation output
592/// --> only indicate that transformation occurred
593
595{
596 std::vector<TString>* strVec = new std::vector<TString>;
597 for (UInt_t ivar=0; ivar<GetNVariables(); ivar++) {
598 strVec->push_back( Variables()[ivar].GetLabel() + "_[transformed]");
599 }
600
601 return strVec;
602}
603
604////////////////////////////////////////////////////////////////////////////////
605/// TODO --> adapt to variable,target,spectator selection
606/// update min and max of a given variable (target) and a given transformation method
607
609{
610 Int_t nvars = fDsi.GetNVariables();
611 if( ivar < nvars ){
612 if (x < Variables().at(ivar).GetMin()) Variables().at(ivar).SetMin(x);
613 if (x > Variables().at(ivar).GetMax()) Variables().at(ivar).SetMax(x);
614 }else{
615 if (x < Targets().at(ivar-nvars).GetMin()) Targets().at(ivar-nvars).SetMin(x);
616 if (x > Targets().at(ivar-nvars).GetMax()) Targets().at(ivar-nvars).SetMax(x);
617 }
618}
619
620////////////////////////////////////////////////////////////////////////////////
621/// create XML description the transformation (write out info of selected variables)
622
624{
625 void* selxml = gTools().AddChild(parent, "Selection");
626
627 void* inpxml = gTools().AddChild(selxml, "Input");
628 gTools().AddAttr(inpxml, "NInputs", fGet.size() );
629
630 // choose the new dsi for output if present, if not, take the common one
631 const DataSetInfo* outputDsiPtr = (fDsiOutput? fDsiOutput : &fDsi );
632
633 for( ItVarTypeIdx itGet = fGet.begin(), itGetEnd = fGet.end(); itGet != itGetEnd; ++itGet ) {
634 UInt_t idx = (*itGet).second;
635 Char_t type = (*itGet).first;
636
637 TString label = "";
638 TString expression = "";
639 TString typeString = "";
640 switch( type ){
641 case 'v':
642 typeString = "Variable";
643 label = fDsi.GetVariableInfo( idx ).GetLabel();
644 expression = fDsi.GetVariableInfo( idx ).GetExpression();
645 break;
646 case 't':
647 typeString = "Target";
648 label = fDsi.GetTargetInfo( idx ).GetLabel();
649 expression = fDsi.GetTargetInfo( idx ).GetExpression();
650 break;
651 case 's':
652 typeString = "Spectator";
653 label = fDsi.GetSpectatorInfo( idx ).GetLabel();
654 expression = fDsi.GetSpectatorInfo( idx ).GetExpression();
655 break;
656 default:
657 Log() << kFATAL << "VariableTransformBase/AttachXMLTo unknown variable type '" << type << "'." << Endl;
658 }
659
660 void* idxxml = gTools().AddChild(inpxml, "Input");
661 // gTools().AddAttr(idxxml, "Index", idx);
662 gTools().AddAttr(idxxml, "Type", typeString);
663 gTools().AddAttr(idxxml, "Label", label);
664 gTools().AddAttr(idxxml, "Expression", expression);
665 }
666
667
668 void* outxml = gTools().AddChild(selxml, "Output");
669 gTools().AddAttr(outxml, "NOutputs", fPut.size() );
670
671 for( ItVarTypeIdx itPut = fPut.begin(), itPutEnd = fPut.end(); itPut != itPutEnd; ++itPut ) {
672 UInt_t idx = (*itPut).second;
673 Char_t type = (*itPut).first;
674
675 TString label = "";
676 TString expression = "";
677 TString typeString = "";
678 switch( type ){
679 case 'v':
680 typeString = "Variable";
681 label = outputDsiPtr->GetVariableInfo( idx ).GetLabel();
682 expression = outputDsiPtr->GetVariableInfo( idx ).GetExpression();
683 break;
684 case 't':
685 typeString = "Target";
686 label = outputDsiPtr->GetTargetInfo( idx ).GetLabel();
687 expression = outputDsiPtr->GetTargetInfo( idx ).GetExpression();
688 break;
689 case 's':
690 typeString = "Spectator";
691 label = outputDsiPtr->GetSpectatorInfo( idx ).GetLabel();
692 expression = outputDsiPtr->GetSpectatorInfo( idx ).GetExpression();
693 break;
694 default:
695 Log() << kFATAL << "VariableTransformBase/AttachXMLTo unknown variable type '" << type << "'." << Endl;
696 }
697
698 void* idxxml = gTools().AddChild(outxml, "Output");
699 // gTools().AddAttr(idxxml, "Index", idx);
700 gTools().AddAttr(idxxml, "Type", typeString);
701 gTools().AddAttr(idxxml, "Label", label);
702 gTools().AddAttr(idxxml, "Expression", expression);
703 }
704
705
706}
707
708////////////////////////////////////////////////////////////////////////////////
709/// Read the input variables from the XML node
710
712{
713 void* inpnode = gTools().GetChild( selnode );
714 void* outnode = gTools().GetNextChild( inpnode );
715
716 UInt_t nvars = GetNVariables();
717 UInt_t ntgts = GetNTargets();
718 UInt_t nspcts = GetNSpectators();
719
720 // read inputs
721 fGet.clear();
722
723 UInt_t nInputs = 0;
724 gTools().ReadAttr(inpnode, "NInputs", nInputs);
725
726 void* ch = gTools().GetChild( inpnode );
727 while(ch) {
728 TString typeString = "";
729 TString label = "";
730 TString expression = "";
731
732 gTools().ReadAttr(ch, "Type", typeString);
733 gTools().ReadAttr(ch, "Label", label);
734 gTools().ReadAttr(ch, "Expression", expression);
735
736 if( typeString == "Variable" ){
737 for( UInt_t ivar = 0; ivar < nvars; ++ivar ) { // search all variables
738 if( fDsi.GetVariableInfo( ivar ).GetLabel() == label ||
739 fDsi.GetVariableInfo( ivar ).GetExpression() == expression) {
740 fGet.push_back( std::pair<Char_t,UInt_t>('v',ivar) );
741 break;
742 }
743 }
744 }else if( typeString == "Target" ){
745 for( UInt_t itgt = 0; itgt < ntgts; ++itgt ) { // search all targets
746 if( fDsi.GetTargetInfo( itgt ).GetLabel() == label ||
747 fDsi.GetTargetInfo( itgt ).GetExpression() == expression ) {
748 fGet.push_back( std::pair<Char_t,UInt_t>('t',itgt) );
749 break;
750 }
751 }
752 }else if( typeString == "Spectator" ){
753 for( UInt_t ispct = 0; ispct < nspcts; ++ispct ) { // search all spectators
754 if( fDsi.GetSpectatorInfo( ispct ).GetLabel() == label ||
755 fDsi.GetSpectatorInfo( ispct ).GetExpression() == expression ) {
756 fGet.push_back( std::pair<Char_t,UInt_t>('s',ispct) );
757 break;
758 }
759 }
760 }else{
761 Log() << kFATAL << "VariableTransformationBase/ReadFromXML : unknown type '" << typeString << "'." << Endl;
762 }
763 ch = gTools().GetNextChild( ch );
764 }
765
766 assert( nInputs == fGet.size() );
767
768 // read outputs
769 fPut.clear();
770
771 UInt_t nOutputs = 0;
772 gTools().ReadAttr(outnode, "NOutputs", nOutputs);
773
774 void* chOut = gTools().GetChild( outnode );
775 while(chOut) {
776 TString typeString = "";
777 TString label = "";
778 TString expression = "";
779
780 gTools().ReadAttr(chOut, "Type", typeString);
781 gTools().ReadAttr(chOut, "Label", label);
782 gTools().ReadAttr(chOut, "Expression", expression);
783
784 if( typeString == "Variable" ){
785 for( UInt_t ivar = 0; ivar < nvars; ++ivar ) { // search all variables
786 if( fDsi.GetVariableInfo( ivar ).GetLabel() == label ||
787 fDsi.GetVariableInfo( ivar ).GetExpression() == expression ) {
788 fPut.push_back( std::pair<Char_t,UInt_t>('v',ivar) );
789 break;
790 }
791 }
792 }else if( typeString == "Target" ){
793 for( UInt_t itgt = 0; itgt < ntgts; ++itgt ) { // search all targets
794 if( fDsi.GetTargetInfo( itgt ).GetLabel() == label ||
795 fDsi.GetTargetInfo( itgt ).GetExpression() == expression ) {
796 fPut.push_back( std::pair<Char_t,UInt_t>('t',itgt) );
797 break;
798 }
799 }
800 }else if( typeString == "Spectator" ){
801 for( UInt_t ispct = 0; ispct < nspcts; ++ispct ) { // search all spectators
802 if( fDsi.GetSpectatorInfo( ispct ).GetLabel() == label ||
803 fDsi.GetSpectatorInfo( ispct ).GetExpression() == expression ) {
804 fPut.push_back( std::pair<Char_t,UInt_t>('s',ispct) );
805 break;
806 }
807 }
808 }else{
809 Log() << kFATAL << "VariableTransformationBase/ReadFromXML : unknown type '" << typeString << "'." << Endl;
810 }
811 chOut = gTools().GetNextChild( chOut );
812 }
813
814 assert( nOutputs == fPut.size() );
815}
816
817////////////////////////////////////////////////////////////////////////////////
818/// getinput and setoutput equivalent
819
820void TMVA::VariableTransformBase::MakeFunction( std::ostream& fout, const TString& /*fncName*/, Int_t part,
821 UInt_t /*trCounter*/, Int_t /*cls*/ )
822{
823 if( part == 0 ){ // definitions
824 fout << std::endl;
825 fout << " // define the indices of the variables which are transformed by this transformation" << std::endl;
826 fout << " static std::vector<int> indicesGet;" << std::endl;
827 fout << " static std::vector<int> indicesPut;" << std::endl << std::endl;
828 fout << " if ( indicesGet.empty() ) {" << std::endl;
829 fout << " indicesGet.reserve(fNvars);" << std::endl;
830
831 for( ItVarTypeIdxConst itEntry = fGet.begin(), itEntryEnd = fGet.end(); itEntry != itEntryEnd; ++itEntry ) {
832 Char_t type = (*itEntry).first;
833 Int_t idx = (*itEntry).second;
834
835 switch( type ) {
836 case 'v':
837 fout << " indicesGet.push_back( " << idx << ");" << std::endl;
838 break;
839 case 't':
840 Log() << kWARNING << "MakeClass doesn't work with transformation of targets. The results will be wrong!" << Endl;
841 break;
842 case 's':
843 Log() << kWARNING << "MakeClass doesn't work with transformation of spectators. The results will be wrong!" << Endl;
844 break;
845 default:
846 Log() << kFATAL << "VariableTransformBase/GetInput : unknown type '" << type << "'." << Endl;
847 }
848 }
849 fout << " }" << std::endl;
850 fout << " if ( indicesPut.empty() ) {" << std::endl;
851 fout << " indicesPut.reserve(fNvars);" << std::endl;
852
853 for( ItVarTypeIdxConst itEntry = fPut.begin(), itEntryEnd = fPut.end(); itEntry != itEntryEnd; ++itEntry ) {
854 Char_t type = (*itEntry).first;
855 Int_t idx = (*itEntry).second;
856
857 switch( type ) {
858 case 'v':
859 fout << " indicesPut.push_back( " << idx << ");" << std::endl;
860 break;
861 case 't':
862 Log() << kWARNING << "MakeClass doesn't work with transformation of targets. The results will be wrong!" << Endl;
863 break;
864 case 's':
865 Log() << kWARNING << "MakeClass doesn't work with transformation of spectators. The results will be wrong!" << Endl;
866 break;
867 default:
868 Log() << kFATAL << "VariableTransformBase/PutInput : unknown type '" << type << "'." << Endl;
869 }
870 }
871
872 fout << " }" << std::endl;
873 fout << std::endl;
874
875 }else if( part == 1){
876 }
877}
static const double x2[5]
char Char_t
Definition: RtypesCore.h:31
const Bool_t kFALSE
Definition: RtypesCore.h:90
double Double_t
Definition: RtypesCore.h:57
float Float_t
Definition: RtypesCore.h:55
const Bool_t kTRUE
Definition: RtypesCore.h:89
#define ClassImp(name)
Definition: Rtypes.h:361
int type
Definition: TGX11.cxx:120
bool advanced
#define TMVA_VERSION_CODE
Definition: Version.h:47
Iterator of linked list.
Definition: TList.h:200
A doubly linked list.
Definition: TList.h:44
Class that contains all the data information.
Definition: DataSetInfo.h:60
UInt_t GetNVariables() const
Definition: DataSetInfo.h:125
UInt_t GetNSpectators(bool all=kTRUE) const
UInt_t GetNTargets() const
Definition: DataSetInfo.h:126
VariableInfo & GetVariableInfo(Int_t i)
Definition: DataSetInfo.h:103
VariableInfo & GetTargetInfo(Int_t i)
Definition: DataSetInfo.h:117
VariableInfo & GetSpectatorInfo(Int_t i)
Definition: DataSetInfo.h:122
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition: Event.cxx:236
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition: Event.cxx:381
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Definition: Event.cxx:261
Float_t GetTarget(UInt_t itgt) const
Definition: Event.h:102
ostringstream derivative to redirect and format output
Definition: MsgLogger.h:59
TList * ParseFormatLine(TString theString, const char *sep=":")
Parse the string and cut into labels separated by ":".
Definition: Tools.cxx:412
void * GetNextChild(void *prevchild, const char *childname=0)
XML helpers.
Definition: Tools.cxx:1173
void * AddChild(void *parent, const char *childname, const char *content=0, bool isRootNode=false)
add child node
Definition: Tools.cxx:1135
void * GetChild(void *parent, const char *childname=0)
get child node
Definition: Tools.cxx:1161
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition: Tools.h:335
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition: Tools.h:353
EVariableTransform
Definition: Types.h:115
Class for type info of MVA input variable.
Definition: VariableInfo.h:47
const TString & GetLabel() const
Definition: VariableInfo.h:59
const TString & GetExpression() const
Definition: VariableInfo.h:57
Linear interpolation class.
virtual void MakeFunction(std::ostream &fout, const TString &fncName, Int_t part, UInt_t trCounter, Int_t cls)=0
getinput and setoutput equivalent
virtual Bool_t GetInput(const Event *event, std::vector< Float_t > &input, std::vector< Char_t > &mask, Bool_t backTransform=kFALSE) const
select the values from the event
void CalcNorm(const std::vector< const Event * > &)
TODO --> adapt to variable,target,spectator selection method to calculate minimum,...
virtual void ReadFromXML(void *trfnode)=0
Read the input variables from the XML node.
virtual void AttachXMLTo(void *parent)=0
create XML description the transformation (write out info of selected variables)
virtual void SetOutput(Event *event, std::vector< Float_t > &output, std::vector< Char_t > &mask, const Event *oldEvent=0, Bool_t backTransform=kFALSE) const
select the values from the event
std::vector< TMVA::VariableInfo > fVariables
VariableTransformBase(DataSetInfo &dsi, Types::EVariableTransform tf, const TString &trfName)
standard constructor
void UpdateNorm(Int_t ivar, Double_t x)
TODO --> adapt to variable,target,spectator selection update min and max of a given variable (target)...
virtual void CountVariableTypes(UInt_t &nvars, UInt_t &ntgts, UInt_t &nspcts) const
count variables, targets and spectators
virtual std::vector< TString > * GetTransformationStrings(Int_t cls) const
TODO --> adapt to variable,target,spectator selection default transformation output --> only indicate...
virtual void SelectInput(const TString &inputVariables, Bool_t putIntoVariables=kFALSE)
select the variables/targets/spectators which serve as input to the transformation
VectorOfCharAndInt::iterator ItVarTypeIdx
std::vector< TMVA::VariableInfo > fSpectators
std::vector< TMVA::VariableInfo > fTargets
VectorOfCharAndInt::const_iterator ItVarTypeIdxConst
Collectable string class.
Definition: TObjString.h:28
Mother of all ROOT objects.
Definition: TObject.h:37
Basic string class.
Definition: TString.h:131
const char * Data() const
Definition: TString.h:364
Double_t x[n]
Definition: legend1.C:17
bool BeginsWith(const std::string &theString, const std::string &theSubstring)
Tools & gTools()
void variables(TString dataset, TString fin="TMVA.root", TString dirName="InputVariables_Id", TString title="TMVA Input Variables", Bool_t isRegression=kFALSE, Bool_t useTMVAStyle=kTRUE)
MsgLogger & Endl(MsgLogger &ml)
Definition: MsgLogger.h:158
Double_t Log(Double_t x)
Definition: TMath.h:750
Double_t Sqrt(Double_t x)
Definition: TMath.h:681
static void output(int code)
Definition: gifencode.c:226