Skip to content

Commit 3dbb5b6

Browse files
committed
Allow all Agent subclasses
1 parent 86b184d commit 3dbb5b6

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

mesa/datacollection.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -256,15 +256,14 @@ def get_reports(agent):
256256
else:
257257
from mesa import Agent
258258

259-
# Check if agent_type is an Agent subclass
260259
if issubclass(agent_type, Agent):
261-
raise NotImplementedError(
262-
f"Agent type {agent_type} is not in model.agent_types. We might implement using superclasses in the future. For now, use one of {agent_types}."
263-
)
260+
agents = [
261+
agent for agent in model.agents if isinstance(agent, agent_type)
262+
]
264263
else:
265264
# Raise error if agent_type is not in model.agent_types
266265
raise ValueError(
267-
f"Agent type {agent_type} is not recognized as an Agent type in the model. Use one of {agent_types}."
266+
f"Agent type {agent_type} is not recognized as an Agent type in the model or Agent subclass. Use an Agent (sub)class, like {agent_types}."
268267
)
269268

270269
agenttype_records = map(get_reports, agents)

tests/test_datacollector.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -366,14 +366,22 @@ def test_agenttype_reporter_multiple_types(self):
366366
self.assertNotIn("type_b_val", agent_a_data.columns)
367367
self.assertNotIn("type_a_val", agent_b_data.columns)
368368

369-
def test_agenttype_reporter_not_in_model(self):
370-
"""Test NotImplementedError is raised when agent type is not in model.agents_by_type."""
369+
def test_agenttype_superclass_reporter(self):
370+
"""Test adding a reporter for a superclass of an agent type."""
371371
model = MockModelWithAgentTypes()
372-
# MockAgent is a legit Agent subclass, but it is not in model.agents_by_type
373372
model.datacollector._new_agenttype_reporter(MockAgent, "val", lambda a: a.val)
374-
with self.assertRaises(NotImplementedError):
373+
model.datacollector._new_agenttype_reporter(Agent, "val", lambda a: a.val)
374+
for _ in range(3):
375375
model.step()
376376

377+
super_data = model.datacollector.get_agenttype_vars_dataframe(MockAgent)
378+
agent_data = model.datacollector.get_agenttype_vars_dataframe(Agent)
379+
self.assertIn("val", super_data.columns)
380+
self.assertIn("val", agent_data.columns)
381+
self.assertEqual(len(super_data), 30) # 10 agents * 3 steps
382+
self.assertEqual(len(agent_data), 30)
383+
self.assertTrue(super_data.equals(agent_data))
384+
377385

378386
if __name__ == "__main__":
379387
unittest.main()

0 commit comments

Comments
 (0)