#include <algorithm>
#include <iomanip>
#include <vector>
#include <iostream>
#include "Riostream.h"
#include "TRandom3.h"
#include "TMath.h"
#include "TObjString.h"
#include "TH1F.h"
#include "TGraph.h"
#include "TSpline.h"
#include "TDirectory.h"
#include "TTreeFormula.h"
#include "TMVA/MethodCategory.h"
#include "TMVA/Tools.h"
#include "TMVA/ClassifierFactory.h"
#include "TMVA/Timer.h"
#include "TMVA/Types.h"
#include "TMVA/PDF.h"
#include "TMVA/Config.h"
#include "TMVA/Ranking.h"
#include "TMVA/VariableInfo.h"
#include "TMVA/DataSetManager.h"
#include "TMVA/VariableRearrangeTransform.h"
REGISTER_METHOD(Category)
ClassImp(TMVA::MethodCategory)
TMVA::MethodCategory::MethodCategory( const TString& jobName,
const TString& methodTitle,
DataSetInfo& theData,
const TString& theOption,
TDirectory* theTargetDir )
: TMVA::MethodCompositeBase( jobName, Types::kCategory, methodTitle, theData, theOption, theTargetDir ),
fCatTree(0),
fDataSetManager(NULL)
{
}
TMVA::MethodCategory::MethodCategory( DataSetInfo& dsi,
const TString& theWeightFile,
TDirectory* theTargetDir )
: TMVA::MethodCompositeBase( Types::kCategory, dsi, theWeightFile, theTargetDir ),
fCatTree(0),
fDataSetManager(NULL)
{
}
TMVA::MethodCategory::~MethodCategory( void )
{
std::vector<TTreeFormula*>::iterator formIt = fCatFormulas.begin();
std::vector<TTreeFormula*>::iterator lastF = fCatFormulas.end();
for(;formIt!=lastF; ++formIt) delete *formIt;
delete fCatTree;
}
Bool_t TMVA::MethodCategory::HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets )
{
std::vector<IMethod*>::iterator itrMethod = fMethods.begin();
for(; itrMethod != fMethods.end(); ++itrMethod ) {
if ( !(*itrMethod)->HasAnalysisType(type, numberClasses, numberTargets) )
return kFALSE;
}
return kTRUE;
}
void TMVA::MethodCategory::DeclareOptions()
{
}
TMVA::IMethod* TMVA::MethodCategory::AddMethod( const TCut& theCut,
const TString& theVariables,
Types::EMVA theMethod ,
const TString& theTitle,
const TString& theOptions )
{
std::string addedMethodName = std::string(Types::Instance().GetMethodName(theMethod));
Log() << kINFO << "Adding sub-classifier: " << addedMethodName << "::" << theTitle << Endl;
DataSetInfo& dsi = CreateCategoryDSI(theCut, theVariables, theTitle);
IMethod* addedMethod = ClassifierFactory::Instance().Create(addedMethodName,GetJobName(),theTitle,dsi,theOptions);
MethodBase *method = (dynamic_cast<MethodBase*>(addedMethod));
if(method==0) return 0;
method->SetAnalysisType( fAnalysisType );
method->SetupMethod();
method->ParseOptions();
method->ProcessSetup();
const TString dirName(Form("Method_%s",method->GetMethodTypeName().Data()));
TDirectory * dir = BaseDir()->GetDirectory(dirName);
if (dir != 0) method->SetMethodBaseDir( dir );
else method->SetMethodBaseDir( BaseDir()->mkdir(dirName,Form("Directory for all %s methods", method->GetMethodTypeName().Data())) );
method->CheckSetup();
method->DisableWriting( kTRUE );
fMethods.push_back(method);
fCategoryCuts.push_back(theCut);
fVars.push_back(theVariables);
DataSetInfo& primaryDSI = DataInfo();
UInt_t newSpectatorIndex = primaryDSI.GetSpectatorInfos().size();
fCategorySpecIdx.push_back(newSpectatorIndex);
primaryDSI.AddSpectator( Form("%s_cat%i:=%s", GetName(),(int)fMethods.size(),theCut.GetTitle()),
Form("%s:%s",GetName(),method->GetName()),
"pass", 0, 0, 'C' );
return method;
}
TMVA::DataSetInfo& TMVA::MethodCategory::CreateCategoryDSI(const TCut& theCut,
const TString& theVariables,
const TString& theTitle)
{
TString dsiName=theTitle+"_dsi";
DataSetInfo& oldDSI = DataInfo();
DataSetInfo* dsi = new DataSetInfo(dsiName);
fDataSetManager->AddDataSetInfo(*dsi);
std::vector<VariableInfo>::iterator itrVarInfo;
for (itrVarInfo = oldDSI.GetTargetInfos().begin(); itrVarInfo != oldDSI.GetTargetInfos().end(); itrVarInfo++)
dsi->AddTarget(*itrVarInfo);
for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); itrVarInfo++)
dsi->AddSpectator(*itrVarInfo);
std::vector<TString> variables = gTools().SplitString(theVariables,':' );
std::vector<UInt_t> varMap;
Int_t counter=0;
std::vector<TString>::iterator itrVariables;
Bool_t found = kFALSE;
for (itrVariables = variables.begin(); itrVariables != variables.end(); itrVariables++) {
counter=0;
for (itrVarInfo = oldDSI.GetVariableInfos().begin(); itrVarInfo != oldDSI.GetVariableInfos().end(); itrVarInfo++) {
if((*itrVariables==itrVarInfo->GetLabel()) ) {
dsi->AddVariable(*itrVarInfo);
varMap.push_back(counter);
found = kTRUE;
}
counter++;
}
for (itrVarInfo = oldDSI.GetSpectatorInfos().begin(); itrVarInfo != oldDSI.GetSpectatorInfos().end(); itrVarInfo++) {
if((*itrVariables==itrVarInfo->GetLabel()) ) {
dsi->AddVariable(*itrVarInfo);
varMap.push_back(counter);
found = kTRUE;
}
counter++;
}
if (!found) {
Log() << kFATAL <<"The variable " << itrVariables->Data() << " was not found and could not be added " << Endl;
}
found = kFALSE;
}
if (theVariables=="") {
for (UInt_t i=0; i<oldDSI.GetVariableInfos().size(); i++) {
dsi->AddVariable(oldDSI.GetVariableInfos()[i]);
varMap.push_back(i);
}
}
fVarMaps.push_back(varMap);
UInt_t nClasses=oldDSI.GetNClasses();
TString className;
for (UInt_t i=0; i<nClasses; i++) {
className = oldDSI.GetClassInfo(i)->GetName();
dsi->AddClass(className);
dsi->SetCut(oldDSI.GetCut(i),className);
dsi->AddCut(theCut,className);
dsi->SetWeightExpression(oldDSI.GetWeightExpression(i),className);
}
dsi->SetSplitOptions(oldDSI.GetSplitOptions());
dsi->SetRootDir(oldDSI.GetRootDir());
TString norm(oldDSI.GetNormalization().Data());
dsi->SetNormalization(norm);
DataSetInfo& dsiReference= (*dsi);
return dsiReference;
}
void TMVA::MethodCategory::Init()
{
}
void TMVA::MethodCategory::InitCircularTree(const DataSetInfo& dsi)
{
delete fCatTree;
std::vector<VariableInfo>::const_iterator viIt;
const std::vector<VariableInfo>& vars = dsi.GetVariableInfos();
const std::vector<VariableInfo>& specs = dsi.GetSpectatorInfos();
Bool_t hasAllExternalLinks = kTRUE;
for (viIt = vars.begin(); viIt != vars.end(); ++viIt)
if( viIt->GetExternalLink() == 0 ) {
hasAllExternalLinks = kFALSE;
break;
}
for (viIt = specs.begin(); viIt != specs.end(); ++viIt)
if( viIt->GetExternalLink() == 0 ) {
hasAllExternalLinks = kFALSE;
break;
}
if(!hasAllExternalLinks) return;
{
TDirectory::TContext ctxt(nullptr);
fCatTree = new TTree(Form("Circ%s",GetMethodName().Data()),"Circlar Tree for categorization");
fCatTree->SetCircular(1);
}
for (viIt = vars.begin(); viIt != vars.end(); ++viIt) {
const VariableInfo& vi = *viIt;
fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString("/F"));
}
for (viIt = specs.begin(); viIt != specs.end(); ++viIt) {
const VariableInfo& vi = *viIt;
if(vi.GetVarType()=='C') continue;
fCatTree->Branch(vi.GetExpression(),(Float_t*)vi.GetExternalLink(), TString(vi.GetExpression())+TString("/F"));
}
for(UInt_t cat=0; cat!=fCategoryCuts.size(); ++cat) {
fCatFormulas.push_back(new TTreeFormula(Form("Category_%i",cat), fCategoryCuts[cat].GetTitle(), fCatTree));
}
}
void TMVA::MethodCategory::Train()
{
const Int_t MinNoTrainingEvents = 10;
Types::EAnalysisType analysisType = GetAnalysisType();
Log() << kINFO << "Train all sub-classifiers for "
<< (analysisType == Types::kRegression ? "Regression" : "Classification") << " ..." << Endl;
if (fMethods.empty()) {
Log() << kINFO << "...nothing found to train" << Endl;
return;
}
std::vector<IMethod*>::iterator itrMethod;
for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); ++itrMethod ) {
MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
if(!mva) continue;
mva->SetAnalysisType( analysisType );
if (!mva->HasAnalysisType( analysisType,
mva->DataInfo().GetNClasses(),
mva->DataInfo().GetNTargets() ) ) {
Log() << kWARNING << "Method " << mva->GetMethodTypeName() << " is not capable of handling " ;
if (analysisType == Types::kRegression)
Log() << "regression with " << mva->DataInfo().GetNTargets() << " targets." << Endl;
else
Log() << "classification with " << mva->DataInfo().GetNClasses() << " classes." << Endl;
itrMethod = fMethods.erase( itrMethod );
continue;
}
if (mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
Log() << kINFO << "Train method: " << mva->GetMethodName() << " for "
<< (analysisType == Types::kRegression ? "Regression" : "Classification") << Endl;
mva->TrainMethod();
Log() << kINFO << "Training finished" << Endl;
} else {
Log() << kWARNING << "Method " << mva->GetMethodName()
<< " not trained (training tree has less entries ["
<< mva->Data()->GetNTrainingEvents()
<< "] than required [" << MinNoTrainingEvents << "]" << Endl;
Log() << kERROR << " w/o training/test events for that category, I better stop here and let you fix " << Endl;
Log() << kFATAL << "that one first, otherwise things get too messy later ... " << Endl;
}
}
if (analysisType != Types::kRegression) {
Log() << kINFO << "Begin ranking of input variables..." << Endl;
for (itrMethod = fMethods.begin(); itrMethod != fMethods.end(); itrMethod++) {
MethodBase* mva = dynamic_cast<MethodBase*>(*itrMethod);
if (mva && mva->Data()->GetNTrainingEvents() >= MinNoTrainingEvents) {
const Ranking* ranking = (*itrMethod)->CreateRanking();
if (ranking != 0)
ranking->Print();
else
Log() << kINFO << "No variable ranking supplied by classifier: "
<< dynamic_cast<MethodBase*>(*itrMethod)->GetMethodName() << Endl;
}
}
}
}
void TMVA::MethodCategory::AddWeightsXMLTo( void* parent ) const
{
void* wght = gTools().AddChild(parent, "Weights");
gTools().AddAttr( wght, "NSubMethods", fMethods.size() );
void* submethod(0);
std::vector<IMethod*>::iterator itrMethod;
for (UInt_t i=0; i<fMethods.size(); i++) {
MethodBase* method = dynamic_cast<MethodBase*>(fMethods[i]);
submethod = gTools().AddChild(wght, "SubMethod");
gTools().AddAttr(submethod, "Index", i);
gTools().AddAttr(submethod, "Method", method->GetMethodTypeName() + "::" + method->GetMethodName());
gTools().AddAttr(submethod, "Cut", fCategoryCuts[i]);
gTools().AddAttr(submethod, "Variables", fVars[i]);
method->WriteStateToXML( submethod );
}
}
void TMVA::MethodCategory::ReadWeightsFromXML( void* wghtnode )
{
UInt_t nSubMethods;
TString fullMethodName;
TString methodType;
TString methodTitle;
TString theCutString;
TString theVariables;
Int_t titleLength;
gTools().ReadAttr( wghtnode, "NSubMethods", nSubMethods );
void* subMethodNode = gTools().GetChild(wghtnode);
Log() << kINFO << "Recreating sub-classifiers from XML-file " << Endl;
for (UInt_t i=0; i<nSubMethods; i++) {
gTools().ReadAttr( subMethodNode, "Method", fullMethodName );
gTools().ReadAttr( subMethodNode, "Cut", theCutString );
gTools().ReadAttr( subMethodNode, "Variables", theVariables );
methodType = fullMethodName(0,fullMethodName.Index("::"));
if (methodType.Contains(" ")) methodType = methodType(methodType.Last(' ')+1,methodType.Length());
titleLength = fullMethodName.Length()-fullMethodName.Index("::")-2;
methodTitle = fullMethodName(fullMethodName.Index("::")+2,titleLength);
DataSetInfo& dsi = CreateCategoryDSI(TCut(theCutString), theVariables, methodTitle);
MethodBase* method = dynamic_cast<MethodBase*>( ClassifierFactory::Instance().Create( methodType.Data(),
dsi, "none" ) );
if(method==0)
Log() << kFATAL << "Could not create sub-method " << method << " from XML." << Endl;
method->SetupMethod();
method->ReadStateFromXML(subMethodNode);
fMethods.push_back(method);
fCategoryCuts.push_back(TCut(theCutString));
fVars.push_back(theVariables);
DataSetInfo& primaryDSI = DataInfo();
UInt_t spectatorIdx = 10000;
UInt_t counter=0;
std::vector<VariableInfo>& spectators=primaryDSI.GetSpectatorInfos();
std::vector<VariableInfo>::iterator itrVarInfo;
TString specName= Form("%s_cat%i", GetName(),(int)fCategorySpecIdx.size()+1);
for (itrVarInfo = spectators.begin(); itrVarInfo != spectators.end(); ++itrVarInfo, ++counter) {
if((specName==itrVarInfo->GetLabel()) || (specName==itrVarInfo->GetExpression())) {
spectatorIdx=counter;
fCategorySpecIdx.push_back(spectatorIdx);
break;
}
}
subMethodNode = gTools().GetNextChild(subMethodNode);
}
InitCircularTree(DataInfo());
}
void TMVA::MethodCategory::ProcessOptions()
{
}
void TMVA::MethodCategory::GetHelpMessage() const
{
Log() << Endl;
Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
Log() << Endl;
Log() << "This method allows to define different categories of events. The" <<Endl;
Log() << "categories are defined via cuts on the variables. For each" << Endl;
Log() << "category, a different classifier and set of variables can be" <<Endl;
Log() << "specified. The categories which are defined for this method must" << Endl;
Log() << "be disjoint." << Endl;
}
const TMVA::Ranking* TMVA::MethodCategory::CreateRanking()
{
return 0;
}
Bool_t TMVA::MethodCategory::PassesCut( const Event* ev, UInt_t methodIdx )
{
if (fCatTree) {
if (methodIdx>=fCatFormulas.size()) {
Log() << kFATAL << "Large method index " << methodIdx << ", number of category formulas = "
<< fCatFormulas.size() << Endl;
}
TTreeFormula* f = fCatFormulas[methodIdx];
return f->EvalInstance(0) > 0.5;
}
else {
if (methodIdx>=fCategorySpecIdx.size()) {
Log() << kFATAL << "Unknown method index " << methodIdx << " maximum allowed index="
<< fCategorySpecIdx.size() << Endl;
}
UInt_t spectatorIdx = fCategorySpecIdx[methodIdx];
Float_t specVal = ev->GetSpectator(spectatorIdx);
Bool_t pass = (specVal>0.5);
return pass;
}
}
Double_t TMVA::MethodCategory::GetMvaValue( Double_t* err, Double_t* errUpper )
{
if (fMethods.empty()) return 0;
UInt_t methodToUse = 0;
const Event* ev = GetEvent();
Int_t suitableCutsN = 0;
for (UInt_t i=0; i<fMethods.size(); ++i) {
if (PassesCut(ev, i)) {
++suitableCutsN;
methodToUse=i;
}
}
if (suitableCutsN == 0) {
Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
return 0;
}
if (suitableCutsN > 1) {
Log() << kFATAL << "The defined categories are not disjoint." << Endl;
return 0;
}
ev->SetVariableArrangement(&fVarMaps[methodToUse]);
Double_t mvaValue = dynamic_cast<MethodBase*>(fMethods[methodToUse])->GetMvaValue(ev,err,errUpper);
ev->SetVariableArrangement(0);
return mvaValue;
}
const std::vector<Float_t> &TMVA::MethodCategory::GetRegressionValues()
{
if (fMethods.empty()) return MethodBase::GetRegressionValues();
UInt_t methodToUse = 0;
const Event* ev = GetEvent();
Int_t suitableCutsN = 0;
for (UInt_t i=0; i<fMethods.size(); ++i) {
if (PassesCut(ev, i)) {
++suitableCutsN;
methodToUse=i;
}
}
if (suitableCutsN == 0) {
Log() << kWARNING << "Event does not lie within the cut of any sub-classifier." << Endl;
return MethodBase::GetRegressionValues();
}
if (suitableCutsN > 1) {
Log() << kFATAL << "The defined categories are not disjoint." << Endl;
return MethodBase::GetRegressionValues();
}
MethodBase* meth = dynamic_cast<MethodBase*>(fMethods[methodToUse]);
if (!meth){
Log() << kFATAL << "method not found in Category Regression method" << Endl;
return MethodBase::GetRegressionValues();
}
return meth->GetRegressionValues(ev);
}