#include <algorithm>
#include "Riostream.h"
#include "TRandom3.h"
#include "TMath.h"
#include "TObjString.h"
#include "TMVA/ClassifierFactory.h"
#include "TMVA/MethodDT.h"
#include "TMVA/Tools.h"
#include "TMVA/Timer.h"
#include "TMVA/Ranking.h"
#include "TMVA/SdivSqrtSplusB.h"
#include "TMVA/BinarySearchTree.h"
#include "TMVA/SeparationBase.h"
#include "TMVA/GiniIndex.h"
#include "TMVA/CrossEntropy.h"
#include "TMVA/MisClassificationError.h"
#include "TMVA/MethodBoost.h"
#include "TMVA/CCPruner.h"
using std::vector;
REGISTER_METHOD(DT)
ClassImp(TMVA::MethodDT)
TMVA::MethodDT::MethodDT( const TString& jobName,
const TString& methodTitle,
DataSetInfo& theData,
const TString& theOption,
TDirectory* theTargetDir ) :
TMVA::MethodBase( jobName, Types::kDT, methodTitle, theData, theOption, theTargetDir )
, fTree(0)
, fNodeMinEvents(0)
, fNCuts(0)
, fUseYesNoLeaf(kFALSE)
, fNodePurityLimit(0)
, fNNodesMax(0)
, fMaxDepth(0)
, fErrorFraction(0)
, fPruneStrength(0)
, fPruneMethod(DecisionTree::kNoPruning)
, fAutomatic(kFALSE)
, fRandomisedTrees(kFALSE)
, fUseNvars(0)
, fPruneBeforeBoost(kFALSE)
, fDeltaPruneStrength(0)
{
}
TMVA::MethodDT::MethodDT( DataSetInfo& dsi,
const TString& theWeightFile,
TDirectory* theTargetDir ) :
TMVA::MethodBase( Types::kDT, dsi, theWeightFile, theTargetDir )
, fTree(0)
, fNodeMinEvents(0)
, fNCuts(0)
, fUseYesNoLeaf(kFALSE)
, fNodePurityLimit(0)
, fNNodesMax(0)
, fMaxDepth(0)
, fErrorFraction(0)
, fPruneStrength(0)
, fPruneMethod(DecisionTree::kNoPruning)
, fAutomatic(kFALSE)
, fRandomisedTrees(kFALSE)
, fUseNvars(0)
, fPruneBeforeBoost(kFALSE)
, fDeltaPruneStrength(0)
{
}
Bool_t TMVA::MethodDT::HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t )
{
if( type == Types::kClassification && numberClasses == 2 ) return kTRUE;
return kFALSE;
}
void TMVA::MethodDT::DeclareOptions()
{
DeclareOptionRef(fRandomisedTrees,"UseRandomisedTrees","Choose at each node splitting a random set of variables and *bagging*");
DeclareOptionRef(fUseNvars,"UseNvars","Number of variables used if randomised Tree option is chosen");
DeclareOptionRef(fUseYesNoLeaf=kTRUE, "UseYesNoLeaf",
"Use Sig or Bkg node type or the ratio S/B as classification in the leaf node");
DeclareOptionRef(fNodePurityLimit=0.5, "NodePurityLimit", "In boosting/pruning, nodes with purity > NodePurityLimit are signal; background otherwise.");
DeclareOptionRef(fPruneBeforeBoost=kFALSE, "PruneBeforeBoost",
"Whether to perform the prune process right after the training or after the boosting");
DeclareOptionRef(fSepTypeS="GiniIndex", "SeparationType", "Separation criterion for node splitting");
AddPreDefVal(TString("MisClassificationError"));
AddPreDefVal(TString("GiniIndex"));
AddPreDefVal(TString("CrossEntropy"));
AddPreDefVal(TString("SDivSqrtSPlusB"));
DeclareOptionRef(fNodeMinEvents, "nEventsMin", "Minimum number of events in a leaf node (default: max(20, N_train/(Nvar^2)/10) ) ");
DeclareOptionRef(fNCuts, "nCuts", "Number of steps during node cut optimisation");
DeclareOptionRef(fPruneStrength, "PruneStrength", "Pruning strength (negative value == automatic adjustment)");
DeclareOptionRef(fPruneMethodS, "PruneMethod", "Pruning method: NoPruning (switched off), ExpectedError or CostComplexity");
AddPreDefVal(TString("NoPruning"));
AddPreDefVal(TString("ExpectedError"));
AddPreDefVal(TString("CostComplexity"));
DeclareOptionRef(fNNodesMax=100000,"NNodesMax","Max number of nodes in tree");
if (DoRegression()) {
DeclareOptionRef(fMaxDepth=50,"MaxDepth","Max depth of the decision tree allowed");
}else{
DeclareOptionRef(fMaxDepth=3,"MaxDepth","Max depth of the decision tree allowed");
}
}
void TMVA::MethodDT::ProcessOptions()
{
fSepTypeS.ToLower();
if (fSepTypeS == "misclassificationerror") fSepType = new MisClassificationError();
else if (fSepTypeS == "giniindex") fSepType = new GiniIndex();
else if (fSepTypeS == "crossentropy") fSepType = new CrossEntropy();
else if (fSepTypeS == "sdivsqrtsplusb") fSepType = new SdivSqrtSplusB();
else {
Log() << kINFO << GetOptions() << Endl;
Log() << kFATAL << "<ProcessOptions> unknown Separation Index option called" << Endl;
}
fPruneMethodS.ToLower();
if (fPruneMethodS == "expectederror" ) fPruneMethod = DecisionTree::kExpectedErrorPruning;
else if (fPruneMethodS == "costcomplexity" ) fPruneMethod = DecisionTree::kCostComplexityPruning;
else if (fPruneMethodS == "nopruning" ) fPruneMethod = DecisionTree::kNoPruning;
else {
Log() << kINFO << GetOptions() << Endl;
Log() << kFATAL << "<ProcessOptions> unknown PruneMethod option called" << Endl;
}
if (fPruneStrength < 0) fAutomatic = kTRUE;
else fAutomatic = kFALSE;
if (fAutomatic && fPruneMethod==!DecisionTree::kCostComplexityPruning){
Log() << kFATAL
<< "Sorry autmoatic pruning strength determination is not implemented yet for ExpectedErrorPruning" << Endl;
}
if (this->Data()->HasNegativeEventWeights()){
Log() << kINFO << " You are using a Monte Carlo that has also negative weights. "
<< "That should in principle be fine as long as on average you end up with "
<< "something positive. For this you have to make sure that the minimal number "
<< "of (unweighted) events demanded for a tree node (currently you use: nEventsMin="
<<fNodeMinEvents<<", you can set this via the BDT option string when booking the "
<< "classifier) is large enough to allow for reasonable averaging!!! "
<< " If this does not help.. maybe you want to try the option: NoNegWeightsInTraining "
<< "which ignores events with negative weight in the training. " << Endl
<< Endl << "Note: You'll get a WARNING message during the training if that should ever happen" << Endl;
}
if (fRandomisedTrees){
Log() << kINFO << " Randomised trees should use *bagging* as *boost* method. Did you set this in the *MethodBoost* ? . Here I can enforce only the *no pruning*" << Endl;
fPruneMethod = DecisionTree::kNoPruning;
}
}
void TMVA::MethodDT::Init( void )
{
fNodeMinEvents = TMath::Max( 20, int( Data()->GetNTrainingEvents() / (10*GetNvar()*GetNvar())) );
fNCuts = 20;
fPruneMethod = DecisionTree::kNoPruning;
fPruneStrength = 5;
fDeltaPruneStrength=0.1;
fRandomisedTrees= kFALSE;
fUseNvars = GetNvar();
SetSignalReferenceCut( 0 );
if (fAnalysisType == Types::kClassification || fAnalysisType == Types::kMulticlass ) {
fMaxDepth = 3;
}else {
fMaxDepth = 50;
}
}
TMVA::MethodDT::~MethodDT( void )
{
delete fTree;
}
void TMVA::MethodDT::Train( void )
{
TMVA::DecisionTreeNode::fgIsTraining=true;
fTree = new DecisionTree( fSepType, fNodeMinEvents, fNCuts, 0,
fRandomisedTrees, fUseNvars, fNNodesMax, fMaxDepth,0 );
if (fRandomisedTrees) Log()<<kWARNING<<" randomised Trees do not work yet in this framework,"
<< " as I do not know how to give each tree a new random seed, now they"
<< " will be all the same and that is not good " << Endl;
fTree->SetAnalysisType( GetAnalysisType() );
fTree->BuildTree(GetEventCollection(Types::kTraining));
TMVA::DecisionTreeNode::fgIsTraining=false;
}
Bool_t TMVA::MethodDT::MonitorBoost( MethodBoost* booster )
{
Int_t methodIndex = booster->GetMethodIndex();
if (booster->GetBoostStage() == Types::kBoostProcBegin)
{
booster->AddMonitoringHist(new TH1I("NodesBeforePruning","nodes before pruning",booster->GetBoostNum(),0,booster->GetBoostNum()));
booster->AddMonitoringHist(new TH1I("NodesAfterPruning","nodes after pruning",booster->GetBoostNum(),0,booster->GetBoostNum()));
booster->AddMonitoringHist(new TH1D("PruneStrength","prune strength",booster->GetBoostNum(),0,booster->GetBoostNum()));
}
if (booster->GetBoostStage() == Types::kBeforeTraining)
{
if (methodIndex == 0)
{
booster->GetMonitoringHist(2)->SetXTitle("#tree");
booster->GetMonitoringHist(2)->SetYTitle("PruneStrength");
if (fAutomatic)
{
Data()->DivideTrainingSet(2);
Data()->MoveTrainingBlock(1,Types::kValidation,kTRUE);
}
}
}
else if (booster->GetBoostStage() == Types::kBeforeBoosting)
booster->GetMonitoringHist(0)->SetBinContent(booster->GetBoostNum()+1,fTree->GetNNodes());
if (booster->GetBoostStage() == ((fPruneBeforeBoost)?Types::kBeforeBoosting:Types::kBoostValidation)
&& !(fPruneMethod == DecisionTree::kNoPruning)) {
if (methodIndex==0 && fPruneBeforeBoost == kFALSE)
Log() << kINFO << "Pruning "<< booster->GetBoostNum() << " Decision Trees ... patience please" << Endl;
if (fAutomatic && methodIndex > 0) {
MethodDT* mdt = dynamic_cast<MethodDT*>(booster->GetPreviousMethod());
if(mdt)
fPruneStrength = mdt->GetPruneStrength();
}
booster->GetMonitoringHist(0)->SetBinContent(methodIndex+1,fTree->GetNNodes());
booster->GetMonitoringHist(2)->SetBinContent(methodIndex+1,PruneTree(methodIndex));
booster->GetMonitoringHist(1)->SetBinContent(methodIndex+1,fTree->GetNNodes());
}
else if (booster->GetBoostStage() != Types::kBoostProcEnd)
return kFALSE;
if (booster->GetBoostStage() == Types::kBoostProcEnd)
{
if (fPruneMethod == DecisionTree::kNoPruning) {
Log() << kINFO << "<Train> average number of nodes (w/o pruning) : "
<< booster->GetMonitoringHist(0)->GetMean() << Endl;
}
else
{
Log() << kINFO << "<Train> average number of nodes before/after pruning : "
<< booster->GetMonitoringHist(0)->GetMean() << " / "
<< booster->GetMonitoringHist(1)->GetMean()
<< Endl;
}
}
return kTRUE;
}
Double_t TMVA::MethodDT::PruneTree(const Int_t methodIndex)
{
if (fAutomatic && fPruneMethod == DecisionTree::kCostComplexityPruning) {
CCPruner* pruneTool = new CCPruner(fTree, this->Data() , fSepType);
pruneTool->Optimize();
std::vector<DecisionTreeNode*> nodes = pruneTool->GetOptimalPruneSequence();
fPruneStrength = pruneTool->GetOptimalPruneStrength();
for(UInt_t i = 0; i < nodes.size(); i++)
fTree->PruneNode(nodes[i]);
delete pruneTool;
}
else if (fAutomatic && fPruneMethod != DecisionTree::kCostComplexityPruning){
Int_t bla;
bla = methodIndex;
}
else {
fTree->SetPruneStrength(fPruneStrength);
fTree->PruneTree();
}
return fPruneStrength;
}
Double_t TMVA::MethodDT::TestTreeQuality( DecisionTree *dt )
{
Data()->SetCurrentType(Types::kValidation);
Double_t SumCorrect=0,SumWrong=0;
for (Long64_t ievt=0; ievt<Data()->GetNEvents(); ievt++)
{
Event * ev = Data()->GetEvent(ievt);
if ((dt->CheckEvent(*ev) > dt->GetNodePurityLimit() ) == DataInfo().IsSignal(ev)) SumCorrect+=ev->GetWeight();
else SumWrong+=ev->GetWeight();
}
Data()->SetCurrentType(Types::kTraining);
return SumCorrect / (SumCorrect + SumWrong);
}
void TMVA::MethodDT::AddWeightsXMLTo( void* parent ) const
{
fTree->AddXMLTo(parent);
}
void TMVA::MethodDT::ReadWeightsFromXML( void* wghtnode)
{
if(fTree)
delete fTree;
fTree = new DecisionTree();
fTree->ReadXML(wghtnode,GetTrainingTMVAVersionCode());
}
void TMVA::MethodDT::ReadWeightsFromStream( istream& istr )
{
delete fTree;
fTree = new DecisionTree();
fTree->Read(istr);
}
Double_t TMVA::MethodDT::GetMvaValue( Double_t* err, Double_t* errUpper )
{
NoErrorCalc(err, errUpper);
return fTree->CheckEvent(*GetEvent(),fUseYesNoLeaf);
}
void TMVA::MethodDT::GetHelpMessage() const
{
}
const TMVA::Ranking* TMVA::MethodDT::CreateRanking()
{
return 0;
}