Skip to content

Commit df871fb

Browse files
committed
Auto merge of #115796 - cjgillot:const-prop-rvalue, r=oli-obk
Generate aggregate constants in DataflowConstProp.
2 parents 151256b + 9c85dfa commit df871fb

23 files changed

+785
-158
lines changed

compiler/rustc_mir_transform/src/dataflow_const_prop.rs

+172-18
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,23 @@
22
//!
33
//! Currently, this pass only propagates scalar values.
44
5-
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
5+
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, PlaceTy, Projectable};
66
use rustc_data_structures::fx::FxHashMap;
77
use rustc_hir::def::DefKind;
88
use rustc_middle::mir::interpret::{AllocId, ConstAllocation, InterpResult, Scalar};
99
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
1010
use rustc_middle::mir::*;
11-
use rustc_middle::ty::layout::TyAndLayout;
11+
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
1212
use rustc_middle::ty::{self, Ty, TyCtxt};
1313
use rustc_mir_dataflow::value_analysis::{
1414
Map, PlaceIndex, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace,
1515
};
1616
use rustc_mir_dataflow::{lattice::FlatSet, Analysis, Results, ResultsVisitor};
1717
use rustc_span::def_id::DefId;
1818
use rustc_span::DUMMY_SP;
19-
use rustc_target::abi::{FieldIdx, VariantIdx};
19+
use rustc_target::abi::{Abi, FieldIdx, Size, VariantIdx, FIRST_VARIANT};
2020

21+
use crate::const_prop::throw_machine_stop_str;
2122
use crate::MirPass;
2223

