#include <iostream>
#include <algorithm>
#include <vector>
#include "TMath.h"
#include "TMVA/DecisionTree.h"
#include "TMVA/DecisionTreeNode.h"
#include "TMVA/BinarySearchTree.h"
#include "TMVA/Tools.h"
#include "TMVA/GiniIndex.h"
#include "TMVA/CrossEntropy.h"
#include "TMVA/MisClassificationError.h"
#include "TMVA/SdivSqrtSplusB.h"
#include "TMVA/Event.h"
using std::vector;
#define USE_HELGESCODE 1 // the other one is Dougs implementation of the TrainNode
#define USE_HELGE_V1 0 // out loop is over NVAR in TrainNode, inner loop is Eventloop
ClassImp(TMVA::DecisionTree)
TMVA::DecisionTree::DecisionTree( void )
: BinaryTree(),
fNvars (0),
fNCuts (-1),
fSepType (NULL),
fMinSize (0),
fPruneMethod(kCostComplexityPruning),
fRandomisedTree (kFALSE),
fUseNvars (0),
fMyTrandom (NULL),
fQualityIndex(NULL)
{
fLogger.SetSource( "DecisionTree" );
fMyTrandom = new TRandom2(0);
}
TMVA::DecisionTree::DecisionTree( DecisionTreeNode* n )
: BinaryTree(),
fNvars (0),
fNCuts (-1),
fSepType (NULL),
fMinSize (0),
fPruneMethod(kCostComplexityPruning),
fRandomisedTree (kFALSE),
fUseNvars (0),
fMyTrandom (NULL),
fQualityIndex(NULL)
{
fLogger.SetSource( "DecisionTree" );
this->SetRoot( n );
this->SetParentTreeInNodes();
fLogger.SetSource( "DecisionTree" );
}
TMVA::DecisionTree::DecisionTree( TMVA::SeparationBase *sepType,Int_t minSize,
Int_t nCuts, TMVA::SeparationBase *qtype,
Bool_t randomisedTree, Int_t useNvars, Int_t iSeed):
BinaryTree(),
fNvars (0),
fNCuts (nCuts),
fSepType (sepType),
fMinSize (minSize),
fPruneMethod(kCostComplexityPruning),
fRandomisedTree (randomisedTree),
fUseNvars (useNvars),
fMyTrandom (NULL),
fQualityIndex(qtype)
{
fLogger.SetSource( "DecisionTree" );
fMyTrandom = new TRandom2(iSeed);
}
TMVA::DecisionTree::DecisionTree( const DecisionTree &d):
BinaryTree(),
fNvars (d.fNvars),
fNCuts (d.fNCuts),
fSepType (d.fSepType),
fMinSize (d.fMinSize),
fPruneMethod(d.fPruneMethod),
fRandomisedTree (d.fRandomisedTree),
fUseNvars (d.fUseNvars),
fMyTrandom (NULL),
fQualityIndex(d.fQualityIndex)
{
this->SetRoot( new DecisionTreeNode ( *((DecisionTreeNode*)(d.GetRoot())) ) );
this->SetParentTreeInNodes();
fNNodes = d.fNNodes;
fLogger.SetSource( "DecisionTree" );
}
TMVA::DecisionTree::~DecisionTree( void )
{
if (fMyTrandom) delete fMyTrandom;
}
void TMVA::DecisionTree::SetParentTreeInNodes( DecisionTreeNode *n)
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
if (n == NULL) {
fLogger << kFATAL << "SetParentTreeNodes: started with undefined ROOT node" <<Endl;
return ;
}
}
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) != NULL) ) {
fLogger << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
return;
} else if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) == NULL) ) {
fLogger << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
return;
}
else {
if (this->GetLeftDaughter(n) != NULL){
this->SetParentTreeInNodes( this->GetLeftDaughter(n) );
}
if (this->GetRightDaughter(n) != NULL) {
this->SetParentTreeInNodes( this->GetRightDaughter(n) );
}
}
n->SetParentTree(this);
if (n->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(n->GetDepth());
return;
}
Int_t TMVA::DecisionTree::BuildTree( vector<TMVA::Event*> & eventSample,
TMVA::DecisionTreeNode *node)
{
if (node==NULL) {
node = new TMVA::DecisionTreeNode();
fNNodes = 1;
this->SetRoot(node);
this->GetRoot()->SetPos('s');
this->GetRoot()->SetDepth(0);
this->GetRoot()->SetParentTree(this);
}
UInt_t nevents = eventSample.size();
if (nevents > 0 ) {
fNvars = eventSample[0]->GetNVars();
fVariableImportance.resize(fNvars);
}
else fLogger << kFATAL << ":<BuildTree> eventsample Size == 0 " << Endl;
Double_t s=0, b=0;
Double_t suw=0, buw=0;
for (UInt_t i=0; i<eventSample.size(); i++){
if (eventSample[i]->IsSignal()){
s += eventSample[i]->GetWeight();
suw += 1;
}
else {
b += eventSample[i]->GetWeight();
buw += 1;
}
}
if (s+b < 0){
fLogger << kWARNING << " One of the Decision Tree nodes has negative total number of signal or background events. (Nsig="<<s<<" Nbkg="<<b<<" Probaby you use a Monte Carlo with negative weights. That should in principle be fine as long as on average you end up with something positive. For this you have to make sure that the minimul number of (unweighted) events demanded for a tree node (currently you use: nEventsMin="<<fMinSize<<", you can set this via the BDT option string when booking the classifier) is large enough to allow for reasonable averaging!!!"<<Endl
<< " If this does not help.. maybe you want to try the option: NoNegWeightsInTraining which ignores events with negative weight in the training. " << Endl;
double nBkg=0.;
for (UInt_t i=0; i<eventSample.size(); i++){
if (! eventSample[i]->IsSignal()){
nBkg += eventSample[i]->GetWeight();
cout << "Event "<< i<< " has (original) weight: " << eventSample[i]->GetWeight()/eventSample[i]->GetBoostWeight()
<< " boostWeight: " << eventSample[i]->GetBoostWeight() << endl;
}
}
cout << " that gives in total: " << nBkg<<endl;
}
node->SetNSigEvents(s);
node->SetNBkgEvents(b);
node->SetNSigEvents_unweighted(suw);
node->SetNBkgEvents_unweighted(buw);
if (node == this->GetRoot()) {
node->SetNEvents(s+b);
node->SetNEvents_unweighted(suw+buw);
}
node->SetSeparationIndex(fSepType->GetSeparationIndex(s,b));
if ( eventSample.size() >= 2*fMinSize){
Double_t separationGain;
separationGain = this->TrainNode(eventSample, node);
if (separationGain == 0) {
if (node->GetPurity() > 0.5) node->SetNodeType(1);
else node->SetNodeType(-1);
if (node->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(node->GetDepth());
}
else {
vector<TMVA::Event*> leftSample; leftSample.reserve(nevents);
vector<TMVA::Event*> rightSample; rightSample.reserve(nevents);
Double_t nRight=0, nLeft=0;
for (UInt_t ie=0; ie< nevents ; ie++){
if (node->GoesRight(*eventSample[ie])){
rightSample.push_back(eventSample[ie]);
nRight += eventSample[ie]->GetWeight();
}
else {
leftSample.push_back(eventSample[ie]);
nLeft += eventSample[ie]->GetWeight();
}
}
if (leftSample.size() == 0 || rightSample.size() == 0) {
fLogger << kFATAL << "<TrainNode> all events went to the same branch" << Endl
<< "--- Hence new node == old node ... check" << Endl
<< "--- left:" << leftSample.size()
<< " right:" << rightSample.size() << Endl
<< "--- this should never happen, please write a bug report to Helge.Voss@cern.ch"
<< Endl;
}
TMVA::DecisionTreeNode *rightNode = new TMVA::DecisionTreeNode(node,'r');
fNNodes++;
rightNode->SetNEvents(nRight);
rightNode->SetNEvents_unweighted(rightSample.size());
TMVA::DecisionTreeNode *leftNode = new TMVA::DecisionTreeNode(node,'l');
fNNodes++;
leftNode->SetNEvents(nLeft);
leftNode->SetNEvents_unweighted(leftSample.size());
node->SetNodeType(0);
node->SetLeft(leftNode);
node->SetRight(rightNode);
this->BuildTree(rightSample, rightNode);
this->BuildTree(leftSample, leftNode );
}
}
else{
if (node->GetPurity() > 0.5) node->SetNodeType(1);
else node->SetNodeType(-1);
if (node->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(node->GetDepth());
}
return fNNodes;
}
void TMVA::DecisionTree::FillTree( vector<TMVA::Event*> & eventSample)
{
for (UInt_t i=0; i<eventSample.size(); i++){
this->FillEvent(*(eventSample[i]),NULL);
}
}
void TMVA::DecisionTree::FillEvent( TMVA::Event & event,
TMVA::DecisionTreeNode *node )
{
if (node == NULL) {
node = (TMVA::DecisionTreeNode*)this->GetRoot();
}
node->IncrementNEvents( event.GetWeight() );
node->IncrementNEvents_unweighted( );
if (event.IsSignal()){
node->IncrementNSigEvents( event.GetWeight() );
node->IncrementNSigEvents_unweighted( );
}
else {
node->IncrementNBkgEvents( event.GetWeight() );
node->IncrementNSigEvents_unweighted( );
}
node->SetSeparationIndex(fSepType->GetSeparationIndex(node->GetNSigEvents(),
node->GetNBkgEvents()));
if (node->GetNodeType() == 0){
if (node->GoesRight(event))
this->FillEvent(event,(TMVA::DecisionTreeNode*)(node->GetRight())) ;
else
this->FillEvent(event,(TMVA::DecisionTreeNode*)(node->GetLeft())) ;
}
}
void TMVA::DecisionTree::ClearTree()
{
if (this->GetRoot()!=NULL)
((DecisionTreeNode*)(this->GetRoot()))->ClearNodeAndAllDaughters();
}
void TMVA::DecisionTree::CleanTree(DecisionTreeNode *node)
{
if (node==NULL){
node = (DecisionTreeNode *)this->GetRoot();
}
DecisionTreeNode *l = (DecisionTreeNode*)node->GetLeft();
DecisionTreeNode *r = (DecisionTreeNode*)node->GetRight();
if (node->GetNodeType() == 0){
this->CleanTree(l);
this->CleanTree(r);
if (l->GetNodeType() * r->GetNodeType() > 0 ){
this->PruneNode(node);
}
}
}
void TMVA::DecisionTree::PruneTree()
{
if (fPruneMethod == kExpectedErrorPruning) this->PruneTreeEEP((DecisionTreeNode *)this->GetRoot());
else if (fPruneMethod == kCostComplexityPruning) this->PruneTreeCC();
else if (fPruneMethod == kMCC) this->PruneTreeMCC();
else {
fLogger << kFATAL << "Selected pruning method not yet implemented "
<< Endl;
}
this->CountNodes();
};
void TMVA::DecisionTree::PruneTreeEEP(DecisionTreeNode *node)
{
DecisionTreeNode *l = (DecisionTreeNode*)node->GetLeft();
DecisionTreeNode *r = (DecisionTreeNode*)node->GetRight();
if (node->GetNodeType() == 0){
this->PruneTreeEEP(l);
this->PruneTreeEEP(r);
if (this->GetSubTreeError(node) >= this->GetNodeError(node)) {
this->PruneNode(node);
}
}
}
void TMVA::DecisionTree::PruneTreeCC()
{
Double_t currentCC = this->GetCostComplexity(fPruneStrength);
Double_t nextCC = this->GetCostComplexityIfNextPruneStep(fPruneStrength);
while (currentCC > nextCC && this->GetNNodes() > 3 ){
this->PruneNode( this->FindCCPruneCandidate() );
currentCC = this->GetCostComplexity(fPruneStrength);
nextCC = this->GetCostComplexityIfNextPruneStep(fPruneStrength);
}
return;
}
void TMVA::DecisionTree::PruneTreeMCC()
{
this->FillLinkStrengthMap();
Double_t currentG = fLinkStrengthMap.begin()->first;
while (currentG < fPruneStrength && this->GetNNodes() > 3 ){
this->PruneNode( this->GetWeakestLink() );
currentG = fLinkStrengthMap.begin()->first;
}
return;
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::GetWeakestLink()
{
this->FillLinkStrengthMap();
return fLinkStrengthMap.begin()->second;
}
void TMVA::DecisionTree::FillLinkStrengthMap(TMVA::DecisionTreeNode *n)
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
fLinkStrengthMap.clear();
if (n == NULL) {
fLogger << kFATAL << "FillLinkStrengthMap: started with undefined ROOT node" <<Endl;
return ;
}
}
if (this->GetLeftDaughter(n) != NULL){
this->FillLinkStrengthMap( this->GetLeftDaughter(n));
}
if (this->GetRightDaughter(n) != NULL) {
this->FillLinkStrengthMap( this->GetRightDaughter(n));
}
if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) != NULL) ) {
Double_t alpha = ( this->MisClassificationCostOfNode(n) -
this->MisClassificationCostOfSubTree(n) ) /
(n->CountMeAndAllDaughters() - 1);
fLinkStrengthMap.insert(std::pair<const Double_t, TMVA::DecisionTreeNode* > ( alpha, n ));
}
}
Double_t TMVA::DecisionTree::MisClassificationCostOfNode(TMVA::DecisionTreeNode *n)
{
return (1 - n->GetPurity()) * n->GetNEvents();
}
Double_t TMVA::DecisionTree::MisClassificationCostOfSubTree(TMVA::DecisionTreeNode *n)
{
Double_t tmp=0;
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
if (n == NULL) {
fLogger << kFATAL << "MisClassificationCostOfSubTree: started with undefined ROOT node" <<Endl;
return 0.;
}
}
if (this->GetLeftDaughter(n) != NULL){
tmp += this->MisClassificationCostOfSubTree( this->GetLeftDaughter(n));
}
if (this->GetRightDaughter(n) != NULL) {
tmp += this->MisClassificationCostOfSubTree( this->GetRightDaughter(n));
}
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
tmp = this->MisClassificationCostOfNode(n);
}
return tmp;
}
UInt_t TMVA::DecisionTree::CountLeafNodes(TMVA::DecisionTreeNode *n)
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
if (n == NULL) {
fLogger << kFATAL << "CountLeafNodes: started with undefined ROOT node" <<Endl;
return 0;
}
}
UInt_t countLeafs=0;
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
countLeafs += 1;
}
else {
if (this->GetLeftDaughter(n) != NULL){
countLeafs += this->CountLeafNodes( this->GetLeftDaughter(n) );
}
if (this->GetRightDaughter(n) != NULL) {
countLeafs += this->CountLeafNodes( this->GetRightDaughter(n) );
}
}
return countLeafs;
}
Double_t TMVA::DecisionTree::GetCostComplexity(Double_t alpha)
{
Double_t cc=0.;
this->FillQualityMap();
std::multimap<Double_t, TMVA::DecisionTreeNode* >::iterator it=fQualityMap.begin();
Int_t count=0;
for (;it!=fQualityMap.end(); it++){
Double_t s=it->second->GetNSigEvents_unweighted();
Double_t b=it->second->GetNBkgEvents_unweighted();
cc += (s+b) * it->first ;
count++;
}
return cc+alpha * count;
}
Double_t TMVA::DecisionTree::GetCostComplexityIfNextPruneStep(Double_t alpha)
{
Double_t cc=0.;
this->FillQualityMap();
this->FillQualityGainMap();
if (fQualityMap.size() == 0 ){
fLogger << kError << "The Quality Map in the BDT-Pruning is empty.. maybe your Tree has "
<< " absolutely no splits ?? e.g.. minimun number of events for node splitting"
<< " being larger than the number of events available ??? " << Endl;
}
else if (fQualityGainMap.size() == 0 ){
fLogger << kError << "The QualityGain Map in the BDT-Pruning is empty.. This can happen"
<< " if your Tree has absolutely no splits ?? e.g.. minimun number of events for"
<< " node splitting being larger than the number of events available ??? " << Endl;
}
else {
std::multimap<Double_t, TMVA::DecisionTreeNode* >::iterator it=fQualityMap.begin();
Int_t count=0;
for (;it!=fQualityMap.end(); it++){
if (it->second->GetParent() != fQualityGainMap.begin()->second ) {
Double_t s=it->second->GetNSigEvents_unweighted();
Double_t b=it->second->GetNBkgEvents_unweighted();
cc += (s+b) * it->first ;
count++;
}
}
Double_t s=fQualityGainMap.begin()->second->GetNSigEvents_unweighted();
Double_t b=fQualityGainMap.begin()->second->GetNBkgEvents_unweighted();
cc += (s+b) * fQualityIndex->GetSeparationIndex(s,b);
count++;
cc+=alpha*count;
}
return cc;
}
void TMVA::DecisionTree::FillQualityGainMap(DecisionTreeNode* n )
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
fQualityGainMap.clear();
if (n == NULL) {
fLogger << kFATAL << "FillQualityGainMap: started with undefined ROOT node" <<Endl;
return ;
}
}
if (this->GetLeftDaughter(n) != NULL){
this->FillQualityGainMap( this->GetLeftDaughter(n));
}
if (this->GetRightDaughter(n) != NULL) {
this->FillQualityGainMap( this->GetRightDaughter(n));
}
if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) != NULL) ) {
if ((this->GetLeftDaughter(n)->GetLeft() == NULL) &&
(this->GetLeftDaughter(n)->GetRight() == NULL) &&
(this->GetRightDaughter(n)->GetLeft() == NULL) &&
(this->GetRightDaughter(n)->GetRight() == NULL) ){
fQualityGainMap.insert(std::pair<const Double_t, TMVA::DecisionTreeNode* >
( fQualityIndex->GetSeparationGain (this->GetRightDaughter(n)->GetNSigEvents_unweighted(),
this->GetRightDaughter(n)->GetNBkgEvents_unweighted(),
n->GetNSigEvents_unweighted(), n->GetNBkgEvents_unweighted()),
n));
}
}
return;
}
void TMVA::DecisionTree::FillQualityMap(DecisionTreeNode* n )
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
fQualityMap.clear();
if (n == NULL) {
fLogger << kFATAL << "FillQualityMap: started with undefined ROOT node" <<Endl;
return ;
}
}
if (this->GetLeftDaughter(n) != NULL){
this->FillQualityMap( this->GetLeftDaughter(n));
}
if (this->GetRightDaughter(n) != NULL) {
this->FillQualityMap( this->GetRightDaughter(n));
}
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
fQualityMap.insert(std::pair<const Double_t, TMVA::DecisionTreeNode* >
( fQualityIndex->GetSeparationIndex (n->GetNSigEvents_unweighted(),
n->GetNBkgEvents_unweighted()),
n));
}
return;
}
void TMVA::DecisionTree::DescendTree( DecisionTreeNode *n)
{
if (n == NULL){
n = (DecisionTreeNode*) this->GetRoot();
if (n == NULL) {
fLogger << kFATAL << "DescendTree: started with undefined ROOT node" <<Endl;
return ;
}
}
if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
}
else if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) != NULL) ) {
fLogger << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
return;
}
else if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) == NULL) ) {
fLogger << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
return;
}
else {
if (this->GetLeftDaughter(n) != NULL){
this->DescendTree( this->GetLeftDaughter(n) );
}
if (this->GetRightDaughter(n) != NULL) {
this->DescendTree( this->GetRightDaughter(n) );
}
}
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::FindCCPruneCandidate()
{
this->FillQualityGainMap();
return fQualityGainMap.begin()->second;
}
void TMVA::DecisionTree::PruneNode(DecisionTreeNode *node)
{
DecisionTreeNode *l = (DecisionTreeNode*)node->GetLeft();
DecisionTreeNode *r = (DecisionTreeNode*)node->GetRight();
node->SetRight(NULL);
node->SetLeft(NULL);
node->SetSelector(-1);
node->SetSeparationIndex(-1);
node->SetSeparationGain(-1);
if (node->GetPurity() > 0.5) node->SetNodeType(1);
else node->SetNodeType(-1);
this->DeleteNode(l);
this->DeleteNode(r);
this->CountNodes();
}
Double_t TMVA::DecisionTree::GetNodeError(DecisionTreeNode *node)
{
Double_t errorRate = 0;
Double_t nEvts = node->GetNEvents();
Double_t f=0;
if (node->GetPurity() > 0.5) f = node->GetPurity();
else f = (1-node->GetPurity());
Double_t df = TMath::Sqrt(f*(1-f)/nEvts );
errorRate = std::min(1.,(1 - (f-fPruneStrength*df) ));
return errorRate;
}
Double_t TMVA::DecisionTree::GetSubTreeError(DecisionTreeNode *node)
{
DecisionTreeNode *l = (DecisionTreeNode*)node->GetLeft();
DecisionTreeNode *r = (DecisionTreeNode*)node->GetRight();
if (node->GetNodeType() == 0) {
Double_t subTreeError =
(l->GetNEvents() * this->GetSubTreeError(l) +
r->GetNEvents() * this->GetSubTreeError(r)) /
node->GetNEvents();
return subTreeError;
}
else {
return this->GetNodeError(node);
}
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::GetLeftDaughter( DecisionTreeNode *n)
{
return (DecisionTreeNode*) n->GetLeft();
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::GetRightDaughter( DecisionTreeNode *n)
{
return (DecisionTreeNode*) n->GetRight();
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::GetNode(ULong_t sequence, UInt_t depth)
{
DecisionTreeNode* current = (DecisionTreeNode*) this->GetRoot();
for (UInt_t i =0; i < depth; i++){
ULong_t tmp = 1 << i;
if ( tmp & sequence) current = this->GetRightDaughter(current);
else current = this->GetLeftDaughter(current);
}
return current;
}
void TMVA::DecisionTree::FindMinAndMax(vector<TMVA::Event*> & eventSample,
vector<Double_t> & xmin,
vector<Double_t> & xmax)
{
UInt_t num_events = eventSample.size();
for (Int_t ivar=0; ivar < fNvars; ivar++){
xmin[ivar]=xmax[ivar]=eventSample[0]->GetVal(ivar);
}
for (UInt_t i=1;i<num_events;i++){
for (Int_t ivar=0; ivar < fNvars; ivar++){
if (xmin[ivar]>eventSample[i]->GetVal(ivar))
xmin[ivar]=eventSample[i]->GetVal(ivar);
if (xmax[ivar]<eventSample[i]->GetVal(ivar))
xmax[ivar]=eventSample[i]->GetVal(ivar);
}
}
};
void TMVA::DecisionTree::SetCutPoints(vector<Double_t> & cut_points,
Double_t xmin,
Double_t xmax,
Int_t num_gridpoints)
{
Double_t step = (xmax - xmin)/num_gridpoints;
Double_t x = xmin + step/2;
for (Int_t j=0; j < num_gridpoints; j++){
cut_points[j] = x;
x += step;
}
};
Double_t TMVA::DecisionTree::TrainNode(vector<TMVA::Event*> & eventSample,
TMVA::DecisionTreeNode *node)
{
vector<Double_t> *xmin = new vector<Double_t>( fNvars );
vector<Double_t> *xmax = new vector<Double_t>( fNvars );
Double_t separationGain = -1, sepTmp;
Double_t cutValue=-999;
Int_t mxVar=-1, cutIndex=0;
Bool_t cutType=kTRUE;
Double_t nTotS, nTotB;
Int_t nTotS_unWeighted, nTotB_unWeighted;
UInt_t nevents = eventSample.size();
for (int ivar=0; ivar < fNvars; ivar++){
(*xmin)[ivar]=(*xmax)[ivar]=eventSample[0]->GetVal(ivar);
}
for (UInt_t iev=1;iev<nevents;iev++){
for (Int_t ivar=0; ivar < fNvars; ivar++){
Double_t eventData = eventSample[iev]->GetVal(ivar);
if ((*xmin)[ivar]>eventData)(*xmin)[ivar]=eventData;
if ((*xmax)[ivar]<eventData)(*xmax)[ivar]=eventData;
}
}
vector< vector<Double_t> > nSelS (fNvars);
vector< vector<Double_t> > nSelB (fNvars);
vector< vector<Int_t> > nSelS_unWeighted (fNvars);
vector< vector<Int_t> > nSelB_unWeighted (fNvars);
vector< vector<Double_t> > significance (fNvars);
vector< vector<Double_t> > cutValues(fNvars);
vector< vector<Bool_t> > cutTypes(fNvars);
vector<Bool_t> useVariable(fNvars);
for (int ivar=0; ivar < fNvars; ivar++) useVariable[ivar]=kFALSE;
if (fRandomisedTree) {
if (fUseNvars==0) {
if (fNvars < 12) fUseNvars = TMath::Max(2,Int_t( Float_t(fNvars) / 2.5 ));
else if (fNvars < 40) fUseNvars = Int_t( Float_t(fNvars) / 5 );
else fUseNvars = Int_t( Float_t(fNvars) / 10 );
}
Int_t nSelectedVars = 0;
while ( nSelectedVars < fUseNvars ){
Double_t bla = fMyTrandom->Rndm()*fNvars;
useVariable[Int_t (bla)] = kTRUE;
for (int ivar=0; ivar < fNvars; ivar++) {
if (useVariable[ivar] == kTRUE) nSelectedVars++;
}
}
} else {
for (int ivar=0; ivar < fNvars; ivar++) useVariable[ivar] = kTRUE;
}
for (int ivar=0; ivar < fNvars; ivar++){
if ( useVariable[ivar] ) {
cutValues[ivar].resize(fNCuts);
cutTypes[ivar].resize(fNCuts);
nSelS[ivar].resize(fNCuts);
nSelB[ivar].resize(fNCuts);
nSelS_unWeighted[ivar].resize(fNCuts);
nSelB_unWeighted[ivar].resize(fNCuts);
significance[ivar].resize(fNCuts);
Double_t istepSize =( (*xmax)[ivar] - (*xmin)[ivar] ) / Double_t(fNCuts);
for (Int_t icut=0; icut<fNCuts; icut++){
cutValues[ivar][icut]=(*xmin)[ivar]+(Float_t(icut)+0.5)*istepSize;
}
}
}
nTotS=0; nTotB=0;
nTotS_unWeighted=0; nTotB_unWeighted=0;
for (UInt_t iev=0; iev<nevents; iev++){
Int_t eventType = eventSample[iev]->Type();
Double_t eventWeight = eventSample[iev]->GetWeight();
if (eventType==1){
nTotS+=eventWeight;
nTotS_unWeighted++;
}
else {
nTotB+=eventWeight;
nTotB_unWeighted++;
}
for (int ivar=0; ivar < fNvars; ivar++){
if ( useVariable[ivar] ) {
Double_t eventData = eventSample[iev]->GetVal(ivar);
for (Int_t icut=0; icut<fNCuts; icut++){
if (eventData > cutValues[ivar][icut]){
if (eventType==1) {
nSelS[ivar][icut]+=eventWeight;
nSelS_unWeighted[ivar][icut]++;
}
else {
nSelB[ivar][icut]+=eventWeight;
nSelB_unWeighted[ivar][icut]++;
}
}
}
}
}
}
for (int ivar=0; ivar < fNvars; ivar++) {
if ( useVariable[ivar] ){
for (Int_t icut=0; icut<fNCuts; icut++){
if ( (nSelS_unWeighted[ivar][icut] + nSelB_unWeighted[ivar][icut]) >= fMinSize &&
(( nTotS_unWeighted+nTotB_unWeighted)-
(nSelS_unWeighted[ivar][icut] + nSelB_unWeighted[ivar][icut])) >= fMinSize) {
sepTmp = fSepType->GetSeparationGain(nSelS[ivar][icut], nSelB[ivar][icut], nTotS, nTotB);
if (separationGain < sepTmp) {
separationGain = sepTmp;
mxVar = ivar;
cutIndex = icut;
}
}
}
}
}
if (mxVar >= 0) {
if (nSelS[mxVar][cutIndex]/nTotS > nSelB[mxVar][cutIndex]/nTotB) cutType=kTRUE;
else cutType=kFALSE;
cutValue = cutValues[mxVar][cutIndex];
node->SetSelector((UInt_t)mxVar);
node->SetCutValue(cutValue);
node->SetCutType(cutType);
node->SetSeparationGain(separationGain);
fVariableImportance[mxVar] += separationGain*separationGain * (nTotS+nTotB) * (nTotS+nTotB) ;
}
else {
separationGain = 0;
}
delete xmin;
delete xmax;
return separationGain;
}
Double_t TMVA::DecisionTree::CheckEvent(const TMVA::Event & e, Bool_t UseYesNoLeaf)
{
TMVA::DecisionTreeNode *current = (TMVA::DecisionTreeNode*)this->GetRoot();
while(current->GetNodeType() == 0){
if (current->GoesRight(e))
current=(TMVA::DecisionTreeNode*)current->GetRight();
else current=(TMVA::DecisionTreeNode*)current->GetLeft();
}
if (UseYesNoLeaf) return Double_t ( current->GetNodeType() );
else return current->GetPurity();
}
Double_t TMVA::DecisionTree::SamplePurity(vector<TMVA::Event*> eventSample)
{
Double_t sumsig=0, sumbkg=0, sumtot=0;
for (UInt_t ievt=0; ievt<eventSample.size(); ievt++) {
if (eventSample[ievt]->Type()==0) sumbkg+=eventSample[ievt]->GetWeight();
if (eventSample[ievt]->Type()==1) sumsig+=eventSample[ievt]->GetWeight();
sumtot+=eventSample[ievt]->GetWeight();
}
if (sumtot!= (sumsig+sumbkg)){
fLogger << kFATAL << "<SamplePurity> sumtot != sumsig+sumbkg"
<< sumtot << " " << sumsig << " " << sumbkg << Endl;
}
if (sumtot>0) return sumsig/(sumsig + sumbkg);
else return -1;
}
vector< Double_t > TMVA::DecisionTree::GetVariableImportance()
{
vector<Double_t> relativeImportance(fNvars);
Double_t sum=0;
for (int i=0; i< fNvars; i++) {
sum += fVariableImportance[i];
relativeImportance[i] = fVariableImportance[i];
}
for (int i=0; i< fNvars; i++) {
if (sum > std::numeric_limits<double>::epsilon())
relativeImportance[i] /= sum;
else
relativeImportance[i] = 0;
}
return relativeImportance;
}
Double_t TMVA::DecisionTree::GetVariableImportance(Int_t ivar)
{
vector<Double_t> relativeImportance = this->GetVariableImportance();
if (ivar >= 0 && ivar < fNvars) return relativeImportance[ivar];
else {
fLogger << kFATAL << "<GetVariableImportance>" << Endl
<< "--- ivar = " << ivar << " is out of range " << Endl;
}
return -1;
}
Last change: Wed Jun 25 08:48:09 2008
Last generated: 2008-06-25 08:48
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.