Skip to content

Advanced generic call argument inference #2521

@ibraheemdev

Description

@ibraheemdev

We currently only account for the call expression type annotation when inferring the arguments of a generic call. This means we cannot currently solve calls where arguments are dependent on one another, e.g.,

def lst[T](x: T) -> list[T]:
    return [x]

def f[T](x: T, y: list[T]) -> T:
    return x

def _(x: int, y: int | str):
    _: int | str = f(y, lst(x))
    f(y, lst(x)) # Argument to function `f` is incorrect: Expected `list[int | str]`, found `list[int]`

Pyright handles the above case fine, but mypy seems to give up. I suspect pyright is using a heuristic where it infers "simple" argument expressions first, i.e., arguments that cannot be influenced by type context. However, this means it is not able to solve more advanced cases, e.g.,

def lst[T](x: T) -> list[T]:
    return [x]

def f[T](x: T, y: list[T], z: list[T]) -> T:
    return x

def _(x: int, y: int | str, z: int | str | None):
    _: int | str | None = f(y, lst(x), lst(z))
    f(y, lst(x), lst(z)) # Argument to function `f` is incorrect: Expected `list[int | str | None]`, found `list[int]`

Another problem is that we consider the call expression type annotation even if it is covariant position, which can lead to false positives, e.g.,

from typing import Sequence

def lst[T](x: T) -> list[T]:
    return [x]

def f[T](x: T, y: list[T], z: list[T]) -> Sequence[T]:
    return [x]

def _(x: int, z: list[int]):
    _: Sequence[int] = f(x, lst(x), z)
    _: Sequence[int | str] = f(x, lst(x), z) # Argument to function `f` is incorrect: Expected `list[int | str]`, found `list[int]`

Another example with TypedDict:

from typing import TypedDict

class A(TypedDict):
    a: int
    b: int

def lst[T](x: T) -> list[T]:
    return [x]

def f[T](x: T, y: list[T]) -> T:
    return x

def _(a: A):
    _: A = f(a, lst({ "a": 1, "b": 2 }))
    f(a, lst({ "a": 1, "b": 2 })) # Argument to function `f` is incorrect: Expected `list[B | dict[Unknown | str, Unknown | int]]`, found `list[dict[Unknown | str, Unknown | int]]`

We may be able to solve this problem completely by lazily inferring arguments as constraint sets, and unifying them after all arguments and type context has been inferred (cc @dcreager). We could also implement a simpler heuristic similar to pyright.

Note that some of these false positives are being hidden by our Unknown unioning, and might start popping up more frequently after #1240 is resolved.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bidirectional inferenceInference of types that takes into account the context of a declared type or expected typegenericsBugs or features relating to ty's generics implementation

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions