Skip to content

Commit 4edb1b4

Browse files
committed
Added support for new assert_type call, which is being added to Python 3.11 and typing_extensions.
1 parent d78f737 commit 4edb1b4

6 files changed

Lines changed: 96 additions & 0 deletions

File tree

packages/pyright-internal/src/analyzer/typeEvaluator.ts

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6738,6 +6738,9 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
67386738
} else if (isFunction(baseTypeResult.type) && baseTypeResult.type.details.builtInName === 'reveal_type') {
67396739
// Handle the "typing.reveal_type" call.
67406740
returnResult = getTypeFromRevealType(node, expectedType);
6741+
} else if (isFunction(baseTypeResult.type) && baseTypeResult.type.details.builtInName === 'assert_type') {
6742+
// Handle the "typing.assert_type" call.
6743+
returnResult = getTypeFromAssertType(node, expectedType);
67416744
} else if (
67426745
isAnyOrUnknown(baseTypeResult.type) &&
67436746
node.leftExpression.nodeType === ParseNodeType.Name &&
@@ -6806,6 +6809,38 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
68066809
return returnResult;
68076810
}
68086811

6812+
function getTypeFromAssertType(node: CallNode, expectedType: Type | undefined): TypeResult {
6813+
if (
6814+
node.arguments.length !== 2 ||
6815+
node.arguments[0].argumentCategory !== ArgumentCategory.Simple ||
6816+
node.arguments[0].name !== undefined ||
6817+
node.arguments[0].argumentCategory !== ArgumentCategory.Simple ||
6818+
node.arguments[1].name !== undefined
6819+
) {
6820+
addError(Localizer.Diagnostic.assertTypeArgs(), node);
6821+
return { node, type: UnknownType.create() };
6822+
}
6823+
6824+
const arg0TypeResult = getTypeOfExpression(node.arguments[0].valueExpression, expectedType);
6825+
if (arg0TypeResult.isIncomplete) {
6826+
return { node, type: UnknownType.create(), isIncomplete: true };
6827+
}
6828+
6829+
const assertedType = convertToInstance(getTypeForArgumentExpectingType(node.arguments[1]).type);
6830+
6831+
if (!isTypeSame(assertedType, arg0TypeResult.type)) {
6832+
addError(
6833+
Localizer.Diagnostic.assertTypeTypeMismatch().format({
6834+
expected: printType(assertedType),
6835+
received: printType(arg0TypeResult.type),
6836+
}),
6837+
node.arguments[0].valueExpression
6838+
);
6839+
}
6840+
6841+
return { node, type: arg0TypeResult.type };
6842+
}
6843+
68096844
function getTypeFromRevealType(node: CallNode, expectedType: Type | undefined): TypeResult {
68106845
let arg0Value: ExpressionNode | undefined;
68116846
let expectedRevealTypeNode: ExpressionNode | undefined;

packages/pyright-internal/src/localization/localize.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,11 @@ export namespace Localizer {
208208
export const argTypePartiallyUnknown = () => getRawString('Diagnostic.argTypePartiallyUnknown');
209209
export const argTypeUnknown = () => getRawString('Diagnostic.argTypeUnknown');
210210
export const assertAlwaysTrue = () => getRawString('Diagnostic.assertAlwaysTrue');
211+
export const assertTypeArgs = () => getRawString('Diagnostic.assertTypeArgs');
212+
export const assertTypeTypeMismatch = () =>
213+
new ParameterizedString<{ expected: string; received: string }>(
214+
getRawString('Diagnostic.assertTypeTypeMismatch')
215+
);
211216
export const assignmentExprContext = () => getRawString('Diagnostic.assignmentExprContext');
212217
export const assignmentExprComprehension = () =>
213218
new ParameterizedString<{ name: string }>(getRawString('Diagnostic.assignmentExprComprehension'));

packages/pyright-internal/src/localization/package.nls.en-us.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
"argTypePartiallyUnknown": "Argument type is partially unknown",
2121
"argTypeUnknown": "Argument type is unknown",
2222
"assertAlwaysTrue": "Assert expression always evaluates to true",
23+
"assertTypeArgs": "\"assert_type\" expects two positional arguments",
24+
"assertTypeTypeMismatch": "\"assert_type\" mismatch: expected \"{expected}\" but received \"{received}\"",
2325
"assignmentExprContext": "Assignment expression must be within module, function or lambda",
2426
"assignmentExprComprehension": "Assignment expression target \"{name}\" cannot use same name as comprehension for target",
2527
"assignmentInProtocol": "Instance or class variables within a Protocol class must be explicitly declared within the class body",
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# This sample tests the assert_type call.
2+
3+
from typing import Any, Literal
4+
from typing_extensions import assert_type
5+
6+
def func1():
7+
# This should generate an error.
8+
assert_type()
9+
10+
# This should generate an error.
11+
assert_type(1)
12+
13+
# This should generate an error.
14+
assert_type(1, 2, 3)
15+
16+
# This should generate an error.
17+
assert_type(*[])
18+
19+
20+
def func2(x: int, y: int | str):
21+
assert_type(x, int)
22+
23+
# This should generate an error.
24+
assert_type(x, str)
25+
26+
# This should generate an error.
27+
assert_type(x, Any)
28+
29+
x = 3
30+
assert_type(x, Literal[3])
31+
32+
# This should generate an error.
33+
assert_type(x, int)
34+
35+
assert_type(y, int | str)
36+
assert_type(y, str | int)
37+
38+
# This should generate an error.
39+
assert_type(y, str)
40+
41+
# This should generate an error.
42+
assert_type(y, None)
43+
44+
# This should generate two errors.
45+
assert_type(y, 3)
46+

packages/pyright-internal/src/tests/typeEvaluator2.test.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,12 @@ test('RevealedType1', () => {
332332
TestUtils.validateResults(analysisResults, 2, 0, 7);
333333
});
334334

335+
test('AssertType1', () => {
336+
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['assertType1.py']);
337+
338+
TestUtils.validateResults(analysisResults, 11);
339+
});
340+
335341
test('NameBindings1', () => {
336342
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['nameBindings1.py']);
337343

packages/pyright-internal/typeshed-fallback/stdlib/typing_extensions.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,5 +166,7 @@ def dataclass_transform(
166166

167167
# Types not yet implemented in typing_extensions library
168168

169+
def assert_type(val: _T, typ: Any, /) -> _T: ...
170+
169171
# Proposed extension to PEP 647
170172
StrictTypeGuard: _SpecialForm = ...

0 commit comments

Comments
 (0)