90 fLocalTrainingTree(0),
92 fValidationFraction(0.5),
101 const TString& theWeightFile) :
104 fLocalTrainingTree(0),
106 fValidationFraction(0.5),
107 fLearningMethod(
"" )
134 if (fMLP)
delete fMLP;
144 while (layerSpec.
Length()>0) {
146 if (layerSpec.
First(
',')<0) {
151 sToAdd = layerSpec(0,layerSpec.
First(
','));
152 layerSpec = layerSpec(layerSpec.
First(
',')+1,layerSpec.
Length());
156 nNodes += atoi(sToAdd);
157 fHiddenLayer =
TString::Format(
"%s%i:", (
const char*)fHiddenLayer, nNodes );
161 std::vector<TString>::iterator itrVar = (*fInputVars).begin();
162 std::vector<TString>::iterator itrVarEnd = (*fInputVars).end();
163 fMLPBuildOptions =
"";
164 for (; itrVar != itrVarEnd; ++itrVar) {
167 fMLPBuildOptions += myVar;
168 fMLPBuildOptions +=
",";
170 fMLPBuildOptions.
Chop();
173 fMLPBuildOptions += fHiddenLayer;
174 fMLPBuildOptions +=
"type";
176 Log() << kINFO <<
"Use " << fNcycles <<
" training cycles" <<
Endl;
177 Log() << kINFO <<
"Use configuration (nodes per hidden layer): " << fHiddenLayer <<
Endl;
197 DeclareOptionRef( fNcycles = 200,
"NCycles",
"Number of training cycles" );
198 DeclareOptionRef( fLayerSpec =
"N,N-1",
"HiddenLayers",
"Specification of hidden layer architecture (N stands for number of variables; any integers may also be used)" );
200 DeclareOptionRef( fValidationFraction = 0.5,
"ValidationFraction",
201 "Fraction of events in training tree used for cross validation" );
203 DeclareOptionRef( fLearningMethod =
"Stochastic",
"LearningMethod",
"Learning method" );
204 AddPreDefVal(
TString(
"Stochastic") );
205 AddPreDefVal(
TString(
"Batch") );
206 AddPreDefVal(
TString(
"SteepestDescent") );
207 AddPreDefVal(
TString(
"RibierePolak") );
208 AddPreDefVal(
TString(
"FletcherReeves") );
209 AddPreDefVal(
TString(
"BFGS") );
217 CreateMLPOptions(fLayerSpec);
219 if (IgnoreEventsWithNegWeightsInTraining()) {
220 Log() << kFATAL <<
"Mechanism to ignore events with negative weights in training not available for method"
221 << GetMethodTypeName()
222 <<
" --> please remove \"IgnoreNegWeightsInTraining\" option from booking string."
232 const Event* ev = GetEvent();
235 for (
UInt_t ivar = 0; ivar<Data()->GetNVariables(); ivar++) {
241 NoErrorCalc(err, errUpper);
264 const Long_t basketsize = 128000;
267 TTree *localTrainingTree =
new TTree(
"TMLPtrain",
"Local training tree for TMlpANN" );
268 localTrainingTree->
Branch(
"type", &
type,
"type/I", basketsize );
269 localTrainingTree->
Branch(
"weight", &weight,
"weight/F", basketsize );
271 for (
UInt_t ivar=0; ivar<GetNvar(); ivar++) {
272 TString myVar = GetInternalVarName(ivar);
274 localTrainingTree->
Branch( myVar.
Data(), &vArr[ivar], myTyp.
Data(), basketsize );
277 for (
UInt_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
278 const Event *ev = GetEvent(ievt);
279 for (
UInt_t i=0; i<GetNvar(); i++) {
282 type = DataInfo().IsSignal( ev ) ? 1 : 0;
284 localTrainingTree->
Fill();
292 trainList += 1.0-fValidationFraction;
294 trainList += (
Int_t)Data()->GetNEvtSigTrain();
295 trainList +=
" || (Entry$>";
296 trainList += (
Int_t)Data()->GetNEvtSigTrain();
297 trainList +=
" && Entry$<";
298 trainList += (
Int_t)(Data()->GetNEvtSigTrain() + (1.0 - fValidationFraction)*Data()->GetNEvtBkgdTrain());
303 Log() << kHEADER <<
"Requirement for training events: \"" << trainList <<
"\"" <<
Endl;
304 Log() << kINFO <<
"Requirement for validation events: \"" << testList <<
"\"" <<
Endl;
309 if (fMLP) {
delete fMLP; fMLP =
nullptr; }
314 fMLP->SetEventWeight(
"weight" );
319 fLearningMethod.ToLower();
327 Log() << kFATAL <<
"Unknown Learning Method: \"" << fLearningMethod <<
"\"" <<
Endl;
329 fMLP->SetLearningMethod( learningMethod );
332 fMLP->Train(fNcycles,
"" );
336 delete localTrainingTree;
348 gTools().
AddAttr( arch,
"BuildOptions", fMLPBuildOptions.Data() );
351 const TString tmpfile=GetWeightFileDir()+
"/TMlp.nn.weights.temp";
352 fMLP->DumpWeights( tmpfile.
Data() );
353 std::ifstream inf( tmpfile.
Data() );
357 while (inf.getline(temp,256)) {
363 dummy = dummy(0,dummy.
First(
' '));
368 data += (dummy +
" ");
385 const TString fname = GetWeightFileDir()+
"/TMlp.nn.weights.temp";
386 std::ofstream fout( fname.
Data() );
387 double temp1=0,temp2=0;
390 std::stringstream content(nodecontent);
391 if (strcmp(
gTools().GetName(ch),
"input")==0) {
392 fout <<
"#input normalization" << std::endl;
393 while ((content >> temp1) &&(content >> temp2)) {
394 fout << temp1 <<
" " << temp2 << std::endl;
397 if (strcmp(
gTools().GetName(ch),
"output")==0) {
398 fout <<
"#output normalization" << std::endl;
399 while ((content >> temp1) &&(content >> temp2)) {
400 fout << temp1 <<
" " << temp2 << std::endl;
403 if (strcmp(
gTools().GetName(ch),
"neurons")==0) {
404 fout <<
"#neurons weights" << std::endl;
405 while (content >> temp1) {
406 fout << temp1 << std::endl;
409 if (strcmp(
gTools().GetName(ch),
"synapses")==0) {
410 fout <<
"#synapses weights" ;
411 while (content >> temp1) {
412 fout << std::endl << temp1 ;
425 TTree * dummyTree =
new TTree(
"dummy",
"Empty dummy tree", 1);
426 for (
UInt_t ivar = 0; ivar<Data()->GetNVariables(); ivar++) {
427 TString vn = DataInfo().GetVariableInfo(ivar).GetInternalName();
433 if (fMLP) {
delete fMLP; fMLP =
nullptr; }
435 fMLP->LoadWeights( fname );
445 std::ofstream fout(
"./TMlp.nn.weights.temp" );
446 fout << istr.rdbuf();
450 Log() << kINFO <<
"Load TMLP weights into " << fMLP <<
Endl;
455 TTree * dummyTree =
new TTree(
"dummy",
"Empty dummy tree", 1);
456 for (
UInt_t ivar = 0; ivar<Data()->GetNVariables(); ivar++) {
457 TString vn = DataInfo().GetVariableInfo(ivar).GetLabel();
463 if (fMLP) {
delete fMLP; fMLP =
nullptr; }
466 fMLP->LoadWeights(
"./TMlp.nn.weights.temp" );
480 if (theClassFileName ==
"")
481 classFileName = GetWeightFileDir() +
"/" + GetJobName() +
"_" + GetMethodName() +
".class";
483 classFileName = theClassFileName;
486 Log() << kINFO <<
"Creating specific (TMultiLayerPerceptron) standalone response class: " << classFileName <<
Endl;
487 fMLP->Export( classFileName.
Data() );
509 Log() <<
"This feed-forward multilayer perceptron neural network is the " <<
Endl;
510 Log() <<
"standard implementation distributed with ROOT (class TMultiLayerPerceptron)." <<
Endl;
512 Log() <<
"Detailed information is available here:" <<
Endl;
513 if (
gConfig().WriteOptionsReference()) {
514 Log() <<
"<a href=\"http://root.cern.ch/root/html/TMultiLayerPerceptron.html\">";
515 Log() <<
"http://root.cern.ch/root/html/TMultiLayerPerceptron.html</a>" <<
Endl;
517 else Log() <<
"http://root.cern.ch/root/html/TMultiLayerPerceptron.html" <<
Endl;
#define REGISTER_METHOD(CLASS)
for example
const Bool_t EnforceNormalization__
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void data
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t type
Class that contains all the data information.
Float_t GetValue(UInt_t ivar) const
return value of i'th variable
Double_t GetWeight() const
return the event weight - depending on whether the flag IgnoreNegWeightsInTraining is or not.
Virtual base Class for all MVA method.
This is the TMVA TMultiLayerPerceptron interface class.
void ReadWeightsFromStream(std::istream &istr)
read weights from stream since the MLP can not read from the stream, we 1st: write the weights to tem...
void Init(void)
default initialisations
virtual Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets)
TMlpANN can handle classification with 2 classes.
void Train(void)
performs TMlpANN training available learning methods:
Double_t GetMvaValue(Double_t *err=nullptr, Double_t *errUpper=nullptr)
calculate the value of the neural net for the current event
void DeclareOptions()
define the options (their key words) that can be set in the option string
void CreateMLPOptions(TString)
translates options from option string into TMlpANN language
void ReadWeightsFromXML(void *wghtnode)
rebuild temporary textfile from xml weightfile and load this file into MLP
MethodTMlpANN(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption="3000:N-1:N-2")
standard constructor
void ProcessOptions()
builds the neural network as specified by the user
void MakeClassSpecific(std::ostream &, const TString &) const
write specific classifier response nothing to do here - all taken care of by TMultiLayerPerceptron
void AddWeightsXMLTo(void *parent) const
write weights to xml file
void MakeClass(const TString &classFileName=TString("")) const
create reader class for classifier -> overwrites base class function create specific class for TMulti...
virtual ~MethodTMlpANN(void)
destructor
void GetHelpMessage() const
get help message text
Singleton class for Global types used by TMVA.
This class describes a neural network.
TSubString Strip(EStripType s=kTrailing, char c=' ') const
Return a substring of self stripped at beginning and/or end.
Ssiz_t First(char c) const
Find first occurrence of a character c.
const char * Data() const
TString & ReplaceAll(const TString &s1, const TString &s2)
Bool_t BeginsWith(const char *s, ECaseCompare cmp=kExact) const
TString & Remove(Ssiz_t pos)
static TString Format(const char *fmt,...)
Static method which formats a string using a printf style format descriptor and return a TString.
A TTree represents a columnar dataset.
virtual Int_t Fill()
Fill all branches.
TBranch * Branch(const char *name, T *obj, Int_t bufsize=32000, Int_t splitlevel=99)
Add a new branch, and infer the data type from the type of obj being passed.
create variable transformations
MsgLogger & Endl(MsgLogger &ml)