Skip to content

Commit 4c761c4

Browse files
committed
pyln-client/gossmap: more fixes, make mypy happier.
Mainly fixing type annotations, but some real fixes: 1. GossmapHalfchannel.from_str() should be a classmethod. 2. update_channel had weird, unusable default values (fields can't be NULL, since we use it below). Signed-off-by: Rusty Russell <[email protected]>
1 parent ef38730 commit 4c761c4

File tree

2 files changed

+17
-16
lines changed

2 files changed

+17
-16
lines changed

contrib/pyln-client/pyln/client/gossmap.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
node_announcement,
66
gossip_store_channel_amount)
77
from pyln.proto import ShortChannelId, PublicKey
8-
from typing import Any, Dict, List, Optional
8+
from typing import Any, Dict, List, Optional, Union, cast
99

1010
import io
1111
import struct
@@ -33,7 +33,7 @@ def __init__(self, buf: bytes):
3333

3434
class GossmapHalfchannel(object):
3535
"""One direction of a GossmapChannel."""
36-
def __init__(self, channel: GossmapChannel, direction: int,
36+
def __init__(self, channel: 'GossmapChannel', direction: int,
3737
timestamp: int, cltv_expiry_delta: int,
3838
htlc_minimum_msat: int, htlc_maximum_msat: int,
3939
fee_base_msat: int, fee_proportional_millionths: int):
@@ -71,12 +71,13 @@ def __hash__(self):
7171
def __repr__(self):
7272
return "GossmapNodeId[{}]".format(self.nodeid.hex())
7373

74-
def from_str(self, s: str):
74+
@classmethod
75+
def from_str(cls, s: str):
7576
if s.startswith('0x'):
7677
s = s[2:]
7778
if len(s) != 67:
7879
raise ValueError(f"{s} is not a valid hexstring of a node_id")
79-
return GossmapNodeId(bytes.fromhex(s))
80+
return cls(bytes.fromhex(s))
8081

8182

8283
class GossmapChannel(object):
@@ -97,14 +98,14 @@ def __init__(self,
9798
self.updates_fields: List[Optional[Dict[str, Any]]] = [None, None]
9899
self.updates_offset: List[Optional[int]] = [None, None]
99100
self.satoshis = None
100-
self.half_channels: List[GossmapHalfchannel] = [None, None]
101+
self.half_channels: List[Optional[GossmapHalfchannel]] = [None, None]
101102

102103
def update_channel(self,
103104
direction: int,
104-
fields: List[Optional[Dict[str, Any]]] = [None, None],
105-
off: List[Optional[int]] = [None, None]):
105+
fields: Dict[str, Any],
106+
off: int):
106107
self.updates_fields[direction] = fields
107-
self.updates_offset = off
108+
self.updates_offset[direction] = off
108109

109110
half = GossmapHalfchannel(self, direction,
110111
fields['timestamp'],
@@ -132,8 +133,8 @@ class GossmapNode(object):
132133
"""
133134
def __init__(self, node_id: GossmapNodeId):
134135
self.announce_fields: Optional[Dict[str, Any]] = None
135-
self.announce_offset = None
136-
self.channels = []
136+
self.announce_offset: Optional[int] = None
137+
self.channels: List[GossmapChannel] = []
137138
self.node_id = node_id
138139

139140
def __repr__(self):
@@ -148,10 +149,10 @@ def __init__(self, store_filename: str = "gossip_store"):
148149
self.store_buf = bytes()
149150
self.nodes: Dict[GossmapNodeId, GossmapNode] = {}
150151
self.channels: Dict[ShortChannelId, GossmapChannel] = {}
151-
self._last_scid: str = None
152+
self._last_scid: Optional[str] = None
152153
version = self.store_file.read(1)
153154
if version[0] != GOSSIP_STORE_VERSION:
154-
raise ValueError("Invalid gossip store version {}".format(version))
155+
raise ValueError("Invalid gossip store version {}".format(int(version)))
155156
self.bytes_read = 1
156157
self.refresh()
157158

@@ -205,11 +206,11 @@ def get_channel(self, short_channel_id: ShortChannelId):
205206
short_channel_id = ShortChannelId.from_str(short_channel_id)
206207
return self.channels.get(short_channel_id)
207208

208-
def get_node(self, node_id: GossmapNodeId):
209+
def get_node(self, node_id: Union[GossmapNodeId, str]):
209210
""" Resolves a node by its public key node_id """
210-
if type(node_id) == str:
211+
if isinstance(node_id, str):
211212
node_id = GossmapNodeId.from_str(node_id)
212-
return self.nodes.get(node_id)
213+
return self.nodes.get(cast(GossmapNodeId, node_id))
213214

214215
def update_channel(self, rec: bytes, off: int):
215216
fields = channel_update.read(io.BytesIO(rec[2:]), {})

contrib/pyln-proto/pyln/proto/message/message.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def write(self, io_out: BufferedIOBase, v: Dict[str, Any], otherfields: Dict[str
310310
f.fieldtype.write(io_out, val, otherfields)
311311

312312
def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[Dict[str, Any]]:
313-
vals = {}
313+
vals: Dict[str, Any] = {}
314314
for field in self.fields:
315315
val = field.fieldtype.read(io_in, vals)
316316
if val is None:

0 commit comments

Comments
 (0)