Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
MethodRSVM.cxx
Go to the documentation of this file.
1// @(#)root/tmva/rmva $Id$
2// Author: Omar Zapata,Lorenzo Moneta, Sergei Gleyzer 2015
3
4/**********************************************************************************
5 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis *
6 * Package: TMVA *
7 * Class : MethodRSVM- *
8 * Web : http://oproject.org *
9 * *
10 * Description: *
11 * Support Vector Machines *
12 * *
13 * *
14 * Redistribution and use in source and binary forms, with or without *
15 * modification, are permitted according to the terms listed in LICENSE *
16 * (see tmva/doc/LICENSE) *
17 * *
18 **********************************************************************************/
19
20#include <iomanip>
21
22#include "TMath.h"
23#include "Riostream.h"
24#include "TMatrix.h"
25#include "TMatrixD.h"
26#include "TVectorD.h"
27
29#include "TMVA/MethodRSVM.h"
30#include "TMVA/Tools.h"
31#include "TMVA/Config.h"
32#include "TMVA/Ranking.h"
33#include "TMVA/Types.h"
34#include "TMVA/PDF.h"
36
37#include "TMVA/Results.h"
38#include "TMVA/Timer.h"
39
40using namespace TMVA;
41
43
45//creating an Instance
47
48
49//_______________________________________________________________________
51 const TString &methodTitle,
52 DataSetInfo &dsi,
53 const TString &theOption) :
54 RMethodBase(jobName, Types::kRSVM, methodTitle, dsi, theOption),
55 fMvaCounter(0),
56 svm("svm"),
57 predict("predict"),
58 asfactor("as.factor"),
59 fModel(NULL)
60{
61 // standard constructor for the RSVM
62 //Booking options
63 fScale = kTRUE;
64 fType = "C-classification";
65 fKernel = "radial";
66 fDegree = 3;
67
68 fGamma = (fDfTrain.GetNcols() == 1) ? 1.0 : (1.0 / fDfTrain.GetNcols());
69 fCoef0 = 0;
70 fCost = 1;
71 fNu = 0.5;
72 fCacheSize = 40;
73 fTolerance = 0.001;
74 fEpsilon = 0.1;
76 fCross = 0;
78 fFitted = kTRUE;
79}
80
81//_______________________________________________________________________
82MethodRSVM::MethodRSVM(DataSetInfo &theData, const TString &theWeightFile)
83 : RMethodBase(Types::kRSVM, theData, theWeightFile),
84 fMvaCounter(0),
85 svm("svm"),
86 predict("predict"),
87 asfactor("as.factor"),
88 fModel(NULL)
89{
90 // standard constructor for the RSVM
91 //Booking options
92 fScale = kTRUE;
93 fType = "C-classification";
94 fKernel = "radial";
95 fDegree = 3;
96
97 fGamma = (fDfTrain.GetNcols() == 1) ? 1.0 : (1.0 / fDfTrain.GetNcols());
98 fCoef0 = 0;
99 fCost = 1;
100 fNu = 0.5;
101 fCacheSize = 40;
102 fTolerance = 0.001;
103 fEpsilon = 0.1;
105 fCross = 0;
107 fFitted = kTRUE;
108}
109
110
111//_______________________________________________________________________
113{
114 if (fModel) delete fModel;
115}
116
117//_______________________________________________________________________
119{
120 if (type == Types::kClassification && numberClasses == 2) return kTRUE;
121 return kFALSE;
122}
123
124
125//_______________________________________________________________________
127{
128 if (!IsModuleLoaded) {
129 Error("Init", "R's package e1071 can not be loaded.");
130 Log() << kFATAL << " R's package e1071 can not be loaded."
131 << Endl;
132 return;
133 }
134}
135
137{
138 if (Data()->GetNTrainingEvents() == 0) Log() << kFATAL << "<Train> Data() has zero events" << Endl;
139 //SVM require a named vector
140 ROOT::R::TRDataFrame ClassWeightsTrain;
141 ClassWeightsTrain["background"] = Data()->GetNEvtBkgdTrain();
142 ClassWeightsTrain["signal"] = Data()->GetNEvtSigTrain();
143
144 Log() << kINFO
145 << " Probability is " << fProbability
146 << " Tolerance is " << fTolerance
147 << " Type is " << fType
148 << Endl;
149
150
151 SEXP Model = svm(ROOT::R::Label["x"] = fDfTrain, \
153 ROOT::R::Label["scale"] = fScale, \
154 ROOT::R::Label["type"] = fType, \
155 ROOT::R::Label["kernel"] = fKernel, \
156 ROOT::R::Label["degree"] = fDegree, \
157 ROOT::R::Label["gamma"] = fGamma, \
158 ROOT::R::Label["coef0"] = fCoef0, \
159 ROOT::R::Label["cost"] = fCost, \
160 ROOT::R::Label["nu"] = fNu, \
161 ROOT::R::Label["class.weights"] = ClassWeightsTrain, \
162 ROOT::R::Label["cachesize"] = fCacheSize, \
163 ROOT::R::Label["tolerance"] = fTolerance, \
164 ROOT::R::Label["epsilon"] = fEpsilon, \
165 ROOT::R::Label["shrinking"] = fShrinking, \
166 ROOT::R::Label["cross"] = fCross, \
167 ROOT::R::Label["probability"] = fProbability, \
168 ROOT::R::Label["fitted"] = fFitted);
169 fModel = new ROOT::R::TRObject(Model);
170 if (IsModelPersistence())
171 {
172 TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
173 Log() << Endl;
174 Log() << gTools().Color("bold") << "--- Saving State File In:" << gTools().Color("reset") << path << Endl;
175 Log() << Endl;
176 r["RSVMModel"] << Model;
177 r << "save(RSVMModel,file='" + path + "')";
178 }
179}
180
181//_______________________________________________________________________
183{
184 DeclareOptionRef(fScale, "Scale", "A logical vector indicating the variables to be scaled. If\
185 ‘scale’ is of length 1, the value is recycled as many times \
186 as needed. Per default, data are scaled internally (both ‘x’\
187 and ‘y’ variables) to zero mean and unit variance. The center \
188 and scale values are returned and used for later predictions.");
189 DeclareOptionRef(fType, "Type", "‘svm’ can be used as a classification machine, as a \
190 regression machine, or for novelty detection. Depending of\
191 whether ‘y’ is a factor or not, the default setting for\
192 ‘type’ is ‘C-classification’ or ‘eps-regression’,\
193 respectively, but may be overwritten by setting an explicit value.\
194 Valid options are:\
195 - ‘C-classification’\
196 - ‘nu-classification’\
197 - ‘one-classification’ (for novelty detection)\
198 - ‘eps-regression’\
199 - ‘nu-regression’");
200 DeclareOptionRef(fKernel, "Kernel", "the kernel used in training and predicting. You might\
201 consider changing some of the following parameters, depending on the kernel type.\
202 linear: u'*v\
203 polynomial: (gamma*u'*v + coef0)^degree\
204 radial basis: exp(-gamma*|u-v|^2)\
205 sigmoid: tanh(gamma*u'*v + coef0)");
206 DeclareOptionRef(fDegree, "Degree", "parameter needed for kernel of type ‘polynomial’ (default: 3)");
207 DeclareOptionRef(fGamma, "Gamma", "parameter needed for all kernels except ‘linear’ (default:1/(data dimension))");
208 DeclareOptionRef(fCoef0, "Coef0", "parameter needed for kernels of type ‘polynomial’ and ‘sigmoid’ (default: 0)");
209 DeclareOptionRef(fCost, "Cost", "cost of constraints violation (default: 1)-it is the ‘C’-constant of the regularization term in the Lagrange formulation.");
210 DeclareOptionRef(fNu, "Nu", "parameter needed for ‘nu-classification’, ‘nu-regression’,and ‘one-classification’");
211 DeclareOptionRef(fCacheSize, "CacheSize", "cache memory in MB (default 40)");
212 DeclareOptionRef(fTolerance, "Tolerance", "tolerance of termination criterion (default: 0.001)");
213 DeclareOptionRef(fEpsilon, "Epsilon", "epsilon in the insensitive-loss function (default: 0.1)");
214 DeclareOptionRef(fShrinking, "Shrinking", "option whether to use the shrinking-heuristics (default:‘TRUE’)");
215 DeclareOptionRef(fCross, "Cross", "if a integer value k>0 is specified, a k-fold cross validation on the training data is performed to assess the quality of the model: the accuracy rate for classification and the Mean Squared Error for regression");
216 DeclareOptionRef(fProbability, "Probability", "logical indicating whether the model should allow for probability predictions");
217 DeclareOptionRef(fFitted, "Fitted", "logical indicating whether the fitted values should be computed and included in the model or not (default: ‘TRUE’)");
218
219}
220
221//_______________________________________________________________________
223{
224 r["RMVA.RSVM.Scale"] = fScale;
225 r["RMVA.RSVM.Type"] = fType;
226 r["RMVA.RSVM.Kernel"] = fKernel;
227 r["RMVA.RSVM.Degree"] = fDegree;
228 r["RMVA.RSVM.Gamma"] = fGamma;
229 r["RMVA.RSVM.Coef0"] = fCoef0;
230 r["RMVA.RSVM.Cost"] = fCost;
231 r["RMVA.RSVM.Nu"] = fNu;
232 r["RMVA.RSVM.CacheSize"] = fCacheSize;
233 r["RMVA.RSVM.Tolerance"] = fTolerance;
234 r["RMVA.RSVM.Epsilon"] = fEpsilon;
235 r["RMVA.RSVM.Shrinking"] = fShrinking;
236 r["RMVA.RSVM.Cross"] = fCross;
237 r["RMVA.RSVM.Probability"] = fProbability;
238 r["RMVA.RSVM.Fitted"] = fFitted;
239
240}
241
242//_______________________________________________________________________
244{
245 Log() << kINFO << "Testing Classification RSVM METHOD " << Endl;
246
248}
249
250
251//_______________________________________________________________________
253{
254 NoErrorCalc(errLower, errUpper);
255 Double_t mvaValue;
256 const TMVA::Event *ev = GetEvent();
257 const UInt_t nvar = DataInfo().GetNVariables();
258 ROOT::R::TRDataFrame fDfEvent;
259 for (UInt_t i = 0; i < nvar; i++) {
260 fDfEvent[DataInfo().GetListOfVariables()[i].Data()] = ev->GetValues()[i];
261 }
262 //if using persistence model
264
265 ROOT::R::TRObject result = predict(*fModel, fDfEvent, ROOT::R::Label["decision.values"] = kTRUE, ROOT::R::Label["probability"] = kTRUE);
266 TVectorD values = result.GetAttribute("decision.values");
267 mvaValue = values[0]; //returning signal prob
268 return mvaValue;
269}
270
271////////////////////////////////////////////////////////////////////////////////
272/// get all the MVA values for the events of the current Data type
273std::vector<Double_t> MethodRSVM::GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress)
274{
275 Long64_t nEvents = Data()->GetNEvents();
276 if (firstEvt > lastEvt || lastEvt > nEvents) lastEvt = nEvents;
277 if (firstEvt < 0) firstEvt = 0;
278
279 nEvents = lastEvt-firstEvt;
280
281 UInt_t nvars = Data()->GetNVariables();
282
283 // use timer
284 Timer timer( nEvents, GetName(), kTRUE );
285 if (logProgress)
286 Log() << kINFO<<Form("Dataset[%s] : ",DataInfo().GetName())<< "Evaluation of " << GetMethodName() << " on "
287 << (Data()->GetCurrentType()==Types::kTraining?"training":"testing") << " sample (" << nEvents << " events)" << Endl;
288
289
290 // fill R DATA FRAME with events data
291 std::vector<std::vector<Float_t> > inputData(nvars);
292 for (UInt_t i = 0; i < nvars; i++) {
293 inputData[i] = std::vector<Float_t>(nEvents);
294 }
295
296 for (Int_t ievt=firstEvt; ievt<lastEvt; ievt++) {
297 Data()->SetCurrentEvent(ievt);
298 const TMVA::Event *e = Data()->GetEvent();
299 assert(nvars == e->GetNVariables());
300 for (UInt_t i = 0; i < nvars; i++) {
301 inputData[i][ievt] = e->GetValue(i);
302 }
303 // if (ievt%100 == 0)
304 // std::cout << "Event " << ievt << " type" << DataInfo().IsSignal(e) << " : " << pValue[ievt*nvars] << " " << pValue[ievt*nvars+1] << " " << pValue[ievt*nvars+2] << std::endl;
305 }
306
307 ROOT::R::TRDataFrame evtData;
308 for (UInt_t i = 0; i < nvars; i++) {
309 evtData[DataInfo().GetListOfVariables()[i].Data()] = inputData[i];
310 }
311 //if using persistence model
313
314 std::vector<Double_t> mvaValues(nEvents);
315
316
317 ROOT::R::TRObject result = predict(*fModel, evtData, ROOT::R::Label["decision.values"] = kTRUE, ROOT::R::Label["probability"] = kTRUE);
318
319 r["result"] << result;
320 r << "v2 <- attr(result, \"probabilities\") ";
321 int probSize = 0;
322 r["length(v2)"] >> probSize;
323 //r << "print(v2)";
324 if (probSize > 0) {
325 std::vector<Double_t> probValues = result.GetAttribute("probabilities");
326 // probabilities are for both cases
327 assert(probValues.size() == 2*mvaValues.size());
328 for (int i = 0; i < nEvents; ++i)
329 // R stores vector column-wise (as in Fortran)
330 // and signal probabilities are the second column
331 mvaValues[i] = probValues[nEvents+i];
332
333 }
334 // use decision values
335 else {
336 Log() << kINFO << " : Probabilities are not available. Use decision values instead !" << Endl;
337 //std::cout << "examine the result " << std::endl;
338 std::vector<Double_t> probValues = result.GetAttribute("decision.values");
339 mvaValues = probValues;
340 // std::cout << "decision values " << values1.size() << std::endl;
341 // for ( auto & v : values1) std::cout << v << " ";
342 // std::cout << std::endl;
343 }
344
345
346 if (logProgress) {
347 Log() << kINFO <<Form("Dataset[%s] : ",DataInfo().GetName())<< "Elapsed time for evaluation of " << nEvents << " events: "
348 << timer.GetElapsedTime() << " " << Endl;
349 }
350
351 return mvaValues;
352
353}
354
355//_______________________________________________________________________
357{
359 TString path = GetWeightFileDir() + "/" + GetName() + ".RData";
360 Log() << Endl;
361 Log() << gTools().Color("bold") << "--- Loading State File From:" << gTools().Color("reset") << path << Endl;
362 Log() << Endl;
363 r << "load('" + path + "')";
364 SEXP Model;
365 r["RSVMModel"] >> Model;
366 fModel = new ROOT::R::TRObject(Model);
367
368}
369
370//_______________________________________________________________________
372{
373// get help message text
374//
375// typical length of text line:
376// "|--------------------------------------------------------------|"
377 Log() << Endl;
378 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
379 Log() << Endl;
380 Log() << "Decision Trees and Rule-Based Models " << Endl;
381 Log() << Endl;
382 Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
383 Log() << Endl;
384 Log() << Endl;
385 Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
386 Log() << Endl;
387 Log() << "<None>" << Endl;
388}
389
#define REGISTER_METHOD(CLASS)
for example
#define e(i)
Definition RSha256.hxx:103
bool Bool_t
Definition RtypesCore.h:63
constexpr Bool_t kFALSE
Definition RtypesCore.h:94
long long Long64_t
Definition RtypesCore.h:69
constexpr Bool_t kTRUE
Definition RtypesCore.h:93
#define ClassImp(name)
Definition Rtypes.h:377
void Error(const char *location, const char *msgfmt,...)
Use this function in case an error occurred.
Definition TError.cxx:185
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t result
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
char * Form(const char *fmt,...)
Formats a string in a circular formatting buffer.
Definition TString.cxx:2489
This is a class to create DataFrames from ROOT to R.
Int_t GetNcols()
Method to get the number of columns.
static TRInterface & Instance()
static method to get an TRInterface instance reference
Bool_t Require(TString pkg)
Method to load an R's package.
This is a class to get ROOT's objects from R's objects.
Definition TRObject.h:70
OptionBase * DeclareOptionRef(T &ref, const TString &name, const TString &desc="")
MsgLogger & Log() const
Class that contains all the data information.
Definition DataSetInfo.h:62
UInt_t GetNVariables() const
std::vector< TString > GetListOfVariables() const
returns list of variables
Long64_t GetNEvtSigTrain()
return number of signal training events in dataset
Definition DataSet.cxx:443
const Event * GetEvent() const
returns event without transformations
Definition DataSet.cxx:202
Types::ETreeType GetCurrentType() const
Definition DataSet.h:194
Long64_t GetNEvents(Types::ETreeType type=Types::kMaxTreeType) const
Definition DataSet.h:206
UInt_t GetNVariables() const
access the number of variables through the datasetinfo
Definition DataSet.cxx:216
void SetCurrentEvent(Long64_t ievt) const
Definition DataSet.h:88
Long64_t GetNEvtBkgdTrain()
return number of background training events in dataset
Definition DataSet.cxx:451
std::vector< Float_t > & GetValues()
Definition Event.h:94
const char * GetName() const
Definition MethodBase.h:334
Bool_t IsModelPersistence() const
Definition MethodBase.h:383
const TString & GetWeightFileDir() const
Definition MethodBase.h:492
const TString & GetMethodName() const
Definition MethodBase.h:331
const Event * GetEvent() const
Definition MethodBase.h:751
DataSetInfo & DataInfo() const
Definition MethodBase.h:410
virtual void TestClassification()
initialization
void ReadStateFromFile()
Function to write options and weights to file.
void NoErrorCalc(Double_t *const err, Double_t *const errUpper)
DataSet * Data() const
Definition MethodBase.h:409
virtual std::vector< Double_t > GetMvaValues(Long64_t firstEvt=0, Long64_t lastEvt=-1, Bool_t logProgress=false)
get all the MVA values for the events of the current Data type
virtual void TestClassification()
initialization
ROOT::R::TRFunctionImport asfactor
Definition MethodRSVM.h:130
static Bool_t IsModuleLoaded
Definition MethodRSVM.h:127
Double_t GetMvaValue(Double_t *errLower=nullptr, Double_t *errUpper=nullptr)
ROOT::R::TRObject * fModel
Definition MethodRSVM.h:131
MethodRSVM(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="")
ROOT::R::TRFunctionImport svm
Definition MethodRSVM.h:128
void GetHelpMessage() const
ROOT::R::TRFunctionImport predict
Definition MethodRSVM.h:129
Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
std::vector< std::string > fFactorTrain
Definition RMethodBase.h:95
ROOT::R::TRInterface & r
Definition RMethodBase.h:52
ROOT::R::TRDataFrame fDfTrain
Definition RMethodBase.h:91
Timing information for training and evaluation of MVA methods.
Definition Timer.h:58
TString GetElapsedTime(Bool_t Scientific=kTRUE)
returns pretty string with elapsed time
Definition Timer.cxx:146
const TString & Color(const TString &)
human readable color strings
Definition Tools.cxx:828
Singleton class for Global types used by TMVA.
Definition Types.h:71
@ kClassification
Definition Types.h:127
@ kTraining
Definition Types.h:143
Basic string class.
Definition TString.h:139
const Rcpp::internal::NamedPlaceHolder & Label
create variable transformations
Tools & gTools()
MsgLogger & Endl(MsgLogger &ml)
Definition MsgLogger.h:148