2324
// These constants are somewhat random guesses and have not been optimized.
@@ -553,16 +554,151 @@ impl<'tcx, 'locals> Collector<'tcx, 'locals> {
553554

554555
fn try_make_constant(
555556
&self,
557+
ecx: &mut InterpCx<'tcx, 'tcx, DummyMachine>,
556558
place: Place<'tcx>,
557559
state: &State<FlatSet<Scalar>>,
558560
map: &Map,
559561
) -> Option<Const<'tcx>> {
560-
let FlatSet::Elem(Scalar::Int(value)) = state.get(place.as_ref(), &map) else {
561-
return None;
562-
};
563562
let ty = place.ty(self.local_decls, self.patch.tcx).ty;
564-
Some(Const::Val(ConstValue::Scalar(value.into()), ty))
563+
let layout = ecx.layout_of(ty).ok()?;
564+
565+
if layout.is_zst() {
566+
return Some(Const::zero_sized(ty));
567+
}
568+
569+
if layout.is_unsized() {
570+
return None;
571+
}
572+
573+
let place = map.find(place.as_ref())?;
574+
if layout.abi.is_scalar()
575+
&& let Some(value) = propagatable_scalar(place, state, map)
576+
{
577+
return Some(Const::Val(ConstValue::Scalar(value), ty));
578+
}
579+
580+
if matches!(layout.abi, Abi::Scalar(..) | Abi::ScalarPair(..)) {
581+
let alloc_id = ecx
582+
.intern_with_temp_alloc(layout, |ecx, dest| {
583+
try_write_constant(ecx, dest, place, ty, state, map)
584+
})
585+
.ok()?;
586+
return Some(Const::Val(ConstValue::Indirect { alloc_id, offset: Size::ZERO }, ty));
587+
}
588+
589+
None
590+
}
591+
}
592+
593+
fn propagatable_scalar(
594+
place: PlaceIndex,
595+
state: &State<FlatSet<Scalar>>,
596+
map: &Map,
597+
) -> Option<Scalar> {
598+
if let FlatSet::Elem(value) = state.get_idx(place, map) && value.try_to_int().is_ok() {
599+
// Do not attempt to propagate pointers, as we may fail to preserve their identity.
600+
Some(value)
601+
} else {
602+
None
603+
}
604+
}
605+
606+
#[instrument(level = "trace", skip(ecx, state, map))]
607+
fn try_write_constant<'tcx>(
608+
ecx: &mut InterpCx<'_, 'tcx, DummyMachine>,
609+
dest: &PlaceTy<'tcx>,
610+
place: PlaceIndex,
611+
ty: Ty<'tcx>,
612+
state: &State<FlatSet<Scalar>>,
613+
map: &Map,
614+
) -> InterpResult<'tcx> {
615+
let layout = ecx.layout_of(ty)?;
616+
617+
// Fast path for ZSTs.
618+
if layout.is_zst() {
619+
return Ok(());
620+
}
621+
622+
// Fast path for scalars.
623+
if layout.abi.is_scalar()
624+
&& let Some(value) = propagatable_scalar(place, state, map)
625+
{
626+
return ecx.write_immediate(Immediate::Scalar(value), dest);
627+
}
628+
629+
match ty.kind() {
630+
// ZSTs. Nothing to do.
631+
ty::FnDef(..) => {}
632+
633+
// Those are scalars, must be handled above.
634+
ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => throw_machine_stop_str!("primitive type with provenance"),
635+
636+
ty::Tuple(elem_tys) => {
637+
for (i, elem) in elem_tys.iter().enumerate() {
638+
let Some(field) = map.apply(place, TrackElem::Field(FieldIdx::from_usize(i))) else {
639+
throw_machine_stop_str!("missing field in tuple")
640+
};
641+
let field_dest = ecx.project_field(dest, i)?;
642+
try_write_constant(ecx, &field_dest, field, elem, state, map)?;
643+
}
644+
}
645+
646+
ty::Adt(def, args) => {
647+
if def.is_union() {
648+
throw_machine_stop_str!("cannot propagate unions")
649+
}
650+
651+
let (variant_idx, variant_def, variant_place, variant_dest) = if def.is_enum() {
652+
let Some(discr) = map.apply(place, TrackElem::Discriminant) else {
653+
throw_machine_stop_str!("missing discriminant for enum")
654+
};
655+
let FlatSet::Elem(Scalar::Int(discr)) = state.get_idx(discr, map) else {
656+
throw_machine_stop_str!("discriminant with provenance")
657+
};
658+
let discr_bits = discr.assert_bits(discr.size());
659+
let Some((variant, _)) = def.discriminants(*ecx.tcx).find(|(_, var)| discr_bits == var.val) else {
660+
throw_machine_stop_str!("illegal discriminant for enum")
661+
};
662+
let Some(variant_place) = map.apply(place, TrackElem::Variant(variant)) else {
663+
throw_machine_stop_str!("missing variant for enum")
664+
};
665+
let variant_dest = ecx.project_downcast(dest, variant)?;
666+
(variant, def.variant(variant), variant_place, variant_dest)
667+
} else {
668+
(FIRST_VARIANT, def.non_enum_variant(), place, dest.clone())
669+
};
670+
671+
for (i, field) in variant_def.fields.iter_enumerated() {
672+
let ty = field.ty(*ecx.tcx, args);
673+
let Some(field) = map.apply(variant_place, TrackElem::Field(i)) else {
674+
throw_machine_stop_str!("missing field in ADT")
675+
};
676+
let field_dest = ecx.project_field(&variant_dest, i.as_usize())?;
677+
try_write_constant(ecx, &field_dest, field, ty, state, map)?;
678+
}
679+
ecx.write_discriminant(variant_idx, dest)?;
680+
}
681+
682+
// Unsupported for now.
683+
ty::Array(_, _)
684+
685+
// Do not attempt to support indirection in constants.
686+
| ty::Ref(..) | ty::RawPtr(..) | ty::FnPtr(..) | ty::Str | ty::Slice(_)
687+
688+
| ty::Never
689+
| ty::Foreign(..)
690+
| ty::Alias(..)
691+
| ty::Param(_)
692+
| ty::Bound(..)
693+
| ty::Placeholder(..)
694+
| ty::Closure(..)
695+
| ty::Coroutine(..)
696+
| ty::Dynamic(..) => throw_machine_stop_str!("unsupported type"),
697+
698+
ty::Error(_) | ty::Infer(..) | ty::CoroutineWitness(..) => bug!(),
565699
}
700+
701+
Ok(())
566702
}
567703

