Skip to content

Commit 982c35f

Browse files
committed
Infer type predicate for simple arrow functions
1 parent 5dab24a commit 982c35f

File tree

1 file changed

+174
-3
lines changed

1 file changed

+174
-3
lines changed

src/compiler/checker.ts

+174-3
Original file line numberDiff line numberDiff line change
@@ -6984,15 +6984,186 @@ namespace ts {
69846984
}
69856985
else {
69866986
const declaration = signature.declaration;
6987-
signature.resolvedTypePredicate = declaration && declaration.type && declaration.type.kind === SyntaxKind.TypePredicate ?
6988-
createTypePredicateFromTypePredicateNode(declaration.type as TypePredicateNode) :
6989-
noTypePredicate;
6987+
6988+
if (declaration && declaration.type && declaration.type.kind === SyntaxKind.TypePredicate) {
6989+
signature.resolvedTypePredicate = createTypePredicateFromTypePredicateNode(declaration.type as TypePredicateNode);
6990+
}
6991+
else if (declaration && !declaration.type && declaration.kind === SyntaxKind.ArrowFunction) {
6992+
signature.resolvedTypePredicate = inferTypePredicateOfArrowSignature(signature) || noTypePredicate;
6993+
}
6994+
else {
6995+
signature.resolvedTypePredicate = noTypePredicate;
6996+
}
69906997
}
69916998
Debug.assert(!!signature.resolvedTypePredicate);
69926999
}
69937000
return signature.resolvedTypePredicate === noTypePredicate ? undefined : signature.resolvedTypePredicate;
69947001
}
69957002

