Skip to content

Commit fa6df11

Browse files
committed
updates
1 parent 8974e17 commit fa6df11

File tree

13 files changed

+167
-164
lines changed

13 files changed

+167
-164
lines changed

ema_workbench/analysis/cart.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def __init__(
106106
self._stats = None
107107

108108
@property
109-
def boxes(self):
109+
def boxes(self) -> list[pd.DataFrame]:
110110
"""Return a list with the box limits for each terminal leaf.
111111
112112
Returns
@@ -185,7 +185,7 @@ def recurse(left, right, child, lineage=None):
185185
return self._boxes
186186

187187
@property
188-
def stats(self):
188+
def stats(self) -> list[dict]:
189189
"""Returns list with the scenario discovery statistics for each terminal leaf.
190190
191191
Returns
@@ -206,7 +206,7 @@ def stats(self):
206206
self._stats.append(boxstats)
207207
return self._stats
208208

209-
def _binary_stats(self, box, box_init):
209+
def _binary_stats(self, box: pd.DataFrame, box_init: pd.DataFrame) -> dict:
210210
indices = sdutil._in_box(self.x, box)
211211

212212
y_in_box = self.y[indices]
@@ -220,7 +220,7 @@ def _binary_stats(self, box, box_init):
220220
}
221221
return boxstats
222222

223-
def _regression_stats(self, box, box_init):
223+
def _regression_stats(self, box: pd.DataFrame, box_init: pd.DataFrame) -> dict:
224224
indices = sdutil._in_box(self.x, box)
225225

226226
y_in_box = self.y[indices]
@@ -232,7 +232,7 @@ def _regression_stats(self, box, box_init):
232232
}
233233
return boxstats
234234

235-
def _classification_stats(self, box, box_init):
235+
def _classification_stats(self, box: pd.DataFrame, box_init: pd.DataFrame) -> dict:
236236
indices = sdutil._in_box(self.x, box)
237237

238238
y_in_box = self.y[indices]
@@ -261,7 +261,7 @@ def _classification_stats(self, box, box_init):
261261
sdutil.RuleInductionType.CLASSIFICATION: _classification_stats,
262262
}
263263

264-
def build_tree(self):
264+
def build_tree(self) -> None:
265265
"""Train CART on the data."""
266266
min_samples = int(self.mass_min * self.x.shape[0])
267267

@@ -271,7 +271,7 @@ def build_tree(self):
271271
self.clf = tree.DecisionTreeClassifier(min_samples_leaf=min_samples)
272272
self.clf.fit(self._x, self.y)
273273

274-
def show_tree(self, mplfig=True, format="png"):
274+
def show_tree(self, mplfig: bool = True, format: str = "png"):
275275
"""Return a png (defaults) or svg of the tree.
276276
277277
On Windows, graphviz needs to be installed with conda.

ema_workbench/analysis/pairs_plotting.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
from __future__ import annotations
44

5+
import matplotlib.axes
56
import matplotlib.cm as cm
67
import matplotlib.gridspec as gridspec
78
import matplotlib.pyplot as plt
89
import numpy as np
10+
import pandas as pd
911

1012
from ..util import get_module_logger
1113

@@ -19,15 +21,15 @@
1921

2022

2123
def pairs_lines(
22-
experiments,
23-
outcomes,
24-
outcomes_to_show=None,
25-
group_by=None,
24+
experiments: pd.DataFrame,
25+
outcomes: dict[str, np.ndarray],
26+
outcomes_to_show: list[str] | None = None,
27+
group_by: str | None = None,
2628
grouping_specifiers=None,
27-
ylabels=None,
28-
legend=True,
29+
ylabels: dict[str, str] | None = None,
30+
legend: bool = True,
2931
**kwargs,
30-
):
32+
) -> tuple[plt.Figure, dict[str, matplotlib.axes.Axes]]:
3133
"""Generate a pairs lines multiplot.
3234
3335
It shows the behavior of two outcomes over time against
@@ -149,18 +151,18 @@ def simple_pairs_lines(ax, y_data, x_data, color):
149151

150152

