1010
1111from dascore .constants import PatchType
1212from dascore .exceptions import ParameterError
13- from dascore .units import get_quantity_str
13+ from dascore .units import get_quantity_str , maybe_convert_percent_to_fraction
1414from dascore .utils .patch import patch_function
1515from dascore .utils .plotting import (
1616 _format_time_axis ,
@@ -51,23 +51,25 @@ def _get_scale(scale, scale_type, data):
5151 Calculate the color bar scale limits based on scale and scale_type.
5252 """
5353 _validate_scale_type (scale_type )
54-
54+ # This ensures we have a list of the previous scale parameters.
55+ scale = maybe_convert_percent_to_fraction (scale )
5556 match (scale , scale_type ):
5657 # Case 1: Single value with relative scaling
5758 # Scale is symmetric around the mean, using fraction of dynamic range
58- case (scale , "relative" ) if isinstance (scale , float | int ):
59+ case (scale , "relative" ) if len (scale ) == 1 :
60+ scale = scale [0 ]
5961 mod = 0.5 * (np .nanmax (data ) - np .nanmin (data ))
6062 mean = np .nanmean (data )
6163 scale = np .asarray ([mean - scale * mod , mean + scale * mod ])
6264 # Case 2: No scale specified with relative scaling
6365 # Use Tukey's fence (C*IQR, C is normally 1.5) to exclude outliers.
6466 # This prevents a few extreme values from obscuring the majority of the
6567 # data at the cost of a slight performance penalty.
66- case (None , "relative" ):
67- q2 , q3 = np .nanpercentile (data , [25 , 75 ])
68+ case ([] , "relative" ):
69+ q1 , q3 = np .nanpercentile (data , [25 , 75 ])
6870 dmin , dmax = np .nanmin (data ), np .nanmax (data )
69- diff = q3 - q2 # Interquartile range (IQR)
70- q_lower = np .nanmax ([q2 - diff * IQR_FENCE_MULTIPLIER , dmin ])
71+ diff = q3 - q1 # Interquartile range (IQR)
72+ q_lower = np .nanmax ([q1 - diff * IQR_FENCE_MULTIPLIER , dmin ])
7173 q_upper = np .nanmin ([q3 + diff * IQR_FENCE_MULTIPLIER , dmax ])
7274 scale = np .asarray ([q_lower , q_upper ])
7375 return scale
@@ -88,8 +90,8 @@ def _get_scale(scale, scale_type, data):
8890 # Map [0, 1] to [data_min, data_max]
8991 scale = dmin + scale * data_range
9092 # Case 4: Absolute scaling
91- case (scale , "absolute" ) if isinstance (scale , int | float ) :
92- scale = np .array ([- abs (scale ), abs (scale )])
93+ case (scale , "absolute" ) if len (scale ) == 1 :
94+ scale = np .array ([- abs (scale [ 0 ] ), abs (scale [ 0 ] )])
9395 # Case 5: Absolute scaling with sequence: no match needed.
9496
9597 # Scale values are used directly as colorbar limits
@@ -168,13 +170,16 @@ def waterfall(
168170 --------
169171 >>> # Plot with default scaling (uses 1.5*IQR fence to exclude outliers)
170172 >>> import dascore as dc
173+ >>> from dascore.units import percent
171174 >>> patch = dc.get_example_patch("example_event_1").normalize("time")
172175 >>> _ = patch.viz.waterfall()
173176 >>>
174- >>> # Use relative scaling with a float to saturate at 10% of dynamic range
175- >>> # This centers the colorbar around the mean and extends ± 10% of the
176- >>> # data's dynamic range in each direction
177+ >>> # Use relative scaling with a tuple to show a specific fraction
178+ >>> # of data range. Scale values of (0.1, 0.9) map to 10% and 90%
179+ >>> # of the [data_min, data_max] range data's dynamic range
177180 >>> _ = patch.viz.waterfall(scale=0.1, scale_type="relative")
181+ >>> # Likewise, percent units can be used for additional clarity
182+ >>> _ = patch.viz.waterfall(scale=10*percent, scale_type="absolute")
178183 >>>
179184 >>> # Use relative scaling with a tuple to show the middle 80% of data range
180185 >>> # Scale values of (0.1, 0.9) map to 10th and 90th percentile of data
@@ -218,15 +223,13 @@ def waterfall(
218223 """
219224 # Validate inputs
220225 patch = _validate_patch_dims (patch )
221-
222226 # Setup axes and data
223227 ax = _get_ax (ax )
224228 cmap = _get_cmap (cmap )
225229 data = np .log10 (np .absolute (patch .data )) if log else patch .data
226230 dims = patch .dims
227231 dims_r = tuple (reversed (dims ))
228232 coords = {dim : patch .coords .get_array (dim ) for dim in dims }
229-
230233 # Plot using imshow and set colorbar limits
231234 extents = _get_extents (dims_r , coords )
232235 scale = _get_scale (scale , scale_type , data )
@@ -244,14 +247,11 @@ def waterfall(
244247 interpolation_stage = "data" ,
245248 )
246249 im .set_clim (scale )
247-
248250 # Format axis labels and handle time-like dimensions
249251 _format_axis_labels (ax , patch , dims_r )
250-
251252 # Add colorbar if requested
252253 if cmap is not None :
253254 _add_colorbar (ax , im , patch , log )
254-
255255 if show :
256256 plt .show ()
257257 return ax
0 commit comments