Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModel_GNN.cxx
Go to the documentation of this file.
1#include <algorithm>
2#include <cctype>
3#include <fstream>
4#include <limits>
5
6#include "TMVA/RModel_GNN.hxx"
7#include "TMVA/RFunction.hxx"
8
9namespace TMVA {
10namespace Experimental {
11namespace SOFIE {
12
14 edges_update_block = std::move(other.edges_update_block);
15 nodes_update_block = std::move(other.nodes_update_block);
16 globals_update_block = std::move(other.globals_update_block);
17
18 edge_node_agg_block = std::move(other.edge_node_agg_block);
19 edge_global_agg_block = std::move(other.edge_global_agg_block);
20 node_global_agg_block = std::move(other.node_global_agg_block);
21
22 num_nodes = std::move(other.num_nodes);
23 num_edges = std::move(other.num_edges);
24 senders = std::move(other.senders);
25 receivers = std::move(other.receivers);
26
27 fName = std::move(other.fName);
28 fFileName = std::move(other.fFileName);
29 fParseTime = std::move(other.fParseTime);
30}
31
33 edges_update_block = std::move(other.edges_update_block);
34 nodes_update_block = std::move(other.nodes_update_block);
35 globals_update_block = std::move(other.globals_update_block);
36
37 edge_node_agg_block = std::move(other.edge_node_agg_block);
38 edge_global_agg_block = std::move(other.edge_global_agg_block);
39 node_global_agg_block = std::move(other.node_global_agg_block);
40
41 num_nodes = std::move(other.num_nodes);
42 num_edges = std::move(other.num_edges);
43 senders = std::move(other.senders);
44 receivers = std::move(other.receivers);
45
46 fName = std::move(other.fName);
47 fFileName = std::move(other.fFileName);
48 fParseTime = std::move(other.fParseTime);
49
50 return *this;
51}
52
53RModel_GNN::RModel_GNN(GNN_Init& graph_input_struct) {
54 edges_update_block = std::move(graph_input_struct.edges_update_block);
55 nodes_update_block = std::move(graph_input_struct.nodes_update_block);
56 globals_update_block = std::move(graph_input_struct.globals_update_block);
57
58 edge_node_agg_block = std::move(graph_input_struct.edge_node_agg_block);
59 edge_global_agg_block = std::move(graph_input_struct.edge_global_agg_block);
60 node_global_agg_block = std::move(graph_input_struct.node_global_agg_block);
61
62 num_nodes = graph_input_struct.num_nodes;
63 num_edges = graph_input_struct.edges.size();
64 num_node_features = graph_input_struct.num_node_features;
65 num_edge_features = graph_input_struct.num_edge_features;
66 num_global_features = graph_input_struct.num_global_features;
67 for(auto& it:graph_input_struct.edges) {
68 receivers.emplace_back(it.first);
69 senders.emplace_back(it.second);
70 }
71 fFileName = graph_input_struct.filename;
72 fName = fFileName.substr(0, fFileName.rfind("."));
73
74 std::time_t ttime = std::time(0);
75 std::tm* gmt_time = std::gmtime(&ttime);
76 fParseTime = std::asctime(gmt_time);
77}
78
80 std::string hgname;
81 GenerateHeaderInfo(hgname);
82
83 std::ofstream f;
84 f.open(fName+".dat");
85 f.close();
86
87 // Generating Infer function definition for Edge Update function
88 long next_pos;
89 size_t block_size = num_edges;
90 fGC+="\n\nnamespace Edge_Update{\nstruct Session {\n";
91 std::vector<std::vector<std::size_t>> Update_Input_edges = {{block_size, num_edge_features},{block_size, num_node_features},{block_size, num_node_features},{block_size, num_global_features}};
92 edges_update_block->Initialize();
93 edges_update_block->AddInputTensors(Update_Input_edges);
94 fGC+=edges_update_block->GenerateModel(fName);
95 next_pos = edges_update_block->GetFunctionBlock()->WriteInitializedTensorsToFile(fName+".dat");
96 fGC+="};\n}\n";
97
98 // the number of output edges features can be smaller, so we need to correct here
99 auto num_edge_features_input = num_edge_features;
100 if(edges_update_block->GetFunctionBlock()->GetTensorShape(edges_update_block->GetFunctionBlock()->GetOutputTensorNames()[0])[1] != num_edge_features) {
101 num_edge_features = edges_update_block->GetFunctionBlock()->GetTensorShape(edges_update_block->GetFunctionBlock()->GetOutputTensorNames()[0])[1];
102 }
103
104 fGC+="\n\nnamespace Node_Update{\nstruct Session {\n";
105 // Generating Infer function definition for Node Update function
106 // num_node_features is the output one
107
108 block_size = num_nodes;
109 std::vector<std::vector<std::size_t>> Update_Input_nodes = {{block_size, num_edge_features},{block_size, num_node_features},{block_size, num_global_features}};
110 nodes_update_block->Initialize();
111 nodes_update_block->AddInputTensors(Update_Input_nodes);
112 fGC+=nodes_update_block->GenerateModel(fName,next_pos);
113 next_pos = nodes_update_block->GetFunctionBlock()->WriteInitializedTensorsToFile(fName+".dat");
114 fGC+="};\n}\n";
115
116 // we need to correct the output number of node features
117 auto num_node_features_input = num_node_features;
118 if(nodes_update_block->GetFunctionBlock()->GetTensorShape(nodes_update_block->GetFunctionBlock()->GetOutputTensorNames()[0])[1] != num_node_features) {
119 num_node_features = nodes_update_block->GetFunctionBlock()->GetTensorShape(nodes_update_block->GetFunctionBlock()->GetOutputTensorNames()[0])[1];
120 }
121
122 fGC+="\n\nnamespace Global_Update{\nstruct Session {\n";
123 // Generating Infer function definition for Global Update function
124 std::vector<std::vector<std::size_t>> Update_Input_globals = {{1, num_edge_features},{1, num_node_features},{1, num_global_features}};
125 globals_update_block->Initialize();
126 globals_update_block->AddInputTensors(Update_Input_globals);
127 fGC+=globals_update_block->GenerateModel(fName,next_pos);
128 next_pos = globals_update_block->GetFunctionBlock()->WriteInitializedTensorsToFile(fName+".dat");
129 fGC+="};\n}\n";
130
131 // correct for difference in global size (check shape[1] of output og globals update)
132 auto num_global_features_input = num_global_features;
133 if(globals_update_block->GetFunctionBlock()->GetTensorShape(globals_update_block->GetFunctionBlock()->GetOutputTensorNames()[0])[1] != num_global_features) {
134 num_global_features = globals_update_block->GetFunctionBlock()->GetTensorShape(globals_update_block->GetFunctionBlock()->GetOutputTensorNames()[0])[1];
135 }
136
137 fGC+=edge_node_agg_block->GenerateModel();
138
139 if(edge_node_agg_block->GetFunctionType() != edge_global_agg_block->GetFunctionType()) {
140 fGC+=edge_global_agg_block->GenerateModel();
141 }
142 if((edge_node_agg_block->GetFunctionType() != node_global_agg_block->GetFunctionType()) && (edge_global_agg_block->GetFunctionType() != node_global_agg_block->GetFunctionType())) {
143 fGC+=node_global_agg_block->GenerateModel();
144 }
145 fGC+="\n\n";
146
147 // computing inplace on input graph
148 fGC += "struct Session {\n";
149 fGC += "\n// Instantiating session objects for graph components\n";
150 fGC += "Edge_Update::Session edge_update;\n";
151 fGC += "Node_Update::Session node_update;\n";
152 fGC += "Global_Update::Session global_update;\n\n";
153
154 std::string e_num = std::to_string(num_edges);
155 std::string n_num = std::to_string(num_nodes);
156 std::string e_size_input = std::to_string(num_edge_features_input);
157 std::string n_size_input = std::to_string(num_node_features_input);
158 std::string g_size_input = std::to_string(num_global_features_input);
159 std::string e_size = std::to_string(num_edge_features);
160 std::string n_size = std::to_string(num_node_features);
161 std::string g_size = std::to_string(num_global_features);
162
163 // create temp vector for edge and node updates
164 fGC += "std::vector<float> fEdgeUpdates = std::vector<float>(" + e_num + "*" + e_size + ");\n";
165 fGC += "\n\nstd::vector<float> fNodeUpdates = std::vector<float>(" + n_num + "*" + n_size + ");\n";
166
167 fGC += "\n// input vectors for edge update\n";
168 fGC += "std::vector<float> fEdgeInputs = std::vector<float>(" + e_num + "*" + e_size_input + ");\n";
169 fGC += "std::vector<float> fRecNodeInputs = std::vector<float>(" + e_num + "*" + n_size_input + ");\n";
170 fGC += "std::vector<float> fSndNodeInputs = std::vector<float>(" + e_num + "*" + n_size_input + ");\n";
171 fGC += "std::vector<float> fGlobInputs = std::vector<float>(" + e_num + "*" + g_size_input + ");\n\n";
172
173 fGC += "\n// input vectors for node update\n";
174 fGC += "std::vector<float> fNodeInputs = std::vector<float>(" + n_num + "*" + n_size_input + ");\n";
175 fGC += "std::vector<float> fNodeEdgeAggregate = std::vector<float>(" + n_num + "*" + n_size_input + ", 0);\n";
176 fGC += "std::vector<float> fNodeAggregateTemp;\n";
177
178 fGC += "\nvoid infer(TMVA::Experimental::SOFIE::GNN_Data& input_graph){\n";
179
180 // computing updated edge attributes
181 fGC += "\n// --- Edge Update ---\n";
182 fGC += "for (int k = 0; k < " + e_num + "; k++) { \n";
183 fGC += " std::copy(input_graph.edge_data.GetData() + k * " + e_size_input +
184 ", input_graph.edge_data.GetData() + (k + 1) * " + e_size_input +
185 ", fEdgeInputs.begin() + k * " + e_size_input + ");\n";
186 fGC += " std::copy(input_graph.node_data.GetData() + input_graph.receivers[k] * " + n_size_input +
187 ", input_graph.node_data.GetData() + (input_graph.receivers[k] + 1) * " + n_size_input +
188 ", fRecNodeInputs.begin() + k * " + n_size_input + ");\n";
189 fGC += " std::copy(input_graph.node_data.GetData() + input_graph.senders[k] * " + n_size_input +
190 ", input_graph.node_data.GetData() + (input_graph.senders[k] + 1) * " + n_size_input +
191 ", fSndNodeInputs.begin() + k * " + n_size_input + ");\n";
192 fGC += " std::copy(input_graph.global_data.GetData()";
193 fGC += ", input_graph.global_data.GetData() + " + g_size_input +
194 ", fGlobInputs.begin() + k * " + g_size_input + ");\n";
195 fGC += "}\n";
196
197 fGC += "fEdgeUpdates = " + edges_update_block->Generate({"fEdgeInputs.data(), fRecNodeInputs.data(), fSndNodeInputs.data(), fGlobInputs.data()"}) + "\n";
198
199 if(num_edge_features != num_edge_features_input) {
200 fGC += "\n// resize edge graph data since output feature size is not equal to input size\n";
201 fGC+="input_graph.edge_data = input_graph.edge_data.Resize({"+e_num+", "+e_size+"});\n";
202 }
203 // copy output
204 fGC += "\nfor (int k = 0; k < " + e_num + "; k++) { \n";
205 fGC += " std::copy(fEdgeUpdates.begin()+ k * " + e_size + ", fEdgeUpdates.begin()+ (k+1) * " + e_size +
206 ",input_graph.edge_data.GetData() + k * " + e_size+ ");\n";
207 fGC += "}\n";
208 fGC += "\n";
209
210 fGC += "\n\n// --- Node Update ---\n";
211
212 // computing updated edge attributes
213 fGC += "for (int k = 0; k < " + n_num + "; k++) { \n";
214 fGC += " std::copy(input_graph.node_data.GetData() + k * " + n_size_input +
215 ", input_graph.node_data.GetData() + (k + 1) * " + n_size_input +
216 ", fNodeInputs.begin() + k * " + n_size_input + ");\n";
217 fGC += "}\n";
218 // reset initial aggregate edge vector to zero
219 fGC += "\nstd::fill(fNodeEdgeAggregate.begin(), fNodeEdgeAggregate.end(), 0.);\n";
220 // fGlobInputs is size { nedges, ngloblas}. It needs to be here { nnodes, nglobals}
221 // if number of nodes is larger than edges we need to resize it and copy values
222 if (num_nodes > num_edges) {
223 fGC += "\n// resize global vector feature to number of nodes\n";
224 fGC += "fGlobInputs.resize( " + std::to_string(num_nodes * num_global_features_input) + ");";
225 fGC += "for (size_t k = " + e_num + "; k < " + n_num + "; k++)";
226 fGC += " std::copy(fGlobInputs.begin(), fGlobInputs.begin() + " + g_size_input +
227 " , fGlobInputs.begin() + k * " + g_size_input + ");\n";
228 }
229
230 // aggregating edge if it's a receiver node and then updating corresponding node
231 // loop on nodes
232 fGC += "\nfor (int j = 0; j < " + n_num + "; j++) {\n";
233 int naprec = int(num_edges/num_nodes) +1; // approximate number of receivers/node
234 fGC += " std::vector<float *> edgesData; edgesData.reserve(" + std::to_string(naprec) + ");\n";
235 // loop on edges
236 fGC += " for (int k = 0; k < " + e_num + "; k++) {\n";
237 fGC += " if (input_graph.receivers[k] == j) \n";
238 fGC += " edgesData.emplace_back(input_graph.edge_data.GetData() + k * " + e_size + ");\n";
239 fGC += " }\n";
240 fGC += " fNodeAggregateTemp = " + edge_node_agg_block->Generate(num_edge_features, "edgesData") + ";\n";
241 fGC += " std::copy(fNodeAggregateTemp.begin(), fNodeAggregateTemp.end(), fNodeEdgeAggregate.begin() + " +
242 e_size + " * j);\n";
243 fGC += "}\n"; // end node loop
244
245
246 fGC+="\n";
247 fGC+="fNodeUpdates = ";
248 fGC+=nodes_update_block->Generate({"fNodeEdgeAggregate.data()","fNodeInputs.data()","fGlobInputs.data()"}); // computing updated node attributes
249 fGC+="\n";
250
251 if(num_node_features != num_node_features_input) {
252 fGC += "\n// resize node graph data since output feature size is not equal to input size\n";
253 fGC+="input_graph.node_data = input_graph.node_data.Resize({"+n_num+", "+n_size+"});\n";
254 }
255 // copy output
256 fGC += "\nfor (int k = 0; k < " + n_num + "; k++) { \n";
257 fGC += " std::copy(fNodeUpdates.begin()+ k * " + n_size + ", fNodeUpdates.begin() + (k+1) * " + n_size +
258 ",input_graph.node_data.GetData() + k * " + n_size+ ");\n";
259 fGC += "}\n";
260 fGC += "\n";
261
262 // aggregating edges & nodes for global update
263 std::vector<std::string> Node_Global_Aggregate_String;
264 for(int k=0; k<num_nodes; ++k) {
265 Node_Global_Aggregate_String.emplace_back("input_graph.node_data.GetData()+"+std::to_string(k*num_node_features));
266 }
267
268 std::vector<std::string> Edge_Global_Aggregate_String;
269 for(int k=0; k<num_edges; ++k) {
270 Edge_Global_Aggregate_String.emplace_back("input_graph.edge_data.GetData()+"+std::to_string(k*num_edge_features));
271 }
272
273 fGC += "\n// --- Global Update ---\n";
274 fGC+="std::vector<float> Edge_Global_Aggregate = ";
275 fGC+=edge_global_agg_block->Generate(num_edge_features, Edge_Global_Aggregate_String); // aggregating edge attributes globally
276 fGC+="\n";
277
278 fGC+="std::vector<float> Node_Global_Aggregate = ";
279 fGC+=node_global_agg_block->Generate(num_node_features, Node_Global_Aggregate_String); // aggregating node attributes globally
280 fGC+="\n";
281
282 // computing updated global attributes
283 fGC += "std::vector<float> Global_Data = ";
284 fGC += globals_update_block->Generate({"Edge_Global_Aggregate.data()","Node_Global_Aggregate.data()", "input_graph.global_data.GetData()"});
285 if(num_global_features != num_global_features_input) {
286 fGC += "\n// resize global graph data since output feature size is not equal to input size\n";
287 fGC+="input_graph.global_data = input_graph.global_data.Resize({"+g_size+"});\n";
288 }
289 fGC += "\nstd::copy(Global_Data.begin(), Global_Data.end(), input_graph.global_data.GetData());";
290 fGC+="\n}\n";
291 fGC+="};\n";
292
293 fGC += ("} //TMVA_SOFIE_" + fName + "\n");
294 fGC += "\n#endif // TMVA_SOFIE_" + hgname + "\n";
295}
296
297}//SOFIE
298}//Experimental
299}//TMVA
#define f(i)
Definition RSha256.hxx:104
void GenerateHeaderInfo(std::string &hgname)
std::unique_ptr< RFunction_Aggregate > node_global_agg_block
std::unique_ptr< RFunction_Update > globals_update_block
RModel_GNN & operator=(RModel_GNN &&other)
std::unique_ptr< RFunction_Update > edges_update_block
std::unique_ptr< RFunction_Aggregate > edge_global_agg_block
std::unique_ptr< RFunction_Aggregate > edge_node_agg_block
std::unique_ptr< RFunction_Update > nodes_update_block
create variable transformations
std::unique_ptr< RFunction_Aggregate > node_global_agg_block
std::unique_ptr< RFunction_Update > globals_update_block
std::unique_ptr< RFunction_Update > nodes_update_block
std::unique_ptr< RFunction_Aggregate > edge_node_agg_block
std::unique_ptr< RFunction_Aggregate > edge_global_agg_block
std::vector< std::pair< int, int > > edges
std::unique_ptr< RFunction_Update > edges_update_block