Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RFunction.cxx
Go to the documentation of this file.
1#include "TMVA/RModel.hxx"
2#include "TMVA/RFunction.hxx"
3
4
5namespace TMVA {
6namespace Experimental {
7namespace SOFIE {
8
9
10
12 switch(target) {
14 fFuncName = "edge_update";
15 break;
16 }
18 fFuncName = "node_update";
19 break;
20 }
22 fFuncName = "global_update";
23 break;
24 }
25 default:
26 throw std::runtime_error("Invalid target for Update function");
27 }
29 function_block = std::make_unique<RModel>(fFuncName);
30
33 fInputTensors = {"edge","receiver","sender","global"};
35 fInputTensors = {"edge","node","global"};
36 }
37
40 fInputTensors = {"edge"};
41 } else if(fTarget == FunctionTarget::NODES) {
42 fInputTensors = {"node"};
43 } else {
44 fInputTensors = {"global"};
45 }
46 }
47}
48
49void RFunction_Update::AddInputTensors(const std::vector<std::vector<std::size_t>>& fInputShape) {
50 for(long unsigned int i=0; i<fInputShape.size(); ++i) {
51 function_block->AddInputTensorInfo(fInputTensors[i],ETensorType::FLOAT, fInputShape[i]);
52 function_block->AddInputTensorName(fInputTensors[i]);
53 }
54}
55
56std::string RFunction_Update::GenerateModel(const std::string& filename, long read_pos, long block_size) {
57 function_block->SetFilename(filename);
58 // use batch size as block size in RModel::generate
59 function_block->Generate(Options::kGNNComponent,block_size,read_pos);
60 std::string modelGenerationString;
61 modelGenerationString = "\n//--------- GNN_Update_Function---"+fFuncName+"\n"+function_block->ReturnGenerated();
62 return modelGenerationString;
63}
64
65std::string RFunction_Update::Generate(const std::vector<std::string>& inputPtrs) {
66 std::string inferFunc = fFuncName+".infer(";
67 for(auto&it : inputPtrs) {
68 inferFunc+=it;
69 inferFunc+=",";
70 }
71 inferFunc.pop_back();
72 inferFunc+=");";
73 return inferFunc;
74}
75
76// passing as input a vector of strings for each input tensor
77std::string RFunction_Aggregate::Generate(std::size_t num_features, const std::vector<std::string>& inputTensors) {
78 std::string inferFunc = fFuncName+"("+std::to_string(num_features)+",{";
79 for(auto&it : inputTensors) {
80 inferFunc+=it;
81 inferFunc+=",";
82 }
83 inferFunc.pop_back();
84 inferFunc+="});";
85 return inferFunc;
86}
87
88// here passing directly the name of the vector containing the input tensor
89std::string RFunction_Aggregate::Generate(std::size_t num_features, const std::string & inputTensors) {
90 std::string inferFunc = fFuncName + "(" +std::to_string(num_features) + "," + inputTensors + ")";
91 return inferFunc;
92}
93
94
95
96
97}
98}
99}
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char filename
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t Float_t Float_t Int_t Int_t UInt_t UInt_t Rectangle_t Int_t Int_t Window_t TString Int_t GCValues_t GetPrimarySelectionOwner GetDisplay GetScreen GetColormap GetNativeEvent const char const char dpyName wid window const char font_name cursor keysym reg const char only_if_exist regb h Point_t winding char text const char depth char const char Int_t count const char ColorStruct_t color const char Pixmap_t Pixmap_t PictureAttributes_t attr const char char ret_data h unsigned char height h Atom_t Int_t ULong_t ULong_t unsigned char prop_list Atom_t Atom_t target
std::string Generate(std::size_t num_features, const std::vector< std::string > &inputTensors)
Definition RFunction.cxx:77
void AddInputTensors(const std::vector< std::vector< std::size_t > > &fInputShape)
Definition RFunction.cxx:49
std::shared_ptr< RModel > function_block
Definition RFunction.hxx:35
std::string Generate(const std::vector< std::string > &inputPtrs)
Definition RFunction.cxx:65
std::string GenerateModel(const std::string &filename, long read_pos=0, long block_size=1)
Definition RFunction.cxx:56
std::vector< std::string > fInputTensors
Definition RFunction.hxx:38
create variable transformations
static int gType