16file_name =
"http://root.cern/files/Higgs_data.root"
25gen_train, gen_validation = ROOT.TMVA.Experimental.CreatePyTorchGenerators(
35input_columns = gen_train.train_columns
36num_features =
len(input_columns)
39def calc_accuracy(targets, pred):
40 return torch.sum(targets == pred.round()) / pred.size(0)
44model = torch.nn.Sequential(
45 torch.nn.Linear(num_features, 300),
47 torch.nn.Linear(300, 300),
49 torch.nn.Linear(300, 300),
51 torch.nn.Linear(300, 1),
54loss_fn = torch.nn.MSELoss(reduction=
"mean")
55optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
59for i, (x_train, y_train)
in enumerate(gen_train):
61 pred = model(x_train).view(-1)
62 loss = loss_fn(pred, y_train)
70 accuracy = calc_accuracy(y_train, pred)
72 print(f
"Training => accuracy: {accuracy}")
79for i, (x_train, y_train)
in enumerate(gen_validation):
81 pred = model(x_train).view(-1)
82 accuracy = calc_accuracy(y_train, pred)
84 print(f
"Validation => accuracy: {accuracy}")
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t Atom_t Time_t UChar_t len