Skip to content

Commit 299bfef

Browse files
authored
remove threshold and threshold type from PRIM, add explicit support for Prim for regression (i.e., bump hunting) (#418)
1 parent 6323372 commit 299bfef

22 files changed

+1239
-1093
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
- run: pip install uv
4545
- name: Install dependencies
4646
run: |
47-
uv pip install .[dev,cov] --system ${{ matrix.pip-pre }}
47+
uv pip install .[dev,cov,graph] --system ${{ matrix.pip-pre }}
4848
- name: Install MPI and mpi4py
4949
if: matrix.test-mpi == true
5050
run: |

ema_workbench/analysis/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
__all__ = [
88
"CART",
99
"Density",
10-
"DiagKind",
1110
"Logit",
1211
"PlotType",
1312
"Prim",
@@ -24,13 +23,11 @@
2423
"pca_preprocess",
2524
"run_constrained_prim",
2625
"set_fig_to_bw",
27-
"setup_cart",
28-
"setup_prim"
2926
]
3027

3128
from . import pairs_plotting
3229
from .b_and_w_plotting import set_fig_to_bw
33-
from .cart import CART, setup_cart
30+
from .cart import CART
3431
from .feature_scoring import (
3532
get_ex_feature_scores,
3633
get_feature_scores_all,
@@ -40,6 +37,5 @@
4037
from .logistic_regression import Logit
4138
from .plotting import envelopes, kde_over_time, lines, multiple_densities
4239
from .plotting_util import Density, PlotType
43-
from .prim import Prim, pca_preprocess, run_constrained_prim, setup_prim
44-
from .prim_util import DiagKind
40+
from .prim import Prim, pca_preprocess, run_constrained_prim
4541
from .scenario_discovery_util import RuleInductionType

ema_workbench/analysis/cart.py

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -25,48 +25,10 @@
2525
# .. codeauthor:: jhkwakkel <j.h.kwakkel (at) tudelft (dot) nl>
2626

2727

28-
__all__ = ["CART", "setup_cart"]
28+
__all__ = ["CART"]
2929
_logger = get_module_logger(__name__)
3030

3131

32-
def setup_cart(results, classify, incl_unc=None, mass_min=0.05):
33-
"""Helper function for performing cart in combination with data generated by the workbench.
34-
35-
Parameters
36-
----------
37-
results : tuple of DataFrame and dict with numpy arrays
38-
the return from :meth:`perform_experiments`.
39-
classify : string, function or callable
40-
either a string denoting the outcome of interest to
41-
use or a function.
42-
incl_unc : list of strings, optional
43-
mass_min : float, optional
44-
45-
46-
Raises:
47-
------
48-
TypeError
49-
if classify is not a string or a callable.
50-
51-
"""
52-
x, outcomes = results
53-
54-
if incl_unc is not None:
55-
drop_names = set(x.columns.values.tolist()) - set(incl_unc)
56-
x = x.drop(drop_names, axis=1)
57-
58-
if isinstance(classify, str):
59-
y = outcomes[classify]
60-
mode = sdutil.RuleInductionType.REGRESSION
61-
elif callable(classify):
62-
y = classify(outcomes)
63-
mode = sdutil.RuleInductionType.BINARY
64-
else:
65-
raise TypeError(f"Unknown type for classify: {type(classify)}")
66-
67-
return CART(x, y, mass_min, mode=mode)
68-
69-
7032
class CART(sdutil.OutputFormatterMixin):
7133
"""CART algorithm.
7234

ema_workbench/analysis/dimensional_stacking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def create_pivot_plot(
411411

412412
x_y_concat = pd.concat([discretized_x, ooi], axis=1)
413413
pvt = pd.pivot_table(
414-
x_y_concat, values=ooi_label, index=rows, columns=columns, dropna=False
414+
x_y_concat, values=ooi_label, index=rows, columns=columns, dropna=False, observed=False
415415
)
416416

417417
fig = plot_pivot_table(pvt, plot_labels=labels, plot_cats=categories)

ema_workbench/analysis/logistic_regression.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,9 @@ class Logit:
136136

137137
# TODO:: peeling trajectory is a misnomer, requires fix to CurEntry
138138

139-
coverage = CurEntry("coverage")
140-
density = CurEntry("density")
141-
res_dim = CurEntry("res_dim")
139+
coverage = CurEntry(float)
140+
density = CurEntry(float)
141+
res_dim = CurEntry(int)
142142

143143
sep = "!?!"
144144

0 commit comments

Comments
 (0)