Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 90 additions & 17 deletions src/idl_gen_rust.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ std::string AddUnwrapIfRequired(std::string s, bool required) {
}
}

bool IsBitFlagsEnum(const EnumDef &enum_def) {
return enum_def.attributes.Lookup("bit_flags");
}
bool IsBitFlagsEnum(const FieldDef &field) {
EnumDef* ed = field.value.type.enum_def;
return ed && IsBitFlagsEnum(*ed);
}

namespace rust {

class RustGenerator : public BaseGenerator {
Expand Down Expand Up @@ -233,9 +241,17 @@ class RustGenerator : public BaseGenerator {

assert(!cur_name_space_);

bool import_bitflags = false;
for (auto it = parser_.enums_.vec.begin(); it != parser_.enums_.vec.end();
++it) {
if (IsBitFlagsEnum(**it)) {
import_bitflags = true;
break;
}
}
// Generate imports for the global scope in case no namespace is used
// in the schema file.
GenNamespaceImports(0);
GenNamespaceImports(0, import_bitflags);
code_ += "";

// Generate all code in their namespaces, once, because Rust does not
Expand Down Expand Up @@ -515,9 +531,25 @@ class RustGenerator : public BaseGenerator {

std::string GetEnumValUse(const EnumDef &enum_def,
const EnumVal &enum_val) const {
return Name(enum_def) + "::" + Name(enum_val);
const std::string val = IsBitFlagsEnum(enum_def) ?
MakeUpper(MakeSnakeCase(Name(enum_val))) : Name(enum_val);
return Name(enum_def) + "::" + val;
}


void ForAllEnumValues(const EnumDef &enum_def,
std::function<void(const EnumVal&)> cb) {
for (auto it = enum_def.Vals().begin(); it != enum_def.Vals().end(); ++it) {
const auto &ev = **it;
code_.SetValue("VARIANT", Name(ev));
code_.SetValue("SSC_VARIANT", MakeUpper(MakeSnakeCase(Name(ev))));
code_.SetValue("VALUE", enum_def.ToString(ev));
cb(ev);
}
}
void ForAllEnumValues(const EnumDef &enum_def, std::function<void()> cb) {
ForAllEnumValues(enum_def, [&](const EnumVal& unused) { cb(); });
}
// Generate an enum declaration,
// an enum string lookup table,
// an enum match function,
Expand All @@ -533,6 +565,52 @@ class RustGenerator : public BaseGenerator {
code_.SetValue("ENUM_MIN_BASE_VALUE", enum_def.ToString(*minv));
code_.SetValue("ENUM_MAX_BASE_VALUE", enum_def.ToString(*maxv));

if (IsBitFlagsEnum(enum_def)) {
// Defer to the convenient and canonical bitflags crate.
code_ += "bitflags::bitflags! {";
GenComment(enum_def.doc_comment);
code_ += " pub struct {{ENUM_NAME}}: {{BASE_TYPE}} {";
ForAllEnumValues(enum_def, [&]{
code_ += " const {{SSC_VARIANT}} = {{VALUE}};";
});
code_ += " }";
code_ += "}";
code_ += "";
// Generate Follow and Push so we can serialize and stuff.
code_ += "impl<'a> flatbuffers::Follow<'a> for {{ENUM_NAME}} {";
code_ += " type Inner = Self;";
code_ += " #[inline]";
code_ += " fn follow(buf: &'a [u8], loc: usize) -> Self::Inner {";
code_ += " let bits = flatbuffers::read_scalar_at::<{{BASE_TYPE}}>(buf, loc);";
code_ += " unsafe { Self::from_bits_unchecked(bits) }";
code_ += " }";
code_ += "}";
code_ += "";
code_ += "impl flatbuffers::Push for {{ENUM_NAME}} {";
code_ += " type Output = {{ENUM_NAME}};";
code_ += " #[inline]";
code_ += " fn push(&self, dst: &mut [u8], _rest: &[u8]) {";
code_ += " flatbuffers::emplace_scalar::<{{BASE_TYPE}}>"
"(dst, self.bits());";
code_ += " }";
code_ += "}";
code_ += "";
code_ += "impl flatbuffers::EndianScalar for {{ENUM_NAME}} {";
code_ += " #[inline]";
code_ += " fn to_little_endian(self) -> Self {";
code_ += " let bits = {{BASE_TYPE}}::to_le(self.bits());";
code_ += " unsafe { Self::from_bits_unchecked(bits) }";
code_ += " }";
code_ += " #[inline]";
code_ += " fn from_little_endian(self) -> Self {";
code_ += " let bits = {{BASE_TYPE}}::from_le(self.bits());";
code_ += " unsafe { Self::from_bits_unchecked(bits) }";
code_ += " }";
code_ += "}";
code_ += "";
return;
}

GenComment(enum_def.doc_comment);
code_ +=
"#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]";
Expand All @@ -542,26 +620,21 @@ class RustGenerator : public BaseGenerator {
code_ += " pub const ENUM_MIN: {{BASE_TYPE}} = {{ENUM_MIN_BASE_VALUE}};";
code_ += " pub const ENUM_MAX: {{BASE_TYPE}} = {{ENUM_MAX_BASE_VALUE}};";

for (auto it = enum_def.Vals().begin(); it != enum_def.Vals().end(); ++it) {
const auto &ev = **it;
ForAllEnumValues(enum_def, [&](const EnumVal &ev){
GenComment(ev.doc_comment, " ");
code_.SetValue("VARIANT", Name(ev));
code_.SetValue("VALUE", enum_def.ToString(ev));
code_ += " pub const {{VARIANT}}: Self = Self({{VALUE}});";
}
});
code_ += " pub const ENUM_VALUES: &'static [Self] = &[";
for (auto it = enum_def.Vals().begin(); it != enum_def.Vals().end(); ++it) {
code_.SetValue("VARIANT", Name(**it));
ForAllEnumValues(enum_def, [&]{
code_ += " Self::{{VARIANT}},";
}
});
code_ += " ];";
code_ += " /// Returns the variant's name or \"\" if unknown.";
code_ += " pub fn variant_name(self) -> &'static str {";
code_ += " match self {";
for (auto it = enum_def.Vals().begin(); it != enum_def.Vals().end(); ++it) {
code_.SetValue("VARIANT", Name(**it));
ForAllEnumValues(enum_def, [&]{
code_ += " Self::{{VARIANT}} => \"{{VARIANT}}\",";
}
});
code_ += " _ => \"\",";
code_ += " }";
code_ += " }";
Expand Down Expand Up @@ -1004,9 +1077,8 @@ class RustGenerator : public BaseGenerator {
}
case ftUnionKey:
case ftEnumKey: {
const auto underlying_typname = GetTypeBasic(type); //<- never used
const auto typname = WrapInNameSpace(*type.enum_def);
const auto default_value = GetDefaultScalarValue(field);
const std::string typname = WrapInNameSpace(*type.enum_def);
const std::string default_value = GetDefaultScalarValue(field);
if (field.optional) {
return "self._tab.get::<" + typname + ">(" + offset_name + ", None)";
} else {
Expand Down Expand Up @@ -1745,7 +1817,7 @@ class RustGenerator : public BaseGenerator {
code_ += "";
}

void GenNamespaceImports(const int white_spaces) {
void GenNamespaceImports(const int white_spaces, bool bitflags=false) {
if (white_spaces == 0) {
code_ += "#![allow(unused_imports, dead_code)]";
}
Expand All @@ -1766,6 +1838,7 @@ class RustGenerator : public BaseGenerator {
code_ += indent + "use std::cmp::Ordering;";
code_ += "";
code_ += indent + "extern crate flatbuffers;";
if (bitflags) code_ += indent + "extern crate bitflags;";
code_ += indent + "use self::flatbuffers::EndianScalar;";
}

Expand Down
72 changes: 20 additions & 52 deletions tests/monster_test_generated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::mem;
use std::cmp::Ordering;

extern crate flatbuffers;
extern crate bitflags;
use self::flatbuffers::EndianScalar;

#[allow(unused_imports, dead_code)]
Expand Down Expand Up @@ -176,78 +177,45 @@ pub mod example {
extern crate flatbuffers;
use self::flatbuffers::EndianScalar;

bitflags::bitflags! {
/// Composite components of Monster color.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Color(pub u8);
#[allow(non_upper_case_globals)]
impl Color {
pub const ENUM_MIN: u8 = 1;
pub const ENUM_MAX: u8 = 8;
pub const Red: Self = Self(1);
/// \brief color Green
/// Green is bit_flag with value (1u << 1)
pub const Green: Self = Self(2);
/// \brief color Blue (1u << 3)
pub const Blue: Self = Self(8);
pub const ENUM_VALUES: &'static [Self] = &[
Self::Red,
Self::Green,
Self::Blue,
];
/// Returns the variant's name or "" if unknown.
pub fn variant_name(self) -> &'static str {
match self {
Self::Red => "Red",
Self::Green => "Green",
Self::Blue => "Blue",
_ => "",
}
}
}
impl std::fmt::Debug for Color {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let name = self.variant_name();
if name.is_empty() {
f.write_fmt(format_args!("<UNKNOWN {:?}>", self.0))
} else {
f.write_str(name)
}
pub struct Color: u8 {
const RED = 1;
const GREEN = 2;
const BLUE = 8;
}
}

impl<'a> flatbuffers::Follow<'a> for Color {
type Inner = Self;
#[inline]
fn follow(buf: &'a [u8], loc: usize) -> Self::Inner {
Self(flatbuffers::read_scalar_at::<u8>(buf, loc))
let bits = flatbuffers::read_scalar_at::<u8>(buf, loc);
unsafe { Self::from_bits_unchecked(bits) }
}
}

impl flatbuffers::Push for Color {
type Output = Color;
#[inline]
fn push(&self, dst: &mut [u8], _rest: &[u8]) {
flatbuffers::emplace_scalar::<u8>(dst, self.0);
flatbuffers::emplace_scalar::<u8>(dst, self.bits());
}
}

impl flatbuffers::EndianScalar for Color {
#[inline]
fn to_little_endian(self) -> Self {
Self(u8::to_le(self.0))
let bits = u8::to_le(self.bits());
unsafe { Self::from_bits_unchecked(bits) }
}
#[inline]
fn from_little_endian(self) -> Self {
Self(u8::from_le(self.0))
let bits = u8::from_le(self.bits());
unsafe { Self::from_bits_unchecked(bits) }
}
}

#[allow(non_camel_case_types)]
pub const ENUM_VALUES_COLOR: [Color; 3] = [
Color::Red,
Color::Green,
Color::Blue
];

#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Race(pub i8);
#[allow(non_upper_case_globals)]
Expand Down Expand Up @@ -811,7 +779,7 @@ impl<'a> TestSimpleTableWithEnum<'a> {

#[inline]
pub fn color(&self) -> Color {
self._tab.get::<Color>(TestSimpleTableWithEnum::VT_COLOR, Some(Color::Green)).unwrap()
self._tab.get::<Color>(TestSimpleTableWithEnum::VT_COLOR, Some(Color::GREEN)).unwrap()
}
}

Expand All @@ -822,7 +790,7 @@ impl<'a> Default for TestSimpleTableWithEnumArgs {
#[inline]
fn default() -> Self {
TestSimpleTableWithEnumArgs {
color: Color::Green,
color: Color::GREEN,
}
}
}
Expand All @@ -833,7 +801,7 @@ pub struct TestSimpleTableWithEnumBuilder<'a: 'b, 'b> {
impl<'a: 'b, 'b> TestSimpleTableWithEnumBuilder<'a, 'b> {
#[inline]
pub fn add_color(&mut self, color: Color) {
self.fbb_.push_slot::<Color>(TestSimpleTableWithEnum::VT_COLOR, color, Color::Green);
self.fbb_.push_slot::<Color>(TestSimpleTableWithEnum::VT_COLOR, color, Color::GREEN);
}
#[inline]
pub fn new(_fbb: &'b mut flatbuffers::FlatBufferBuilder<'a>) -> TestSimpleTableWithEnumBuilder<'a, 'b> {
Expand Down Expand Up @@ -1204,7 +1172,7 @@ impl<'a> Monster<'a> {
}
#[inline]
pub fn color(&self) -> Color {
self._tab.get::<Color>(Monster::VT_COLOR, Some(Color::Blue)).unwrap()
self._tab.get::<Color>(Monster::VT_COLOR, Some(Color::BLUE)).unwrap()
}
#[inline]
pub fn test_type(&self) -> Any {
Expand Down Expand Up @@ -1536,7 +1504,7 @@ impl<'a> Default for MonsterArgs<'a> {
hp: 100,
name: None, // required field
inventory: None,
color: Color::Blue,
color: Color::BLUE,
test_type: Any::NONE,
test: None,
test4: None,
Expand Down Expand Up @@ -1609,7 +1577,7 @@ impl<'a: 'b, 'b> MonsterBuilder<'a, 'b> {
}
#[inline]
pub fn add_color(&mut self, color: Color) {
self.fbb_.push_slot::<Color>(Monster::VT_COLOR, color, Color::Blue);
self.fbb_.push_slot::<Color>(Monster::VT_COLOR, color, Color::BLUE);
}
#[inline]
pub fn add_test_type(&mut self, test_type: Any) {
Expand Down
1 change: 1 addition & 0 deletions tests/rust_usage_test/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ flexbuffers = { path = "../../rust/flexbuffers" }
serde_derive = "1.0"
serde = "1.0"
serde_bytes = "0.11"
bitflags = "1.2"

[[bin]]
name = "monster_example"
Expand Down
4 changes: 2 additions & 2 deletions tests/rust_usage_test/bin/flatbuffers_alloc_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ fn create_serialized_example_with_generated_code(builder: &mut flatbuffers::Flat
2.0,
3.0,
3.0,
my_game::example::Color::Green,
my_game::example::Color::GREEN,
&my_game::example::Test::new(5i16, 6i8),
);

Expand Down Expand Up @@ -144,7 +144,7 @@ fn main() {
assert!((pos.y() - 2.0f32).abs() < std::f32::EPSILON);
assert!((pos.z() - 3.0f32).abs() < std::f32::EPSILON);
assert!((pos.test1() - 3.0f64).abs() < std::f64::EPSILON);
assert_eq!(pos.test2(), my_game::example::Color::Green);
assert_eq!(pos.test2(), my_game::example::Color::GREEN);
let pos_test3 = pos.test3();
assert_eq!(pos_test3.a(), 5i16);
assert_eq!(pos_test3.b(), 6i8);
Expand Down
Loading