#ifndef ROOT_TMVA_DecisionTree
#define ROOT_TMVA_DecisionTree
#ifndef ROOT_TH2
#include "TH2.h"
#endif
#ifndef ROOT_TMVA_Types
#include "TMVA/Types.h"
#endif
#ifndef ROOT_TMVA_DecisionTreeNode
#include "TMVA/DecisionTreeNode.h"
#endif
#ifndef ROOT_TMVA_BinaryTree
#include "TMVA/BinaryTree.h"
#endif
#ifndef ROOT_TMVA_BinarySearchTree
#include "TMVA/BinarySearchTree.h"
#endif
#ifndef ROOT_TMVA_SeparationBase
#include "TMVA/SeparationBase.h"
#endif
#ifndef ROOT_TMVA_RegressionVariance
#include "TMVA/RegressionVariance.h"
#endif
class TRandom3;
namespace TMVA {
class Event;
class DecisionTree : public BinaryTree {
private:
static const Int_t fgRandomSeed;
public:
typedef std::vector<TMVA::Event*> EventList;
DecisionTree( void );
DecisionTree( SeparationBase *sepType, Int_t minSize,
Int_t nCuts,
Bool_t randomisedTree=kFALSE, Int_t useNvars=0,
UInt_t nNodesMax=999999, UInt_t nMaxDepth=9999999,
Int_t iSeed=fgRandomSeed, Float_t purityLimit=0.5,
Int_t treeID = 0);
DecisionTree (const DecisionTree &d);
virtual ~DecisionTree( void );
virtual Node * CreateNode(UInt_t) const { return new DecisionTreeNode(); }
virtual BinaryTree* CreateTree() const { return new DecisionTree(); }
virtual const char* ClassName() const { return "DecisionTree"; }
UInt_t BuildTree( const EventList & eventSample,
DecisionTreeNode *node = NULL);
Float_t TrainNode( const EventList & eventSample, DecisionTreeNode *node ) { return TrainNodeFast( eventSample, node ); }
Float_t TrainNodeFast( const EventList & eventSample, DecisionTreeNode *node );
Float_t TrainNodeFull( const EventList & eventSample, DecisionTreeNode *node );
void FillTree( EventList & eventSample);
void FillEvent( TMVA::Event & event,
TMVA::DecisionTreeNode *node );
Double_t CheckEvent( const TMVA::Event & , Bool_t UseYesNoLeaf = kFALSE ) const;
TMVA::DecisionTreeNode* GetEventNode(const TMVA::Event & e) const;
std::vector< Double_t > GetVariableImportance();
Double_t GetVariableImportance(UInt_t ivar);
void ClearTree();
enum EPruneMethod { kExpectedErrorPruning=0, kCostComplexityPruning, kNoPruning };
void SetPruneMethod( EPruneMethod m = kCostComplexityPruning ) { fPruneMethod = m; }
Double_t PruneTree( EventList* validationSample = NULL );
void SetPruneStrength( Double_t p ) { fPruneStrength = p; }
Double_t GetPruneStrength( ) const { return fPruneStrength; }
void ApplyValidationSample( const EventList* validationSample ) const;
Double_t TestPrunedTreeQuality( const DecisionTreeNode* dt = NULL, Int_t mode=0 ) const;
void CheckEventWithPrunedTree( const TMVA::Event& ) const;
Float_t GetSumWeights( const EventList* validationSample ) const;
void SetNodePurityLimit( Double_t p ) { fNodePurityLimit = p; }
Float_t GetNodePurityLimit( ) const { return fNodePurityLimit; }
void DescendTree( DecisionTreeNode *n = NULL );
void SetParentTreeInNodes( DecisionTreeNode *n = NULL );
DecisionTreeNode* GetLeftDaughter( DecisionTreeNode *n );
DecisionTreeNode* GetRightDaughter( DecisionTreeNode *n );
DecisionTreeNode* GetNode( ULong_t sequence, UInt_t depth );
UInt_t CleanTree(DecisionTreeNode *node=NULL);
void PruneNode(TMVA::DecisionTreeNode *node);
void PruneNodeInPlace( TMVA::DecisionTreeNode* node );
UInt_t CountLeafNodes(TMVA::DecisionTreeNode *n = NULL);
void SetTreeID(Int_t treeID){fTreeID = treeID;};
Int_t GetTreeID(){return fTreeID;};
Bool_t DoRegression() const { return fAnalysisType == Types::kRegression; }
void SetAnalysisType (Types::EAnalysisType t) { fAnalysisType = t;}
Types::EAnalysisType GetAnalysisType ( void ) { return fAnalysisType;}
private:
Float_t SamplePurity(EventList eventSample);
UInt_t fNvars;
Int_t fNCuts;
SeparationBase *fSepType;
RegressionVariance *fRegType;
Float_t fMinSize;
Float_t fMinSepGain;
Bool_t fUseSearchTree;
Double_t fPruneStrength;
EPruneMethod fPruneMethod;
Float_t fNodePurityLimit;
Bool_t fRandomisedTree;
Int_t fUseNvars;
TRandom3 *fMyTrandom;
std::vector< Double_t > fVariableImportance;
UInt_t fNNodesMax;
UInt_t fMaxDepth;
static const Int_t fgDebugLevel = 0;
Int_t fTreeID;
Types::EAnalysisType fAnalysisType;
ClassDef(DecisionTree,0)
};
}
#endif