Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RModel_GNN.hxx
Go to the documentation of this file.
1#ifndef TMVA_SOFIE_RMODEL_GNN
2#define TMVA_SOFIE_RMODEL_GNN
3
4#include <ctime>
5
7#include "TMVA/RModel.hxx"
8#include "TMVA/RFunction.hxx"
9
10namespace TMVA {
11namespace Experimental {
12namespace SOFIE {
13
14class RFunction_Update;
15class RFunction_Aggregate;
16
17struct GNN_Init {
18
19 // Explicitly define default constructor so cppyy doesn't attempt
20 // aggregate initialization.
22
23 // update blocks
24 std::unique_ptr<RFunction_Update> edges_update_block;
25 std::unique_ptr<RFunction_Update> nodes_update_block;
26 std::unique_ptr<RFunction_Update> globals_update_block;
27
28 // aggregation blocks
29 std::unique_ptr<RFunction_Aggregate> edge_node_agg_block;
30 std::unique_ptr<RFunction_Aggregate> edge_global_agg_block;
31 std::unique_ptr<RFunction_Aggregate> node_global_agg_block;
32
33 std::size_t num_nodes;
34 std::vector<std::pair<int, int>> edges;
35
36 std::size_t num_node_features;
37 std::size_t num_edge_features;
39
40 std::string filename;
41
42 template <typename T>
44 {
45 switch (updateFunction.GetFunctionTarget()) {
48 break;
49 }
52 break;
53 }
56 break;
57 }
58 default: {
59 throw std::runtime_error("TMVA SOFIE: Invalid Update function supplied for creating GNN function block.");
60 }
61 }
62 }
63
64 template <typename T>
66 {
67 switch (relation) {
69 edge_node_agg_block.reset(new T(aggFunction));
70 break;
71 }
74 break;
75 }
78 break;
79 }
80 default: {
81 throw std::runtime_error("TMVA SOFIE: Invalid Aggregate function supplied for creating GNN function block.");
82 }
83 }
84 }
85};
86
88
89private:
90 // update function for edges, nodes & global attributes
91 std::unique_ptr<RFunction_Update> edges_update_block;
92 std::unique_ptr<RFunction_Update> nodes_update_block;
93 std::unique_ptr<RFunction_Update> globals_update_block;
94
95 // aggregation function for edges, nodes & global attributes
96 std::unique_ptr<RFunction_Aggregate> edge_node_agg_block;
97 std::unique_ptr<RFunction_Aggregate> edge_global_agg_block;
98 std::unique_ptr<RFunction_Aggregate> node_global_agg_block;
99
100 std::size_t num_nodes; // maximum number of nodes
101 std::size_t num_edges; // maximum number of edges
102
103 std::size_t num_node_features;
104 std::size_t num_edge_features;
106
107public:
109
110 void Generate() final;
111};
112
113} // namespace SOFIE
114} // namespace Experimental
115} // namespace TMVA
116
117#endif // TMVA_SOFIE_RMODEL_GNN
std::unique_ptr< RFunction_Aggregate > node_global_agg_block
std::unique_ptr< RFunction_Update > globals_update_block
std::unique_ptr< RFunction_Update > edges_update_block
std::unique_ptr< RFunction_Aggregate > edge_global_agg_block
std::unique_ptr< RFunction_Aggregate > edge_node_agg_block
RModel_GNN(GNN_Init &graph_input_struct)
std::unique_ptr< RFunction_Update > nodes_update_block
create variable transformations
std::vector< std::pair< int, int > > edges
std::unique_ptr< RFunction_Aggregate > node_global_agg_block
std::unique_ptr< RFunction_Update > globals_update_block
std::unique_ptr< RFunction_Update > nodes_update_block
std::unique_ptr< RFunction_Aggregate > edge_node_agg_block
std::unique_ptr< RFunction_Aggregate > edge_global_agg_block
void createAggregateFunction(T &aggFunction, FunctionRelation relation)
std::unique_ptr< RFunction_Update > edges_update_block
void createUpdateFunction(T &updateFunction)