Skip to content

Commit 97dfcb9

Browse files
committed
feat(ffi): add OwnedString class
A number of FFI functions return strings that are owned by the library and must be freed manually. Unfortunately, returning a `str` from a function would result in a memory leak, as the Python runtime would not know to free the string. This commit adds an `OwnedString` class that wraps a `str` and a function that frees the string. The `__del__` method of the class calls the free function, ensuring that the string is freed when the object is garbage collected. Signed-off-by: JP-Ellis <[email protected]>
1 parent db6374a commit 97dfcb9

File tree

2 files changed

+104
-24
lines changed

2 files changed

+104
-24
lines changed

pact/v3/ffi.py

Lines changed: 88 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
if TYPE_CHECKING:
9393
import cffi
9494
from pathlib import Path
95+
from typing_extensions import Self
9596

9697
# The follow types are classes defined in the Rust code. Ultimately, a Python
9798
# alternative should be implemented, but for now, the follow lines only serve
@@ -613,6 +614,75 @@ def raise_exception(self) -> None:
613614
raise RuntimeError(self.text)
614615

615616

617+
class OwnedString(str):
618+
"""
619+
A string that owns its own memory.
620+
621+
This is used to ensure that the memory is freed when the string is
622+
destroyed.
623+
624+
As this is subclassed from `str`, it can be used in place of a normal string
625+
in most cases.
626+
"""
627+
628+
def __new__(cls, ptr: cffi.FFI.CData) -> Self:
629+
"""
630+
Create a new Owned String.
631+
632+
As this is a subclass of the immutable type `str`, we need to override
633+
the `__new__` method to ensure that the string is initialised correctly.
634+
"""
635+
s = ffi.string(ptr)
636+
return super().__new__(cls, s if isinstance(s, str) else s.decode("utf-8"))
637+
638+
def __init__(self, ptr: cffi.FFI.CData) -> None:
639+
"""
640+
Initialise a new Owned String.
641+
642+
Args:
643+
ptr:
644+
CFFI data structure.
645+
"""
646+
self._ptr = ptr
647+
s = ffi.string(ptr)
648+
self._string = s if isinstance(s, str) else s.decode("utf-8")
649+
650+
def __str__(self) -> str:
651+
"""
652+
String representation of the Owned String.
653+
"""
654+
return self._string
655+
656+
def __repr__(self) -> str:
657+
"""
658+
Debugging string representation of the Owned String.
659+
"""
660+
return f"<OwnedString: {self._string!r}, ptr={self._ptr!r}>"
661+
662+
def __del__(self) -> None:
663+
"""
664+
Destructor for the Owned String.
665+
"""
666+
string_delete(self)
667+
668+
def __eq__(self, other: object) -> bool:
669+
"""
670+
Equality comparison.
671+
672+
Args:
673+
other:
674+
The object to compare to.
675+
676+
Returns:
677+
Whether the two objects are equal.
678+
"""
679+
if isinstance(other, OwnedString):
680+
return self._ptr == other._ptr
681+
if isinstance(other, str):
682+
return self._string == other
683+
return super().__eq__(other)
684+
685+
616686
def version() -> str:
617687
"""
618688
Return the version of the pact_ffi library.
@@ -3000,7 +3070,7 @@ def message_delete(message: Message) -> None:
30003070
raise NotImplementedError
30013071

30023072

3003-
def message_get_contents(message: Message) -> str:
3073+
def message_get_contents(message: Message) -> OwnedString | None:
30043074
"""
30053075
Get the contents of a `Message` in string form.
30063076
@@ -3112,7 +3182,7 @@ def message_set_contents_bin(
31123182
raise NotImplementedError
31133183

31143184

3115-
def message_get_description(message: Message) -> str:
3185+
def message_get_description(message: Message) -> OwnedString:
31163186
r"""
31173187
Get a copy of the description.
31183188
@@ -4196,20 +4266,14 @@ def sync_message_get_provider_state_iter(
41964266
raise NotImplementedError
41974267

41984268

4199-
def string_delete(string: str) -> None:
4269+
def string_delete(string: OwnedString) -> None:
42004270
"""
42014271
Delete a string previously returned by this FFI.
42024272
42034273
[Rust
42044274
`pactffi_string_delete`](https://docs.rs/pact_ffi/0.4.9/pact_ffi/?search=pactffi_string_delete)
4205-
4206-
It is explicitly allowed to pass a null pointer to this function; in that
4207-
case the function will do nothing.
4208-
4209-
# Safety Passing an invalid pointer, or one that was not returned by a FFI
4210-
function can result in undefined behaviour.
42114275
"""
4212-
raise NotImplementedError
4276+
lib.pactffi_string_delete(string._ptr)
42134277

42144278

42154279
def create_mock_server(pact_str: str, addr_str: str, *, tls: bool) -> int:
@@ -4253,7 +4317,7 @@ def create_mock_server(pact_str: str, addr_str: str, *, tls: bool) -> int:
42534317
raise NotImplementedError
42544318

42554319

4256-
def get_tls_ca_certificate() -> str:
4320+
def get_tls_ca_certificate() -> OwnedString:
42574321
"""
42584322
Fetch the CA Certificate used to generate the self-signed certificate.
42594323
@@ -4267,7 +4331,7 @@ def get_tls_ca_certificate() -> str:
42674331
42684332
An empty string indicates an error reading the pem file.
42694333
"""
4270-
raise NotImplementedError
4334+
return OwnedString(lib.pactffi_get_tls_ca_certificate())
42714335

42724336

42734337
def create_mock_server_for_pact(pact: PactHandle, addr_str: str, *, tls: bool) -> int:
@@ -5624,7 +5688,7 @@ def message_with_metadata(message_handle: MessageHandle, key: str, value: str) -
56245688
raise NotImplementedError
56255689

56265690

5627-
def message_reify(message_handle: MessageHandle) -> str:
5691+
def message_reify(message_handle: MessageHandle) -> OwnedString:
56285692
"""
56295693
Reifies the given message.
56305694
@@ -6320,7 +6384,7 @@ def verifier_cli_args() -> str:
63206384
raise NotImplementedError
63216385

63226386

6323-
def verifier_logs(handle: VerifierHandle) -> str:
6387+
def verifier_logs(handle: VerifierHandle) -> OwnedString:
63246388
"""
63256389
Extracts the logs for the verification run.
63266390
@@ -6337,7 +6401,7 @@ def verifier_logs(handle: VerifierHandle) -> str:
63376401
raise NotImplementedError
63386402

63396403

6340-
def verifier_logs_for_provider(provider_name: str) -> str:
6404+
def verifier_logs_for_provider(provider_name: str) -> OwnedString:
63416405
"""
63426406
Extracts the logs for the verification run for the provider name.
63436407
@@ -6354,7 +6418,7 @@ def verifier_logs_for_provider(provider_name: str) -> str:
63546418
raise NotImplementedError
63556419

63566420

6357-
def verifier_output(handle: VerifierHandle, strip_ansi: int) -> str:
6421+
def verifier_output(handle: VerifierHandle, strip_ansi: int) -> OwnedString:
63586422
"""
63596423
Extracts the standard output for the verification run.
63606424
@@ -6373,7 +6437,7 @@ def verifier_output(handle: VerifierHandle, strip_ansi: int) -> str:
63736437
raise NotImplementedError
63746438

63756439

6376-
def verifier_json(handle: VerifierHandle) -> str:
6440+
def verifier_json(handle: VerifierHandle) -> OwnedString:
63776441
"""
63786442
Extracts the verification result as a JSON document.
63796443
@@ -6498,7 +6562,7 @@ def matches_string_value(
64986562
expected_value: str,
64996563
actual_value: str,
65006564
cascaded: int,
6501-
) -> str:
6565+
) -> OwnedString:
65026566
"""
65036567
Determines if the string value matches the given matching rule.
65046568
@@ -6529,7 +6593,7 @@ def matches_u64_value(
65296593
expected_value: int,
65306594
actual_value: int,
65316595
cascaded: int,
6532-
) -> str:
6596+
) -> OwnedString:
65336597
"""
65346598
Determines if the unsigned integer value matches the given matching rule.
65356599
@@ -6559,7 +6623,7 @@ def matches_i64_value(
65596623
expected_value: int,
65606624
actual_value: int,
65616625
cascaded: int,
6562-
) -> str:
6626+
) -> OwnedString:
65636627
"""
65646628
Determines if the signed integer value matches the given matching rule.
65656629
@@ -6589,7 +6653,7 @@ def matches_f64_value(
65896653
expected_value: float,
65906654
actual_value: float,
65916655
cascaded: int,
6592-
) -> str:
6656+
) -> OwnedString:
65936657
"""
65946658
Determines if the floating point value matches the given matching rule.
65956659
@@ -6619,7 +6683,7 @@ def matches_bool_value(
66196683
expected_value: int,
66206684
actual_value: int,
66216685
cascaded: int,
6622-
) -> str:
6686+
) -> OwnedString:
66236687
"""
66246688
Determines if the boolean value matches the given matching rule.
66256689
@@ -6651,7 +6715,7 @@ def matches_binary_value( # noqa: PLR0913
66516715
actual_value: str,
66526716
actual_value_len: int,
66536717
cascaded: int,
6654-
) -> str:
6718+
) -> OwnedString:
66556719
"""
66566720
Determines if the binary value matches the given matching rule.
66576721
@@ -6686,7 +6750,7 @@ def matches_json_value(
66866750
expected_value: str,
66876751
actual_value: str,
66886752
cascaded: int,
6689-
) -> str:
6753+
) -> OwnedString:
66906754
"""
66916755
Determines if the JSON value matches the given matching rule.
66926756

tests/v3/test_ffi.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,19 @@ def test_get_error_message() -> None:
5151
ret: int = ffi.lib.pactffi_validate_datetime(invalid_utf8, invalid_utf8)
5252
assert ret == 2
5353
assert ffi.get_error_message() == "error parsing value as UTF-8"
54+
55+
56+
def test_owned_string() -> None:
57+
string = ffi.get_tls_ca_certificate()
58+
assert isinstance(string, str)
59+
assert len(string) > 0
60+
assert str(string) == string
61+
assert repr(string).startswith("<OwnedString: ")
62+
assert repr(string).endswith(">")
63+
assert string.startswith("-----BEGIN CERTIFICATE-----")
64+
assert string.endswith(
65+
(
66+
"-----END CERTIFICATE-----\n",
67+
"-----END CERTIFICATE-----\r\n",
68+
)
69+
)

0 commit comments

Comments
 (0)