Skip to content

Commit 1ef4854

Browse files
authored
More fine-grained control for custom sampling schemes (#419)
1 parent 299bfef commit 1ef4854

File tree

6 files changed

+232
-68
lines changed

6 files changed

+232
-68
lines changed

ema_workbench/em_framework/callbacks.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
import numpy as np
1919
import pandas as pd
2020

21+
from .outcomes import AbstractOutcome
2122
from ..util import ema_exceptions, get_module_logger
22-
from .parameters import BooleanParameter, CategoricalParameter, IntegerParameter
23+
from .parameters import BooleanParameter, CategoricalParameter, IntegerParameter, Parameter
2324
from .util import ProgressTrackingMixIn
2425

2526
#
@@ -157,13 +158,13 @@ class DefaultCallback(AbstractCallback):
157158

158159
def __init__(
159160
self,
160-
uncertainties,
161-
levers,
162-
outcomes,
163-
nr_experiments,
164-
reporting_interval=100,
165-
reporting_frequency=10,
166-
log_progress=False,
161+
uncertainties:list[Parameter],
162+
levers:list[Parameter],
163+
outcomes:list[AbstractOutcome],
164+
nr_experiments:int,
165+
reporting_interval:int=100,
166+
reporting_frequency:int=10,
167+
log_progress:bool=False,
167168
):
168169
"""Init.
169170

ema_workbench/em_framework/evaluators.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def perform_experiments(
337337
lever_sampling_kwargs:dict|None=None,
338338
callback:type[AbstractCallback]|None=None,
339339
return_callback:bool=False,
340-
combine:str="factorial",
340+
combine:Literal["factorial", "sample"]="factorial",
341341
log_progress:bool=False,
342342
**kwargs,
343343
) -> DefaultCallback:
@@ -382,7 +382,7 @@ def perform_experiments(
382382
383383
"""
384384
# TODO:: break up in to helper functions
385-
# unreadable in this form
385+
# unreadable in this form
386386

387387
if not scenarios and not policies:
388388
raise EMAError(
@@ -430,7 +430,7 @@ def perform_experiments(
430430
f"{nr_of_exp} experiments"
431431
)
432432
case _:
433-
raise ValueError(f'unknown value for combine, got {combine}, should be one of "zipover" or "factorial"')
433+
raise ValueError(f'unknown value for combine, got {combine}, should be one of "sample" or "factorial"')
434434

435435
callback = setup_callback(
436436
callback,

ema_workbench/em_framework/points.py

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
"""
66

77
import itertools
8-
import random
98
from collections import ChainMap
109
from collections.abc import Iterable
10+
from typing import Literal
1111

1212
from ema_workbench.em_framework.util import Counter, NamedDict, NamedObject, combine
1313
from ema_workbench.util import get_module_logger
@@ -19,7 +19,6 @@
1919
"Policy",
2020
"Scenario",
2121
"combine_cases_factorial",
22-
"combine_cases_sampling",
2322
"experiment_generator",
2423
]
2524
_logger = get_module_logger(__name__)
@@ -40,9 +39,10 @@ def __init__(self, name=None, unique_id=None, **kwargs):
4039
super().__init__(name, **kwargs)
4140
self.unique_id = unique_id
4241

43-
def __repr__(self): # noqa D105
42+
def __repr__(self): # noqa D105
4443
return f"Point({super().__repr__()})"
4544

45+
4646
class Policy(Point):
4747
"""Helper class representing a policy.
4848
@@ -66,7 +66,7 @@ def __init__(self, name=None, **kwargs):
6666
#
6767
# return [self[param.name] for param in parameters]
6868

69-
def __repr__(self): # noqa D105
69+
def __repr__(self): # noqa D105
7070
return f"Policy({super().__repr__()})"
7171

7272

@@ -88,7 +88,7 @@ class Scenario(Point):
8888
def __init__(self, name=None, **kwargs):
8989
super().__init__(name, unique_id=Scenario.id_counter(), **kwargs)
9090

91-
def __repr__(self): # noqa: D105
91+
def __repr__(self): # noqa: D105
9292
return f"Scenario({super().__repr__()})"
9393

9494

@@ -105,14 +105,21 @@ class Experiment(NamedObject):
105105
106106
"""
107107

108-
def __init__(self, name, model_name, policy, scenario, experiment_id):
108+
def __init__(
109+
self,
110+
name: str,
111+
model_name: str,
112+
policy: Policy,
113+
scenario: Scenario,
114+
experiment_id: int,
115+
):
109116
super().__init__(name)
110117
self.experiment_id = experiment_id
111118
self.policy = policy
112119
self.model_name = model_name
113120
self.scenario = scenario
114121

115-
def __repr__(self): # noqa: D105
122+
def __repr__(self): # noqa: D105
116123
return (
117124
f"Experiment(name={self.name!r}, model_name={self.model_name!r}, "
118125
f"policy={self.policy!r}, scenario={self.scenario!r}, "
@@ -161,36 +168,6 @@ def zip_cycle(*args):
161168
return itertools.islice(zip(*(itertools.cycle(a) for a in args)), max_len)
162169

163170

164-
def combine_cases_sampling(*point_collection):
165-
"""Helper function for combining cases sampling.
166-
167-
Combine collections of cases by iterating over the longest collection
168-
while sampling with replacement from the others.
169-
170-
Parameters
171-
----------
172-
point_collection : collection of collection of Point instances
173-
174-
Yields
175-
------
176-
Point
177-
178-
"""
179-
180-
# figure out the longest
181-
def exhaust_cases(cases):
182-
return list(cases)
183-
184-
point_collection = [exhaust_cases(case) for case in point_collection]
185-
longest_cases = max(point_collection, key=len)
186-
other_cases = [case for case in point_collection if case is not longest_cases]
187-
188-
for case in longest_cases:
189-
other = (random.choice(entry) for entry in other_cases)
190-
191-
yield Point(**ChainMap(case, *other))
192-
193-
194171
def combine_cases_factorial(*point_collections):
195172
"""Combine collections of cases in a full factorial manner.
196173
@@ -234,7 +211,12 @@ def combine_cases_factorial(*point_collections):
234211
# return combined_cases
235212

236213

237-
def experiment_generator(models:Iterable["AbstractModel"], scenarios:Iterable[Scenario], policies:Iterable[Policy], combine:str="factorial"):
214+
def experiment_generator(
215+
models: Iterable["AbstractModel"],
216+
scenarios: Iterable[Scenario],
217+
policies: Iterable[Policy],
218+
combine: Literal["factorial", "sample"] = "factorial",
219+
):
238220
"""Generator function which yields experiments.
239221
240222
Parameters

0 commit comments

Comments
 (0)