Skip to content

Subclasses in dict values not inferred as subclasses #2985

@adamjstewart

Description

@adamjstewart

Summary

The following code:

from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection

MetricCollection({
    'mse': MeanSquaredError(),
    'mae': MeanAbsoluteError(),
})

fails in ty 0.0.21 (but not prior versions):

> ty check test.py
error[invalid-argument-type]: Argument to bound method `__init__` is incorrect
 --> test.py:3:18
  |
1 |   from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection
2 |
3 |   MetricCollection({
  |  __________________^
4 | |     'mse': MeanSquaredError(),
5 | |     'mae': MeanAbsoluteError(),
6 | | })
  | |_^ Expected `Metric | MetricCollection | Sequence[Metric | MetricCollection] | dict[str, Metric | MetricCollection]`, found `dict[str, MeanSquaredError | MeanAbsoluteError]`
  |
info: Method defined here
   --> /Users/Adam/spack/var/spack/environments/default/.spack-env/._view/gd2vtbwe4fojggua4emapinyxnqhryg4/lib/python3.14/site-packages/torchmetrics/collections.py:199:9
    |
197 |       __jit_unused_properties__: ClassVar[list[str]] = ["metric_state"]
198 |
199 |       def __init__(
    |           ^^^^^^^^
200 |           self,
201 | /         metrics: Union[
202 | |             Metric,
203 | |             "MetricCollection",
204 | |             Sequence[Union[Metric, "MetricCollection"]],
205 | |             dict[str, Union[Metric, "MetricCollection"]],
206 | |         ],
    | |_________- Parameter declared here
207 |           *additional_metrics: Metric,
208 |           prefix: Optional[str] = None,
    |
info: rule `invalid-argument-type` is enabled by default

Found 1 diagnostic

However, MeanSquaredError and MeanAbsoluteError are subclasses of Metric, so I believe this should be valid. For now, easiest workaround is to store these in a variable with explicit type hints.

Version

ty 0.0.21

Metadata

Metadata

Assignees

No one assigned

    Labels

    bidirectional inferenceInference of types that takes into account the context of a declared type or expected type

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions