@@ -424,9 +424,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
424
424
if let Some ( trait_def_id) = trait_def_id {
425
425
let found_kind = match closure_kind {
426
426
hir:: ClosureKind :: Closure => self . tcx . fn_trait_kind_from_def_id ( trait_def_id) ,
427
- hir:: ClosureKind :: CoroutineClosure ( hir:: CoroutineDesugaring :: Async ) => {
428
- self . tcx . async_fn_trait_kind_from_def_id ( trait_def_id)
429
- }
427
+ hir:: ClosureKind :: CoroutineClosure ( hir:: CoroutineDesugaring :: Async ) => self
428
+ . tcx
429
+ . async_fn_trait_kind_from_def_id ( trait_def_id)
430
+ . or_else ( || self . tcx . fn_trait_kind_from_def_id ( trait_def_id) ) ,
430
431
_ => None ,
431
432
} ;
432
433
@@ -470,14 +471,37 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
470
471
// for closures and async closures, respectively.
471
472
match closure_kind {
472
473
hir:: ClosureKind :: Closure
473
- if self . tcx . fn_trait_kind_from_def_id ( trait_def_id) . is_some ( ) => { }
474
+ if self . tcx . fn_trait_kind_from_def_id ( trait_def_id) . is_some ( ) =>
475
+ {
476
+ self . extract_sig_from_projection ( cause_span, projection)
477
+ }
478
+ hir:: ClosureKind :: CoroutineClosure ( hir:: CoroutineDesugaring :: Async )
479
+ if self . tcx . async_fn_trait_kind_from_def_id ( trait_def_id) . is_some ( ) =>
480
+ {
481
+ self . extract_sig_from_projection ( cause_span, projection)
482
+ }
483
+ // It's possible we've passed the closure to a (somewhat out-of-fashion)
484
+ // `F: FnOnce() -> Fut, Fut: Future<Output = T>` style bound. Let's still
485
+ // guide inference here, since it's beneficial for the user.
474
486
hir:: ClosureKind :: CoroutineClosure ( hir:: CoroutineDesugaring :: Async )
475
- if self . tcx . async_fn_trait_kind_from_def_id ( trait_def_id) . is_some ( ) => { }
476
- _ => return None ,
487
+ if self . tcx . fn_trait_kind_from_def_id ( trait_def_id) . is_some ( ) =>
488
+ {
489
+ self . extract_sig_from_projection_and_future_bound ( cause_span, projection)
490
+ }
491
+ _ => None ,
477
492
}
493
+ }
494
+
495
+ /// Given an `FnOnce::Output` or `AsyncFn::Output` projection, extract the args
496
+ /// and return type to infer a [`ty::PolyFnSig`] for the closure.
497
+ fn extract_sig_from_projection (
498
+ & self ,
499
+ cause_span : Option < Span > ,
500
+ projection : ty:: PolyProjectionPredicate < ' tcx > ,
501
+ ) -> Option < ExpectedSig < ' tcx > > {
502
+ let projection = self . resolve_vars_if_possible ( projection) ;
478
503
479
504
let arg_param_ty = projection. skip_binder ( ) . projection_term . args . type_at ( 1 ) ;
480
- let arg_param_ty = self . resolve_vars_if_possible ( arg_param_ty) ;
481
505
debug ! ( ?arg_param_ty) ;
482
506
483
507
let ty:: Tuple ( input_tys) = * arg_param_ty. kind ( ) else {
@@ -486,7 +510,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
486
510
487
511
// Since this is a return parameter type it is safe to unwrap.
488
512
let ret_param_ty = projection. skip_binder ( ) . term . expect_type ( ) ;
489
- let ret_param_ty = self . resolve_vars_if_possible ( ret_param_ty) ;
490
513
debug ! ( ?ret_param_ty) ;
491
514
492
515
let sig = projection. rebind ( self . tcx . mk_fn_sig (
@@ -500,6 +523,69 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
500
523
Some ( ExpectedSig { cause_span, sig } )
501
524
}
502
525
526
+ /// When an async closure is passed to a function that has a "two-part" `Fn`
527
+ /// and `Future` trait bound, like:
528
+ ///
529
+ /// ```rust
530
+ /// use std::future::Future;
531
+ ///
532
+ /// fn not_exactly_an_async_closure<F, Fut>(_f: F)
533
+ /// where
534
+ /// F: FnOnce(String, u32) -> Fut,
535
+ /// Fut: Future<Output = i32>,
536
+ /// {}
537
+ /// ```
538
+ ///
539
+ /// The we want to be able to extract the signature to guide inference in the async
540
+ /// closure. We will have two projection predicates registered in this case. First,
541
+ /// we identify the `FnOnce<Args, Output = ?Fut>` bound, and if the output type is
542
+ /// an inference variable `?Fut`, we check if that is bounded by a `Future<Output = Ty>`
543
+ /// projection.
544
+ fn extract_sig_from_projection_and_future_bound (
545
+ & self ,
546
+ cause_span : Option < Span > ,
547
+ projection : ty:: PolyProjectionPredicate < ' tcx > ,
548
+ ) -> Option < ExpectedSig < ' tcx > > {
549
+ let projection = self . resolve_vars_if_possible ( projection) ;
550
+
551
+ let arg_param_ty = projection. skip_binder ( ) . projection_term . args . type_at ( 1 ) ;
552
+ debug ! ( ?arg_param_ty) ;
553
+
554
+ let ty:: Tuple ( input_tys) = * arg_param_ty. kind ( ) else {
555
+ return None ;
556
+ } ;
557
+
558
+ // If the return type is a type variable, look for bounds on it.
559
+ // We could theoretically support other kinds of return types here,
560
+ // but none of them would be useful, since async closures return
561
+ // concrete anonymous future types, and their futures are not coerced
562
+ // into any other type within the body of the async closure.
563
+ let ty:: Infer ( ty:: TyVar ( return_vid) ) = * projection. skip_binder ( ) . term . expect_type ( ) . kind ( )
564
+ else {
565
+ return None ;
566
+ } ;
567
+
568
+ // FIXME: We may want to elaborate here, though I assume this will be exceedingly rare.
569
+ for bound in self . obligations_for_self_ty ( return_vid) {
570
+ if let Some ( ret_projection) = bound. predicate . as_projection_clause ( )
571
+ && let Some ( ret_projection) = ret_projection. no_bound_vars ( )
572
+ && self . tcx . is_lang_item ( ret_projection. def_id ( ) , LangItem :: FutureOutput )
573
+ {
574
+ let sig = projection. rebind ( self . tcx . mk_fn_sig (
575
+ input_tys,
576
+ ret_projection. term . expect_type ( ) ,
577
+ false ,
578
+ hir:: Safety :: Safe ,
579
+ Abi :: Rust ,
580
+ ) ) ;
581
+
582
+ return Some ( ExpectedSig { cause_span, sig } ) ;
583
+ }
584
+ }
585
+
586
+ None
587
+ }
588
+
503
589
fn sig_of_closure (
504
590
& self ,
505
591
expr_def_id : LocalDefId ,
0 commit comments