Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
BDT_Reg.cxx
Go to the documentation of this file.
1#include "TMVA/BDT_Reg.h"
2#include <iostream>
3#include <fstream>
4
5#include "RQ_OBJECT.h"
6
7#include "TROOT.h"
8#include "TCanvas.h"
9#include "TLine.h"
10#include "TFile.h"
11#include "TColor.h"
12#include "TPaveText.h"
13#include "TObjString.h"
14#include "TControlBar.h"
15
16#include "TGWindow.h"
17#include "TGButton.h"
18#include "TGLabel.h"
19#include "TGNumberEntry.h"
20
21#include "TMVA/DecisionTree.h"
22#include "TMVA/Tools.h"
23#include "TXMLEngine.h"
24
25std::vector<TControlBar*> TMVA::BDTReg_Global__cbar;
26
28
33
35{
36 UpdateCanvases();
37}
38
40{
41 delete this;
42}
43
45 : fMain( 0 ),
46 fItree(itree),
47 fNtrees(0),
48 fCanvas(0),
49 fDataset(dataset),
50 fInput(0),
51 fButtons(0),
52 fDrawButton(0),
53 fCloseButton(0),
54 fWfile( wfile ),
55 fMethName( methName )
56{
57 UInt_t totalWidth = 500;
58 UInt_t totalHeight = 200;
59
60 fThis = this;
61
62 // read number of decision trees from weight file
63 GetNtrees();
64
65 // main frame
67
68 TGLabel *sigLab = new TGLabel( fMain, TString::Format( "Regression tree [%i-%i]",0,fNtrees-1 ) );
70
73 fInput->Resize(100,24);
75
77
78 fCloseButton = new TGTextButton(fButtons,"&Close");
80
81 fDrawButton = new TGTextButton(fButtons,"&Draw");
83
85
86 fMain->SetWindowName("Regression tree");
87 fMain->SetWMPosition(0,0);
91
92 fInput->Connect("ValueSet(Long_t)","TMVA::StatDialogBDTReg",this, "SetItree()");
93
94 // doesn't seem to exist .. gives an 'error message' and seems to work just fine without ... :)
95 // fDrawButton->Connect("Clicked()","TGNumberEntry",fInput, "ValueSet(Long_t)");
96 fDrawButton->Connect("Clicked()", "TMVA::StatDialogBDTReg", this, "Redraw()");
97
98 fCloseButton->Connect("Clicked()", "TMVA::StatDialogBDTReg", this, "Close()");
99}
100
102{
103 DrawTree( fItree );
104}
105
107{
108 if(!fWfile.EndsWith(".xml") ){
109 std::ifstream fin( fWfile );
110 if (!fin.good( )) { // file not found --> Error
111 std::cout << "*** ERROR: Weight file: " << fWfile << " does not exist" << std::endl;
112 return;
113 }
114
115 TString dummy = "";
116
117 // read total number of trees, and check whether requested tree is in range
118 Int_t nc = 0;
119 while (!dummy.Contains("NTrees")) {
120 fin >> dummy;
121 nc++;
122 if (nc > 200) {
123 std::cout << std::endl;
124 std::cout << "*** Huge problem: could not locate term \"NTrees\" in BDT weight file: "
125 << fWfile << std::endl;
126 std::cout << "==> panic abort (please contact the TMVA authors)" << std::endl;
127 std::cout << std::endl;
128 exit(1);
129 }
130 }
131 fin >> dummy;
132 fNtrees = dummy.ReplaceAll("\"","").Atoi();
133 fin.close();
134 }
135 else{
136 void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
137 void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
138 void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
139 while(ch){
140 TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
141 if(nodeName=="Weights") {
142 TMVA::gTools().ReadAttr( ch, "NTrees", fNtrees );
143 break;
144 }
145 ch = TMVA::gTools().xmlengine().GetNext(ch);
146 }
147 }
148 std::cout << "--- Found " << fNtrees << " decision trees in weight file" << std::endl;
149
150}
151
152////////////////////////////////////////////////////////////////////////////////
153/// recursively puts an entries in the histogram for the node and its daughters
154///
155
159{
160 Float_t xsize=xscale*1.5;
162 if (xsize>0.15) xsize=0.1;
163 if (n->GetLeft() != NULL){
164 TLine *a1 = new TLine(x-xscale/4,y-ysize,x-xscale,y-ysize*2);
165 a1->SetLineWidth(2);
166 a1->Draw();
167 DrawNode((TMVA::DecisionTreeNode*) n->GetLeft(), x-xscale, y-yscale, xscale/2, yscale, vars);
168 }
169 if (n->GetRight() != NULL){
170 TLine *a1 = new TLine(x+xscale/4,y-ysize,x+xscale,y-ysize*2);
171 a1->SetLineWidth(2);
172 a1->Draw();
173 DrawNode((TMVA::DecisionTreeNode*) n->GetRight(), x+xscale, y-yscale, xscale/2, yscale, vars );
174 }
175
176 // TPaveText *t = new TPaveText(x-xscale/2,y-yscale/2,x+xscale/2,y+yscale/2, "NDC");
177 TPaveText *t = new TPaveText(x-xsize,y-ysize,x+xsize,y+ysize, "NDC");
178
179 t->SetBorderSize(1);
180
181 t->SetFillStyle(1001);
182 if (n->GetNodeType() == 1) { t->SetFillColor( getSigColorF() ); t->SetTextColor( getSigColorT() ); }
183 else if (n->GetNodeType() == -1) { t->SetFillColor( getBkgColorF() ); t->SetTextColor( getBkgColorT() ); }
184 else if (n->GetNodeType() == 0) { t->SetFillColor( getIntColorF() ); t->SetTextColor( getIntColorT() ); }
185
186 char buffer[25];
187 // sprintf( buffer, "N=%f", n->GetNEvents() );
188 // t->AddText(buffer);
189 snprintf( buffer, 25, "R=%4.1f +- %4.1f", n->GetResponse(),n->GetRMS() );
190 t->AddText(buffer);
191
192 if (n->GetNodeType() == 0){
193 if (n->GetCutType()){
194 t->AddText(vars[n->GetSelector()] + ">" + TString::Format("%5.3g",n->GetCutValue()));
195 }else{
196 t->AddText(vars[n->GetSelector()] + "<" + TString::Format("%5.3g",n->GetCutValue()));
197 }
198 }
199
200 t->Draw();
201
202 return;
203}
204
206{
207 std::cout << "--- Reading Tree " << itree << " from weight file: " << fWfile << std::endl;
209
210
211 if(!fWfile.EndsWith(".xml") ){
212
213 std::ifstream fin( fWfile );
214 if (!fin.good( )) { // file not found --> Error
215 std::cout << "*** ERROR: Weight file: " << fWfile << " does not exist" << std::endl;
216 delete d;
217 d = nullptr;
218 return 0;
219 }
220 TString dummy = "";
221
222 if (itree >= fNtrees) {
223 std::cout << "*** ERROR: requested decision tree: " << itree
224 << ", but number of trained trees only: " << fNtrees << std::endl;
225 delete d;
226 d = nullptr;
227 return 0;
228 }
229
230 // file header with name
231 while (!dummy.Contains("#VAR")) fin >> dummy;
232 fin >> dummy >> dummy >> dummy; // the rest of header line
233
234 // number of variables
235 Int_t nVars;
236 fin >> dummy >> nVars;
237
238 // variable mins and maxes
239 vars = new TString[nVars+1];
240 for (Int_t i = 0; i < nVars; i++) fin >> vars[i] >> dummy >> dummy >> dummy >> dummy;
241 vars[nVars]="FisherCrit";
242
243 char buffer[20];
244 char line[256];
245 snprintf(buffer, 20, "Tree %d",itree);
246
247 while (!dummy.Contains(buffer)) {
248 fin.getline(line,256);
249 dummy = TString(line);
250 }
251
252 d->Read(fin);
253
254 fin.close();
255 }
256 else{
257 if (itree >= fNtrees) {
258 std::cout << "*** ERROR: requested decision tree: " << itree
259 << ", but number of trained trees only: " << fNtrees << std::endl;
260 delete d;
261 d = nullptr;
262 return 0;
263 }
264 Int_t nVars;
265 void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
266 void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
267 void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
268 while(ch){
269 TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
270 if(nodeName=="Variables"){
271 TMVA::gTools().ReadAttr( ch, "NVar", nVars);
272 vars = new TString[nVars+1];
273 void* varnode = TMVA::gTools().xmlengine().GetChild(ch);
274 for (Int_t i = 0; i < nVars; i++){
275 TMVA::gTools().ReadAttr( varnode, "Expression", vars[i]);
277 }
278 vars[nVars]="FisherCrit";
279 }
280 if(nodeName=="Weights") break;
281 ch = TMVA::gTools().xmlengine().GetNext(ch);
282 }
283 ch = TMVA::gTools().xmlengine().GetChild(ch);
284 for (int i=0; i<itree; i++) ch = TMVA::gTools().xmlengine().GetNext(ch);
285 d->ReadXML(ch);
286 }
287 return d;
288}
289
290////////////////////////////////////////////////////////////////////////////////
291
293{
294 TString *vars;
295
296 TMVA::DecisionTree* d = ReadTree( vars, itree );
297 if (d == 0) return;
298
299 UInt_t depth = d->GetTotalTreeDepth();
300 Double_t ystep = 1.0/(depth + 1.0);
301
302 std::cout << "--- Tree depth: " << depth << std::endl;
303
304 TStyle* TMVAStyle = gROOT->GetStyle("Plain"); // our style is based on Plain
305 Int_t canvasColor = TMVAStyle->GetCanvasColor(); // backup
306
307 TString cbuffer = TString::Format( "Reading weight file: %s", fWfile.Data() );
308 TString tbuffer = TString::Format( "Regression Tree no.: %d", itree );
309 if (!fCanvas) fCanvas = new TCanvas( "c1", cbuffer, 200, 0, 1000, 600 );
310 else fCanvas->Clear();
311 fCanvas->Draw();
312 DrawNode( (TMVA::DecisionTreeNode*)d->GetRoot(), 0.5, 1.-0.5*ystep, 0.25, ystep ,vars);
313
314 // make the legend
315 Double_t yup=0.99;
316 Double_t ydown=yup-ystep/2.5;
317 Double_t dy= ystep/2.5 * 0.2;
318
319 TPaveText *whichTree = new TPaveText(0.85,ydown,0.98,yup, "NDC");
320 whichTree->SetBorderSize(1);
321 whichTree->SetFillStyle(1001);
322 whichTree->SetFillColor( TColor::GetColor( "#ffff33" ) );
323 whichTree->AddText( tbuffer );
324 whichTree->Draw();
325
326 TPaveText *intermediate = new TPaveText(0.02,ydown,0.15,yup, "NDC");
327 intermediate->SetBorderSize(1);
328 intermediate->SetFillStyle(1001);
329 intermediate->SetFillColor( getIntColorF() );
330 intermediate->AddText("Intermediate Nodes");
331 intermediate->SetTextColor( getIntColorT() );
332 intermediate->Draw();
333
334 ydown = ydown - ystep/2.5 -dy;
335 yup = yup - ystep/2.5 -dy;
336 TPaveText *signalleaf = new TPaveText(0.02,ydown ,0.15,yup, "NDC");
337 signalleaf->SetBorderSize(1);
338 signalleaf->SetFillStyle(1001);
339 signalleaf->SetFillColor( getSigColorF() );
340 signalleaf->AddText("Leaf Nodes");
341 signalleaf->SetTextColor( getSigColorT() );
342 signalleaf->Draw();
343 /*
344 ydown = ydown - ystep/2.5 -dy;
345 yup = yup - ystep/2.5 -dy;
346 TPaveText *backgroundleaf = new TPaveText(0.02,ydown,0.15,yup, "NDC");
347 backgroundleaf->SetBorderSize(1);
348 backgroundleaf->SetFillStyle(1001);
349 backgroundleaf->SetFillColor( kBkgColorF );
350
351 backgroundleaf->AddText("Backgr. Leaf Nodes");
352 backgroundleaf->SetTextColor( kBkgColorT );
353 backgroundleaf->Draw();
354 */
355 fCanvas->Update();
356 TString fname = fDataset + TString::Format("/plots/%s_%i", fMethName.Data(), itree);
357 std::cout << "--- Creating image: " << fname << std::endl;
358 TMVAGlob::imgconv( fCanvas, fname );
359
360 TMVAStyle->SetCanvasColor( canvasColor );
361}
362
363// ========================================================================================
364
365// intermediate GUI
366void TMVA::BDT_Reg(TString dataset, const TString& fin )
367{
368 // --- read the available BDT weight files
369
370 // destroy all open cavases
372
373 // checks if file with name "fin" is already open, and if not opens one
374 TFile* file = TMVAGlob::OpenFile( fin );
375
376 TDirectory* dir = file->GetDirectory(dataset.Data())->GetDirectory( "Method_BDT" );
377 if (!dir) {
378 std::cout << "*** Error in macro \"BDT_Reg.C\": cannot find directory \"Method_BDT\" in file: " << fin << std::endl;
379 return;
380 }
381
382 // read all directories
383 TIter next( dir->GetListOfKeys() );
384 TKey *key(0);
385 std::vector<TString> methname;
386 std::vector<TString> path;
387 std::vector<TString> wfile;
388 while ((key = (TKey*)next())) {
389 TDirectory* mdir = dir->GetDirectory( key->GetName() );
390 if (!mdir) {
391 std::cout << "*** Error in macro \"BDT_Reg.C\": cannot find sub-directory: " << key->GetName()
392 << " in directory: " << dir->GetName() << std::endl;
393 return;
394 }
395
396 // retrieve weight file name and path
397 TObjString* strPath = (TObjString*)mdir->Get( "TrainingPath" );
398 TObjString* strWFile = (TObjString*)mdir->Get( "WeightFileName" );
399 if (!strPath || !strWFile) {
400 std::cout << "*** Error in macro \"BDT_Reg.C\": could not find TObjStrings \"TrainingPath\" and/or \"WeightFileName\" *** " << std::endl;
401 std::cout << "*** Maybe you are using TMVA >= 3.8.15 with an older training target file ? *** " << std::endl;
402 return;
403 }
404
405 methname.push_back( key->GetName() );
406 path .push_back( strPath->GetString() );
407 wfile .push_back( strWFile->GetString() );
408 }
409
410 // create the control bar
411 TControlBar* cbar = new TControlBar( "vertical", "Choose weight file:", 50, 50 );
412 BDTReg_Global__cbar.push_back(cbar);
413
414 for (UInt_t im=0; im<path.size(); im++) {
415 TString fname = path[im];
416 if (fname[fname.Length()-1] != '/') fname += "/";
417 fname += wfile[im];
418 TString macro = TString::Format( "TMVA::BDT_Reg(\"%s\",0,\"%s\",\"%s\")",dataset.Data(), fname.Data(), methname[im].Data() );
419 cbar->AddButton( fname, macro, "Plot decision trees from this weight file", "button" );
420 }
421
422 // set the style
423 cbar->SetTextColor("blue");
424
425 // draw
426 cbar->Show();
427}
428
429void TMVA::BDTReg_DeleteTBar(int i)
430{
431 // destroy all open canvases
434
435 delete BDTReg_Global__cbar[i];
436 BDTReg_Global__cbar[i] = 0;
437}
438
439// input: - No. of tree
440// - the weight file from which the tree is read
442{
443 // destroy possibly existing dialog windows and/or canvases
446 if(wfile=="")
447 wfile = dataset+"/weights/TMVARegression_BDT.weights.xml";
448
449 // quick check if weight file exist
450 if(!wfile.EndsWith(".xml") ){
451 std::ifstream fin( wfile );
452 if (!fin.good( )) { // file not found --> Error
453 std::cout << "*** ERROR: Weight file: " << wfile << " does not exist" << std::endl;
454 return;
455 }
456 }
457 std::cout << "test1";
458 // set style and remove existing canvas'
460
461 StatDialogBDTReg* gGui = new StatDialogBDTReg(dataset, gClient->GetRoot(), wfile, methName, itree );
462
463 gGui->DrawTree( itree );
464
465 gGui->RaiseDialog();
466}
467
@ kVerticalFrame
Definition GuiTypes.h:381
@ kMainFrame
Definition GuiTypes.h:380
#define d(i)
Definition RSha256.hxx:102
bool Bool_t
Definition RtypesCore.h:63
int Int_t
Definition RtypesCore.h:45
unsigned int UInt_t
Definition RtypesCore.h:46
float Float_t
Definition RtypesCore.h:57
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
#define gClient
Definition TGClient.h:157
@ kLHintsRight
Definition TGLayout.h:26
@ kLHintsLeft
Definition TGLayout.h:24
@ kLHintsBottom
Definition TGLayout.h:29
@ kLHintsTop
Definition TGLayout.h:27
winID h TVirtualViewer3D TVirtualGLPainter p
#define gROOT
Definition TROOT.h:406
#define snprintf
Definition civetweb.c:1540
virtual void SetFillColor(Color_t fcolor)
Set the fill area color.
Definition TAttFill.h:38
virtual void SetFillStyle(Style_t fstyle)
Set the fill area style.
Definition TAttFill.h:40
virtual void SetTextColor(Color_t tcolor=1)
Set the text color.
Definition TAttText.h:46
The Canvas class.
Definition TCanvas.h:23
static Int_t GetColor(const char *hexcolor)
Static method returning color number for color specified by hex color string of form: "#rrggbb",...
Definition TColor.cxx:1924
A Control Bar is a fully user configurable tool which provides fast access to frequently used operati...
Definition TControlBar.h:26
TDirectory * GetDirectory(const char *apath, Bool_t printError=false, const char *funcname="GetDirectory") override
Find a directory named "apath".
Describe directory structure in memory.
Definition TDirectory.h:45
virtual TDirectory * GetDirectory(const char *namecycle, Bool_t printError=false, const char *funcname="GetDirectory")
Find a directory using apath.
virtual TList * GetListOfKeys() const
Definition TDirectory.h:223
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:53
TGDimension GetDefaultSize() const override
std::cout << fWidth << "x" << fHeight << std::endl;
Definition TGFrame.h:316
virtual void AddFrame(TGFrame *f, TGLayoutHints *l=nullptr)
Add frame to the composite frame using the specified layout hints.
Definition TGFrame.cxx:1117
void MapSubwindows() override
Map all sub windows that are part of the composite frame.
Definition TGFrame.cxx:1164
void Resize(UInt_t w=0, UInt_t h=0) override
Resize the frame.
Definition TGFrame.cxx:605
void MapWindow() override
map window
Definition TGFrame.h:204
A composite frame that layout their children in horizontal way.
Definition TGFrame.h:385
This class handles GUI labels.
Definition TGLabel.h:24
This class describes layout hints used by the layout classes.
Definition TGLayout.h:50
Defines top level windows that interact with the system Window Manager.
Definition TGFrame.h:397
void SetWMPosition(Int_t x, Int_t y)
Give the window manager a window position hint.
Definition TGFrame.cxx:1881
void SetWindowName(const char *name=nullptr) override
Set window name. This is typically done via the window manager.
Definition TGFrame.cxx:1788
TGNumberEntry is a number entry input widget with up/down buttons.
virtual void SetLimits(ELimit limits=TGNumberFormat::kNELNoLimits, Double_t min=0, Double_t max=1)
virtual Double_t GetNumber() const
@ kNELLimitMinMax
Both lower and upper limits.
Yield an action as soon as it is clicked.
Definition TGButton.h:142
ROOT GUI Window base class.
Definition TGWindow.h:23
Book space in a file, create I/O buffers, to fill them, (un)compress them.
Definition TKey.h:28
Use the TLine constructor to create a simple line.
Definition TLine.h:22
Implementation of a Decision Tree.
TMVA::DecisionTree * ReadTree(TString *&vars, Int_t itree)
Definition BDT_Reg.cxx:205
TGHorizontalFrame * fButtons
Definition BDT_Reg.h:68
static StatDialogBDTReg * fThis
Definition BDT_Reg.h:95
StatDialogBDTReg(TString dataset, const TGWindow *p, TString wfile, TString methName="BDT", Int_t itree=0)
Definition BDT_Reg.cxx:44
void DrawTree(Int_t itree)
Definition BDT_Reg.cxx:292
TGNumberEntry * fInput
Definition BDT_Reg.h:66
TGMainFrame * fMain
Definition BDT_Reg.h:60
TGTextButton * fDrawButton
Definition BDT_Reg.h:69
static void Delete()
Definition BDT_Reg.h:86
TGTextButton * fCloseButton
Definition BDT_Reg.h:70
void DrawNode(TMVA::DecisionTreeNode *n, Double_t x, Double_t y, Double_t xscale, Double_t yscale, TString *vars)
recursively puts an entries in the histogram for the node and its daughters
Definition BDT_Reg.cxx:156
TXMLEngine & xmlengine()
Definition Tools.h:262
void ReadAttr(void *node, const char *, T &value)
read attribute from xml
Definition Tools.h:329
std::vector< TControlBar * > BDTReg_Global__cbar
Definition BDT_Reg.h:101
const char * GetName() const override
Returns name of object.
Definition TNamed.h:47
Collectable string class.
Definition TObjString.h:28
A Pave (see TPave) with text, lines or/and boxes inside.
Definition TPaveText.h:21
virtual TText * AddText(Double_t x1, Double_t y1, const char *label)
Add a new Text line to this pavetext at given coordinates.
void Draw(Option_t *option="") override
Draw this pavetext with its current attributes.
virtual void SetBorderSize(Int_t bordersize=4)
Sets the border size of the TPave box and shadow.
Definition TPave.h:77
Bool_t Connect(const char *signal, const char *receiver_class, void *receiver, const char *slot)
Non-static method is used to connect from the signal of this object to the receiver slot.
Definition TQObject.cxx:869
Basic string class.
Definition TString.h:139
Int_t Atoi() const
Return integer value of string.
Definition TString.cxx:1988
TString & ReplaceAll(const TString &s1, const TString &s2)
Definition TString.h:704
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2378
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition TString.h:632
TStyle objects may be created to define special styles.
Definition TStyle.h:29
XMLNodePointer_t GetChild(XMLNodePointer_t xmlnode, Bool_t realnode=kTRUE)
returns first child of xmlnode
XMLNodePointer_t DocGetRootElement(XMLDocPointer_t xmldoc)
returns root node of document
XMLDocPointer_t ParseFile(const char *filename, Int_t maxbuf=100000)
Parses content of file and tries to produce xml structures.
XMLNodePointer_t GetNext(XMLNodePointer_t xmlnode, Bool_t realnode=kTRUE)
return next to xmlnode node if realnode==kTRUE, any special nodes in between will be skipped
TLine * line
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16
void Initialize(Bool_t useTMVAStyle=kTRUE)
Definition tmvaglob.cxx:176
void DestroyCanvases()
Definition tmvaglob.cxx:166
TFile * OpenFile(const TString &fin)
Definition tmvaglob.cxx:192
void imgconv(TCanvas *c, const TString &fname)
Definition tmvaglob.cxx:212
Int_t getSigColorT()
Definition BDT.h:40
void BDTReg_DeleteTBar(int i)
Int_t getIntColorT()
Definition BDT.h:42
Int_t getIntColorF()
Definition BDT.h:37
Tools & gTools()
Int_t getSigColorF()
Definition BDT.h:35
Int_t getBkgColorT()
Definition BDT.h:41
Int_t getBkgColorF()
Definition BDT.h:36
void BDT_Reg(TString dataset, const TString &fin="TMVAReg.root")