diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index 5e514d5c1..a30a89f8f 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -609,7 +609,12 @@ class Message(ABC): Calls :meth:`__bool__`. """ - _serialized_on_wire: bool + serialized_on_wire: bool + """ + If this message was or should be serialized on the wire. This can be used to detect + presence (e.g. optional wrapper message) and is used internally during + parsing/serialization. + """ _unknown_fields: bytes _group_current: Dict[str, str] @@ -634,7 +639,7 @@ def __post_init__(self) -> None: group_current[meta.group] = field_name # Now that all the defaults are set, reset it! - self.__dict__["_serialized_on_wire"] = not all_sentinel + self.__dict__["serialized_on_wire"] = not all_sentinel self.__dict__["_unknown_fields"] = b"" self.__dict__["_group_current"] = group_current @@ -694,9 +699,9 @@ def __getattribute__(self, name: str) -> Any: return value def __setattr__(self, attr: str, value: Any) -> None: - if attr != "_serialized_on_wire": + if attr != "serialized_on_wire": # Track when a field has been set. - self.__dict__["_serialized_on_wire"] = True + self.__dict__["serialized_on_wire"] = True if hasattr(self, "_group_current"): # __post_init__ had already run if attr in self._betterproto.oneof_group_by_field: @@ -756,7 +761,7 @@ def __bytes__(self) -> bytes: # Empty messages can still be sent on the wire if they were # set (or received empty). - serialize_empty = isinstance(value, Message) and value._serialized_on_wire + serialize_empty = isinstance(value, Message) and value.serialized_on_wire include_default_value_for_oneof = self._include_default_value_for_oneof( field_name=field_name, meta=meta @@ -924,7 +929,7 @@ def _postprocess_single( value = _get_wrapper(meta.wraps)().parse(value).value else: value = cls().parse(value) - value._serialized_on_wire = True + value.serialized_on_wire = True elif meta.proto_type == TYPE_MAP: value = self._betterproto.cls_by_field[field_name]().parse(value) @@ -953,7 +958,7 @@ def parse(self: T, data: bytes) -> T: The initialized message. """ # Got some data over the wire - self._serialized_on_wire = True + self.serialized_on_wire = True proto_meta = self._betterproto for parsed in parse_fields(data): field_name = proto_meta.field_name_by_number.get(parsed.number) @@ -1089,7 +1094,7 @@ def to_dict( if include_default_values: output[cased_name] = value elif ( - value._serialized_on_wire + value.serialized_on_wire or include_default_values or self._include_default_value_for_oneof( field_name=field_name, meta=meta @@ -1170,7 +1175,7 @@ def from_dict(self: T, value: Dict[str, Any]) -> T: :class:`Message` The initialized message. """ - self._serialized_on_wire = True + self.serialized_on_wire = True for key in value: field_name = safe_snake_case(key) meta = self._betterproto.meta_by_field_name.get(field_name) @@ -1280,34 +1285,19 @@ def from_json(self: T, value: Union[str, bytes]) -> T: """ return self.from_dict(json.loads(value)) + def which_one_of(self, group_name: str) -> Tuple[str, Optional[Any]]: + """ + Return the name and value of a message's one-of field group. -def serialized_on_wire(message: Message) -> bool: - """ - If this message was or should be serialized on the wire. This can be used to detect - presence (e.g. optional wrapper message) and is used internally during - parsing/serialization. - - Returns - -------- - :class:`bool` - Whether this message was or should be serialized on the wire. - """ - return message._serialized_on_wire - - -def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]: - """ - Return the name and value of a message's one-of field group. - - Returns - -------- - Tuple[:class:`str`, Any] - The field name and the value for that field. - """ - field_name = message._group_current.get(group_name) - if not field_name: - return "", None - return field_name, getattr(message, field_name) + Returns + -------- + Tuple[:class:`str`, Any] + The field name and the value for that field. + """ + field_name = self._group_current.get(group_name) + if not field_name: + return "", None + return field_name, getattr(self, field_name) # Circular import workaround: google.protobuf depends on base classes defined above. diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index 840140043..8b7324abf 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -32,7 +32,6 @@ import builtins import betterproto -from betterproto import which_one_of from betterproto.casing import sanitize_name from betterproto.compile.importing import ( get_type_reference, @@ -114,6 +113,9 @@ FieldDescriptorProtoType.TYPE_SINT32, # 17 FieldDescriptorProtoType.TYPE_SINT64, # 18 ) +UNSAFE_FIELD_NAMES = frozenset(dir(betterproto.Message)) | frozenset( + betterproto.Message.__annotations__ +) def monkey_patch_oneof_index(): @@ -355,7 +357,7 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool: us to tell whether it was set, via the which_one_of interface. """ - return which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index" + return proto_field_obj.which_one_of("oneof_index")[0] == "oneof_index" @dataclass @@ -501,7 +503,11 @@ def packed(self) -> bool: @property def py_name(self) -> str: """Pythonized name.""" - return pythonize_field_name(self.proto_name) + unsafe_name = pythonize_field_name(self.proto_name) + # rename fields in case they clash with things defined in Message + if unsafe_name in UNSAFE_FIELD_NAMES: + return f"{unsafe_name}_" + return unsafe_name @property def proto_name(self) -> str: diff --git a/tests/inputs/oneof/test_oneof.py b/tests/inputs/oneof/test_oneof.py index d1267659f..c7f7a6a3f 100644 --- a/tests/inputs/oneof/test_oneof.py +++ b/tests/inputs/oneof/test_oneof.py @@ -1,4 +1,3 @@ -import betterproto from tests.output_betterproto.oneof import Test from tests.util import get_test_case_json_data @@ -6,10 +5,10 @@ def test_which_count(): message = Test() message.from_json(get_test_case_json_data("oneof")[0].json) - assert betterproto.which_one_of(message, "foo") == ("pitied", 100) + assert message.which_one_of("foo") == ("pitied", 100) def test_which_name(): message = Test() message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0].json) - assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T") + assert message.which_one_of("foo") == ("pitier", "Mr. T") diff --git a/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py b/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py index 0c928cb89..f010cb01e 100644 --- a/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py +++ b/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py @@ -1,7 +1,6 @@ import pytest import datetime -import betterproto from tests.output_betterproto.oneof_default_value_serialization import ( Test, Message, @@ -10,9 +9,9 @@ def assert_round_trip_serialization_works(message: Test) -> None: - assert betterproto.which_one_of(message, "value_type") == betterproto.which_one_of( - Test().from_json(message.to_json()), "value_type" - ) + assert message.which_one_of("value_type") == Test().from_json( + message.to_json() + ).which_one_of("value_type") def test_oneof_default_value_serialization_works_for_all_values(): @@ -49,8 +48,8 @@ def test_oneof_default_value_serialization_works_for_all_values(): def test_oneof_no_default_values_passed(): message = Test() assert ( - betterproto.which_one_of(message, "value_type") - == betterproto.which_one_of(Test().from_json(message.to_json()), "value_type") + message.which_one_of("value_type") + == Test().from_json(message.to_json()).which_one_of("value_type") == ("", None) ) @@ -65,8 +64,8 @@ def test_oneof_nested_oneof_messages_are_serialized_with_defaults(): ) ) assert ( - betterproto.which_one_of(message, "value_type") - == betterproto.which_one_of(Test().from_json(message.to_json()), "value_type") + message.which_one_of("value_type") + == Test().from_json(message.to_json()).which_one_of("value_type") == ( "wrapped_nested_message_value", NestedMessage(id=0, wrapped_message_value=Message(value=0)), diff --git a/tests/inputs/oneof_enum/test_oneof_enum.py b/tests/inputs/oneof_enum/test_oneof_enum.py index 7e287d4a4..b71c3d377 100644 --- a/tests/inputs/oneof_enum/test_oneof_enum.py +++ b/tests/inputs/oneof_enum/test_oneof_enum.py @@ -1,6 +1,5 @@ import pytest -import betterproto from tests.output_betterproto.oneof_enum import ( Move, Signal, @@ -22,7 +21,7 @@ def test_which_one_of_returns_enum_with_default_value(): x=0, y=0 ) # Proto3 will default this as there is no null assert message.signal == Signal.PASS - assert betterproto.which_one_of(message, "action") == ("signal", Signal.PASS) + assert message.which_one_of("action") == ("signal", Signal.PASS) def test_which_one_of_returns_enum_with_non_default_value(): @@ -37,7 +36,7 @@ def test_which_one_of_returns_enum_with_non_default_value(): x=0, y=0 ) # Proto3 will default this as there is no null assert message.signal == Signal.RESIGN - assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN) + assert message.which_one_of("action") == ("signal", Signal.RESIGN) def test_which_one_of_returns_second_field_when_set(): @@ -45,4 +44,4 @@ def test_which_one_of_returns_second_field_when_set(): message.from_json(get_test_case_json_data("oneof_enum")[0].json) assert message.move == Move(x=2, y=3) assert message.signal == Signal.PASS - assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3)) + assert message.which_one_of("action") == ("move", Move(x=2, y=3)) diff --git a/tests/inputs/rename/rename.proto b/tests/inputs/rename/rename.proto new file mode 100644 index 000000000..c796e442f --- /dev/null +++ b/tests/inputs/rename/rename.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +// The fields that have overlapping names with betterproto.Message will be renamed. +message Test { + bool parse = 1; + bool serialized_on_wire = 2; + bool from_json = 3; + int32 this = 4; +} diff --git a/tests/inputs/rename/test_rename.py b/tests/inputs/rename/test_rename.py new file mode 100644 index 000000000..5ee5b8ba6 --- /dev/null +++ b/tests/inputs/rename/test_rename.py @@ -0,0 +1,12 @@ +from dataclasses import fields + +from tests.output_betterproto.rename import Test + + +def test_renamed_fields(): + assert {field.name for field in fields(Test)} == { + "parse_", + "serialized_on_wire_", + "from_json_", + "this", + } diff --git a/tests/test_features.py b/tests/test_features.py index b82528eea..d44f32e1c 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -16,23 +16,23 @@ class Foo(betterproto.Message): # Unset by default foo = Foo() - assert betterproto.serialized_on_wire(foo.bar) is False + assert foo.bar.serialized_on_wire is False # Serialized after setting something foo.bar.baz = 1 - assert betterproto.serialized_on_wire(foo.bar) is True + assert foo.bar.serialized_on_wire is True # Still has it after setting the default value foo.bar.baz = 0 - assert betterproto.serialized_on_wire(foo.bar) is True + assert foo.bar.serialized_on_wire is True # Manual override (don't do this) - foo.bar._serialized_on_wire = False - assert betterproto.serialized_on_wire(foo.bar) is False + foo.bar.serialized_on_wire = False + assert foo.bar.serialized_on_wire is False # Can manually set it but defaults to false foo.bar = Bar() - assert betterproto.serialized_on_wire(foo.bar) is False + assert foo.bar.serialized_on_wire is False @dataclass class WithCollections(betterproto.Message): @@ -43,15 +43,15 @@ class WithCollections(betterproto.Message): # Is always set from parse, even if all collections are empty with_collections_empty = WithCollections().parse(bytes(WithCollections())) - assert betterproto.serialized_on_wire(with_collections_empty) == True + assert with_collections_empty.serialized_on_wire == True with_collections_list = WithCollections().parse( bytes(WithCollections(test_list=["a", "b", "c"])) ) - assert betterproto.serialized_on_wire(with_collections_list) == True + assert with_collections_list.serialized_on_wire == True with_collections_map = WithCollections().parse( bytes(WithCollections(test_map={"a": "b", "c": "d"})) ) - assert betterproto.serialized_on_wire(with_collections_map) == True + assert with_collections_map.serialized_on_wire == True def test_class_init(): @@ -123,35 +123,35 @@ class Foo(betterproto.Message): foo = Foo() - assert betterproto.which_one_of(foo, "group1")[0] == "" + assert foo.which_one_of("group1")[0] == "" foo.bar = 1 foo.baz = "test" # Other oneof fields should now be unset assert foo.bar == 0 - assert betterproto.which_one_of(foo, "group1")[0] == "baz" + assert foo.which_one_of("group1")[0] == "baz" foo.sub.val = 1 - assert betterproto.serialized_on_wire(foo.sub) + assert foo.sub.serialized_on_wire foo.abc = "test" # Group 1 shouldn't be touched, group 2 should have reset assert foo.sub.val == 0 - assert betterproto.serialized_on_wire(foo.sub) is False - assert betterproto.which_one_of(foo, "group2")[0] == "abc" + assert foo.sub.serialized_on_wire is False + assert foo.which_one_of("group2")[0] == "abc" # Zero value should always serialize for one-of foo = Foo(bar=0) - assert betterproto.which_one_of(foo, "group1")[0] == "bar" + assert foo.which_one_of("group1")[0] == "bar" assert bytes(foo) == b"\x08\x00" # Round trip should also work foo2 = Foo().parse(bytes(foo)) - assert betterproto.which_one_of(foo2, "group1")[0] == "bar" + assert foo2.which_one_of("group1")[0] == "bar" assert foo.bar == 0 - assert betterproto.which_one_of(foo2, "group2")[0] == "" + assert foo2.which_one_of("group2")[0] == "" def test_json_casing(): @@ -307,29 +307,29 @@ def _round_trip_serialization(foo: Foo) -> Foo: assert bytes(foo1) == b"\x08\x00" assert ( - betterproto.which_one_of(foo1, "group1") - == betterproto.which_one_of(_round_trip_serialization(foo1), "group1") + foo1.which_one_of("group1") + == _round_trip_serialization(foo1).which_one_of("group1") == ("bar", 0) ) assert bytes(foo2) == b"\x12\x00" # Baz is just an empty string assert ( - betterproto.which_one_of(foo2, "group1") - == betterproto.which_one_of(_round_trip_serialization(foo2), "group1") + foo2.which_one_of("group1") + == _round_trip_serialization(foo2).which_one_of("group1") == ("baz", "") ) assert bytes(foo3) == b"\x1a\x00" assert ( - betterproto.which_one_of(foo3, "group1") - == betterproto.which_one_of(_round_trip_serialization(foo3), "group1") + foo3.which_one_of("group1") + == _round_trip_serialization(foo3).which_one_of("group1") == ("qux", Empty()) ) assert bytes(foo4) == b"" assert ( - betterproto.which_one_of(foo4, "group1") - == betterproto.which_one_of(_round_trip_serialization(foo4), "group1") + foo4.which_one_of("group1") + == _round_trip_serialization(foo4).which_one_of("group1") == ("", None) )