#include <iostream>
#include <algorithm>
#include <vector>
#include "TMath.h"
#include "TMVA/DecisionTree.h"
#include "TMVA/DecisionTreeNode.h"
#include "TMVA/BinarySearchTree.h"
#include "TMVA/Tools.h"
#include "TMVA/GiniIndex.h"
#include "TMVA/CrossEntropy.h"
#include "TMVA/MisClassificationError.h"
#include "TMVA/SdivSqrtSplusB.h"
#include "TMVA/Event.h"
using std::vector;
#define USE_HELGESCODE 1    // the other one is Dougs implementation of the TrainNode
#define USE_HELGE_V1  0     // out loop is over NVAR in TrainNode, inner loop is Eventloop
ClassImp(TMVA::DecisionTree)
TMVA::DecisionTree::DecisionTree( void )
   : BinaryTree(),
     fNvars      (0),
     fNCuts      (-1),
     fSepType    (NULL),
     fMinSize    (0),
     fPruneMethod(kCostComplexityPruning),
     fDepth (0),
     fQualityIndex(NULL)
{
   
   
   
   fLogger.SetSource( "DecisionTree" );
}
TMVA::DecisionTree::DecisionTree( DecisionTreeNode* n )
   : BinaryTree(),
     fNvars      (0),
     fNCuts      (-1),
     fSepType    (NULL),
     fMinSize    (0),
     fPruneMethod(kCostComplexityPruning),
     fDepth (0),
     fQualityIndex(NULL)
{
   
   
   
   fLogger.SetSource( "DecisionTree" );
   this->SetRoot( n );
   this->SetParentTreeInNodes();
   fLogger.SetSource( "DecisionTree" );
}
TMVA::DecisionTree::DecisionTree( TMVA::SeparationBase *sepType,Int_t minSize,
                                  Int_t nCuts, TMVA::SeparationBase *qtype):
   BinaryTree(),
   fNvars      (0),
   fNCuts      (nCuts),
   fSepType    (sepType),
   fMinSize    (minSize),
   fPruneMethod(kCostComplexityPruning),
   fDepth (0),
   fQualityIndex(qtype)
{
   
   
   
   
   fLogger.SetSource( "DecisionTree" );
}
TMVA::DecisionTree::DecisionTree( const DecisionTree &d):
   BinaryTree(),
   fNvars      (d.fNvars),
   fNCuts      (d.fNCuts),
   fSepType    (d.fSepType),
   fMinSize    (d.fMinSize),
   fPruneMethod(d.fPruneMethod),
   fDepth      (d.fDepth),
   fQualityIndex(d.fQualityIndex)
{
   
   
   this->SetRoot( new DecisionTreeNode ( *((DecisionTreeNode*)(d.GetRoot())) ) );
   this->SetParentTreeInNodes();
   fNNodes = d.fNNodes;
   fLogger.SetSource( "DecisionTree" );
}
TMVA::DecisionTree::~DecisionTree( void )
{
   
   
   
   
   
}
void TMVA::DecisionTree::SetParentTreeInNodes( DecisionTreeNode *n)
{
   
   
   if (n == NULL){ 
      n = (DecisionTreeNode*) this->GetRoot();
      if (n == NULL) {
         fLogger << kFATAL << "SetParentTreeNodes: started with undefined ROOT node" <<Endl;
         return ;
      }
   } 
   if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) != NULL) ) {
      fLogger << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
      return;
   }  else if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) == NULL) ) {
      fLogger << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
      return;
   } 
   else { 
      if (this->GetLeftDaughter(n) != NULL){
         this->SetParentTreeInNodes( this->GetLeftDaughter(n) );
      }
      if (this->GetRightDaughter(n) != NULL) {
         this->SetParentTreeInNodes( this->GetRightDaughter(n) );
      }
   }
   n->SetParentTree(this);
   if (n->GetDepth() > fDepth) fDepth = n->GetDepth();
   return;
}
Int_t TMVA::DecisionTree::BuildTree( vector<TMVA::Event*> & eventSample,
                                     TMVA::DecisionTreeNode *node )
{
   
   
   
   if (node==NULL) {
      
      node = new TMVA::DecisionTreeNode();
      fNNodes = 1;   
      this->SetRoot(node);
      
      this->GetRoot()->SetPos('s');
      this->GetRoot()->SetDepth(0);
      this->GetRoot()->SetParentTree(this);
   }
   UInt_t nevents = eventSample.size();
   if (nevents > 0 ) {
      fNvars = eventSample[0]->GetNVars();
      fVariableImportance.resize(fNvars);
   }
   else fLogger << kFATAL << ":<BuildTree> eventsample Size == 0 " << Endl;
   Double_t s=0, b=0;
   Double_t suw=0, buw=0;
   for (UInt_t i=0; i<eventSample.size(); i++){
      if (eventSample[i]->IsSignal()){
         s += eventSample[i]->GetWeight();
         suw += 1;
      } 
      else {
         b += eventSample[i]->GetWeight();
         buw += 1;
      }
   }
   node->SetNSigEvents(s);
   node->SetNBkgEvents(b);
   node->SetNSigEvents_unweighted(suw);
   node->SetNBkgEvents_unweighted(buw);
   if (node == this->GetRoot()) {
      node->SetNEvents(s+b);
      node->SetNEvents_unweighted(suw+buw);
   }
   node->SetSeparationIndex(fSepType->GetSeparationIndex(s,b));
   
   
   
   
   
   
   if ( eventSample.size() >= 2*fMinSize){
      Double_t separationGain;
      separationGain = this->TrainNode(eventSample, node);
      if (separationGain == 0) {
         
         
         if (node->GetPurity() > 0.5) node->SetNodeType(1);
         else node->SetNodeType(-1);
         if (node->GetDepth() > fDepth) fDepth = node->GetDepth();
      } 
      else {
         vector<TMVA::Event*> leftSample; leftSample.reserve(nevents);
         vector<TMVA::Event*> rightSample; rightSample.reserve(nevents);
         Double_t nRight=0, nLeft=0;
         for (UInt_t ie=0; ie< nevents ; ie++){
            if (node->GoesRight(*eventSample[ie])){
               rightSample.push_back(eventSample[ie]);
               nRight += eventSample[ie]->GetWeight();
            }
            else {
               leftSample.push_back(eventSample[ie]);
               nLeft += eventSample[ie]->GetWeight();
            }
         }
         
         
         if (leftSample.size() == 0 || rightSample.size() == 0) {
            fLogger << kFATAL << "<TrainNode> all events went to the same branch" << Endl
                    << "---                       Hence new node == old node ... check" << Endl
                    << "---                         left:" << leftSample.size()
                    << " right:" << rightSample.size() << Endl
                    << "--- this should never happen, please write a bug report to Helge.Voss@cern.ch"
                    << Endl;
         }
    
         
         TMVA::DecisionTreeNode *rightNode = new TMVA::DecisionTreeNode(node,'r');
         fNNodes++;
         
         
         rightNode->SetNEvents(nRight);
         rightNode->SetNEvents_unweighted(rightSample.size());
         TMVA::DecisionTreeNode *leftNode = new TMVA::DecisionTreeNode(node,'l');
         fNNodes++;
         
         
         leftNode->SetNEvents(nLeft);
         leftNode->SetNEvents_unweighted(leftSample.size());
         
         node->SetNodeType(0);
         node->SetLeft(leftNode);
         node->SetRight(rightNode);
         this->BuildTree(rightSample, rightNode);
         this->BuildTree(leftSample,  leftNode );
      }
   } 
   else{ 
      if (node->GetPurity() > 0.5) node->SetNodeType(1);
      else node->SetNodeType(-1);
      if (node->GetDepth() > fDepth) fDepth = node->GetDepth();
   }
   
   return fNNodes;
}
void TMVA::DecisionTree::FillTree( vector<TMVA::Event*> & eventSample)
{
   
   
   for (UInt_t i=0; i<eventSample.size(); i++){
      this->FillEvent(*(eventSample[i]),NULL);
   }
}
void TMVA::DecisionTree::FillEvent( TMVA::Event & event,  
                                     TMVA::DecisionTreeNode *node  )
{
   
   
   
   if (node == NULL) { 
      node = (TMVA::DecisionTreeNode*)this->GetRoot();
   }
   node->IncrementNEvents( event.GetWeight() );
   node->IncrementNEvents_unweighted( );
                         
   if (event.IsSignal()){
      node->IncrementNSigEvents( event.GetWeight() );
      node->IncrementNSigEvents_unweighted( );
   } 
   else {
      node->IncrementNBkgEvents( event.GetWeight() );
      node->IncrementNSigEvents_unweighted( );
   }
   node->SetSeparationIndex(fSepType->GetSeparationIndex(node->GetNSigEvents(),
                                                         node->GetNBkgEvents()));
   
   if (node->GetNodeType() == 0){ 
      if (node->GoesRight(event))
         this->FillEvent(event,(TMVA::DecisionTreeNode*)(node->GetRight())) ;
      else
         this->FillEvent(event,(TMVA::DecisionTreeNode*)(node->GetLeft())) ;
   }
}
void TMVA::DecisionTree::ClearTree()
{
   
   if (this->GetRoot()!=NULL) 
      ((DecisionTreeNode*)(this->GetRoot()))->ClearNodeAndAllDaughters();
}
void TMVA::DecisionTree::PruneTree()
{
   
   
   
   
   
   if      (fPruneMethod == kExpectedErrorPruning)  this->PruneTreeEEP((DecisionTreeNode *)this->GetRoot());
   else if (fPruneMethod == kCostComplexityPruning) this->PruneTreeCC();
   else if (fPruneMethod == kMCC)                   this->PruneTreeMCC();
   else {
      fLogger << kFATAL << "Selected pruning method not yet implemented "
              << Endl;
   }
   
   this->CountNodes();
};
      
