Skip to content

Commit 5d00403

Browse files
committed
Auto merge of #119989 - lcnr:sub_relations-bye-bye, r=<try>
remove `sub_relations` from the `InferCtxt` see commit descriptions for the reasoning of each change. r? `@compiler-errors`
2 parents 73252d5 + d99a07d commit 5d00403

32 files changed

+277
-387
lines changed

compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -1493,10 +1493,13 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
14931493
if self.next_trait_solver()
14941494
&& let ty::Alias(..) = ty.kind()
14951495
{
1496-
match self
1496+
// We need to use a separate variable here as we otherwise the temporary for
1497+
// `self.fulfillment_cx.borrow_mut()` is alive in the `Err` branch, resulting
1498+
// in a reentrant borrow, causing an ICE.
1499+
let result = self
14971500
.at(&self.misc(sp), self.param_env)
1498-
.structurally_normalize(ty, &mut **self.fulfillment_cx.borrow_mut())
1499-
{
1501+
.structurally_normalize(ty, &mut **self.fulfillment_cx.borrow_mut());
1502+
match result {
15001503
Ok(normalized_ty) => normalized_ty,
15011504
Err(errors) => {
15021505
let guar = self.err_ctxt().report_fulfillment_errors(errors);

compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs

+7
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use rustc_hir as hir;
1111
use rustc_hir::def_id::{DefId, LocalDefId};
1212
use rustc_hir_analysis::astconv::AstConv;
1313
use rustc_infer::infer;
14+
use rustc_infer::infer::error_reporting::sub_relations::SubRelations;
1415
use rustc_infer::infer::error_reporting::TypeErrCtxt;
1516
use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
1617
use rustc_middle::infer::unify_key::{ConstVariableOrigin, ConstVariableOriginKind};
@@ -155,8 +156,14 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
155156
///
156157
/// [`InferCtxt::err_ctxt`]: infer::InferCtxt::err_ctxt
157158
pub fn err_ctxt(&'a self) -> TypeErrCtxt<'a, 'tcx> {
159+
let mut sub_relations = SubRelations::default();
160+
sub_relations.add_constraints(
161+
self,
162+
self.fulfillment_cx.borrow_mut().pending_obligations().iter().map(|o| o.predicate),
163+
);
158164
TypeErrCtxt {
159165
infcx: &self.infcx,
166+
sub_relations: RefCell::new(sub_relations),
160167
typeck_results: Some(self.typeck_results.borrow()),
161168
fallback_has_occurred: self.fallback_has_occurred.get(),
162169
normalize_fn_sig: Box::new(|fn_sig| {

compiler/rustc_infer/src/infer/error_reporting/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ mod note_and_explain;
8787
mod suggest;
8888

8989
pub(crate) mod need_type_info;
90+
pub mod sub_relations;
9091
pub use need_type_info::TypeAnnotationNeeded;
9192

9293
pub mod nice_region_error;
@@ -122,6 +123,8 @@ fn escape_literal(s: &str) -> String {
122123
/// during the happy path.
123124
pub struct TypeErrCtxt<'a, 'tcx> {
124125
pub infcx: &'a InferCtxt<'tcx>,
126+
pub sub_relations: std::cell::RefCell<sub_relations::SubRelations>,
127+
125128
pub typeck_results: Option<std::cell::Ref<'a, ty::TypeckResults<'tcx>>>,
126129
pub fallback_has_occurred: bool,
127130

compiler/rustc_infer/src/infer/error_reporting/need_type_info.rs

+20-23
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> {
489489
parent_name,
490490
});
491491

492-
let args = if self.infcx.tcx.get_diagnostic_item(sym::iterator_collect_fn)
492+
let args = if self.tcx.get_diagnostic_item(sym::iterator_collect_fn)
493493
== Some(generics_def_id)
494494
{
495495
"Vec<_>".to_string()
@@ -697,7 +697,7 @@ struct InsertableGenericArgs<'tcx> {
697697
/// While doing so, the currently best spot is stored in `infer_source`.
698698
/// For details on how we rank spots, see [Self::source_cost]
699699
struct FindInferSourceVisitor<'a, 'tcx> {
700-
infcx: &'a InferCtxt<'tcx>,
700+
tecx: &'a TypeErrCtxt<'a, 'tcx>,
701701
typeck_results: &'a TypeckResults<'tcx>,
702702

703703
target: GenericArg<'tcx>,
@@ -709,12 +709,12 @@ struct FindInferSourceVisitor<'a, 'tcx> {
709709

710710
impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
711711
fn new(
712-
infcx: &'a InferCtxt<'tcx>,
712+
tecx: &'a TypeErrCtxt<'a, 'tcx>,
713713
typeck_results: &'a TypeckResults<'tcx>,
714714
target: GenericArg<'tcx>,
715715
) -> Self {
716716
FindInferSourceVisitor {
717-
infcx,
717+
tecx,
718718
typeck_results,
719719

720720
target,
@@ -765,7 +765,7 @@ impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
765765
}
766766

767767
// The sources are listed in order of preference here.
768-
let tcx = self.infcx.tcx;
768+
let tcx = self.tecx.tcx;
769769
let ctx = CostCtxt { tcx };
770770
match source.kind {
771771
InferSourceKind::LetBinding { ty, .. } => ctx.ty_cost(ty),
@@ -816,12 +816,12 @@ impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
816816

817817
fn node_args_opt(&self, hir_id: HirId) -> Option<GenericArgsRef<'tcx>> {
818818
let args = self.typeck_results.node_args_opt(hir_id);
819-
self.infcx.resolve_vars_if_possible(args)
819+
self.tecx.resolve_vars_if_possible(args)
820820
}
821821

822822
fn opt_node_type(&self, hir_id: HirId) -> Option<Ty<'tcx>> {
823823
let ty = self.typeck_results.node_type_opt(hir_id);
824-
self.infcx.resolve_vars_if_possible(ty)
824+
self.tecx.resolve_vars_if_possible(ty)
825825
}
826826

827827
// Check whether this generic argument is the inference variable we
@@ -836,20 +836,17 @@ impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
836836
use ty::{Infer, TyVar};
837837
match (inner_ty.kind(), target_ty.kind()) {
838838
(&Infer(TyVar(a_vid)), &Infer(TyVar(b_vid))) => {
839-
self.infcx.inner.borrow_mut().type_variables().sub_unified(a_vid, b_vid)
839+
self.tecx.sub_relations.borrow_mut().unified(self.tecx, a_vid, b_vid)
840840
}
841841
_ => false,
842842
}
843843
}
844844
(GenericArgKind::Const(inner_ct), GenericArgKind::Const(target_ct)) => {
845845
use ty::InferConst::*;
846846
match (inner_ct.kind(), target_ct.kind()) {
847-
(ty::ConstKind::Infer(Var(a_vid)), ty::ConstKind::Infer(Var(b_vid))) => self
848-
.infcx
849-
.inner
850-
.borrow_mut()
851-
.const_unification_table()
852-
.unioned(a_vid, b_vid),
847+
(ty::ConstKind::Infer(Var(a_vid)), ty::ConstKind::Infer(Var(b_vid))) => {
848+
self.tecx.inner.borrow_mut().const_unification_table().unioned(a_vid, b_vid)
849+
}
853850
_ => false,
854851
}
855852
}
@@ -901,7 +898,7 @@ impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
901898
&self,
902899
expr: &'tcx hir::Expr<'tcx>,
903900
) -> Box<dyn Iterator<Item = InsertableGenericArgs<'tcx>> + 'a> {
904-
let tcx = self.infcx.tcx;
901+
let tcx = self.tecx.tcx;
905902
match expr.kind {
906903
hir::ExprKind::Path(ref path) => {
907904
if let Some(args) = self.node_args_opt(expr.hir_id) {
@@ -964,7 +961,7 @@ impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
964961
path: &'tcx hir::Path<'tcx>,
965962
args: GenericArgsRef<'tcx>,
966963
) -> impl Iterator<Item = InsertableGenericArgs<'tcx>> + 'a {
967-
let tcx = self.infcx.tcx;
964+
let tcx = self.tecx.tcx;
968965
let have_turbofish = path.segments.iter().any(|segment| {
969966
segment.args.is_some_and(|args| args.args.iter().any(|arg| arg.is_ty_or_const()))
970967
});
@@ -1018,7 +1015,7 @@ impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
10181015
args: GenericArgsRef<'tcx>,
10191016
qpath: &'tcx hir::QPath<'tcx>,
10201017
) -> Box<dyn Iterator<Item = InsertableGenericArgs<'tcx>> + 'a> {
1021-
let tcx = self.infcx.tcx;
1018+
let tcx = self.tecx.tcx;
10221019
match qpath {
10231020
hir::QPath::Resolved(_self_ty, path) => {
10241021
Box::new(self.resolved_path_inferred_arg_iter(path, args))
@@ -1091,7 +1088,7 @@ impl<'a, 'tcx> Visitor<'tcx> for FindInferSourceVisitor<'a, 'tcx> {
10911088
type NestedFilter = nested_filter::OnlyBodies;
10921089

10931090
fn nested_visit_map(&mut self) -> Self::Map {
1094-
self.infcx.tcx.hir()
1091+
self.tecx.tcx.hir()
10951092
}
10961093

10971094
fn visit_local(&mut self, local: &'tcx Local<'tcx>) {
@@ -1147,7 +1144,7 @@ impl<'a, 'tcx> Visitor<'tcx> for FindInferSourceVisitor<'a, 'tcx> {
11471144

11481145
#[instrument(level = "debug", skip(self))]
11491146
fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) {
1150-
let tcx = self.infcx.tcx;
1147+
let tcx = self.tecx.tcx;
11511148
match expr.kind {
11521149
// When encountering `func(arg)` first look into `arg` and then `func`,
11531150
// as `arg` is "more specific".
@@ -1178,7 +1175,7 @@ impl<'a, 'tcx> Visitor<'tcx> for FindInferSourceVisitor<'a, 'tcx> {
11781175
if generics.parent.is_none() && generics.has_self {
11791176
argument_index += 1;
11801177
}
1181-
let args = self.infcx.resolve_vars_if_possible(args);
1178+
let args = self.tecx.resolve_vars_if_possible(args);
11821179
let generic_args =
11831180
&generics.own_args_no_defaults(tcx, args)[generics.own_counts().lifetimes..];
11841181
let span = match expr.kind {
@@ -1208,7 +1205,7 @@ impl<'a, 'tcx> Visitor<'tcx> for FindInferSourceVisitor<'a, 'tcx> {
12081205
{
12091206
let output = args.as_closure().sig().output().skip_binder();
12101207
if self.generic_arg_contains_target(output.into()) {
1211-
let body = self.infcx.tcx.hir().body(body);
1208+
let body = self.tecx.tcx.hir().body(body);
12121209
let should_wrap_expr = if matches!(body.value.kind, ExprKind::Block(..)) {
12131210
None
12141211
} else {
@@ -1236,12 +1233,12 @@ impl<'a, 'tcx> Visitor<'tcx> for FindInferSourceVisitor<'a, 'tcx> {
12361233
&& let Some(args) = self.node_args_opt(expr.hir_id)
12371234
&& args.iter().any(|arg| self.generic_arg_contains_target(arg))
12381235
&& let Some(def_id) = self.typeck_results.type_dependent_def_id(expr.hir_id)
1239-
&& self.infcx.tcx.trait_of_item(def_id).is_some()
1236+
&& self.tecx.tcx.trait_of_item(def_id).is_some()
12401237
&& !has_impl_trait(def_id)
12411238
{
12421239
let successor =
12431240
method_args.get(0).map_or_else(|| (")", span.hi()), |arg| (", ", arg.span.lo()));
1244-
let args = self.infcx.resolve_vars_if_possible(args);
1241+
let args = self.tecx.resolve_vars_if_possible(args);
12451242
self.update_infer_source(InferSource {
12461243
span: path.ident.span,
12471244
kind: InferSourceKind::FullyQualifiedMethodCall {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
use rustc_data_structures::fx::FxHashMap;
2+
use rustc_data_structures::undo_log::NoUndo;
3+
use rustc_data_structures::unify as ut;
4+
use rustc_middle::ty;
5+
6+
use crate::infer::InferCtxt;
7+
8+
#[derive(Debug, Copy, Clone, PartialEq)]
9+
struct SubId(u32);
10+
impl ut::UnifyKey for SubId {
11+
type Value = ();
12+
#[inline]
13+
fn index(&self) -> u32 {
14+
self.0
15+
}
16+
#[inline]
17+
fn from_index(i: u32) -> SubId {
18+
SubId(i)
19+
}
20+
fn tag() -> &'static str {
21+
"SubId"
22+
}
23+
}
24+
25+
/// When reporting ambiguity errors, we sometimes want to
26+
/// treat all inference vars which are subtypes of each
27+
/// others as if they are equal. For this case we compute
28+
/// the transitive closure of our subtype obligations here.
29+
#[derive(Default)]
30+
pub struct SubRelations {
31+
map: FxHashMap<ty::TyVid, SubId>,
32+
table: ut::UnificationTableStorage<SubId>,
33+
}
34+
35+
impl SubRelations {
36+
fn get_id<'tcx>(&mut self, infcx: &InferCtxt<'tcx>, vid: ty::TyVid) -> SubId {
37+
let root_vid = infcx.root_var(vid);
38+
*self.map.entry(root_vid).or_insert_with(|| self.table.with_log(&mut NoUndo).new_key(()))
39+
}
40+
41+
pub fn add_constraints<'tcx>(
42+
&mut self,
43+
infcx: &InferCtxt<'tcx>,
44+
obls: impl IntoIterator<Item = ty::Predicate<'tcx>>,
45+
) {
46+
for p in obls {
47+
let (a, b) = match p.kind().skip_binder() {
48+
ty::PredicateKind::Subtype(ty::SubtypePredicate { a_is_expected: _, a, b }) => {
49+
(a, b)
50+
}
51+
ty::PredicateKind::Coerce(ty::CoercePredicate { a, b }) => (a, b),
52+
_ => continue,
53+
};
54+
55+
match (a.kind(), b.kind()) {
56+
(&ty::Infer(ty::TyVar(a_vid)), &ty::Infer(ty::TyVar(b_vid))) => {
57+
let a = self.get_id(infcx, a_vid);
58+
let b = self.get_id(infcx, b_vid);
59+
self.table.with_log(&mut NoUndo).unify_var_var(a, b).unwrap();
60+
}
61+
_ => continue,
62+
}
63+
}
64+
}
65+
66+
pub fn unified<'tcx>(&mut self, infcx: &InferCtxt<'tcx>, a: ty::TyVid, b: ty::TyVid) -> bool {
67+
let a = self.get_id(infcx, a);
68+
let b = self.get_id(infcx, b);
69+
self.table.with_log(&mut NoUndo).unioned(a, b)
70+
}
71+
}

0 commit comments

Comments
 (0)