Skip to content

Commit b53180f

Browse files
committed
interpret: add sanity check in dyn upcast to double-check what codegen does
1 parent 4cd8dc6 commit b53180f

File tree

6 files changed

+103
-47
lines changed

6 files changed

+103
-47
lines changed

compiler/rustc_const_eval/src/interpret/cast.rs

+36-5
Original file line numberDiff line numberDiff line change
@@ -401,15 +401,46 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
401401
}
402402
(ty::Dynamic(data_a, _, ty::Dyn), ty::Dynamic(data_b, _, ty::Dyn)) => {
403403
let val = self.read_immediate(src)?;
404-
if data_a.principal() == data_b.principal() {
405-
// A NOP cast that doesn't actually change anything, should be allowed even with mismatching vtables.
406-
// (But currently mismatching vtables violate the validity invariant so UB is triggered anyway.)
407-
return self.write_immediate(*val, dest);
408-
}
404+
// Take apart the old pointer, and find the dynamic type.
409405
let (old_data, old_vptr) = val.to_scalar_pair();
410406
let old_data = old_data.to_pointer(self)?;
411407
let old_vptr = old_vptr.to_pointer(self)?;
412408
let ty = self.get_ptr_vtable_ty(old_vptr, Some(data_a))?;
409+
410+
// Sanity-check that `supertrait_vtable_slot` in this type's vtable indeed produces
411+
// our destination trait.
412+
if cfg!(debug_assertions) {
413+
let vptr_entry_idx =
414+
self.tcx.supertrait_vtable_slot((src_pointee_ty, dest_pointee_ty));
415+
let vtable_entries = self.vtable_entries(data_a.principal(), ty);
416+
if let Some(entry_idx) = vptr_entry_idx {
417+
let Some(&ty::VtblEntry::TraitVPtr(upcast_trait_ref)) =
418+
vtable_entries.get(entry_idx)
419+
else {
420+
span_bug!(
421+
self.cur_span(),
422+
"invalid vtable entry index in {} -> {} upcast",
423+
src_pointee_ty,
424+
dest_pointee_ty
425+
);
426+
};
427+
let erased_trait_ref = upcast_trait_ref
428+
.map_bound(|r| ty::ExistentialTraitRef::erase_self_ty(*self.tcx, r));
429+
assert!(
430+
data_b
431+
.principal()
432+
.is_some_and(|b| self.eq_in_param_env(erased_trait_ref, b))
433+
);
434+
} else {
435+
// In this case codegen would keep using the old vtable. We don't want to do
436+
// that as it has the wrong trait. The reason codegen can do this is that
437+
// one vtable is a prefix of the other, so we double-check that.
438+
let vtable_entries_b = self.vtable_entries(data_b.principal(), ty);
439+
assert!(&vtable_entries[..vtable_entries_b.len()] == vtable_entries_b);
440+
};
441+
}
442+
443+
// Get the destination trait vtable and return that.
413444
let new_vptr = self.get_vtable_ptr(ty, data_b.principal())?;
414445
self.write_immediate(Immediate::new_dyn_trait(old_data, new_vptr, self), dest)
415446
}

compiler/rustc_const_eval/src/interpret/eval_context.rs

+30
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@ use std::cell::Cell;
22
use std::{fmt, mem};
33

44
use either::{Either, Left, Right};
5+
use rustc_infer::infer::at::ToTrace;
6+
use rustc_infer::traits::ObligationCause;
7+
use rustc_trait_selection::traits::ObligationCtxt;
58
use tracing::{debug, info, info_span, instrument, trace};
69

710
use rustc_errors::DiagCtxtHandle;
811
use rustc_hir::{self as hir, def_id::DefId, definitions::DefPathData};
912
use rustc_index::IndexVec;
13+
use rustc_infer::infer::TyCtxtInferExt;
1014
use rustc_middle::mir;
1115
use rustc_middle::mir::interpret::{
1216
CtfeProvenance, ErrorHandled, InvalidMetaKind, ReportedErrorInfo,
@@ -640,6 +644,32 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
640644
}
641645
}
642646

