Skip to content

Commit 0e11472

Browse files
Freax13fu5ha
andauthored
add support for deriving NoUninit on enums with fields (#292)
* add support for deriving NoUninit on enums with fields We check for padding like we do for structs except that we also consider the enum discriminant when calculating the unpadded size. * improve support for #[repr(C)] enums Unfortunately, the integer discriminant used for #[repr(C)] doesn't have a name (it's not always core::ffi::c_int), but we can use some compile-time tricks to get the integer type. Use that instead of hard- coding core::ffi::c_int and add support for deriving NoUninit for #[repr(C)] enums. * simplify get_enum_discriminant for enums w/o fields * add more comments Co-authored-by: Gray Olson <[email protected]> * update requirements for NoUninit * link type layout section in NoUninit docs * small wording change on no uninit docs * inline type alias This makes the generated code a bit harder to read, but also has the advantage of not unnecessarily adding a type to the global namespace for CheckedBitPattern. --------- Co-authored-by: Gray Olson <[email protected]>
1 parent ebb1326 commit 0e11472

6 files changed

Lines changed: 256 additions & 32 deletions

File tree

derive/src/lib.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,12 @@ pub fn derive_zeroable(
260260
///
261261
/// If applied to an enum:
262262
/// - The enum must be explicit `#[repr(Int)]`, `#[repr(C)]`, or both
263-
/// - All variants must be fieldless
263+
/// - If the enum has fields:
264+
/// - All fields must implement `NoUninit`
265+
/// - All variants must not contain any padding bytes
266+
/// - All variants must be of the the same size
267+
/// - There must be no padding bytes between the discriminant and any of the
268+
/// variant fields
264269
/// - The enum must contain no generic parameters
265270
#[proc_macro_derive(NoUninit, attributes(bytemuck))]
266271
pub fn derive_no_uninit(

derive/src/traits.rs

Lines changed: 178 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ impl Derivable for Pod {
8080
match &input.data {
8181
Data::Struct(_) => {
8282
let assert_no_padding = if !completly_packed {
83-
Some(generate_assert_no_padding(input)?)
83+
Some(generate_assert_no_padding(input, None)?)
8484
} else {
8585
None
8686
};
@@ -237,10 +237,18 @@ impl Derivable for NoUninit {
237237
Repr::C | Repr::Transparent => Ok(()),
238238
_ => bail!("NoUninit requires the struct to be #[repr(C)] or #[repr(transparent)]"),
239239
},
240-
Data::Enum(_) => if repr.repr.is_integer() {
241-
Ok(())
242-
} else {
243-
bail!("NoUninit requires the enum to be an explicit #[repr(Int)]")
240+
Data::Enum(DataEnum { variants,.. }) => {
241+
if !enum_has_fields(variants.iter()) {
242+
if matches!(repr.repr, Repr::C | Repr::Integer(_)) {
243+
Ok(())
244+
} else {
245+
bail!("NoUninit requires the enum to be #[repr(C)] or #[repr(Int)]")
246+
}
247+
} else if matches!(repr.repr, Repr::Rust) {
248+
bail!("NoUninit requires an explicit repr annotation because `repr(Rust)` doesn't have a specified type layout")
249+
} else {
250+
Ok(())
251+
}
244252
},
245253
Data::Union(_) => bail!("NoUninit can only be derived on enums and structs")
246254
}
@@ -255,7 +263,7 @@ impl Derivable for NoUninit {
255263

256264
match &input.data {
257265
Data::Struct(DataStruct { .. }) => {
258-
let assert_no_padding = generate_assert_no_padding(&input)?;
266+
let assert_no_padding = generate_assert_no_padding(&input, None)?;
259267
let assert_fields_are_no_padding = generate_fields_are_trait(
260268
&input,
261269
None,
@@ -268,8 +276,61 @@ impl Derivable for NoUninit {
268276
))
269277
}
270278
Data::Enum(DataEnum { variants, .. }) => {
271-
if variants.iter().any(|variant| !variant.fields.is_empty()) {
272-
bail!("Only fieldless enums are supported for NoUninit")
279+
if enum_has_fields(variants.iter()) {
280+
// There are two different C representations for enums with fields:
281+
// There's `#[repr(C)]`/`[repr(C, int)]` and `#[repr(int)]`.
282+
// `#[repr(C)]` is equivalent to a struct containing the discriminant
283+
// and a union of structs representing each variant's fields.
284+
// `#[repr(C)]` is equivalent to a union containing structs of the
285+
// discriminant and the fields.
286+
//
287+
// See https://doc.rust-lang.org/reference/type-layout.html#r-layout.repr.c.adt
288+
// and https://doc.rust-lang.org/reference/type-layout.html#r-layout.repr.primitive.adt
289+
//
290+
// In practice the only difference between the two is whether and
291+
// where padding bytes are placed. For `#[repr(C)]` enums, the first
292+
// enum fields of all variants start at the same location (the first
293+
// byte in the union). For `#[repr(int)]` enums, the structs
294+
// representing each variant are layed out individually and padding
295+
// does not depend on other variants, but only on the size of the
296+
// discriminant and the alignment of the first field. The location of
297+
// the first field might differ between variants, potentially
298+
// resulting in less padding or padding placed later in the enum.
299+
//
300+
// The `NoUninit` derive macro asserts that no padding exists by
301+
// removing all padding with `#[repr(packed)]` and checking that this
302+
// doesn't change the size. Since the location and presence of
303+
// padding bytes is the only difference between the two
304+
// representations and we're removing all padding bytes, the resuling
305+
// layout would identical for both representations. This means that
306+
// we can just pick one of the representations and don't have to
307+
// implement desugaring for both. We chose to implement the
308+
// desugaring for `#[repr(int)]`.
309+
310+
let enum_discriminant = generate_enum_discriminant(input)?;
311+
let variant_assertions = variants
312+
.iter()
313+
.map(|variant| {
314+
let assert_no_padding =
315+
generate_assert_no_padding(&input, Some(variant))?;
316+
let assert_fields_are_no_padding = generate_fields_are_trait(
317+
&input,
318+
Some(variant),
319+
Self::ident(input, crate_name)?,
320+
)?;
321+
322+
Ok(quote!(
323+
#assert_no_padding
324+
#assert_fields_are_no_padding
325+
))
326+
})
327+
.collect::<Result<Vec<_>>>()?;
328+
Ok(quote! {
329+
const _: () = {
330+
#enum_discriminant
331+
#(#variant_assertions)*
332+
};
333+
})
273334
} else {
274335
Ok(quote!())
275336
}
@@ -301,10 +362,10 @@ impl Derivable for CheckedBitPattern {
301362
},
302363
Data::Enum(DataEnum { variants,.. }) => {
303364
if !enum_has_fields(variants.iter()){
304-
if repr.repr.is_integer() {
365+
if matches!(repr.repr, Repr::C | Repr::Integer(_)) {
305366
Ok(())
306367
} else {
307-
bail!("CheckedBitPattern requires the enum to be an explicit #[repr(Int)]")
368+
bail!("CheckedBitPattern requires the enum to be #[repr(C)] or #[repr(Int)]")
308369
}
309370
} else if matches!(repr.repr, Repr::Rust) {
310371
bail!("CheckedBitPattern requires an explicit repr annotation because `repr(Rust)` doesn't have a specified type layout")
@@ -648,12 +709,15 @@ fn generate_checked_bit_pattern_enum(
648709
if enum_has_fields(variants.iter()) {
649710
generate_checked_bit_pattern_enum_with_fields(input, variants, crate_name)
650711
} else {
651-
generate_checked_bit_pattern_enum_without_fields(input, variants)
712+
generate_checked_bit_pattern_enum_without_fields(
713+
input, variants, crate_name,
714+
)
652715
}
653716
}
654717

655718
fn generate_checked_bit_pattern_enum_without_fields(
656719
input: &DeriveInput, variants: &Punctuated<Variant, Token![,]>,
720+
crate_name: &TokenStream,
657721
) -> Result<(TokenStream, TokenStream)> {
658722
let span = input.span();
659723
let mut variants_with_discriminant =
@@ -696,10 +760,9 @@ fn generate_checked_bit_pattern_enum_without_fields(
696760
quote!(matches!(*bits, #first #(| #rest )*))
697761
};
698762

699-
let repr = get_repr(&input.attrs)?;
700-
let integer = repr.repr.as_integer().unwrap(); // should be checked in attr check already
763+
let (integer, defs) = get_enum_discriminant(input, crate_name)?;
701764
Ok((
702-
quote!(),
765+
quote!(#defs),
703766
quote! {
704767
type Bits = #integer;
705768

@@ -721,12 +784,8 @@ fn generate_checked_bit_pattern_enum_with_fields(
721784

722785
match representation.repr {
723786
Repr::Rust => unreachable!(),
724-
repr @ (Repr::C | Repr::CWithDiscriminant(_)) => {
725-
let integer = match repr {
726-
Repr::C => quote!(::core::ffi::c_int),
727-
Repr::CWithDiscriminant(integer) => quote!(#integer),
728-
_ => unreachable!(),
729-
};
787+
Repr::C | Repr::CWithDiscriminant(_) => {
788+
let (integer, defs) = get_enum_discriminant(input, crate_name)?;
730789
let input_ident = &input.ident;
731790

732791
let bits_repr = Representation { repr: Repr::C, ..representation };
@@ -796,6 +855,8 @@ fn generate_checked_bit_pattern_enum_with_fields(
796855

797856
Ok((
798857
quote! {
858+
#defs
859+
799860
#[doc = #GENERATED_TYPE_DOCUMENTATION]
800861
#[derive(::core::clone::Clone, ::core::marker::Copy, #crate_name::AnyBitPattern)]
801862
#bits_repr
@@ -981,14 +1042,28 @@ fn generate_checked_bit_pattern_enum_with_fields(
9811042
}
9821043
}
9831044

984-
/// Check that a struct has no padding by asserting that the size of the struct
985-
/// is equal to the sum of the size of it's fields
986-
fn generate_assert_no_padding(input: &DeriveInput) -> Result<TokenStream> {
1045+
/// Check that a struct or enum has no padding by asserting that the size of
1046+
/// the type is equal to the sum of the size of it's fields and discriminant
1047+
/// (for enums, this must be asserted for each variant).
1048+
fn generate_assert_no_padding(
1049+
input: &DeriveInput, enum_variant: Option<&Variant>,
1050+
) -> Result<TokenStream> {
9871051
let struct_type = &input.ident;
988-
let enum_variant = None; // `no padding` check is not supported for `enum`s yet.
9891052
let fields = get_fields(input, enum_variant)?;
9901053

991-
let mut field_types = get_field_types(&fields);
1054+
// If the type is an enum, determine the type of its discriminant.
1055+
let enum_discriminant = if matches!(input.data, Data::Enum(_)) {
1056+
let ident =
1057+
Ident::new(&format!("{}Discriminant", input.ident), input.ident.span());
1058+
Some(ident.into_token_stream())
1059+
} else {
1060+
None
1061+
};
1062+
1063+
// Prepend the type of the discriminant to the types of the fields.
1064+
let mut field_types = enum_discriminant
1065+
.into_iter()
1066+
.chain(get_field_types(&fields).map(ToTokens::to_token_stream));
9921067
let size_sum = if let Some(first) = field_types.next() {
9931068
let size_first = quote!(::core::mem::size_of::<#first>());
9941069
let size_rest = quote!(#( + ::core::mem::size_of::<#field_types>() )*);
@@ -1024,6 +1099,84 @@ fn generate_fields_are_trait(
10241099
})
10251100
}
10261101

1102+
/// Get the type of an enum's discriminant.
1103+
///
1104+
/// For `repr(int)` and `repr(C, int)` enums, this will return the known bare
1105+
/// integer type specified.
1106+
///
1107+
/// For `repr(C)` enums, this will extract the underlying size chosen by rustc.
1108+
/// It will return a token stream which is a type expression that evaluates to
1109+
/// a primitive integer type of this size, using our `EnumTagIntegerBytes`
1110+
/// trait.
1111+
///
1112+
/// For fieldless `repr(C)` enums, we can feed the size of the enum directly
1113+
/// into the trait.
1114+
///
1115+
/// For `repr(C)` enums with fields, we generate a new fieldless `repr(C)` enum
1116+
/// with the same variants, then use that in the calculation. This is the
1117+
/// specified behavior, see https://doc.rust-lang.org/stable/reference/type-layout.html#reprc-enums-with-fields
1118+
///
1119+
/// Returns a tuple of (type ident, auxiliary definitions)
1120+
fn get_enum_discriminant(
1121+
input: &DeriveInput, crate_name: &TokenStream,
1122+
) -> Result<(TokenStream, TokenStream)> {
1123+
let repr = get_repr(&input.attrs)?;
1124+
match repr.repr {
1125+
Repr::C => {
1126+
let e = if let Data::Enum(e) = &input.data { e } else { unreachable!() };
1127+
if enum_has_fields(e.variants.iter()) {
1128+
// If the enum has fields, we must first isolate the discriminant by
1129+
// removing all the fields.
1130+
let enum_discriminant = generate_enum_discriminant(input)?;
1131+
let discriminant_ident = Ident::new(
1132+
&format!("{}Discriminant", input.ident),
1133+
input.ident.span(),
1134+
);
1135+
Ok((
1136+
quote!(<[::core::primitive::u8; ::core::mem::size_of::<#discriminant_ident>()] as #crate_name::derive::EnumTagIntegerBytes>::Integer),
1137+
quote! {
1138+
#enum_discriminant
1139+
},
1140+
))
1141+
} else {
1142+
// If the enum doesn't have fields, we can just use it directly.
1143+
let ident = &input.ident;
1144+
Ok((
1145+
quote!(<[::core::primitive::u8; ::core::mem::size_of::<#ident>()] as #crate_name::derive::EnumTagIntegerBytes>::Integer),
1146+
quote!(),
1147+
))
1148+
}
1149+
}
1150+
Repr::Integer(integer) | Repr::CWithDiscriminant(integer) => {
1151+
Ok((quote!(#integer), quote!()))
1152+
}
1153+
_ => unreachable!(),
1154+
}
1155+
}
1156+
1157+
fn generate_enum_discriminant(input: &DeriveInput) -> Result<TokenStream> {
1158+
let e = if let Data::Enum(e) = &input.data { e } else { unreachable!() };
1159+
let repr = get_repr(&input.attrs)?;
1160+
let repr = match repr.repr {
1161+
Repr::C => quote!(#[repr(C)]),
1162+
Repr::Integer(int) | Repr::CWithDiscriminant(int) => quote!(#[repr(#int)]),
1163+
Repr::Rust | Repr::Transparent => unreachable!(),
1164+
};
1165+
let ident =
1166+
Ident::new(&format!("{}Discriminant", input.ident), input.ident.span());
1167+
let variants = e.variants.iter().cloned().map(|mut e| {
1168+
e.fields = Fields::Unit;
1169+
e
1170+
});
1171+
Ok(quote! {
1172+
#repr
1173+
#[allow(dead_code)]
1174+
enum #ident {
1175+
#(#variants,)*
1176+
}
1177+
})
1178+
}
1179+
10271180
fn get_ident_from_stream(tokens: TokenStream) -> Option<Ident> {
10281181
match tokens.into_iter().next() {
10291182
Some(TokenTree::Group(group)) => get_ident_from_stream(group.stream()),
@@ -1139,10 +1292,6 @@ enum Repr {
11391292
}
11401293

11411294
impl Repr {
1142-
fn is_integer(&self) -> bool {
1143-
matches!(self, Self::Integer(..))
1144-
}
1145-
11461295
fn as_integer(&self) -> Option<IntegerRepr> {
11471296
if let Self::Integer(v) = self {
11481297
Some(*v)

derive/tests/basic.rs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,34 @@ struct CheckedBitPatternStruct {
201201
b: CheckedBitPatternEnumNonContiguous,
202202
}
203203

204+
#[derive(Debug, Copy, Clone, NoUninit)]
205+
#[repr(C)]
206+
enum NoUninitEnum {
207+
A,
208+
B,
209+
}
210+
211+
#[derive(Debug, Copy, Clone, NoUninit)]
212+
#[repr(C)]
213+
enum NoUninitEnumWithFields {
214+
A(u32, u32),
215+
B(u16, u16, u16, u16),
216+
}
217+
218+
#[derive(Debug, Copy, Clone, NoUninit)]
219+
#[repr(C, u16)]
220+
enum NoUninitEnumWithFieldsAndCAndDiscriminant {
221+
A(u16, u16),
222+
B(u8, u8, u8, u8),
223+
}
224+
225+
#[derive(Debug, Clone, Copy, NoUninit)]
226+
#[repr(u16)]
227+
enum NoUninitEnumWithFieldsAndDiscriminant {
228+
A(u16, u16),
229+
B(u8, u8, u8, u8),
230+
}
231+
204232
#[derive(Debug, Copy, Clone, AnyBitPattern, PartialEq, Eq)]
205233
#[repr(C)]
206234
struct AnyBitPatternTest<A: AnyBitPattern, B: AnyBitPattern> {
@@ -221,6 +249,13 @@ struct CheckedBitPatternPackedStruct {
221249
b: u16,
222250
}
223251

252+
#[derive(Debug, Clone, Copy, CheckedBitPattern, PartialEq, Eq)]
253+
#[repr(C)]
254+
enum CheckedBitPatternCDefaultDiscriminantEnum {
255+
A,
256+
B,
257+
}
258+
224259
#[derive(Debug, Clone, Copy, CheckedBitPattern, PartialEq, Eq)]
225260
#[repr(C)]
226261
enum CheckedBitPatternCDefaultDiscriminantEnumWithFields {

src/derive.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//! This module contains some helpers for the derive macros.
2+
3+
/// A trait that can be used to convert the type of a byte array to an integer
4+
/// type of the same size.
5+
pub trait EnumTagIntegerBytes {
6+
type Integer;
7+
}
8+
9+
macro_rules! enum_tag_integer_impls {
10+
($($ty:ty),*) => {
11+
$(
12+
impl EnumTagIntegerBytes for [u8; core::mem::size_of::<$ty>()] {
13+
type Integer = $ty;
14+
}
15+
)*
16+
};
17+
}
18+
19+
enum_tag_integer_impls!(u8, u16, u32, u64, u128);

src/lib.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,13 @@ mod offset_of;
229229
mod transparent;
230230
pub use transparent::*;
231231

232+
// This module is just an implementation detail for the derive macros. It needs
233+
// to be public to be usable from the macros, but it shouldn't be considered
234+
// part of bytemuck's public API.
235+
#[cfg(feature = "derive")]
236+
#[doc(hidden)]
237+
pub mod derive;
238+
232239
#[cfg(feature = "derive")]
233240
#[cfg_attr(feature = "nightly_docs", doc(cfg(feature = "derive")))]
234241
pub use bytemuck_derive::{

0 commit comments

Comments
 (0)