// @(#)root/tmva $Id: TMVA_DecisionTree.cxx,v 1.2 2006/05/09 08:37:06 brun Exp $
// Author: Andreas Hoecker, Helge Voss, Kai Voss

/**********************************************************************************
 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
 * Package: TMVA                                                                  *
 * Class  : TMVA_DecisionTree                                                     *
 *                                                                                *
 * Description:                                                                   *
 *      Implementation of a Decision Tree                                         *
 *                                                                                *
 * Authors (alphabetical):                                                        *
 *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
 *      Xavier Prudent  <prudent@lapp.in2p3.fr>  - LAPP, France                   *
 *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-KP Heidelberg, Germany     *
 *      Kai Voss        <Kai.Voss@cern.ch>       - U. of Victoria, Canada         *
 *                                                                                *
 * Copyright (c) 2005:                                                            *
 *      CERN, Switzerland,                                                        *
 *      U. of Victoria, Canada,                                                   *
 *      MPI-KP Heidelberg, Germany,                                               *
 *      LAPP, Annecy, France                                                      *
 *                                                                                *
 * Redistribution and use in source and binary forms, with or without             *
 * modification, are permitted according to the terms listed in LICENSE           *
 * (http://mva.sourceforge.net/license.txt)                                       *
 *                                                                                *
 * File and Version Information:                                                  *
 * $Id: TMVA_DecisionTree.cxx,v 1.2 2006/05/09 08:37:06 brun Exp $
 **********************************************************************************/

//_______________________________________________________________________
//
// Implementation of a Decision Tree
//
//_______________________________________________________________________

#include <iostream>
#include <algorithm>

#include "TVirtualFitter.h"

#include "TMVA_DecisionTree.h"
#include "TMVA_DecisionTreeNode.h"
#include "TMVA_BinarySearchTree.h"

#include "TMVA_Tools.h"

#include "TMVA_GiniIndex.h"
#include "TMVA_CrossEntropy.h"
#include "TMVA_MisClassificationError.h"
#include "TMVA_SdivSqrtSplusB.h"

using std::vector;

ClassImp(TMVA_DecisionTree)

//_______________________________________________________________________
 TMVA_DecisionTree::TMVA_DecisionTree( void )
{
  fNvars        = 0;
  fSepType      = new TMVA_GiniIndex();
  fNCuts        = -1;

//   fSoverSBUpperThreshold = 0;
//   fSoverSBLowerThreshold = 0;
  fMinSize               = 0;
  fMinSepGain   = 0.0003;

}

//_______________________________________________________________________
 TMVA_DecisionTree::TMVA_DecisionTree( TMVA_SeparationBase *sepType,Int_t minSize, Double_t mnsep,
                                      Int_t nCuts)
{
  fNvars        = 0;
  fSepType      = sepType;
  fNCuts        = nCuts;

//   fSoverSBUpperThreshold = mxp;
//   fSoverSBLowerThreshold = mnp;
  fMinSize               = minSize;
  fMinSepGain            = mnsep;
}

//_______________________________________________________________________
 TMVA_DecisionTree::~TMVA_DecisionTree( void )
{}