647+
/// Check if the two things are equal in the current param_env, using an infctx to get proper
648+
/// equality checks.
649+
pub(super) fn eq_in_param_env<T>(&self, a: T, b: T) -> bool
650+
where
651+
T: PartialEq + TypeFoldable<TyCtxt<'tcx>> + ToTrace<'tcx>,
652+
{
653+
// Fast path: compare directly.
654+
if a == b {
655+
return true;
656+
}
657+
// Slow path: spin up an inference context to check if these traits are sufficiently equal.
658+
let infcx = self.tcx.infer_ctxt().build();
659+
let ocx = ObligationCtxt::new(&infcx);
660+
let cause = ObligationCause::dummy_with_span(self.cur_span());
661+
// equate the two trait refs after normalization
662+
let a = ocx.normalize(&cause, self.param_env, a);
663+
let b = ocx.normalize(&cause, self.param_env, b);
664+
if ocx.eq(&cause, self.param_env, a, b).is_ok() {
665+
if ocx.select_all_or_error().is_empty() {
666+
// All good.
667+
return true;
668+
}
669+
}
670+
return false;
671+
}
672+
643673
/// Walks up the callstack from the intrinsic's callsite, searching for the first callsite in a
644674
/// frame which is not `#[track_caller]`. This matches the `caller_location` intrinsic,
645675
/// and is primarily intended for the panic machinery.

compiler/rustc_const_eval/src/interpret/terminator.rs

+6-11
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use std::borrow::Cow;
22

33
use either::Either;
4-
use rustc_middle::ty::TyCtxt;
54
use tracing::trace;
65

76
use rustc_middle::{
@@ -867,7 +866,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
867866
};
868867

869868
// Obtain the underlying trait we are working on, and the adjusted receiver argument.
870-
let (dyn_trait, dyn_ty, adjusted_recv) = if let ty::Dynamic(data, _, ty::DynStar) =
869+
let (trait_, dyn_ty, adjusted_recv) = if let ty::Dynamic(data, _, ty::DynStar) =
871870
receiver_place.layout.ty.kind()
872871
{
873872
let recv = self.unpack_dyn_star(&receiver_place, data)?;
@@ -898,20 +897,16 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
898897
(receiver_trait.principal(), dyn_ty, receiver_place.ptr())
899898
};
900899

901-
// Now determine the actual method to call. We can do that in two different ways and
902-
// compare them to ensure everything fits.
903-
let vtable_entries = if let Some(dyn_trait) = dyn_trait {
904-
let trait_ref = dyn_trait.with_self_ty(*self.tcx, dyn_ty);
905-
let trait_ref = self.tcx.erase_regions(trait_ref);
906-
self.tcx.vtable_entries(trait_ref)
907-
} else {
908-
TyCtxt::COMMON_VTABLE_ENTRIES
909-
};
900+
// Now determine the actual method to call. Usually we use the easy way of just
901+
// looking up the method at index `idx`.
902+
let vtable_entries = self.vtable_entries(trait_, dyn_ty);
910903
let Some(ty::VtblEntry::Method(fn_inst)) = vtable_entries.get(idx).copied() else {
911904
// FIXME(fee1-dead) these could be variants of the UB info enum instead of this
912905
throw_ub_custom!(fluent::const_eval_dyn_call_not_a_method);
913906
};
914907
trace!("Virtual call dispatches to {fn_inst:#?}");
908+
// We can also do the lookup based on `def_id` and `dyn_ty`, and check that that
909+
// produces the same result.
915910
if cfg!(debug_assertions) {
916911
let tcx = *self.tcx;
917912

compiler/rustc_const_eval/src/interpret/traits.rs

+23-25
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
use rustc_infer::infer::TyCtxtInferExt;
2-
use rustc_infer::traits::ObligationCause;
31
use rustc_middle::mir::interpret::{InterpResult, Pointer};
42
use rustc_middle::ty::layout::LayoutOf;
5-
use rustc_middle::ty::{self, Ty};
3+
use rustc_middle::ty::{self, Ty, TyCtxt, VtblEntry};
64
use rustc_target::abi::{Align, Size};
7-
use rustc_trait_selection::traits::ObligationCtxt;
85
use tracing::trace;
96

107
use super::util::ensure_monomorphic_enough;
@@ -47,35 +44,36 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
4744
Ok((layout.size, layout.align.abi))
4845
}
4946