151153
def pairs_density(
152-
experiments,
153-
outcomes,
154-
outcomes_to_show=None,
155-
group_by=None,
154+
experiments: pd.DataFrame,
155+
outcomes: dict[str, np.ndarray],
156+
outcomes_to_show: list[str] | None = None,
157+
group_by: str | None = None,
156158
grouping_specifiers=None,
157-
ylabels=None,
158-
point_in_time=-1,
159-
log=True,
160-
gridsize=50,
161-
colormap="coolwarm",
162-
filter_scalar=True,
163-
):
159+
ylabels: dict[str, str] | None = None,
160+
point_in_time: int = -1,
161+
log: bool = True,
162+
gridsize: int = 50,
163+
colormap: str = "coolwarm",
164+
filter_scalar: bool = True,
165+
) -> tuple[plt.Figure, dict[str, matplotlib.axes.Axes]]:
164166
"""Generate a pairs hexbin density multiplot.
165167
166168
In case of time-series data, the end states are used.
@@ -394,17 +396,17 @@ def simple_pairs_density(
394396

395397

396398
def pairs_scatter(
397-
experiments,
398-
outcomes,
399-
outcomes_to_show=None,
400-
group_by=None,
399+
experiments: pd.DataFrame,
400+
outcomes: dict[str, np.ndarray],
401+
outcomes_to_show: list[str] | None = None,
402+
group_by: str | None = None,
401403
grouping_specifiers=None,
402-
ylabels=None,
403-
legend=True,
404-
point_in_time=-1,
405-
filter_scalar=False,
404+
ylabels: dict[str, str] | None = None,
405+
legend: bool = True,
406+
point_in_time: int = -1,
407+
filter_scalar: bool = False,
406408
**kwargs,
407-
):
409+
) -> tuple[plt.Figure, dict[str, matplotlib.axes.Axes]]:
408410
"""Generate a pairs scatter multiplot.
409411
410412
In case of time-series data, the end states are used.

ema_workbench/analysis/parcoords.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
__all__ = ["ParallelAxes", "get_limits"]
1717

1818

19-
def setup_parallel_plot(labels, minima, maxima, formatter=None, fs=14, rot=90):
19+
def setup_parallel_plot(labels: list[str], minima: pd.Series, maxima: pd.Series, formatter: dict[str, str] | None = None, fs: int = 14, rot: float = 90) -> tuple[plt.Figure, list[plt.Axes], dict]:
2020
"""Helper function for setting up the parallel axes plot.
2121
2222
Parameters
@@ -102,7 +102,7 @@ def setup_parallel_plot(labels, minima, maxima, formatter=None, fs=14, rot=90):
102102
return fig, axes, tick_labels
103103

104104

105-
def get_limits(data):
105+
def get_limits(data: pd.DataFrame) -> pd.DataFrame:
106106
"""Helper function to get limits of a FataFrame that can serve as input to ParallelAxis.
107107
108108
Parameters
@@ -167,7 +167,7 @@ class ParallelAxes:
167167
168168
"""
169169

170-
def __init__(self, limits, formatter=None, fontsize=14, rot=90):
170+
def __init__(self, limits: pd.DataFrame, formatter: dict[str, str] | None = None, fontsize: int = 14, rot: float = 90):
171171
"""Init.
172172
173173
Parameters
@@ -218,7 +218,7 @@ def __init__(self, limits, formatter=None, fontsize=14, rot=90):
218218
plt.tight_layout(h_pad=0, w_pad=0)
219219
plt.subplots_adjust(wspace=0)
220220

221-
def plot(self, data, color=None, label=None, **kwargs):
221+
def plot(self, data: pd.DataFrame | pd.Series, color=None, label: str | None = None, **kwargs) -> None:
222222
"""Plot data on parallel axes.
223223
224224
Parameters
@@ -259,7 +259,7 @@ def plot(self, data, color=None, label=None, **kwargs):
259259
# plot the data
260260
self._plot(normalized_data, color=color, **kwargs)
261261

262-
def legend(self):
262+
def legend(self) -> None:
263263
"""Add a legend to the figure."""
264264
artists = []
265265
labels = []
@@ -301,7 +301,7 @@ def _plot(self, data, **kwargs):
301301
if label_j in self.flipped_axes:
302302
self._update_plot_data(ax, 1, lines=lines)
303303

304-
def invert_axis(self, axis):
304+
def invert_axis(self, axis: str | list[str]) -> None:
305305
"""Flip direction for specified axis.
306306
307307
Parameters

ema_workbench/analysis/plotting.py

