Skip to content

Commit f2ef88b

Browse files
Consolidate logic around resolving built-in coroutine trait impls
1 parent 32ec40c commit f2ef88b

File tree

5 files changed

+56
-58
lines changed

5 files changed

+56
-58
lines changed

compiler/rustc_hir/src/lang_items.rs

+3
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,11 @@ language_item_table! {
213213
Iterator, sym::iterator, iterator_trait, Target::Trait, GenericRequirement::Exact(0);
214214
Future, sym::future_trait, future_trait, Target::Trait, GenericRequirement::Exact(0);
215215
AsyncIterator, sym::async_iterator, async_iterator_trait, Target::Trait, GenericRequirement::Exact(0);
216+
216217
CoroutineState, sym::coroutine_state, coroutine_state, Target::Enum, GenericRequirement::None;
217218
Coroutine, sym::coroutine, coroutine_trait, Target::Trait, GenericRequirement::Minimum(1);
219+
CoroutineResume, sym::coroutine_resume, coroutine_resume, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;
220+
218221
Unpin, sym::unpin, unpin_trait, Target::Trait, GenericRequirement::None;
219222
Pin, sym::pin, pin_type, Target::Struct, GenericRequirement::None;
220223

compiler/rustc_middle/src/ty/instance.rs

+50
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::ty::print::{FmtPrinter, Printer};
33
use crate::ty::{self, Ty, TyCtxt, TypeFoldable, TypeSuperFoldable};
44
use crate::ty::{EarlyBinder, GenericArgs, GenericArgsRef, TypeVisitableExt};
55
use rustc_errors::ErrorGuaranteed;
6+
use rustc_hir as hir;
67
use rustc_hir::def::Namespace;
78
use rustc_hir::def_id::{CrateNum, DefId};
89
use rustc_hir::lang_items::LangItem;
@@ -11,6 +12,7 @@ use rustc_macros::HashStable;
1112
use rustc_middle::ty::normalize_erasing_regions::NormalizationError;
1213
use rustc_span::Symbol;
1314

15+
use std::assert_matches::assert_matches;
1416
use std::fmt;
1517

1618
/// A monomorphized `InstanceDef`.
@@ -572,6 +574,54 @@ impl<'tcx> Instance<'tcx> {
572574
Some(Instance { def, args })
573575
}
574576

577+
pub fn try_resolve_item_for_coroutine(
578+
tcx: TyCtxt<'tcx>,
579+
trait_item_id: DefId,
580+
trait_id: DefId,
581+
rcvr_args: ty::GenericArgsRef<'tcx>,
582+
) -> Option<Instance<'tcx>> {
583+
let ty::Coroutine(coroutine_def_id, args) = *rcvr_args.type_at(0).kind() else {
584+
return None;
585+
};
586+
let coroutine_kind = tcx.coroutine_kind(coroutine_def_id).unwrap();
587+
588+
let lang_items = tcx.lang_items();
589+
let coroutine_callable_item = if Some(trait_id) == lang_items.future_trait() {
590+
assert_matches!(
591+
coroutine_kind,
592+
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _)
593+
);
594+
hir::LangItem::FuturePoll
595+
} else if Some(trait_id) == lang_items.iterator_trait() {
596+
assert_matches!(
597+
coroutine_kind,
598+
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _)
599+
);
600+
hir::LangItem::IteratorNext
601+
} else if Some(trait_id) == lang_items.async_iterator_trait() {
602+
assert_matches!(
603+
coroutine_kind,
604+
hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _)
605+
);
606+
hir::LangItem::AsyncIteratorPollNext
607+
} else if Some(trait_id) == lang_items.coroutine_trait() {
608+
assert_matches!(coroutine_kind, hir::CoroutineKind::Coroutine(_));
609+
hir::LangItem::CoroutineResume
610+
} else {
611+
return None;
612+
};
613+
614+
if tcx.lang_items().get(coroutine_callable_item) == Some(trait_item_id) {
615+
Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args: args })
616+
} else {
617+
// All other methods should be defaulted methods of the built-in trait.
618+
// This is important for `Iterator`'s combinators, but also useful for
619+
// adding future default methods to `Future`, for instance.
620+
debug_assert!(tcx.defaultness(trait_item_id).has_value());
621+
Some(Instance::new(trait_item_id, rcvr_args))
622+
}
623+
}
624+
575625
/// Depending on the kind of `InstanceDef`, the MIR body associated with an
576626
/// instance is expressed in terms of the generic parameters of `self.def_id()`, and in other
577627
/// cases the MIR body is expressed in terms of the types found in the substitution array.

compiler/rustc_span/src/symbol.rs

+1
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,7 @@ symbols! {
600600
core_panic_macro,
601601
coroutine,
602602
coroutine_clone,
603+
coroutine_resume,
603604
coroutine_state,
604605
coroutines,
605606
cosf32,

compiler/rustc_ty_utils/src/instance.rs

+1-58
Original file line numberDiff line numberDiff line change
@@ -245,63 +245,6 @@ fn resolve_associated_item<'tcx>(
245245
span: tcx.def_span(trait_item_id),
246246
})
247247
}
248-
} else if Some(trait_ref.def_id) == lang_items.future_trait() {
249-
let ty::Coroutine(coroutine_def_id, args) = *rcvr_args.type_at(0).kind() else {
250-
bug!()
251-
};
252-
if Some(trait_item_id) == tcx.lang_items().future_poll_fn() {
253-
// `Future::poll` is generated by the compiler.
254-
Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args: args })
255-
} else {
256-
// All other methods are default methods of the `Future` trait.
257-
// (this assumes that `ImplSource::Builtin` is only used for methods on `Future`)
258-
debug_assert!(tcx.defaultness(trait_item_id).has_value());
259-
Some(Instance::new(trait_item_id, rcvr_args))
260-
}
261-
} else if Some(trait_ref.def_id) == lang_items.iterator_trait() {
262-
let ty::Coroutine(coroutine_def_id, args) = *rcvr_args.type_at(0).kind() else {
263-
bug!()
264-
};
265-
if Some(trait_item_id) == tcx.lang_items().next_fn() {
266-
// `Iterator::next` is generated by the compiler.
267-
Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args })
268-
} else {
269-
// All other methods are default methods of the `Iterator` trait.
270-
// (this assumes that `ImplSource::Builtin` is only used for methods on `Iterator`)
271-
debug_assert!(tcx.defaultness(trait_item_id).has_value());
272-
Some(Instance::new(trait_item_id, rcvr_args))
273-
}
274-
} else if Some(trait_ref.def_id) == lang_items.async_iterator_trait() {
275-
let ty::Coroutine(coroutine_def_id, args) = *rcvr_args.type_at(0).kind() else {
276-
bug!()
277-
};
278-
279-
if cfg!(debug_assertions) && tcx.item_name(trait_item_id) != sym::poll_next {
280-
span_bug!(
281-
tcx.def_span(coroutine_def_id),
282-
"no definition for `{trait_ref}::{}` for built-in coroutine type",
283-
tcx.item_name(trait_item_id)
284-
)
285-
}
286-
287-
// `AsyncIterator::poll_next` is generated by the compiler.
288-
Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args })
289-
} else if Some(trait_ref.def_id) == lang_items.coroutine_trait() {
290-
let ty::Coroutine(coroutine_def_id, args) = *rcvr_args.type_at(0).kind() else {
291-
bug!()
292-
};
293-
if cfg!(debug_assertions) && tcx.item_name(trait_item_id) != sym::resume {
294-
// For compiler developers who'd like to add new items to `Coroutine`,
295-
// you either need to generate a shim body, or perhaps return
296-
// `InstanceDef::Item` pointing to a trait default method body if
297-
// it is given a default implementation by the trait.
298-
span_bug!(
299-
tcx.def_span(coroutine_def_id),
300-
"no definition for `{trait_ref}::{}` for built-in coroutine type",
301-
tcx.item_name(trait_item_id)
302-
)
303-
}
304-
Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args })
305248
} else if tcx.fn_trait_kind_from_def_id(trait_ref.def_id).is_some() {
306249
// FIXME: This doesn't check for malformed libcore that defines, e.g.,
307250
// `trait Fn { fn call_once(&self) { .. } }`. This is mostly for extension
@@ -334,7 +277,7 @@ fn resolve_associated_item<'tcx>(
334277
),
335278
}
336279
} else {
337-
None
280+
Instance::try_resolve_item_for_coroutine(tcx, trait_item_id, trait_id, rcvr_args)
338281
}
339282
}
340283
traits::ImplSource::Param(..)

library/core/src/ops/coroutine.rs

+1
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ pub trait Coroutine<R = ()> {
111111
/// been returned previously. While coroutine literals in the language are
112112
/// guaranteed to panic on resuming after `Complete`, this is not guaranteed
113113
/// for all implementations of the `Coroutine` trait.
114+
#[cfg_attr(not(bootstrap), lang = "coroutine_resume")]
114115
fn resume(self: Pin<&mut Self>, arg: R) -> CoroutineState<Self::Yield, Self::Return>;
115116
}
116117

0 commit comments

Comments
 (0)