ROOT logo
// @(#)root/tmva $Id: VariableTransformBase.cxx 29195 2009-06-24 10:39:49Z brun $
// Author: Andreas Hoecker, Joerg Stelzer, Helge Voss

/**********************************************************************************
 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
 * Package: TMVA                                                                  *
 * Class  : VariableTransformBase                                                 *
 * Web    : http://tmva.sourceforge.net                                           *
 *                                                                                *
 * Description:                                                                   *
 *      Implementation (see header for description)                               *
 *                                                                                *
 * Authors (alphabetical):                                                        *
 *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
 *      Joerg Stelzer   <Joerg.Stelzer@cern.ch>  - CERN, Switzerland              *
 *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
 *                                                                                *
 * Copyright (c) 2005:                                                            *
 *      CERN, Switzerland                                                         *
 *      MPI-K Heidelberg, Germany                                                 *
 *                                                                                *
 * Redistribution and use in source and binary forms, with or without             *
 * modification, are permitted according to the terms listed in LICENSE           *
 * (http://tmva.sourceforge.net/LICENSE)                                          *
 **********************************************************************************/

#include <iomanip>

#include "TMath.h"
#include "TVectorD.h"
#include "TH1.h"
#include "TH2.h"
#include "TProfile.h"

#include "TMVA/VariableTransformBase.h"
#include "TMVA/Ranking.h"
#include "TMVA/Config.h"
#include "TMVA/Tools.h"

#ifndef ROOT_TMVA_MsgLogger
#include "TMVA/MsgLogger.h"
#endif

ClassImp(TMVA::VariableTransformBase)

//_______________________________________________________________________
TMVA::VariableTransformBase::VariableTransformBase( DataSetInfo& dsi,
                                                    Types::EVariableTransform tf,
                                                    const TString& trfName )
   : TObject(),
     fDsi(dsi),
     fTransformedEvent(0),
     fBackTransformedEvent(0),
     fVariableTransform(tf),
     fEnabled( kTRUE ),
     fCreated( kFALSE ),
     fNormalise( kFALSE ),
     fTransformName(trfName),
     fLogger( new MsgLogger(this, kINFO) )
{
   // standard constructor

   for (UInt_t ivar = 0; ivar < fDsi.GetNVariables(); ivar++) {
      fVariables.push_back( VariableInfo( fDsi.GetVariableInfo(ivar) ) );
   }
   for (UInt_t itgt = 0; itgt < fDsi.GetNTargets(); itgt++) {
      fTargets.push_back( VariableInfo( fDsi.GetTargetInfo(itgt) ) );
   }
}

//_______________________________________________________________________
TMVA::VariableTransformBase::~VariableTransformBase()
{
   if (fTransformedEvent!=0)     delete fTransformedEvent;
   if (fBackTransformedEvent!=0) delete fBackTransformedEvent;
   // destructor
   delete fLogger;
}

//_______________________________________________________________________
void TMVA::VariableTransformBase::CalcNorm( const std::vector<Event*>& events ) 
{
   // method to calculate minimum, maximum, mean, and RMS for all
   // variables used in the MVA

   if (!IsCreated()) return;

   const UInt_t nvars = GetNVariables();
   const UInt_t ntgts = GetNTargets();

   UInt_t nevts = events.size();

   TVectorD x2( nvars+ntgts ); x2 *= 0;
   TVectorD x0( nvars+ntgts ); x0 *= 0;   

   Double_t sumOfWeights = 0;
   for (UInt_t ievt=0; ievt<nevts; ievt++) {
      const Event* ev = events[ievt];

      Double_t weight = ev->GetWeight();
      sumOfWeights += weight;
      for (UInt_t ivar=0; ivar<nvars; ivar++) {
         Double_t x = ev->GetValue(ivar);
         if (ievt==0) {
            Variables().at(ivar).SetMin(x);
            Variables().at(ivar).SetMax(x);
         } 
         else {
            UpdateNorm( ivar,  x );
         }
         x0(ivar) += x*weight;
         x2(ivar) += x*x*weight;
      }
      for (UInt_t itgt=0; itgt<ntgts; itgt++) {
         Double_t x = ev->GetTarget(itgt);
         if (ievt==0) {
            Targets().at(itgt).SetMin(x);
            Targets().at(itgt).SetMax(x);
         } 
         else {
            UpdateNorm( nvars+itgt,  x );
         }
         x0(nvars+itgt) += x*weight;
         x2(nvars+itgt) += x*x*weight;
      }
   }

   // set Mean and RMS
   for (UInt_t ivar=0; ivar<nvars; ivar++) {
      Double_t mean = x0(ivar)/sumOfWeights;
      Variables().at(ivar).SetMean( mean ); 
      Variables().at(ivar).SetRMS( TMath::Sqrt( x2(ivar)/sumOfWeights - mean*mean) );
   }
   for (UInt_t itgt=0; itgt<ntgts; itgt++) {
      Double_t mean = x0(nvars+itgt)/sumOfWeights;
      Targets().at(itgt).SetMean( mean ); 
      Targets().at(itgt).SetRMS( TMath::Sqrt( x2(nvars+itgt)/sumOfWeights - mean*mean) );
   }

   Log() << kVERBOSE << "Set minNorm/maxNorm for variables to: " << Endl;
   Log() << std::setprecision(3);
   for (UInt_t ivar=0; ivar<GetNVariables(); ivar++)
      Log() << "    " << Variables().at(ivar).GetInternalName()
              << "\t: [" << Variables().at(ivar).GetMin() << "\t, " << Variables().at(ivar).GetMax() << "\t] " << Endl;
   Log() << kVERBOSE << "Set minNorm/maxNorm for targets to: " << Endl;
   Log() << std::setprecision(3);
   for (UInt_t itgt=0; itgt<GetNTargets(); itgt++)
      Log() << "    " << Targets().at(itgt).GetInternalName()
              << "\t: [" << Targets().at(itgt).GetMin() << "\t, " << Targets().at(itgt).GetMax() << "\t] " << Endl;
   Log() << std::setprecision(5); // reset to better value       
}