47+
pub(super) fn vtable_entries(
48+
&self,
49+
trait_: Option<ty::PolyExistentialTraitRef<'tcx>>,
50+
dyn_ty: Ty<'tcx>,
51+
) -> &'tcx [VtblEntry<'tcx>] {
52+
if let Some(trait_) = trait_ {
53+
let trait_ref = trait_.with_self_ty(*self.tcx, dyn_ty);
54+
let trait_ref = self.tcx.erase_regions(trait_ref);
55+
self.tcx.vtable_entries(trait_ref)
56+
} else {
57+
TyCtxt::COMMON_VTABLE_ENTRIES
58+
}
59+
}
60+
5061
/// Check that the given vtable trait is valid for a pointer/reference/place with the given
5162
/// expected trait type.
5263
pub(super) fn check_vtable_for_type(
5364
&self,
5465
vtable_trait: Option<ty::PolyExistentialTraitRef<'tcx>>,
5566
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
5667
) -> InterpResult<'tcx> {
57-
// Fast path: if they are equal, it's all fine.
58-
if expected_trait.principal() == vtable_trait {
59-
return Ok(());
60-
}
61-
if let (Some(expected_trait), Some(vtable_trait)) =
62-
(expected_trait.principal(), vtable_trait)
63-
{
64-
// Slow path: spin up an inference context to check if these traits are sufficiently equal.
65-
let infcx = self.tcx.infer_ctxt().build();
66-
let ocx = ObligationCtxt::new(&infcx);
67-
let cause = ObligationCause::dummy_with_span(self.cur_span());
68-
// equate the two trait refs after normalization
69-
let expected_trait = ocx.normalize(&cause, self.param_env, expected_trait);
70-
let vtable_trait = ocx.normalize(&cause, self.param_env, vtable_trait);
71-
if ocx.eq(&cause, self.param_env, expected_trait, vtable_trait).is_ok() {
72-
if ocx.select_all_or_error().is_empty() {
73-
// All good.
74-
return Ok(());
75-
}
76-
}
68+
let eq = match (expected_trait.principal(), vtable_trait) {
69+
(Some(a), Some(b)) => self.eq_in_param_env(a, b),
70+
(None, None) => true,
71+
_ => false,
72+
};
73+
if !eq {
74+
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
7775
}
78-
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
76+
Ok(())
7977
}
8078

8179
/// Turn a place with a `dyn Trait` type into a place with the actual dynamic type.

src/tools/miri/tests/fail/dyn-upcast-trait-mismatch.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,10 @@ impl Baz for i32 {
5959
}
6060

6161
fn main() {
62-
let baz: &dyn Baz = &1;
63-
let baz_fake: *const dyn Bar = unsafe { std::mem::transmute(baz) };
64-
let _err = baz_fake as *const dyn Foo;
65-
//~^ERROR: using vtable for trait `Baz` but trait `Bar` was expected
62+
unsafe {
63+
let baz: &dyn Baz = &1;
64+
let baz_fake: *const dyn Bar = std::mem::transmute(baz);
65+
let _err = baz_fake as *const dyn Foo;
66+
//~^ERROR: using vtable for trait `Baz` but trait `Bar` was expected
67+
}
6668
}

src/tools/miri/tests/fail/dyn-upcast-trait-mismatch.stderr

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
error: Undefined Behavior: using vtable for trait `Baz` but trait `Bar` was expected
22
--> $DIR/dyn-upcast-trait-mismatch.rs:LL:CC
33
|
4-
LL | let _err = baz_fake as *const dyn Foo;
5-
| ^^^^^^^^ using vtable for trait `Baz` but trait `Bar` was expected
4+
LL | let _err = baz_fake as *const dyn Foo;
5+
| ^^^^^^^^ using vtable for trait `Baz` but trait `Bar` was expected
66
|
77
= help: this indicates a bug in the program: it performed an invalid operation, and caused Undefined Behavior
88
= help: see https://doc.rust-lang.org/nightly/reference/behavior-considered-undefined.html for further information

0 commit comments

Comments
 (0)