Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_SOFIE_PyTorch.C
Go to the documentation of this file.
1/// \file
2/// \ingroup tutorial_ml
3/// \notebook -nodraw
4/// This macro provides a simple example for the parsing of PyTorch .pt file
5/// into RModel object and further generating the .hxx header files for inference.
6///
7/// \macro_code
8/// \macro_output
9/// \author Sanjiban Sengupta
10
11using namespace TMVA::Experimental;
12
14import torch\n\
15import torch.nn as nn\n\
16\n\
17model = nn.Sequential(\n\
18 nn.Linear(32,16),\n\
19 nn.ReLU(),\n\
20 nn.Linear(16,8),\n\
21 nn.ReLU()\n\
22 )\n\
23\n\
24criterion = nn.MSELoss()\n\
25optimizer = torch.optim.SGD(model.parameters(),lr=0.01)\n\
26\n\
27x=torch.randn(2,32)\n\
28y=torch.randn(2,8)\n\
29\n\
30for i in range(500):\n\
31 y_pred = model(x)\n\
32 loss = criterion(y_pred,y)\n\
33 optimizer.zero_grad()\n\
34 loss.backward()\n\
35 optimizer.step()\n\
36\n\
37model.eval()\n\
38m = torch.jit.script(model)\n\
39torch.jit.save(m,'PyTorchModel.pt')\n";
40
41
43
44 // Running the Python script to generate PyTorch .pt file
45
46 TMacro m;
47 m.AddLine(pythonSrc);
48 m.SaveSource("make_pytorch_model.py");
49 gSystem->Exec("python3 make_pytorch_model.py");
50
51 // Parsing a PyTorch model requires the shape and data-type of input tensor
52 // Data-type of input tensor defaults to Float if not specified
53 std::vector<size_t> inputTensorShapeSequential{2, 32};
54 std::vector<std::vector<size_t>> inputShapesSequential{inputTensorShapeSequential};
55
56 // Parsing the saved PyTorch .pt file into RModel object
57 SOFIE::RModel model = SOFIE::PyTorch::Parse("PyTorchModel.pt", inputShapesSequential);
58
59 // Generating inference code
60 model.Generate();
61 model.OutputGenerated("PyTorchModel.hxx");
62
63 // Printing required input tensors
64 std::cout << "\n\n";
66
67 // Printing initialized tensors (weights)
68 std::cout << "\n\n";
70
71 // Printing intermediate tensors
72 std::cout << "\n\n";
74
75 // Checking if tensor already exist in model
76 std::cout << "\n\nTensor \"0weight\" already exist: " << std::boolalpha << model.CheckIfTensorAlreadyExist("0weight")
77 << "\n\n";
78 std::vector<size_t> tensorShape = model.GetTensorShape("0weight");
79 std::cout << "Shape of tensor \"0weight\": ";
80 for (auto &it : tensorShape) {
81 std::cout << it << ",";
82 }
83 std::cout<<"\n\nData type of tensor \"0weight\": ";
85 std::cout<<SOFIE::ConvertTypeToString(tensorType);
86
87 //Printing generated inference code
88 std::cout<<"\n\n";
89 model.PrintGenerated();
90}
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
R__EXTERN TSystem * gSystem
Definition TSystem.h:572
std::vector< size_t > GetTensorShape(const std::string &name) const
Definition RModel.cxx:29
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:122
void OutputGenerated(std::string filename="", bool append=false)
Definition RModel.cxx:1430
void Generate(std::underlying_type_t< Options > options, int batchSize=-1, long pos=0, bool verbose=false)
Definition RModel.cxx:1062
ETensorType GetTensorType(std::string name) const
Definition RModel.cxx:90
Class supporting a collection of lines with C++ code.
Definition TMacro.h:31
Basic string class.
Definition TString.h:138
virtual Int_t Exec(const char *shellcmd)
Execute a command.
Definition TSystem.cxx:651
TMarker m
Definition textangle.C:8