void TMVA::DecisionTree::PruneTreeEEP(DecisionTreeNode *node)
{
   
   
   DecisionTreeNode *l = (DecisionTreeNode*)node->GetLeft();
   DecisionTreeNode *r = (DecisionTreeNode*)node->GetRight();
   if (node->GetNodeType() == 0){
      this->PruneTreeEEP(l);
      this->PruneTreeEEP(r);
      if (this->GetSubTreeError(node) >= this->GetNodeError(node)) { 
         this->PruneNode(node);
      }
   } 
}
void TMVA::DecisionTree::PruneTreeCC()
{
   
   
   
   
   
   
   
   
   
   
   Double_t currentCC = this->GetCostComplexity(fPruneStrength);
   Double_t nextCC    = this->GetCostComplexityIfNextPruneStep(fPruneStrength);
   while (currentCC > nextCC &&  this->GetNNodes() > 3 ){
      this->PruneNode( this->FindCCPruneCandidate() );
      currentCC = this->GetCostComplexity(fPruneStrength);
      nextCC    = this->GetCostComplexityIfNextPruneStep(fPruneStrength);
   }
   return;
}
void TMVA::DecisionTree::PruneTreeMCC()
{
   
   
   
   
   this->FillLinkStrengthMap();
   Double_t currentG = fLinkStrengthMap.begin()->first;
   
   while (currentG < fPruneStrength &&  this->GetNNodes() > 3 ){
      
      this->PruneNode( this->GetWeakestLink() );
      currentG = fLinkStrengthMap.begin()->first;
   }
   return;
}
TMVA::DecisionTreeNode*  TMVA::DecisionTree::GetWeakestLink() 
{
   
   this->FillLinkStrengthMap();
   return fLinkStrengthMap.begin()->second;
}
void TMVA::DecisionTree::FillLinkStrengthMap(TMVA::DecisionTreeNode *n) 
{
   
   
   
   
   
   
   
   
   if (n == NULL){ 
      n = (DecisionTreeNode*) this->GetRoot();
      fLinkStrengthMap.clear();
      if (n == NULL) {
         fLogger << kFATAL << "FillLinkStrengthMap: started with undefined ROOT node" <<Endl;
         return ;
      }
   } 
   if (this->GetLeftDaughter(n) != NULL){
      this->FillLinkStrengthMap( this->GetLeftDaughter(n)); 
   }
   if (this->GetRightDaughter(n) != NULL) {
      this->FillLinkStrengthMap( this->GetRightDaughter(n));
   }
   
   
   if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) != NULL) ) {   
      
      Double_t alpha = ( this->MisClassificationCostOfNode(n)  -
                         this->MisClassificationCostOfSubTree(n) ) /  
         (n->CountMeAndAllDaughters() - 1);
      fLinkStrengthMap.insert(pair<const Double_t, TMVA::DecisionTreeNode* > ( alpha, n ));
   }
}
Double_t TMVA::DecisionTree::MisClassificationCostOfNode(TMVA::DecisionTreeNode *n)
{
   
   return (1 - n->GetPurity()) * n->GetNEvents(); 
}
Double_t TMVA::DecisionTree::MisClassificationCostOfSubTree(TMVA::DecisionTreeNode *n)
{
   
   Double_t tmp=0;
   if (n == NULL){ 
      n = (DecisionTreeNode*) this->GetRoot();
      if (n == NULL) {
         fLogger << kFATAL << "MisClassificationCostOfSubTree: started with undefined ROOT node" <<Endl;
         return 0.;
      }
   } 
   if (this->GetLeftDaughter(n) != NULL){
      tmp += this->MisClassificationCostOfSubTree( this->GetLeftDaughter(n)); 
   }
   if (this->GetRightDaughter(n) != NULL) {
      tmp += this->MisClassificationCostOfSubTree( this->GetRightDaughter(n));
   }
   
   
   if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {   
      tmp = this->MisClassificationCostOfNode(n);
   }
   return tmp;
}
UInt_t TMVA::DecisionTree::CountLeafNodes(TMVA::DecisionTreeNode *n)
{
   
   if (n == NULL){ 
      n = (DecisionTreeNode*) this->GetRoot();
      if (n == NULL) {
         fLogger << kFATAL << "CountLeafNodes: started with undefined ROOT node" <<Endl;
         return 0;
      }
   } 
   UInt_t countLeafs=0;
   if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
      countLeafs += 1;
   } 
   else { 
      if (this->GetLeftDaughter(n) != NULL){
         countLeafs += this->CountLeafNodes( this->GetLeftDaughter(n) );
      }
      if (this->GetRightDaughter(n) != NULL) {
         countLeafs += this->CountLeafNodes( this->GetRightDaughter(n) );
      }
   }
   return countLeafs;
}
Double_t TMVA::DecisionTree::GetCostComplexity(Double_t alpha) 
{
   
   
   
   
   
   
   
   
   
   
   Double_t cc=0.;
   this->FillQualityMap();
   multimap<Double_t, TMVA::DecisionTreeNode* >::iterator it=fQualityMap.begin();
   Int_t count=0;
   for (;it!=fQualityMap.end(); it++){
      Double_t s=it->second->GetNSigEvents_unweighted();
      Double_t b=it->second->GetNBkgEvents_unweighted();
      cc += (s+b) * it->first ;  
      count++;
   }
   return cc+alpha * count;
}
Double_t TMVA::DecisionTree::GetCostComplexityIfNextPruneStep(Double_t alpha) 
{
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   Double_t cc=0.;
   this->FillQualityMap();
   this->FillQualityGainMap();
   if (fQualityMap.size() == 0 ){
      fLogger << kError << "The Quality Map in the BDT-Pruning is empty.. maybe your Tree has "
              << " absolutely no splits ?? e.g.. minimun number of events for node splitting"
              << " being larger than the number of events available ??? " << Endl;
   } 
   else if (fQualityGainMap.size() == 0 ){
      fLogger << kError << "The QualityGain Map in the BDT-Pruning is empty.. This can happen"
              << " if your Tree has absolutely no splits ?? e.g.. minimun number of events for"
              << " node splitting being larger than the number of events available ??? " << Endl;
   } 
   else {
      multimap<Double_t, TMVA::DecisionTreeNode* >::iterator it=fQualityMap.begin();
      Int_t count=0;
      for (;it!=fQualityMap.end(); it++){
         if (it->second->GetParent() != fQualityGainMap.begin()->second ) {
            Double_t s=it->second->GetNSigEvents_unweighted();
            Double_t b=it->second->GetNBkgEvents_unweighted();
            cc += (s+b) * it->first ;  
            count++;
         } 
      }
      
      Double_t s=fQualityGainMap.begin()->second->GetNSigEvents_unweighted();
      Double_t b=fQualityGainMap.begin()->second->GetNBkgEvents_unweighted();
      
      cc += (s+b) * fQualityIndex->GetSeparationIndex(s,b);
      count++;
      
   cc+=alpha*count;
   }
   return cc;
}
void TMVA::DecisionTree::FillQualityGainMap(DecisionTreeNode* n )
{
   
   
   
   if (n == NULL){ 
      n = (DecisionTreeNode*) this->GetRoot();
      fQualityGainMap.clear();
      if (n == NULL) {
         fLogger << kFATAL << "FillQualityGainMap: started with undefined ROOT node" <<Endl;
         return ;
      }
   } 
   if (this->GetLeftDaughter(n) != NULL){
      this->FillQualityGainMap( this->GetLeftDaughter(n)); 
   }
   if (this->GetRightDaughter(n) != NULL) {
      this->FillQualityGainMap( this->GetRightDaughter(n));
   }
   
   if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) != NULL) ) {
      
      if ((this->GetLeftDaughter(n)->GetLeft() == NULL) && 
          (this->GetLeftDaughter(n)->GetRight() == NULL) && 
          (this->GetRightDaughter(n)->GetLeft() == NULL) && 
          (this->GetRightDaughter(n)->GetRight() == NULL) ){
         
         fQualityGainMap.insert(pair<const Double_t, TMVA::DecisionTreeNode* > 
                                ( fQualityIndex->GetSeparationGain (this->GetRightDaughter(n)->GetNSigEvents_unweighted(),
                                                                    this->GetRightDaughter(n)->GetNBkgEvents_unweighted(),
                                                                    n->GetNSigEvents_unweighted(), n->GetNBkgEvents_unweighted()),
                                  n));
      }
   }
   return;
}
void TMVA::DecisionTree::FillQualityMap(DecisionTreeNode* n )
{
   
   
   if (n == NULL){ 
      n = (DecisionTreeNode*) this->GetRoot();
      fQualityMap.clear();
      if (n == NULL) {
         fLogger << kFATAL << "FillQualityMap: started with undefined ROOT node" <<Endl;
         return ;
      }
   } 
   
   if (this->GetLeftDaughter(n) != NULL){
      this->FillQualityMap( this->GetLeftDaughter(n)); 
   }
   if (this->GetRightDaughter(n) != NULL) {
      this->FillQualityMap( this->GetRightDaughter(n));
   }
   
   
   if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {   
      fQualityMap.insert(pair<const Double_t, TMVA::DecisionTreeNode* > 
                         ( fQualityIndex->GetSeparationIndex (n->GetNSigEvents_unweighted(), 
                                                              n->GetNBkgEvents_unweighted()),
                           n));
   }
   return;
}
   
