#include <stdlib.h>
#include <fstream>
#include "TLeaf.h"
#include "TEventList.h"
#include "TObjString.h"
#include "TROOT.h"
#include "TMultiLayerPerceptron.h"
#include "TMVA/MethodTMlpANN.h"
#ifndef ROOT_TMVA_Tools
#include "TMVA/Tools.h"
#endif
const Bool_t EnforceNormalization__=kTRUE;
ClassImp(TMVA::MethodTMlpANN)
TMVA::MethodTMlpANN::MethodTMlpANN( const TString& jobName, const TString& methodTitle, DataSet& theData, 
                                    const TString& theOption, TDirectory* theTargetDir)
   : TMVA::MethodBase(jobName, methodTitle, theData, theOption, theTargetDir  ),
     fMLP(0),
     fLearningMethod( "" )
{
   
   InitTMlpANN();
   
   
   DeclareOptions();
   ParseOptions();
   ProcessOptions();  
}
TMVA::MethodTMlpANN::MethodTMlpANN( DataSet& theData, 
                                    const TString& theWeightFile,  
                                    TDirectory* theTargetDir )
   : TMVA::MethodBase( theData, theWeightFile, theTargetDir ),
     fMLP(0),
     fLearningMethod( "" )
{
   
   
   InitTMlpANN();
   DeclareOptions();
}
void TMVA::MethodTMlpANN::InitTMlpANN( void )
{
   
   SetMethodName( "TMlpANN" );
   SetMethodType( TMVA::Types::kTMlpANN );
   SetTestvarName();
}
TMVA::MethodTMlpANN::~MethodTMlpANN( void )
{
   
   if (fMLP != 0) delete fMLP;
}
void TMVA::MethodTMlpANN::CreateMLPOptions( TString layerSpec )
{
   
   fHiddenLayer = ":";
   while (layerSpec.Length()>0) {
      TString sToAdd="";
      if (layerSpec.First(',')<0) {
         sToAdd = layerSpec;
         layerSpec = "";
      } 
      else {
         sToAdd = layerSpec(0,layerSpec.First(','));
         layerSpec = layerSpec(layerSpec.First(',')+1,layerSpec.Length());
      }
      int nNodes = 0;
      if (sToAdd.BeginsWith("N")) { sToAdd.Remove(0,1); nNodes = GetNvar(); }
      nNodes += atoi(sToAdd);
      fHiddenLayer = Form( "%s%i:", (const char*)fHiddenLayer, nNodes );
   }
   
   std::vector<TString>::iterator itrVar    = (*fInputVars).begin();
   std::vector<TString>::iterator itrVarEnd = (*fInputVars).end();
   fMLPBuildOptions = "";
   for (; itrVar != itrVarEnd; itrVar++) {
      if (EnforceNormalization__) fMLPBuildOptions += "@";
      TString myVar = *itrVar; ;
      fMLPBuildOptions += myVar;
      fMLPBuildOptions += ",";
   }
   fMLPBuildOptions.Chop(); 
   
   fMLPBuildOptions += fHiddenLayer;
   fMLPBuildOptions += "type";
   fLogger << kINFO << "Use " << fNcycles << " training cycles" << Endl;
   fLogger << kINFO << "Use configuration (nodes per hidden layer): " << fHiddenLayer << Endl;  
}
void TMVA::MethodTMlpANN::DeclareOptions() 
{
   
   
   
   
   
   
   
   
   
   
   DeclareOptionRef( fNcycles  = 3000,     "NCycles",      "Number of training cycles" );
   DeclareOptionRef( fLayerSpec="N-1,N-2", "HiddenLayers", "Specification of hidden layer architecture" );
   
   DeclareOptionRef( fValidationFraction = 0.5, "ValidationFraction", 
                     "Fraction of events in training tree used for cross validation" );
   DeclareOptionRef( fLearningMethod = "Stochastic", "LearningMethod", "Learning method" );
   AddPreDefVal( TString("Stochastic") );
   AddPreDefVal( TString("Batch") );
   AddPreDefVal( TString("SteepestDescent") );
   AddPreDefVal( TString("RibierePolak") );
   AddPreDefVal( TString("FletcherReeves") );
   AddPreDefVal( TString("BFGS") );
}
void TMVA::MethodTMlpANN::ProcessOptions() 
{
   
   MethodBase::ProcessOptions();
   
   CreateMLPOptions(fLayerSpec);
   
   
   
   
   static Double_t* d = new Double_t[Data().GetNVariables()] ;
   static Int_t   type;
   gROOT->cd();
   TTree* dummyTree = new TTree("dummy","Empty dummy tree", 1);
   for (UInt_t ivar = 0; ivar<Data().GetNVariables(); ivar++) {
      TString vn = Data().GetInternalVarName(ivar);
      dummyTree->Branch(Form("%s",vn.Data()), d+ivar, Form("%s/D",vn.Data()));
   }
   dummyTree->Branch("type", &type, "type/I");
   if (fMLP != 0) delete fMLP;
   fMLP = new TMultiLayerPerceptron( fMLPBuildOptions.Data(), dummyTree );
}
Double_t TMVA::MethodTMlpANN::GetMvaValue()
{
   
   static Double_t* d = new Double_t[Data().GetNVariables()];
   for (UInt_t ivar = 0; ivar<Data().GetNVariables(); ivar++) d[ivar] = (Double_t)GetEventVal(ivar);
   return fMLP->Evaluate( 0, d );
}
void TMVA::MethodTMlpANN::Train( void )
{
   
   
   
   
   
   
   
   
   
   
   if (!CheckSanity()) fLogger << kFATAL << "<Train> sanity check failed" << Endl;
  
   fLogger << kVERBOSE << "Option string: " << GetOptions() << Endl;
   
   
   TTree* trainingTree = Data().GetTrainingTree(); 
   
  
   
   
   
   
   TString trainList = "Entry$<";
   trainList += 1.0-fValidationFraction;
   trainList += "*";
   trainList += (Int_t)Data().GetNEvtSigTrain();
   trainList += " || (Entry$>";
   trainList += (Int_t)Data().GetNEvtSigTrain();
   trainList += " && Entry$<";
   trainList += (Int_t)(Data().GetNEvtSigTrain() + (1.0 - fValidationFraction)*Data().GetNEvtBkgdTrain());
   trainList += ")";
   TString testList  = TString("!(") + trainList + ")";
   
   fLogger << kINFO << "Requirement for training   events: \"" << trainList << "\"" << Endl;
   fLogger << kINFO << "Requirement for validation events: \"" << testList << "\"" << Endl;
   
   if (fMLP != 0) delete fMLP;
   fMLP = new TMultiLayerPerceptron( fMLPBuildOptions.Data(), 
                                     trainingTree,
                                     trainList,
                                     testList );
  
   
#if ROOT_VERSION_CODE > ROOT_VERSION(5,13,06)
   TMultiLayerPerceptron::ELearningMethod learningMethod = TMultiLayerPerceptron::kStochastic; 
#else
   TMultiLayerPerceptron::LearningMethod  learningMethod = TMultiLayerPerceptron::kStochastic; 
#endif
   fLearningMethod.ToLower();
   if      (fLearningMethod == "stochastic"      ) learningMethod = TMultiLayerPerceptron::kStochastic;
   else if (fLearningMethod == "batch"           ) learningMethod = TMultiLayerPerceptron::kBatch;
   else if (fLearningMethod == "steepestdescent" ) learningMethod = TMultiLayerPerceptron::kSteepestDescent;
   else if (fLearningMethod == "ribierepolak"    ) learningMethod = TMultiLayerPerceptron::kRibierePolak;
   else if (fLearningMethod == "fletcherreeves"  ) learningMethod = TMultiLayerPerceptron::kFletcherReeves;
   else if (fLearningMethod == "bfgs"            ) learningMethod = TMultiLayerPerceptron::kBFGS;
   else {
      fLogger << kFATAL << "Unknown Learning Method: \"" << fLearningMethod << "\"" << Endl;
   }
   fMLP->SetLearningMethod( learningMethod );
   
   fMLP->Train( fNcycles, "text,update=50" );
}
void  TMVA::MethodTMlpANN::WriteWeightsToStream( ostream & o ) const
{
   
   
   
   
   fMLP->DumpWeights( "weights/TMlp.nn.weights.temp" );
   
   std::ifstream inf( "weights/TMlp.nn.weights.temp" );
   
   o << inf.rdbuf();
   inf.close();
   
   
}
  
void  TMVA::MethodTMlpANN::ReadWeightsFromStream( istream & istr )
{
   
   
   
   std::ofstream fout( "weights/TMlp.nn.weights.temp" );
   fout << istr.rdbuf();
   fout.close();
   
   
   fLogger << kINFO << "Load TMLP weights" << Endl;
   fMLP->LoadWeights( "weights/TMlp.nn.weights.temp" );
   
   
}
void TMVA::MethodTMlpANN::MakeClassSpecific( std::ostream& fout, const TString& className ) const
{
   
   fout << "   // not implemented for class: \"" << className << "\"" << std::endl;
   fout << "};" << std::endl;
}
void TMVA::MethodTMlpANN::GetHelpMessage() const
{
   
   
   
   
   fLogger << Endl;
   fLogger << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
   fLogger << Endl;
   fLogger << "<None>" << Endl;
   fLogger << Endl;
   fLogger << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
   fLogger << Endl;
   fLogger << "<None>" << Endl;
   fLogger << Endl;
   fLogger << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
   fLogger << Endl;
   fLogger << "<None>" << Endl;
}
Last change: Tue May 13 17:20:52 2008
Last generated: 2008-05-13 17:20
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.