10namespace Experimental {
24 senders = std::move(other.senders);
27 fName = std::move(other.fName);
43 senders = std::move(other.senders);
46 fName = std::move(other.fName);
67 for(
auto& it:graph_input_struct.
edges) {
69 senders.emplace_back(it.second);
74 std::time_t ttime = std::time(0);
75 std::tm* gmt_time = std::gmtime(&ttime);
90 fGC+=
"\n\nnamespace Edge_Update{\nstruct Session {\n";
104 fGC+=
"\n\nnamespace Node_Update{\nstruct Session {\n";
122 fGC+=
"\n\nnamespace Global_Update{\nstruct Session {\n";
148 fGC +=
"struct Session {\n";
149 fGC +=
"\n// Instantiating session objects for graph components\n";
150 fGC +=
"Edge_Update::Session edge_update;\n";
151 fGC +=
"Node_Update::Session node_update;\n";
152 fGC +=
"Global_Update::Session global_update;\n\n";
154 std::string e_num = std::to_string(
num_edges);
155 std::string n_num = std::to_string(
num_nodes);
156 std::string e_size_input = std::to_string(num_edge_features_input);
157 std::string n_size_input = std::to_string(num_node_features_input);
158 std::string g_size_input = std::to_string(num_global_features_input);
164 fGC +=
"std::vector<float> fEdgeUpdates = std::vector<float>(" + e_num +
"*" + e_size +
");\n";
165 fGC +=
"\n\nstd::vector<float> fNodeUpdates = std::vector<float>(" + n_num +
"*" + n_size +
");\n";
167 fGC +=
"\n// input vectors for edge update\n";
168 fGC +=
"std::vector<float> fEdgeInputs = std::vector<float>(" + e_num +
"*" + e_size_input +
");\n";
169 fGC +=
"std::vector<float> fRecNodeInputs = std::vector<float>(" + e_num +
"*" + n_size_input +
");\n";
170 fGC +=
"std::vector<float> fSndNodeInputs = std::vector<float>(" + e_num +
"*" + n_size_input +
");\n";
171 fGC +=
"std::vector<float> fGlobInputs = std::vector<float>(" + e_num +
"*" + g_size_input +
");\n\n";
173 fGC +=
"\n// input vectors for node update\n";
174 fGC +=
"std::vector<float> fNodeInputs = std::vector<float>(" + n_num +
"*" + n_size_input +
");\n";
175 fGC +=
"std::vector<float> fNodeEdgeAggregate = std::vector<float>(" + n_num +
"*" + n_size_input +
", 0);\n";
176 fGC +=
"std::vector<float> fNodeAggregateTemp;\n";
178 fGC +=
"\nvoid infer(TMVA::Experimental::SOFIE::GNN_Data& input_graph){\n";
181 fGC +=
"\n// --- Edge Update ---\n";
182 fGC +=
"for (int k = 0; k < " + e_num +
"; k++) { \n";
183 fGC +=
" std::copy(input_graph.edge_data.GetData() + k * " + e_size_input +
184 ", input_graph.edge_data.GetData() + (k + 1) * " + e_size_input +
185 ", fEdgeInputs.begin() + k * " + e_size_input +
");\n";
186 fGC +=
" std::copy(input_graph.node_data.GetData() + input_graph.receivers[k] * " + n_size_input +
187 ", input_graph.node_data.GetData() + (input_graph.receivers[k] + 1) * " + n_size_input +
188 ", fRecNodeInputs.begin() + k * " + n_size_input +
");\n";
189 fGC +=
" std::copy(input_graph.node_data.GetData() + input_graph.senders[k] * " + n_size_input +
190 ", input_graph.node_data.GetData() + (input_graph.senders[k] + 1) * " + n_size_input +
191 ", fSndNodeInputs.begin() + k * " + n_size_input +
");\n";
192 fGC +=
" std::copy(input_graph.global_data.GetData()";
193 fGC +=
", input_graph.global_data.GetData() + " + g_size_input +
194 ", fGlobInputs.begin() + k * " + g_size_input +
");\n";
197 fGC +=
"fEdgeUpdates = " +
edges_update_block->Generate({
"fEdgeInputs.data(), fRecNodeInputs.data(), fSndNodeInputs.data(), fGlobInputs.data()"}) +
"\n";
200 fGC +=
"\n// resize edge graph data since output feature size is not equal to input size\n";
201 fGC+=
"input_graph.edge_data = input_graph.edge_data.Resize({"+e_num+
", "+e_size+
"});\n";
204 fGC +=
"\nfor (int k = 0; k < " + e_num +
"; k++) { \n";
205 fGC +=
" std::copy(fEdgeUpdates.begin()+ k * " + e_size +
", fEdgeUpdates.begin()+ (k+1) * " + e_size +
206 ",input_graph.edge_data.GetData() + k * " + e_size+
");\n";
210 fGC +=
"\n\n// --- Node Update ---\n";
213 fGC +=
"for (int k = 0; k < " + n_num +
"; k++) { \n";
214 fGC +=
" std::copy(input_graph.node_data.GetData() + k * " + n_size_input +
215 ", input_graph.node_data.GetData() + (k + 1) * " + n_size_input +
216 ", fNodeInputs.begin() + k * " + n_size_input +
");\n";
219 fGC +=
"\nstd::fill(fNodeEdgeAggregate.begin(), fNodeEdgeAggregate.end(), 0.);\n";
223 fGC +=
"\n// resize global vector feature to number of nodes\n";
224 fGC +=
"fGlobInputs.resize( " + std::to_string(
num_nodes * num_global_features_input) +
");";
225 fGC +=
"for (size_t k = " + e_num +
"; k < " + n_num +
"; k++)";
226 fGC +=
" std::copy(fGlobInputs.begin(), fGlobInputs.begin() + " + g_size_input +
227 " , fGlobInputs.begin() + k * " + g_size_input +
");\n";
232 fGC +=
"\nfor (int j = 0; j < " + n_num +
"; j++) {\n";
234 fGC +=
" std::vector<float *> edgesData; edgesData.reserve(" + std::to_string(naprec) +
");\n";
236 fGC +=
" for (int k = 0; k < " + e_num +
"; k++) {\n";
237 fGC +=
" if (input_graph.receivers[k] == j) \n";
238 fGC +=
" edgesData.emplace_back(input_graph.edge_data.GetData() + k * " + e_size +
");\n";
241 fGC +=
" std::copy(fNodeAggregateTemp.begin(), fNodeAggregateTemp.end(), fNodeEdgeAggregate.begin() + " +
247 fGC+=
"fNodeUpdates = ";
248 fGC+=
nodes_update_block->Generate({
"fNodeEdgeAggregate.data()",
"fNodeInputs.data()",
"fGlobInputs.data()"});
252 fGC +=
"\n// resize node graph data since output feature size is not equal to input size\n";
253 fGC+=
"input_graph.node_data = input_graph.node_data.Resize({"+n_num+
", "+n_size+
"});\n";
256 fGC +=
"\nfor (int k = 0; k < " + n_num +
"; k++) { \n";
257 fGC +=
" std::copy(fNodeUpdates.begin()+ k * " + n_size +
", fNodeUpdates.begin() + (k+1) * " + n_size +
258 ",input_graph.node_data.GetData() + k * " + n_size+
");\n";
263 std::vector<std::string> Node_Global_Aggregate_String;
265 Node_Global_Aggregate_String.emplace_back(
"input_graph.node_data.GetData()+"+std::to_string(k*
num_node_features));
268 std::vector<std::string> Edge_Global_Aggregate_String;
270 Edge_Global_Aggregate_String.emplace_back(
"input_graph.edge_data.GetData()+"+std::to_string(k*
num_edge_features));
273 fGC +=
"\n// --- Global Update ---\n";
274 fGC+=
"std::vector<float> Edge_Global_Aggregate = ";
278 fGC+=
"std::vector<float> Node_Global_Aggregate = ";
283 fGC +=
"std::vector<float> Global_Data = ";
284 fGC +=
globals_update_block->Generate({
"Edge_Global_Aggregate.data()",
"Node_Global_Aggregate.data()",
"input_graph.global_data.GetData()"});
286 fGC +=
"\n// resize global graph data since output feature size is not equal to input size\n";
287 fGC+=
"input_graph.global_data = input_graph.global_data.Resize({"+g_size+
"});\n";
289 fGC +=
"\nstd::copy(Global_Data.begin(), Global_Data.end(), input_graph.global_data.GetData());";
293 fGC += (
"} //TMVA_SOFIE_" +
fName +
"\n");
294 fGC +=
"\n#endif // TMVA_SOFIE_" + hgname +
"\n";
void GenerateHeaderInfo(std::string &hgname)
std::size_t num_edge_features
std::unique_ptr< RFunction_Aggregate > node_global_agg_block
std::unique_ptr< RFunction_Update > globals_update_block
RModel_GNN & operator=(RModel_GNN &&other)
std::unique_ptr< RFunction_Update > edges_update_block
std::unique_ptr< RFunction_Aggregate > edge_global_agg_block
std::unique_ptr< RFunction_Aggregate > edge_node_agg_block
std::size_t num_global_features
std::vector< int > senders
std::vector< int > receivers
std::size_t num_node_features
std::unique_ptr< RFunction_Update > nodes_update_block
create variable transformations
std::unique_ptr< RFunction_Aggregate > node_global_agg_block
std::unique_ptr< RFunction_Update > globals_update_block
std::unique_ptr< RFunction_Update > nodes_update_block
std::unique_ptr< RFunction_Aggregate > edge_node_agg_block
std::unique_ptr< RFunction_Aggregate > edge_global_agg_block
std::size_t num_node_features
std::vector< std::pair< int, int > > edges
std::unique_ptr< RFunction_Update > edges_update_block
std::size_t num_global_features
std::size_t num_edge_features