Example of getting batches of events from a ROOT dataset into a basic TensorFlow workflow.
import tensorflow as tf
import ROOT
tree_name = "sig_tree"
file_name = "http://root.cern/files/Higgs_data.root"
batch_size = 128
chunk_size = 5_000
target = "Type"
ds_train, ds_valid = ROOT.TMVA.Experimental.CreateTFDatasets(
tree_name,
file_name,
batch_size,
chunk_size,
validation_split=0.3,
target=target,
)
input_columns = ds_train.train_columns
num_features =
len(input_columns)
model = tf.keras.Sequential(
[
tf.keras.layers.Dense(
300, activation=tf.nn.tanh, input_shape=(num_features,)
),
tf.keras.layers.Dense(300, activation=tf.nn.tanh),
tf.keras.layers.Dense(300, activation=tf.nn.tanh),
tf.keras.layers.Dense(1, activation=tf.nn.sigmoid),
]
)
loss_fn = tf.keras.losses.BinaryCrossentropy()
model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
model.fit(ds_train, validation_data=ds_valid, epochs=2)
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len
Epoch 1/2
1/Unknown - 8s 8s/step - loss: 0.4445 - accuracy: 0.9766␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
11/Unknown - 8s 5ms/step - loss: 0.0461 - accuracy: 0.9979␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
24/Unknown - 8s 5ms/step - loss: 0.0211 - accuracy: 0.9990␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
39/Unknown - 8s 4ms/step - loss: 0.0130 - accuracy: 0.9994␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
53/Unknown - 8s 4ms/step - loss: 0.0096 - accuracy: 0.9996␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
54/54 [==============================] - 9s 14ms/step - loss: 0.0094 - accuracy: 0.9996 - val_loss: 2.5342e-07 - val_accuracy: 1.0000
Epoch 2/2
1/54 [..............................] - ETA: 2s - loss: 3.4435e-07 - accuracy: 1.0000␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
16/54 [=======>......................] - ETA: 0s - loss: 2.5484e-07 - accuracy: 1.0000␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
32/54 [================>.............] - ETA: 0s - loss: 2.4553e-07 - accuracy: 1.0000␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
47/54 [=========================>....] - ETA: 0s - loss: 2.4822e-07 - accuracy: 1.0000␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈␈
54/54 [==============================] - 0s 5ms/step - loss: 2.5165e-07 - accuracy: 1.0000 - val_loss: 2.4065e-07 - val_accuracy: 1.0000
- Author
- Dante Niewenhuis
Definition in file RBatchGenerator_TensorFlow.py.