Skip to content

Commit 2e6ab85

Browse files
authored
Bugfix ModelCreator for required model parameters and user adjusted model parameters (#2780)
* bugfix _check_model_params in ModelCreator * undo debug info in error message
1 parent 42f607b commit 2e6ab85

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

mesa/visualization/solara_viz.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -553,18 +553,18 @@ def ModelCreator(
553553
model_parameters = solara.use_reactive(model_parameters)
554554

555555
solara.use_effect(
556-
lambda: _check_model_params(model.value.__class__.__init__, fixed_params),
556+
lambda: _check_model_params(model.value.__class__.__init__, user_params),
557557
[model.value],
558558
)
559-
user_params, fixed_params = split_model_params(user_params)
559+
user_adjust_params, fixed_params = split_model_params(user_params)
560560

561561
# Use solara.use_effect to run the initialization code only once
562562
solara.use_effect(
563563
# set model_parameters to the default values for all parameters
564564
lambda: model_parameters.set(
565565
{
566566
**fixed_params,
567-
**{k: v.get("value") for k, v in user_params.items()},
567+
**{k: v.get("value") for k, v in user_adjust_params.items()},
568568
}
569569
),
570570
[],
@@ -574,7 +574,7 @@ def ModelCreator(
574574
def on_change(name, value):
575575
model_parameters.value = {**model_parameters.value, name: value}
576576

577-
UserInputs(user_params, on_change=on_change)
577+
UserInputs(user_adjust_params, on_change=on_change)
578578

579579

580580
def _check_model_params(init_func, model_params):

tests/test_solara_viz.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from mesa.visualization.components.altair_components import make_altair_space
1414
from mesa.visualization.components.matplotlib_components import make_mpl_space_component
1515
from mesa.visualization.solara_viz import (
16+
ModelCreator,
1617
Slider,
1718
SolaraViz,
1819
UserInputs,
@@ -273,6 +274,36 @@ def __init__(self, **kwargs):
273274
_check_model_params(ModelWithOnlyRequired.__init__, {})
274275

275276

277+
def test_model_creator(): # noqa: D103
278+
class ModelWithRequiredParam:
279+
def __init__(self, param1):
280+
pass
281+
282+
solara.render(
283+
ModelCreator(
284+
solara.reactive(ModelWithRequiredParam(param1="mock")),
285+
user_params={"param1": 1},
286+
),
287+
handle_error=False,
288+
)
289+
290+
solara.render(
291+
ModelCreator(
292+
solara.reactive(ModelWithRequiredParam(param1="mock")),
293+
user_params={"param1": Slider("Param1", 10, 10, 100, 1)},
294+
),
295+
handle_error=False,
296+
)
297+
298+
with pytest.raises(ValueError, match="Missing required model parameter"):
299+
solara.render(
300+
ModelCreator(
301+
solara.reactive(ModelWithRequiredParam(param1="mock")), user_params={}
302+
),
303+
handle_error=False,
304+
)
305+
306+
276307
# test that _check_model_params raises ValueError when *args are present
277308
def test_check_model_params_with_args_only():
278309
"""Test that _check_model_params raises ValueError when *args are present."""

0 commit comments

Comments
 (0)