
import ROOT
import torch

tree_name = "sig_tree"
file_name = str(ROOT.gROOT.GetTutorialDir()) + "/machine_learning/data/Higgs_data.root"

batch_size = 128

rdataframe = ROOT.RDataFrame(tree_name, file_name)

target = "Type"

# Returns two generators that return training and validation batches
# as PyTorch tensors.
dl = ROOT.Experimental.ML.RDataLoader(
    rdataframe,
    batch_size,
    target=target,
    shuffle=True,
    drop_remainder=True,
)

gen_train, gen_validation = dl.train_test_split(test_size=0.3)

# Get a list of the columns used for training
input_columns = gen_train.train_columns
num_features = len(input_columns)


def calc_accuracy(targets, pred):
    return torch.sum(targets == pred.round()) / pred.size(0)


# Initialize PyTorch model
model = torch.nn.Sequential(
    torch.nn.Linear(num_features, 300),
    torch.nn.Tanh(),
    torch.nn.Linear(300, 300),
    torch.nn.Tanh(),
    torch.nn.Linear(300, 300),
    torch.nn.Tanh(),
    torch.nn.Linear(300, 1),
    torch.nn.Sigmoid(),
)
loss_fn = torch.nn.MSELoss(reduction="mean")
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

number_of_epochs = 2

for i in range(number_of_epochs):
    print("Epoch ", i)
    model.train()
    # Loop through the training set and train model
    for i, (x_train, y_train) in enumerate(gen_train.as_torch()):
        # Make prediction and calculate loss
        pred = model(x_train)
        loss = loss_fn(pred, y_train)

        # improve model
        model.zero_grad()
        loss.backward()
        optimizer.step()

        # Calculate accuracy
        accuracy = calc_accuracy(y_train, pred)

        print(f"Training => accuracy: {accuracy}")

    # #################################################################
    # # Validation
    # #################################################################

    model.eval()
    # Evaluate the model on the validation set
    for i, (x_val, y_val) in enumerate(gen_validation.as_torch()):
        # Make prediction and calculate accuracy
        pred = model(x_val)
        accuracy = calc_accuracy(y_val, pred)

        print(f"Validation => accuracy: {accuracy}")
