#include <algorithm>
#include <exception>
#include <iomanip>
#include "TMVA/MsgLogger.h"
#include "TMVA/DecisionTreeNode.h"
#include "TMVA/Tools.h"
#include "TMVA/Event.h"
using std::string;
ClassImp(TMVA::DecisionTreeNode)
TMVA::MsgLogger* TMVA::DecisionTreeNode::fgLogger = 0;
bool TMVA::DecisionTreeNode::fgIsTraining = false;
TMVA::DecisionTreeNode::DecisionTreeNode()
: TMVA::Node(),
fCutValue(0),
fCutType ( kTRUE ),
fSelector ( -1 ),
fResponse(-99 ),
fRMS(0),
fNodeType (-99 ),
fPurity (-99),
fIsTerminalNode( kFALSE )
{
if (!fgLogger) fgLogger = new TMVA::MsgLogger( "DecisionTreeNode" );
if (fgIsTraining){
fTrainInfo = new DTNodeTrainingInfo();
}
else {
fTrainInfo = 0;
}
}
TMVA::DecisionTreeNode::DecisionTreeNode(TMVA::Node* p, char pos)
: TMVA::Node(p, pos),
fCutValue( 0 ),
fCutType ( kTRUE ),
fSelector( -1 ),
fResponse(-99 ),
fRMS(0),
fNodeType( -99 ),
fPurity (-99),
fIsTerminalNode( kFALSE )
{
if (!fgLogger) fgLogger = new TMVA::MsgLogger( "DecisionTreeNode" );
if (fgIsTraining){
fTrainInfo = new DTNodeTrainingInfo();
}
else {
fTrainInfo = 0;
}
}
TMVA::DecisionTreeNode::DecisionTreeNode(const TMVA::DecisionTreeNode &n,
DecisionTreeNode* parent)
: TMVA::Node(n),
fCutValue( n.fCutValue ),
fCutType ( n.fCutType ),
fSelector( n.fSelector ),
fResponse( n.fResponse ),
fRMS ( n.fRMS),
fNodeType( n.fNodeType ),
fPurity ( n.fPurity),
fIsTerminalNode( n.fIsTerminalNode )
{
if (!fgLogger) fgLogger = new TMVA::MsgLogger( "DecisionTreeNode" );
this->SetParent( parent );
if (n.GetLeft() == 0 ) this->SetLeft(NULL);
else this->SetLeft( new DecisionTreeNode( *((DecisionTreeNode*)(n.GetLeft())),this));
if (n.GetRight() == 0 ) this->SetRight(NULL);
else this->SetRight( new DecisionTreeNode( *((DecisionTreeNode*)(n.GetRight())),this));
if (fgIsTraining){
fTrainInfo = new DTNodeTrainingInfo(*(n.fTrainInfo));
}
else {
fTrainInfo = 0;
}
}
TMVA::DecisionTreeNode::~DecisionTreeNode(){
delete fTrainInfo;
}
Bool_t TMVA::DecisionTreeNode::GoesRight(const TMVA::Event & e) const
{
Bool_t result;
if (GetNFisherCoeff() == 0){
result = (e.GetValue(this->GetSelector()) > this->GetCutValue() );
}else{
Double_t fisher = this->GetFisherCoeff(fFisherCoeff.size()-1);
for (UInt_t ivar=0; ivar<fFisherCoeff.size()-1; ivar++)
fisher += this->GetFisherCoeff(ivar)*(e.GetValue(ivar));
result = fisher > this->GetCutValue();
}
if (fCutType == kTRUE) return result;
else return !result;
}
Bool_t TMVA::DecisionTreeNode::GoesLeft(const TMVA::Event & e) const
{
if (!this->GoesRight(e)) return kTRUE;
else return kFALSE;
}
void TMVA::DecisionTreeNode::SetPurity( void )
{
if ( ( this->GetNSigEvents() + this->GetNBkgEvents() ) > 0 ) {
fPurity = this->GetNSigEvents() / ( this->GetNSigEvents() + this->GetNBkgEvents());
}
else {
*fgLogger << kINFO << "Zero events in purity calcuation , return purity=0.5" << Endl;
this->Print(*fgLogger);
fPurity = 0.5;
}
return;
}
void TMVA::DecisionTreeNode::Print(ostream& os) const
{
os << "< *** " << std::endl;
os << " d: " << this->GetDepth()
<< std::setprecision(6)
<< "NCoef: " << this->GetNFisherCoeff();
for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++) { os << "fC"<<i<<": " << this->GetFisherCoeff(i);}
os << " ivar: " << this->GetSelector()
<< " cut: " << this->GetCutValue()
<< " cType: " << this->GetCutType()
<< " s: " << this->GetNSigEvents()
<< " b: " << this->GetNBkgEvents()
<< " nEv: " << this->GetNEvents()
<< " suw: " << this->GetNSigEvents_unweighted()
<< " buw: " << this->GetNBkgEvents_unweighted()
<< " nEvuw: " << this->GetNEvents_unweighted()
<< " sepI: " << this->GetSeparationIndex()
<< " sepG: " << this->GetSeparationGain()
<< " nType: " << this->GetNodeType()
<< std::endl;
os << "My address is " << long(this) << ", ";
if (this->GetParent() != NULL) os << " parent at addr: " << long(this->GetParent()) ;
if (this->GetLeft() != NULL) os << " left daughter at addr: " << long(this->GetLeft());
if (this->GetRight() != NULL) os << " right daughter at addr: " << long(this->GetRight()) ;
os << " **** > " << std::endl;
}
void TMVA::DecisionTreeNode::PrintRec(ostream& os) const
{
os << this->GetDepth()
<< std::setprecision(6)
<< " " << this->GetPos()
<< "NCoef: " << this->GetNFisherCoeff();
for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++) {os << "fC"<<i<<": " << this->GetFisherCoeff(i);}
os << " ivar: " << this->GetSelector()
<< " cut: " << this->GetCutValue()
<< " cType: " << this->GetCutType()
<< " s: " << this->GetNSigEvents()
<< " b: " << this->GetNBkgEvents()
<< " nEv: " << this->GetNEvents()
<< " suw: " << this->GetNSigEvents_unweighted()
<< " buw: " << this->GetNBkgEvents_unweighted()
<< " nEvuw: " << this->GetNEvents_unweighted()
<< " sepI: " << this->GetSeparationIndex()
<< " sepG: " << this->GetSeparationGain()
<< " res: " << this->GetResponse()
<< " rms: " << this->GetRMS()
<< " nType: " << this->GetNodeType();
if (this->GetCC() > 10000000000000.) os << " CC: " << 100000. << std::endl;
else os << " CC: " << this->GetCC() << std::endl;
if (this->GetLeft() != NULL) this->GetLeft() ->PrintRec(os);
if (this->GetRight() != NULL) this->GetRight()->PrintRec(os);
}
Bool_t TMVA::DecisionTreeNode::ReadDataRecord( istream& is, UInt_t tmva_Version_Code )
{
string tmp;
Float_t cutVal, cutType, nsig, nbkg, nEv, nsig_unweighted, nbkg_unweighted, nEv_unweighted;
Float_t separationIndex, separationGain, response(-99), cc(0);
Int_t depth, ivar, nodeType;
ULong_t lseq;
char pos;
is >> depth;
if ( depth==-1 ) { return kFALSE; }
is >> pos ;
this->SetDepth(depth);
this->SetPos(pos);
if (tmva_Version_Code < TMVA_VERSION(4,0,0)) {
is >> tmp >> lseq
>> tmp >> ivar
>> tmp >> cutVal
>> tmp >> cutType
>> tmp >> nsig
>> tmp >> nbkg
>> tmp >> nEv
>> tmp >> nsig_unweighted
>> tmp >> nbkg_unweighted
>> tmp >> nEv_unweighted
>> tmp >> separationIndex
>> tmp >> separationGain
>> tmp >> nodeType;
} else {
is >> tmp >> lseq
>> tmp >> ivar
>> tmp >> cutVal
>> tmp >> cutType
>> tmp >> nsig
>> tmp >> nbkg
>> tmp >> nEv
>> tmp >> nsig_unweighted
>> tmp >> nbkg_unweighted
>> tmp >> nEv_unweighted
>> tmp >> separationIndex
>> tmp >> separationGain
>> tmp >> response
>> tmp >> nodeType
>> tmp >> cc;
}
this->SetSelector((UInt_t)ivar);
this->SetCutValue(cutVal);
this->SetCutType(cutType);
this->SetNodeType(nodeType);
if (fTrainInfo){
this->SetNSigEvents(nsig);
this->SetNBkgEvents(nbkg);
this->SetNEvents(nEv);
this->SetNSigEvents_unweighted(nsig_unweighted);
this->SetNBkgEvents_unweighted(nbkg_unweighted);
this->SetNEvents_unweighted(nEv_unweighted);
this->SetSeparationIndex(separationIndex);
this->SetSeparationGain(separationGain);
this->SetPurity();
this->SetCC(cc);
}
return kTRUE;
}
void TMVA::DecisionTreeNode::ClearNodeAndAllDaughters()
{
SetNSigEvents(0);
SetNBkgEvents(0);
SetNEvents(0);
SetNSigEvents_unweighted(0);
SetNBkgEvents_unweighted(0);
SetNEvents_unweighted(0);
SetSeparationIndex(-1);
SetSeparationGain(-1);
SetPurity();
if (this->GetLeft() != NULL) ((DecisionTreeNode*)(this->GetLeft()))->ClearNodeAndAllDaughters();
if (this->GetRight() != NULL) ((DecisionTreeNode*)(this->GetRight()))->ClearNodeAndAllDaughters();
}
void TMVA::DecisionTreeNode::ResetValidationData( ) {
SetNBValidation( 0.0 );
SetNSValidation( 0.0 );
SetSumTarget( 0 );
SetSumTarget2( 0 );
if(GetLeft() != NULL && GetRight() != NULL) {
GetLeft()->ResetValidationData();
GetRight()->ResetValidationData();
}
}
void TMVA::DecisionTreeNode::PrintPrune( ostream& os ) const {
os << "----------------------" << std::endl
<< "|~T_t| " << GetNTerminal() << std::endl
<< "R(t): " << GetNodeR() << std::endl
<< "R(T_t): " << GetSubTreeR() << std::endl
<< "g(t): " << GetAlpha() << std::endl
<< "G(t): " << GetAlphaMinSubtree() << std::endl;
}
void TMVA::DecisionTreeNode::PrintRecPrune( ostream& os ) const {
this->PrintPrune(os);
if(this->GetLeft() != NULL && this->GetRight() != NULL) {
((DecisionTreeNode*)this->GetLeft())->PrintRecPrune(os);
((DecisionTreeNode*)this->GetRight())->PrintRecPrune(os);
}
}
void TMVA::DecisionTreeNode::SetCC(Double_t cc)
{
if (fTrainInfo) fTrainInfo->fCC = cc;
else *fgLogger << kFATAL << "call to SetCC without trainingInfo" << Endl;
}
Float_t TMVA::DecisionTreeNode::GetSampleMin(UInt_t ivar) const {
if (fTrainInfo && ivar < fTrainInfo->fSampleMin.size()) return fTrainInfo->fSampleMin[ivar];
else *fgLogger << kFATAL << "You asked for Min of the event sample in node for variable "
<< ivar << " that is out of range" << Endl;
return -9999;
}
Float_t TMVA::DecisionTreeNode::GetSampleMax(UInt_t ivar) const {
if (fTrainInfo && ivar < fTrainInfo->fSampleMin.size()) return fTrainInfo->fSampleMax[ivar];
else *fgLogger << kFATAL << "You asked for Max of the event sample in node for variable "
<< ivar << " that is out of range" << Endl;
return 9999;
}
void TMVA::DecisionTreeNode::SetSampleMin(UInt_t ivar, Float_t xmin){
if ( fTrainInfo) {
if ( ivar >= fTrainInfo->fSampleMin.size()) fTrainInfo->fSampleMin.resize(ivar+1);
fTrainInfo->fSampleMin[ivar]=xmin;
}
}
void TMVA::DecisionTreeNode::SetSampleMax(UInt_t ivar, Float_t xmax){
if( ! fTrainInfo ) return;
if ( ivar >= fTrainInfo->fSampleMax.size() )
fTrainInfo->fSampleMax.resize(ivar+1);
fTrainInfo->fSampleMax[ivar]=xmax;
}
void TMVA::DecisionTreeNode::ReadAttributes(void* node, UInt_t )
{
Float_t tempNSigEvents,tempNBkgEvents;
Int_t nCoef;
if (gTools().HasAttr(node, "NCoef")){
gTools().ReadAttr(node, "NCoef", nCoef );
this->SetNFisherCoeff(nCoef);
Double_t tmp;
for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++) {
gTools().ReadAttr(node, Form("fC%d",i), tmp );
this->SetFisherCoeff(i,tmp);
}
}else{
this->SetNFisherCoeff(0);
}
gTools().ReadAttr(node, "IVar", fSelector );
gTools().ReadAttr(node, "Cut", fCutValue );
gTools().ReadAttr(node, "cType", fCutType );
if (gTools().HasAttr(node,"res")) gTools().ReadAttr(node, "res", fResponse);
if (gTools().HasAttr(node,"rms")) gTools().ReadAttr(node, "rms", fRMS);
if( gTools().HasAttr(node, "purity") ) {
gTools().ReadAttr(node, "purity",fPurity );
} else {
gTools().ReadAttr(node, "nS", tempNSigEvents );
gTools().ReadAttr(node, "nB", tempNBkgEvents );
fPurity = tempNSigEvents / (tempNSigEvents + tempNBkgEvents);
}
gTools().ReadAttr(node, "nType", fNodeType );
}
void TMVA::DecisionTreeNode::AddAttributesToNode(void* node) const
{
gTools().AddAttr(node, "NCoef", GetNFisherCoeff());
for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++)
gTools().AddAttr(node, Form("fC%d",i), this->GetFisherCoeff(i));
gTools().AddAttr(node, "IVar", GetSelector());
gTools().AddAttr(node, "Cut", GetCutValue());
gTools().AddAttr(node, "cType", GetCutType());
gTools().AddAttr(node, "res", GetResponse());
gTools().AddAttr(node, "rms", GetRMS());
gTools().AddAttr(node, "purity",GetPurity());
gTools().AddAttr(node, "nType", GetNodeType());
}
void TMVA::DecisionTreeNode::SetFisherCoeff(Int_t ivar, Double_t coeff)
{
if ((Int_t) fFisherCoeff.size()<ivar+1) fFisherCoeff.resize(ivar+1) ;
fFisherCoeff[ivar]=coeff;
}
void TMVA::DecisionTreeNode::AddContentToNode( std::stringstream& ) const
{
}
void TMVA::DecisionTreeNode::ReadContent( std::stringstream& )
{
}