//_______________________________________________________________________
 void TMVA_DecisionTree::BuildTree( vector<TMVA_Event*> & eventSample,
                                   TMVA_DecisionTreeNode *node )
{
  if (node==NULL) {
    //start with the root node
    node = new TMVA_DecisionTreeNode();
    fNNodes++;
    fSumOfWeights+=1.;
    this->SetRoot(node);
  }

  UInt_t nevents = eventSample.size();
  if (nevents > 0 ) fNvars = eventSample[0]->GetEventSize();
  else{
    cout << "--- TMVA_DecisionTree::BuildTree:  Error, Eventsample Size == 0 " <<endl;
    exit(1);
  }

  Double_t s=0, b=0;
  for (UInt_t i=0; i<eventSample.size(); i++){
    if (eventSample[i]->GetType()==0) b+= eventSample[i]->GetWeight();
    else if (eventSample[i]->GetType()==1) s+= eventSample[i]->GetWeight();
  }
  node->SetSoverSB(s/(s+b));
  node->SetSeparationIndex(fSepType->GetSeparationIndex(s,b));

//   if ( eventSample.size() > fMinSize  &&
//        node->GetSoverSB() < fSoverSBUpperThreshold      &&
//        node->GetSoverSB() > fSoverSBLowerThreshold  ) {
  if ( eventSample.size() > fMinSize &&
       node->GetSoverSB()*eventSample.size() > fMinSize     &&
       node->GetSoverSB()*eventSample.size() < eventSample.size()-fMinSize ) {

    this->TrainNode(eventSample, node);
    if (node->GetSeparationGain() > fMinSepGain) {
      vector<TMVA_Event*> leftSample; leftSample.reserve(nevents);
      vector<TMVA_Event*> rightSample; rightSample.reserve(nevents);
      Double_t nRight=0, nLeft=0;
      for (UInt_t ie=0; ie< nevents ; ie++){
        if (node->GoesRight(eventSample[ie])){
          rightSample.push_back(eventSample[ie]);
          nRight += eventSample[ie]->GetWeight();
        }
        else {
          leftSample.push_back(eventSample[ie]);
          nLeft += eventSample[ie]->GetWeight();
        }
      }

      // sanity check
      if (leftSample.size() == 0 || rightSample.size() == 0) {
        cout << "--- DecisionTree::TrainNode Error:  all events went to the same branch\n";
        cout << "---                         Hence new node == old node ... check\n";
        cout << "---                         left:" << leftSample.size()
             << " right:" << rightSample.size() << endl;
        cout << "--- this should never happen, please write a bug report to Helge.Voss@cern.ch"
             << endl;
        exit(1);
      }

      // continue building daughter nodes for the left and the right eventsample
      TMVA_DecisionTreeNode *rightNode = new TMVA_DecisionTreeNode(node);
      fNNodes++;
      fSumOfWeights += 1.0;
      rightNode->SetNEvents(nRight);
      TMVA_DecisionTreeNode *leftNode = new TMVA_DecisionTreeNode(node);
      fNNodes++;
      fSumOfWeights += 1.0;
      leftNode->SetNEvents(nLeft);

      node->SetNodeType(0);
      node->SetLeft(leftNode);
      node->SetRight(rightNode);
      this->BuildTree(rightSample, rightNode);
      this->BuildTree(leftSample,  leftNode );
    } else { // it is a leaf node
      //      cout << "Found a leaf node: " << node->GetSeparationGain() << endl;

      if (node->GetSoverSB() > 0.5) node->SetNodeType(1);
      else node->SetNodeType(-1);
    }
  } else{ // it is a leaf node
    //    cout << "Found a leaf lode: " << eventSample.size() << " " <<
    //      node->GetSoverSB()*eventSample.size()  << endl;
    if (node->GetSoverSB() > 0.5) node->SetNodeType(1);
    else node->SetNodeType(-1);
  }

  return;
}

