|
1 | | -import typing |
2 | | -import typing_extensions |
3 | 1 | from .annotations import type_from_value |
4 | 2 | from .error_code import ErrorCode |
5 | | -from .extensions import reveal_type |
| 3 | +from .extensions import assert_type, reveal_type |
6 | 4 | from .format_strings import parse_format_string |
7 | 5 | from .predicates import IsAssignablePredicate |
8 | 6 | from .safe import safe_hasattr, safe_isinstance, safe_issubclass |
|
52 | 50 | concrete_values_from_iterable, |
53 | 51 | kv_pairs_from_mapping, |
54 | 52 | make_weak, |
| 53 | + unannotate, |
55 | 54 | unite_values, |
56 | 55 | flatten_values, |
57 | 56 | replace_known_sequence_value, |
|
66 | 65 | import inspect |
67 | 66 | import warnings |
68 | 67 | from types import FunctionType |
| 68 | +import typing |
| 69 | +import typing_extensions |
69 | 70 | from typing import ( |
70 | 71 | Sequence, |
71 | 72 | TypeVar, |
@@ -1042,6 +1043,20 @@ def _cast_impl(ctx: CallContext) -> Value: |
1042 | 1043 | return type_from_value(typ, visitor=ctx.visitor, node=ctx.node) |
1043 | 1044 |
|
1044 | 1045 |
|
| 1046 | +def _assert_type_impl(ctx: CallContext) -> Value: |
| 1047 | + # TODO maybe we should walk over the whole value and remove Annotated. |
| 1048 | + val = unannotate(ctx.vars["val"]) |
| 1049 | + typ = ctx.vars["typ"] |
| 1050 | + expected_type = type_from_value(typ, visitor=ctx.visitor, node=ctx.node) |
| 1051 | + if val != expected_type: |
| 1052 | + ctx.show_error( |
| 1053 | + f"Type is {val} (expected {expected_type})", |
| 1054 | + error_code=ErrorCode.inference_failure, |
| 1055 | + arg="obj", |
| 1056 | + ) |
| 1057 | + return val |
| 1058 | + |
| 1059 | + |
1045 | 1060 | def _subclasses_impl(ctx: CallContext) -> Value: |
1046 | 1061 | """Overridden because typeshed types make it (T) => List[T] instead.""" |
1047 | 1062 | self_obj = ctx.vars["self"] |
@@ -1423,7 +1438,18 @@ def get_default_argspecs() -> Dict[object, Signature]: |
1423 | 1438 | callable=str.format, |
1424 | 1439 | ), |
1425 | 1440 | Signature.make( |
1426 | | - [SigParameter("typ"), SigParameter("val")], callable=cast, impl=_cast_impl |
| 1441 | + [SigParameter("typ", _POS_ONLY), SigParameter("val", _POS_ONLY)], |
| 1442 | + callable=cast, |
| 1443 | + impl=_cast_impl, |
| 1444 | + ), |
| 1445 | + Signature.make( |
| 1446 | + [ |
| 1447 | + SigParameter("val", _POS_ONLY, annotation=TypeVarValue(T)), |
| 1448 | + SigParameter("typ", _POS_ONLY), |
| 1449 | + ], |
| 1450 | + TypeVarValue(T), |
| 1451 | + callable=assert_type, |
| 1452 | + impl=_assert_type_impl, |
1427 | 1453 | ), |
1428 | 1454 | # workaround for https://github.com/python/typeshed/pull/3501 |
1429 | 1455 | Signature.make( |
@@ -1566,4 +1592,20 @@ def get_default_argspecs() -> Dict[object, Signature]: |
1566 | 1592 | callable=reveal_type_func, |
1567 | 1593 | ) |
1568 | 1594 | signatures.append(sig) |
| 1595 | + # Anticipating that this will be added to the stdlib |
| 1596 | + try: |
| 1597 | + assert_type_func = getattr(mod, "assert_type") |
| 1598 | + except AttributeError: |
| 1599 | + pass |
| 1600 | + else: |
| 1601 | + sig = Signature.make( |
| 1602 | + [ |
| 1603 | + SigParameter("val", _POS_ONLY, annotation=TypeVarValue(T)), |
| 1604 | + SigParameter("typ", _POS_ONLY), |
| 1605 | + ], |
| 1606 | + TypeVarValue(T), |
| 1607 | + callable=assert_type_func, |
| 1608 | + impl=_assert_type_impl, |
| 1609 | + ) |
| 1610 | + signatures.append(sig) |
1569 | 1611 | return {sig.callable: sig for sig in signatures} |
0 commit comments