Skip to content

Commit f111f25

Browse files
authored
Unrolled build for rust-lang#127482
Rollup merge of rust-lang#127482 - compiler-errors:closure-two-par-sig-inference, r=oli-obk Infer async closure signature from (old-style) two-part `Fn` + `Future` bounds When an async closure is passed to a function that has a "two-part" `Fn` and `Future` trait bound, like: ```rust use std::future::Future; fn not_exactly_an_async_closure(_f: F) where F: FnOnce(String) -> Fut, Fut: Future<Output = ()>, {} ``` The we want to be able to extract the signature to guide inference in the async closure, like: ```rust not_exactly_an_async_closure(async |string| { for x in string.split('\n') { ... } //~^ We need to know that the type of `string` is `String` to call methods on it. }) ``` Closure signature inference will see two bounds: `<?F as FnOnce<Args>>::Output = ?Fut`, `<?Fut as Future>::Output = String`. We need to extract the signature by looking through both projections. ### Why? I expect the ecosystem's move onto `async Fn` trait bounds (which are not affected by this PR, and already do signature inference fine) to be slow. In the mean time, I don't see major overhead to supporting this "old–style" of trait bounds that were used to model async closures. r? oli-obk Fixes rust-lang#127468 Fixes rust-lang#127425
2 parents 32e6926 + f4f678f commit f111f25

File tree

2 files changed

+121
-8
lines changed

2 files changed

+121
-8
lines changed

compiler/rustc_hir_typeck/src/closure.rs

+94-8
Original file line numberDiff line numberDiff line change
@@ -424,9 +424,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
424424
if let Some(trait_def_id) = trait_def_id {
425425
let found_kind = match closure_kind {
426426
hir::ClosureKind::Closure => self.tcx.fn_trait_kind_from_def_id(trait_def_id),
427-
hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async) => {
428-
self.tcx.async_fn_trait_kind_from_def_id(trait_def_id)
429-
}
427+
hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async) => self
428+
.tcx
429+
.async_fn_trait_kind_from_def_id(trait_def_id)
430+
.or_else(|| self.tcx.fn_trait_kind_from_def_id(trait_def_id)),
430431
_ => None,
431432
};
432433

@@ -470,14 +471,37 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
470471
// for closures and async closures, respectively.
471472
match closure_kind {
472473
hir::ClosureKind::Closure
473-
if self.tcx.fn_trait_kind_from_def_id(trait_def_id).is_some() => {}
474+
if self.tcx.fn_trait_kind_from_def_id(trait_def_id).is_some() =>
475+
{
476+
self.extract_sig_from_projection(cause_span, projection)
477+
}
478+
hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async)
479+
if self.tcx.async_fn_trait_kind_from_def_id(trait_def_id).is_some() =>
480+
{
481+
self.extract_sig_from_projection(cause_span, projection)
482+
}
483+
// It's possible we've passed the closure to a (somewhat out-of-fashion)
484+
// `F: FnOnce() -> Fut, Fut: Future<Output = T>` style bound. Let's still
485+
// guide inference here, since it's beneficial for the user.
474486
hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async)
475-
if self.tcx.async_fn_trait_kind_from_def_id(trait_def_id).is_some() => {}
476-
_ => return None,
487+
if self.tcx.fn_trait_kind_from_def_id(trait_def_id).is_some() =>
488+
{
489+
self.extract_sig_from_projection_and_future_bound(cause_span, projection)
490+
}
491+
_ => None,
477492
}
493+
}
494+
495+
/// Given an `FnOnce::Output` or `AsyncFn::Output` projection, extract the args
496+
/// and return type to infer a [`ty::PolyFnSig`] for the closure.
497+
fn extract_sig_from_projection(
498+
&self,
499+
cause_span: Option<Span>,
500+
projection: ty::PolyProjectionPredicate<'tcx>,
501+
) -> Option<ExpectedSig<'tcx>> {
502+
let projection = self.resolve_vars_if_possible(projection);
478503

479504
let arg_param_ty = projection.skip_binder().projection_term.args.type_at(1);
480-
let arg_param_ty = self.resolve_vars_if_possible(arg_param_ty);
481505
debug!(?arg_param_ty);
482506

