-
Notifications
You must be signed in to change notification settings - Fork 173
Expand file tree
/
Copy pathsystem.py
More file actions
367 lines (325 loc) · 17.3 KB
/
system.py
File metadata and controls
367 lines (325 loc) · 17.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
"""
open-loop (directed acyclic graphs) and closed-loop (directed cyclic graphs) systems components
Minimum viable product
1, system class for open-loop rollout of autonomous nn.Module class
2, system class for open-loop rollout of non-autonomous nn.Module class
3, system class for closed-loop rollout of simple DPC with neural policy and nonautonomous dynamics class (e.g. SSM, psl, ...)
Notes on simple implementation:
Time delay can be handled inside nodes simply or with more complexity
Sporadically sampled data can be handled prior with interpolation
Different time scales can be handled with nested systems
Networked systems seem like a natural fit here
"""
import os
import pydot
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from neuromancer.constraint import Variable
class Node(nn.Module):
"""
Simple class to handle cyclic computational graph connections. input_keys and output_keys
define computational node connections through intermediate dictionaries.
"""
def __init__(self, callable, input_keys, output_keys, name=None):
"""
:param callable: Input: All input arguments are assumed to be torch.Tensors (batchsize, dim)
Output: All outputs are assumed to be torch.Tensors (batchsize, dim)
:param input_keys: (list of str or Variable) For gathering inputs from intermediary data dictionary
:param output_keys: (list of str or Variable) For sending inputs to other nodes through intermediary data dictionary
:param name: (str) Unique node identifier
"""
super().__init__()
self.input_keys = [
var.key if isinstance(var, Variable) else var for var in input_keys
]
self.output_keys = [
var.key if isinstance(var, Variable) else var for var in output_keys
]
self.callable, self.name = callable, name
def forward(self, data):
"""
This call function wraps the callable to receive/send dictionaries of Tensors
:param datadict: (dict {str: Tensor}) input to callable with associated input_keys
:return: (dict {str: Tensor}) Output of callable with associated output_keys
"""
inputs = [data[k] for k in self.input_keys]
output = self.callable(*inputs)
if not isinstance(output, tuple):
output = [output]
return {k: v for k, v in zip(self.output_keys, output)}
def freeze(self):
"""
Freezes the parameters of the callable in this node
"""
for param in self.callable.parameters():
param.requires_grad = False
def unfreeze(self):
"""
Unfreezes the parameters of the callable in this node
"""
for param in self.callable.parameters():
param.requires_grad = True
def __repr__(self):
return f"{self.name}({', '.join(self.input_keys)}) -> {', '.join(self.output_keys)}"
class MovingHorizon(nn.Module):
"""
The MovingHorizon class buffers single time step inputs for time-delay modeling from past ndelay
steps. This class is a wrapper which does data handling for modules which take 3-d input (batch, time, dim)
"""
def __init__(self, module, ndelay=1, history=None):
"""
:param module: nn.Module which takes 3-d input dict and returns 2-d ouput dict
:param ndelay: (int) Time-delay horizon
:param history: (dict {str: list of Tensors}) An optional initialization of the history
buffer from previous measurements. There should be a key for each of
the input_keys in module and the values should be lists of 2-d tensors
"""
super().__init__()
self.input_keys, self.output_keys = module.input_keys, module.output_keys
self.history = {k: [] for k in self.input_keys} if history is None else history
self.ndelay, self.module = ndelay, module
def forward(self, input):
"""
The forward pass appends the input dictionary to the history buffer and gives
last ndelay steps to the module. If history is blank the first step will be
repeated ndelay times to initialize the buffer.
:param input: (dict: str: 2-d tensor (batch, dim)) Dictionary of single step tensor inputs
:return: (dict: str: 3-d Tensor (ndelay, batch, dim)) Dictionary of tensor outputs
"""
for k in self.input_keys:
self.history[k].append(input[k])
if len(self.history[k]) == 1:
self.history[k] *= self.ndelay
inputs = {k: torch.stack(self.history[k][-self.ndelay:]) for k in self.input_keys}
return self.module(inputs)
class System(nn.Module):
"""
Simple implementation for arbitrary cyclic computation
"""
def __init__(self, nodes, name=None, nstep_key='X', init_func=None, nsteps=None):
"""
:param nodes: (list of Node objects)
:param name: (str) Unique identifier for system class.
:param nstep_key: (str) Key is used to infer number of rollout steps from input_data
:param init_func: (callable(input_dict) -> input_dict) This function is used to set initial conditions of the system
:param nsteps: (int) prediction horizon (rollout steps) length
"""
super().__init__()
self.nstep_key = nstep_key
self.nsteps = nsteps
self.nodes, self.name = nn.ModuleList(nodes), name
if init_func is not None:
self.init = init_func
self.input_keys = set().union(*[c.input_keys for c in nodes])
self.output_keys = set().union(*[c.output_keys for c in nodes])
self.system_graph = self.graph()
def graph(self):
self._check_unique_names()
graph = pydot.Dot("problem", graph_type="digraph", splines="spline", rankdir="LR")
graph.add_node(pydot.Node("in", label="dataset", color='skyblue',
style='filled', shape="box"))
sim_loop = pydot.Cluster('sim_loop', color='cornsilk',
style='filled', label='system')
input_keys = []
output_keys = []
nonames = 1
for node in self.nodes:
input_keys += node.input_keys
output_keys += node.output_keys
if node.name is None or node.name == '':
node.name = f'node_{nonames}'
nonames += 1
sim_loop.add_node(pydot.Node(node.name, label=node.name,
color='lavender',
style='filled',
shape="box"))
graph.add_node(pydot.Node('out', label='out', color='skyblue', style='filled', shape='box'))
graph.add_subgraph(sim_loop)
# build node connections in reverse order
reverse_order_nodes = self.nodes[::-1]
for idx_dst, dst in enumerate(reverse_order_nodes):
src_nodes = reverse_order_nodes[1+idx_dst:]
unique_common_keys = set()
for idx_src, src in enumerate(src_nodes):
common_keys = set(src.output_keys) & set(dst.input_keys)
for key in common_keys:
if key not in unique_common_keys:
graph.add_edge(pydot.Edge(src.name, dst.name, label=key))
unique_common_keys.add(key)
# build I/O and node loop connections
loop_keys = []
init_keys = []
previous_output_keys = []
for idx_node, node in enumerate(self.nodes):
node_loop_keys = set(node.input_keys) & set(node.output_keys)
loop_keys += node_loop_keys
init_keys += set(node.input_keys) - set(previous_output_keys)
previous_output_keys += node.output_keys
# build single node recurrent connections
for key in node_loop_keys:
graph.add_edge(pydot.Edge(node.name, node.name, label=key))
# build connections to the dataset
for key in set(node.input_keys) & set(init_keys):
graph.add_edge(pydot.Edge("in", node.name, label=key))
# build feedback connections for init nodes
feedback_src_nodes = reverse_order_nodes[:-1-idx_node]
if len(set(node.input_keys) & set(loop_keys) & set(init_keys)) > 0:
for key in node.input_keys:
for src in feedback_src_nodes:
if key in src.output_keys and key not in previous_output_keys:
graph.add_edge(pydot.Edge(src.name, node.name, label=key))
break
# build connections to the output of the system in a reversed order
previous_output_keys = []
for node in self.nodes[::-1]:
for key in (set(node.output_keys) - set(previous_output_keys)):
graph.add_edge(pydot.Edge(node.name, 'out', label=key))
previous_output_keys += node.output_keys
self.input_keys = list(set(init_keys))
self.output_keys = list(set(output_keys))
return graph
def show(self, figname=None):
graph = self.graph()
if figname is not None:
plot_func = {'svg': graph.write_svg,
'png': graph.write_png,
'jpg': graph.write_jpg}
ext = figname.split('.')[-1]
plot_func[ext](figname)
else:
graph.write_png('system_graph.png')
img = mpimg.imread('system_graph.png')
os.remove('system_graph.png')
plt.figure()
fig = plt.imshow(img, aspect='equal')
fig.axes.get_xaxis().set_visible(False)
fig.axes.get_yaxis().set_visible(False)
plt.show()
def _check_unique_names(self):
num_unique = len([node.name for node in self.nodes])
num_comp = len(self.nodes)
assert num_unique == num_comp, \
"All system nodes must have unique names " \
"to construct a computational graph."
def cat(self, data3d, data2d):
"""
Concatenates data2d contents to corresponding entries in data3d
:param data3d: (dict {str: Tensor}) Input to a node
:param data2d: (dict {str: Tensor}) Output of a node
:return: (dict: {str: Tensor})
"""
for k in data2d:
if k not in data3d:
data3d[k] = data2d[k][:, None, :]
else:
data3d[k] = torch.cat([data3d[k], data2d[k][:, None, :]], dim=1)
return data3d
def init(self, data):
"""
:param data: (dict: {str: Tensor}) Tensor shapes in dictionary are asssumed to be (batch, time, dim)
:return: (dict: {str: Tensor})
Any nodes in the graph that are start nodes will need some data initialized.
Here is an example of initializing an x0 entry in the input_dict.
Provide in base class analysis of computational graph. Label the source nodes. Keys for source nodes have to
be in the data.
"""
return data
def forward(self, input_dict):
"""
:param input_dict: (dict: {str: Tensor}) Tensor shapes in dictionary are asssumed to be (batch, time, dim)
If an init function should be written to assure that any 2-d or 1-d tensors
have 3 dims.
:return: (dict: {str: Tensor}) data with outputs of nstep rollout of Node interactions
"""
data = input_dict.copy()
nsteps = self.nsteps if self.nsteps is not None else data[self.nstep_key].shape[1] # Infer number of rollout steps
data = self.init(data) # Set initial conditions of the system
for i in range(nsteps):
for node in self.nodes:
indata = {k: data[k][:, i] for k in node.input_keys} # collect what the compute node needs from data nodes
outdata = node(indata) # compute
data = self.cat(data, outdata) # feed the data nodes
return data # return recorded system measurements
def freeze(self):
"""
Freezes the parameters of all nodes in the system
"""
for node in self.nodes:
node.freeze()
def unfreeze(self):
"""
Unfreezes the parameters of all nodes in the system
"""
for node in self.nodes:
node.unfreeze()
class SystemPreview(System):
"""
System class with preview of future known variables
"""
def __init__(self, nodes, preview_keys_map: dict={}, preview_length: dict=None, pad_mode: str='circular', pad_constant=0.0,
name=None, nstep_key='X', init_func=None, nsteps=None):
"""
:param nodes: (list of Node objects)
:param preview_keys_map: (dict of string lists) Dict key (str) variable name to be previewed, Value: list of strings containing the names of nodes which expect the preview of this variable
:param preview_length: (dict of ints) Dict key (str) variable_name: Value (int) represnts the length of the future preview
:param pad_mode: (str) Options - 'replicate', 'circular', 'constant' (default value is 0), 'reflect'; more info at https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
:param pad_constant: (float) if pad_mode is 'constant' this specifies the value of padded samples
:param name: (str) Unique identifier for system class.
:param nstep_key: (str) Key is used to infer number of rollout steps from input_data
:param init_func: (callable(input_dict) -> input_dict) This function is used to set initial conditions of the system
:param nsteps: (int) prediction horizon (rollout steps) length
"""
super().__init__(nodes=nodes, name=name)
self.nstep_key = nstep_key
self.nsteps = nsteps
self.nodes, self.name = nn.ModuleList(nodes), name
if init_func is not None:
self.init = init_func
self.input_keys = set().union(*[c.input_keys for c in nodes])
self.output_keys = set().union(*[c.output_keys for c in nodes])
self.system_graph = self.graph()
self.preview_keys_map = preview_keys_map
self.pad_mode = pad_mode
self.pad_constant = pad_constant if self.pad_mode == 'constant' else None
self.preview_length = preview_length if preview_length is not None else {k: nsteps for k in self.preview_keys_map.keys()}
def get_data_with_preview(self, input_dict, var_name, iteration):
"""
Extracts a temporal slice of data for a given variable with a preview window.
This function returns the data segment starting from the current timestep
(`iteration`) and extends for `preview_length` timesteps into the future.
If there are not enough timesteps available in the input tensor (e.g., near
the end of the sequence), the data is padded (with a constant, relicate,
or circular modes, depending on `self.pad_mode` and `self.pad_constant`).
:param input_dict: (dict: {str: Tensor}) Dictionary of tensors with shape (batch, time, dim).
:param var_name: (str) Key identifying the variable in `input_dict`.
:param iteration: (int) Current timestep of the rollout.
:return: (Tensor) Data slice of shape (batch, 1+preview_length, dim).
Includes the current timestep and future preview steps.
"""
data = input_dict[var_name][:,iteration:iteration+1+self.preview_length[var_name],:] # slice input data with future window
if data.shape[1] < self.preview_length[var_name]+1: # if data length insufficient
data = nn.functional.pad(input_dict[var_name], (0, 0, 0, self.preview_length[var_name]), mode=self.pad_mode, # data padding
value=self.pad_constant)[:,iteration:iteration+1+self.preview_length[var_name],:] # slice data
return data
def forward(self, input_dict):
"""
:param input_dict: (dict: {str: Tensor}) Tensor shapes in dictionary are asssumed to be (batch, time, dim)
If an init function should be written to assure that any 2-d or 1-d tensors
have 3 dims.
:return: (dict: {str: Tensor}) data with outputs of nstep rollout of Node interactions
"""
data = input_dict.copy()
nsteps = self.nsteps if self.nsteps is not None else data[self.nstep_key].shape[1] # Infer number of rollout steps
data = self.init(data) # Set initial conditions of the system
for i in range(nsteps):
for node in self.nodes:
indata = {
k: (self.get_data_with_preview(input_dict=data, var_name=k, iteration=i).reshape(data[k].size(0),-1) # Fetch data with future sequences; flatten 3d (batch, time, dim) to 2d (batch, time), e.g., [batch, [r1_{t=0}, r2_{t=0}, r1_{t=1}, r2_{t=1},...]]
if (k in list(self.preview_keys_map.keys()) and node.name in self.preview_keys_map[k]) # Preview is performed if True
else data[k][:, i]) for k in node.input_keys # Otherwise fetch the current timestep, e.g., [batch, [r1_{t=current_timestep}, r2_{t=current_timestep}...]]
}
outdata = node(indata) # compute
data = self.cat(data, outdata) # feed the data nodes
return data # return recorded system measurements