* *
* File and Version Information: *
* $Id: DecisionTree.cxx,v 1.4 2006/05/31 14:01:33 rdm Exp $
**********************************************************************************/
#include <algorithm>
#include "Riostream.h"
#include "TVirtualFitter.h"
#include "TMVA/DecisionTree.h"
#include "TMVA/DecisionTreeNode.h"
#include "TMVA/BinarySearchTree.h"
#include "TMVA/Tools.h"
#include "TMVA/GiniIndex.h"
#include "TMVA/CrossEntropy.h"
#include "TMVA/MisClassificationError.h"
#include "TMVA/SdivSqrtSplusB.h"
using std::vector;
ClassImp(TMVA::DecisionTree)
TMVA::DecisionTree::DecisionTree( void ):
fNvars (0),
fNCuts (-1),
fSepType (new TMVA::GiniIndex()),
fMinSize (0)
{
}
TMVA::DecisionTree::DecisionTree( TMVA::SeparationBase *sepType,Int_t minSize,
Int_t nCuts):
fNvars (0),
fNCuts (nCuts),
fSepType (sepType),
fMinSize (minSize)
{
}
TMVA::DecisionTree::~DecisionTree( void )
{
}
Int_t TMVA::DecisionTree::BuildTree( vector<TMVA::Event*> & eventSample,
TMVA::DecisionTreeNode *node )
{
if (node==NULL) {
node = new TMVA::DecisionTreeNode();
fNNodes++;
fSumOfWeights+=1.;
this->SetRoot(node);
}
UInt_t nevents = eventSample.size();
if (nevents > 0 ) fNvars = eventSample[0]->GetEventSize();
else{
cout << "--- TMVA::DecisionTree::BuildTree: Error, Eventsample Size == 0 " <<endl;
exit(1);
}
Double_t s=0, b=0;
for (UInt_t i=0; i<eventSample.size(); i++){
if (eventSample[i]->GetType()==0) b+= eventSample[i]->GetWeight();
else if (eventSample[i]->GetType()==1) s+= eventSample[i]->GetWeight();
}
node->SetSoverSB(s/(s+b));
node->SetSeparationIndex(fSepType->GetSeparationIndex(s,b));
if ( eventSample.size() > fMinSize &&
node->GetSoverSB()*eventSample.size() > fMinSize &&
node->GetSoverSB()*eventSample.size() < eventSample.size()-fMinSize ) {
Double_t separationGain;
separationGain = this->TrainNode(eventSample, node);
vector<TMVA::Event*> leftSample; leftSample.reserve(nevents);
vector<TMVA::Event*> rightSample; rightSample.reserve(nevents);
Double_t nRight=0, nLeft=0;
for (UInt_t ie=0; ie< nevents ; ie++){
if (node->GoesRight(eventSample[ie])){
rightSample.push_back(eventSample[ie]);
nRight += eventSample[ie]->GetWeight();
}
else {
leftSample.push_back(eventSample[ie]);
nLeft += eventSample[ie]->GetWeight();
}
}
if (leftSample.size() == 0 || rightSample.size() == 0) {
cout << "--- DecisionTree::TrainNode Error: all events went to the same branch\n";
cout << "--- Hence new node == old node ... check\n";
cout << "--- left:" << leftSample.size()
<< " right:" << rightSample.size() << endl;
cout << "--- this should never happen, please write a bug report to Helge.Voss@cern.ch"
<< endl;
exit(1);
}
TMVA::DecisionTreeNode *rightNode = new TMVA::DecisionTreeNode(node);
fNNodes++;
fSumOfWeights += 1.0;
rightNode->SetNEvents(nRight);
TMVA::DecisionTreeNode *leftNode = new TMVA::DecisionTreeNode(node);
fNNodes++;
fSumOfWeights += 1.0;
leftNode->SetNEvents(nLeft);
node->SetNodeType(0);
node->SetLeft(leftNode);
node->SetRight(rightNode);
this->BuildTree(rightSample, rightNode);
this->BuildTree(leftSample, leftNode );
} else{
if (node->GetSoverSB() > 0.5) node->SetNodeType(1);
else node->SetNodeType(-1);
}
return fNNodes;
}
Double_t TMVA::DecisionTree::TrainNode(vector<TMVA::Event*> & eventSample,
TMVA::DecisionTreeNode *node)
{
vector<Double_t> *xmin = new vector<Double_t>( fNvars );
vector<Double_t> *xmax = new vector<Double_t>( fNvars );
Double_t separation = -1;
Double_t cutValue=-999;
Int_t mxVar=-1;
Bool_t cutType=kTRUE;
Double_t nSelS, nSelB, nTotS, nTotB;
TMVA::BinarySearchTree *sigBST=NULL;
TMVA::BinarySearchTree *bkgBST=NULL;
fUseSearchTree = kTRUE;
if (eventSample.size() < 30000) fUseSearchTree = kFALSE;
for (int ivar=0; ivar < fNvars; ivar++){
(*xmin)[ivar]=(*xmax)[ivar]=eventSample[0]->GetData(ivar);
}
for (UInt_t i=1;i<eventSample.size();i++){
for (Int_t ivar=0; ivar < fNvars; ivar++){
if ((*xmin)[ivar]>eventSample[i]->GetData(ivar))(*xmin)[ivar]=eventSample[i]->GetData(ivar);
if ((*xmax)[ivar]<eventSample[i]->GetData(ivar))(*xmax)[ivar]=eventSample[i]->GetData(ivar);
}
}
for (int ivar=0; ivar < fNvars; ivar++){
if (fUseSearchTree) {
sigBST = new TMVA::BinarySearchTree();
bkgBST = new TMVA::BinarySearchTree();
vector<Int_t> theVars;
theVars.push_back(ivar);
sigBST->Fill( eventSample, theVars, 1 );
bkgBST->Fill( eventSample, theVars, 0 );
}
Double_t istepSize =( (*xmax)[ivar] - (*xmin)[ivar] ) / Double_t(fNCuts);
Int_t nCuts = fNCuts;
vector<Double_t> cutValueTmp(nCuts);
vector<Double_t> sep(nCuts);
vector<Bool_t> cutTypeTmp(nCuts);
for (Int_t istep=0; istep<fNCuts; istep++){
cutValueTmp[istep]=(*xmin)[ivar]+(Float_t(istep)+0.5)*istepSize;
if (fUseSearchTree){
TMVA::Volume volume(cutValueTmp[istep], (*xmax)[ivar]);
nSelS = sigBST->SearchVolume( &volume );
nSelB = bkgBST->SearchVolume( &volume );
nTotS = sigBST->GetSumOfWeights();
nTotB = bkgBST->GetSumOfWeights();
}else{
nSelS=0; nSelB=0; nTotS=0; nTotB=0;
for (UInt_t i=0; i<eventSample.size(); i++){
if (eventSample[i]->GetType()==1){
nTotS+=eventSample[i]->GetWeight();
if (eventSample[i]->GetData(ivar) > cutValueTmp[istep]) nSelS+=eventSample[i]->GetWeight();
}else if (eventSample[i]->GetType()==0){
nTotB+=eventSample[i]->GetWeight();
if (eventSample[i]->GetData(ivar) > cutValueTmp[istep]) nSelB+=eventSample[i]->GetWeight();
}
}
}
if (nSelS/nTotS > nSelB/nTotB) cutTypeTmp[istep]=kTRUE;
else cutTypeTmp[istep]=kFALSE;
sep[istep]= fSepType->GetSeparationGain(nSelS, nSelB, nTotS, nTotB);
}
Int_t pos = TMVA::Tools::GetIndexMaxElement(sep);
if (separation < sep[pos]) {
separation = sep[pos];
cutValue=cutValueTmp[pos];
cutType=cutTypeTmp[pos];
mxVar = ivar;
}
if (fUseSearchTree) {
if (sigBST!=NULL) delete sigBST;
if (bkgBST!=NULL) delete bkgBST;
}
}
node->SetSelector(mxVar);
node->SetCutValue(cutValue);
node->SetCutType(cutType);
node->SetSeparationGain(separation);
delete xmin;
delete xmax;
return separation;
}
Double_t TMVA::DecisionTree::CheckEvent(TMVA::Event* e)
{
TMVA::DecisionTreeNode *current = (TMVA::DecisionTreeNode*)this->GetRoot();
while(current->GetNodeType() == 0){
if (current->GoesRight(e))
current=(TMVA::DecisionTreeNode*)current->GetRight();
else current=(TMVA::DecisionTreeNode*)current->GetLeft();
}
return current->GetSoverSB();
}
Double_t TMVA::DecisionTree::SamplePurity(vector<TMVA::Event*> eventSample)
{
Double_t sumsig=0, sumbkg=0, sumtot=0;
for (UInt_t ievt=0; ievt<eventSample.size(); ievt++) {
if (eventSample[ievt]->GetType()==0) sumbkg+=eventSample[ievt]->GetWeight();
if (eventSample[ievt]->GetType()==1) sumsig+=eventSample[ievt]->GetWeight();
sumtot+=eventSample[ievt]->GetWeight();
}
if (sumtot!= (sumsig+sumbkg)){
cout << "--- TMVA::DecisionTree::Purity Error! sumtot != sumsig+sumbkg"
<< sumtot << " " << sumsig << " " << sumbkg << endl;
exit(1);
}
if (sumtot>0) return sumsig/(sumsig + sumbkg);
else return -1;
}
ROOT page - Class index - Class Hierarchy - Top of the page
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.