Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RBatchGenerator_NumPy.py
Go to the documentation of this file.
1### \file
2### \ingroup tutorial_ml
3### \notebook -nodraw
4### Example of getting batches of events from a ROOT dataset as Python
5### generators of numpy arrays.
6###
7### \macro_code
8### \macro_output
9### \author Dante Niewenhuis
10
11import ROOT
12
13tree_name = "sig_tree"
14file_name = str(ROOT.gROOT.GetTutorialDir()) + "/machine_learning/data/Higgs_data.root"
15
16batch_size = 128
17chunk_size = 5000
18block_size = 400
19
20rdataframe = ROOT.RDataFrame(tree_name, file_name)
21
22target = "Type"
23
24num_of_epochs = 2
25
26gen_train, gen_validation = ROOT.TMVA.Experimental.CreateNumPyGenerators(
27 rdataframe,
28 batch_size,
29 chunk_size,
30 block_size,
31 target = target,
32 validation_split = 0.3,
33 shuffle = True,
34 drop_remainder = True
35)
36
37for i in range(num_of_epochs):
38 # Loop through training set
39 for i, (x_train, y_train) in enumerate(gen_train):
40 print(f"Training batch {i + 1} => x: {x_train.shape}, y: {y_train.shape}")
41
42 # Loop through Validation set
43 for i, (x_validation, y_validation) in enumerate(gen_validation):
44 print(f"Validation batch {i + 1} => x: {x_validation.shape}, y: {y_validation.shape}")
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
ROOT's RDataFrame offers a modern, high-level interface for analysis of data stored in TTree ,...