1515from dataclasses import astuple , dataclass
1616from typing import (
1717 Any ,
18+ cast ,
1819 Dict ,
1920 List ,
2021 Mapping ,
@@ -217,7 +218,7 @@ def _plot_colorbar(
217218 )
218219 position = self ._config ['colorbar_position' ]
219220 orien = 'vertical' if position in ('left' , 'right' ) else 'horizontal'
220- colorbar = ax .figure .colorbar (
221+ colorbar = cast ( plt . Figure , ax .figure ) .colorbar (
221222 mappable , colorbar_ax , ax , orientation = orien , ** self ._config .get ("colorbar_options" , {})
222223 )
223224 colorbar_ax .tick_params (axis = 'y' , direction = 'out' )
@@ -230,15 +231,15 @@ def _write_annotations(
230231 ax : plt .Axes ,
231232 ) -> None :
232233 """Writes annotations to the center of cells. Internal."""
233- for (center , annotation ), facecolor in zip (centers_and_annot , collection .get_facecolors ()):
234+ for (center , annotation ), facecolor in zip (centers_and_annot , collection .get_facecolor ()):
234235 # Calculate the center of the cell, assuming that it is a square
235236 # centered at (x=col, y=row).
236237 if not annotation :
237238 continue
238239 x , y = center
239- face_luminance = vis_utils .relative_luminance (facecolor )
240+ face_luminance = vis_utils .relative_luminance (facecolor ) # type: ignore
240241 text_color = 'black' if face_luminance > 0.4 else 'white'
241- text_kwargs = dict (color = text_color , ha = "center" , va = "center" )
242+ text_kwargs : Dict [ str , Any ] = dict (color = text_color , ha = "center" , va = "center" )
242243 text_kwargs .update (self ._config .get ('annotation_text_kwargs' , {}))
243244 ax .text (x , y , annotation , ** text_kwargs )
244245
@@ -295,6 +296,7 @@ def plot(
295296 show_plot = not ax
296297 if not ax :
297298 fig , ax = plt .subplots (figsize = (8 , 8 ))
299+ ax = cast (plt .Axes , ax )
298300 original_config = copy .deepcopy (self ._config )
299301 self .update_config (** kwargs )
300302 collection = self ._plot_on_axis (ax )
@@ -381,6 +383,7 @@ def plot(
381383 show_plot = not ax
382384 if not ax :
383385 fig , ax = plt .subplots (figsize = (8 , 8 ))
386+ ax = cast (plt .Axes , ax )
384387 original_config = copy .deepcopy (self ._config )
385388 self .update_config (** kwargs )
386389 qubits = set ([q for qubits in self ._value_map .keys () for q in qubits ])
0 commit comments