Skip to content

Commit dc64103

Browse files
committed
Auto merge of #117703 - compiler-errors:recursive-async, r=lcnr
Support async recursive calls (as long as they have indirection) Before #101692, we stored coroutine witness types directly inside of the coroutine. That means that a coroutine could not contain itself (as a witness field) without creating a cycle in the type representation of the coroutine, which we detected with the `OpaqueTypeExpander`, which is used to detect cycles when expanding opaque types after that are inferred to contain themselves. After `-Zdrop-tracking-mir` was stabilized, we no longer store these generator witness fields directly, but instead behind a def-id based query. That means there is no technical obstacle in the compiler preventing coroutines from containing themselves per se, other than the fact that for a coroutine to have a non-infinite layout, it must contain itself wrapped in a layer of allocation indirection (like a `Box`). This means that it should be valid for this code to work: ``` async fn async_fibonacci(i: u32) -> u32 { if i == 0 || i == 1 { i } else { Box::pin(async_fibonacci(i - 1)).await + Box::pin(async_fibonacci(i - 2)).await } } ``` Whereas previously, you'd need to coerce the future to `Pin<Box<dyn Future<Output = ...>>` before `await`ing it, to prevent the async's desugared coroutine from containing itself across as await point. This PR does two things: 1. Only report an error if an opaque expansion cycle is detected *not* through coroutine witness fields. * Instead, if we find an opaque cycle through coroutine witness fields, we compute the layout of the coroutine. If that results in a cycle error, we report it as a recursive async fn. 4. Reworks the way we report layout errors having to do with coroutines, to make up for the diagnostic regressions introduced by (1.). We actually do even better now, pointing out the call sites of the recursion!
2 parents 387e7a5 + 9a75603 commit dc64103

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+394
-243
lines changed

compiler/rustc_error_codes/src/error_codes/E0733.md