//_______________________________________________________________________
std::vector<TString>* TMVA::VariableTransformBase::GetTransformationStrings( Int_t /*cls*/ ) const
{
   // default transformation output
   // --> only indicate that transformation occurred
   std::vector<TString>* strVec = new std::vector<TString>;
   for (UInt_t ivar=0; ivar<GetNVariables(); ivar++) {
      strVec->push_back( Variables()[ivar].GetLabel() + "_[transformed]");
   }

   return strVec;   
}

//_______________________________________________________________________
void TMVA::VariableTransformBase::UpdateNorm ( Int_t ivar,  Double_t x ) 
{
   // update min and max of a given variable (target) and a given transformation method
   Int_t nvars = fDsi.GetNVariables();
   if( ivar < nvars ){
      if (x < Variables().at(ivar).GetMin()) Variables().at(ivar).SetMin(x);
      if (x > Variables().at(ivar).GetMax()) Variables().at(ivar).SetMax(x);
   }else{
      if (x < Targets().at(ivar-nvars).GetMin()) Targets().at(ivar-nvars).SetMin(x);
      if (x > Targets().at(ivar-nvars).GetMax()) Targets().at(ivar-nvars).SetMax(x);
   }
}

// TODO

// //_______________________________________________________________________
// void TMVA::VariableTransformBase::WriteVarsToStream( std::ostream& o, const TString& prefix ) const 
// {
//    // write the list of variables (name, min, max) for a given data
//    // transformation method to the stream
//    o << prefix << "NVar " << GetNVariables() << endl;
//    std::vector<VariableInfo>::const_iterator varIt = Variables().begin();
//    for (; varIt!=Variables().end(); varIt++) { o << prefix; varIt->WriteToStream(o); }
// }

// //_______________________________________________________________________
// void TMVA::VariableTransformBase::ReadVarsFromStream( std::istream& istr ) 
// {
//    // Read the variables (name, min, max) for a given data
//    // transformation method from the stream. In the stream we only
//    // expect the limits which will be set

//    TString dummy;
//    UInt_t readNVar;
//    istr >> dummy >> readNVar;

//    if (readNVar!=Variables().size()) {
//       Log() << kFATAL << "You declared "<< Variables().size() << " variables in the Reader"
//               << " while there are " << readNVar << " variables declared in the file"
//               << Endl;
//    }

