Skip to content

Commit 20b602b

Browse files
Added PyTorch Geometric functionality.
1 parent 4e8c577 commit 20b602b

26 files changed

+2901
-0
lines changed

bindings/pyroot/pythonizations/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ if(tmva)
6565
ROOT/_pythonization/_tmva/_rtensor.py
6666
ROOT/_pythonization/_tmva/_tree_inference.py
6767
ROOT/_pythonization/_tmva/_utils.py)
68+
list(APPEND PYROOT_EXTRA_PY3_SOURCE
69+
ROOT/_pythonization/_tmva/_torchgnn.py)
6870
endif()
6971

7072
if(PYTHON_VERSION_STRING_Development_Main VERSION_GREATER_EQUAL 3.8 AND dataframe)

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
from ._rbdt import Compute, pythonize_rbdt
2424

25+
from ._torchgnn import RModel_TorchGNN
26+
2527
if sys.version_info >= (3, 8):
2628
from ._batchgenerator import (
2729
CreateNumPyGenerators,
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
Helper functions for Python TorchGNN.
3+
4+
Author: Stefan van Berkum
5+
"""
6+
7+
from .. import pythonization
8+
from cppyy.gbl.std import vector, map
9+
10+
11+
class RModel_TorchGNN():
12+
def ExtractParameters(self, model):
13+
"""Extract the parameters from a PyTorch model.
14+
15+
In order for this to work, the parameterized module names in ROOT should
16+
be the same as those in the PyTorch state dictionary, which is named
17+
after the class attributes.
18+
For example:
19+
Torch: self.linear_1 = torch.nn.Linear(5, 20)
20+
ROOT: model.AddModule(ROOT.TMVA.Experimental.SOFIE.RModule_Linear('X',
21+
5, 20), 'linear_1')
22+
23+
:param model: The PyTorch model.
24+
"""
25+
26+
# Transform Python dictionary to C++ map and load parameters.
27+
m = map[str, vector[float]]()
28+
for key, value in model.state_dict().items():
29+
m[key] = value.cpu().numpy().flatten().tolist()
30+
self.LoadParameters(m)
31+
32+
33+
@pythonization("RModel_TorchGNN", ns="TMVA::Experimental::SOFIE")
34+
def pythonize_torchgnn_extractparameters(klass):
35+
setattr(klass, "ExtractParameters", RModel_TorchGNN.ExtractParameters)

tmva/sofie/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,24 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofie
4444
TMVA/ROperator_Erf.hxx
4545
TMVA/SOFIE_common.hxx
4646
TMVA/SOFIEHelpers.hxx
47+
48+
TMVA/TorchGNN/modules/RModule_Add.hxx
49+
TMVA/TorchGNN/modules/RModule_Cat.hxx
50+
TMVA/TorchGNN/modules/RModule_GCNConv.hxx
51+
TMVA/TorchGNN/modules/RModule_GlobalMeanPool.hxx
52+
TMVA/TorchGNN/modules/RModule_Input.hxx
53+
TMVA/TorchGNN/modules/RModule_Linear.hxx
54+
TMVA/TorchGNN/modules/RModule_ReLU.hxx
55+
TMVA/TorchGNN/modules/RModule_Reshape.hxx
56+
TMVA/TorchGNN/modules/RModule_Softmax.hxx
57+
TMVA/TorchGNN/modules/RModule.hxx
58+
59+
TMVA/TorchGNN/RModel_TorchGNN.hxx
4760
SOURCES
4861
src/RModel.cxx
4962
src/SOFIE_common.cxx
63+
64+
src/TorchGNN/RModel_TorchGNN.cxx
5065
DEPENDENCIES
5166
TMVA
5267
)

tmva/sofie/inc/LinkDef.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,7 @@
1414
#pragma link C++ struct TMVA::Experimental::SOFIE::TensorInfo+;
1515
#pragma link C++ struct TMVA::Experimental::SOFIE::InputTensorInfo+;
1616
#pragma link C++ struct TMVA::Experimental::SOFIE::Dim+;
17+
#pragma link C++ class TMVA::Experimental::SOFIE::RModule+;
18+
#pragma link C++ class TMVA::Experimental::SOFIE::RModel_TorchGNN+;
1719

1820
#endif
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
// @(#)root/tmva/sofie:$Id$
2+
// Author: Stefan van Berkum
3+
4+
/**
5+
* Header file for PyTorch Geometric models.
6+
*
7+
* Models are created by the user and parameters can then be loaded into each layer.
8+
*
9+
* IMPORTANT: Changes to the format (e.g., namespaces) may affect the emit
10+
* defined in RModel_TorchGNN.cxx (save).
11+
*/
12+
13+
#ifndef TMVA_SOFIE_RMODEL_TORCHGNN_H_
14+
#define TMVA_SOFIE_RMODEL_TORCHGNN_H_
15+
16+
#include "TMVA/TorchGNN/modules/RModule.hxx"
17+
#include "TMVA/TorchGNN/modules/RModule_Input.hxx"
18+
#include <stdexcept>
19+
#include <iostream>
20+
21+
namespace TMVA {
22+
namespace Experimental {
23+
namespace SOFIE {
24+
25+
class RModel_TorchGNN {
26+
public:
27+
/** Model constructor without inputs. */
28+
RModel_TorchGNN() {}
29+
30+
/**
31+
* Model constructor with manual input names.
32+
*
33+
* @param input_names Vector of input names.
34+
* @param input_shapes Vector of input shapes. Each element may contain
35+
* at most one wildcard (-1).
36+
*/
37+
RModel_TorchGNN(std::vector<std::string> input_names, std::vector<std::vector<int>> input_shapes) {
38+
fInputs = input_names;
39+
fShapes = input_shapes;
40+
41+
// Generate input layers.
42+
for (std::size_t i = 0; i < input_names.size(); i++) {
43+
// Check shape.
44+
if (std::any_of(input_shapes[i].begin(), input_shapes[i].end(), [](int j){return j == 0;})) {
45+
throw std::invalid_argument("Invalid input shape for input " + input_names[i] + ". Dimension cannot be zero.");
46+
}
47+
if (std::any_of(input_shapes[i].begin(), input_shapes[i].end(), [](int j){return j < -1;})) {
48+
throw std::invalid_argument("Invalid input shape for input " + input_names[i] + ". Shape cannot have negative entries (except for the wildcard dimension).");
49+
}
50+
if (std::count(input_shapes[i].begin(), input_shapes[i].end(), -1) > 1) {
51+
throw std::invalid_argument("Invalid input shape for input " + input_names[i] + ". Shape may have at most one wildcard.");
52+
}
53+
AddModule(RModule_Input(input_shapes[i]), input_names[i]);
54+
}
55+
}
56+
57+
/**
58+
* Add a module to the module list.
59+
*
60+
* @param module Module to add.
61+
* @param name Module name. Defaults to the module type with a count
62+
* value (e.g., GCNConv_1).
63+
*/
64+
template<typename T>
65+
void AddModule(T module, std::string name="") {
66+
std::string new_name = (name == "") ? std::string(module.GetOperation()) : name;
67+
if (fModuleCounts[new_name] > 0) {
68+
// Module exists, so add discriminator and increment count.
69+
new_name += "_" + std::to_string(fModuleCounts[new_name]);
70+
fModuleCounts[new_name]++;
71+
72+
if (name != "") {
73+
// Issue warning.
74+
std::cout << "WARNING: Module with duplicate name \"" << name << "\" renamed to \"" << new_name << "\"." << std::endl;
75+
}
76+
} else {
77+
// First module of its kind.
78+
fModuleCounts[new_name] = 1;
79+
}
80+
module.SetName(new_name);
81+
82+
// Initialize the module.
83+
module.Initialize(fModules, fModuleMap);
84+
85+
// Add module to the module list.
86+
fModules.push_back(std::make_shared<T>(module));
87+
fModuleMap[std::string(module.GetName())] = fModuleCount;
88+
fModuleCount++;
89+
}
90+
91+
/**
92+
* Run the forward function.
93+
*
94+
* @param args Any number of input arguments.
95+
* @returns The output of the last layer.
96+
*/
97+
template<class... Types>
98+
std::vector<float> Forward(Types... args) {
99+
auto input = std::make_tuple(args...);
100+
101+
// Instantiate input layers.
102+
int k = 0;
103+
std::apply(
104+
[&](auto&... in) {
105+
((std::dynamic_pointer_cast<RModule_Input>(fModules[k++]) -> SetParams(in)), ...);
106+
}, input);
107+
108+
// Loop through and execute modules.
109+
for (std::shared_ptr<RModule> module: fModules) {
110+
module -> Execute();
111+
}
112+
113+
// Return output of the last layer.
114+
const std::vector<float>& out_const = fModules.back() -> GetOutput();
115+
std::vector<float> out = out_const;
116+
return out;
117+
}
118+
119+
/**
120+
* Load parameters from PyTorch state dictionary for all modules.
121+
*
122+
* @param state_dict The state dictionary.
123+
*/
124+
void LoadParameters(std::map<std::string, std::vector<float>> state_dict) {
125+
for (std::shared_ptr<RModule> module: fModules) {
126+
module -> LoadParameters(state_dict);
127+
}
128+
}
129+
130+
/**
131+
* Load saved parameters for all modules.
132+
*/
133+
void LoadParameters() {
134+
for (std::shared_ptr<RModule> module: fModules) {
135+
module -> LoadParameters();
136+
}
137+
}
138+
139+
/**
140+
* Save the model as standalone inference code.
141+
*
142+
* @param path Path to save location.
143+
* @param name Model name.
144+
* @param overwrite True if any existing directory should be
145+
* overwritten. Defaults to false.
146+
*/
147+
void Save(std::string path, std::string name, bool overwrite=false);
148+
private:
149+
/**
150+
* Get a timestamp.
151+
*
152+
* @returns The timestamp in string format.
153+
*/
154+
static std::string GetTimestamp() {
155+
time_t rawtime;
156+
struct tm * timeinfo;
157+
char timestamp [80];
158+
time(&rawtime);
159+
timeinfo = localtime(&rawtime);
160+
strftime(timestamp, 80, "Timestamp: %d-%m-%Y %T.", timeinfo);
161+
return timestamp;
162+
}
163+
164+
/**
165+
* Write the methods to create a self-contained package.
166+
*
167+
* @param dir Directory to save to.
168+
* @param name Model name.
169+
* @param timestamp Timestamp.
170+
*/
171+
void WriteMethods(std::string dir, std::string name, std::string timestamp);
172+
173+
/**
174+
* Write the model to a file.
175+
*
176+
* @param dir Directory to save to.
177+
* @param name Model name.
178+
* @param timestamp Timestamp.
179+
*/
180+
void WriteModel(std::string dir, std::string name, std::string timestamp);
181+
182+
/**
183+
* Write the CMakeLists file.
184+
*
185+
* @param dir Directory to save to.
186+
* @param name Model name.
187+
* @param timestamp Timestamp.
188+
*/
189+
void WriteCMakeLists(std::string dir, std::string name, std::string timestamp);
190+
191+
std::vector<std::string> fInputs; // Vector of input names.
192+
std::vector<std::vector<int>> fShapes; // Vector of input shapes.
193+
std::map<std::string, int> fModuleCounts; // Map from module name to number of occurrences.
194+
std::vector<std::shared_ptr<RModule>> fModules; // Vector containing the modules.
195+
std::map<std::string, int> fModuleMap; // Map from module name to module index (in modules).
196+
int fModuleCount = 0; // Number of modules.
197+
};
198+
199+
} // SOFIE.
200+
} // Experimental.
201+
} // TMVA.
202+
203+
#endif // TMVA_SOFIE_RMODEL_TORCHGNN_H_

0 commit comments

Comments
 (0)