#include "Riostream.h"
#include "TMath.h"
#include "TVectorD.h"
#include "TH1.h"
#include "TH2.h"
#include "TProfile.h"
#include "TMVA/VariableTransformBase.h"
#include "TMVA/Ranking.h"
#include "TMVA/Config.h"
#include "TMVA/Tools.h"
ClassImp(TMVA::VariableTransformBase)
TMVA::VariableTransformBase::VariableTransformBase( std::vector<VariableInfo>& varinfo, Types::EVariableTransform tf )
   : TObject(),
     fEvent( 0 ),
     fEventRaw( 0 ),
     fVariableTransform(tf),
     fEnabled( kTRUE ),
     fCreated( kFALSE ),
     fNormalise( kFALSE ),
     fTransformName("TransBase"),
     fVariables( varinfo ),
     fCurrentTree(0), 
     fCurrentEvtIdx(0),
     fOutputBaseDir(0),
     fLogger( GetName(), kINFO )
{
   
   std::vector<VariableInfo>::iterator it = fVariables.begin();
   for (; it!=fVariables.end(); it++ ) (*it).ResetMinMax();
}
TMVA::VariableTransformBase::~VariableTransformBase()
{
   
   if (fEvent != fEventRaw && fEvent != 0) { delete fEvent; fEvent = 0; }
   if (fEventRaw != 0)                     { delete fEventRaw; fEventRaw = 0; }
}
void TMVA::VariableTransformBase::ResetBranchAddresses( TTree* tree ) const
{
   
   tree->ResetBranchAddresses();
   fCurrentTree = 0;
   GetEventRaw().SetBranchAddresses(tree);
}
void TMVA::VariableTransformBase::CreateEvent() const 
{
   
   
   
   
   
   
   
   
   Bool_t allowExternalLinks = kFALSE;
   fEvent = new Event(fVariables, allowExternalLinks); 
}
Bool_t TMVA::VariableTransformBase::ReadEvent( TTree* tr, UInt_t evidx, Types::ESBType type ) const
{
   
   
   if (tr == 0) fLogger << kFATAL << "<ReadEvent> zero Tree Pointer encountered" << Endl;
   Bool_t needRead = kFALSE;
   if (fEventRaw == 0) {
      needRead = kTRUE;
      GetEventRaw();
      ResetBranchAddresses( tr );
   }   
   if (tr != fCurrentTree) {
      needRead = kTRUE;
      if (fCurrentTree!=0) fCurrentTree->ResetBranchAddresses();
      fCurrentTree = tr;
      ResetBranchAddresses( tr );
   }
   if (evidx != fCurrentEvtIdx) {
      needRead = kTRUE;
      fCurrentEvtIdx = evidx;
   }
   if (!needRead) return kTRUE;
   
   
   std::vector<TBranch*>::iterator brIt = fEventRaw->Branches().begin();
   for (;brIt!=fEventRaw->Branches().end(); brIt++) (*brIt)->GetEntry(evidx);
   if (type == Types::kTrueType ) type = fEventRaw->IsSignal() ? Types::kSignal : Types::kBackground;
   ApplyTransformation(type);
   return kTRUE;
}
void TMVA::VariableTransformBase::UpdateNorm ( Int_t ivar,  Double_t x ) 
{
   
   if (x < fVariables[ivar].GetMin()) fVariables[ivar].SetMin( x );
   if (x > fVariables[ivar].GetMax()) fVariables[ivar].SetMax( x );
}
void TMVA::VariableTransformBase::CalcNorm( TTree * tr )
{
   
   
   if (!IsCreated()) return;
   
   if (tr == 0) return;
   ResetBranchAddresses( tr );
   UInt_t nvar = GetNVariables();
   UInt_t nevts = tr->GetEntries();
   TVectorD x2( nvar ); x2 *= 0;
   TVectorD x0( nvar ); x0 *= 0;   
   Double_t sumOfWeights = 0;
   for (UInt_t ievt=0; ievt<nevts; ievt++) {
      ReadEvent( tr, ievt, Types::kSignal );
      Double_t weight = GetEvent().GetWeight();
      sumOfWeights += weight;
      for (UInt_t ivar=0; ivar<nvar; ivar++) {
         Double_t x = GetEvent().GetVal(ivar);
         UpdateNorm( ivar,  x );
         x0(ivar) += x*weight;
         x2(ivar) += x*x*weight;
      }
   }
   
   for (UInt_t ivar=0; ivar<nvar; ivar++) {
      Double_t mean = x0(ivar)/sumOfWeights;
      fVariables[ivar].SetMean( mean ); 
      fVariables[ivar].SetRMS( TMath::Sqrt( x2(ivar)/sumOfWeights - mean*mean) ); 
   }
   fLogger << kVERBOSE << "Set minNorm/maxNorm for variables to: " << Endl;
   fLogger << setprecision(3);
   for (UInt_t ivar=0; ivar<GetNVariables(); ivar++)
      fLogger << "    " << fVariables[ivar].GetInternalVarName()
              << "\t: [" << fVariables[ivar].GetMin() << "\t, " << fVariables[ivar].GetMax() << "\t] " << Endl;
   fLogger << setprecision(5); 
}
void TMVA::VariableTransformBase::PlotVariables( TTree* theTree )
{
   
   
   
   if (!IsCreated()) return;
   
   if (theTree == 0) return;
   ResetBranchAddresses( theTree );
   
   fLogger << kVERBOSE << "Plot input variables from '" << theTree->GetName() << "'" << Endl;
   
   TString transfType = "_"; transfType += GetName();
   const UInt_t nvar = GetNVariables();
   
   TVectorD x2S( nvar ); x2S *= 0;
   TVectorD x2B( nvar ); x2B *= 0;
   TVectorD x0S( nvar ); x0S *= 0;   
   TVectorD x0B( nvar ); x0B *= 0;      
   TVectorD rmsS( nvar ), meanS( nvar ); 
   TVectorD rmsB( nvar ), meanB( nvar ); 
   
   UInt_t nevts = (UInt_t)theTree->GetEntries();
   Double_t nS = 0, nB = 0;
   for (UInt_t ievt=0; ievt<nevts; ievt++) {
      ReadEvent( theTree, ievt, Types::kSignal );
      Double_t weight = GetEvent().GetWeight();
      if (GetEvent().IsSignal()) nS += weight; 
      else                       nB += weight;
      for (UInt_t ivar=0; ivar<nvar; ivar++) {
         Double_t x = GetEvent().GetVal(ivar);
         if (GetEvent().IsSignal()) {
            x0S(ivar) += x*weight;
            x2S(ivar) += x*x*weight;
         }
         else {
            x0B(ivar) += x*weight;
            x2B(ivar) += x*x*weight;
         }
      }
   }
   for (UInt_t ivar=0; ivar<nvar; ivar++) {
      meanS(ivar) = x0S(ivar)/nS;
      meanB(ivar) = x0B(ivar)/nB;
      
      rmsS(ivar) = x2S(ivar)/nS - x0S(ivar)*x0S(ivar)/nS/nS;   
      rmsB(ivar) = x2B(ivar)/nB - x0B(ivar)*x0B(ivar)/nB/nB;   
      if (rmsS(ivar) <= 0) {
         fLogger << kWARNING << "Variable \"" << Variable(ivar).GetExpression() 
                 << "\" has zero or negative RMS^2 for signal " 
                 << "==> set to zero. Please check the variable content" << Endl;
         rmsS(ivar) = 0;
      }
      if (rmsB(ivar) <= 0) {
         fLogger << kWARNING << "Variable \"" << Variable(ivar).GetExpression() 
                 << "\" has zero or negative RMS^2 for background "
                 << "==> set to zero. Please check the variable content" << Endl;
         rmsB(ivar) = 0;
      }
      rmsS(ivar) = TMath::Sqrt( rmsS(ivar) );   
      rmsB(ivar) = TMath::Sqrt( rmsB(ivar) );   
   }
   
   
   std::vector<TH1F*> vS ( nvar );
   std::vector<TH1F*> vB ( nvar );
   std::vector<std::vector<TH2F*> >     mycorrS( nvar );
   std::vector<std::vector<TH2F*> >     mycorrB( nvar );
   std::vector<std::vector<TProfile*> > myprofS( nvar );
   std::vector<std::vector<TProfile*> > myprofB( nvar );
   for (UInt_t ivar=0; ivar < nvar; ivar++) {
      mycorrS[ivar].resize(nvar);
      mycorrB[ivar].resize(nvar);
      myprofS[ivar].resize(nvar);
      myprofB[ivar].resize(nvar);
   }
   
   
   
   if (nvar > (UInt_t)gConfig().GetVariablePlotting().fMaxNumOfAllowedVariablesForScatterPlots) {
      Int_t nhists = nvar*(nvar - 1)/2;
      fLogger << kINFO << Tools::Color("dgreen") << Endl;
      fLogger << kINFO << "<PlotVariables> Will not produce scatter plots ==> " << Endl;
      fLogger << kINFO
              << "|  The number of " << nvar << " input variables would require " << nhists << " two-dimensional" << Endl;
      fLogger << kINFO
              << "|  histograms, which would occupy the computer's memory. Note that this" << Endl;
      fLogger << kINFO
              << "|  suppression does not have any consequences for your analysis, other" << Endl;
      fLogger << kINFO
              << "|  than not disposing of these scatter plots. You can modify the maximum" << Endl;
      fLogger << kINFO
              << "|  number of input variables allowed to generate scatter plots in your" << Endl; 
      fLogger << "|  script via the command line:" << Endl;
      fLogger << kINFO
              << "|  \"(TMVA::gConfig().GetVariablePlotting()).fMaxNumOfAllowedVariablesForScatterPlots = <some int>;\""
              << Tools::Color("reset") << Endl;
      fLogger << Endl;
      fLogger << kINFO << "Some more output" << Endl;
   }
   Float_t timesRMS  = gConfig().GetVariablePlotting().fTimesRMS;
   UInt_t  nbins1D   = gConfig().GetVariablePlotting().fNbins1D;
   UInt_t  nbins2D   = gConfig().GetVariablePlotting().fNbins2D;
   for (UInt_t i=0; i<nvar; i++) {
      TString myVari = Variable(i).GetInternalVarName();  
      
      if (Variable(i).GetVarType() == 'I') {
         
         Int_t xmin = TMath::Nint( Variable(i).GetMin() );
         Int_t xmax = TMath::Nint( Variable(i).GetMax() + 1 );
         Int_t nbins = xmax - xmin;
         vS[i] = new TH1F( Form("%s__S%s", myVari.Data(), transfType.Data()), Variable(i).GetExpression(), nbins, xmin, xmax );
         vB[i] = new TH1F( Form("%s__B%s", myVari.Data(), transfType.Data()), Variable(i).GetExpression(), nbins, xmin, xmax );
      }
      else {
         Double_t xmin = TMath::Max( Variable(i).GetMin(), TMath::Min( meanS(i) - timesRMS*rmsS(i), meanB(i) - timesRMS*rmsB(i) ) );
         Double_t xmax = TMath::Min( Variable(i).GetMax(), TMath::Max( meanS(i) + timesRMS*rmsS(i), meanB(i) + timesRMS*rmsB(i) ) );
         
         
         if (xmin >= xmax) xmax = xmin*1.1; 
         if (xmin >= xmax) xmax = xmin + 1; 
         vS[i] = new TH1F( Form("%s__S%s", myVari.Data(), transfType.Data()), Variable(i).GetExpression(), nbins1D, xmin, xmax );
         vB[i] = new TH1F( Form("%s__B%s", myVari.Data(), transfType.Data()), Variable(i).GetExpression(), nbins1D, xmin, xmax );         
      }
      vS[i]->SetXTitle(Variable(i).GetExpression());
      vB[i]->SetXTitle(Variable(i).GetExpression());
      vS[i]->SetLineColor(4);
      vB[i]->SetLineColor(2);
      
      
      if (nvar <= (UInt_t)gConfig().GetVariablePlotting().fMaxNumOfAllowedVariablesForScatterPlots) {
         for (UInt_t j=i+1; j<nvar; j++) {
            TString myVarj = Variable(j).GetInternalVarName();  
            
            mycorrS[i][j] = new TH2F( Form( "scat_%s_vs_%s_sig%s", myVarj.Data(), myVari.Data(), transfType.Data() ), 
                                      Form( "%s versus %s (signal)%s", myVarj.Data(), myVari.Data(), transfType.Data() ), 
                                      nbins2D, Variable(i).GetMin(), Variable(i).GetMax(), 
                                      nbins2D, Variable(j).GetMin(), Variable(j).GetMax() );
            mycorrS[i][j]->SetXTitle(Variable(i).GetExpression());
            mycorrS[i][j]->SetYTitle(Variable(j).GetExpression());
            mycorrB[i][j] = new TH2F( Form( "scat_%s_vs_%s_bgd%s", myVarj.Data(), myVari.Data(), transfType.Data() ), 
                                      Form( "%s versus %s (background)%s", myVarj.Data(), myVari.Data(), transfType.Data() ), 
                                      nbins2D, Variable(i).GetMin(), Variable(i).GetMax(), 
                                      nbins2D, Variable(j).GetMin(), Variable(j).GetMax() );
            mycorrB[i][j]->SetXTitle(Variable(i).GetExpression());
            mycorrB[i][j]->SetYTitle(Variable(j).GetExpression());
            
            myprofS[i][j] = new TProfile( Form( "prof_%s_vs_%s_sig%s", myVarj.Data(), myVari.Data(), transfType.Data() ), 
                                          Form( "profile %s versus %s (signal)%s", myVarj.Data(), myVari.Data(), transfType.Data() ), 
                                          nbins1D, Variable(i).GetMin(), Variable(i).GetMax() );
            myprofB[i][j] = new TProfile( Form( "prof_%s_vs_%s_bgd%s", myVarj.Data(), myVari.Data(), transfType.Data() ), 
                                          Form( "profile %s versus %s (background)%s", myVarj.Data(), myVari.Data(), transfType.Data() ), 
                                          nbins1D, Variable(i).GetMin(), Variable(i).GetMax() );
         }
      }   
   }
   
   for (Int_t ievt=0; ievt<theTree->GetEntries(); ievt++) {
      ReadEvent( theTree, ievt, Types::kSignal );
      Float_t weight = GetEvent().GetWeight();
      for (UInt_t i=0; i<nvar; i++) {
         Float_t vali = GetEvent().GetVal(i);
         
         if (GetEvent().IsSignal()) vS[i]->Fill( vali, weight );
         else                       vB[i]->Fill( vali, weight );
         
         
         if (nvar <= (UInt_t)gConfig().GetVariablePlotting().fMaxNumOfAllowedVariablesForScatterPlots) {
            for (UInt_t j=i+1; j<nvar; j++) {
               Float_t valj = GetEvent().GetVal(j);
               if (GetEvent().IsSignal()) {
                  mycorrS[i][j]->Fill( vali, valj, weight );
                  myprofS[i][j]->Fill( vali, valj, weight );
               }
               else {
                  mycorrB[i][j]->Fill( vali, valj, weight );
                  myprofB[i][j]->Fill( vali, valj, weight );
               }
            }
         }
      }
   }
      
   
   fRanking = new Ranking( GetName(), "Separation" );
   for (UInt_t i=0; i<nvar; i++) {   
      Double_t sep = Tools::GetSeparation( vS[i], vB[i] );
      fRanking->AddRank( *new Rank( vS[i]->GetTitle(), sep ) );
   }
   
   
   TString outputDir = TString("InputVariables_") + GetName();
   TObject* o = GetOutputBaseDir()->FindObject(outputDir);
   if (o != 0) {
      fLogger << kFATAL << "A " << o->ClassName() << " already exists in " 
              << GetOutputBaseDir()->GetPath() << Endl;
   }
   TDirectory* localDir = GetOutputBaseDir()->mkdir( outputDir );
   localDir->cd();
   fLogger << kVERBOSE << "Create and switch to directory " << localDir->GetPath() << Endl;
   for (UInt_t i=0; i<nvar; i++) {
      vS[i]->Write();
      vB[i]->Write();
      vS[i]->SetDirectory(0);
      vB[i]->SetDirectory(0);
   }
   
   if (nvar <= (UInt_t)gConfig().GetVariablePlotting().fMaxNumOfAllowedVariablesForScatterPlots) {
      localDir = localDir->mkdir( "CorrelationPlots" );
      localDir ->cd();
      fLogger << kINFO << "Create scatter and profile plots in target-file directory: " << Endl;
      fLogger << kINFO << localDir->GetPath() << Endl;
      
      for (UInt_t i=0; i<nvar; i++) {
         for (UInt_t j=i+1; j<nvar; j++) {
            mycorrS[i][j]->Write();
            mycorrB[i][j]->Write();
            myprofS[i][j]->Write();
            myprofB[i][j]->Write();
            mycorrS[i][j]->SetDirectory(0);
            mycorrB[i][j]->SetDirectory(0);
            myprofS[i][j]->SetDirectory(0);
            myprofB[i][j]->SetDirectory(0);
         }
      }         
   }
   GetOutputBaseDir()->cd();
   theTree->ResetBranchAddresses();
}
void TMVA::VariableTransformBase::PrintVariableRanking() const
{
   
   fLogger << kINFO << "Ranking input variables..." << Endl;
   fRanking->Print();
}
void TMVA::VariableTransformBase::WriteVarsToStream( std::ostream& o, const TString& prefix ) const 
{
   
   
   o << prefix << "NVar " << GetNVariables() << endl;
   std::vector<VariableInfo>::const_iterator varIt = fVariables.begin();
   for (; varIt!=fVariables.end(); varIt++) { o << prefix; varIt->WriteToStream(o); }
}
void TMVA::VariableTransformBase::ReadVarsFromStream( std::istream& istr ) 
{
   
   
   
   TString dummy;
   UInt_t readNVar;
   istr >> dummy >> readNVar;
   if (readNVar!=fVariables.size()) {
      fLogger << kFATAL << "You declared "<< fVariables.size() << " variables in the Reader"
              << " while there are " << readNVar << " variables declared in the file"
              << Endl;
   }
   
   VariableInfo varInfo;
   std::vector<VariableInfo>::iterator varIt = fVariables.begin();
   int varIdx = 0;
   for (; varIt!=fVariables.end(); varIt++, varIdx++) {
      varInfo.ReadFromStream(istr);
      if (varIt->GetExpression() == varInfo.GetExpression()) {
         varInfo.SetExternalLink((*varIt).GetExternalLink());
         (*varIt) = varInfo;
      } 
      else {
         fLogger << kINFO << "The definition (or the order) of the variables found in the input file is"  << Endl;
         fLogger << kINFO << "is not the same as the one declared in the Reader (which is necessary for" << Endl;
         fLogger << kINFO << "the correct working of the classifier):" << Endl;
         fLogger << kINFO << "   var #" << varIdx <<" declared in Reader: " << varIt->GetExpression() << Endl;
         fLogger << kINFO << "   var #" << varIdx <<" declared in file  : " << varInfo.GetExpression() << Endl;
         fLogger << kFATAL << "The expression declared to the Reader needs to be checked (name or order are wrong)" << Endl;
      }
   }
}
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.