From db1c3e033955d6c62045f70a8ba839778a372d13 Mon Sep 17 00:00:00 2001 From: Joshua Liebow-Feeser Date: Wed, 14 May 2025 10:48:30 -0700 Subject: [PATCH] [derive] TryFromBytes on unions w/o Immutable TODO: - Add tests in zerocopy-derive/tests for unions with UnsafeCells - Add tests (location TBD) for calling &mut methods on TryFromBytes on unions with UnsafeCells gherrit-pr-id: If86f182f8e3fda5d12d680c400c7a0a7f7095ce4 --- src/impls.rs | 100 ++++++++++-------- src/lib.rs | 3 + src/macros.rs | 8 ++ src/util/macros.rs | 77 ++++++++++---- zerocopy-derive/src/enum.rs | 6 +- zerocopy-derive/src/lib.rs | 84 +++++++++++---- zerocopy-derive/tests/union_from_bytes.rs | 10 +- zerocopy-derive/tests/union_from_zeros.rs | 10 +- zerocopy-derive/tests/union_try_from_bytes.rs | 12 +-- 9 files changed, 203 insertions(+), 107 deletions(-) diff --git a/src/impls.rs b/src/impls.rs index 9f03d4ec6d..e92c1bd344 100644 --- a/src/impls.rs +++ b/src/impls.rs @@ -25,7 +25,7 @@ use super::*; // // [1] https://doc.rust-lang.org/1.81.0/reference/type-layout.html#tuple-layout const _: () = unsafe { - unsafe_impl!((): Immutable, TryFromBytes, FromZeros, FromBytes, IntoBytes, Unaligned); + unsafe_impl!((): Immutable, TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes, Unaligned); assert_unaligned!(()); }; @@ -60,25 +60,31 @@ const _: () = unsafe { // FIXME(#278): Once we've updated the trait docs to refer to `u8`s rather than // bits or bytes, update this comment, especially the reference to [1]. const _: () = unsafe { - unsafe_impl!(u8: Immutable, TryFromBytes, FromZeros, FromBytes, IntoBytes, Unaligned); - unsafe_impl!(i8: Immutable, TryFromBytes, FromZeros, FromBytes, IntoBytes, Unaligned); + unsafe_impl!(u8: Immutable, TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes, Unaligned); + unsafe_impl!(i8: Immutable, TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes, Unaligned); assert_unaligned!(u8, i8); - unsafe_impl!(u16: Immutable, TryFromBytes, FromZeros, FromBytes, IntoBytes); - unsafe_impl!(i16: Immutable, TryFromBytes, FromZeros, FromBytes, IntoBytes); - unsafe_impl!(u32: Immutable, TryFromBytes, FromZeros, FromBytes, IntoBytes); - unsafe_impl!(i32: Immutable, TryFromBytes, FromZeros, FromBytes, IntoBytes); - unsafe_impl!(u64: Immutable, TryFromBytes, FromZeros, FromBytes, IntoBytes); - unsafe_impl!(i64: Immutable, TryFromBytes, FromZeros, FromBytes, IntoBytes); - unsafe_impl!(u128: Immutable, TryFromBytes, FromZeros, FromBytes, IntoBytes); - unsafe_impl!(i128: Immutable, TryFromBytes, FromZeros, FromBytes, IntoBytes); - unsafe_impl!(usize: Immutable, TryFromBytes, FromZeros, FromBytes, IntoBytes); - unsafe_impl!(isize: Immutable, TryFromBytes, FromZeros, FromBytes, IntoBytes); - unsafe_impl!(f32: Immutable, TryFromBytes, FromZeros, FromBytes, IntoBytes); - unsafe_impl!(f64: Immutable, TryFromBytes, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(u16: Immutable, TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(i16: Immutable, TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(u32: Immutable, TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(i32: Immutable, TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(u64: Immutable, TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(i64: Immutable, TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(u128: Immutable, TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(i128: Immutable, TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(usize: Immutable, TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(isize: Immutable, TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(f32: Immutable, TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(f64: Immutable, TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); #[cfg(feature = "float-nightly")] - unsafe_impl!(#[cfg_attr(doc_cfg, doc(cfg(feature = "float-nightly")))] f16: Immutable, TryFromBytes, FromZeros, FromBytes, IntoBytes); + unsafe_impl!( + #[cfg_attr(doc_cfg, doc(cfg(feature = "float-nightly")))] + f16: Immutable, TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes + ); #[cfg(feature = "float-nightly")] - unsafe_impl!(#[cfg_attr(doc_cfg, doc(cfg(feature = "float-nightly")))] f128: Immutable, TryFromBytes, FromZeros, FromBytes, IntoBytes); + unsafe_impl!( + #[cfg_attr(doc_cfg, doc(cfg(feature = "float-nightly")))] + f128: Immutable, TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes + ); }; // SAFETY: @@ -107,7 +113,7 @@ const _: () = unsafe { unsafe_impl!(=> TryFromBytes for bool; |byte| { let byte = byte.transmute::(); *byte.unaligned_as_ref() < 2 - }) + }; IS_IMMUTABLE = true) }; impl_size_eq!(bool, u8); @@ -137,7 +143,7 @@ const _: () = unsafe { let c = c.transmute::, invariant::Valid, _>(); let c = c.read_unaligned().into_inner(); char::from_u32(c).is_some() - }); + }; IS_IMMUTABLE = true); }; impl_size_eq!(char, Unalign); @@ -170,7 +176,7 @@ const _: () = unsafe { let c = c.transmute::<[u8], invariant::Valid, _>(); let c = c.unaligned_as_ref(); core::str::from_utf8(c).is_ok() - }) + }; IS_IMMUTABLE = true) }; // SAFETY: `str` and `[u8]` have the same layout [1]. @@ -210,7 +216,7 @@ macro_rules! unsafe_impl_try_from_bytes_for_nonzero { let n = n.transmute::, invariant::Valid, _>(); $nonzero::new(n.read_unaligned().into_inner()).is_some() - }); + }; IS_IMMUTABLE = true); )* } } @@ -296,19 +302,19 @@ const _: () = unsafe { // FIXME(https://github.com/rust-lang/rust/pull/104082): Cite documentation for // layout guarantees. const _: () = unsafe { - unsafe_impl!(Option: TryFromBytes, FromZeros, FromBytes, IntoBytes, Unaligned); - unsafe_impl!(Option: TryFromBytes, FromZeros, FromBytes, IntoBytes, Unaligned); + unsafe_impl!(Option: TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes, Unaligned); + unsafe_impl!(Option: TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes, Unaligned); assert_unaligned!(Option, Option); - unsafe_impl!(Option: TryFromBytes, FromZeros, FromBytes, IntoBytes); - unsafe_impl!(Option: TryFromBytes, FromZeros, FromBytes, IntoBytes); - unsafe_impl!(Option: TryFromBytes, FromZeros, FromBytes, IntoBytes); - unsafe_impl!(Option: TryFromBytes, FromZeros, FromBytes, IntoBytes); - unsafe_impl!(Option: TryFromBytes, FromZeros, FromBytes, IntoBytes); - unsafe_impl!(Option: TryFromBytes, FromZeros, FromBytes, IntoBytes); - unsafe_impl!(Option: TryFromBytes, FromZeros, FromBytes, IntoBytes); - unsafe_impl!(Option: TryFromBytes, FromZeros, FromBytes, IntoBytes); - unsafe_impl!(Option: TryFromBytes, FromZeros, FromBytes, IntoBytes); - unsafe_impl!(Option: TryFromBytes, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(Option: TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(Option: TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(Option: TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(Option: TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(Option: TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(Option: TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(Option: TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(Option: TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(Option: TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); + unsafe_impl!(Option: TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); }; // SAFETY: While it's not fully documented, the consensus is that `Box` does @@ -348,7 +354,7 @@ const _: () = unsafe { #[cfg(feature = "alloc")] unsafe_impl!( #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] - T => TryFromBytes for Option>; |c| pointer::is_zeroed(c) + T => TryFromBytes for Option>; |c| pointer::is_zeroed(c); IS_IMMUTABLE = true ); #[cfg(feature = "alloc")] unsafe_impl!( @@ -356,26 +362,26 @@ const _: () = unsafe { T => FromZeros for Option> ); unsafe_impl!( - T => TryFromBytes for Option<&'_ T>; |c| pointer::is_zeroed(c) + T => TryFromBytes for Option<&'_ T>; |c| pointer::is_zeroed(c); IS_IMMUTABLE = true ); unsafe_impl!(T => FromZeros for Option<&'_ T>); unsafe_impl!( - T => TryFromBytes for Option<&'_ mut T>; |c| pointer::is_zeroed(c) + T => TryFromBytes for Option<&'_ mut T>; |c| pointer::is_zeroed(c); IS_IMMUTABLE = true ); unsafe_impl!(T => FromZeros for Option<&'_ mut T>); unsafe_impl!( - T => TryFromBytes for Option>; |c| pointer::is_zeroed(c) + T => TryFromBytes for Option>; |c| pointer::is_zeroed(c); IS_IMMUTABLE = true ); unsafe_impl!(T => FromZeros for Option>); unsafe_impl_for_power_set!(A, B, C, D, E, F, G, H, I, J, K, L -> M => FromZeros for opt_fn!(...)); unsafe_impl_for_power_set!( A, B, C, D, E, F, G, H, I, J, K, L -> M => TryFromBytes for opt_fn!(...); - |c| pointer::is_zeroed(c) + |c| pointer::is_zeroed(c); IS_IMMUTABLE = true ); unsafe_impl_for_power_set!(A, B, C, D, E, F, G, H, I, J, K, L -> M => FromZeros for opt_extern_c_fn!(...)); unsafe_impl_for_power_set!( A, B, C, D, E, F, G, H, I, J, K, L -> M => TryFromBytes for opt_extern_c_fn!(...); - |c| pointer::is_zeroed(c) + |c| pointer::is_zeroed(c); IS_IMMUTABLE = true ); }; @@ -679,7 +685,7 @@ mod atomics { // [1] https://doc.rust-lang.org/1.81.0/std/marker/struct.PhantomData.html#layout-1 const _: () = unsafe { unsafe_impl!(T: ?Sized => Immutable for PhantomData); - unsafe_impl!(T: ?Sized => TryFromBytes for PhantomData); + unsafe_impl!(T: ?Sized => TryFromBytes for PhantomData; IS_IMMUTABLE = true); unsafe_impl!(T: ?Sized => FromZeros for PhantomData); unsafe_impl!(T: ?Sized => FromBytes for PhantomData); unsafe_impl!(T: ?Sized => IntoBytes for PhantomData); @@ -712,7 +718,7 @@ const _: () = unsafe { unsafe_impl!(T: Unaligned => Unaligned for Wrapping) } // SAFETY: `TryFromBytes` (with no validator), `FromZeros`, `FromBytes`: // `MaybeUninit` has no restrictions on its contents. const _: () = unsafe { - unsafe_impl!(T => TryFromBytes for CoreMaybeUninit); + unsafe_impl!(T => TryFromBytes for CoreMaybeUninit; IS_IMMUTABLE = T::IS_IMMUTABLE); unsafe_impl!(T => FromZeros for CoreMaybeUninit); unsafe_impl!(T => FromBytes for CoreMaybeUninit); }; @@ -808,6 +814,8 @@ unsafe impl TryFromBytes for UnsafeCell { { } + const IS_IMMUTABLE: bool = false; + #[inline] fn is_bit_valid(candidate: Maybe<'_, Self, A>) -> bool { // The only way to implement this function is using an exclusive-aliased @@ -864,7 +872,7 @@ const _: () = unsafe { // it explicitly warns that it's a possibility), and we have not // violated any safety invariants that we must fix before returning. <[T] as TryFromBytes>::is_bit_valid(c.as_slice()) - }); + }; IS_IMMUTABLE = T::IS_IMMUTABLE); unsafe_impl!(const N: usize, T: FromZeros => FromZeros for [T; N]); unsafe_impl!(const N: usize, T: FromBytes => FromBytes for [T; N]); unsafe_impl!(const N: usize, T: IntoBytes => IntoBytes for [T; N]); @@ -893,7 +901,7 @@ const _: () = unsafe { // we have not violated any safety invariants that we must fix before // returning. c.iter().all(::is_bit_valid) - }); + }; IS_IMMUTABLE = T::IS_IMMUTABLE); unsafe_impl!(T: FromZeros => FromZeros for [T]); unsafe_impl!(T: FromBytes => FromBytes for [T]); unsafe_impl!(T: IntoBytes => IntoBytes for [T]); @@ -919,9 +927,9 @@ const _: () = unsafe { const _: () = unsafe { unsafe_impl!(T: ?Sized => Immutable for *const T); unsafe_impl!(T: ?Sized => Immutable for *mut T); - unsafe_impl!(T => TryFromBytes for *const T; |c| pointer::is_zeroed(c)); + unsafe_impl!(T => TryFromBytes for *const T; |c| pointer::is_zeroed(c); IS_IMMUTABLE = true); unsafe_impl!(T => FromZeros for *const T); - unsafe_impl!(T => TryFromBytes for *mut T; |c| pointer::is_zeroed(c)); + unsafe_impl!(T => TryFromBytes for *mut T; |c| pointer::is_zeroed(c); IS_IMMUTABLE = true); unsafe_impl!(T => FromZeros for *mut T); }; @@ -1032,7 +1040,7 @@ mod simd { impl_known_layout!($($typ),*); // SAFETY: See comment on module definition for justification. const _: () = unsafe { - $( unsafe_impl!($typ: Immutable, TryFromBytes, FromZeros, FromBytes, IntoBytes); )* + $( unsafe_impl!($typ: Immutable, TryFromBytes; IS_IMMUTABLE = true, FromZeros, FromBytes, IntoBytes); )* }; } }; diff --git a/src/lib.rs b/src/lib.rs index 5791574336..9b62f34882 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1448,6 +1448,9 @@ pub unsafe trait TryFromBytes { where Self: Sized; + #[doc(hidden)] + const IS_IMMUTABLE: bool; + /// Does a given memory range contain a valid instance of `Self`? /// /// # Safety diff --git a/src/macros.rs b/src/macros.rs index da99a742d4..a3f174590c 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -780,6 +780,10 @@ macro_rules! cryptocorrosion_derive_traits { $($field_ty: $crate::FromBytes,)* )? { + const IS_IMMUTABLE: bool = true $( + && <$field_ty as $crate::TryFromBytes>::IS_IMMUTABLE + )*; + fn is_bit_valid(_c: $crate::Maybe<'_, Self, A>) -> bool where A: $crate::pointer::invariant::Reference @@ -923,6 +927,10 @@ macro_rules! cryptocorrosion_derive_traits { $field_ty: $crate::FromBytes, )* { + const IS_IMMUTABLE: bool = true $( + && <$field_ty as $crate::TryFromBytes>::IS_IMMUTABLE + )*; + fn is_bit_valid(_c: $crate::Maybe<'_, Self, A>) -> bool where A: $crate::pointer::invariant::Reference diff --git a/src/util/macros.rs b/src/util/macros.rs index e70a9af52b..788d2122bc 100644 --- a/src/util/macros.rs +++ b/src/util/macros.rs @@ -21,13 +21,17 @@ /// must only return `true` if its argument refers to a valid `$ty`. macro_rules! unsafe_impl { // Implement `$trait` for `$ty` with no bounds. - ($(#[$attr:meta])* $ty:ty: $trait:ident $(; |$candidate:ident| $is_bit_valid:expr)?) => {{ + ( + $(#[$attr:meta])* $ty:ty: $trait:ident + $(;)? $(|$candidate:ident| $is_bit_valid:expr)? + $(; IS_IMMUTABLE = $is_immutable:expr)? + ) => {{ crate::util::macros::__unsafe(); $(#[$attr])* // SAFETY: The caller promises that this is sound. unsafe impl $trait for $ty { - unsafe_impl!(@method $trait $(; |$candidate| $is_bit_valid)?); + unsafe_impl!(@items $trait $(; |$candidate| $is_bit_valid)? $(; IS_IMMUTABLE = $is_immutable)?); } }}; @@ -49,16 +53,16 @@ macro_rules! unsafe_impl { // 1. Pack the attributes into a single token tree fragment we can match over. // 2. Expand the traits. // 3. Unpack and expand the attributes. - ($(#[$attrs:meta])* $ty:ty: $($traits:ident),*) => { - unsafe_impl!(@impl_traits_with_packed_attrs { $(#[$attrs])* } $ty: $($traits),*) + ($(#[$attrs:meta])* $ty:ty: $($traits:ident $(; IS_IMMUTABLE = $is_immutable:expr)?),*) => { + unsafe_impl!(@impl_traits_with_packed_attrs { $(#[$attrs])* } $ty: $($traits $(; IS_IMMUTABLE = $is_immutable)?),*) }; - (@impl_traits_with_packed_attrs $attrs:tt $ty:ty: $($traits:ident),*) => {{ - $( unsafe_impl!(@unpack_attrs $attrs $ty: $traits); )* + (@impl_traits_with_packed_attrs $attrs:tt $ty:ty: $($traits:ident $(; IS_IMMUTABLE = $is_immutable:expr)?),*) => {{ + $( unsafe_impl!(@unpack_attrs $attrs $ty: $traits $(; IS_IMMUTABLE = $is_immutable)?); )* }}; - (@unpack_attrs { $(#[$attrs:meta])* } $ty:ty: $traits:ident) => { - unsafe_impl!($(#[$attrs])* $ty: $traits); + (@unpack_attrs { $(#[$attrs:meta])* } $ty:ty: $traits:ident $(; IS_IMMUTABLE = $is_immutable:expr)?) => { + unsafe_impl!($(#[$attrs])* $ty: $traits $(; IS_IMMUTABLE = $is_immutable)?); }; // This arm is identical to the following one, except it contains a @@ -90,26 +94,30 @@ macro_rules! unsafe_impl { $(#[$attr:meta])* const $constname:ident : $constty:ident $(,)? $($tyvar:ident $(: $(? $optbound:ident $(+)?)* $($bound:ident $(+)?)* )?),* - => $trait:ident for $ty:ty $(; |$candidate:ident| $is_bit_valid:expr)? + => $trait:ident for $ty:ty + $(;)? $(|$candidate:ident| $is_bit_valid:expr)? + $(; IS_IMMUTABLE = $is_immutable:expr)? ) => { unsafe_impl!( @inner $(#[$attr])* @const $constname: $constty, $($tyvar $(: $(? $optbound +)* + $($bound +)*)?,)* - => $trait for $ty $(; |$candidate| $is_bit_valid)? + => $trait for $ty $(; |$candidate| $is_bit_valid)? $(; IS_IMMUTABLE = $is_immutable)? ); }; ( $(#[$attr:meta])* $($tyvar:ident $(: $(? $optbound:ident $(+)?)* $($bound:ident $(+)?)* )?),* - => $trait:ident for $ty:ty $(; |$candidate:ident| $is_bit_valid:expr)? + => $trait:ident for $ty:ty + $(;)? $(|$candidate:ident| $is_bit_valid:expr)? + $(; IS_IMMUTABLE = $is_immutable:expr)? ) => {{ unsafe_impl!( @inner $(#[$attr])* $($tyvar $(: $(? $optbound +)* + $($bound +)*)?,)* - => $trait for $ty $(; |$candidate| $is_bit_valid)? + => $trait for $ty $(; |$candidate| $is_bit_valid)? $(; IS_IMMUTABLE = $is_immutable)? ); }}; ( @@ -117,7 +125,9 @@ macro_rules! unsafe_impl { $(#[$attr:meta])* $(@const $constname:ident : $constty:ident,)* $($tyvar:ident $(: $(? $optbound:ident +)* + $($bound:ident +)* )?,)* - => $trait:ident for $ty:ty $(; |$candidate:ident| $is_bit_valid:expr)? + => $trait:ident for $ty:ty + $(;)? $(|$candidate:ident| $is_bit_valid:expr)? + $(; IS_IMMUTABLE = $is_immutable:expr)? ) => {{ crate::util::macros::__unsafe(); @@ -125,32 +135,42 @@ macro_rules! unsafe_impl { #[allow(non_local_definitions)] // SAFETY: The caller promises that this is sound. unsafe impl<$($tyvar $(: $(? $optbound +)* $($bound +)*)?),* $(, const $constname: $constty,)*> $trait for $ty { - unsafe_impl!(@method $trait $(; |$candidate| $is_bit_valid)?); + unsafe_impl!(@items $trait $(; |$candidate| $is_bit_valid)? $(; IS_IMMUTABLE = $is_immutable)?); } }}; - (@method TryFromBytes ; |$candidate:ident| $is_bit_valid:expr) => { + (@items TryFromBytes ; |$candidate:ident| $is_bit_valid:expr ; IS_IMMUTABLE = $is_immutable:expr) => { #[allow(clippy::missing_inline_in_public_items, dead_code)] #[cfg_attr(all(coverage_nightly, __ZEROCOPY_INTERNAL_USE_ONLY_NIGHTLY_FEATURES_IN_TESTS), coverage(off))] fn only_derive_is_allowed_to_implement_this_trait() {} + // TODO: THIS IS UNSOUND. We are just using it to unblock ourselves. + const IS_IMMUTABLE: bool = $is_immutable; + #[inline] fn is_bit_valid($candidate: Maybe<'_, Self, AA>) -> bool { $is_bit_valid } }; - (@method TryFromBytes) => { + (@items TryFromBytes ; IS_IMMUTABLE = $is_immutable:expr) => { #[allow(clippy::missing_inline_in_public_items)] #[cfg_attr(all(coverage_nightly, __ZEROCOPY_INTERNAL_USE_ONLY_NIGHTLY_FEATURES_IN_TESTS), coverage(off))] fn only_derive_is_allowed_to_implement_this_trait() {} + + // TODO: THIS IS UNSOUND. We are just using it to unblock ourselves. + const IS_IMMUTABLE: bool = $is_immutable; + #[inline(always)] fn is_bit_valid(_: Maybe<'_, Self, AA>) -> bool { true } }; - (@method $trait:ident) => { + (@items TryFromBytes; |$_candidate:ident| $_is_bit_valid:expr) => { + compile_error!("Must provide `IS_IMMUTABLE` for `TryFromBytes`"); + }; + (@items $trait:ident) => { #[allow(clippy::missing_inline_in_public_items, dead_code)] #[cfg_attr(all(coverage_nightly, __ZEROCOPY_INTERNAL_USE_ONLY_NIGHTLY_FEATURES_IN_TESTS), coverage(off))] fn only_derive_is_allowed_to_implement_this_trait() {} }; - (@method $trait:ident; |$_candidate:ident| $_is_bit_valid:expr) => { + (@items $trait:ident; |$_candidate:ident| $_is_bit_valid:expr ; IS_IMMUTABLE = $_is_immutable:expr) => { compile_error!("Can't provide `is_bit_valid` impl for trait other than `TryFromBytes`"); }; } @@ -217,6 +237,10 @@ macro_rules! impl_for_transmute_from { $(<$tyvar:ident $(: $(? $optbound:ident $(+)?)* $($bound:ident $(+)?)* )?>)? TryFromBytes for $ty:ty [UnsafeCell<$repr:ty>] ) => { + // TODO: THIS IS UNSOUND because we don't know that `$ty` and `$repr` + // have the same interior mutability. + const IS_IMMUTABLE: bool = <$repr as TryFromBytes>::IS_IMMUTABLE; + #[inline] fn is_bit_valid(candidate: Maybe<'_, Self, A>) -> bool { let c: Maybe<'_, Self, crate::pointer::invariant::Exclusive> = candidate.into_exclusive_or_pme(); @@ -232,6 +256,10 @@ macro_rules! impl_for_transmute_from { $(<$tyvar:ident $(: $(? $optbound:ident $(+)?)* $($bound:ident $(+)?)* )?>)? TryFromBytes for $ty:ty [<$repr:ty>] ) => { + // TODO: THIS IS UNSOUND because we don't know that `$ty` and `$repr` + // have the same interior mutability. + const IS_IMMUTABLE: bool = <$repr as TryFromBytes>::IS_IMMUTABLE; + #[inline] fn is_bit_valid(candidate: Maybe<'_, Self, A>) -> bool { // SAFETY: This macro ensures that `$repr` and `Self` have the same @@ -270,33 +298,40 @@ macro_rules! impl_for_transmute_from { macro_rules! unsafe_impl_for_power_set { ( $first:ident $(, $rest:ident)* $(-> $ret:ident)? => $trait:ident for $macro:ident!(...) - $(; |$candidate:ident| $is_bit_valid:expr)? + $(;)? $(|$candidate:ident| $is_bit_valid:expr)? + $(; IS_IMMUTABLE = $is_immutable:expr)? ) => { unsafe_impl_for_power_set!( $($rest),* $(-> $ret)? => $trait for $macro!(...) $(; |$candidate| $is_bit_valid)? + $(; IS_IMMUTABLE = $is_immutable)? ); unsafe_impl_for_power_set!( @impl $first $(, $rest)* $(-> $ret)? => $trait for $macro!(...) $(; |$candidate| $is_bit_valid)? + $(; IS_IMMUTABLE = $is_immutable)? ); }; ( $(-> $ret:ident)? => $trait:ident for $macro:ident!(...) - $(; |$candidate:ident| $is_bit_valid:expr)? + $(;)? $(|$candidate:ident| $is_bit_valid:expr)? + $(; IS_IMMUTABLE = $is_immutable:expr)? ) => { unsafe_impl_for_power_set!( @impl $(-> $ret)? => $trait for $macro!(...) $(; |$candidate| $is_bit_valid)? + $(; IS_IMMUTABLE = $is_immutable)? ); }; ( @impl $($vars:ident),* $(-> $ret:ident)? => $trait:ident for $macro:ident!(...) - $(; |$candidate:ident| $is_bit_valid:expr)? + $(;)? $(|$candidate:ident| $is_bit_valid:expr)? + $(; IS_IMMUTABLE = $is_immutable:expr)? ) => { unsafe_impl!( $($vars,)* $($ret)? => $trait for $macro!($($vars),* $(-> $ret)?) $(; |$candidate| $is_bit_valid)? + $(; IS_IMMUTABLE = $is_immutable)? ); }; } diff --git a/zerocopy-derive/src/enum.rs b/zerocopy-derive/src/enum.rs index 264cd0a042..c1d1853817 100644 --- a/zerocopy-derive/src/enum.rs +++ b/zerocopy-derive/src/enum.rs @@ -8,7 +8,7 @@ use proc_macro2::{Span, TokenStream}; use quote::quote; -use syn::{parse_quote, DataEnum, Error, Fields, Generics, Ident, Path}; +use syn::{parse_quote, DataEnum, DeriveInput, Error, Fields, Generics, Ident, Path}; use crate::{derive_try_from_bytes_inner, repr::EnumRepr, Trait}; @@ -210,6 +210,7 @@ fn generate_variants_union(generics: &Generics, data: &DataEnum) -> TokenStream /// - `repr(int)`: /// - `repr(C, int)`: pub(crate) fn derive_is_bit_valid( + ast: &DeriveInput, enum_ident: &Ident, repr: &EnumRepr, generics: &Generics, @@ -280,7 +281,10 @@ pub(crate) fn derive_is_bit_valid( } }); + let is_immutable = crate::gen_is_immutable(ast, zerocopy_crate); Ok(quote! { + #is_immutable + // SAFETY: We use `is_bit_valid` to validate that the bit pattern of the // enum's tag corresponds to one of the enum's discriminants. Then, we // check the bit validity of each field of the corresponding variant. diff --git a/zerocopy-derive/src/lib.rs b/zerocopy-derive/src/lib.rs index 9f8568e790..96520a1a0e 100644 --- a/zerocopy-derive/src/lib.rs +++ b/zerocopy-derive/src/lib.rs @@ -745,7 +745,11 @@ fn derive_try_from_bytes_struct( let fields = strct.fields(); let field_names = fields.iter().map(|(_vis, name, _ty)| name); let field_tys = fields.iter().map(|(_vis, _name, ty)| ty); + + let is_immutable = gen_is_immutable(ast, zerocopy_crate); quote!( + #is_immutable + // SAFETY: We use `is_bit_valid` to validate that each field is // bit-valid, and only return `true` if all of them are. The bit // validity of a struct is just the composition of the bit @@ -808,15 +812,15 @@ fn derive_try_from_bytes_union( top_level: Trait, zerocopy_crate: &Path, ) -> TokenStream { - // FIXME(#5): Remove the `Immutable` bound. - let field_type_trait_bounds = - FieldBounds::All(&[TraitBound::Slf, TraitBound::Other(Trait::Immutable)]); let extras = try_gen_trivial_is_bit_valid(ast, top_level, zerocopy_crate).unwrap_or_else(|| { let fields = unn.fields(); let field_names = fields.iter().map(|(_vis, name, _ty)| name); let field_tys = fields.iter().map(|(_vis, _name, ty)| ty); + let is_immutable = gen_is_immutable(ast, zerocopy_crate); quote!( + #is_immutable + // SAFETY: We use `is_bit_valid` to validate that any field is // bit-valid; we only return `true` if at least one of them is. The // bit validity of a union is not yet well defined in Rust, but it @@ -828,16 +832,35 @@ fn derive_try_from_bytes_union( where ___ZerocopyAliasing: #zerocopy_crate::pointer::invariant::Reference, { - use #zerocopy_crate::util::macro_util::core_reexport; + use #zerocopy_crate::{ + util::macro_util::core_reexport, + pointer::invariant::Aliasing + }; + + trait ConstAssert { + const IS_READ: bool; + } + + // TODO: What do we name `TryFromBytes`? + impl ConstAssert for (T, A) { + const IS_READ: bool = { + assert!(T::IS_IMMUTABLE || A::IS_EXCLUSIVE); + true + }; + } + + assert!(<(Self, ___ZerocopyAliasing) as ConstAssert>::IS_READ); false #(|| { // SAFETY: // - `project` is a field projection, and so it addresses a // subset of the bytes addressed by `slf` // - ..., and so it preserves provenance - // - Since `Self: Immutable` is enforced by - // `self_type_trait_bounds`, neither `*slf` nor the - // returned pointer's referent contain any `UnsafeCell`s + // - By the preceding assert, it is either the case that + // `Self: Immutable` or that `___ZerocopyAliasing` is + // `Exclusive`. In the former case, `Self: Immutable` + // ensures that `*slf` nor the returned pointer's + // referent contain any `UnsafeCell`s. let field_candidate = unsafe { let project = |slf: core_reexport::ptr::NonNull| { let slf = slf.as_ptr(); @@ -860,7 +883,7 @@ fn derive_try_from_bytes_union( } ) }); - ImplBlockBuilder::new(ast, unn, Trait::TryFromBytes, field_type_trait_bounds, zerocopy_crate) + ImplBlockBuilder::new(ast, unn, Trait::TryFromBytes, FieldBounds::ALL_SELF, zerocopy_crate) .inner_extras(extras) .build() } @@ -887,9 +910,9 @@ fn derive_try_from_bytes_enum( (Some(is_bit_valid), _) => is_bit_valid, // SAFETY: It would be sound for the enum to implement `FomBytes`, as // required by `gen_trivial_is_bit_valid_unchecked`. - (None, true) => unsafe { gen_trivial_is_bit_valid_unchecked(zerocopy_crate) }, + (None, true) => unsafe { gen_trivial_is_bit_valid_unchecked(ast, zerocopy_crate) }, (None, false) => { - r#enum::derive_is_bit_valid(&ast.ident, &repr, &ast.generics, enm, zerocopy_crate)? + r#enum::derive_is_bit_valid(ast, &ast.ident, &repr, &ast.generics, enm, zerocopy_crate)? } }; @@ -932,8 +955,12 @@ fn try_gen_trivial_is_bit_valid( // make this no longer true. To hedge against these, we include an explicit // `Self: FromBytes` check in the generated `is_bit_valid`, which is // bulletproof. + + let is_immutable = gen_is_immutable(ast, zerocopy_crate); if top_level == Trait::FromBytes && ast.generics.params.is_empty() { Some(quote!( + #is_immutable + // SAFETY: See inline. fn is_bit_valid<___ZerocopyAliasing>( _candidate: #zerocopy_crate::Maybe, @@ -963,6 +990,21 @@ fn try_gen_trivial_is_bit_valid( } } +fn gen_is_immutable(ast: &DeriveInput, zerocopy_crate: &Path) -> proc_macro2::TokenStream { + let fields = match &ast.data { + Data::Struct(strct) => strct.fields(), + Data::Enum(enm) => enm.fields(), + Data::Union(unn) => unn.fields(), + }; + let field_tys = fields.iter().map(|(_vis, _name, ty)| ty); + + quote!( + const IS_IMMUTABLE: bool = true #( + && <#field_tys as #zerocopy_crate::TryFromBytes>::IS_IMMUTABLE + )*; + ) +} + /// Generates a `TryFromBytes::is_bit_valid` instance that unconditionally /// returns true. /// @@ -974,8 +1016,14 @@ fn try_gen_trivial_is_bit_valid( /// /// The caller must ensure that all initialized bit patterns are valid for /// `Self`. -unsafe fn gen_trivial_is_bit_valid_unchecked(zerocopy_crate: &Path) -> proc_macro2::TokenStream { +unsafe fn gen_trivial_is_bit_valid_unchecked( + ast: &DeriveInput, + zerocopy_crate: &Path, +) -> proc_macro2::TokenStream { + let is_immutable = gen_is_immutable(ast, zerocopy_crate); quote!( + #is_immutable + // SAFETY: The caller of `gen_trivial_is_bit_valid_unchecked` has // promised that all initialized bit patterns are valid for `Self`. fn is_bit_valid<___ZerocopyAliasing>( @@ -1146,12 +1194,7 @@ fn derive_from_zeros_union( unn: &DataUnion, zerocopy_crate: &Path, ) -> TokenStream { - // FIXME(#5): Remove the `Immutable` bound. It's only necessary for - // compatibility with `derive(TryFromBytes)` on unions; not for soundness. - let field_type_trait_bounds = - FieldBounds::All(&[TraitBound::Slf, TraitBound::Other(Trait::Immutable)]); - ImplBlockBuilder::new(ast, unn, Trait::FromZeros, field_type_trait_bounds, zerocopy_crate) - .build() + ImplBlockBuilder::new(ast, unn, Trait::FromZeros, FieldBounds::ALL_SELF, zerocopy_crate).build() } /// A struct is `FromBytes` if: @@ -1223,12 +1266,7 @@ fn derive_from_bytes_union( unn: &DataUnion, zerocopy_crate: &Path, ) -> TokenStream { - // FIXME(#5): Remove the `Immutable` bound. It's only necessary for - // compatibility with `derive(TryFromBytes)` on unions; not for soundness. - let field_type_trait_bounds = - FieldBounds::All(&[TraitBound::Slf, TraitBound::Other(Trait::Immutable)]); - ImplBlockBuilder::new(ast, unn, Trait::FromBytes, field_type_trait_bounds, zerocopy_crate) - .build() + ImplBlockBuilder::new(ast, unn, Trait::FromBytes, FieldBounds::ALL_SELF, zerocopy_crate).build() } fn derive_into_bytes_struct( diff --git a/zerocopy-derive/tests/union_from_bytes.rs b/zerocopy-derive/tests/union_from_bytes.rs index f8482248b1..f393292daa 100644 --- a/zerocopy-derive/tests/union_from_bytes.rs +++ b/zerocopy-derive/tests/union_from_bytes.rs @@ -15,7 +15,7 @@ include!("include.rs"); // A union is `imp::FromBytes` if: // - all fields are `imp::FromBytes` -#[derive(Clone, Copy, imp::Immutable, imp::FromBytes)] +#[derive(Clone, Copy, imp::FromBytes)] union Zst { a: (), } @@ -23,7 +23,7 @@ union Zst { util_assert_impl_all!(Zst: imp::FromBytes); test_trivial_is_bit_valid!(Zst => test_zst_trivial_is_bit_valid); -#[derive(imp::Immutable, imp::FromBytes)] +#[derive(imp::FromBytes)] union One { a: u8, } @@ -31,7 +31,7 @@ union One { util_assert_impl_all!(One: imp::FromBytes); test_trivial_is_bit_valid!(One => test_one_trivial_is_bit_valid); -#[derive(imp::Immutable, imp::FromBytes)] +#[derive(imp::FromBytes)] union Two { a: u8, b: Zst, @@ -40,7 +40,7 @@ union Two { util_assert_impl_all!(Two: imp::FromBytes); test_trivial_is_bit_valid!(Two => test_two_trivial_is_bit_valid); -#[derive(imp::Immutable, imp::FromBytes)] +#[derive(imp::FromBytes)] union TypeParams<'a, T: imp::Copy, I: imp::Iterator> where I::Item: imp::Copy, @@ -58,7 +58,7 @@ test_trivial_is_bit_valid!(TypeParams<'static, (), imp::IntoIter<()>> => test_ty // Deriving `imp::FromBytes` should work if the union has bounded parameters. -#[derive(imp::Immutable, imp::FromBytes)] +#[derive(imp::FromBytes)] #[repr(C)] union WithParams<'a: 'b, 'b: 'a, T: 'a + 'b + imp::FromBytes, const N: usize> where diff --git a/zerocopy-derive/tests/union_from_zeros.rs b/zerocopy-derive/tests/union_from_zeros.rs index 4f5b8e17be..12234d5ccf 100644 --- a/zerocopy-derive/tests/union_from_zeros.rs +++ b/zerocopy-derive/tests/union_from_zeros.rs @@ -15,21 +15,21 @@ include!("include.rs"); // A union is `imp::FromZeros` if: // - all fields are `imp::FromZeros` -#[derive(Clone, Copy, imp::Immutable, imp::FromZeros)] +#[derive(Clone, Copy, imp::FromZeros)] union Zst { a: (), } util_assert_impl_all!(Zst: imp::FromZeros); -#[derive(imp::Immutable, imp::FromZeros)] +#[derive(imp::FromZeros)] union One { a: bool, } util_assert_impl_all!(One: imp::FromZeros); -#[derive(imp::Immutable, imp::FromZeros)] +#[derive(imp::FromZeros)] union Two { a: bool, b: Zst, @@ -37,7 +37,7 @@ union Two { util_assert_impl_all!(Two: imp::FromZeros); -#[derive(imp::Immutable, imp::FromZeros)] +#[derive(imp::FromZeros)] union TypeParams<'a, T: imp::Copy, I: imp::Iterator> where I::Item: imp::Copy, @@ -54,7 +54,7 @@ util_assert_impl_all!(TypeParams<'static, (), imp::IntoIter<()>>: imp::FromZeros // Deriving `imp::FromZeros` should work if the union has bounded parameters. -#[derive(imp::Immutable, imp::FromZeros)] +#[derive(imp::FromZeros)] #[repr(C)] union WithParams<'a: 'b, 'b: 'a, T: 'a + 'b + imp::FromZeros, const N: usize> where diff --git a/zerocopy-derive/tests/union_try_from_bytes.rs b/zerocopy-derive/tests/union_try_from_bytes.rs index 80bae235ba..1060057e19 100644 --- a/zerocopy-derive/tests/union_try_from_bytes.rs +++ b/zerocopy-derive/tests/union_try_from_bytes.rs @@ -15,7 +15,7 @@ include!("include.rs"); // A struct is `imp::TryFromBytes` if: // - any of its fields are `imp::TryFromBytes` -#[derive(imp::Immutable, imp::TryFromBytes)] +#[derive(imp::TryFromBytes)] union One { a: u8, } @@ -33,7 +33,7 @@ fn one() { assert!(is_bit_valid); } -#[derive(imp::Immutable, imp::TryFromBytes)] +#[derive(imp::TryFromBytes)] #[repr(C)] union Two { a: bool, @@ -82,7 +82,7 @@ fn two_bad() { assert!(!is_bit_valid); } -#[derive(imp::Immutable, imp::TryFromBytes)] +#[derive(imp::TryFromBytes)] #[repr(C)] union BoolAndZst { a: bool, @@ -140,7 +140,7 @@ fn test_maybe_from_bytes() { imp::assert!(!is_bit_valid); } -#[derive(imp::Immutable, imp::TryFromBytes)] +#[derive(imp::TryFromBytes)] #[repr(C)] union TypeParams<'a, T: imp::Copy, I: imp::Iterator> where @@ -160,7 +160,7 @@ util_assert_impl_all!(TypeParams<'static, [util::AU16; 2], imp::IntoIter<()>>: i // Deriving `imp::TryFromBytes` should work if the union has bounded parameters. -#[derive(imp::Immutable, imp::TryFromBytes)] +#[derive(imp::TryFromBytes)] #[repr(C)] union WithParams<'a: 'b, 'b: 'a, T: 'a + 'b + imp::TryFromBytes, const N: usize> where @@ -174,7 +174,7 @@ where util_assert_impl_all!(WithParams<'static, 'static, u8, 42>: imp::TryFromBytes); -#[derive(Clone, Copy, imp::TryFromBytes, imp::Immutable)] +#[derive(Clone, Copy, imp::TryFromBytes)] struct A; #[derive(imp::TryFromBytes)]