From d2f33ae795f1ac1df00acbf9a8b2e1d092a205bd Mon Sep 17 00:00:00 2001 From: James Hilton-Balfe <50501825+Gobot1234@users.noreply.github.com> Date: Tue, 18 Jan 2022 19:01:10 +0000 Subject: [PATCH 1/8] Rename fields if they intersect with betterproto.Message --- src/betterproto/plugin/models.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index 840140043..d2d742ccd 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -501,7 +501,11 @@ def packed(self) -> bool: @property def py_name(self) -> str: """Pythonized name.""" - return pythonize_field_name(self.proto_name) + unsafe_name = pythonize_field_name(self.proto_name) + # rename fields in case they clash with things defined in Message + if unsafe_name in dir(betterproto.Message): + return f"{unsafe_name}_" + return unsafe_name @property def proto_name(self) -> str: From dc292168e7466874d5935b0a2b218a48153f014a Mon Sep 17 00:00:00 2001 From: James Hilton-Balfe <50501825+Gobot1234@users.noreply.github.com> Date: Tue, 18 Jan 2022 19:03:22 +0000 Subject: [PATCH 2/8] I meant to stage all of those --- src/betterproto/__init__.py | 62 +++++++++++++----------------- tests/inputs/rename/rename.proto | 7 ++++ tests/inputs/rename/test_rename.py | 12 ++++++ 3 files changed, 45 insertions(+), 36 deletions(-) create mode 100644 tests/inputs/rename/rename.proto create mode 100644 tests/inputs/rename/test_rename.py diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index 5e514d5c1..a30a89f8f 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -609,7 +609,12 @@ class Message(ABC): Calls :meth:`__bool__`. """ - _serialized_on_wire: bool + serialized_on_wire: bool + """ + If this message was or should be serialized on the wire. This can be used to detect + presence (e.g. optional wrapper message) and is used internally during + parsing/serialization. + """ _unknown_fields: bytes _group_current: Dict[str, str] @@ -634,7 +639,7 @@ def __post_init__(self) -> None: group_current[meta.group] = field_name # Now that all the defaults are set, reset it! - self.__dict__["_serialized_on_wire"] = not all_sentinel + self.__dict__["serialized_on_wire"] = not all_sentinel self.__dict__["_unknown_fields"] = b"" self.__dict__["_group_current"] = group_current @@ -694,9 +699,9 @@ def __getattribute__(self, name: str) -> Any: return value def __setattr__(self, attr: str, value: Any) -> None: - if attr != "_serialized_on_wire": + if attr != "serialized_on_wire": # Track when a field has been set. - self.__dict__["_serialized_on_wire"] = True + self.__dict__["serialized_on_wire"] = True if hasattr(self, "_group_current"): # __post_init__ had already run if attr in self._betterproto.oneof_group_by_field: @@ -756,7 +761,7 @@ def __bytes__(self) -> bytes: # Empty messages can still be sent on the wire if they were # set (or received empty). - serialize_empty = isinstance(value, Message) and value._serialized_on_wire + serialize_empty = isinstance(value, Message) and value.serialized_on_wire include_default_value_for_oneof = self._include_default_value_for_oneof( field_name=field_name, meta=meta @@ -924,7 +929,7 @@ def _postprocess_single( value = _get_wrapper(meta.wraps)().parse(value).value else: value = cls().parse(value) - value._serialized_on_wire = True + value.serialized_on_wire = True elif meta.proto_type == TYPE_MAP: value = self._betterproto.cls_by_field[field_name]().parse(value) @@ -953,7 +958,7 @@ def parse(self: T, data: bytes) -> T: The initialized message. """ # Got some data over the wire - self._serialized_on_wire = True + self.serialized_on_wire = True proto_meta = self._betterproto for parsed in parse_fields(data): field_name = proto_meta.field_name_by_number.get(parsed.number) @@ -1089,7 +1094,7 @@ def to_dict( if include_default_values: output[cased_name] = value elif ( - value._serialized_on_wire + value.serialized_on_wire or include_default_values or self._include_default_value_for_oneof( field_name=field_name, meta=meta @@ -1170,7 +1175,7 @@ def from_dict(self: T, value: Dict[str, Any]) -> T: :class:`Message` The initialized message. """ - self._serialized_on_wire = True + self.serialized_on_wire = True for key in value: field_name = safe_snake_case(key) meta = self._betterproto.meta_by_field_name.get(field_name) @@ -1280,34 +1285,19 @@ def from_json(self: T, value: Union[str, bytes]) -> T: """ return self.from_dict(json.loads(value)) + def which_one_of(self, group_name: str) -> Tuple[str, Optional[Any]]: + """ + Return the name and value of a message's one-of field group. -def serialized_on_wire(message: Message) -> bool: - """ - If this message was or should be serialized on the wire. This can be used to detect - presence (e.g. optional wrapper message) and is used internally during - parsing/serialization. - - Returns - -------- - :class:`bool` - Whether this message was or should be serialized on the wire. - """ - return message._serialized_on_wire - - -def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]: - """ - Return the name and value of a message's one-of field group. - - Returns - -------- - Tuple[:class:`str`, Any] - The field name and the value for that field. - """ - field_name = message._group_current.get(group_name) - if not field_name: - return "", None - return field_name, getattr(message, field_name) + Returns + -------- + Tuple[:class:`str`, Any] + The field name and the value for that field. + """ + field_name = self._group_current.get(group_name) + if not field_name: + return "", None + return field_name, getattr(self, field_name) # Circular import workaround: google.protobuf depends on base classes defined above. diff --git a/tests/inputs/rename/rename.proto b/tests/inputs/rename/rename.proto new file mode 100644 index 000000000..9e02e6320 --- /dev/null +++ b/tests/inputs/rename/rename.proto @@ -0,0 +1,7 @@ +// The fields that have overlapping names with betterproto.Message will be renamed. +message Renamed { + optional bool parse = 1; + optional bool serialized_on_wire = 2; + optional bool from_json = 3; + optional int32 this = 4; +} diff --git a/tests/inputs/rename/test_rename.py b/tests/inputs/rename/test_rename.py new file mode 100644 index 000000000..7c117e162 --- /dev/null +++ b/tests/inputs/rename/test_rename.py @@ -0,0 +1,12 @@ +from dataclasses import fields + +from tests.output_betterproto.rename import Renamed + + +def test_renamed_fields(): + assert {field.name for field in fields(Renamed)} == { + "parse_", + "serialized_on_wire_", + "from_json_", + "this", + } From bbb2650eeae71b1b0973a356630d06b0ea2414d8 Mon Sep 17 00:00:00 2001 From: James Hilton-Balfe <50501825+Gobot1234@users.noreply.github.com> Date: Tue, 18 Jan 2022 23:24:11 +0000 Subject: [PATCH 3/8] Fix which_one_of being removed --- src/betterproto/plugin/models.py | 3 +- tests/inputs/oneof/test_oneof.py | 5 ++-- .../test_oneof_default_value_serialization.py | 15 +++++----- tests/inputs/oneof_enum/test_oneof_enum.py | 7 ++--- tests/test_features.py | 28 +++++++++---------- 5 files changed, 27 insertions(+), 31 deletions(-) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index d2d742ccd..5c7d1c352 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -32,7 +32,6 @@ import builtins import betterproto -from betterproto import which_one_of from betterproto.casing import sanitize_name from betterproto.compile.importing import ( get_type_reference, @@ -355,7 +354,7 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool: us to tell whether it was set, via the which_one_of interface. """ - return which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index" + return proto_field_obj.which_one_of("oneof_index")[0] == "oneof_index" @dataclass diff --git a/tests/inputs/oneof/test_oneof.py b/tests/inputs/oneof/test_oneof.py index d1267659f..c7f7a6a3f 100644 --- a/tests/inputs/oneof/test_oneof.py +++ b/tests/inputs/oneof/test_oneof.py @@ -1,4 +1,3 @@ -import betterproto from tests.output_betterproto.oneof import Test from tests.util import get_test_case_json_data @@ -6,10 +5,10 @@ def test_which_count(): message = Test() message.from_json(get_test_case_json_data("oneof")[0].json) - assert betterproto.which_one_of(message, "foo") == ("pitied", 100) + assert message.which_one_of("foo") == ("pitied", 100) def test_which_name(): message = Test() message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0].json) - assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T") + assert message.which_one_of("foo") == ("pitier", "Mr. T") diff --git a/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py b/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py index 0c928cb89..f010cb01e 100644 --- a/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py +++ b/tests/inputs/oneof_default_value_serialization/test_oneof_default_value_serialization.py @@ -1,7 +1,6 @@ import pytest import datetime -import betterproto from tests.output_betterproto.oneof_default_value_serialization import ( Test, Message, @@ -10,9 +9,9 @@ def assert_round_trip_serialization_works(message: Test) -> None: - assert betterproto.which_one_of(message, "value_type") == betterproto.which_one_of( - Test().from_json(message.to_json()), "value_type" - ) + assert message.which_one_of("value_type") == Test().from_json( + message.to_json() + ).which_one_of("value_type") def test_oneof_default_value_serialization_works_for_all_values(): @@ -49,8 +48,8 @@ def test_oneof_default_value_serialization_works_for_all_values(): def test_oneof_no_default_values_passed(): message = Test() assert ( - betterproto.which_one_of(message, "value_type") - == betterproto.which_one_of(Test().from_json(message.to_json()), "value_type") + message.which_one_of("value_type") + == Test().from_json(message.to_json()).which_one_of("value_type") == ("", None) ) @@ -65,8 +64,8 @@ def test_oneof_nested_oneof_messages_are_serialized_with_defaults(): ) ) assert ( - betterproto.which_one_of(message, "value_type") - == betterproto.which_one_of(Test().from_json(message.to_json()), "value_type") + message.which_one_of("value_type") + == Test().from_json(message.to_json()).which_one_of("value_type") == ( "wrapped_nested_message_value", NestedMessage(id=0, wrapped_message_value=Message(value=0)), diff --git a/tests/inputs/oneof_enum/test_oneof_enum.py b/tests/inputs/oneof_enum/test_oneof_enum.py index 7e287d4a4..b71c3d377 100644 --- a/tests/inputs/oneof_enum/test_oneof_enum.py +++ b/tests/inputs/oneof_enum/test_oneof_enum.py @@ -1,6 +1,5 @@ import pytest -import betterproto from tests.output_betterproto.oneof_enum import ( Move, Signal, @@ -22,7 +21,7 @@ def test_which_one_of_returns_enum_with_default_value(): x=0, y=0 ) # Proto3 will default this as there is no null assert message.signal == Signal.PASS - assert betterproto.which_one_of(message, "action") == ("signal", Signal.PASS) + assert message.which_one_of("action") == ("signal", Signal.PASS) def test_which_one_of_returns_enum_with_non_default_value(): @@ -37,7 +36,7 @@ def test_which_one_of_returns_enum_with_non_default_value(): x=0, y=0 ) # Proto3 will default this as there is no null assert message.signal == Signal.RESIGN - assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN) + assert message.which_one_of("action") == ("signal", Signal.RESIGN) def test_which_one_of_returns_second_field_when_set(): @@ -45,4 +44,4 @@ def test_which_one_of_returns_second_field_when_set(): message.from_json(get_test_case_json_data("oneof_enum")[0].json) assert message.move == Move(x=2, y=3) assert message.signal == Signal.PASS - assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3)) + assert message.which_one_of("action") == ("move", Move(x=2, y=3)) diff --git a/tests/test_features.py b/tests/test_features.py index b82528eea..5d349a98c 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -123,14 +123,14 @@ class Foo(betterproto.Message): foo = Foo() - assert betterproto.which_one_of(foo, "group1")[0] == "" + assert foo.which_one_of("group1")[0] == "" foo.bar = 1 foo.baz = "test" # Other oneof fields should now be unset assert foo.bar == 0 - assert betterproto.which_one_of(foo, "group1")[0] == "baz" + assert foo.which_one_of("group1")[0] == "baz" foo.sub.val = 1 assert betterproto.serialized_on_wire(foo.sub) @@ -140,18 +140,18 @@ class Foo(betterproto.Message): # Group 1 shouldn't be touched, group 2 should have reset assert foo.sub.val == 0 assert betterproto.serialized_on_wire(foo.sub) is False - assert betterproto.which_one_of(foo, "group2")[0] == "abc" + assert foo.which_one_of("group2")[0] == "abc" # Zero value should always serialize for one-of foo = Foo(bar=0) - assert betterproto.which_one_of(foo, "group1")[0] == "bar" + assert foo.which_one_of("group1")[0] == "bar" assert bytes(foo) == b"\x08\x00" # Round trip should also work foo2 = Foo().parse(bytes(foo)) - assert betterproto.which_one_of(foo2, "group1")[0] == "bar" + assert foo2.which_one_of("group1")[0] == "bar" assert foo.bar == 0 - assert betterproto.which_one_of(foo2, "group2")[0] == "" + assert foo2.which_one_of("group2")[0] == "" def test_json_casing(): @@ -307,29 +307,29 @@ def _round_trip_serialization(foo: Foo) -> Foo: assert bytes(foo1) == b"\x08\x00" assert ( - betterproto.which_one_of(foo1, "group1") - == betterproto.which_one_of(_round_trip_serialization(foo1), "group1") + foo1.which_one_of("group1") + == _round_trip_serialization(foo1).which_one_of("group1") == ("bar", 0) ) assert bytes(foo2) == b"\x12\x00" # Baz is just an empty string assert ( - betterproto.which_one_of(foo2, "group1") - == betterproto.which_one_of(_round_trip_serialization(foo2), "group1") + foo2.which_one_of("group1") + == _round_trip_serialization(foo2).which_one_of("group1") == ("baz", "") ) assert bytes(foo3) == b"\x1a\x00" assert ( - betterproto.which_one_of(foo3, "group1") - == betterproto.which_one_of(_round_trip_serialization(foo3), "group1") + foo3.which_one_of("group1") + == _round_trip_serialization(foo3).which_one_of("group1") == ("qux", Empty()) ) assert bytes(foo4) == b"" assert ( - betterproto.which_one_of(foo4, "group1") - == betterproto.which_one_of(_round_trip_serialization(foo4), "group1") + foo4.which_one_of("group1") + == _round_trip_serialization(foo4).which_one_of("group1") == ("", None) ) From 135362c31f6d1eff68915288160c6f192cd593b1 Mon Sep 17 00:00:00 2001 From: James Hilton-Balfe <50501825+Gobot1234@users.noreply.github.com> Date: Tue, 18 Jan 2022 23:29:43 +0000 Subject: [PATCH 4/8] Fix clashes in __annotations__ --- src/betterproto/plugin/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index 5c7d1c352..f3433fc00 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -113,6 +113,7 @@ FieldDescriptorProtoType.TYPE_SINT32, # 17 FieldDescriptorProtoType.TYPE_SINT64, # 18 ) +UNSAFE_FIELD_NAMES = frozenset(dir(betterproto.Message)) | frozenset(betterproto.Message.__annotations__) def monkey_patch_oneof_index(): @@ -502,7 +503,7 @@ def py_name(self) -> str: """Pythonized name.""" unsafe_name = pythonize_field_name(self.proto_name) # rename fields in case they clash with things defined in Message - if unsafe_name in dir(betterproto.Message): + if unsafe_name in UNSAFE_FIELD_NAMES: return f"{unsafe_name}_" return unsafe_name From 4b8acda2fdfbff054f504c2e6a0d2b718e60809f Mon Sep 17 00:00:00 2001 From: James Hilton-Balfe <50501825+Gobot1234@users.noreply.github.com> Date: Tue, 18 Jan 2022 23:35:16 +0000 Subject: [PATCH 5/8] Fix more CI --- src/betterproto/__init__.py | 1 + tests/inputs/rename/rename.proto | 2 +- tests/inputs/rename/test_rename.py | 4 ++-- tests/test_features.py | 22 +++++++++++----------- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index a30a89f8f..d8880c02d 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -1300,6 +1300,7 @@ def which_one_of(self, group_name: str) -> Tuple[str, Optional[Any]]: return field_name, getattr(self, field_name) +serialized_on_wire = None # Circular import workaround: google.protobuf depends on base classes defined above. from .lib.google.protobuf import ( # noqa BoolValue, diff --git a/tests/inputs/rename/rename.proto b/tests/inputs/rename/rename.proto index 9e02e6320..cc70898bb 100644 --- a/tests/inputs/rename/rename.proto +++ b/tests/inputs/rename/rename.proto @@ -1,5 +1,5 @@ // The fields that have overlapping names with betterproto.Message will be renamed. -message Renamed { +message Test { optional bool parse = 1; optional bool serialized_on_wire = 2; optional bool from_json = 3; diff --git a/tests/inputs/rename/test_rename.py b/tests/inputs/rename/test_rename.py index 7c117e162..5ee5b8ba6 100644 --- a/tests/inputs/rename/test_rename.py +++ b/tests/inputs/rename/test_rename.py @@ -1,10 +1,10 @@ from dataclasses import fields -from tests.output_betterproto.rename import Renamed +from tests.output_betterproto.rename import Test def test_renamed_fields(): - assert {field.name for field in fields(Renamed)} == { + assert {field.name for field in fields(Test)} == { "parse_", "serialized_on_wire_", "from_json_", diff --git a/tests/test_features.py b/tests/test_features.py index 5d349a98c..d44f32e1c 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -16,23 +16,23 @@ class Foo(betterproto.Message): # Unset by default foo = Foo() - assert betterproto.serialized_on_wire(foo.bar) is False + assert foo.bar.serialized_on_wire is False # Serialized after setting something foo.bar.baz = 1 - assert betterproto.serialized_on_wire(foo.bar) is True + assert foo.bar.serialized_on_wire is True # Still has it after setting the default value foo.bar.baz = 0 - assert betterproto.serialized_on_wire(foo.bar) is True + assert foo.bar.serialized_on_wire is True # Manual override (don't do this) - foo.bar._serialized_on_wire = False - assert betterproto.serialized_on_wire(foo.bar) is False + foo.bar.serialized_on_wire = False + assert foo.bar.serialized_on_wire is False # Can manually set it but defaults to false foo.bar = Bar() - assert betterproto.serialized_on_wire(foo.bar) is False + assert foo.bar.serialized_on_wire is False @dataclass class WithCollections(betterproto.Message): @@ -43,15 +43,15 @@ class WithCollections(betterproto.Message): # Is always set from parse, even if all collections are empty with_collections_empty = WithCollections().parse(bytes(WithCollections())) - assert betterproto.serialized_on_wire(with_collections_empty) == True + assert with_collections_empty.serialized_on_wire == True with_collections_list = WithCollections().parse( bytes(WithCollections(test_list=["a", "b", "c"])) ) - assert betterproto.serialized_on_wire(with_collections_list) == True + assert with_collections_list.serialized_on_wire == True with_collections_map = WithCollections().parse( bytes(WithCollections(test_map={"a": "b", "c": "d"})) ) - assert betterproto.serialized_on_wire(with_collections_map) == True + assert with_collections_map.serialized_on_wire == True def test_class_init(): @@ -133,13 +133,13 @@ class Foo(betterproto.Message): assert foo.which_one_of("group1")[0] == "baz" foo.sub.val = 1 - assert betterproto.serialized_on_wire(foo.sub) + assert foo.sub.serialized_on_wire foo.abc = "test" # Group 1 shouldn't be touched, group 2 should have reset assert foo.sub.val == 0 - assert betterproto.serialized_on_wire(foo.sub) is False + assert foo.sub.serialized_on_wire is False assert foo.which_one_of("group2")[0] == "abc" # Zero value should always serialize for one-of From 7525f4086bd0a2573d34b66d160ca6536e3a89f2 Mon Sep 17 00:00:00 2001 From: James Hilton-Balfe <50501825+Gobot1234@users.noreply.github.com> Date: Tue, 18 Jan 2022 23:36:08 +0000 Subject: [PATCH 6/8] Remove value --- src/betterproto/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index d8880c02d..a30a89f8f 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -1300,7 +1300,6 @@ def which_one_of(self, group_name: str) -> Tuple[str, Optional[Any]]: return field_name, getattr(self, field_name) -serialized_on_wire = None # Circular import workaround: google.protobuf depends on base classes defined above. from .lib.google.protobuf import ( # noqa BoolValue, From 959ca80a32e3ccd55bb0ada5a448ec21aecaed19 Mon Sep 17 00:00:00 2001 From: James Hilton-Balfe <50501825+Gobot1234@users.noreply.github.com> Date: Tue, 18 Jan 2022 23:39:58 +0000 Subject: [PATCH 7/8] Run black --- src/betterproto/plugin/models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index f3433fc00..8b7324abf 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -113,7 +113,9 @@ FieldDescriptorProtoType.TYPE_SINT32, # 17 FieldDescriptorProtoType.TYPE_SINT64, # 18 ) -UNSAFE_FIELD_NAMES = frozenset(dir(betterproto.Message)) | frozenset(betterproto.Message.__annotations__) +UNSAFE_FIELD_NAMES = frozenset(dir(betterproto.Message)) | frozenset( + betterproto.Message.__annotations__ +) def monkey_patch_oneof_index(): From 698cfd410cc1e6485a08d13e7b950d3a9541df22 Mon Sep 17 00:00:00 2001 From: James Hilton-Balfe <50501825+Gobot1234@users.noreply.github.com> Date: Tue, 18 Jan 2022 23:45:13 +0000 Subject: [PATCH 8/8] Make sure to use proto3 --- tests/inputs/rename/rename.proto | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/inputs/rename/rename.proto b/tests/inputs/rename/rename.proto index cc70898bb..c796e442f 100644 --- a/tests/inputs/rename/rename.proto +++ b/tests/inputs/rename/rename.proto @@ -1,7 +1,9 @@ +syntax = "proto3"; + // The fields that have overlapping names with betterproto.Message will be renamed. message Test { - optional bool parse = 1; - optional bool serialized_on_wire = 2; - optional bool from_json = 3; - optional int32 this = 4; + bool parse = 1; + bool serialized_on_wire = 2; + bool from_json = 3; + int32 this = 4; }