Skip to content

Commit 4aa02fd

Browse files
committed
Convert make_space into solara components
1 parent 47637a7 commit 4aa02fd

File tree

1 file changed

+40
-23
lines changed

1 file changed

+40
-23
lines changed

mesa/experimental/jupyter_viz.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,7 @@ def make_model():
4848

4949
with solara.GridFixed(columns=2):
5050
# 4. Space
51-
if space_drawer is None:
52-
make_space(model, agent_portrayal)
53-
else:
54-
space_drawer(model, agent_portrayal)
51+
SpaceView(space_drawer, model, agent_portrayal)
5552
# 5. Plots
5653
for measure in measures:
5754
if callable(measure):
@@ -182,15 +179,16 @@ def make_user_input(user_input, name, options):
182179
raise ValueError(f"{input_type} is not a supported input type")
183180

184181

185-
def make_space(model, agent_portrayal):
186-
def portray(g):
182+
@solara.component
183+
def GridView(grid, agent_portrayal):
184+
def portray(grid):
187185
x = []
188186
y = []
189-
s = [] # size
190-
c = [] # color
191-
for i in range(g.width):
192-
for j in range(g.height):
193-
content = g._grid[i][j]
187+
sizes = []
188+
colors = []
189+
for i in range(grid.width):
190+
for j in range(grid.height):
191+
content = grid._grid[i][j]
194192
if not content:
195193
continue
196194
if not hasattr(content, "__iter__"):
@@ -201,35 +199,54 @@ def portray(g):
201199
x.append(i)
202200
y.append(j)
203201
if "size" in data:
204-
s.append(data["size"])
202+
sizes.append(data["size"])
205203
if "color" in data:
206-
c.append(data["color"])
204+
colors.append(data["color"])
207205
out = {"x": x, "y": y}
208-
if len(s) > 0:
209-
out["s"] = s
210-
if len(c) > 0:
211-
out["c"] = c
206+
if len(sizes) > 0:
207+
out["s"] = sizes
208+
if len(colors) > 0:
209+
out["c"] = colors
212210
return out
213211

214212
space_fig = Figure()
215213
space_ax = space_fig.subplots()
216-
if isinstance(model.grid, mesa.space.NetworkGrid):
217-
_draw_network_grid(model, space_ax, agent_portrayal)
218-
else:
219-
space_ax.scatter(**portray(model.grid))
214+
space_ax.scatter(**portray(grid))
220215
space_ax.set_axis_off()
221216
solara.FigureMatplotlib(space_fig)
222217

223218

224-
def _draw_network_grid(model, space_ax, agent_portrayal):
225-
graph = model.grid.G
219+
@solara.component
220+
def SpaceView(
221+
space_drawer,
222+
model,
223+
agent_portrayal,
224+
):
225+
if space_drawer is not None:
226+
return space_drawer(model, agent_portrayal)
227+
228+
if isinstance(model.grid, mesa.space.NetworkGrid):
229+
return NetworkSpace(model.grid.G, agent_portrayal)
230+
231+
if isinstance(model.grid, mesa.space._Grid):
232+
return GridView(model.grid, agent_portrayal)
233+
234+
raise ValueError(f"Unsupported space type: {type(model.grid)}")
235+
236+
237+
@solara.component
238+
def NetworkSpace(graph, agent_portrayal):
239+
space_fig = Figure()
240+
space_ax = space_fig.subplots()
226241
pos = nx.spring_layout(graph, seed=0)
227242
nx.draw(
228243
graph,
229244
ax=space_ax,
230245
pos=pos,
231246
**agent_portrayal(graph),
232247
)
248+
space_ax.set_axis_off()
249+
solara.FigureMatplotlib(space_fig)
233250

234251

235252
def make_plot(model, measure):

0 commit comments

Comments
 (0)