import ROOT
import tensorflow as tf
tree_name = "sig_tree"
file_name = str(ROOT.gROOT.GetTutorialDir()) + "/machine_learning/data/Higgs_data.root"
batch_size = 128
approx_batches_in_memory = 50
target = ["Type"]
dl = ROOT.Experimental.ML.RDataLoader(
rdataframe,
batch_size,
approx_batches_in_memory,
target=target,
shuffle=True,
drop_remainder=True,
)
ds_train, ds_valid = dl.train_test_split(test_size=0.3)
num_of_epochs = 2
ds_train_repeated = ds_train.as_tensorflow().repeat(num_of_epochs)
ds_valid_repeated = ds_valid.as_tensorflow().repeat(num_of_epochs)
train_batches_per_epoch = ds_train.num_batches
validation_batches_per_epoch = ds_valid.num_batches
input_columns = ds_train.train_columns
num_features =
len(input_columns)
model = tf.keras.Sequential(
[
tf.keras.layers.Input(shape=(num_features,)),
tf.keras.layers.Dense(300, activation=tf.nn.tanh),
tf.keras.layers.Dense(300, activation=tf.nn.tanh),
tf.keras.layers.Dense(300, activation=tf.nn.tanh),
tf.keras.layers.Dense(1, activation=tf.nn.sigmoid),
]
)
loss_fn = tf.keras.losses.BinaryCrossentropy()
model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
model.fit(
ds_train_repeated,
steps_per_epoch=train_batches_per_epoch,
validation_data=ds_valid_repeated,
validation_steps=validation_batches_per_epoch,
epochs=num_of_epochs,
)
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...