#ifndef ROOT_TMVA_DecisionTreeNode
#define ROOT_TMVA_DecisionTreeNode
#ifndef ROOT_TMVA_Node
#include "TMVA/Node.h"
#endif
#ifndef ROOT_TMVA_Version
#include "TMVA/Version.h"
#endif
#include <iostream>
#include <vector>
#include <map>
namespace TMVA {
class DTNodeTrainingInfo
{
public:
DTNodeTrainingInfo():fSampleMin(),
fSampleMax(),
fNodeR(0),fSubTreeR(0),fAlpha(0),fG(0),fNTerminal(0),
fNB(0),fNS(0),fSumTarget(0),fSumTarget2(0),fCC(0),
fNSigEvents ( 0 ), fNBkgEvents ( 0 ),
fNEvents ( -1 ),
fNSigEvents_unweighted ( 0 ),
fNBkgEvents_unweighted ( 0 ),
fNEvents_unweighted ( 0 ),
fNSigEvents_unboosted ( 0 ),
fNBkgEvents_unboosted ( 0 ),
fNEvents_unboosted ( 0 ),
fSeparationIndex (-1 ),
fSeparationGain ( -1 )
{
}
std::vector< Float_t > fSampleMin;
std::vector< Float_t > fSampleMax;
Double_t fNodeR;
Double_t fSubTreeR;
Double_t fAlpha;
Double_t fG;
Int_t fNTerminal;
Double_t fNB;
Double_t fNS;
Float_t fSumTarget;
Float_t fSumTarget2;
Double_t fCC;
Float_t fNSigEvents;
Float_t fNBkgEvents;
Float_t fNEvents;
Float_t fNSigEvents_unweighted;
Float_t fNBkgEvents_unweighted;
Float_t fNEvents_unweighted;
Float_t fNSigEvents_unboosted;
Float_t fNBkgEvents_unboosted;
Float_t fNEvents_unboosted;
Float_t fSeparationIndex;
Float_t fSeparationGain;
DTNodeTrainingInfo(const DTNodeTrainingInfo& n) :
fSampleMin(),fSampleMax(),
fNodeR(n.fNodeR), fSubTreeR(n.fSubTreeR),
fAlpha(n.fAlpha), fG(n.fG),
fNTerminal(n.fNTerminal),
fNB(n.fNB), fNS(n.fNS),
fSumTarget(0),fSumTarget2(0),
fCC(0),
fNSigEvents ( n.fNSigEvents ), fNBkgEvents ( n.fNBkgEvents ),
fNEvents ( n.fNEvents ),
fNSigEvents_unweighted ( n.fNSigEvents_unweighted ),
fNBkgEvents_unweighted ( n.fNBkgEvents_unweighted ),
fNEvents_unweighted ( n.fNEvents_unweighted ),
fSeparationIndex( n.fSeparationIndex ),
fSeparationGain ( n.fSeparationGain )
{ }
};
class Event;
class MsgLogger;
class DecisionTreeNode: public Node {
public:
DecisionTreeNode ();
DecisionTreeNode (Node* p, char pos);
DecisionTreeNode (const DecisionTreeNode &n, DecisionTreeNode* parent = NULL);
virtual ~DecisionTreeNode();
virtual Node* CreateNode() const { return new DecisionTreeNode(); }
inline void SetNFisherCoeff(Int_t nvars){fFisherCoeff.resize(nvars);}
inline UInt_t GetNFisherCoeff() const { return fFisherCoeff.size();}
void SetFisherCoeff(Int_t ivar, Double_t coeff);
Double_t GetFisherCoeff(Int_t ivar) const {return fFisherCoeff.at(ivar);}
virtual Bool_t GoesRight( const Event & ) const;
virtual Bool_t GoesLeft ( const Event & ) const;
void SetSelector( Short_t i) { fSelector = i; }
Short_t GetSelector() const { return fSelector; }
void SetCutValue ( Float_t c ) { fCutValue = c; }
Float_t GetCutValue ( void ) const { return fCutValue; }
void SetCutType( Bool_t t ) { fCutType = t; }
Bool_t GetCutType( void ) const { return fCutType; }
void SetNodeType( Int_t t ) { fNodeType = t;}
Int_t GetNodeType( void ) const { return fNodeType; }
Float_t GetPurity( void ) const { return fPurity;}
void SetPurity( void );
void SetResponse( Float_t r ) { fResponse = r;}
Float_t GetResponse( void ) const { return fResponse;}
void SetRMS( Float_t r ) { fRMS = r;}
Float_t GetRMS( void ) const { return fRMS;}
void SetNSigEvents( Float_t s ) { fTrainInfo->fNSigEvents = s; }
void SetNBkgEvents( Float_t b ) { fTrainInfo->fNBkgEvents = b; }
void SetNEvents( Float_t nev ){ fTrainInfo->fNEvents =nev ; }
void SetNSigEvents_unweighted( Float_t s ) { fTrainInfo->fNSigEvents_unweighted = s; }
void SetNBkgEvents_unweighted( Float_t b ) { fTrainInfo->fNBkgEvents_unweighted = b; }
void SetNEvents_unweighted( Float_t nev ){ fTrainInfo->fNEvents_unweighted =nev ; }
void SetNSigEvents_unboosted( Float_t s ) { fTrainInfo->fNSigEvents_unboosted = s; }
void SetNBkgEvents_unboosted( Float_t b ) { fTrainInfo->fNBkgEvents_unboosted = b; }
void SetNEvents_unboosted( Float_t nev ){ fTrainInfo->fNEvents_unboosted =nev ; }
void IncrementNSigEvents( Float_t s ) { fTrainInfo->fNSigEvents += s; }
void IncrementNBkgEvents( Float_t b ) { fTrainInfo->fNBkgEvents += b; }
void IncrementNEvents( Float_t nev ){ fTrainInfo->fNEvents +=nev ; }
void IncrementNSigEvents_unweighted( ) { fTrainInfo->fNSigEvents_unweighted += 1; }
void IncrementNBkgEvents_unweighted( ) { fTrainInfo->fNBkgEvents_unweighted += 1; }
void IncrementNEvents_unweighted( ){ fTrainInfo->fNEvents_unweighted +=1 ; }
Float_t GetNSigEvents( void ) const { return fTrainInfo->fNSigEvents; }
Float_t GetNBkgEvents( void ) const { return fTrainInfo->fNBkgEvents; }
Float_t GetNEvents( void ) const { return fTrainInfo->fNEvents; }
Float_t GetNSigEvents_unweighted( void ) const { return fTrainInfo->fNSigEvents_unweighted; }
Float_t GetNBkgEvents_unweighted( void ) const { return fTrainInfo->fNBkgEvents_unweighted; }
Float_t GetNEvents_unweighted( void ) const { return fTrainInfo->fNEvents_unweighted; }
Float_t GetNSigEvents_unboosted( void ) const { return fTrainInfo->fNSigEvents_unboosted; }
Float_t GetNBkgEvents_unboosted( void ) const { return fTrainInfo->fNBkgEvents_unboosted; }
Float_t GetNEvents_unboosted( void ) const { return fTrainInfo->fNEvents_unboosted; }
void SetSeparationIndex( Float_t sep ){ fTrainInfo->fSeparationIndex =sep ; }
Float_t GetSeparationIndex( void ) const { return fTrainInfo->fSeparationIndex; }
void SetSeparationGain( Float_t sep ){ fTrainInfo->fSeparationGain =sep ; }
Float_t GetSeparationGain( void ) const { return fTrainInfo->fSeparationGain; }
virtual void Print( std::ostream& os ) const;
virtual void PrintRec( std::ostream& os ) const;
virtual void AddAttributesToNode(void* node) const;
virtual void AddContentToNode(std::stringstream& s) const;
void ClearNodeAndAllDaughters();
inline virtual DecisionTreeNode* GetLeft( ) const { return dynamic_cast<DecisionTreeNode*>(fLeft); }
inline virtual DecisionTreeNode* GetRight( ) const { return dynamic_cast<DecisionTreeNode*>(fRight); }
inline virtual DecisionTreeNode* GetParent( ) const { return dynamic_cast<DecisionTreeNode*>(fParent); }
inline virtual void SetLeft (Node* l) { fLeft = dynamic_cast<DecisionTreeNode*>(l);}
inline virtual void SetRight (Node* r) { fRight = dynamic_cast<DecisionTreeNode*>(r);}
inline virtual void SetParent(Node* p) { fParent = dynamic_cast<DecisionTreeNode*>(p);}
inline void SetNodeR( Double_t r ) { fTrainInfo->fNodeR = r; }
inline Double_t GetNodeR( ) const { return fTrainInfo->fNodeR; }
inline void SetSubTreeR( Double_t r ) { fTrainInfo->fSubTreeR = r; }
inline Double_t GetSubTreeR( ) const { return fTrainInfo->fSubTreeR; }
inline void SetAlpha( Double_t alpha ) { fTrainInfo->fAlpha = alpha; }
inline Double_t GetAlpha( ) const { return fTrainInfo->fAlpha; }
inline void SetAlphaMinSubtree( Double_t g ) { fTrainInfo->fG = g; }
inline Double_t GetAlphaMinSubtree( ) const { return fTrainInfo->fG; }
inline void SetNTerminal( Int_t n ) { fTrainInfo->fNTerminal = n; }
inline Int_t GetNTerminal( ) const { return fTrainInfo->fNTerminal; }
inline void SetNBValidation( Double_t b ) { fTrainInfo->fNB = b; }
inline void SetNSValidation( Double_t s ) { fTrainInfo->fNS = s; }
inline Double_t GetNBValidation( ) const { return fTrainInfo->fNB; }
inline Double_t GetNSValidation( ) const { return fTrainInfo->fNS; }
inline void SetSumTarget(Float_t t) {fTrainInfo->fSumTarget = t; }
inline void SetSumTarget2(Float_t t2){fTrainInfo->fSumTarget2 = t2; }
inline void AddToSumTarget(Float_t t) {fTrainInfo->fSumTarget += t; }
inline void AddToSumTarget2(Float_t t2){fTrainInfo->fSumTarget2 += t2; }
inline Float_t GetSumTarget() const {return fTrainInfo? fTrainInfo->fSumTarget : -9999;}
inline Float_t GetSumTarget2() const {return fTrainInfo? fTrainInfo->fSumTarget2: -9999;}
void ResetValidationData( );
inline Bool_t IsTerminal() const { return fIsTerminalNode; }
inline void SetTerminal( Bool_t s = kTRUE ) { fIsTerminalNode = s; }
void PrintPrune( std::ostream& os ) const ;
void PrintRecPrune( std::ostream& os ) const;
void SetCC(Double_t cc);
Double_t GetCC() const {return (fTrainInfo? fTrainInfo->fCC : -1.);}
Float_t GetSampleMin(UInt_t ivar) const;
Float_t GetSampleMax(UInt_t ivar) const;
void SetSampleMin(UInt_t ivar, Float_t xmin);
void SetSampleMax(UInt_t ivar, Float_t xmax);
static bool fgIsTraining;
static UInt_t fgTmva_Version_Code;
protected:
static MsgLogger& Log();
std::vector<Double_t> fFisherCoeff;
Float_t fCutValue;
Bool_t fCutType;
Short_t fSelector;
Float_t fResponse;
Float_t fRMS;
Int_t fNodeType;
Float_t fPurity;
Bool_t fIsTerminalNode;
mutable DTNodeTrainingInfo* fTrainInfo;
private:
virtual void ReadAttributes(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
virtual Bool_t ReadDataRecord( std::istream& is, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
virtual void ReadContent(std::stringstream& s);
ClassDef(DecisionTreeNode,0)
};
}
#endif