Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
TMVA_SOFIE_PyTorch.C File Reference

Detailed Description

View in nbviewer Open in SWAN
This macro provides a simple example for the parsing of PyTorch .pt file into RModel object and further generating the .hxx header files for inference.

using namespace TMVA::Experimental;
TString pythonSrc = "\
import torch\n\
import torch.nn as nn\n\
\n\
model = nn.Sequential(\n\
nn.Linear(32,16),\n\
nn.ReLU(),\n\
nn.Linear(16,8),\n\
nn.ReLU()\n\
)\n\
\n\
criterion = nn.MSELoss()\n\
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)\n\
\n\
x=torch.randn(2,32)\n\
y=torch.randn(2,8)\n\
\n\
for i in range(500):\n\
y_pred = model(x)\n\
loss = criterion(y_pred,y)\n\
optimizer.zero_grad()\n\
loss.backward()\n\
optimizer.step()\n\
\n\
model.eval()\n\
m = torch.jit.script(model)\n\
torch.jit.save(m,'PyTorchModel.pt')\n";
void TMVA_SOFIE_PyTorch(){
//Running the Python script to generate PyTorch .pt file
m.AddLine(pythonSrc);
m.SaveSource("make_pytorch_model.py");
gSystem->Exec(TMVA::Python_Executable() + " make_pytorch_model.py");
//Parsing a PyTorch model requires the shape and data-type of input tensor
//Data-type of input tensor defaults to Float if not specified
std::vector<size_t> inputTensorShapeSequential{2,32};
std::vector<std::vector<size_t>> inputShapesSequential{inputTensorShapeSequential};
//Parsing the saved PyTorch .pt file into RModel object
SOFIE::RModel model = SOFIE::PyTorch::Parse("PyTorchModel.pt",inputShapesSequential);
//Generating inference code
model.Generate();
model.OutputGenerated("PyTorchModel.hxx");
//Printing required input tensors
std::cout<<"\n\n";
//Printing initialized tensors (weights)
std::cout<<"\n\n";
//Printing intermediate tensors
std::cout<<"\n\n";
//Checking if tensor already exist in model
std::cout<<"\n\nTensor \"0weight\" already exist: "<<std::boolalpha<<model.CheckIfTensorAlreadyExist("0weight")<<"\n\n";
std::vector<size_t> tensorShape = model.GetTensorShape("0weight");
std::cout<<"Shape of tensor \"0weight\": ";
for(auto& it:tensorShape){
std::cout<<it<<",";
}
std::cout<<"\n\nData type of tensor \"0weight\": ";
SOFIE::ETensorType tensorType = model.GetTensorType("0weight");
std::cout<<SOFIE::ConvertTypeToString(tensorType);
//Printing generated inference code
std::cout<<"\n\n";
model.PrintGenerated();
}
R__EXTERN TSystem * gSystem
Definition TSystem.h:555
const ETensorType & GetTensorType(std::string name)
Definition RModel.cxx:91
bool CheckIfTensorAlreadyExist(std::string tensor_name)
Definition RModel.cxx:116
void OutputGenerated(std::string filename="", bool append=false)
Definition RModel.cxx:920
const std::vector< size_t > & GetTensorShape(std::string name)
Definition RModel.cxx:56
void Generate(std::underlying_type_t< Options > options, int batchSize=-1, long pos=0)
Definition RModel.cxx:529
static void PyInitialize()
Initialize Python interpreter.
Class supporting a collection of lines with C++ code.
Definition TMacro.h:31
Basic string class.
Definition TString.h:139
virtual Int_t Exec(const char *shellcmd)
Execute a command.
Definition TSystem.cxx:653
TString Python_Executable()
Function to find current Python executable used by ROOT If "Python3" is installed,...
TMarker m
Definition textangle.C:8
Model requires following inputs:
Fully Specified Tensor name: input1 type: float shape: [2,32]
Model initialized the following tensors:
Tensor name: "2bias" type: float shape: [8]
Tensor name: "0weight" type: float shape: [16,32]
Tensor name: "2weight" type: float shape: [8,16]
Tensor name: "0bias" type: float shape: [16]
Model specify the following intermediate tensors:
Tensor name: "result3" type: float shape: [2,8]
Tensor name: "2biasbcast" type: float shape: [2,8]
Tensor name: "input2" type: float shape: [2,8]
Tensor name: "input0" type: float shape: [2,16]
Tensor name: "result" type: float shape: [2,16]
Tensor name: "0biasbcast" type: float shape: [2,16]
Tensor "0weight" already exist: true
Shape of tensor "0weight": 16,32,
Data type of tensor "0weight": float
//Code generated automatically by TMVA for Inference of Model file [PyTorchModel.pt] at [Wed Apr 17 10:21:55 2024]
#ifndef ROOT_TMVA_SOFIE_PYTORCHMODEL
#define ROOT_TMVA_SOFIE_PYTORCHMODEL
#include <algorithm>
#include <vector>
#include "TMVA/SOFIE_common.hxx"
#include <fstream>
namespace TMVA_SOFIE_PyTorchModel{
namespace BLAS{
extern "C" void sgemv_(const char * trans, const int * m, const int * n, const float * alpha, const float * A,
const int * lda, const float * X, const int * incx, const float * beta, const float * Y, const int * incy);
extern "C" void sgemm_(const char * transa, const char * transb, const int * m, const int * n, const int * k,
const float * alpha, const float * A, const int * lda, const float * B, const int * ldb,
const float * beta, float * C, const int * ldc);
}//BLAS
struct Session {
std::vector<float> fTensor_2bias = std::vector<float>(8);
float * tensor_2bias = fTensor_2bias.data();
std::vector<float> fTensor_0weight = std::vector<float>(512);
float * tensor_0weight = fTensor_0weight.data();
std::vector<float> fTensor_2weight = std::vector<float>(128);
float * tensor_2weight = fTensor_2weight.data();
std::vector<float> fTensor_0bias = std::vector<float>(16);
float * tensor_0bias = fTensor_0bias.data();
//--- declare and allocate the intermediate tensors
std::vector<float> fTensor_result3 = std::vector<float>(16);
float * tensor_result3 = fTensor_result3.data();
std::vector<float> fTensor_2biasbcast = std::vector<float>(16);
float * tensor_2biasbcast = fTensor_2biasbcast.data();
std::vector<float> fTensor_input2 = std::vector<float>(16);
float * tensor_input2 = fTensor_input2.data();
std::vector<float> fTensor_input0 = std::vector<float>(32);
float * tensor_input0 = fTensor_input0.data();
std::vector<float> fTensor_result = std::vector<float>(32);
float * tensor_result = fTensor_result.data();
std::vector<float> fTensor_0biasbcast = std::vector<float>(32);
float * tensor_0biasbcast = fTensor_0biasbcast.data();
Session(std::string filename ="PyTorchModel.dat") {
//--- reading weights from file
std::ifstream f;
f.open(filename);
if (!f.is_open()) {
throw std::runtime_error("tmva-sofie failed to open file for input weights");
}
std::string tensor_name;
size_t length;
f >> tensor_name >> length;
if (tensor_name != "tensor_2bias" ) {
std::string err_msg = "TMVA-SOFIE failed to read the correct tensor name; expected name is tensor_2bias , read " + tensor_name;
throw std::runtime_error(err_msg);
}
if (length != 8) {
std::string err_msg = "TMVA-SOFIE failed to read the correct tensor size; expected size is 8 , read " + std::to_string(length) ;
throw std::runtime_error(err_msg);
}
for (size_t i = 0; i < length; ++i)
f >> tensor_2bias[i];
f >> tensor_name >> length;
if (tensor_name != "tensor_0weight" ) {
std::string err_msg = "TMVA-SOFIE failed to read the correct tensor name; expected name is tensor_0weight , read " + tensor_name;
throw std::runtime_error(err_msg);
}
if (length != 512) {
std::string err_msg = "TMVA-SOFIE failed to read the correct tensor size; expected size is 512 , read " + std::to_string(length) ;
throw std::runtime_error(err_msg);
}
for (size_t i = 0; i < length; ++i)
f >> tensor_0weight[i];
f >> tensor_name >> length;
if (tensor_name != "tensor_2weight" ) {
std::string err_msg = "TMVA-SOFIE failed to read the correct tensor name; expected name is tensor_2weight , read " + tensor_name;
throw std::runtime_error(err_msg);
}
if (length != 128) {
std::string err_msg = "TMVA-SOFIE failed to read the correct tensor size; expected size is 128 , read " + std::to_string(length) ;
throw std::runtime_error(err_msg);
}
for (size_t i = 0; i < length; ++i)
f >> tensor_2weight[i];
f >> tensor_name >> length;
if (tensor_name != "tensor_0bias" ) {
std::string err_msg = "TMVA-SOFIE failed to read the correct tensor name; expected name is tensor_0bias , read " + tensor_name;
throw std::runtime_error(err_msg);
}
if (length != 16) {
std::string err_msg = "TMVA-SOFIE failed to read the correct tensor size; expected size is 16 , read " + std::to_string(length) ;
throw std::runtime_error(err_msg);
}
for (size_t i = 0; i < length; ++i)
f >> tensor_0bias[i];
f.close();
//---- allocate the intermediate dynamic tensors
//--- broadcast bias tensor 0biasfor Gemm op
{
float * data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_0bias,{ 16 }, { 2 , 16 });
std::copy(data, data + 32, tensor_0biasbcast);
delete [] data;
}
//--- broadcast bias tensor 2biasfor Gemm op
{
float * data = TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<float>(tensor_2bias,{ 8 }, { 2 , 8 });
std::copy(data, data + 16, tensor_2biasbcast);
delete [] data;
}
}
std::vector<float> infer(float* tensor_input1){
//--------- Gemm
char op_0_transA = 'n';
char op_0_transB = 't';
int op_0_m = 2;
int op_0_n = 16;
int op_0_k = 32;
float op_0_alpha = 1;
float op_0_beta = 1;
int op_0_lda = 32;
int op_0_ldb = 32;
std::copy(tensor_0biasbcast, tensor_0biasbcast + 32, tensor_input0);
BLAS::sgemm_(&op_0_transB, &op_0_transA, &op_0_n, &op_0_m, &op_0_k, &op_0_alpha, tensor_0weight, &op_0_ldb, tensor_input1, &op_0_lda, &op_0_beta, tensor_input0, &op_0_n);
//------ RELU
for (int id = 0; id < 32 ; id++){
tensor_result[id] = ((tensor_input0[id] > 0 )? tensor_input0[id] : 0);
}
//--------- Gemm
char op_2_transA = 'n';
char op_2_transB = 't';
int op_2_m = 2;
int op_2_n = 8;
int op_2_k = 16;
float op_2_alpha = 1;
float op_2_beta = 1;
int op_2_lda = 16;
int op_2_ldb = 16;
std::copy(tensor_2biasbcast, tensor_2biasbcast + 16, tensor_input2);
BLAS::sgemm_(&op_2_transB, &op_2_transA, &op_2_n, &op_2_m, &op_2_k, &op_2_alpha, tensor_2weight, &op_2_ldb, tensor_result, &op_2_lda, &op_2_beta, tensor_input2, &op_2_n);
//------ RELU
for (int id = 0; id < 16 ; id++){
tensor_result3[id] = ((tensor_input2[id] > 0 )? tensor_input2[id] : 0);
}
return fTensor_result3;
}
};
} //TMVA_SOFIE_PyTorchModel
#endif // ROOT_TMVA_SOFIE_PYTORCHMODEL
Author
Sanjiban Sengupta

Definition in file TMVA_SOFIE_PyTorch.C.