Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator_TensorFlow.py
Go to the documentation of this file.
1### \file
2### \ingroup tutorial_tmva
3### \notebook -nodraw
4###
5### Example of getting batches of events from a ROOT dataset into a basic
6### TensorFlow workflow.
7###
8### \macro_code
9### \macro_output
10### \author Dante Niewenhuis
11
12import tensorflow as tf
13import ROOT
14
15tree_name = "sig_tree"
16file_name = "http://root.cern/files/Higgs_data.root"
17
18batch_size = 128
19chunk_size = 5_000
20
21target = "Type"
22
23# Returns two TF.Dataset for training and validation batches.
24ds_train, ds_valid = ROOT.TMVA.Experimental.CreateTFDatasets(
25 tree_name,
26 file_name,
27 batch_size,
28 chunk_size,
29 validation_split=0.3,
30 target=target,
31)
32
33# Get a list of the columns used for training
34input_columns = ds_train.train_columns
35num_features = len(input_columns)
36
37##############################################################################
38# AI example
39##############################################################################
40
41# Define TensorFlow model
42model = tf.keras.Sequential(
43 [
44 tf.keras.layers.Dense(
45 300, activation=tf.nn.tanh, input_shape=(num_features,)
46 ), # input shape required
47 tf.keras.layers.Dense(300, activation=tf.nn.tanh),
48 tf.keras.layers.Dense(300, activation=tf.nn.tanh),
49 tf.keras.layers.Dense(1, activation=tf.nn.sigmoid),
50 ]
51)
52loss_fn = tf.keras.losses.BinaryCrossentropy()
53model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
54
55# Train model
56model.fit(ds_train, validation_data=ds_valid, epochs=2)
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len