Skip to content

Commit 87e6134

Browse files
committed
improved handling of boolean variables and dedicated unit test for it
1 parent ddefa34 commit 87e6134

File tree

3 files changed

+38
-12
lines changed

3 files changed

+38
-12
lines changed

ema_workbench/analysis/scenario_discovery_util.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -221,26 +221,24 @@ def _in_box(x, boxlim):
221221
category dtype
222222
223223
"""
224-
x_numbered = x.select_dtypes(np.number)
225-
boxlim_numbered = boxlim.select_dtypes(np.number)
224+
number_like = [np.number, np.bool]
225+
x_numbered = x.select_dtypes(number_like)
226+
boxlim_numbered = boxlim.select_dtypes(number_like)
226227
logical = (boxlim_numbered.loc[0, :].values <= x_numbered.values) & (
227228
x_numbered.values <= boxlim_numbered.loc[1, :].values
228229
)
229230
logical = logical.all(axis=1)
230231

231232
# TODO:: how to speed this up
232-
for column, values in x.select_dtypes(exclude=np.number).items():
233+
234+
for column, values in x.select_dtypes(exclude=number_like).items():
233235
entries = boxlim.loc[0, column]
234-
if values.dtype == np.dtype(np.bool):
235-
l = x[column] == entries
236-
logical = logical & l
237-
else:
238-
not_present = set(values.cat.categories.values) - entries
236+
not_present = set(values.cat.categories.values) - entries
239237

240-
if not_present:
241-
# what other options do we have here....
242-
l = pd.isnull(x[column].cat.remove_categories(list(entries))) # noqa: E741
243-
logical = l & logical
238+
if not_present:
239+
# what other options do we have here....
240+
l = pd.isnull(x[column].cat.remove_categories(list(entries))) # noqa: E741
241+
logical = l & logical
244242
return logical
245243

246244

ema_workbench/examples/sd_cart_wcm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
results = load_results(fn)
2020
x, outcomes = results
2121

22+
x = x.drop(["scenario", "model", "policy"], axis=1)
23+
2224
ooi = "throughput_Rotterdam"
2325
outcome = outcomes[ooi] / default_flow
2426
y = outcome < 1

test/test_analysis/test_scenario_discovery_util.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,32 @@ def test_in_box(self):
8989
result = x.loc[logical]
9090
self.assertTrue(np.all(correct_result == result))
9191

92+
x = pd.DataFrame(
93+
[
94+
(0.1, 0, "a", True),
95+
(1.1, 1, "a", True),
96+
(2.1, 2, "b", True),
97+
(3.1, 3, "b", True),
98+
(4.1, 4, "c", False),
99+
(5.1, 5, "c", False),
100+
(6.1, 6, "d", False),
101+
(7.1, 7, "d", False),
102+
(8.1, 8, "e", False),
103+
(9.1, 9, "e", False),
104+
],
105+
columns=["a", "b", "c", "d"],
106+
)
107+
boxlim = pd.DataFrame(
108+
[(1.2, 0, {"a", "b", "c"}, True), (8.0, 7, {"a", "b", "c"}, True)], columns=["a", "b", "c", "d"]
109+
)
110+
x["c"] = x["c"].astype("category")
111+
112+
correct_result = x.loc[[2, 3], :]
113+
logical = sdutil._in_box(x, boxlim)
114+
result = x.loc[logical]
115+
self.assertTrue(np.all(correct_result == result))
116+
117+
92118
def test_make_box(self):
93119
x = pd.DataFrame([(0, 1, 2), (2, 5, 6), (3, 2, 1)], columns=["a", "b", "c"])
94120

0 commit comments

Comments
 (0)