//_______________________________________________________________________
 Double_t TMVA_DecisionTree::TrainNode(vector<TMVA_Event*> & eventSample,
                                  TMVA_DecisionTreeNode *node)
{
  Int_t dummy;
  // at each node, ONE of the variables is choosen, which gives the best
  // separation between sign and bkg on the sample which enters the Node.
  // --> first fill a binary search tree for "each" variable in order to
  // quickly find which one offers the best separation.

  TMVA_BinarySearchTree *sigBST=NULL;
  TMVA_BinarySearchTree *bkgBST=NULL;

  vector<Double_t> *xmin  = new vector<Double_t>( fNvars );
  vector<Double_t> *xmax  = new vector<Double_t>( fNvars );

  Double_t separation = -1;
  Double_t cutMin=-999, cutMax=-999;
  Int_t mxVar=-1;
  Bool_t cutType=kTRUE;
  Double_t  nSelS, nSelB, nTotS, nTotB;

  fUseSearchTree = kTRUE;
  if (eventSample.size() < 30000) fUseSearchTree = kFALSE;

  for (int ivar=0; ivar < fNvars; ivar++){
    (*xmin)[ivar]=(*xmax)[ivar]=eventSample[0]->GetData(ivar);
  }

  for (UInt_t i=1;i<eventSample.size();i++){
    for (Int_t ivar=0; ivar < fNvars; ivar++){
      if ((*xmin)[ivar]>eventSample[i]->GetData(ivar))(*xmin)[ivar]=eventSample[i]->GetData(ivar);
      if ((*xmax)[ivar]<eventSample[i]->GetData(ivar))(*xmax)[ivar]=eventSample[i]->GetData(ivar);
    }
  }

  for (int ivar=0; ivar < fNvars; ivar++){
    if (fUseSearchTree) {
      sigBST = new TMVA_BinarySearchTree();
      bkgBST = new TMVA_BinarySearchTree();
      vector<Int_t> theVars;
      theVars.push_back(ivar);
      sigBST->Fill( eventSample,theVars, dummy, 1 );
      bkgBST->Fill( eventSample, theVars, dummy, 0 );
    }

    // now optimist the cuts for each varable and find which one gives
    // the best separation at the current stage.
    // just scan the possible cut values for this variable
    Double_t istepSize =( (*xmax)[ivar] - (*xmin)[ivar] ) / Double_t(fNCuts);
    Int_t nCuts = fNCuts;
    vector<Double_t> cutMinTmp(nCuts), cutMaxTmp(nCuts);
    vector<Double_t> sep(nCuts);
    vector<Bool_t> cutTypeTmp(nCuts);

    for (Int_t istep=0; istep<fNCuts; istep++){
      cutMinTmp[istep]=(*xmin)[ivar]+(Float_t(istep)+0.5)*istepSize;
      cutMaxTmp[istep]=(*xmax)[ivar];
      if (fUseSearchTree){
        TMVA_Volume volume(cutMinTmp[istep], cutMaxTmp[istep]);
        nSelS  = sigBST->SearchVolume( &volume );
        nSelB  = bkgBST->SearchVolume( &volume );

        nTotS  = sigBST->GetSumOfWeights();
        nTotB  = bkgBST->GetSumOfWeights();
      }else{
        nSelS=0; nSelB=0; nTotS=0; nTotB=0;
        for (UInt_t i=0; i<eventSample.size(); i++){
          if (eventSample[i]->GetType()==1){
            nTotS+=eventSample[i]->GetWeight();
            if (eventSample[i]->GetData(ivar) > cutMinTmp[istep]) nSelS+=eventSample[i]->GetWeight();
          }else if (eventSample[i]->GetType()==0){
            nTotB+=eventSample[i]->GetWeight();
            if (eventSample[i]->GetData(ivar) > cutMinTmp[istep]) nSelB+=eventSample[i]->GetWeight();
          }
        }
      }

      // now the separation is defined as the various indices (Gini, CorssEntropy, e.t.c)
      // calculated by the "SamplePurities" from the branches that would go to the
      // left or the right from this node if "these" cuts were used in the Node:
      // hereby: nSelS and nSelB would go to the right branch
      //        (nTotS - nSelS) + (nTotB - nSelB)  would go to the left branch;

      if (nSelS/nTotS > nSelB/nTotB) cutTypeTmp[istep]=kTRUE;
      else cutTypeTmp[istep]=kFALSE;

      sep[istep]= fSepType->GetSeparationGain(nSelS, nSelB, nTotS, nTotB);
    }

    //ich hab's versucht...aber das ist scheissee!!! Ich will ein INT!!!
    //    vector<Double_t>::iterator mxsep=max_element(sep.begin(),sep.end());
    Int_t pos = TMVA_Tools::GetIndexMaxElement(sep);

    //and now, choose the variable that gives the maximum separation
    if (separation < sep[pos]) {
      separation = sep[pos];
      cutMin=cutMinTmp[pos];
      cutMax=cutMaxTmp[pos];
      cutType=cutTypeTmp[pos];
      mxVar = ivar;
    }
    if (fUseSearchTree) {
      if (sigBST!=NULL) delete sigBST;
      if (bkgBST!=NULL) delete bkgBST;
    }
  }

  node->SetSelector(mxVar);
  node->SetCutMin(cutMin);
  node->SetCutMax(cutMax);
  node->SetCutType(cutType);
  node->SetSeparationGain(separation);

  delete xmin;
  delete xmax;
  return separation;
}

//_______________________________________________________________________
 Int_t TMVA_DecisionTree::CheckEvent(TMVA_Event* e)
{
  TMVA_DecisionTreeNode *current = (TMVA_DecisionTreeNode*)this->GetRoot();

  while(current->GetNodeType() == 0){ //intermediate node
    if (current->GoesRight(e))
        current=(TMVA_DecisionTreeNode*)current->GetRight();
    else current=(TMVA_DecisionTreeNode*)current->GetLeft();
  }
  return current->GetNodeType();
}

//_______________________________________________________________________
 Double_t  TMVA_DecisionTree::SamplePurity(vector<TMVA_Event*> eventSample)
{
  Double_t sumsig=0, sumbkg=0, sumtot=0;
  for (UInt_t ievt=0; ievt<eventSample.size(); ievt++) {
    if (eventSample[ievt]->GetType()==0) sumbkg+=eventSample[ievt]->GetWeight();
    if (eventSample[ievt]->GetType()==1) sumsig+=eventSample[ievt]->GetWeight();
    sumtot+=eventSample[ievt]->GetWeight();
  }
  //sanity check
  if (sumtot!= (sumsig+sumbkg)){
    cout << "--- TMVA_DecisionTree::Purity Error! sumtot != sumsig+sumbkg"
         << sumtot << " " << sumsig << " " << sumbkg << endl;
    exit(1);
  }
  if (sumtot>0) return sumsig/(sumsig + sumbkg);
  else return -1;
}




ROOT page - Class index - Class Hierarchy - Top of the page

This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.