Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModel_GNN.cxx
Go to the documentation of this file.
1#include <algorithm>
2#include <cctype>
3#include <fstream>
4#include <limits>
5
6#include "TMVA/RModel_GNN.hxx"
7#include "TMVA/RFunction.hxx"
8
9namespace TMVA {
10namespace Experimental {
11namespace SOFIE {
12
14 edges_update_block = std::move(graph_input_struct.edges_update_block);
15 nodes_update_block = std::move(graph_input_struct.nodes_update_block);
16 globals_update_block = std::move(graph_input_struct.globals_update_block);
17
18 edge_node_agg_block = std::move(graph_input_struct.edge_node_agg_block);
19 edge_global_agg_block = std::move(graph_input_struct.edge_global_agg_block);
20 node_global_agg_block = std::move(graph_input_struct.node_global_agg_block);
21
22 num_nodes = graph_input_struct.num_nodes;
23 num_edges = graph_input_struct.edges.size();
24 num_node_features = graph_input_struct.num_node_features;
25 num_edge_features = graph_input_struct.num_edge_features;
26 num_global_features = graph_input_struct.num_global_features;
27
29 fName = fFileName.substr(0, fFileName.rfind("."));
30
31 std::time_t ttime = std::time(0);
32 std::tm* gmt_time = std::gmtime(&ttime);
33 fParseTime = std::asctime(gmt_time);
34}
35
37 std::string hgname;
39
40 std::ofstream f;
41 f.open(fName+".dat");
42 f.close();
43
44 // Generating Infer function definition for Edge Update function
45 long next_pos;
46 //size_t block_size = num_edges;
47 fGC+="\n\nnamespace Edge_Update{\nstruct Session {\n";
48 // there are 4 input tensors for edge updates: {edges, receiver nodes, sender nodes, globals }
49 std::vector<std::vector<Dim>> update_input_edges(4);
54 edges_update_block->Initialize();
55 edges_update_block->AddInputTensors(update_input_edges);
56 fGC+=edges_update_block->GenerateModel(fName);
57 next_pos = edges_update_block->GetFunctionBlock()->WriteInitializedTensorsToFile(fName+".dat");
58 fGC+="};\n}\n";
59
60 // the number of output edges features can be smaller, so we need to correct here
62 auto edges_update_output_shape = edges_update_block->GetFunctionBlock()->GetDynamicTensorShape(edges_update_block->GetFunctionBlock()->GetOutputTensorNames()[0]);
65 }
66
67 fGC+="\n\nnamespace Node_Update{\nstruct Session {\n";
68 // Generating Infer function definition for Node Update function
69 // num_node_features is the output one
70
71 //block_size = num_nodes;
72 // there are 3 input tensors for node updates: {received edges, nodes, globals }
73 std::vector<std::vector<Dim>> update_input_nodes(3);
77 nodes_update_block->Initialize();
78 nodes_update_block->AddInputTensors(update_input_nodes);
79 fGC+=nodes_update_block->GenerateModel(fName,next_pos);
80 next_pos = nodes_update_block->GetFunctionBlock()->WriteInitializedTensorsToFile(fName+".dat");
81 fGC+="};\n}\n";
82
83 // we need to correct the output number of node features
85 auto nodes_update_output_shape = nodes_update_block->GetFunctionBlock()->GetDynamicTensorShape(nodes_update_block->GetFunctionBlock()->GetOutputTensorNames()[0]);
88 }
89
90 fGC+="\n\nnamespace Global_Update{\nstruct Session {\n";
91 // Generating Infer function definition for Global Update function
92 std::vector<std::vector<std::size_t>> update_input_globals = {{1, num_edge_features},{1, num_node_features},{1, num_global_features}};
93 globals_update_block->Initialize();
95 fGC+=globals_update_block->GenerateModel(fName,next_pos);
96 next_pos = globals_update_block->GetFunctionBlock()->WriteInitializedTensorsToFile(fName+".dat");
97 fGC+="};\n}\n";
98
99 // correct for difference in global size (check shape[1] of output of the globals update)
101 if(globals_update_block->GetFunctionBlock()->GetTensorShape(globals_update_block->GetFunctionBlock()->GetOutputTensorNames()[0])[1] != num_global_features) {
102 num_global_features = globals_update_block->GetFunctionBlock()->GetTensorShape(globals_update_block->GetFunctionBlock()->GetOutputTensorNames()[0])[1];
103 }
104
105 fGC+=edge_node_agg_block->GenerateModel();
106
107 if(edge_node_agg_block->GetFunctionType() != edge_global_agg_block->GetFunctionType()) {
108 fGC+=edge_global_agg_block->GenerateModel();
109 }
110 if((edge_node_agg_block->GetFunctionType() != node_global_agg_block->GetFunctionType()) && (edge_global_agg_block->GetFunctionType() != node_global_agg_block->GetFunctionType())) {
111 fGC+=node_global_agg_block->GenerateModel();
112 }
113 fGC+="\n\n";
114
115 // computing inplace on input graph
116 fGC += "struct Session {\n";
117 fGC += "\n// Instantiating session objects for graph components\n";
118 fGC += "Edge_Update::Session edge_update;\n";
119 fGC += "Node_Update::Session node_update;\n";
120 fGC += "Global_Update::Session global_update;\n\n";
121
122 std::string e_num = std::to_string(num_edges);
123 std::string n_num = std::to_string(num_nodes);
124 std::string e_size_input = std::to_string(num_edge_features_input);
125 std::string n_size_input = std::to_string(num_node_features_input);
126 std::string g_size_input = std::to_string(num_global_features_input);
127 std::string e_size = std::to_string(num_edge_features);
128 std::string n_size = std::to_string(num_node_features);
129 std::string g_size = std::to_string(num_global_features);
130
131 // create temp vector for edge and node updates
132 fGC += "std::vector<float> fEdgeUpdates = std::vector<float>(" + e_num + "*" + e_size + ");\n";
133 fGC += "\n\nstd::vector<float> fNodeUpdates = std::vector<float>(" + n_num + "*" + n_size + ");\n";
134
135 fGC += "\n// input vectors for edge update\n";
136 fGC += "std::vector<float> fEdgeInputs = std::vector<float>(" + e_num + "*" + e_size_input + ");\n";
137 fGC += "std::vector<float> fRecNodeInputs = std::vector<float>(" + e_num + "*" + n_size_input + ");\n";
138 fGC += "std::vector<float> fSndNodeInputs = std::vector<float>(" + e_num + "*" + n_size_input + ");\n";
139 fGC += "std::vector<float> fGlobInputs = std::vector<float>(" + e_num + "*" + g_size_input + ");\n\n";
140
141 fGC += "\n// input vectors for node update\n";
142 fGC += "std::vector<float> fNodeInputs = std::vector<float>(" + n_num + "*" + n_size_input + ");\n";
143 fGC += "std::vector<float> fNodeEdgeAggregate = std::vector<float>(" + n_num + "*" + n_size_input + ", 0);\n";
144 fGC += "std::vector<float> fNodeAggregateTemp;\n";
145
146 fGC += "\nvoid infer(TMVA::Experimental::SOFIE::GNN_Data& input_graph){\n";
147
148 // computing updated edge attributes
149 fGC += "\n// --- Edge Update ---\n";
150 fGC += "size_t n_edges = input_graph.edge_data.GetShape()[0];\n";
151 fGC += "if (n_edges > " + e_num + ")\n";
152 fGC += " throw std::runtime_error(\"Number of input edges larger than " + e_num + "\" );\n\n";
153 fGC += "auto receivers = input_graph.edge_index.GetData();\n";
154 fGC += "auto senders = input_graph.edge_index.GetData() + n_edges;\n";
155
156 fGC += "for (size_t k = 0; k < n_edges; k++) { \n";
157 fGC += " std::copy(input_graph.edge_data.GetData() + k * " + e_size_input +
158 ", input_graph.edge_data.GetData() + (k + 1) * " + e_size_input +
159 ", fEdgeInputs.begin() + k * " + e_size_input + ");\n";
160 fGC += " std::copy(input_graph.node_data.GetData() + receivers[k] * " + n_size_input +
161 ", input_graph.node_data.GetData() + (receivers[k] + 1) * " + n_size_input +
162 ", fRecNodeInputs.begin() + k * " + n_size_input + ");\n";
163 fGC += " std::copy(input_graph.node_data.GetData() + senders[k] * " + n_size_input +
164 ", input_graph.node_data.GetData() + (senders[k] + 1) * " + n_size_input +
165 ", fSndNodeInputs.begin() + k * " + n_size_input + ");\n";
166 fGC += " std::copy(input_graph.global_data.GetData()";
167 fGC += ", input_graph.global_data.GetData() + " + g_size_input +
168 ", fGlobInputs.begin() + k * " + g_size_input + ");\n";
169 fGC += "}\n";
170
171 fGC += "fEdgeUpdates = " + edges_update_block->Generate({"n_edges","fEdgeInputs.data(), fRecNodeInputs.data(), fSndNodeInputs.data(), fGlobInputs.data()"}) + "\n";
172
174 fGC += "\n// resize edge graph data since output feature size is not equal to input size\n";
175 fGC+="input_graph.edge_data = input_graph.edge_data.Resize({n_edges, "+e_size+"});\n";
176 }
177 // copy output
178 fGC += "\nfor (size_t k = 0; k < n_edges; k++) { \n";
179 fGC += " std::copy(fEdgeUpdates.begin()+ k * " + e_size + ", fEdgeUpdates.begin()+ (k+1) * " + e_size +
180 ",input_graph.edge_data.GetData() + k * " + e_size + ");\n";
181 fGC += "}\n";
182 fGC += "\n";
183
184 fGC += "\n\n// --- Node Update ---\n";
185 fGC += "size_t n_nodes = input_graph.node_data.GetShape()[0];\n";
186 // computing updated node attributes
187 fGC += "for (size_t k = 0; k < n_nodes; k++) { \n";
188 fGC += " std::copy(input_graph.node_data.GetData() + k * " + n_size_input +
189 ", input_graph.node_data.GetData() + (k + 1) * " + n_size_input +
190 ", fNodeInputs.begin() + k * " + n_size_input + ");\n";
191 fGC += "}\n";
192 // reset initial aggregate edge vector to zero
193 fGC += "\nstd::fill(fNodeEdgeAggregate.begin(), fNodeEdgeAggregate.end(), 0.);\n";
194 // fGlobInputs is size { n_edges, n_globals}. It needs to be here { n_nodes, n_globals}
195 // if number of nodes is larger than edges we need to resize it and copy values
196
197 fGC += "\n// resize global vector feature to number of nodes if needed\n";
198 fGC += "if (n_nodes > n_edges) {\n";
199 fGC += " fGlobInputs.resize( n_nodes * " + std::to_string(num_global_features_input) + ");\n";
200 fGC += " for (size_t k = n_edges; k < n_nodes; k++)\n";
201 fGC += " std::copy(fGlobInputs.begin(), fGlobInputs.begin() + " + g_size_input +
202 " , fGlobInputs.begin() + k * " + g_size_input + ");\n";
203 fGC += "}\n";
204
205 // loop on nodes and aggregate incoming edges
206 fGC += "\n// aggregate edges going to a node\n";
207 fGC += "for (size_t j = 0; j < n_nodes; j++) {\n";
208 // approximate number of receivers/node to allocate vector
209 fGC += " std::vector<float *> edgesData; edgesData.reserve( int(n_edges/n_nodes) +1);\n";
210 // loop on edges
211 fGC += " for (size_t k = 0; k < n_edges; k++) {\n";
212 fGC += " if (receivers[k] == j) \n";
213 fGC += " edgesData.emplace_back(input_graph.edge_data.GetData() + k * " + e_size + ");\n";
214 fGC += " }\n";
215 fGC += " fNodeAggregateTemp = " + edge_node_agg_block->Generate(num_edge_features, "edgesData") + ";\n";
216 fGC += " std::copy(fNodeAggregateTemp.begin(), fNodeAggregateTemp.end(), fNodeEdgeAggregate.begin() + " +
217 e_size + " * j);\n";
218 fGC += "}\n"; // end node loop
219
220
221 fGC+="\n";
222 fGC+="fNodeUpdates = ";
223 fGC+=nodes_update_block->Generate({"n_nodes","fNodeEdgeAggregate.data()","fNodeInputs.data()","fGlobInputs.data()"}); // computing updated node attributes
224 fGC+="\n";
225
227 fGC += "\n// resize node graph data since output feature size is not equal to input size\n";
228 fGC+="input_graph.node_data = input_graph.node_data.Resize({n_nodes, " + n_size + "});\n";
229 }
230 // copy output
231 fGC += "\nfor (size_t k = 0; k < n_nodes; k++) { \n";
232 fGC += " std::copy(fNodeUpdates.begin()+ k * " + n_size + ", fNodeUpdates.begin() + (k+1) * " + n_size +
233 ",input_graph.node_data.GetData() + k * " + n_size+ ");\n";
234 fGC += "}\n";
235 fGC += "\n";
236
237 // aggregating edges & nodes for global update
238 fGC += "std::vector<float *> allEdgesData; allEdgesData.reserve(n_edges);\n";
239 fGC += "for (size_t k = 0; k < n_edges; k++) {\n";
240 fGC += " allEdgesData.emplace_back(input_graph.edge_data.GetData() + k * " + e_size + ");\n";
241 fGC += "}\n";
242 fGC += "std::vector<float *> allNodesData; allNodesData.reserve(n_nodes);\n";
243 fGC += "for (size_t k = 0; k < n_nodes; k++) {\n";
244 fGC += " allNodesData.emplace_back(input_graph.node_data.GetData() + k * " + n_size + ");\n";
245 fGC += "}\n";
246
247
248 fGC += "\n// --- Global Update ---\n";
249 fGC+="std::vector<float> Edge_Global_Aggregate = ";
250 fGC+=edge_global_agg_block->Generate(num_edge_features, "allEdgesData"); // aggregating edge attributes globally
251 fGC+=";\n";
252
253 fGC+="std::vector<float> Node_Global_Aggregate = ";
254 fGC+=node_global_agg_block->Generate(num_node_features, "allNodesData"); // aggregating node attributes globally
255 fGC+=";\n";
256
257 // computing updated global attributes
258 fGC += "std::vector<float> Global_Data = ";
259 fGC += globals_update_block->Generate({"Edge_Global_Aggregate.data()","Node_Global_Aggregate.data()", "input_graph.global_data.GetData()"});
261 fGC += "\n// resize global graph data since output feature size is not equal to input size\n";
262 fGC+="input_graph.global_data = input_graph.global_data.Resize({"+g_size+"});\n";
263 }
264 fGC += "\nstd::copy(Global_Data.begin(), Global_Data.end(), input_graph.global_data.GetData());";
265 fGC+="\n}\n";
266 fGC+="};\n";
267
268 fGC += ("} //TMVA_SOFIE_" + fName + "\n");
269 fGC += "\n#endif // TMVA_SOFIE_" + hgname + "\n";
270}
271
272}//SOFIE
273}//Experimental
274}//TMVA
#define f(i)
Definition RSha256.hxx:104
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
void GenerateHeaderInfo(std::string &hgname)
std::unique_ptr< RFunction_Aggregate > node_global_agg_block
std::unique_ptr< RFunction_Update > globals_update_block
std::unique_ptr< RFunction_Update > edges_update_block
std::unique_ptr< RFunction_Aggregate > edge_global_agg_block
std::unique_ptr< RFunction_Aggregate > edge_node_agg_block
RModel_GNN(GNN_Init &graph_input_struct)
std::unique_ptr< RFunction_Update > nodes_update_block
create variable transformations