//    // we want to make sure all variables are read in the order they are defined
//    VariableInfo varInfo;
//    std::vector<VariableInfo>::iterator varIt = Variables().begin();
//    int varIdx = 0;
//    for (; varIt!=Variables().end(); varIt++, varIdx++) {
//       varInfo.ReadFromStream(istr);
//       if (varIt->GetExpression() == varInfo.GetExpression()) {
//          varInfo.SetExternalLink((*varIt).GetExternalLink());
//          (*varIt) = varInfo;
//       } 
//       else {
//          Log() << kINFO << "The definition (or the order) of the variables found in the input file is"  << Endl;
//          Log() << kINFO << "is not the same as the one declared in the Reader (which is necessary for" << Endl;
//          Log() << kINFO << "the correct working of the classifier):" << Endl;
//          Log() << kINFO << "   var #" << varIdx <<" declared in Reader: " << varIt->GetExpression() << Endl;
//          Log() << kINFO << "   var #" << varIdx <<" declared in file  : " << varInfo.GetExpression() << Endl;
//          Log() << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
//       }
//    }
// }
 VariableTransformBase.cxx:1
 VariableTransformBase.cxx:2
 VariableTransformBase.cxx:3
 VariableTransformBase.cxx:4
 VariableTransformBase.cxx:5
 VariableTransformBase.cxx:6
 VariableTransformBase.cxx:7
 VariableTransformBase.cxx:8
 VariableTransformBase.cxx:9
 VariableTransformBase.cxx:10
 VariableTransformBase.cxx:11
 VariableTransformBase.cxx:12
 VariableTransformBase.cxx:13
 VariableTransformBase.cxx:14
 VariableTransformBase.cxx:15
 VariableTransformBase.cxx:16
 VariableTransformBase.cxx:17
 VariableTransformBase.cxx:18
 VariableTransformBase.cxx:19
 VariableTransformBase.cxx:20
 VariableTransformBase.cxx:21
 VariableTransformBase.cxx:22
 VariableTransformBase.cxx:23
 VariableTransformBase.cxx:24
 VariableTransformBase.cxx:25
 VariableTransformBase.cxx:26
 VariableTransformBase.cxx:27
 VariableTransformBase.cxx:28
 VariableTransformBase.cxx:29
 VariableTransformBase.cxx:30
 VariableTransformBase.cxx:31
 VariableTransformBase.cxx:32
 VariableTransformBase.cxx:33
 VariableTransformBase.cxx:34
 VariableTransformBase.cxx:35
 VariableTransformBase.cxx:36
 VariableTransformBase.cxx:37
 VariableTransformBase.cxx:38
 VariableTransformBase.cxx:39
 VariableTransformBase.cxx:40
 VariableTransformBase.cxx:41
 VariableTransformBase.cxx:42
 VariableTransformBase.cxx:43
 VariableTransformBase.cxx:44
 VariableTransformBase.cxx:45
 VariableTransformBase.cxx:46
 VariableTransformBase.cxx:47
 VariableTransformBase.cxx:48
 VariableTransformBase.cxx:49
 VariableTransformBase.cxx:50
 VariableTransformBase.cxx:51
 VariableTransformBase.cxx:52
 VariableTransformBase.cxx:53
 VariableTransformBase.cxx:54
 VariableTransformBase.cxx:55
 VariableTransformBase.cxx:56
 VariableTransformBase.cxx:57
 VariableTransformBase.cxx:58
 VariableTransformBase.cxx:59
 VariableTransformBase.cxx:60
 VariableTransformBase.cxx:61
 VariableTransformBase.cxx:62
 VariableTransformBase.cxx:63
 VariableTransformBase.cxx:64
 VariableTransformBase.cxx:65
 VariableTransformBase.cxx:66
 VariableTransformBase.cxx:67
 VariableTransformBase.cxx:68
 VariableTransformBase.cxx:69
 VariableTransformBase.cxx:70
 VariableTransformBase.cxx:71
 VariableTransformBase.cxx:72
 VariableTransformBase.cxx:73
 VariableTransformBase.cxx:74
 VariableTransformBase.cxx:75
 VariableTransformBase.cxx:76
 VariableTransformBase.cxx:77
 VariableTransformBase.cxx:78
 VariableTransformBase.cxx:79
 VariableTransformBase.cxx:80
 VariableTransformBase.cxx:81
 VariableTransformBase.cxx:82
 VariableTransformBase.cxx:83
 VariableTransformBase.cxx:84
 VariableTransformBase.cxx:85
 VariableTransformBase.cxx:86
 VariableTransformBase.cxx:87
 VariableTransformBase.cxx:88
 VariableTransformBase.cxx:89
 VariableTransformBase.cxx:90
 VariableTransformBase.cxx:91
 VariableTransformBase.cxx:92
 VariableTransformBase.cxx:93
 VariableTransformBase.cxx:94
 VariableTransformBase.cxx:95
 VariableTransformBase.cxx:96
 VariableTransformBase.cxx:97
 VariableTransformBase.cxx:98
 VariableTransformBase.cxx:99
 VariableTransformBase.cxx:100
 VariableTransformBase.cxx:101
 VariableTransformBase.cxx:102
 VariableTransformBase.cxx:103
 VariableTransformBase.cxx:104
 VariableTransformBase.cxx:105
 VariableTransformBase.cxx:106
 VariableTransformBase.cxx:107
 VariableTransformBase.cxx:108
 VariableTransformBase.cxx:109
 VariableTransformBase.cxx:110
 VariableTransformBase.cxx:111
 VariableTransformBase.cxx:112
 VariableTransformBase.cxx:113
 VariableTransformBase.cxx:114
 VariableTransformBase.cxx:115
 VariableTransformBase.cxx:116
 VariableTransformBase.cxx:117
 VariableTransformBase.cxx:118
 VariableTransformBase.cxx:119
 VariableTransformBase.cxx:120
 VariableTransformBase.cxx:121
 VariableTransformBase.cxx:122
 VariableTransformBase.cxx:123
 VariableTransformBase.cxx:124
 VariableTransformBase.cxx:125
 VariableTransformBase.cxx:126
 VariableTransformBase.cxx:127
 VariableTransformBase.cxx:128
 VariableTransformBase.cxx:129
 VariableTransformBase.cxx:130
 VariableTransformBase.cxx:131
 VariableTransformBase.cxx:132
 VariableTransformBase.cxx:133
 VariableTransformBase.cxx:134
 VariableTransformBase.cxx:135
 VariableTransformBase.cxx:136
 VariableTransformBase.cxx:137
 VariableTransformBase.cxx:138
 VariableTransformBase.cxx:139
 VariableTransformBase.cxx:140
 VariableTransformBase.cxx:141
 VariableTransformBase.cxx:142
 VariableTransformBase.cxx:143
 VariableTransformBase.cxx:144
 VariableTransformBase.cxx:145
 VariableTransformBase.cxx:146
 VariableTransformBase.cxx:147
 VariableTransformBase.cxx:148
 VariableTransformBase.cxx:149
 VariableTransformBase.cxx:150
 VariableTransformBase.cxx:151
 VariableTransformBase.cxx:152
 VariableTransformBase.cxx:153
 VariableTransformBase.cxx:154
 VariableTransformBase.cxx:155
 VariableTransformBase.cxx:156
 VariableTransformBase.cxx:157
 VariableTransformBase.cxx:158
 VariableTransformBase.cxx:159
 VariableTransformBase.cxx:160
 VariableTransformBase.cxx:161
 VariableTransformBase.cxx:162
 VariableTransformBase.cxx:163
 VariableTransformBase.cxx:164
 VariableTransformBase.cxx:165
 VariableTransformBase.cxx:166
 VariableTransformBase.cxx:167
 VariableTransformBase.cxx:168
 VariableTransformBase.cxx:169
 VariableTransformBase.cxx:170
 VariableTransformBase.cxx:171
 VariableTransformBase.cxx:172
 VariableTransformBase.cxx:173
 VariableTransformBase.cxx:174
 VariableTransformBase.cxx:175
 VariableTransformBase.cxx:176
 VariableTransformBase.cxx:177
 VariableTransformBase.cxx:178
 VariableTransformBase.cxx:179
 VariableTransformBase.cxx:180
 VariableTransformBase.cxx:181
 VariableTransformBase.cxx:182
 VariableTransformBase.cxx:183
 VariableTransformBase.cxx:184
 VariableTransformBase.cxx:185
 VariableTransformBase.cxx:186
 VariableTransformBase.cxx:187
 VariableTransformBase.cxx:188
 VariableTransformBase.cxx:189
 VariableTransformBase.cxx:190
 VariableTransformBase.cxx:191
 VariableTransformBase.cxx:192
 VariableTransformBase.cxx:193
 VariableTransformBase.cxx:194
 VariableTransformBase.cxx:195
 VariableTransformBase.cxx:196
 VariableTransformBase.cxx:197
 VariableTransformBase.cxx:198
 VariableTransformBase.cxx:199
 VariableTransformBase.cxx:200
 VariableTransformBase.cxx:201
 VariableTransformBase.cxx:202
 VariableTransformBase.cxx:203
 VariableTransformBase.cxx:204
 VariableTransformBase.cxx:205
 VariableTransformBase.cxx:206
 VariableTransformBase.cxx:207
 VariableTransformBase.cxx:208
 VariableTransformBase.cxx:209
 VariableTransformBase.cxx:210
 VariableTransformBase.cxx:211
 VariableTransformBase.cxx:212
 VariableTransformBase.cxx:213
 VariableTransformBase.cxx:214
 VariableTransformBase.cxx:215
 VariableTransformBase.cxx:216
 VariableTransformBase.cxx:217
 VariableTransformBase.cxx:218
 VariableTransformBase.cxx:219
 VariableTransformBase.cxx:220
 VariableTransformBase.cxx:221
 VariableTransformBase.cxx:222
 VariableTransformBase.cxx:223
 VariableTransformBase.cxx:224
 VariableTransformBase.cxx:225
 VariableTransformBase.cxx:226
 VariableTransformBase.cxx:227
 VariableTransformBase.cxx:228