Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
VariableTransformBase.cxx
Go to the documentation of this file.
1// @(#)root/tmva $Id$
2// Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : VariableTransformBase *
8 * *
9 * *
10 * Description: *
11 * Implementation (see header for description) *
12 * *
13 * Authors (alphabetical): *
14 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland *
15 * Peter Speckmayer <Peter.Speckmayer@cern.ch> - CERN, Switzerland *
16 * Joerg Stelzer <Joerg.Stelzer@cern.ch> - CERN, Switzerland *
17 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany *
18 * *
19 * Copyright (c) 2005: *
20 * CERN, Switzerland *
21 * MPI-K Heidelberg, Germany *
22 * *
23 * Redistribution and use in source and binary forms, with or without *
24 * modification, are permitted according to the terms listed in LICENSE *
25 * (see tmva/doc/LICENSE) *
26 **********************************************************************************/
27
28/*! \class TMVA::VariableTransformBase
29\ingroup TMVA
30Linear interpolation class.
31*/
32
34
35#include "TMVA/Config.h"
36#include "TMVA/DataSetInfo.h"
37#include "TMVA/MsgLogger.h"
38#include "TMVA/Ranking.h"
39#include "TMVA/Tools.h"
40#include "TMVA/Types.h"
41#include "TMVA/VariableInfo.h"
42#include "TMVA/Version.h"
43
44#include "THashTable.h"
45#include "TList.h"
46#include "TObjString.h"
47#include "TMath.h"
48#include "TVectorD.h"
49
50#include <algorithm>
51#include <cassert>
52#include <exception>
53#include <iomanip>
54#include <stdexcept>
55#include <set>
56
58
60
61////////////////////////////////////////////////////////////////////////////////
62/// standard constructor
63
66 const TString& trfName )
67: TObject(),
68 fDsi(dsi),
69 fDsiOutput(NULL),
70 fTransformedEvent(0),
71 fBackTransformedEvent(0),
72 fVariableTransform(tf),
73 fEnabled( kTRUE ),
74 fCreated( kFALSE ),
75 fNormalise( kFALSE ),
76 fTransformName(trfName),
77 fVariableTypesAreCounted(false),
78 fNVariables(0),
79 fNTargets(0),
80 fNSpectators(0),
81 fSortGet(kTRUE),
82 fTMVAVersion(TMVA_VERSION_CODE),
83 fLogger( 0 )
84{
85 fLogger = new MsgLogger(this, kINFO);
86 for (UInt_t ivar = 0; ivar < fDsi.GetNVariables(); ivar++) {
87 fVariables.push_back( VariableInfo( fDsi.GetVariableInfo(ivar) ) );
88 }
89 for (UInt_t itgt = 0; itgt < fDsi.GetNTargets(); itgt++) {
90 fTargets.push_back( VariableInfo( fDsi.GetTargetInfo(itgt) ) );
91 }
92 for (UInt_t ispct = 0; ispct < fDsi.GetNSpectators(); ispct++) {
93 fSpectators.push_back( VariableInfo( fDsi.GetSpectatorInfo(ispct) ) );
94 }
95}
96
97////////////////////////////////////////////////////////////////////////////////
98
100{
101 if (fTransformedEvent!=0) delete fTransformedEvent;
102 if (fBackTransformedEvent!=0) delete fBackTransformedEvent;
103 // destructor
104 delete fLogger;
105}
106
107////////////////////////////////////////////////////////////////////////////////
108/// select the variables/targets/spectators which serve as input to the transformation
109
110void TMVA::VariableTransformBase::SelectInput( const TString& _inputVariables, Bool_t putIntoVariables )
111{
112 TString inputVariables = _inputVariables;
113
114 // unselect all variables first
115 fGet.clear();
116
117 UInt_t nvars = GetNVariables();
118 UInt_t ntgts = GetNTargets();
119 UInt_t nspcts = GetNSpectators();
120
121 typedef std::set<Int_t> SelectedIndices;
122
123 SelectedIndices varIndices;
124 SelectedIndices tgtIndices;
125 SelectedIndices spctIndices;
126
127 if (inputVariables == "") // default is all variables and all targets
128 { // (the default can be changed by decorating this member function in the implementations)
129 inputVariables = "_V_,_T_";
130 }
131
132 TList* inList = gTools().ParseFormatLine( inputVariables, "," );
133 TListIter inIt(inList);
134 while (TObjString* os = (TObjString*)inIt()) {
135
136 TString variables = os->GetString();
137
138 if( variables.BeginsWith("_") && variables.EndsWith("_") ) { // special symbol (keyword)
139 variables.Remove( 0,1); // remove first "_"
140 variables.Remove( variables.Length()-1,1 ); // remove last "_"
141
142 if( variables.BeginsWith("V") ) { // variables
143 variables.Remove(0,1); // remove "V"
144 if( variables.Length() == 0 ){
145 for( UInt_t ivar = 0; ivar < nvars; ++ivar ) {
146 fGet.push_back( std::pair<Char_t,UInt_t>('v',ivar) );
147 varIndices.insert( ivar );
148 }
149 } else {
150 UInt_t idx = variables.Atoi();
151 if( idx >= nvars )
152 Log() << kFATAL << "You selected variable with index : " << idx << " of only " << nvars << " variables." << Endl;
153 fGet.push_back( std::pair<Char_t,UInt_t>('v',idx) );
154 varIndices.insert( idx );
155 }
156 }else if( variables.BeginsWith("T") ) { // targets
157 variables.Remove(0,1); // remove "T"
158 if( variables.Length() == 0 ){
159 for( UInt_t itgt = 0; itgt < ntgts; ++itgt ) {
160 fGet.push_back( std::pair<Char_t,UInt_t>('t',itgt) );
161 tgtIndices.insert( itgt );
162 }
163 } else {
164 UInt_t idx = variables.Atoi();
165 if( idx >= ntgts )
166 Log() << kFATAL << "You selected target with index : " << idx << " of only " << ntgts << " targets." << Endl;
167 fGet.push_back( std::pair<Char_t,UInt_t>('t',idx) );
168 tgtIndices.insert( idx );
169 }
170 }else if( variables.BeginsWith("S") ) { // spectators
171 variables.Remove(0,1); // remove "S"
172 if( variables.Length() == 0 ){
173 for( UInt_t ispct = 0; ispct < nspcts; ++ispct ) {
174 fGet.push_back( std::pair<Char_t,UInt_t>('s',ispct) );
175 spctIndices.insert( ispct );
176 }
177 } else {
178 UInt_t idx = variables.Atoi();
179 if( idx >= nspcts )
180 Log() << kFATAL << "You selected spectator with index : " << idx << " of only " << nspcts << " spectators." << Endl;
181 fGet.push_back( std::pair<Char_t,UInt_t>('s',idx) );
182 spctIndices.insert( idx );
183 }
184 }else if( TString("REARRANGE").BeginsWith(variables) ) { // toggle rearrange sorting (take sort order given in the options)
185 ToggleInputSortOrder( kFALSE );
186 if( !fSortGet )
187 Log() << kINFO << "Variable rearrangement set true: Variable order given in transformation option is used for input to transformation!" << Endl;
188
189 }
190 }else{ // no keyword, ... user provided variable labels
191 Int_t numIndices = varIndices.size()+tgtIndices.size()+spctIndices.size();
192 for( UInt_t ivar = 0; ivar < nvars; ++ivar ) { // search all variables
193 if( fDsi.GetVariableInfo( ivar ).GetLabel() == variables ) {
194 fGet.push_back( std::pair<Char_t,UInt_t>('v',ivar) );
195 varIndices.insert( ivar );
196 break;
197 }
198 }
199 for( UInt_t itgt = 0; itgt < ntgts; ++itgt ) { // search all targets
200 if( fDsi.GetTargetInfo( itgt ).GetLabel() == variables ) {
201 fGet.push_back( std::pair<Char_t,UInt_t>('t',itgt) );
202 tgtIndices.insert( itgt );
203 break;
204 }
205 }
206 for( UInt_t ispct = 0; ispct < nspcts; ++ispct ) { // search all spectators
207 if( fDsi.GetSpectatorInfo( ispct ).GetLabel() == variables ) {
208 fGet.push_back( std::pair<Char_t,UInt_t>('s',ispct) );
209 spctIndices.insert( ispct );
210 break;
211 }
212 }
213 Int_t numIndicesEndOfLoop = varIndices.size()+tgtIndices.size()+spctIndices.size();
214 if( numIndicesEndOfLoop == numIndices )
215 Log() << kWARNING << "Error at parsing the options for the variable transformations: Variable/Target/Spectator '" << variables.Data() << "' not found." << Endl;
216 numIndices = numIndicesEndOfLoop;
217 }
218 }
219
220
221 if( putIntoVariables ) {
222 Int_t idx = 0;
223 for( SelectedIndices::iterator it = varIndices.begin(), itEnd = varIndices.end(); it != itEnd; ++it ) {
224 fPut.push_back( std::pair<Char_t,UInt_t>('v',idx) );
225 ++idx;
226 }
227 for( SelectedIndices::iterator it = tgtIndices.begin(), itEnd = tgtIndices.end(); it != itEnd; ++it ) {
228 fPut.push_back( std::pair<Char_t,UInt_t>('t',idx) );
229 ++idx;
230 }
231 for( SelectedIndices::iterator it = spctIndices.begin(), itEnd = spctIndices.end(); it != itEnd; ++it ) {
232 fPut.push_back( std::pair<Char_t,UInt_t>('s',idx) );
233 ++idx;
234 }
235 }else {
236 for( SelectedIndices::iterator it = varIndices.begin(), itEnd = varIndices.end(); it != itEnd; ++it ) {
237 Int_t idx = (*it);
238 fPut.push_back( std::pair<Char_t,UInt_t>('v',idx) );
239 }
240 for( SelectedIndices::iterator it = tgtIndices.begin(), itEnd = tgtIndices.end(); it != itEnd; ++it ) {
241 Int_t idx = (*it);
242 fPut.push_back( std::pair<Char_t,UInt_t>('t',idx) );
243 }
244 for( SelectedIndices::iterator it = spctIndices.begin(), itEnd = spctIndices.end(); it != itEnd; ++it ) {
245 Int_t idx = (*it);
246 fPut.push_back( std::pair<Char_t,UInt_t>('s',idx) );
247 }
248
249 // if sorting is turned on, fGet should have the indices sorted as fPut has them.
250 if( fSortGet ) {
251 fGet.clear();
252 fGet.assign( fPut.begin(), fPut.end() );
253 }
254 }
255
256 Log() << kHEADER << "Transformation, Variable selection : " << Endl;
257
258 // choose the new dsi for output if present, if not, take the common one
259 const DataSetInfo* outputDsiPtr = (fDsiOutput? &(*fDsiOutput) : &fDsi );
260
261
262
263 ItVarTypeIdx itGet = fGet.begin(), itGetEnd = fGet.end();
264 ItVarTypeIdx itPut = fPut.begin(); // , itPutEnd = fPut.end();
265 for( ; itGet != itGetEnd; ++itGet ) {
266 TString inputTypeString = "?";
267
268 Char_t inputType = (*itGet).first;
269 Int_t inputIdx = (*itGet).second;
270
271 TString inputLabel = "NOT FOND";
272 if( inputType == 'v' ) {
273 inputLabel = fDsi.GetVariableInfo( inputIdx ).GetLabel();
274 inputTypeString = "variable";
275 }
276 else if( inputType == 't' ){
277 inputLabel = fDsi.GetTargetInfo( inputIdx ).GetLabel();
278 inputTypeString = "target";
279 }
280 else if( inputType == 's' ){
281 inputLabel = fDsi.GetSpectatorInfo( inputIdx ).GetLabel();
282 inputTypeString = "spectator";
283 }
284
285 TString outputTypeString = "?";
286
287 Char_t outputType = (*itPut).first;
288 Int_t outputIdx = (*itPut).second;
289
290 TString outputLabel = "NOT FOUND";
291 if( outputType == 'v' ) {
292 outputLabel = outputDsiPtr->GetVariableInfo( outputIdx ).GetLabel();
293 outputTypeString = "variable";
294 }
295 else if( outputType == 't' ){
296 outputLabel = outputDsiPtr->GetTargetInfo( outputIdx ).GetLabel();
297 outputTypeString = "target";
298 }
299 else if( outputType == 's' ){
300 outputLabel = outputDsiPtr->GetSpectatorInfo( outputIdx ).GetLabel();
301 outputTypeString = "spectator";
302 }
303 Log() << kINFO << "Input : " << inputTypeString.Data() << " '" << inputLabel.Data() << "'" << " <---> " << "Output : " << outputTypeString.Data() << " '" << outputLabel.Data() << "'" << Endl;
304 Log() << kDEBUG << "\t(index=" << inputIdx << ")." << "\t(index=" << outputIdx << ")." << Endl;
305
306 ++itPut;
307 }
308 // Log() << kINFO << Endl;
309}
310
311
312////////////////////////////////////////////////////////////////////////////////
313/// select the values from the event
314
315Bool_t TMVA::VariableTransformBase::GetInput( const Event* event, std::vector<Float_t>& input, std::vector<Char_t>& mask, Bool_t backTransformation ) const
316{
317 ItVarTypeIdxConst itEntry;
318 ItVarTypeIdxConst itEntryEnd;
319
320 input.clear();
321 mask.clear();
322
323 if( backTransformation && !fPut.empty() ){
324 itEntry = fPut.begin();
325 itEntryEnd = fPut.end();
326 input.reserve(fPut.size());
327 }
328 else {
329 itEntry = fGet.begin();
330 itEntryEnd = fGet.end();
331 input.reserve(fGet.size() );
332 }
333
334 Bool_t hasMaskedEntries = kFALSE;
335 // event->Print(std::cout);
336 for( ; itEntry != itEntryEnd; ++itEntry ) {
337 Char_t type = (*itEntry).first;
338 Int_t idx = (*itEntry).second;
339
340 try{
341 switch( type ) {
342 case 'v':
343 input.push_back( event->GetValue(idx) );
344 break;
345 case 't':
346 input.push_back( event->GetTarget(idx) );
347 break;
348 case 's':
349 input.push_back( event->GetSpectator(idx) );
350 break;
351 default:
352 Log() << kFATAL << "VariableTransformBase/GetInput : unknown type '" << type << "'." << Endl;
353 }
354 mask.push_back(kFALSE);
355 }
356 catch(std::out_of_range& /* excpt */ ){ // happens when an event is transformed which does not yet have the targets calculated (in the application phase)
357 input.push_back(0.f);
358 mask.push_back(kTRUE);
359 hasMaskedEntries = kTRUE;
360 }
361 }
362 return hasMaskedEntries;
363}
364
365////////////////////////////////////////////////////////////////////////////////
366/// select the values from the event
367
368void TMVA::VariableTransformBase::SetOutput( Event* event, std::vector<Float_t>& output, std::vector<Char_t>& mask, const Event* oldEvent, Bool_t backTransformation ) const
369{
370 std::vector<Float_t>::iterator itOutput = output.begin();
371 std::vector<Char_t>::iterator itMask = mask.begin();
372
373 if( oldEvent )
374 event->CopyVarValues( *oldEvent );
375
376 try {
377
378 ItVarTypeIdxConst itEntry;
379 ItVarTypeIdxConst itEntryEnd;
380
381 if( backTransformation || fPut.empty() ){ // as in GetInput, but the other way round (from fPut for transformation, from fGet for backTransformation)
382 itEntry = fGet.begin();
383 itEntryEnd = fGet.end();
384 }
385 else {
386 itEntry = fPut.begin();
387 itEntryEnd = fPut.end();
388 }
389
390
391 for( ; itEntry != itEntryEnd; ++itEntry ) {
392
393 if( (*itMask) ){ // if the value is masked
394 continue;
395 }
396
397 Char_t type = (*itEntry).first;
398 Int_t idx = (*itEntry).second;
399 if (itOutput == output.end()) Log() << kFATAL << "Read beyond array boundaries in VariableTransformBase::SetOutput"<<Endl;
400 Float_t value = (*itOutput);
401
402 switch( type ) {
403 case 'v':
404 event->SetVal( idx, value );
405 break;
406 case 't':
407 event->SetTarget( idx, value );
408 break;
409 case 's':
410 event->SetSpectator( idx, value );
411 break;
412 default:
413 Log() << kFATAL << "VariableTransformBase/GetInput : unknown type '" << type << "'." << Endl;
414 }
415 if( !(*itMask) ) ++itOutput;
416 ++itMask;
417
418 }
419 }catch( std::exception& except ){
420 Log() << kFATAL << "VariableTransformBase/SetOutput : exception/" << except.what() << Endl;
421 throw;
422 }
423}
424
425
426////////////////////////////////////////////////////////////////////////////////
427/// count variables, targets and spectators
428
430{
431 if( fVariableTypesAreCounted ){
432 nvars = fNVariables;
433 ntgts = fNTargets;
434 nspcts = fNSpectators;
435 return;
436 }
437
438 nvars = ntgts = nspcts = 0;
439
440 for( ItVarTypeIdxConst itEntry = fGet.begin(), itEntryEnd = fGet.end(); itEntry != itEntryEnd; ++itEntry ) {
441 Char_t type = (*itEntry).first;
442
443 switch( type ) {
444 case 'v':
445 nvars++;
446 break;
447 case 't':
448 ntgts++;
449 break;
450 case 's':
451 nspcts++;
452 break;
453 default:
454 Log() << kFATAL << "VariableTransformBase/GetVariableTypeNumbers : unknown type '" << type << "'." << Endl;
455 }
456 }
457
458 fNVariables = nvars;
459 fNTargets = ntgts;
460 fNSpectators = nspcts;
461
462 fVariableTypesAreCounted = true;
463}
464
465////////////////////////////////////////////////////////////////////////////////
466/// TODO --> adapt to variable,target,spectator selection
467/// method to calculate minimum, maximum, mean, and RMS for all
468/// variables used in the MVA
469
470void TMVA::VariableTransformBase::CalcNorm( const std::vector<const Event*>& events )
471{
472 if (!IsCreated()) return;
473
474 const UInt_t nvars = GetNVariables();
475 const UInt_t ntgts = GetNTargets();
476
477 UInt_t nevts = events.size();
478
479 TVectorD x2( nvars+ntgts ); x2 *= 0;
480 TVectorD x0( nvars+ntgts ); x0 *= 0;
481 TVectorD v0( nvars+ntgts ); v0 *= 0;
482
483 Double_t sumOfWeights = 0;
484 for (UInt_t ievt=0; ievt<nevts; ievt++) {
485 const Event* ev = events[ievt];
486
487 Double_t weight = ev->GetWeight();
488 sumOfWeights += weight;
489 for (UInt_t ivar=0; ivar<nvars; ivar++) {
490 Double_t x = ev->GetValue(ivar);
491 if (ievt==0) {
492 Variables().at(ivar).SetMin(x);
493 Variables().at(ivar).SetMax(x);
494 }
495 else {
496 UpdateNorm( ivar, x );
497 }
498 x0(ivar) += x*weight;
499 x2(ivar) += x*x*weight;
500 }
501 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
502 Double_t x = ev->GetTarget(itgt);
503 if (ievt==0) {
504 Targets().at(itgt).SetMin(x);
505 Targets().at(itgt).SetMax(x);
506 }
507 else {
508 UpdateNorm( nvars+itgt, x );
509 }
510 x0(nvars+itgt) += x*weight;
511 x2(nvars+itgt) += x*x*weight;
512 }
513 }
514
515 if (sumOfWeights <= 0) {
516 Log() << kFATAL << " the sum of event weights calculated for your input is == 0"
517 << " or exactly: " << sumOfWeights << " there is obviously some problem..."<< Endl;
518 }
519
520 // set Mean and RMS
521 for (UInt_t ivar=0; ivar<nvars; ivar++) {
522 Double_t mean = x0(ivar)/sumOfWeights;
523
524 Variables().at(ivar).SetMean( mean );
525 if (x2(ivar)/sumOfWeights - mean*mean < 0) {
526 Log() << kFATAL << " the RMS of your input variable " << ivar
527 << " evaluates to an imaginary number: sqrt("<< x2(ivar)/sumOfWeights - mean*mean
528 <<") .. sometimes related to a problem with outliers and negative event weights"
529 << Endl;
530 }
531 Variables().at(ivar).SetRMS( TMath::Sqrt( x2(ivar)/sumOfWeights - mean*mean) );
532 }
533 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
534 Double_t mean = x0(nvars+itgt)/sumOfWeights;
535 Targets().at(itgt).SetMean( mean );
536 if (x2(nvars+itgt)/sumOfWeights - mean*mean < 0) {
537 Log() << kFATAL << " the RMS of your target variable " << itgt
538 << " evaluates to an imaginary number: sqrt(" << x2(nvars+itgt)/sumOfWeights - mean*mean
539 <<") .. sometimes related to a problem with outliers and negative event weights"
540 << Endl;
541 }
542 Targets().at(itgt).SetRMS( TMath::Sqrt( x2(nvars+itgt)/sumOfWeights - mean*mean) );
543 }
544 // calculate variance
545 for (UInt_t ievt=0; ievt<nevts; ievt++) {
546 const Event* ev = events[ievt];
547 Double_t weight = ev->GetWeight();
548 for (UInt_t ivar=0; ivar<nvars; ivar++) {
549 Double_t x = ev->GetValue(ivar);
550 Double_t mean = Variables().at(ivar).GetMean();
551 v0(ivar) += weight*(x-mean)*(x-mean);
552 }
553 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
554 Double_t x = ev->GetTarget(itgt);
555 Double_t mean = Targets().at(itgt).GetMean();
556 v0(nvars+itgt) += weight*(x-mean)*(x-mean);
557 }
558
559 }
560
561 // set variance
562 for (UInt_t ivar=0; ivar<nvars; ivar++) {
563 Double_t variance = v0(ivar)/sumOfWeights;
564 Variables().at(ivar).SetVariance( variance );
565 Log() << kINFO << "Variable " << Variables().at(ivar).GetExpression() <<" variance = " << variance << Endl;
566 }
567 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
568 Double_t variance = v0(nvars+itgt)/sumOfWeights;
569 Targets().at(itgt).SetVariance( variance );
570 Log() << kINFO << "Target " << Targets().at(itgt).GetExpression() <<" variance = " << variance << Endl;
571 }
572
573 Log() << kVERBOSE << "Set minNorm/maxNorm for variables to: " << Endl;
574 Log() << std::setprecision(3);
575 for (UInt_t ivar=0; ivar<GetNVariables(); ivar++)
576 Log() << " " << Variables().at(ivar).GetInternalName()
577 << "\t: [" << Variables().at(ivar).GetMin() << "\t, " << Variables().at(ivar).GetMax() << "\t] " << Endl;
578 Log() << kVERBOSE << "Set minNorm/maxNorm for targets to: " << Endl;
579 Log() << std::setprecision(3);
580 for (UInt_t itgt=0; itgt<GetNTargets(); itgt++)
581 Log() << " " << Targets().at(itgt).GetInternalName()
582 << "\t: [" << Targets().at(itgt).GetMin() << "\t, " << Targets().at(itgt).GetMax() << "\t] " << Endl;
583 Log() << std::setprecision(5); // reset to better value
584}
585
586////////////////////////////////////////////////////////////////////////////////
587/// TODO --> adapt to variable,target,spectator selection
588/// default transformation output
589/// --> only indicate that transformation occurred
590
592{
593 std::vector<TString>* strVec = new std::vector<TString>;
594 for (UInt_t ivar=0; ivar<GetNVariables(); ivar++) {
595 strVec->push_back( Variables()[ivar].GetLabel() + "_[transformed]");
596 }
597
598 return strVec;
599}
600
601////////////////////////////////////////////////////////////////////////////////
602/// TODO --> adapt to variable,target,spectator selection
603/// update min and max of a given variable (target) and a given transformation method
604
606{
607 Int_t nvars = fDsi.GetNVariables();
608 if( ivar < nvars ){
609 if (x < Variables().at(ivar).GetMin()) Variables().at(ivar).SetMin(x);
610 if (x > Variables().at(ivar).GetMax()) Variables().at(ivar).SetMax(x);
611 }else{
612 if (x < Targets().at(ivar-nvars).GetMin()) Targets().at(ivar-nvars).SetMin(x);
613 if (x > Targets().at(ivar-nvars).GetMax()) Targets().at(ivar-nvars).SetMax(x);
614 }
615}
616
617////////////////////////////////////////////////////////////////////////////////
618/// create XML description the transformation (write out info of selected variables)
619
621{
622 void* selxml = gTools().AddChild(parent, "Selection");
623
624 void* inpxml = gTools().AddChild(selxml, "Input");
625 gTools().AddAttr(inpxml, "NInputs", fGet.size() );
626
627 // choose the new dsi for output if present, if not, take the common one
628 const DataSetInfo* outputDsiPtr = (fDsiOutput? fDsiOutput : &fDsi );
629
630 for( ItVarTypeIdx itGet = fGet.begin(), itGetEnd = fGet.end(); itGet != itGetEnd; ++itGet ) {
631 UInt_t idx = (*itGet).second;
632 Char_t type = (*itGet).first;
633
634 TString label = "";
635 TString expression = "";
636 TString typeString = "";
637 switch( type ){
638 case 'v':
639 typeString = "Variable";
640 label = fDsi.GetVariableInfo( idx ).GetLabel();
641 expression = fDsi.GetVariableInfo( idx ).GetExpression();
642 break;
643 case 't':
644 typeString = "Target";
645 label = fDsi.GetTargetInfo( idx ).GetLabel();
646 expression = fDsi.GetTargetInfo( idx ).GetExpression();
647 break;
648 case 's':
649 typeString = "Spectator";
650 label = fDsi.GetSpectatorInfo( idx ).GetLabel();
651 expression = fDsi.GetSpectatorInfo( idx ).GetExpression();
652 break;
653 default:
654 Log() << kFATAL << "VariableTransformBase/AttachXMLTo unknown variable type '" << type << "'." << Endl;
655 }
656
657 void* idxxml = gTools().AddChild(inpxml, "Input");
658 // gTools().AddAttr(idxxml, "Index", idx);
659 gTools().AddAttr(idxxml, "Type", typeString);
660 gTools().AddAttr(idxxml, "Label", label);
661 gTools().AddAttr(idxxml, "Expression", expression);
662 }
663
664
665 void* outxml = gTools().AddChild(selxml, "Output");
666 gTools().AddAttr(outxml, "NOutputs", fPut.size() );
667
668 for( ItVarTypeIdx itPut = fPut.begin(), itPutEnd = fPut.end(); itPut != itPutEnd; ++itPut ) {
669 UInt_t idx = (*itPut).second;
670 Char_t type = (*itPut).first;
671
672 TString label = "";
673 TString expression = "";
674 TString typeString = "";
675 switch( type ){
676 case 'v':
677 typeString = "Variable";
678 label = outputDsiPtr->GetVariableInfo( idx ).GetLabel();
679 expression = outputDsiPtr->GetVariableInfo( idx ).GetExpression();
680 break;
681 case 't':
682 typeString = "Target";
683 label = outputDsiPtr->GetTargetInfo( idx ).GetLabel();
684 expression = outputDsiPtr->GetTargetInfo( idx ).GetExpression();
685 break;
686 case 's':
687 typeString = "Spectator";
688 label = outputDsiPtr->GetSpectatorInfo( idx ).GetLabel();
689 expression = outputDsiPtr->GetSpectatorInfo( idx ).GetExpression();
690 break;
691 default:
692 Log() << kFATAL << "VariableTransformBase/AttachXMLTo unknown variable type '" << type << "'." << Endl;
693 }
694
695 void* idxxml = gTools().AddChild(outxml, "Output");
696 // gTools().AddAttr(idxxml, "Index", idx);
697 gTools().AddAttr(idxxml, "Type", typeString);
698 gTools().AddAttr(idxxml, "Label", label);
699 gTools().AddAttr(idxxml, "Expression", expression);
700 }
701
702
703}
704
705////////////////////////////////////////////////////////////////////////////////
706/// Read the input variables from the XML node
707
709{
710 void* inpnode = gTools().GetChild( selnode );
711 void* outnode = gTools().GetNextChild( inpnode );
712
713 UInt_t nvars = GetNVariables();
714 UInt_t ntgts = GetNTargets();
715 UInt_t nspcts = GetNSpectators();
716
717 // read inputs
718 fGet.clear();
719
720 UInt_t nInputs = 0;
721 gTools().ReadAttr(inpnode, "NInputs", nInputs);
722
723 void* ch = gTools().GetChild( inpnode );
724 while(ch) {
725 TString typeString = "";
726 TString label = "";
727 TString expression = "";
728
729 gTools().ReadAttr(ch, "Type", typeString);
730 gTools().ReadAttr(ch, "Label", label);
731 gTools().ReadAttr(ch, "Expression", expression);
732
733 if( typeString == "Variable" ){
734 for( UInt_t ivar = 0; ivar < nvars; ++ivar ) { // search all variables
735 if( fDsi.GetVariableInfo( ivar ).GetLabel() == label ||
736 fDsi.GetVariableInfo( ivar ).GetExpression() == expression) {
737 fGet.push_back( std::pair<Char_t,UInt_t>('v',ivar) );
738 break;
739 }
740 }
741 }else if( typeString == "Target" ){
742 for( UInt_t itgt = 0; itgt < ntgts; ++itgt ) { // search all targets
743 if( fDsi.GetTargetInfo( itgt ).GetLabel() == label ||
744 fDsi.GetTargetInfo( itgt ).GetExpression() == expression ) {
745 fGet.push_back( std::pair<Char_t,UInt_t>('t',itgt) );
746 break;
747 }
748 }
749 }else if( typeString == "Spectator" ){
750 for( UInt_t ispct = 0; ispct < nspcts; ++ispct ) { // search all spectators
751 if( fDsi.GetSpectatorInfo( ispct ).GetLabel() == label ||
752 fDsi.GetSpectatorInfo( ispct ).GetExpression() == expression ) {
753 fGet.push_back( std::pair<Char_t,UInt_t>('s',ispct) );
754 break;
755 }
756 }
757 }else{
758 Log() << kFATAL << "VariableTransformationBase/ReadFromXML : unknown type '" << typeString << "'." << Endl;
759 }
760 ch = gTools().GetNextChild( ch );
761 }
762
763 assert( nInputs == fGet.size() );
764
765 // read outputs
766 fPut.clear();
767
768 UInt_t nOutputs = 0;
769 gTools().ReadAttr(outnode, "NOutputs", nOutputs);
770
771 void* chOut = gTools().GetChild( outnode );
772 while(chOut) {
773 TString typeString = "";
774 TString label = "";
775 TString expression = "";
776
777 gTools().ReadAttr(chOut, "Type", typeString);
778 gTools().ReadAttr(chOut, "Label", label);
779 gTools().ReadAttr(chOut, "Expression", expression);
780
781 if( typeString == "Variable" ){
782 for( UInt_t ivar = 0; ivar < nvars; ++ivar ) { // search all variables
783 if( fDsi.GetVariableInfo( ivar ).GetLabel() == label ||
784 fDsi.GetVariableInfo( ivar ).GetExpression() == expression ) {
785 fPut.push_back( std::pair<Char_t,UInt_t>('v',ivar) );
786 break;
787 }
788 }
789 }else if( typeString == "Target" ){
790 for( UInt_t itgt = 0; itgt < ntgts; ++itgt ) { // search all targets
791 if( fDsi.GetTargetInfo( itgt ).GetLabel() == label ||
792 fDsi.GetTargetInfo( itgt ).GetExpression() == expression ) {
793 fPut.push_back( std::pair<Char_t,UInt_t>('t',itgt) );
794 break;
795 }
796 }
797 }else if( typeString == "Spectator" ){
798 for( UInt_t ispct = 0; ispct < nspcts; ++ispct ) { // search all spectators
799 if( fDsi.GetSpectatorInfo( ispct ).GetLabel() == label ||
800 fDsi.GetSpectatorInfo( ispct ).GetExpression() == expression ) {
801 fPut.push_back( std::pair<Char_t,UInt_t>('s',ispct) );
802 break;
803 }
804 }
805 }else{
806 Log() << kFATAL << "VariableTransformationBase/ReadFromXML : unknown type '" << typeString << "'." << Endl;
807 }
808 chOut = gTools().GetNextChild( chOut );
809 }
810
811 assert( nOutputs == fPut.size() );
812}
813
814////////////////////////////////////////////////////////////////////////////////
815/// getinput and setoutput equivalent
816
817void TMVA::VariableTransformBase::MakeFunction( std::ostream& fout, const TString& /*fncName*/, Int_t part,
818 UInt_t /*trCounter*/, Int_t /*cls*/ )
819{
820 if( part == 0 ){ // definitions
821 fout << std::endl;
822 fout << " // define the indices of the variables which are transformed by this transformation" << std::endl;
823 fout << " static std::vector<int> indicesGet;" << std::endl;
824 fout << " static std::vector<int> indicesPut;" << std::endl << std::endl;
825 fout << " if ( indicesGet.empty() ) {" << std::endl;
826 fout << " indicesGet.reserve(fNvars);" << std::endl;
827
828 for( ItVarTypeIdxConst itEntry = fGet.begin(), itEntryEnd = fGet.end(); itEntry != itEntryEnd; ++itEntry ) {
829 Char_t type = (*itEntry).first;
830 Int_t idx = (*itEntry).second;
831
832 switch( type ) {
833 case 'v':
834 fout << " indicesGet.push_back( " << idx << ");" << std::endl;
835 break;
836 case 't':
837 Log() << kWARNING << "MakeClass doesn't work with transformation of targets. The results will be wrong!" << Endl;
838 break;
839 case 's':
840 Log() << kWARNING << "MakeClass doesn't work with transformation of spectators. The results will be wrong!" << Endl;
841 break;
842 default:
843 Log() << kFATAL << "VariableTransformBase/GetInput : unknown type '" << type << "'." << Endl;
844 }
845 }
846 fout << " }" << std::endl;
847 fout << " if ( indicesPut.empty() ) {" << std::endl;
848 fout << " indicesPut.reserve(fNvars);" << std::endl;
849
850 for( ItVarTypeIdxConst itEntry = fPut.begin(), itEntryEnd = fPut.end(); itEntry != itEntryEnd; ++itEntry ) {
851 Char_t type = (*itEntry).first;
852 Int_t idx = (*itEntry).second;
853
854 switch( type ) {
855 case 'v':
856 fout << " indicesPut.push_back( " << idx << ");" << std::endl;
857 break;
858 case 't':
859 Log() << kWARNING << "MakeClass doesn't work with transformation of targets. The results will be wrong!" << Endl;
860 break;
861 case 's':
862 Log() << kWARNING << "MakeClass doesn't work with transformation of spectators. The results will be wrong!" << Endl;
863 break;
864 default:
865 Log() << kFATAL << "VariableTransformBase/PutInput : unknown type '" << type << "'." << Endl;
866 }
867 }
868
869 fout << " }" << std::endl;
870 fout << std::endl;
871
872 }else if( part == 1){
873 }
874}
char Char_t
Definition RtypesCore.h:37
float Float_t
Definition RtypesCore.h:57
constexpr Bool_t kFALSE
Definition RtypesCore.h:101
constexpr Bool_t kTRUE
Definition RtypesCore.h:100
#define ClassImp(name)
Definition Rtypes.h:377
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t mask
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void value
Option_t Option_t TPoint TPoint const char x2
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
bool advanced
#define TMVA_VERSION_CODE
Definition Version.h:47
Iterator of linked list.
Definition TList.h:193
A doubly linked list.
Definition TList.h:38
Class that contains all the data information.
Definition DataSetInfo.h:62
UInt_t GetNVariables() const
UInt_t GetNSpectators(bool all=kTRUE) const
UInt_t GetNTargets() const
VariableInfo & GetVariableInfo(Int_t i)
VariableInfo & GetTargetInfo(Int_t i)
VariableInfo & GetSpectatorInfo(Int_t i)
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Definition Event.cxx:236
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Definition Event.cxx:389
Float_t GetSpectator(UInt_t ivar) const
return spectator content
Definition Event.cxx:261
Float_t GetTarget(UInt_t itgt) const
Definition Event.h:102
ostringstream derivative to redirect and format output
Definition MsgLogger.h:57
TList * ParseFormatLine(TString theString, const char *sep=":")
Parse the string and cut into labels separated by ":".
Definition Tools.cxx:401
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition Tools.h:329
void * GetChild(void *parent, const char *childname=nullptr)
get child node
Definition Tools.cxx:1150
void AddAttr(void *node, const char *, const T &value, Int_t precision=16)
add attribute to xml
Definition Tools.h:347
void * AddChild(void *parent, const char *childname, const char *content=nullptr, bool isRootNode=false)
add child node
Definition Tools.cxx:1124
void * GetNextChild(void *prevchild, const char *childname=nullptr)
XML helpers.
Definition Tools.cxx:1162
EVariableTransform
Definition Types.h:114
Class for type info of MVA input variable.
const TString & GetLabel() const
const TString & GetExpression() const
Linear interpolation class.
virtual void MakeFunction(std::ostream &fout, const TString &fncName, Int_t part, UInt_t trCounter, Int_t cls)=0
getinput and setoutput equivalent
virtual void SetOutput(Event *event, std::vector< Float_t > &output, std::vector< Char_t > &mask, const Event *oldEvent=nullptr, Bool_t backTransform=kFALSE) const
select the values from the event
virtual Bool_t GetInput(const Event *event, std::vector< Float_t > &input, std::vector< Char_t > &mask, Bool_t backTransform=kFALSE) const
select the values from the event
void CalcNorm(const std::vector< const Event * > &)
TODO --> adapt to variable,target,spectator selection method to calculate minimum,...
virtual void ReadFromXML(void *trfnode)=0
Read the input variables from the XML node.
virtual void AttachXMLTo(void *parent)=0
create XML description the transformation (write out info of selected variables)
std::vector< TMVA::VariableInfo > fVariables
event variables [saved to weight file]
VariableTransformBase(DataSetInfo &dsi, Types::EVariableTransform tf, const TString &trfName)
standard constructor
void UpdateNorm(Int_t ivar, Double_t x)
TODO --> adapt to variable,target,spectator selection update min and max of a given variable (target)...
virtual void CountVariableTypes(UInt_t &nvars, UInt_t &ntgts, UInt_t &nspcts) const
count variables, targets and spectators
virtual std::vector< TString > * GetTransformationStrings(Int_t cls) const
TODO --> adapt to variable,target,spectator selection default transformation output --> only indicate...
virtual void SelectInput(const TString &inputVariables, Bool_t putIntoVariables=kFALSE)
select the variables/targets/spectators which serve as input to the transformation
VectorOfCharAndInt::iterator ItVarTypeIdx
std::vector< TMVA::VariableInfo > fSpectators
event spectators [saved to weight file --> TODO ]
std::vector< TMVA::VariableInfo > fTargets
event targets [saved to weight file --> TODO ]
MsgLogger * fLogger
! message logger
VectorOfCharAndInt::const_iterator ItVarTypeIdxConst
Collectable string class.
Definition TObjString.h:28
Mother of all ROOT objects.
Definition TObject.h:41
Basic string class.
Definition TString.h:139
const char * Data() const
Definition TString.h:376
TString & Remove(Ssiz_t pos)
Definition TString.h:685
Double_t x[n]
Definition legend1.C:17
Tools & gTools()
void variables(TString dataset, TString fin="TMVA.root", TString dirName="InputVariables_Id", TString title="TMVA Input Variables", Bool_t isRegression=kFALSE, Bool_t useTMVAStyle=kTRUE)
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148
Double_t Sqrt(Double_t x)
Returns the square root of x.
Definition TMath.h:662
static void output()