483507
let ty::Tuple(input_tys) = *arg_param_ty.kind() else {
@@ -486,7 +510,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
486510

487511
// Since this is a return parameter type it is safe to unwrap.
488512
let ret_param_ty = projection.skip_binder().term.expect_type();
489-
let ret_param_ty = self.resolve_vars_if_possible(ret_param_ty);
490513
debug!(?ret_param_ty);
491514

492515
let sig = projection.rebind(self.tcx.mk_fn_sig(
@@ -500,6 +523,69 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
500523
Some(ExpectedSig { cause_span, sig })
501524
}
502525

526+
/// When an async closure is passed to a function that has a "two-part" `Fn`
527+
/// and `Future` trait bound, like:
528+
///
529+
/// ```rust
530+
/// use std::future::Future;
531+
///
532+
/// fn not_exactly_an_async_closure<F, Fut>(_f: F)
533+
/// where
534+
/// F: FnOnce(String, u32) -> Fut,
535+
/// Fut: Future<Output = i32>,
536+
/// {}
537+
/// ```
538+
///
539+
/// The we want to be able to extract the signature to guide inference in the async
540+
/// closure. We will have two projection predicates registered in this case. First,
541+
/// we identify the `FnOnce<Args, Output = ?Fut>` bound, and if the output type is
542+
/// an inference variable `?Fut`, we check if that is bounded by a `Future<Output = Ty>`
543+
/// projection.
544+
fn extract_sig_from_projection_and_future_bound(
545+
&self,
546+
cause_span: Option<Span>,
547+
projection: ty::PolyProjectionPredicate<'tcx>,
548+
) -> Option<ExpectedSig<'tcx>> {
549+
let projection = self.resolve_vars_if_possible(projection);
550+
551+
let arg_param_ty = projection.skip_binder().projection_term.args.type_at(1);
552+
debug!(?arg_param_ty);
553+
554+
let ty::Tuple(input_tys) = *arg_param_ty.kind() else {
555+
return None;
556+
};
557+
558+
// If the return type is a type variable, look for bounds on it.
559+
// We could theoretically support other kinds of return types here,
560+
// but none of them would be useful, since async closures return
561+
// concrete anonymous future types, and their futures are not coerced
562+
// into any other type within the body of the async closure.
563+
let ty::Infer(ty::TyVar(return_vid)) = *projection.skip_binder().term.expect_type().kind()
564+
else {
565+
return None;
566+
};
567+
568+
// FIXME: We may want to elaborate here, though I assume this will be exceedingly rare.
569+
for bound in self.obligations_for_self_ty(return_vid) {
570+
if let Some(ret_projection) = bound.predicate.as_projection_clause()
571+
&& let Some(ret_projection) = ret_projection.no_bound_vars()
572+
&& self.tcx.is_lang_item(ret_projection.def_id(), LangItem::FutureOutput)
573+
{
574+
let sig = projection.rebind(self.tcx.mk_fn_sig(
575+
input_tys,
576+
ret_projection.term.expect_type(),
577+
false,
578+
hir::Safety::Safe,
579+
Abi::Rust,
580+
));
581+
582+
return Some(ExpectedSig { cause_span, sig });
583+
}
584+
}
585+
586+
None
587+
}
588+
503589
fn sig_of_closure(
504590
&self,
505591
expr_def_id: LocalDefId,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//@ edition: 2021
2+
//@ check-pass
3+
//@ revisions: current next
4+
//@ ignore-compare-mode-next-solver (explicit revisions)
5+
//@[next] compile-flags: -Znext-solver
6+
7+
#![feature(async_closure)]
8+
9+
use std::future::Future;
10+
use std::any::Any;
11+
12+
struct Struct;
13+
impl Struct {
14+
fn method(&self) {}
15+
}
16+
17+
fn fake_async_closure<F, Fut>(_: F)
18+
where
19+
F: Fn(Struct) -> Fut,
20+
Fut: Future<Output = ()>,
21+
{}
22+
23+
fn main() {
24+
fake_async_closure(async |s| {
25+
s.method();
26+
})
27+
}

0 commit comments

Comments
 (0)