Lines changed: 43 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import matplotlib.pyplot as plt
88
import numpy as np
9+
import pandas as pd
910
from matplotlib.patches import ConnectionPatch
1011

1112
from ..util import EMAError, get_module_logger
@@ -36,18 +37,18 @@
3637

3738

3839
def envelopes(
39-
experiments,
40-
outcomes,
41-
outcomes_to_show=None,
42-
group_by=None,
40+
experiments: pd.DataFrame,
41+
outcomes: dict[str, np.ndarray],
42+
outcomes_to_show: str | list[str] | None = None,
43+
group_by: str | None = None,
4344
grouping_specifiers=None,
44-
density=None,
45-
fill=False,
46-
legend=True,
47-
titles=None,
48-
ylabels=None,
49-
log=False,
50-
):
45+
density: Density | None = None,
46+
fill: bool = False,
47+
legend: bool = True,
48+
titles: dict[str, str] | None = None,
49+
ylabels: dict[str, str] | None = None,
50+
log: bool = False,
51+
) -> tuple[plt.Figure, dict[str, plt.Axes]]:
5152
"""Make envelop plots.
5253
5354
An envelope shows over time the minimum and maximum value for a set
@@ -260,19 +261,19 @@ def single_envelope(outcomes, outcome_to_plot, time, density, ax, ax_d, fill, lo
260261

261262

262263
def lines(
263-
experiments,
264-
outcomes,
265-
outcomes_to_show=None,
266-
group_by=None,
264+
experiments: pd.DataFrame,
265+
outcomes: dict[str, np.ndarray],
266+
outcomes_to_show: str | list[str] | None = None,
267+
group_by: str | None = None,
267268
grouping_specifiers=None,
268-
density="",
269-
legend=True,
270-
titles=None,
271-
ylabels=None,
272-
experiments_to_show=None,
273-
show_envelope=False,
274-
log=False,
275-
):
269+
density: Density | str = "",
270+
legend: bool = True,
271+
titles: dict[str, str] | None = None,
272+
ylabels: dict[str, str] | None = None,
273+
experiments_to_show: np.ndarray | None = None,
274+
show_envelope: bool = False,
275+
log: bool = False,
276+
) -> tuple[plt.Figure, dict[str, plt.Axes]]:
276277
"""Visualize results from experiments as line plots.
277278
278279
It is thus to be used in case of time
@@ -601,13 +602,13 @@ def simple_lines(outcomes, outcome_to_plot, time, density, ax, ax_d, log):
601602

602603

603604
def kde_over_time(
604-
experiments,
605-
outcomes,
606-
outcomes_to_show=None,
607-
group_by=None,
605+
experiments: pd.DataFrame,
606+
outcomes: dict[str, np.ndarray],
607+
outcomes_to_show: str | list[str] | None = None,
608+
group_by: str | None = None,
608609
grouping_specifiers=None,
609-
colormap="viridis",
610-
log=True,
610+
colormap: str = "viridis",
611+
log: bool = True,
611612
):
612613
"""Plot a KDE over time. The KDE is visualized through a heatmap.
613614
@@ -679,21 +680,21 @@ def kde_over_time(
679680

680681

681682
def multiple_densities(
682-
experiments,
683-
outcomes,
684-
points_in_time=None,
685-
outcomes_to_show=None,
686-
group_by=None,
683+
experiments: pd.DataFrame,
684+
outcomes: dict[str, np.ndarray],
685+
points_in_time: list[float] | None = None,
686+
outcomes_to_show: str | list[str] | None = None,
687+
group_by: str | None = None,
687688
grouping_specifiers=None,
688-
density=Density.KDE,
689-
legend=True,
690-
titles=None,
691-
ylabels=None,
692-
experiments_to_show=None,
693-
plot_type=PlotType.ENVELOPE,
694-
log=False,
689+
density: Density = Density.KDE,
690+
legend: bool = True,
691+
titles: dict[str, str] | None = None,
692+
ylabels: dict[str, str] | None = None,
693+
experiments_to_show: np.ndarray | None = None,
694+
plot_type: PlotType = PlotType.ENVELOPE,
695+
log: bool = False,
695696
**kwargs,
696-
):
697+
) -> tuple[list[plt.Figure], dict[str, dict[str, plt.Axes]]]:
697698
"""Make an envelope plot with multiple density plots over the run time.
698699
699700
Parameters

0 commit comments

Comments
 (0)