55"""
66
77import itertools
8- import random
98from collections import ChainMap
109from collections .abc import Iterable
10+ from typing import Literal
1111
1212from ema_workbench .em_framework .util import Counter , NamedDict , NamedObject , combine
1313from ema_workbench .util import get_module_logger
1919 "Policy" ,
2020 "Scenario" ,
2121 "combine_cases_factorial" ,
22- "combine_cases_sampling" ,
2322 "experiment_generator" ,
2423]
2524_logger = get_module_logger (__name__ )
@@ -40,9 +39,10 @@ def __init__(self, name=None, unique_id=None, **kwargs):
4039 super ().__init__ (name , ** kwargs )
4140 self .unique_id = unique_id
4241
43- def __repr__ (self ): # noqa D105
42+ def __repr__ (self ): # noqa D105
4443 return f"Point({ super ().__repr__ ()} )"
4544
45+
4646class Policy (Point ):
4747 """Helper class representing a policy.
4848
@@ -66,7 +66,7 @@ def __init__(self, name=None, **kwargs):
6666 #
6767 # return [self[param.name] for param in parameters]
6868
69- def __repr__ (self ): # noqa D105
69+ def __repr__ (self ): # noqa D105
7070 return f"Policy({ super ().__repr__ ()} )"
7171
7272
@@ -88,7 +88,7 @@ class Scenario(Point):
8888 def __init__ (self , name = None , ** kwargs ):
8989 super ().__init__ (name , unique_id = Scenario .id_counter (), ** kwargs )
9090
91- def __repr__ (self ): # noqa: D105
91+ def __repr__ (self ): # noqa: D105
9292 return f"Scenario({ super ().__repr__ ()} )"
9393
9494
@@ -105,14 +105,21 @@ class Experiment(NamedObject):
105105
106106 """
107107
108- def __init__ (self , name , model_name , policy , scenario , experiment_id ):
108+ def __init__ (
109+ self ,
110+ name : str ,
111+ model_name : str ,
112+ policy : Policy ,
113+ scenario : Scenario ,
114+ experiment_id : int ,
115+ ):
109116 super ().__init__ (name )
110117 self .experiment_id = experiment_id
111118 self .policy = policy
112119 self .model_name = model_name
113120 self .scenario = scenario
114121
115- def __repr__ (self ): # noqa: D105
122+ def __repr__ (self ): # noqa: D105
116123 return (
117124 f"Experiment(name={ self .name !r} , model_name={ self .model_name !r} , "
118125 f"policy={ self .policy !r} , scenario={ self .scenario !r} , "
@@ -161,36 +168,6 @@ def zip_cycle(*args):
161168 return itertools .islice (zip (* (itertools .cycle (a ) for a in args )), max_len )
162169
163170
164- def combine_cases_sampling (* point_collection ):
165- """Helper function for combining cases sampling.
166-
167- Combine collections of cases by iterating over the longest collection
168- while sampling with replacement from the others.
169-
170- Parameters
171- ----------
172- point_collection : collection of collection of Point instances
173-
174- Yields
175- ------
176- Point
177-
178- """
179-
180- # figure out the longest
181- def exhaust_cases (cases ):
182- return list (cases )
183-
184- point_collection = [exhaust_cases (case ) for case in point_collection ]
185- longest_cases = max (point_collection , key = len )
186- other_cases = [case for case in point_collection if case is not longest_cases ]
187-
188- for case in longest_cases :
189- other = (random .choice (entry ) for entry in other_cases )
190-
191- yield Point (** ChainMap (case , * other ))
192-
193-
194171def combine_cases_factorial (* point_collections ):
195172 """Combine collections of cases in a full factorial manner.
196173
@@ -234,7 +211,12 @@ def combine_cases_factorial(*point_collections):
234211# return combined_cases
235212
236213
237- def experiment_generator (models :Iterable ["AbstractModel" ], scenarios :Iterable [Scenario ], policies :Iterable [Policy ], combine :str = "factorial" ):
214+ def experiment_generator (
215+ models : Iterable ["AbstractModel" ],
216+ scenarios : Iterable [Scenario ],
217+ policies : Iterable [Policy ],
218+ combine : Literal ["factorial" , "sample" ] = "factorial" ,
219+ ):
238220 """Generator function which yields experiments.
239221
240222 Parameters
0 commit comments