568704
impl<'mir, 'tcx>
@@ -580,8 +716,13 @@ impl<'mir, 'tcx>
580716
) {
581717
match &statement.kind {
582718
StatementKind::Assign(box (_, rvalue)) => {
583-
OperandCollector { state, visitor: self, map: &results.analysis.0.map }
584-
.visit_rvalue(rvalue, location);
719+
OperandCollector {
720+
state,
721+
visitor: self,
722+
ecx: &mut results.analysis.0.ecx,
723+
map: &results.analysis.0.map,
724+
}
725+
.visit_rvalue(rvalue, location);
585726
}
586727
_ => (),
587728
}
@@ -599,7 +740,12 @@ impl<'mir, 'tcx>
599740
// Don't overwrite the assignment if it already uses a constant (to keep the span).
600741
}
601742
StatementKind::Assign(box (place, _)) => {
602-
if let Some(value) = self.try_make_constant(place, state, &results.analysis.0.map) {
743+
if let Some(value) = self.try_make_constant(
744+
&mut results.analysis.0.ecx,
745+
place,
746+
state,
747+
&results.analysis.0.map,
748+
) {
603749
self.patch.assignments.insert(location, value);
604750
}
605751
}
@@ -614,8 +760,13 @@ impl<'mir, 'tcx>
614760
terminator: &'mir Terminator<'tcx>,
615761
location: Location,
616762
) {
617-
OperandCollector { state, visitor: self, map: &results.analysis.0.map }
618-
.visit_terminator(terminator, location);
763+
OperandCollector {
764+
state,
765+
visitor: self,
766+
ecx: &mut results.analysis.0.ecx,
767+
map: &results.analysis.0.map,
768+
}
769+
.visit_terminator(terminator, location);
619770
}
620771
}
621772

@@ -670,6 +821,7 @@ impl<'tcx> MutVisitor<'tcx> for Patch<'tcx> {
670821
struct OperandCollector<'tcx, 'map, 'locals, 'a> {
671822
state: &'a State<FlatSet<Scalar>>,
672823
visitor: &'a mut Collector<'tcx, 'locals>,
824+
ecx: &'map mut InterpCx<'tcx, 'tcx, DummyMachine>,
673825
map: &'map Map,
674826
}
675827

@@ -682,15 +834,17 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> {
682834
location: Location,
683835
) {
684836
if let PlaceElem::Index(local) = elem
685-
&& let Some(value) = self.visitor.try_make_constant(local.into(), self.state, self.map)
837+
&& let Some(value) = self.visitor.try_make_constant(self.ecx, local.into(), self.state, self.map)
686838
{
687839
self.visitor.patch.before_effect.insert((location, local.into()), value);
688840
}
689841
}
690842

691843
fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) {
692844
if let Some(place) = operand.place() {
693-
if let Some(value) = self.visitor.try_make_constant(place, self.state, self.map) {
845+
if let Some(value) =
846+
self.visitor.try_make_constant(self.ecx, place, self.state, self.map)
847+
{
694848
self.visitor.patch.before_effect.insert((location, place), value);
695849
} else if !place.projection.is_empty() {
696850
// Try to propagate into `Index` projections.
@@ -713,7 +867,7 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
713867
}
714868

715869
fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool {
716-
unimplemented!()
870+
false
717871
}
718872

719873
fn before_access_global(
@@ -725,13 +879,13 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
725879
is_write: bool,
726880
) -> InterpResult<'tcx> {
727881
if is_write {
728-
crate::const_prop::throw_machine_stop_str!("can't write to global");
882+
throw_machine_stop_str!("can't write to global");
729883
}
730884

731885
// If the static allocation is mutable, then we can't const prop it as its content
732886
// might be different at runtime.
733887
if alloc.inner().mutability.is_mut() {
734-
crate::const_prop::throw_machine_stop_str!("can't access mutable globals in ConstProp");
888+
throw_machine_stop_str!("can't access mutable globals in ConstProp");
735889
}
736890

737891
Ok(())
@@ -781,7 +935,7 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
781935
_left: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>,
782936
_right: &rustc_const_eval::interpret::ImmTy<'tcx, Self::Provenance>,
783937
) -> interpret::InterpResult<'tcx, (ImmTy<'tcx, Self::Provenance>, bool)> {
784-
crate::const_prop::throw_machine_stop_str!("can't do pointer arithmetic");
938+
throw_machine_stop_str!("can't do pointer arithmetic");
785939
}
786940

787941
fn expose_ptr(

tests/mir-opt/const_debuginfo.main.ConstDebugInfo.diff

+7-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
+ debug ((f: (bool, bool, u32)).2: u32) => const 123_u32;
4242
let _10: std::option::Option<u16>;
4343
scope 7 {
44-
debug o => _10;
44+
- debug o => _10;
45+
+ debug o => const Option::<u16>::Some(99_u16);
4546
let _17: u32;
4647
let _18: u32;
4748
scope 8 {
@@ -81,7 +82,7 @@
8182
_15 = const false;
8283
_16 = const 123_u32;
8384
StorageLive(_10);
84-
_10 = Option::<u16>::Some(const 99_u16);
85+
_10 = const Option::<u16>::Some(99_u16);
8586
_17 = const 32_u32;
8687
_18 = const 32_u32;
8788
StorageLive(_11);
@@ -97,3 +98,7 @@
9798
}
9899
}
99100

101+
ALLOC0 (size: 4, align: 2) {
102+
01 00 63 00 │ ..c.
103+
}
104+

tests/mir-opt/dataflow-const-prop/checked.main.DataflowConstProp.panic-abort.diff

+10-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
- _6 = CheckedAdd(_4, _5);
4444
- assert(!move (_6.1: bool), "attempt to compute `{} + {}`, which would overflow", move _4, move _5) -> [success: bb1, unwind unreachable];
4545
+ _5 = const 2_i32;
46-
+ _6 = CheckedAdd(const 1_i32, const 2_i32);
46+
+ _6 = const (3_i32, false);
4747
+ assert(!const false, "attempt to compute `{} + {}`, which would overflow", const 1_i32, const 2_i32) -> [success: bb1, unwind unreachable];
4848
}
4949

@@ -60,7 +60,7 @@
6060
- _10 = CheckedAdd(_9, const 1_i32);
6161
- assert(!move (_10.1: bool), "attempt to compute `{} + {}`, which would overflow", move _9, const 1_i32) -> [success: bb2, unwind unreachable];
6262
+ _9 = const i32::MAX;
63-
+ _10 = CheckedAdd(const i32::MAX, const 1_i32);
63+
+ _10 = const (i32::MIN, true);
6464
+ assert(!const true, "attempt to compute `{} + {}`, which would overflow", const i32::MAX, const 1_i32) -> [success: bb2, unwind unreachable];
6565
}
6666

