@@ -6984,15 +6984,186 @@ namespace ts {
6984
6984
}
6985
6985
else {
6986
6986
const declaration = signature.declaration;
6987
- signature.resolvedTypePredicate = declaration && declaration.type && declaration.type.kind === SyntaxKind.TypePredicate ?
6988
- createTypePredicateFromTypePredicateNode(declaration.type as TypePredicateNode) :
6989
- noTypePredicate;
6987
+
6988
+ if (declaration && declaration.type && declaration.type.kind === SyntaxKind.TypePredicate) {
6989
+ signature.resolvedTypePredicate = createTypePredicateFromTypePredicateNode(declaration.type as TypePredicateNode);
6990
+ }
6991
+ else if (declaration && !declaration.type && declaration.kind === SyntaxKind.ArrowFunction) {
6992
+ signature.resolvedTypePredicate = inferTypePredicateOfArrowSignature(signature) || noTypePredicate;
6993
+ }
6994
+ else {
6995
+ signature.resolvedTypePredicate = noTypePredicate;
6996
+ }
6990
6997
}
6991
6998
Debug.assert(!!signature.resolvedTypePredicate);
6992
6999
}
6993
7000
return signature.resolvedTypePredicate === noTypePredicate ? undefined : signature.resolvedTypePredicate;
6994
7001
}
6995
7002
7003
+ function inferTypePredicateOfArrowSignature(signature: Signature): IdentifierTypePredicate | undefined {
7004
+ Debug.assert(signature.declaration.kind === SyntaxKind.ArrowFunction);
7005
+
7006
+ const arrow = signature.declaration as ArrowFunction;
7007
+
7008
+ // not inferring for blocks is an arbitrary but reasonable choice
7009
+ if (!(getReturnTypeOfSignature(signature).flags & TypeFlags.Boolean) || arrow.body.kind === SyntaxKind.Block) {
7010
+ return undefined;
7011
+ }
7012
+
7013
+ const params = signature.parameters;
7014
+ const paramDecls = params.map(p => getDeclarationOfKind<ParameterDeclaration>(p, SyntaxKind.Parameter));
7015
+
7016
+ return inferTypePredicateFromExpression(arrow.body, /*negated*/ false);
7017
+
7018
+ function inferTypePredicateFromExpression(expr: Expression, negated: boolean): IdentifierTypePredicate | undefined {
7019
+ switch (expr.kind) {
7020
+ case SyntaxKind.CallExpression:
7021
+ return inferTypePredicateFromCallExpression(expr as CallExpression, negated);
7022
+ case SyntaxKind.ParenthesizedExpression:
7023
+ return inferTypePredicateFromExpression((<ParenthesizedExpression>expr).expression, negated);
7024
+ case SyntaxKind.PrefixUnaryExpression:
7025
+ return (<PrefixUnaryExpression>expr).operator === SyntaxKind.ExclamationToken ?
7026
+ inferTypePredicateFromExpression((<PrefixUnaryExpression>expr).operand, !negated) :
7027
+ undefined;
7028
+ case SyntaxKind.BinaryExpression:
7029
+ return inferTypePredicateFromBinaryExpression(expr as BinaryExpression, negated);
7030
+ default:
7031
+ return undefined;
7032
+ }
7033
+ }
7034
+
7035
+ function inferTypePredicateFromCallExpression(expr: CallExpression, negated: boolean) {
7036
+ if (negated) {
7037
+ // We can't deny the antecedent
7038
+ return undefined;
7039
+ }
7040
+
7041
+ const signature = getResolvedSignature(expr);
7042
+ const typePredicate = getTypePredicateOfSignature(signature);
7043
+
7044
+ if (!typePredicate || !isIdentifierTypePredicate(typePredicate)) {
7045
+ return undefined;
7046
+ }
7047
+
7048
+ const argument = expr.arguments[typePredicate.parameterIndex];
7049
+ const paramIndex = argument ? findIndex(paramDecls, p => isMatchingReference(p.name, argument)) : -1;
7050
+
7051
+ if (paramIndex >= 0) {
7052
+ return createIdentifierTypePredicate(params[paramIndex].escapedName as string, paramIndex, typePredicate.type);
7053
+ }
7054
+ else {
7055
+ return undefined;
7056
+ }
7057
+ }
7058
+
7059
+ function inferTypePredicateFromBinaryExpression(expr: BinaryExpression, negated: boolean): IdentifierTypePredicate | undefined {
7060
+ switch (expr.operatorToken.kind) {
7061
+ case SyntaxKind.EqualsEqualsEqualsToken:
7062
+ case SyntaxKind.ExclamationEqualsEqualsToken:
7063
+ const equality = negated !== (expr.operatorToken.kind === SyntaxKind.EqualsEqualsEqualsToken);
7064
+ const leftType = getTypeOfExpression(expr.left);
7065
+ const rightType = getTypeOfExpression(expr.right);
7066
+ const isLeftConstant = isStaticallyKnownConstant(expr.left);
7067
+ if (isLeftConstant === isStaticallyKnownConstant(expr.right)) {
7068
+ return undefined;
7069
+ }
7070
+
7071
+ const literalType = isLeftConstant ? leftType : rightType;
7072
+ const subject = isLeftConstant ? expr.right : expr.left;
7073
+
7074
+ if (subject.kind === SyntaxKind.PropertyAccessExpression) {
7075
+ const propAccess = subject as PropertyAccessExpression;
7076
+ const paramIndex = findIndex(paramDecls, p => isMatchingReference(p.name, propAccess.expression));
7077
+
7078
+ if (paramIndex < 0) {
7079
+ return undefined;
7080
+ }
7081
+
7082
+ const param = params[paramIndex];
7083
+ const paramType = getTypeOfVariableOrParameterOrProperty(param);
7084
+
7085
+ if (isDiscriminantProperty(paramType, propAccess.name.escapedText)) {
7086
+ const filter = (t: Type) => equality === isTypeComparableTo(literalType, t);
7087
+ const impliedType = narrowTypeByDiscriminant(paramType, propAccess, t0 => filterType(t0, filter));
7088
+
7089
+ return createIdentifierTypePredicate(param.escapedName as string, paramIndex, impliedType);
7090
+ }
7091
+
7092
+ return undefined;
7093
+ }
7094
+ else {
7095
+ return undefined;
7096
+ }
7097
+ case SyntaxKind.AmpersandAmpersandToken:
7098
+ case SyntaxKind.BarBarToken:
7099
+ const conjunctive = negated !== (expr.operatorToken.kind === SyntaxKind.AmpersandAmpersandToken);
7100
+ const left = inferTypePredicateFromExpression(expr.left, negated);
7101
+ const right = inferTypePredicateFromExpression(expr.right, negated);
7102
+
7103
+ if (left && right && left.parameterIndex === right.parameterIndex) {
7104
+ if (conjunctive) {
7105
+ const impliedType = filterType(left.type, t => isTypeComparableTo(right.type, t));
7106
+ return createIdentifierTypePredicate(left.parameterName, left.parameterIndex, impliedType);
7107
+ }
7108
+ else {
7109
+ const impliedType = getUnionType([left.type, right.type], UnionReduction.None);
7110
+ return createIdentifierTypePredicate(left.parameterName, left.parameterIndex, impliedType);
7111
+ }
7112
+ }
7113
+ else if (conjunctive && (!left || !right)) {
7114
+ return left || right;
7115
+ }
7116
+ else {
7117
+ return undefined;
7118
+ }
7119
+ case SyntaxKind.InstanceOfKeyword:
7120
+ const paramIndex = findIndex(paramDecls, p => isMatchingReference(p.name, expr.left));
7121
+
7122
+ if (paramIndex >= 0) {
7123
+ const param = params[paramIndex];
7124
+ const paramType = getTypeOfVariableOrParameterOrProperty(param);
7125
+
7126
+ const constructorType = getTypeOfExpression(expr.right);
7127
+ const instanceType = getTypeOfPropertyOfType(constructorType, "prototype" as __String);
7128
+ const filter = (t: Type) => negated !== isTypeComparableTo(instanceType, t);
7129
+ const impliedType = filterType(paramType, filter);
7130
+ return createIdentifierTypePredicate(param.escapedName as string, paramIndex, impliedType);
7131
+ }
7132
+ else {
7133
+ return undefined;
7134
+ }
7135
+ default:
7136
+ return undefined;
7137
+ }
7138
+ }
7139
+
7140
+ function isStaticallyKnownConstant(expr: Expression): boolean {
7141
+ const isLiteral = expr.kind === SyntaxKind.TrueKeyword
7142
+ || expr.kind === SyntaxKind.FalseKeyword
7143
+ || expr.kind === SyntaxKind.NumericLiteral
7144
+ || expr.kind === SyntaxKind.StringLiteral;
7145
+
7146
+ if (isLiteral) {
7147
+ return true;
7148
+ }
7149
+ else if (isEntityNameExpression(expr)) {
7150
+ const resolved = resolveEntityName(expr, SymbolFlags.Value, /*ignoreErrors*/ true);
7151
+ // TODO const foo = 'Should this also be considered?';
7152
+ return !!(resolved && (resolved.flags & SymbolFlags.EnumMember));
7153
+ }
7154
+ else {
7155
+ return false;
7156
+ }
7157
+ }
7158
+
7159
+ function narrowTypeByDiscriminant(type: Type, propAccess: PropertyAccessExpression, narrowType: (t: Type) => Type): Type {
7160
+ const propName = propAccess.name.escapedText;
7161
+ const propType = getTypeOfPropertyOfType(type, propName);
7162
+ const narrowedPropType = propType && narrowType(propType);
7163
+ return propType === narrowedPropType ? type : filterType(type, t => isTypeComparableTo(getTypeOfPropertyOfType(t, propName), narrowedPropType));
7164
+ }
7165
+ }
7166
+
6996
7167
function getReturnTypeOfSignature(signature: Signature): Type {
6997
7168
if (!signature.resolvedReturnType) {
6998
7169
if (!pushTypeResolution(signature, TypeSystemPropertyName.ResolvedReturnType)) {
0 commit comments