void TMVA::DecisionTree::DescendTree( DecisionTreeNode *n)
{
   
   if (n == NULL){ 
      n = (DecisionTreeNode*) this->GetRoot();
      if (n == NULL) {
         fLogger << kFATAL << "DescendTree: started with undefined ROOT node" <<Endl;
         return ;
      }
   } 
   if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
      
   } 
   else if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) != NULL) ) {
      fLogger << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
      return;
   }  
   else if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) == NULL) ) {
      fLogger << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
      return;
   } 
   else { 
      if (this->GetLeftDaughter(n) != NULL){
         this->DescendTree( this->GetLeftDaughter(n) );
      }
      if (this->GetRightDaughter(n) != NULL) {
         this->DescendTree( this->GetRightDaughter(n) );
      }
   }
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::FindCCPruneCandidate()
{
   
   
   this->FillQualityGainMap();
   return fQualityGainMap.begin()->second;
}
void TMVA::DecisionTree::PruneNode(DecisionTreeNode *node)
{
   
   
   DecisionTreeNode *l = (DecisionTreeNode*)node->GetLeft();
   DecisionTreeNode *r = (DecisionTreeNode*)node->GetRight();
   
   node->SetRight(NULL);
   node->SetLeft(NULL);
   node->SetSelector(-1);
   node->SetSeparationIndex(-1);
   node->SetSeparationGain(-1);
   if (node->GetPurity() > 0.5) node->SetNodeType(1);
   else node->SetNodeType(-1);
   this->DeleteNode(l);
   this->DeleteNode(r);
   
   this->CountNodes();
}
Double_t TMVA::DecisionTree::GetNodeError(DecisionTreeNode *node)
{
   
   
   
   
   
   
   
   
   
   Double_t errorRate = 0;
   Double_t nEvts = node->GetNEvents();
   
   
   Double_t f=0;
   if (node->GetPurity() > 0.5) f = node->GetPurity();
   else  f = (1-node->GetPurity());
   Double_t df = TMath::Sqrt(f*(1-f)/nEvts );
   
   errorRate = std::min(1.,(1 - (f-fPruneStrength*df) ));
   
   
   
   
   
   
   
   
   
   
   
   
   
   
   
 
   
   
   
   
   
   
   
   
   
   
          
   
   
   
   
   
   
   
   
   return errorRate;
}
Double_t TMVA::DecisionTree::GetSubTreeError(DecisionTreeNode *node)
{
   
   
   DecisionTreeNode *l = (DecisionTreeNode*)node->GetLeft();
   DecisionTreeNode *r = (DecisionTreeNode*)node->GetRight();
   if (node->GetNodeType() == 0) {
      Double_t subTreeError = 
         (l->GetNEvents() * this->GetSubTreeError(l) +
          r->GetNEvents() * this->GetSubTreeError(r)) /
         node->GetNEvents();
      return subTreeError;
   }
   else {
      return this->GetNodeError(node);
   }
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::GetLeftDaughter( DecisionTreeNode *n)
{
   
   return (DecisionTreeNode*) n->GetLeft();
}
TMVA::DecisionTreeNode* TMVA::DecisionTree::GetRightDaughter( DecisionTreeNode *n)
{
   
   return (DecisionTreeNode*) n->GetRight();
}
   
TMVA::DecisionTreeNode* TMVA::DecisionTree::GetNode(ULong_t sequence, UInt_t depth)
{
   
   
   
   DecisionTreeNode* current = (DecisionTreeNode*) this->GetRoot();
   
   for (UInt_t i =0;  i < depth; i++){
      ULong_t tmp = 1 << i;
      if ( tmp & sequence) current = this->GetRightDaughter(current);
      else current = this->GetLeftDaughter(current);
   }
   
   return current;
}
void TMVA::DecisionTree::FindMinAndMax(vector<TMVA::Event*> & eventSample,
                                       vector<Double_t> & xmin,
                                       vector<Double_t> & xmax)
{
   
   
   UInt_t num_events = eventSample.size();
  
   for (Int_t ivar=0; ivar < fNvars; ivar++){
      xmin[ivar]=xmax[ivar]=eventSample[0]->GetVal(ivar);
   }
  
   for (UInt_t i=1;i<num_events;i++){
      for (Int_t ivar=0; ivar < fNvars; ivar++){
         if (xmin[ivar]>eventSample[i]->GetVal(ivar))
            xmin[ivar]=eventSample[i]->GetVal(ivar);
         if (xmax[ivar]<eventSample[i]->GetVal(ivar))
            xmax[ivar]=eventSample[i]->GetVal(ivar);
      }
   }
  
};
void  TMVA::DecisionTree::SetCutPoints(vector<Double_t> & cut_points,
                                       Double_t xmin,
                                       Double_t xmax,
                                       Int_t num_gridpoints)
{
   
   
   Double_t step = (xmax - xmin)/num_gridpoints;
   Double_t x = xmin + step/2; 
   for (Int_t j=0; j < num_gridpoints; j++){
      cut_points[j] = x;
      x += step;
   }
};
#if USE_HELGESCODE==1
Double_t TMVA::DecisionTree::TrainNode(vector<TMVA::Event*> & eventSample,
                                       TMVA::DecisionTreeNode *node)
{
   
   
   
   
   
   
   
   
   vector<Double_t> *xmin  = new vector<Double_t>( fNvars );
   vector<Double_t> *xmax  = new vector<Double_t>( fNvars );
   Double_t separationGain = -1, sepTmp;
   Double_t cutValue=-999;
   Int_t mxVar=-1, cutIndex=0;
   Bool_t cutType=kTRUE;
   Double_t  nTotS, nTotB;
   Int_t     nTotS_unWeighted, nTotB_unWeighted; 
   UInt_t nevents = eventSample.size();
   
   for (int ivar=0; ivar < fNvars; ivar++){
      (*xmin)[ivar]=(*xmax)[ivar]=eventSample[0]->GetVal(ivar);
   }
   for (UInt_t iev=1;iev<nevents;iev++){
      for (Int_t ivar=0; ivar < fNvars; ivar++){
         Double_t eventData = eventSample[iev]->GetVal(ivar); 
         if ((*xmin)[ivar]>eventData)(*xmin)[ivar]=eventData;
         if ((*xmax)[ivar]<eventData)(*xmax)[ivar]=eventData;
      }
   }
   vector< vector<Double_t> > nSelS (fNvars);
   vector< vector<Double_t> > nSelB (fNvars);
   vector< vector<Int_t> >    nSelS_unWeighted (fNvars);
   vector< vector<Int_t> >    nSelB_unWeighted (fNvars);
   vector< vector<Double_t> > significance (fNvars);
   vector< vector<Double_t> > cutValues(fNvars);
   vector< vector<Bool_t> > cutTypes(fNvars);
   for (int ivar=0; ivar < fNvars; ivar++){
      cutValues[ivar].resize(fNCuts);
      cutTypes[ivar].resize(fNCuts);
      nSelS[ivar].resize(fNCuts);
      nSelB[ivar].resize(fNCuts);
      nSelS_unWeighted[ivar].resize(fNCuts);
      nSelB_unWeighted[ivar].resize(fNCuts);
      significance[ivar].resize(fNCuts);
      
      Double_t istepSize =( (*xmax)[ivar] - (*xmin)[ivar] ) / Double_t(fNCuts);
      for (Int_t icut=0; icut<fNCuts; icut++){
         cutValues[ivar][icut]=(*xmin)[ivar]+(Float_t(icut)+0.5)*istepSize;
      }
   }
 
#if USE_HELGE_V1==1
   
   
   
   
   nTotS=0; nTotB=0;
   nTotS_unWeighted=0; nTotB_unWeighted=0;   
   for (int ivar=0; ivar < fNvars; ivar++){
      for (UInt_t iev=0; iev<nevents; iev++){
         Double_t eventData  = eventSample[iev]->GetData(ivar); 
         Int_t    eventType  = eventSample[iev]->GetType(); 
         Double_t eventWeight= eventSample[iev]->GetWeight(); 
         if (ivar==0){
            if (eventType==1){
               nTotS+=eventWeight;
               nTotS_unWeighted++;
            }
            else {
               nTotB+=eventWeight;
               nTotB_unWeighted++;
            }
         }
         
         
         
         for (Int_t icut=0; icut<fNCuts; icut++){
            if (eventData > cutValues[ivar][icut]){
               if (eventType==1) {
                  nSelS[ivar][icut]+=eventWeight;
                  nSelS_unWeighted[ivar][icut]++;
               } 
               else {
                  nSelB[ivar][icut]+=eventWeight;
                  nSelB_unWeighted[ivar][icut]++;
               }
            }
         }
      }
   }
#else 
   nTotS=0; nTotB=0;
   nTotS_unWeighted=0; nTotB_unWeighted=0;   
   for (UInt_t iev=0; iev<nevents; iev++){
      Int_t eventType = eventSample[iev]->Type();
      Double_t eventWeight =  eventSample[iev]->GetWeight(); 
      if (eventType==1){
         nTotS+=eventWeight;
         nTotS_unWeighted++;
      }
      else {
         nTotB+=eventWeight;
         nTotB_unWeighted++;
      }
      for (int ivar=0; ivar < fNvars; ivar++){
         
         
         
         Double_t eventData = eventSample[iev]->GetVal(ivar); 
         for (Int_t icut=0; icut<fNCuts; icut++){
            if (eventData > cutValues[ivar][icut]){
               if (eventType==1) {
                  nSelS[ivar][icut]+=eventWeight;
                  nSelS_unWeighted[ivar][icut]++;
               } 
               else {
                  nSelB[ivar][icut]+=eventWeight;
                  nSelB_unWeighted[ivar][icut]++;
               }
            }
         }
      }
   }
#endif
   
   
   for (int ivar=0; ivar < fNvars; ivar++) {
      for (Int_t icut=0; icut<fNCuts; icut++){
         
         
         
         
         
       
         
         
         
         
         if ( (nSelS_unWeighted[ivar][icut] +  nSelB_unWeighted[ivar][icut]) >= fMinSize &&
              (( nTotS_unWeighted+nTotB_unWeighted)- 
               (nSelS_unWeighted[ivar][icut] +  nSelB_unWeighted[ivar][icut])) >= fMinSize) {
            sepTmp = fSepType->GetSeparationGain(nSelS[ivar][icut], nSelB[ivar][icut], nTotS, nTotB);
            
            if (separationGain < sepTmp) {
               separationGain = sepTmp;
               mxVar = ivar;
               cutIndex = icut;
            }
         }
      }
   }
   if (mxVar >= 0) { 
      if (nSelS[mxVar][cutIndex]/nTotS > nSelB[mxVar][cutIndex]/nTotB) cutType=kTRUE;
      else cutType=kFALSE;
      cutValue = cutValues[mxVar][cutIndex];
      
      node->SetSelector((UInt_t)mxVar);
      node->SetCutValue(cutValue);
      node->SetCutType(cutType);
      node->SetSeparationGain(separationGain);
      
      fVariableImportance[mxVar] += separationGain*separationGain * (nTotS+nTotB) * (nTotS+nTotB) ;
   }
   else {
      separationGain = 0;
   }
   delete xmin;
   delete xmax;
   return separationGain;
}
#else 
Double_t TMVA::DecisionTree::TrainNode(vector<TMVA::Event*> & eventSample,
                                       TMVA::DecisionTreeNode *node)
{
   
   
   
   
   
   
   
   vector<Double_t> xmin ( fNvars );
   vector<Double_t> xmax ( fNvars );
   Double_t separationGain = -1;
   Double_t cutValue=-999;
   Int_t mxVar=-1;
   Bool_t cutType=kTRUE;
   Double_t  nSelS=0., nSelB=0., nTotS=0., nTotB=0.;
   UInt_t num_events = eventSample.size();
   
   vector<vector<Double_t> > signal_counts (fNvars);
   vector<vector<Double_t> > background_counts (fNvars);
   vector<vector<Double_t> > cut_points (fNvars);
   vector<vector<Double_t> > significance (fNvars);
   
   this->FindMinAndMax(eventSample, xmin, xmax);
   
   for (Int_t i=0; i < fNvars; i++){
      signal_counts[i].resize(fNCuts);
      background_counts[i].resize(fNCuts);
      cut_points[i].resize(fNCuts);
      significance[i].resize(fNCuts);
      this->SetCutPoints(cut_points[i], xmin[i], xmax[i], fNCuts);
   }
   for (UInt_t event=0; event < num_events; event++){
     
      Int_t event_type = eventSample[event]->GetType();
      Double_t event_weight = eventSample[event]->GetWeight();
     
      if (event_type == 1){
         nTotS += event_weight;
      } 
      else {
         nTotB += event_weight;
      }
     
      for (Int_t variable = 0; variable < fNvars; variable++){
         Double_t event_val = eventSample[event]->GetData(variable);
         for (Int_t cut=0; cut < fNCuts; cut++){
            if (event_val > cut_points[variable][cut]){
               if (event_type == 1){
                  signal_counts[variable][cut] += event_weight;
               } 
               else {
                  background_counts[variable][cut] += event_weight;
               }
            }
         } 
      }
   }
   for (Int_t var = 0; var < fNvars; var++){
      for (Int_t cut=0; cut < fNCuts; cut++){
         Double_t cur_sep = fSepType->GetSeparationGain(signal_counts[var][cut],
                                                        background_counts[var][cut],
                                                        nTotS, nTotB);
         if (separationGain < cur_sep) {
            separationGain = cur_sep;
            cutValue=cut_points[var][cut];
            cutType= (nSelS/nTotS > nSelB/nTotB) ? kTRUE : kFALSE;
            mxVar = var;
         } 
         
      }
   }
   
   node->SetSelector(mxVar);
   node->SetCutValue(cutValue);
   node->SetCutType(cutType);
   node->SetSeparationGain(separationGain);
   fVariableImportance[mxVar] += separationGain*separationGain * (nTotS+nTotB)* (nTotS+nTotB);
  
   return separationGain;
}
#endif
Double_t TMVA::DecisionTree::CheckEvent(const TMVA::Event & e, Bool_t UseYesNoLeaf)
{
   
   
   
   
   TMVA::DecisionTreeNode *current = (TMVA::DecisionTreeNode*)this->GetRoot();
   while(current->GetNodeType() == 0){ 
      if (current->GoesRight(e))
         current=(TMVA::DecisionTreeNode*)current->GetRight();
      else current=(TMVA::DecisionTreeNode*)current->GetLeft();
   }
   if (UseYesNoLeaf) return Double_t ( current->GetNodeType() );
   else return current->GetPurity();
}
Double_t  TMVA::DecisionTree::SamplePurity(vector<TMVA::Event*> eventSample)
{
   
   
   Double_t sumsig=0, sumbkg=0, sumtot=0;
   for (UInt_t ievt=0; ievt<eventSample.size(); ievt++) {
      if (eventSample[ievt]->Type()==0) sumbkg+=eventSample[ievt]->GetWeight();
      if (eventSample[ievt]->Type()==1) sumsig+=eventSample[ievt]->GetWeight();
      sumtot+=eventSample[ievt]->GetWeight();
   }
   
   if (sumtot!= (sumsig+sumbkg)){
      fLogger << kFATAL << "<SamplePurity> sumtot != sumsig+sumbkg"
              << sumtot << " " << sumsig << " " << sumbkg << Endl;
   }
   if (sumtot>0) return sumsig/(sumsig + sumbkg);
   else return -1;
}
vector< Double_t >  TMVA::DecisionTree::GetVariableImportance()
{
   
   
   
   
 
   vector<Double_t> relativeImportance(fNvars);
   Double_t  sum=0;
   for (int i=0; i< fNvars; i++) {
      sum += fVariableImportance[i];
      relativeImportance[i] = fVariableImportance[i];
   } 
   for (int i=0; i< fNvars; i++) {
      if (sum > std::numeric_limits<double>::epsilon())
         relativeImportance[i] /= sum;
      else 
         relativeImportance[i] = 0;
   } 
   return relativeImportance;
}
Double_t  TMVA::DecisionTree::GetVariableImportance(Int_t ivar)
{
   
 
   vector<Double_t> relativeImportance = this->GetVariableImportance();
   if (ivar >= 0 && ivar < fNvars) return relativeImportance[ivar];
   else {
      fLogger << kFATAL << "<GetVariableImportance>" << Endl
              << "---                     ivar = " << ivar << " is out of range " << Endl;
   }
   return -1;
}
TH2D*  TMVA::DecisionTree::DrawTree(TString hname)
{
   
   Double_t xmax= 2*fDepth + 0.5;
   Double_t xmin= -xmax;
   ULong_t nbins =1; for (UInt_t i=0; i<fDepth; i++) nbins *= 2;  
   TH2D* h=new TH2D(hname,"",2*nbins+1, xmin, xmax,
                    2*fDepth+2, -0.5, 2*fDepth+0.5);
   this->DrawNode( h, (DecisionTreeNode*)this->GetRoot(), 2*fDepth, 0, Double_t(fDepth) );
  
   return h;
}   
      
void TMVA::DecisionTree::DrawNode( TH2D* h,  DecisionTreeNode *n, 
                                   Double_t y, Double_t x, Double_t scale)
{
   
   
   if (this->GetLeftDaughter(n) != NULL){
      this->DrawNode( h, this->GetLeftDaughter(n), y-2, x-scale, scale/2.  );
   }
   if (this->GetRightDaughter(n) != NULL) {
      this->DrawNode( h, this->GetRightDaughter(n), y-2, x+scale, scale/2.);
   }
   h->Fill(x,y,n->GetNEvents());
   return;
}
This page has been automatically generated. If you have any comments or suggestions about the page layout send a mail to ROOT support, or contact the developers with any questions or problems regarding ROOT.