#ifndef ROOT_TMVA_DecisionTree
#define ROOT_TMVA_DecisionTree
#ifndef ROOT_TH2
#include "TH2.h"
#endif
#ifndef ROOT_TMVA_Types
#include "TMVA/Types.h"
#endif
#ifndef ROOT_TMVA_DecisionTreeNode
#include "TMVA/DecisionTreeNode.h"
#endif
#ifndef ROOT_TMVA_BinaryTree
#include "TMVA/BinaryTree.h"
#endif
#ifndef ROOT_TMVA_BinarySearchTree
#include "TMVA/BinarySearchTree.h"
#endif
#ifndef ROOT_TMVA_SeparationBase
#include "TMVA/SeparationBase.h"
#endif
#ifndef ROOT_TMVA_RegressionVariance
#include "TMVA/RegressionVariance.h"
#endif
#include "TMVA/DataSetInfo.h"
class TRandom3;
namespace TMVA {
class Event;
class DecisionTree : public BinaryTree {
private:
static const Int_t fgRandomSeed;
public:
typedef std::vector<TMVA::Event*> EventList;
typedef std::vector<const TMVA::Event*> EventConstList;
DecisionTree( void );
DecisionTree( SeparationBase *sepType, Float_t minSize,
Int_t nCuts, DataSetInfo* = NULL,
UInt_t cls =0,
Bool_t randomisedTree=kFALSE, Int_t useNvars=0, Bool_t usePoissonNvars=kFALSE,
UInt_t nMaxDepth=9999999,
Int_t iSeed=fgRandomSeed, Float_t purityLimit=0.5,
Int_t treeID = 0);
DecisionTree (const DecisionTree &d);
virtual ~DecisionTree( void );
virtual DecisionTreeNode* GetRoot() const { return dynamic_cast<TMVA::DecisionTreeNode*>(fRoot); }
virtual DecisionTreeNode * CreateNode(UInt_t) const { return new DecisionTreeNode(); }
virtual BinaryTree* CreateTree() const { return new DecisionTree(); }
static DecisionTree* CreateFromXML(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
virtual const char* ClassName() const { return "DecisionTree"; }
UInt_t BuildTree( const EventConstList & eventSample,
DecisionTreeNode *node = NULL);
Double_t TrainNode( const EventConstList & eventSample, DecisionTreeNode *node ) { return TrainNodeFast( eventSample, node ); }
Double_t TrainNodeFast( const EventConstList & eventSample, DecisionTreeNode *node );
Double_t TrainNodeFull( const EventConstList & eventSample, DecisionTreeNode *node );
void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t & nVars);
std::vector<Double_t> GetFisherCoefficients(const EventConstList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher);
void FillTree( const EventList & eventSample);
void FillEvent( const TMVA::Event & event,
TMVA::DecisionTreeNode *node );
Double_t CheckEvent( const TMVA::Event * , Bool_t UseYesNoLeaf = kFALSE ) const;
TMVA::DecisionTreeNode* GetEventNode(const TMVA::Event & e) const;
std::vector< Double_t > GetVariableImportance();
Double_t GetVariableImportance(UInt_t ivar);
void ClearTree();
enum EPruneMethod { kExpectedErrorPruning=0, kCostComplexityPruning, kNoPruning };
void SetPruneMethod( EPruneMethod m = kCostComplexityPruning ) { fPruneMethod = m; }
Double_t PruneTree( const EventConstList* validationSample = NULL );
void SetPruneStrength( Double_t p ) { fPruneStrength = p; }
Double_t GetPruneStrength( ) const { return fPruneStrength; }
void ApplyValidationSample( const EventConstList* validationSample ) const;
Double_t TestPrunedTreeQuality( const DecisionTreeNode* dt = NULL, Int_t mode=0 ) const;
void CheckEventWithPrunedTree( const TMVA::Event* ) const;
Double_t GetSumWeights( const EventConstList* validationSample ) const;
void SetNodePurityLimit( Double_t p ) { fNodePurityLimit = p; }
Double_t GetNodePurityLimit( ) const { return fNodePurityLimit; }
void DescendTree( Node *n = NULL );
void SetParentTreeInNodes( Node *n = NULL );
Node* GetNode( ULong_t sequence, UInt_t depth );
UInt_t CleanTree(DecisionTreeNode *node=NULL);
void PruneNode(TMVA::DecisionTreeNode *node);
void PruneNodeInPlace( TMVA::DecisionTreeNode* node );
Int_t GetNNodesBeforePruning(){return (fNNodesBeforePruning)?fNNodesBeforePruning:fNNodesBeforePruning=GetNNodes();}
UInt_t CountLeafNodes(TMVA::Node *n = NULL);
void SetTreeID(Int_t treeID){fTreeID = treeID;};
Int_t GetTreeID(){return fTreeID;};
Bool_t DoRegression() const { return fAnalysisType == Types::kRegression; }
void SetAnalysisType (Types::EAnalysisType t) { fAnalysisType = t;}
Types::EAnalysisType GetAnalysisType ( void ) { return fAnalysisType;}
inline void SetUseFisherCuts(Bool_t t=kTRUE) { fUseFisherCuts = t;}
inline void SetMinLinCorrForFisher(Double_t min){fMinLinCorrForFisher = min;}
inline void SetUseExclusiveVars(Bool_t t=kTRUE){fUseExclusiveVars = t;}
inline void SetNVars(Int_t n){fNvars = n;}
private:
Double_t SamplePurity(EventList eventSample);
UInt_t fNvars;
Int_t fNCuts;
Bool_t fUseFisherCuts;
Double_t fMinLinCorrForFisher;
Bool_t fUseExclusiveVars;
SeparationBase *fSepType;
RegressionVariance *fRegType;
Double_t fMinSize;
Double_t fMinNodeSize;
Double_t fMinSepGain;
Bool_t fUseSearchTree;
Double_t fPruneStrength;
EPruneMethod fPruneMethod;
Int_t fNNodesBeforePruning;
Double_t fNodePurityLimit;
Bool_t fRandomisedTree;
Int_t fUseNvars;
Bool_t fUsePoissonNvars;
TRandom3 *fMyTrandom;
std::vector< Double_t > fVariableImportance;
UInt_t fMaxDepth;
UInt_t fSigClass;
static const Int_t fgDebugLevel = 0;
Int_t fTreeID;
Types::EAnalysisType fAnalysisType;
DataSetInfo* fDataSetInfo;
ClassDef(DecisionTree,0)
};
}
#endif