+11-15
Original file line numberDiff line numberDiff line change
@@ -10,35 +10,31 @@ async fn foo(n: usize) {
1010
}
1111
```
1212

13-
To perform async recursion, the `async fn` needs to be desugared such that the
14-
`Future` is explicit in the return type:
13+
The recursive invocation can be boxed:
1514

16-
```edition2018,compile_fail,E0720
17-
use std::future::Future;
18-
fn foo_desugared(n: usize) -> impl Future<Output = ()> {
19-
async move {
20-
if n > 0 {
21-
foo_desugared(n - 1).await;
22-
}
15+
```edition2018
16+
async fn foo(n: usize) {
17+
if n > 0 {
18+
Box::pin(foo(n - 1)).await;
2319
}
2420
}
2521
```
2622

27-
Finally, the future is wrapped in a pinned box:
23+
The `Box<...>` ensures that the result is of known size, and the pin is
24+
required to keep it in the same place in memory.
25+
26+
Alternatively, the body can be boxed:
2827

2928
```edition2018
3029
use std::future::Future;
3130
use std::pin::Pin;
32-
fn foo_recursive(n: usize) -> Pin<Box<dyn Future<Output = ()>>> {
31+
fn foo(n: usize) -> Pin<Box<dyn Future<Output = ()>>> {
3332
Box::pin(async move {
3433
if n > 0 {
35-
foo_recursive(n - 1).await;
34+
foo(n - 1).await;
3635
}
3736
})
3837
}
3938
```
4039

41-
The `Box<...>` ensures that the result is of known size, and the pin is
42-
required to keep it in the same place in memory.
43-
4440
[`async`]: https://doc.rust-lang.org/std/keyword.async.html

compiler/rustc_hir/src/hir.rs

+6
Original file line numberDiff line numberDiff line change
@@ -1361,6 +1361,12 @@ impl CoroutineKind {
13611361
}
13621362
}
13631363

1364+
impl CoroutineKind {
1365+
pub fn is_fn_like(self) -> bool {
1366+
matches!(self, CoroutineKind::Desugared(_, CoroutineSource::Fn))
1367+
}
1368+
}
1369+
13641370
impl fmt::Display for CoroutineKind {
13651371
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13661372
match self {

compiler/rustc_hir_analysis/src/check/check.rs

+29-23
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use rustc_middle::middle::stability::EvalResult;
1717
use rustc_middle::traits::{DefiningAnchor, ObligationCauseCode};
1818
use rustc_middle::ty::fold::BottomUpFolder;
1919
use rustc_middle::ty::layout::{LayoutError, MAX_SIMD_LANES};
20-
use rustc_middle::ty::util::{Discr, IntTypeExt};
20+
use rustc_middle::ty::util::{Discr, InspectCoroutineFields, IntTypeExt};
2121
use rustc_middle::ty::GenericArgKind;
2222
use rustc_middle::ty::{
2323
AdtDef, ParamEnv, RegionKind, TypeSuperVisitable, TypeVisitable, TypeVisitableExt,
@@ -213,13 +213,12 @@ fn check_opaque(tcx: TyCtxt<'_>, def_id: LocalDefId) {
213213
return;
214214
}
215215

216-
let args = GenericArgs::identity_for_item(tcx, item.owner_id);
217216
let span = tcx.def_span(item.owner_id.def_id);
218217

219218
if tcx.type_of(item.owner_id.def_id).instantiate_identity().references_error() {
220219
return;
221220
}
222-
if check_opaque_for_cycles(tcx, item.owner_id.def_id, args, span, origin).is_err() {
221+
if check_opaque_for_cycles(tcx, item.owner_id.def_id, span).is_err() {
223222
return;
224223
}
225224

@@ -230,19 +229,36 @@ fn check_opaque(tcx: TyCtxt<'_>, def_id: LocalDefId) {
230229
pub(super) fn check_opaque_for_cycles<'tcx>(
231230
tcx: TyCtxt<'tcx>,
232231
def_id: LocalDefId,
233-
args: GenericArgsRef<'tcx>,
234232
span: Span,
235-
origin: &hir::OpaqueTyOrigin,
236233
) -> Result<(), ErrorGuaranteed> {
237-
if tcx.try_expand_impl_trait_type(def_id.to_def_id(), args).is_err() {
238-
let reported = match origin {
239-
hir::OpaqueTyOrigin::AsyncFn(..) => async_opaque_type_cycle_error(tcx, span),
240-
_ => opaque_type_cycle_error(tcx, def_id, span),
241-
};
242-
Err(reported)
243-
} else {
244-
Ok(())
234+
let args = GenericArgs::identity_for_item(tcx, def_id);
235+
236+
// First, try to look at any opaque expansion cycles, considering coroutine fields
237+
// (even though these aren't necessarily true errors).
238+
if tcx
239+
.try_expand_impl_trait_type(def_id.to_def_id(), args, InspectCoroutineFields::Yes)
240+
.is_err()
241+
{
242+
// Look for true opaque expansion cycles, but ignore coroutines.
243+
// This will give us any true errors. Coroutines are only problematic
244+
// if they cause layout computation errors.
245+
if tcx
246+
.try_expand_impl_trait_type(def_id.to_def_id(), args, InspectCoroutineFields::No)
247+
.is_err()
248+
{
249+
let reported = opaque_type_cycle_error(tcx, def_id, span);
250+
return Err(reported);
251+
}
252+
253+
// And also look for cycle errors in the layout of coroutines.
254+
if let Err(&LayoutError::Cycle(guar)) =
255+
tcx.layout_of(tcx.param_env(def_id).and(Ty::new_opaque(tcx, def_id.to_def_id(), args)))
256+
{
257+
return Err(guar);
258+
}
245259
}
260+
261+
Ok(())
246262
}
247263

248264
/// Check that the concrete type behind `impl Trait` actually implements `Trait`.
@@ -1300,16 +1316,6 @@ pub(super) fn check_type_params_are_used<'tcx>(
13001316
}
13011317
}
13021318

1303-
fn async_opaque_type_cycle_error(tcx: TyCtxt<'_>, span: Span) -> ErrorGuaranteed {
1304-
struct_span_err!(tcx.dcx(), span, E0733, "recursion in an `async fn` requires boxing")
1305-
.span_label_mv(span, "recursive `async fn`")
1306-
.note_mv("a recursive `async fn` must be rewritten to return a boxed `dyn Future`")
1307-
.note_mv(
1308-
"consider using the `async_recursion` crate: https://crates.io/crates/async_recursion",
1309-
)
1310-
.emit()
1311-
}
1312-
13131319
/// Emit an error for recursive opaque types.
13141320
///
13151321
/// If this is a return `impl Trait`, find the item's return expressions and point at them. For

compiler/rustc_middle/src/query/keys.rs

+9-4
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ pub trait Key: Sized {
4040
None
4141
}
4242

43-
fn ty_adt_id(&self) -> Option<DefId> {
43+
fn ty_def_id(&self) -> Option<DefId> {
4444
None
4545
}
4646
}
@@ -406,9 +406,10 @@ impl<'tcx> Key for Ty<'tcx> {
406406
DUMMY_SP
407407
}
408408

409-
fn ty_adt_id(&self) -> Option<DefId> {
410-
match self.kind() {
409+
fn ty_def_id(&self) -> Option<DefId> {
410+
match *self.kind() {
411411
ty::Adt(adt, _) => Some(adt.did()),
412+
ty::Coroutine(def_id, ..) => Some(def_id),
412413
_ => None,
413414
}
414415
}
@@ -452,6 +453,10 @@ impl<'tcx, T: Key> Key for ty::ParamEnvAnd<'tcx, T> {
452453
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
453454
self.value.default_span(tcx)
454455
}
456+
457+
fn ty_def_id(&self) -> Option<DefId> {
458+
self.value.ty_def_id()
459+
}
455460
}
456461

457462
impl Key for Symbol {
@@ -550,7 +555,7 @@ impl<'tcx> Key for (ValidityRequirement, ty::ParamEnvAnd<'tcx, Ty<'tcx>>) {
550555
DUMMY_SP
551556
}
552557

553-
fn ty_adt_id(&self) -> Option<DefId> {
558+
fn ty_def_id(&self) -> Option<DefId> {
554559
match self.1.value.kind() {
555560
ty::Adt(adt, _) => Some(adt.did()),
556561
_ => None,

compiler/rustc_middle/src/query/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,8 @@ rustc_queries! {
13871387
) -> Result<ty::layout::TyAndLayout<'tcx>, &'tcx ty::layout::LayoutError<'tcx>> {
13881388
depth_limit
13891389
desc { "computing layout of `{}`", key.value }
1390+
// we emit our own error during query cycle handling
1391+
cycle_delay_bug
13901392
}
13911393

13921394
/// Compute a `FnAbi` suitable for indirect calls, i.e. to `fn` pointers.

compiler/rustc_middle/src/query/plumbing.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ pub struct DynamicQuery<'tcx, C: QueryCache> {
5353
fn(tcx: TyCtxt<'tcx>, key: &C::Key, index: SerializedDepNodeIndex) -> bool,
5454
pub hash_result: HashResult<C::Value>,
5555
pub value_from_cycle_error:
56-
fn(tcx: TyCtxt<'tcx>, cycle: &[QueryInfo], guar: ErrorGuaranteed) -> C::Value,
56+
fn(tcx: TyCtxt<'tcx>, cycle_error: &CycleError, guar: ErrorGuaranteed) -> C::Value,
5757
pub format_value: fn(&C::Value) -> String,
5858
}
5959

compiler/rustc_middle/src/ty/util.rs

+51-12
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,7 @@ impl<'tcx> TyCtxt<'tcx> {
702702
self,
703703
def_id: DefId,
704704
args: GenericArgsRef<'tcx>,
705+
inspect_coroutine_fields: InspectCoroutineFields,
705706
) -> Result<Ty<'tcx>, Ty<'tcx>> {
706707
let mut visitor = OpaqueTypeExpander {
707708
seen_opaque_tys: FxHashSet::default(),
@@ -712,6 +713,7 @@ impl<'tcx> TyCtxt<'tcx> {
712713
check_recursion: true,
713714
expand_coroutines: true,
714715
tcx: self,
716+
inspect_coroutine_fields,
715717
};
716718

717719
let expanded_type = visitor.expand_opaque_ty(def_id, args).unwrap();
@@ -729,16 +731,43 @@ impl<'tcx> TyCtxt<'tcx> {
729731
DefKind::AssocFn if self.associated_item(def_id).fn_has_self_parameter => "method",
730732
DefKind::Closure if let Some(coroutine_kind) = self.coroutine_kind(def_id) => {
731733
match coroutine_kind {
732-
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _) => {
733-
"async closure"
734-
}
735-
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _) => {
736-
"async gen closure"
737-
}
734+
hir::CoroutineKind::Desugared(
735+
hir::CoroutineDesugaring::Async,
736+
hir::CoroutineSource::Fn,
737+
) => "async fn",
738+
hir::CoroutineKind::Desugared(
739+
hir::CoroutineDesugaring::Async,
740+
hir::CoroutineSource::Block,
741+
) => "async block",
742+
hir::CoroutineKind::Desugared(
743+
hir::CoroutineDesugaring::Async,
744+
hir::CoroutineSource::Closure,
745+
) => "async closure",
746+
hir::CoroutineKind::Desugared(
747+
hir::CoroutineDesugaring::AsyncGen,
748+
hir::CoroutineSource::Fn,
749+
) => "async gen fn",
750+
hir::CoroutineKind::Desugared(
751+
hir::CoroutineDesugaring::AsyncGen,
752+
hir::CoroutineSource::Block,
753+
) => "async gen block",
754+
hir::CoroutineKind::Desugared(
755+
hir::CoroutineDesugaring::AsyncGen,
756+
hir::CoroutineSource::Closure,
757+
) => "async gen closure",
758+
hir::CoroutineKind::Desugared(
759+
hir::CoroutineDesugaring::Gen,
760+
hir::CoroutineSource::Fn,
761+
) => "gen fn",
762+
hir::CoroutineKind::Desugared(
763+
hir::CoroutineDesugaring::Gen,
764+
hir::CoroutineSource::Block,
765+
) => "gen block",
766+
hir::CoroutineKind::Desugared(
767+
hir::CoroutineDesugaring::Gen,
768+
hir::CoroutineSource::Closure,
769+
) => "gen closure",
738770
hir::CoroutineKind::Coroutine(_) => "coroutine",
739-
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _) => {
740-
"gen closure"
741-
}
742771
}
743772
}
744773
_ => def_kind.descr(def_id),
@@ -865,6 +894,13 @@ struct OpaqueTypeExpander<'tcx> {
865894
/// recursion, and 'false' otherwise to avoid unnecessary work.
866895
check_recursion: bool,
867896
tcx: TyCtxt<'tcx>,
897+
inspect_coroutine_fields: InspectCoroutineFields,
898+
}
899+
900+
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
901+
pub enum InspectCoroutineFields {
902+
No,
903+
Yes,
868904
}
869905

870906
impl<'tcx> OpaqueTypeExpander<'tcx> {
@@ -906,9 +942,11 @@ impl<'tcx> OpaqueTypeExpander<'tcx> {
906942
let expanded_ty = match self.expanded_cache.get(&(def_id, args)) {
907943
Some(expanded_ty) => *expanded_ty,
908944
None => {
909-
for bty in self.tcx.coroutine_hidden_types(def_id) {
910-
let hidden_ty = bty.instantiate(self.tcx, args);
911-
self.fold_ty(hidden_ty);
945+
if matches!(self.inspect_coroutine_fields, InspectCoroutineFields::Yes) {
946+
for bty in self.tcx.coroutine_hidden_types(def_id) {
947+
let hidden_ty = bty.instantiate(self.tcx, args);
948+
self.fold_ty(hidden_ty);
949+
}
912950
}
913951
let expanded_ty = Ty::new_coroutine_witness(self.tcx, def_id, args);
914952
self.expanded_cache.insert((def_id, args), expanded_ty);
@@ -1486,6 +1524,7 @@ pub fn reveal_opaque_types_in_bounds<'tcx>(
14861524
check_recursion: false,
14871525
expand_coroutines: false,
14881526
tcx,
1527+
inspect_coroutine_fields: InspectCoroutineFields::No,
14891528
};
14901529
val.fold_with(&mut visitor)
14911530
}

0 commit comments

Comments
 (0)