7003+
function inferTypePredicateOfArrowSignature(signature: Signature): IdentifierTypePredicate | undefined {
7004+
Debug.assert(signature.declaration.kind === SyntaxKind.ArrowFunction);
7005+
7006+
const arrow = signature.declaration as ArrowFunction;
7007+
7008+
// not inferring for blocks is an arbitrary but reasonable choice
7009+
if (!(getReturnTypeOfSignature(signature).flags & TypeFlags.Boolean) || arrow.body.kind === SyntaxKind.Block) {
7010+
return undefined;
7011+
}
7012+
7013+
const params = signature.parameters;
7014+
const paramDecls = params.map(p => getDeclarationOfKind<ParameterDeclaration>(p, SyntaxKind.Parameter));
7015+
7016+
return inferTypePredicateFromExpression(arrow.body, /*negated*/ false);
7017+
7018+
function inferTypePredicateFromExpression(expr: Expression, negated: boolean): IdentifierTypePredicate | undefined {
7019+
switch (expr.kind) {
7020+
case SyntaxKind.CallExpression:
7021+
return inferTypePredicateFromCallExpression(expr as CallExpression, negated);
7022+
case SyntaxKind.ParenthesizedExpression:
7023+
return inferTypePredicateFromExpression((<ParenthesizedExpression>expr).expression, negated);
7024+
case SyntaxKind.PrefixUnaryExpression:
7025+
return (<PrefixUnaryExpression>expr).operator === SyntaxKind.ExclamationToken ?
7026+
inferTypePredicateFromExpression((<PrefixUnaryExpression>expr).operand, !negated) :
7027+
undefined;
7028+
case SyntaxKind.BinaryExpression:
7029+
return inferTypePredicateFromBinaryExpression(expr as BinaryExpression, negated);
7030+
default:
7031+
return undefined;
7032+
}
7033+
}
7034+
7035+
function inferTypePredicateFromCallExpression(expr: CallExpression, negated: boolean) {
7036+
if (negated) {
7037+
// We can't deny the antecedent
7038+
return undefined;
7039+
}
7040+
7041+
const signature = getResolvedSignature(expr);
7042+
const typePredicate = getTypePredicateOfSignature(signature);
7043+
7044+
if (!typePredicate || !isIdentifierTypePredicate(typePredicate)) {
7045+
return undefined;
7046+
}
7047+
7048+
const argument = expr.arguments[typePredicate.parameterIndex];
7049+
const paramIndex = argument ? findIndex(paramDecls, p => isMatchingReference(p.name, argument)) : -1;
7050+
7051+
if (paramIndex >= 0) {
7052+
return createIdentifierTypePredicate(params[paramIndex].escapedName as string, paramIndex, typePredicate.type);
7053+
}
7054+
else {
7055+
return undefined;
7056+
}
7057+
}
7058+
7059+
function inferTypePredicateFromBinaryExpression(expr: BinaryExpression, negated: boolean): IdentifierTypePredicate | undefined {
7060+
switch (expr.operatorToken.kind) {
7061+
case SyntaxKind.EqualsEqualsEqualsToken:
7062+
case SyntaxKind.ExclamationEqualsEqualsToken:
7063+
const equality = negated !== (expr.operatorToken.kind === SyntaxKind.EqualsEqualsEqualsToken);
7064+
const leftType = getTypeOfExpression(expr.left);
7065+
const rightType = getTypeOfExpression(expr.right);
7066+
const isLeftConstant = isStaticallyKnownConstant(expr.left);
7067+
if (isLeftConstant === isStaticallyKnownConstant(expr.right)) {
7068+
return undefined;
7069+
}
7070+
7071+
const literalType = isLeftConstant ? leftType : rightType;
7072+
const subject = isLeftConstant ? expr.right : expr.left;
7073+
7074+
if (subject.kind === SyntaxKind.PropertyAccessExpression) {
7075+
const propAccess = subject as PropertyAccessExpression;
7076+
const paramIndex = findIndex(paramDecls, p => isMatchingReference(p.name, propAccess.expression));
7077+
7078+
if (paramIndex < 0) {
7079+
return undefined;
7080+
}
7081+
7082+
const param = params[paramIndex];
7083+
const paramType = getTypeOfVariableOrParameterOrProperty(param);
7084+
7085+
if (isDiscriminantProperty(paramType, propAccess.name.escapedText)) {
7086+
const filter = (t: Type) => equality === isTypeComparableTo(literalType, t);
7087+
const impliedType = narrowTypeByDiscriminant(paramType, propAccess, t0 => filterType(t0, filter));
7088+
7089+
return createIdentifierTypePredicate(param.escapedName as string, paramIndex, impliedType);
7090+
}
7091+
7092+
return undefined;
7093+
}
7094+
else {
7095+
return undefined;
7096+
}
7097+
case SyntaxKind.AmpersandAmpersandToken:
7098+
case SyntaxKind.BarBarToken:
7099+
const conjunctive = negated !== (expr.operatorToken.kind === SyntaxKind.AmpersandAmpersandToken);
7100+
const left = inferTypePredicateFromExpression(expr.left, negated);
7101+
const right = inferTypePredicateFromExpression(expr.right, negated);
7102+
7103+
if (left && right && left.parameterIndex === right.parameterIndex) {
7104+
if (conjunctive) {
7105+
const impliedType = filterType(left.type, t => isTypeComparableTo(right.type, t));
7106+
return createIdentifierTypePredicate(left.parameterName, left.parameterIndex, impliedType);
7107+
}
7108+
else {
7109+
const impliedType = getUnionType([left.type, right.type], UnionReduction.None);
7110+
return createIdentifierTypePredicate(left.parameterName, left.parameterIndex, impliedType);
7111+
}
7112+
}
7113+
else if (conjunctive && (!left || !right)) {
7114+
return left || right;
7115+
}
7116+
else {
7117+
return undefined;
7118+
}
7119+
case SyntaxKind.InstanceOfKeyword:
7120+
const paramIndex = findIndex(paramDecls, p => isMatchingReference(p.name, expr.left));
7121+
7122+
if (paramIndex >= 0) {
7123+
const param = params[paramIndex];
7124+
const paramType = getTypeOfVariableOrParameterOrProperty(param);
7125+
7126+
const constructorType = getTypeOfExpression(expr.right);
7127+
const instanceType = getTypeOfPropertyOfType(constructorType, "prototype" as __String);
7128+
const filter = (t: Type) => negated !== isTypeComparableTo(instanceType, t);
7129+
const impliedType = filterType(paramType, filter);
7130+
return createIdentifierTypePredicate(param.escapedName as string, paramIndex, impliedType);
7131+
}
7132+
else {
7133+
return undefined;
7134+
}
7135+
default:
7136+
return undefined;
7137+
}
7138+
}
7139+
7140+
function isStaticallyKnownConstant(expr: Expression): boolean {
7141+
const isLiteral = expr.kind === SyntaxKind.TrueKeyword
7142+
|| expr.kind === SyntaxKind.FalseKeyword
7143+
|| expr.kind === SyntaxKind.NumericLiteral
7144+
|| expr.kind === SyntaxKind.StringLiteral;
7145+
7146+
if (isLiteral) {
7147+
return true;
7148+
}
7149+
else if (isEntityNameExpression(expr)) {
7150+
const resolved = resolveEntityName(expr, SymbolFlags.Value, /*ignoreErrors*/ true);
7151+
// TODO const foo = 'Should this also be considered?';
7152+
return !!(resolved && (resolved.flags & SymbolFlags.EnumMember));
7153+
}
7154+
else {
7155+
return false;
7156+
}
7157+
}
7158+
7159+
function narrowTypeByDiscriminant(type: Type, propAccess: PropertyAccessExpression, narrowType: (t: Type) => Type): Type {
7160+
const propName = propAccess.name.escapedText;
7161+
const propType = getTypeOfPropertyOfType(type, propName);
7162+
const narrowedPropType = propType && narrowType(propType);
7163+
return propType === narrowedPropType ? type : filterType(type, t => isTypeComparableTo(getTypeOfPropertyOfType(t, propName), narrowedPropType));
7164+
}
7165+
}
7166+
69967167
function getReturnTypeOfSignature(signature: Signature): Type {
69977168
if (!signature.resolvedReturnType) {
69987169
if (!pushTypeResolution(signature, TypeSystemPropertyName.ResolvedReturnType)) {

0 commit comments

Comments
 (0)