Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
network.cxx
Go to the documentation of this file.
1#include "TMVA/network.h"
2#include <iostream>
3
4#include "TArrow.h"
5#include "TEllipse.h"
6#include "TPaveLabel.h"
7#include "TCanvas.h"
8#include "TH2F.h"
9#include "TFile.h"
10#include "TString.h"
11#include "TDirectory.h"
12#include "TKey.h"
13#include "TText.h"
14
15using std::cout;
16using std::endl;
17
18
19// this macro prints out a neural network generated by MethodMLP graphically
20// @author: Matt Jachowski, jachowski@stanford.edu
21
23
25
26
28
29void TMVA::draw_network(TString dataset, TFile* f, TDirectory* d, const TString& hName,
30 Bool_t movieMode , const TString& epoch )
31{
34
37
38 // create canvas
39 TStyle* TMVAStyle = gROOT->GetStyle("TMVA"); // the TMVA style
40 Int_t canvasColor = TMVAStyle->GetCanvasColor(); // backup
41 TMVAStyle->SetCanvasColor( c_DarkBackground );
42
43 Int_t titleFillColor = TMVAStyle->GetTitleFillColor();
44 Int_t titleTextColor = TMVAStyle->GetTitleTextColor();
45 Int_t borderSize = TMVAStyle->GetTitleBorderSize();
46
47 TMVAStyle->SetTitleFillColor( c_DarkBackground );
48 TMVAStyle->SetTitleTextColor( TColor::GetColor( "#FFFFFF" ) );
49 TMVAStyle->SetTitleBorderSize( 0 );
50
51 static Int_t icanvas = -1;
52 Int_t ixc = 100 + (icanvas)*40;
53 Int_t iyc = 0 + (icanvas+1)*20;
54 if (MovieMode) ixc = iyc = 0;
56 TString canvastitle = TString::Format("Neural Network Layout for: %s", d->GetName());
58 ixc, 0 + (icanvas+1)*20, 1000, 650 );
59 icanvas++;
60 TIter next = d->GetListOfKeys();
61 TKey *key( 0 );
62 Int_t numHists = 0;
63
64 // loop over all histograms with hName in name again
65 next.Reset();
67 // find max weight
68 while ((key = (TKey*)next())) {
69
70 TClass *cl = gROOT->GetClass(key->GetClassName());
71 if (!cl->InheritsFrom("TH2F"))
72 {
73 continue;
74 }else
75 {
76 std::cout<<key->GetClassName()<<"----"<<cl->InheritsFrom("TH2F")<<"----"<<hName<<std::endl;
77 }
78
79 TH2F* h = (TH2F*)key->ReadObj();
80 if (!h) {
81 cout << "Big troubles in \"draw_network\" (1)" << endl;
82 exit(1);
83 }
84 std::cout<<h->GetName()<<"----"<<hName<<std::endl;
85 if (TString(h->GetName()).Contains( hName )){
86 numHists++;
87
88 Int_t n1 = h->GetNbinsX();
89 Int_t n2 = h->GetNbinsY();
90 for (Int_t i = 0; i < n1; i++) {
91 for (Int_t j = 0; j < n2; j++) {
92 Double_t weight = TMath::Abs(h->GetBinContent(i+1, j+1));
93 if (maxWeight < weight) maxWeight = weight;
94 }
95 }
96 }
97 }
98 if (numHists == 0) {
99 cout << "Error: could not find histograms" << endl;
100 //exit(1);
101 }
102
103 // draw network
104 next.Reset();
105 //cout << "check4a" << endl;
106
107 Int_t count = 0;
108 while ((key = (TKey*)next())) {
109 //cout << "check4b" << endl;
110
111 TClass *cl = gROOT->GetClass(key->GetClassName());
112 if (!cl->InheritsFrom("TH2F")) continue;
113 //cout << "check4c" << endl;
114
115 TH2F* h = (TH2F*)key->ReadObj();
116 //cout << (h->GetName()) << endl;
117 if (!h) {
118 cout << "Big troubles in \"draw_network\" (2)" << endl;
119 exit(1);
120 }
121 //cout << (h->GetName()) << endl;
122 if (TString(h->GetName()).Contains( hName )) {
123 //cout << (h->GetName()) << endl;
124 draw_layer(dataset,c, h, count++, numHists+1, maxWeight);
125 }
126 //cout << "check4d" << endl;
127
128 }
130
131 // add epoch
132 if (MovieMode) {
133 TText t;
134 t.SetTextSize( 0.04 );
135 t.SetTextColor( 0 );
136 t.SetTextAlign( 31 );
137 t.DrawTextNDC( 1 - c->GetRightMargin(), 1 - c->GetTopMargin() - 0.033,
138 TString::Format( "Epoch: %s", epoch.Data() ) );
139 }
140
141 // ============================================================
143 // ============================================================
144
145 c->Update();
146 if (MovieMode) {
147 // save to file
148 TString dirname = "movieplots";
149 TString foutname = dirname + "/" + hName;
150 foutname.Resize( foutname.Length()-5 );
151 foutname.ReplaceAll("epochmonitoring___","");
152 foutname += ".gif";
153
154 cout << "storing file: " << foutname << endl;
155 c->Print(foutname);
156 c->Clear();
157 delete c;
158 }
159 else {
160 TString fname = dataset+"/plots/network";
162 }
163
164 // reset global style changes so that it does not affect other plots
165 TMVAStyle->SetCanvasColor ( canvasColor );
166 TMVAStyle->SetTitleFillColor ( titleFillColor );
167 TMVAStyle->SetTitleTextColor ( titleTextColor );
168 TMVAStyle->SetTitleBorderSize( borderSize );
169
170}
171
173{
174 const Double_t LABEL_HEIGHT = 0.032;
175 const Double_t LABEL_WIDTH = 0.20;
179
180 for (Int_t i = 0; i < nLayers; i++) {
181 TString label = TString::Format("Layer %i", i);
182 if (i == nLayers-1) label = "Output layer";
183 Double_t cx = i*(1.0-LABEL_WIDTH)/nLayers+1.0/(2.0*nLayers)+LABEL_WIDTH;
184 Double_t x1 = cx-0.8*effWidth/2.0;
185 Double_t x2 = cx+0.8*effWidth/2.0;
186 Double_t y1 = margY;
188
189 TPaveLabel *p = new TPaveLabel(x1, y1, x2, y2, label, "br");
190 p->SetFillColor(gStyle->GetTitleFillColor());
191 p->SetTextColor(gStyle->GetTitleTextColor());
192 p->SetFillStyle(1001);
193 p->SetBorderSize( 0 );
194 p->Draw();
195 }
196}
197
200{
201 const Double_t LABEL_HEIGHT = 0.04;
202 const Double_t LABEL_WIDTH = 0.20;
204 Double_t margX = 0.01;
206
208 if (varNames == 0) exit(1);
209
211
212 for (Int_t i = 0; i < nInputs; i++) {
213 if (i != nInputs-1) input = varNames[i];
214 else input = "Bias node";
215 Double_t x = margX + width;
216 Double_t y = cy[i] - effHeight;
217
218 TText t;
220 t.SetTextAlign(31);
222 if (i == nInputs-1) t.SetTextColor( TColor::GetColor( "#AFDCEC" ) );
223 t.DrawText( x, y+0.018, input + " :");
224 }
225
226 delete[] varNames;
227}
228
230{
231 const TString directories[6] = { "InputVariables_NoTransform",
232 "InputVariables_DecorrTransform",
233 "InputVariables_PCATransform",
234 "InputVariables_Id",
235 "InputVariables_Norm",
236 "InputVariables_Deco"};
237
238 TDirectory* dir = 0;
239 for (Int_t i=0; i<6; i++) {
240 dir = (TDirectory*)Network_GFile->GetDirectory(dataset.Data())->Get( directories[i] );
241 if (dir != 0) break;
242 }
243 if (dir==0) {
244 cout << "*** Big troubles in macro \"network.cxx\": could not find directory for input variables, "
245 << "and hence could not determine variable names --> abort" << endl;
246 return 0;
247 }
248 dir->cd();
249
250 TString* vars = new TString[nVars];
251 Int_t ivar = 0;
252
253 // loop over all objects in directory
254 TIter next(dir->GetListOfKeys());
255 TKey* key = 0;
256 while ((key = (TKey*)next())) {
257 if (key->GetCycle() != 1) continue;
258
259 if (!TString(key->GetName()).Contains("__S") &&
260 !TString(key->GetName()).Contains("__r") &&
261 !TString(key->GetName()).Contains("Regression"))
262 continue;
263 if (TString(key->GetName()).Contains("target"))
264 continue;
265
266 // make sure, that we only look at histograms
267 TClass *cl = gROOT->GetClass(key->GetClassName());
268 if (!cl->InheritsFrom("TH1")) continue;
269 TH1 *sig = (TH1*)key->ReadObj();
270 TString hname = sig->GetTitle();
271
272 vars[ivar] = hname; ivar++;
273
274 if (ivar > nVars-1) break;
275 }
276
277 if (ivar != nVars-1) { // bias layer and targets are also in nVars counts
278 cout << "*** Troubles in \"network.cxx\": did not reproduce correct number of "
279 << "input variables: " << ivar << " != " << nVars << endl;
280 }
281
282 return vars;
283}
284
287{
289
290 switch (whichActivation) {
291 case 0:
292 activation = TMVA::TMVAGlob::findImage("sigmoid-small.png");
293 break;
294 case 1:
295 activation = TMVA::TMVAGlob::findImage("line-small.png");
296 break;
297 default:
298 cout << "Activation index " << whichActivation << " is not known." << endl;
299 cout << "You messed up or you need to modify network.cxx to introduce a new "
300 << "activation function (and image) corresponding to this index" << endl;
301 }
302
303 if (activation == NULL) {
304 cout << "Could not create an image... exit" << endl;
305 return;
306 }
307
308 activation->SetConstRatio(kFALSE);
309
310 radx *= 0.7;
311 rady *= 0.7;
312 TString name = TString::Format("activation%f%f", cx, cy);
313 TPad* p = new TPad(name, name, cx-radx, cy-rady, cx+radx, cy+rady);
314
315 p->Draw();
316 p->cd();
317
318 activation->Draw();
319 c->cd();
320}
321
324{
325 const Double_t MAX_NEURONS_NICE = 12;
326 const Double_t LABEL_HEIGHT = 0.03;
327 const Double_t LABEL_WIDTH = 0.20;
328 Double_t ratio = ((Double_t)(c->GetWindowHeight())) / c->GetWindowWidth();
329 Double_t rad, cx1, *cy1, cx2, *cy2;
330
331 // this is the smallest radius that will still display the activation images
332 rad = 0.04*650/c->GetWindowHeight();
333
334 Int_t nNeurons1 = h->GetNbinsX();
335 cx1 = iHist*(1.0-LABEL_WIDTH)/nLayers + 1.0/(2.0*nLayers) + LABEL_WIDTH;
336 cy1 = new Double_t[nNeurons1];
337
338 Int_t nNeurons2 = h->GetNbinsY();
339 cx2 = (iHist+1)*(1.0-LABEL_WIDTH)/nLayers + 1.0/(2.0*nLayers) + LABEL_WIDTH;
340 cy2 = new Double_t[nNeurons2];
341
342 Double_t effRad1 = rad;
344 effRad1 = 0.8*(1.0-LABEL_HEIGHT)/(2.0*nNeurons1);
345
346 for (Int_t i = 0; i < nNeurons1; i++) {
347 cy1[nNeurons1-i-1] = i*(1.0-LABEL_HEIGHT)/nNeurons1 + 1.0/(2.0*nNeurons1) + LABEL_HEIGHT;
348
349 if (iHist == 0) {
350
352 effRad1*ratio, effRad1, 0, 360, 0 );
353 ellipse->SetFillColor(TColor::GetColor( "#fffffd" ));
354 ellipse->SetFillStyle(1001);
355 ellipse->Draw();
356
357 if (i == 0) ellipse->SetLineColor(9);
358
359 if (nNeurons1 > MAX_NEURONS_NICE) continue;
360
362 if (iHist==0 || iHist==nLayers-1 || i==0) whichActivation = 1;
364 rad*ratio, rad, whichActivation);
365 }
366 }
367
368 if (iHist == 0) draw_input_labels(dataset,nNeurons1, cy1, rad, (1.0-LABEL_WIDTH)/nLayers);
369
370 Double_t effRad2 = rad;
372 effRad2 = 0.8*(1.0-LABEL_HEIGHT)/(2.0*nNeurons2);
373
374 for (Int_t i = 0; i < nNeurons2; i++) {
375 cy2[nNeurons2-i-1] = i*(1.0-LABEL_HEIGHT)/nNeurons2 + 1.0/(2.0*nNeurons2) + LABEL_HEIGHT;
376
378 new TEllipse(cx2, cy2[nNeurons2-i-1], effRad2*ratio, effRad2, 0, 360, 0);
379 ellipse->SetFillColor(TColor::GetColor( "#fffffd" ));
380 ellipse->SetFillStyle(1001);
381 ellipse->Draw();
382
383 if (i == 0 && nNeurons2 > 1) ellipse->SetLineColor(9);
384
385 if (nNeurons2 > MAX_NEURONS_NICE) continue;
386
388 if (iHist+1==0 || iHist+1==nLayers-1 || i==0) whichActivation = 1;
389 draw_activation(c, cx2, cy2[nNeurons2-i-1], rad*ratio, rad, whichActivation);
390 }
391
392 for (Int_t i = 0; i < nNeurons1; i++) {
393 for (Int_t j = 0; j < nNeurons2; j++) {
394 draw_synapse(cx1, cy1[i], cx2, cy2[j], effRad1*ratio, effRad2*ratio,
395 h->GetBinContent(i+1, j+1)/maxWeight);
396 }
397 }
398
399 delete [] cy1;
400 delete [] cy2;
401}
402
405{
406 const Double_t TIP_SIZE = 0.01;
407 const Double_t MAX_WEIGHT = 8;
408 const Double_t MAX_COLOR = 100; // red
409 const Double_t MIN_COLOR = 60; // blue
410
411 if (weightNormed == 0) return;
412
413 // gStyle->SetPalette(100, NULL);
414
415 TArrow *arrow = new TArrow(cx1+rad1, cy1, cx2-rad2, cy2, TIP_SIZE, ">");
416 arrow->SetFillColor(1);
417 arrow->SetFillStyle(1001);
418 arrow->SetLineWidth((Int_t)(TMath::Abs(weightNormed)*MAX_WEIGHT+0.5));
419 arrow->SetLineColor((Int_t)((weightNormed+1.0)/2.0*(MAX_COLOR-MIN_COLOR)+MIN_COLOR+0.5));
420 arrow->Draw();
421}
422
423// input: - Input file (result from TMVA),
424// - use of TMVA plotting TStyle
426{
427 // set style and remove existing canvas'
429
430 // checks if file with name "fin" is already open, and if not opens one
431 TFile* file = TMVAGlob::OpenFile( fin );
432 TIter next(file->GetDirectory(dataset.Data())->GetListOfKeys());
433 TKey *key(0);
434 while( (key = (TKey*)next()) ) {
435 if (!TString(key->GetName()).BeginsWith("Method_MLP")) continue;
436 if( ! gROOT->GetClass(key->GetClassName())->InheritsFrom("TDirectory") ) continue;
437
438 cout << "--- Found directory: " << ((TDirectory*)key->ReadObj())->GetName() << endl;
439
440 TDirectory* mDir = (TDirectory*)key->ReadObj();
441
442 TIter keyIt(mDir->GetListOfKeys());
443 TKey *titkey;
444 while((titkey = (TKey*)keyIt())) {
445 if( ! gROOT->GetClass(titkey->GetClassName())->InheritsFrom("TDirectory") ) continue;
446
447 TDirectory* dir = (TDirectory *)titkey->ReadObj();
448 dir->cd();
451 if (ni==0) {
452 cout << "No titles found for Method_MLP" << endl;
453 return;
454 }
455 draw_network(dataset, file, dir );
456 }
457 }
458
459 return;
460}
461
#define d(i)
Definition RSha256.hxx:102
#define f(i)
Definition RSha256.hxx:104
#define c(i)
Definition RSha256.hxx:101
#define h(i)
Definition RSha256.hxx:106
constexpr Bool_t kFALSE
Definition RtypesCore.h:94
double Double_t
Definition RtypesCore.h:59
constexpr Bool_t kTRUE
Definition RtypesCore.h:93
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
#define MAX_COLOR
Definition TGHtml.cxx:1670
winID h TVirtualViewer3D TVirtualGLPainter p
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void input
Option_t Option_t TPoint TPoint const char x2
Option_t Option_t TPoint TPoint const char x1
Option_t Option_t TPoint TPoint const char y2
Option_t Option_t width
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t height
Option_t Option_t TPoint TPoint const char y1
char name[80]
Definition TGX11.cxx:110
#define gROOT
Definition TROOT.h:406
R__EXTERN TStyle * gStyle
Definition TStyle.h:436
Draw all kinds of Arrows.
Definition TArrow.h:29
virtual void SetTextAlign(Short_t align=11)
Set the text alignment.
Definition TAttText.h:44
virtual void SetTextColor(Color_t tcolor=1)
Set the text color.
Definition TAttText.h:46
virtual void SetTextSize(Float_t tsize=1)
Set the text size.
Definition TAttText.h:49
The Canvas class.
Definition TCanvas.h:23
TClass instances represent classes, structs and namespaces in the ROOT type system.
Definition TClass.h:84
Bool_t InheritsFrom(const char *cl) const override
Return kTRUE if this class inherits from a class with name "classname".
Definition TClass.cxx:4971
static TClass * GetClass(const char *name, Bool_t load=kTRUE, Bool_t silent=kFALSE)
Static method returning pointer to TClass of the specified class name.
Definition TClass.cxx:3069
static Int_t GetColor(const char *hexcolor)
Static method returning color number for color specified by hex color string of form: "#rrggbb",...
Definition TColor.cxx:1924
TDirectory * GetDirectory(const char *apath, Bool_t printError=false, const char *funcname="GetDirectory") override
Find a directory named "apath".
Describe directory structure in memory.
Definition TDirectory.h:45
virtual Bool_t cd()
Change current directory to "this" directory.
virtual TList * GetListOfKeys() const
Definition TDirectory.h:223
Draw Ellipses.
Definition TEllipse.h:23
A ROOT file is an on-disk file, usually with extension .root, that stores objects in a file-system-li...
Definition TFile.h:53
TH1 is the base class of all histogram classes in ROOT.
Definition TH1.h:59
2-D histogram with a float per channel (see TH1 documentation)
Definition TH2.h:303
An abstract interface to image processing library.
Definition TImage.h:29
void Reset()
Book space in a file, create I/O buffers, to fill them, (un)compress them.
Definition TKey.h:28
virtual const char * GetClassName() const
Definition TKey.h:75
Short_t GetCycle() const
Return cycle number associated to this key.
Definition TKey.cxx:577
virtual TObject * ReadObj()
To read a TObject* from the file.
Definition TKey.cxx:758
A doubly linked list.
Definition TList.h:38
const char * GetName() const override
Returns name of object.
Definition TNamed.h:47
const char * GetTitle() const override
Returns title of object.
Definition TNamed.h:48
The most important graphics class in the ROOT system.
Definition TPad.h:28
A Pave (see TPave) with a text centered in the Pave.
Definition TPaveLabel.h:20
Basic string class.
Definition TString.h:139
Bool_t BeginsWith(const char *s, ECaseCompare cmp=kExact) const
Definition TString.h:623
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
Definition TString.cxx:2378
Bool_t Contains(const char *pat, ECaseCompare cmp=kExact) const
Definition TString.h:632
TStyle objects may be created to define special styles.
Definition TStyle.h:29
Color_t GetTitleFillColor() const
Definition TStyle.h:271
Color_t GetTitleTextColor() const
Definition TStyle.h:272
Base class for several text objects.
Definition TText.h:22
virtual TText * DrawText(Double_t x, Double_t y, const char *text)
Draw this text with new coordinates.
Definition TText.cxx:176
virtual TText * DrawTextNDC(Double_t x, Double_t y, const char *text)
Draw this text with new coordinates in NDC.
Definition TText.cxx:202
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
double ratio(double numerator, double denominator)
Definition MathFuncs.h:102
UInt_t GetListOfTitles(TDirectory *rfdir, TList &titles)
Definition tmvaglob.cxx:643
void Initialize(Bool_t useTMVAStyle=kTRUE)
Definition tmvaglob.cxx:176
void plot_logo(Float_t v_scale=1.0, Float_t skew=1.0)
Definition tmvaglob.cxx:270
TFile * OpenFile(const TString &fin)
Definition tmvaglob.cxx:192
TImage * findImage(const char *imageName)
Definition tmvaglob.cxx:252
void imgconv(TCanvas *c, const TString &fname)
Definition tmvaglob.cxx:212
TString * get_var_names(TString dataset, Int_t nVars)
void draw_layer(TString dataset, TCanvas *c, TH2F *h, Int_t iHist, Int_t nLayers, Double_t maxWeight)
void draw_activation(TCanvas *c, Double_t cx, Double_t cy, Double_t radx, Double_t rady, Int_t whichActivation)
void draw_network(TString dataset, TFile *f, TDirectory *d, const TString &hName="weights_hist", Bool_t movieMode=kFALSE, const TString &epoch="")
void draw_input_labels(TString dataset, Int_t nInputs, Double_t *cy, Double_t rad, Double_t layerWidth)
void draw_layer_labels(Int_t nLayers)
void draw_synapse(Double_t cx1, Double_t cy1, Double_t cx2, Double_t cy2, Double_t rad1, Double_t rad2, Double_t weightNormed)
void network(TString dataset, TString fin="TMVA.root", Bool_t useTMVAStyle=kTRUE)
Short_t Abs(Short_t d)
Returns the absolute value of parameter Short_t d.
Definition TMathBase.h:123
Bool_t MovieMode
Definition network.cxx:27
TFile * Network_GFile
Definition network.cxx:22
static Int_t c_DarkBackground
Definition network.cxx:24