13from ROOT
import TMVA, TFile, TCut
14from subprocess
import call
15from os.path
import isfile
17from tensorflow.keras.models
import Sequential
18from tensorflow.keras.layers
import Dense
19from tensorflow.keras.optimizers
import SGD
27 model.add(Dense(64, activation=
'relu', input_dim=4))
28 model.add(Dense(2, activation=
'softmax'))
31 model.compile(loss=
'categorical_crossentropy',
32 optimizer=SGD(learning_rate=0.01), weighted_metrics=[
'accuracy', ])
35 model.save(
'modelClassification.h5')
40 with TFile.Open(
'TMVA_Classification_Keras.root',
'RECREATE')
as output,
TFile.Open(
'tmva_class_example.root')
as data:
42 '!V:!Silent:Color:DrawProgressBar:Transformations=D,G:AnalysisType=Classification')
44 signal = data.Get(
'TreeS')
45 background = data.Get(
'TreeB')
48 for branch
in signal.GetListOfBranches():
49 dataloader.AddVariable(branch.GetName())
51 dataloader.AddSignalTree(signal, 1.0)
52 dataloader.AddBackgroundTree(background, 1.0)
53 dataloader.PrepareTrainingAndTestTree(
TCut(
''),
54 'nTrain_Signal=4000:nTrain_Background=4000:SplitMode=Random:NormMode=NumEvents:!V')
57 factory.BookMethod(dataloader, TMVA.Types.kFisher,
'Fisher',
58 '!H:!V:Fisher:VarTransform=D,G')
59 factory.BookMethod(dataloader, TMVA.Types.kPyKeras,
'PyKeras',
60 'H:!V:VarTransform=D,G:FilenameModel=modelClassification.h5:FilenameTrainedModel=trainedModelClassification.h5:NumEpochs=20:BatchSize=32')
63 factory.TrainAllMethods()
64 factory.TestAllMethods()
65 factory.EvaluateAllMethods()
68if __name__ ==
"__main__":
77 if not isfile(
'tmva_class_example.root'):
78 call([
'curl',
'-L',
'-O',
'http://root.cern/files/tmva_class_example.root'])
A specialized string object used for TTree selections.
static TFile * Open(const char *name, Option_t *option="", const char *ftitle="", Int_t compress=ROOT::RCompressionSetting::EDefaults::kUseCompiledDefault, Int_t netopt=0)
Create / open a file.
This is the main MVA steering class.
static void PyInitialize()
Initialize Python interpreter.