Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodRSNNS.cxx
Go to the documentation of this file.
1// @(#)root/tmva/rmva $Id$
2// Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3
4
5/**********************************************************************************
6 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
7 * Package: TMVA *
8 * Class : MethodRSNNS *
9 * Web : http://oproject.org *
10 * *
11 * Description: *
12 * Neural Networks in R using the Stuttgart Neural Network Simulator *
13 * *
14 * *
15 * Redistribution and use in source and binary forms, with or without *
16 * modification, are permitted according to the terms listed in LICENSE *
17 * (see tmva/doc/LICENSE) *
18 * *
19 **********************************************************************************/
20
21#include <iomanip>
22
23#include "TMath.h"
24#include "Riostream.h"
25#include "TMatrix.h"
26#include "TMatrixD.h"
27#include "TVectorD.h"
28
30#include "TMVA/MethodRSNNS.h"
31#include "TMVA/Tools.h"
32#include "TMVA/Config.h"
33#include "TMVA/Ranking.h"
34#include "TMVA/Types.h"
35#include "TMVA/PDF.h"
37
38#include "TMVA/Results.h"
39#include "TMVA/Timer.h"
40
41using namespace TMVA;
42
44
45
46//creating an Instance
48
49//_______________________________________________________________________
51 const TString &methodTitle,
53 const TString &theOption) :
54 RMethodBase(jobName, Types::kRSNNS, methodTitle, dsi, theOption),
55 fMvaCounter(0),
56 predict("predict"),
57 mlp("mlp"),
58 asfactor("as.factor"),
59 fModel(NULL)
60{
61 fNetType = methodTitle;
62 if (fNetType != "RMLP") {
63 Log() << kFATAL << " Unknow Method" + fNetType
64 << Endl;
65 return;
66 }
67
68 // standard constructor for the RSNNS
69 //RSNNS Options for all NN methods
70 fSize = "c(5)";
71 fMaxit = 100;
72
73 fInitFunc = "Randomize_Weights";
74 fInitFuncParams = "c(-0.3,0.3)"; //the maximun number of pacameter is 5 see RSNNS::getSnnsRFunctionTable() type 6
75
76 fLearnFunc = "Std_Backpropagation"; //
77 fLearnFuncParams = "c(0.2,0)";
78
79 fUpdateFunc = "Topological_Order";
80 fUpdateFuncParams = "c(0)";
81
82 fHiddenActFunc = "Act_Logistic";
85 fPruneFunc = "NULL";
86 fPruneFuncParams = "NULL";
87
88}
89
90//_______________________________________________________________________
93 fMvaCounter(0),
94 predict("predict"),
95 mlp("mlp"),
96 asfactor("as.factor"),
97 fModel(NULL)
98
99{
100 fNetType = "RMLP"; //GetMethodName();//GetMethodName() is not returning RMLP is reting MethodBase why?
101 if (fNetType != "RMLP") {
102 Log() << kFATAL << " Unknow Method = " + fNetType
103 << Endl;
104 return;
105 }
106
107 // standard constructor for the RSNNS
108 //RSNNS Options for all NN methods
109 fSize = "c(5)";
110 fMaxit = 100;
111
112 fInitFunc = "Randomize_Weights";
113 fInitFuncParams = "c(-0.3,0.3)"; //the maximun number of pacameter is 5 see RSNNS::getSnnsRFunctionTable() type 6
114
115 fLearnFunc = "Std_Backpropagation"; //
116 fLearnFuncParams = "c(0.2,0)";
117
118 fUpdateFunc = "Topological_Order";
119 fUpdateFuncParams = "c(0)";
120
121 fHiddenActFunc = "Act_Logistic";
123 fLinOut = kFALSE;
124 fPruneFunc = "NULL";
125 fPruneFuncParams = "NULL";
126}
127
128
129//_______________________________________________________________________
131{
132 if (fModel) delete fModel;
133}
134
135//_______________________________________________________________________
141
142
143//_______________________________________________________________________
145{
146 if (!IsModuleLoaded) {
147 Error("Init", "R's package RSNNS can not be loaded.");
148 Log() << kFATAL << " R's package RSNNS can not be loaded."
149 << Endl;
150 return;
151 }
152 //factors creations
153 //RSNNS mlp require a numeric factor then background=0 signal=1 from fFactorTrain/fFactorTest
154 UInt_t size = fFactorTrain.size();
155 fFactorNumeric.resize(size);
156
157 for (UInt_t i = 0; i < size; i++) {
158 if (fFactorTrain[i] == "signal") fFactorNumeric[i] = 1;
159 else fFactorNumeric[i] = 0;
160 }
161}
162
164{
165 if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL << "<Train> Data() has zero events" << Endl;
166 if (fNetType == "RMLP") {
168 if (fPruneFunc == "NULL") PruneFunc = r.Eval("NULL");
169 else PruneFunc = r.Eval(Form("'%s'", fPruneFunc.Data()));
170
173 ROOT::R::Label["size"] = r.Eval(fSize),
174 ROOT::R::Label["maxit"] = fMaxit,
175 ROOT::R::Label["initFunc"] = fInitFunc,
176 ROOT::R::Label["initFuncParams"] = r.Eval(fInitFuncParams),
177 ROOT::R::Label["learnFunc"] = fLearnFunc,
178 ROOT::R::Label["learnFuncParams"] = r.Eval(fLearnFuncParams),
179 ROOT::R::Label["updateFunc"] = fUpdateFunc,
180 ROOT::R::Label["updateFuncParams"] = r.Eval(fUpdateFuncParams),
181 ROOT::R::Label["hiddenActFunc"] = fHiddenActFunc,
182 ROOT::R::Label["shufflePatterns"] = fShufflePatterns,
183 ROOT::R::Label["libOut"] = fLinOut,
184 ROOT::R::Label["pruneFunc"] = PruneFunc,
185 ROOT::R::Label["pruneFuncParams"] = r.Eval(fPruneFuncParams));
187 //if model persistence is enabled saving it is R serialziation.
188 if (IsModelPersistence())
189 {
190 TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
191 Log() << Endl;
192 Log() << gTools().Color("bold") << "--- Saving State File In:" << gTools().Color("reset") << path << Endl;
193 Log() << Endl;
194 r["RMLPModel"] << Model;
195 r << "save(RMLPModel,file='" + path + "')";
196 }
197 }
198}
199
200//_______________________________________________________________________
202{
203 //RSNNS Options for all NN methods
204// TVectorF fSize;//number of units in the hidden layer(s)
205 DeclareOptionRef(fSize, "Size", "number of units in the hidden layer(s)");
206 DeclareOptionRef(fMaxit, "Maxit", "Maximum of iterations to learn");
207
208 DeclareOptionRef(fInitFunc, "InitFunc", "the initialization function to use");
209 DeclareOptionRef(fInitFuncParams, "InitFuncParams", "the parameters for the initialization function");
210
211 DeclareOptionRef(fLearnFunc, "LearnFunc", "the learning function to use");
212 DeclareOptionRef(fLearnFuncParams, "LearnFuncParams", "the parameters for the learning function");
213
214 DeclareOptionRef(fUpdateFunc, "UpdateFunc", "the update function to use");
215 DeclareOptionRef(fUpdateFuncParams, "UpdateFuncParams", "the parameters for the update function");
216
217 DeclareOptionRef(fHiddenActFunc, "HiddenActFunc", "the activation function of all hidden units");
218 DeclareOptionRef(fShufflePatterns, "ShufflePatterns", "should the patterns be shuffled?");
219 DeclareOptionRef(fLinOut, "LinOut", "sets the activation function of the output units to linear or logistic");
220
221 DeclareOptionRef(fPruneFunc, "PruneFunc", "the prune function to use");
222 DeclareOptionRef(fPruneFuncParams, "PruneFuncParams", "the parameters for the pruning function. Unlike the\
223 other functions, these have to be given in a named list. See\
224 the pruning demos for further explanation.the update function to use");
225
226}
227
228//_______________________________________________________________________
230{
231 if (fMaxit <= 0) {
232 Log() << kERROR << " fMaxit <=0... that does not work !! "
233 << " I set it to 50 .. just so that the program does not crash"
234 << Endl;
235 fMaxit = 1;
236 }
237 // standard constructor for the RSNNS
238 //RSNNS Options for all NN methods
239
240}
241
242//_______________________________________________________________________
244{
245 Log() << kINFO << "Testing Classification " << fNetType << " METHOD " << Endl;
246
248}
249
250
251//_______________________________________________________________________
253{
256 const TMVA::Event *ev = GetEvent();
257 const UInt_t nvar = DataInfo().GetNVariables();
259 for (UInt_t i = 0; i < nvar; i++) {
260 fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
261 }
262 //if using persistence model
264
265 TVectorD result = predict(*fModel, fDfEvent, ROOT::R::Label["type"] = "prob");
266 mvaValue = result[0]; //returning signal prob
267 return mvaValue;
268}
269
270////////////////////////////////////////////////////////////////////////////////
271/// get all the MVA values for the events of the current Data type
273{
274 Long64_t nEvents = Data()->GetNEvents();
275 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
276 if (firstEvt < 0) firstEvt = 0;
277
278 nEvents = lastEvt-firstEvt;
279
280 UInt_t nvars = Data()->GetNVariables();
281
282 // use timer
283 Timer timer( nEvents, GetName(), kTRUE );
284 if (logProgress)
285 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())<< "Evaluation of " << GetMethodName() << " on "
286 << (Data()->GetCurrentType()==Types::kTraining?"training":"testing") << " sample (" << nEvents << " events)" << Endl;
287
288
289 // fill R DATA FRAME with events data
290 std::vector<std::vector<Float_t> > inputData(nvars);
291 for (UInt_t i = 0; i < nvars; i++) {
292 inputData[i] = std::vector<Float_t>(nEvents);
293 }
294
295 for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
297 const TMVA::Event *e = Data()->GetEvent();
298 assert(nvars == e->GetNVariables());
299 for (UInt_t i = 0; i < nvars; i++) {
300 inputData[i][ievt] = e->GetValue(i);
301 }
302 // if (ievt%100 == 0)
303 // std::cout << "Event " << ievt << " type" << DataInfo().IsSignal(e) << " : " << pValue[ievt*nvars] << " " << pValue[ievt*nvars+1] << " " << pValue[ievt*nvars+2] << std::endl;
304 }
305
307 for (UInt_t i = 0; i < nvars; i++) {
308 evtData[DataInfo().GetListOfVariables()[i].Data()] = inputData[i];
309 }
310 //if using persistence model
312
313 std::vector<Double_t> mvaValues(nEvents);
315 //std::vector<Double_t> probValues(2*nEvents);
316 mvaValues = result.As<std::vector<Double_t>>();
317 // assert(probValues.size() == 2*mvaValues.size());
318 // std::copy(probValues.begin()+nEvents, probValues.end(), mvaValues.begin() );
319
320 if (logProgress) {
321 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Elapsed time for evaluation of " << nEvents << " events: "
322 << timer.GetElapsedTime() << " " << Endl;
323 }
324
325 return mvaValues;
326
327}
328
329
330//_______________________________________________________________________
332{
333 ROOT::R::TRInterface::Instance().Require("RSNNS");
334 TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
335 Log() << Endl;
336 Log() << gTools().Color("bold") << "--- Loading State File From:" << gTools().Color("reset") << path << Endl;
337 Log() << Endl;
338 r << "load('" + path + "')";
339 SEXP Model;
340 r["RMLPModel"] >> Model;
341 fModel = new ROOT::R::TRObject(Model);
342
343}
344
345
346//_______________________________________________________________________
348{
349// get help message text
350//
351// typical length of text line:
352// "|--------------------------------------------------------------|"
353 Log() << Endl;
354 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
355 Log() << Endl;
356 Log() << "Decision Trees and Rule-Based Models " << Endl;
357 Log() << Endl;
358 Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
359 Log() << Endl;
360 Log() << Endl;
361 Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
362 Log() << Endl;
363 Log() << "<None>" << Endl;
364}
365
#define REGISTER_METHOD(CLASS)
for example
#define e(i)
Definition RSha256.hxx:103
size_t size(const MatrixT &matrix)
retrieve the size of a square matrix
bool Bool_t
Boolean (0=false, 1=true) (bool)
Definition RtypesCore.h:77
constexpr Bool_t kFALSE
Definition RtypesCore.h:108
long long Long64_t
Portable signed long integer 8 bytes.
Definition RtypesCore.h:83
constexpr Bool_t kTRUE
Definition RtypesCore.h:107
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
void Error(const char *location, const char *msgfmt,...)
Use this function in case an error occurred.
Definition TError.cxx:208
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2495
This is a class to create DataFrames from ROOT to R.
static TRInterface & Instance()
static method to get an TRInterface instance reference
Int_t Eval(const TString &code, TRObject &ans)
Method to eval R code and you get the result in a reference to TRObject.
This is a class to get ROOT's objects from R's objects.
Definition TRObject.h:70
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
MsgLogger & Log() const
Class that contains all the data information.
Definition DataSetInfo.h:62
UInt_t GetNVariables() const
std::vector< TString > GetListOfVariables() const
returns list of variables
const Event * GetEvent() const
returns event without transformations
Definition DataSet.cxx:202
Types::ETreeType GetCurrentType() const
Definition DataSet.h:194
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:206
UInt_t GetNVariables() const
access the number of variables through the datasetinfo
Definition DataSet.cxx:216
void SetCurrentEvent(Long64_t ievt) const
Definition DataSet.h:88
const char * GetName() const override
Definition MethodBase.h:334
Bool_t IsModelPersistence() const
Definition MethodBase.h:383
const TString & GetWeightFileDir() const
Definition MethodBase.h:492
const TString & GetMethodName() const
Definition MethodBase.h:331
const Event * GetEvent() const
Definition MethodBase.h:751
DataSetInfo & DataInfo() const
Definition MethodBase.h:410
virtual void TestClassification()
initialization
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
DataSet * Data() const
Definition MethodBase.h:409
Double_t GetMvaValue(Double_t *errLower=nullptr, Double_t *errUpper=nullptr)
static Bool_t IsModuleLoaded
void GetHelpMessage() const
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
ROOT::R::TRFunctionImport predict
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
TString fUpdateFuncParams
TString fLearnFuncParams
Definition MethodRSNNS.h:97
virtual void TestClassification()
initialization
TString fPruneFuncParams
MethodRSNNS(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
TString fInitFuncParams
Definition MethodRSNNS.h:94
std::vector< UInt_t > fFactorNumeric
ROOT::R::TRFunctionImport mlp
ROOT::R::TRObject * fModel
std::vector< std::string > fFactorTrain
Definition RMethodBase.h:95
ROOT::R::TRInterface & r
Definition RMethodBase.h:52
ROOT::R::TRDataFrame fDfTrain
Definition RMethodBase.h:91
Timing information for training and evaluation of MVA methods.
Definition Timer.h:58
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:828
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kClassification
Definition Types.h:127
@ kTraining
Definition Types.h:143
Basic string class.
Definition TString.h:138
const char * Data() const
Definition TString.h:384
const Rcpp::internal::NamedPlaceHolder & Label
create variable transformations
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148