@@ -76,5 +76,13 @@
7676
StorageDead(_1);
7777
return;
7878
}
79+
+ }
80+
+
81+
+ ALLOC0 (size: 8, align: 4) {
82+
+ 00 00 00 80 01 __ __ __ │ .....░░░
83+
+ }
84+
+
85+
+ ALLOC1 (size: 8, align: 4) {
86+
+ 03 00 00 00 00 __ __ __ │ .....░░░
7987
}
8088

tests/mir-opt/dataflow-const-prop/checked.main.DataflowConstProp.panic-unwind.diff

+10-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
- _6 = CheckedAdd(_4, _5);
4444
- assert(!move (_6.1: bool), "attempt to compute `{} + {}`, which would overflow", move _4, move _5) -> [success: bb1, unwind continue];
4545
+ _5 = const 2_i32;
46-
+ _6 = CheckedAdd(const 1_i32, const 2_i32);
46+
+ _6 = const (3_i32, false);
4747
+ assert(!const false, "attempt to compute `{} + {}`, which would overflow", const 1_i32, const 2_i32) -> [success: bb1, unwind continue];
4848
}
4949

@@ -60,7 +60,7 @@
6060
- _10 = CheckedAdd(_9, const 1_i32);
6161
- assert(!move (_10.1: bool), "attempt to compute `{} + {}`, which would overflow", move _9, const 1_i32) -> [success: bb2, unwind continue];
6262
+ _9 = const i32::MAX;
63-
+ _10 = CheckedAdd(const i32::MAX, const 1_i32);
63+
+ _10 = const (i32::MIN, true);
6464
+ assert(!const true, "attempt to compute `{} + {}`, which would overflow", const i32::MAX, const 1_i32) -> [success: bb2, unwind continue];
6565
}
6666

@@ -76,5 +76,13 @@
7676
StorageDead(_1);
7777
return;
7878
}
79+
+ }
80+
+
81+
+ ALLOC0 (size: 8, align: 4) {
82+
+ 00 00 00 80 01 __ __ __ │ .....░░░
83+
+ }
84+
+
85+
+ ALLOC1 (size: 8, align: 4) {
86+
+ 03 00 00 00 00 __ __ __ │ .....░░░
7987
}
8088

tests/mir-opt/dataflow-const-prop/checked.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// skip-filecheck
2-
// EMIT_MIR_FOR_EACH_PANIC_STRATEGY
32
// unit-test: DataflowConstProp
43
// compile-flags: -Coverflow-checks=on
4+
// EMIT_MIR_FOR_EACH_PANIC_STRATEGY
55

66
// EMIT_MIR checked.main.DataflowConstProp.diff
77
#[allow(arithmetic_overflow)]

0 commit comments

Comments
 (0)