|
1 | 1 | """ModbusProtocol layer.""" |
2 | | -# mypy: disable-error-code="name-defined,union-attr" |
3 | | -# needed because asyncio.Server is not defined (to mypy) in v3.8.16 |
4 | 2 | from __future__ import annotations |
5 | 3 |
|
6 | 4 | import asyncio |
7 | 5 | import dataclasses |
8 | 6 | import ssl |
| 7 | +from contextlib import suppress |
9 | 8 | from enum import Enum |
10 | 9 | from typing import Any, Callable, Coroutine |
11 | 10 |
|
@@ -124,7 +123,7 @@ def __init__( |
124 | 123 | self.is_server = is_server |
125 | 124 | self.is_closing = False |
126 | 125 |
|
127 | | - self.transport: asyncio.BaseModbusProtocol | asyncio.Server = None |
| 126 | + self.transport: asyncio.BaseTransport = None |
128 | 127 | self.loop: asyncio.AbstractEventLoop = None |
129 | 128 | self.recv_buffer: bytes = b"" |
130 | 129 | self.call_create: Callable[[], Coroutine[Any, Any, Any]] = lambda: None |
@@ -258,7 +257,7 @@ async def transport_listen(self) -> bool: |
258 | 257 | # ---------------------------------- # |
259 | 258 | # ModbusProtocol asyncio standard methods # |
260 | 259 | # ---------------------------------- # |
261 | | - def connection_made(self, transport: asyncio.BaseModbusProtocol): |
| 260 | + def connection_made(self, transport: asyncio.BaseTransport): |
262 | 261 | """Call from asyncio, when a connection is made. |
263 | 262 |
|
264 | 263 | :param transport: socket etc. representing the connection. |
@@ -298,10 +297,23 @@ def datagram_received(self, data: bytes, addr: tuple): |
298 | 297 | self.sent_buffer = b"" |
299 | 298 | if not data: |
300 | 299 | return |
301 | | - Log.debug("recv: {} addr={}", data, ":hex", addr) |
| 300 | + Log.debug( |
| 301 | + "recv: {} old_data: {} addr={}", |
| 302 | + data, |
| 303 | + ":hex", |
| 304 | + self.recv_buffer, |
| 305 | + ":hex", |
| 306 | + addr, |
| 307 | + ) |
302 | 308 | self.recv_buffer += data |
303 | 309 | cut = self.callback_data(self.recv_buffer, addr=addr) |
304 | 310 | self.recv_buffer = self.recv_buffer[cut:] |
| 311 | + if self.recv_buffer: |
| 312 | + Log.debug( |
| 313 | + "recv, unused data waiting for next packet: {}", |
| 314 | + self.recv_buffer, |
| 315 | + ":hex", |
| 316 | + ) |
305 | 317 |
|
306 | 318 | def eof_received(self): |
307 | 319 | """Accept other end terminates connection.""" |
@@ -342,11 +354,11 @@ def transport_send(self, data: bytes, addr: tuple = None) -> None: |
342 | 354 | self.sent_buffer = data |
343 | 355 | if self.comm_params.comm_type == CommType.UDP: |
344 | 356 | if addr: |
345 | | - self.transport.sendto(data, addr=addr) |
| 357 | + self.transport.sendto(data, addr=addr) # type: ignore[attr-defined] |
346 | 358 | else: |
347 | | - self.transport.sendto(data) |
| 359 | + self.transport.sendto(data) # type: ignore[attr-defined] |
348 | 360 | else: |
349 | | - self.transport.write(data) |
| 361 | + self.transport.write(data) # type: ignore[attr-defined] |
350 | 362 |
|
351 | 363 | def transport_close(self, intern: bool = False, reconnect: bool = False) -> None: |
352 | 364 | """Close connection. |
@@ -392,26 +404,11 @@ async def create_nullmodem(self, port): |
392 | 404 | """Bypass create_ and use null modem""" |
393 | 405 | if self.is_server: |
394 | 406 | # Listener object |
395 | | - self.transport = NullModem(self) |
396 | | - NullModem.listener_new_connection[port] = self.handle_new_connection |
| 407 | + self.transport = NullModem.set_listener(port, self) |
397 | 408 | return self.transport, self |
398 | 409 |
|
399 | 410 | # connect object |
400 | | - client_protocol = self.handle_new_connection() |
401 | | - try: |
402 | | - server_protocol = NullModem.listener_new_connection[port]() |
403 | | - except KeyError as exc: |
404 | | - raise asyncio.TimeoutError( |
405 | | - f"No listener on port {self.comm_params.port} for connect" |
406 | | - ) from exc |
407 | | - |
408 | | - client_transport = NullModem(client_protocol) |
409 | | - server_transport = NullModem(server_protocol) |
410 | | - client_transport.other_transport = server_transport |
411 | | - server_transport.other_transport = client_transport |
412 | | - client_protocol.connection_made(client_transport) |
413 | | - server_protocol.connection_made(server_transport) |
414 | | - return client_transport, client_protocol |
| 411 | + return NullModem.set_connection(port, self) |
415 | 412 |
|
416 | 413 | def handle_new_connection(self): |
417 | 414 | """Handle incoming connect.""" |
@@ -468,46 +465,117 @@ class NullModem(asyncio.DatagramTransport, asyncio.Transport): |
468 | 465 | (Allowing tests to be shortcut without actual network calls) |
469 | 466 | """ |
470 | 467 |
|
471 | | - listener_new_connection: dict[int, ModbusProtocol] = {} |
| 468 | + listeners: dict[int, ModbusProtocol] = {} |
| 469 | + connections: dict[NullModem, int] = {} |
472 | 470 |
|
473 | | - def __init__(self, protocol: ModbusProtocol): |
| 471 | + def __init__(self, protocol: ModbusProtocol, listen: int = None) -> None: |
474 | 472 | """Create half part of null modem""" |
475 | 473 | asyncio.DatagramTransport.__init__(self) |
476 | 474 | asyncio.Transport.__init__(self) |
477 | | - self.other: NullModem = None |
478 | | - self.protocol: ModbusProtocol | asyncio.BaseProtocol = protocol |
| 475 | + self.protocol: ModbusProtocol = protocol |
479 | 476 | self.serving: asyncio.Future = asyncio.Future() |
480 | | - self.other_transport: NullModem = None |
| 477 | + self.other_modem: NullModem = None |
| 478 | + self.listen = listen |
| 479 | + self.manipulator: Callable[[bytes], list[bytes]] = None |
| 480 | + self._is_closing = False |
| 481 | + |
| 482 | + # -------------------------- # |
| 483 | + # external nullmodem methods # |
| 484 | + # -------------------------- # |
| 485 | + @classmethod |
| 486 | + def set_listener(cls, port: int, parent: ModbusProtocol) -> NullModem: |
| 487 | + """Register listener.""" |
| 488 | + if port in cls.listeners: |
| 489 | + raise AssertionError(f"Port {port} already listening !") |
| 490 | + cls.listeners[port] = parent |
| 491 | + return NullModem(parent, listen=port) |
| 492 | + |
| 493 | + @classmethod |
| 494 | + def set_connection( |
| 495 | + cls, port: int, parent: ModbusProtocol |
| 496 | + ) -> tuple[NullModem, ModbusProtocol]: |
| 497 | + """Connect to listener.""" |
| 498 | + if port not in cls.listeners: |
| 499 | + raise asyncio.TimeoutError(f"Port {port} not being listened on !") |
| 500 | + |
| 501 | + client_protocol = parent.handle_new_connection() |
| 502 | + server_protocol = NullModem.listeners[port].handle_new_connection() |
| 503 | + client_transport = NullModem(client_protocol) |
| 504 | + server_transport = NullModem(server_protocol) |
| 505 | + cls.connections[client_transport] = port |
| 506 | + cls.connections[server_transport] = -port |
| 507 | + client_transport.other_modem = server_transport |
| 508 | + server_transport.other_modem = client_transport |
| 509 | + client_protocol.connection_made(client_transport) |
| 510 | + server_protocol.connection_made(server_transport) |
| 511 | + return client_transport, client_protocol |
| 512 | + |
| 513 | + def set_manipulator(self, function: Callable[[bytes], list[bytes]]) -> None: |
| 514 | + """Register a manipulator.""" |
| 515 | + self.manipulator = function |
| 516 | + |
| 517 | + @classmethod |
| 518 | + def is_dirty(cls): |
| 519 | + """Check if everything is closed.""" |
| 520 | + dirty = False |
| 521 | + if cls.connections: |
| 522 | + Log.error( |
| 523 | + "NullModem_FATAL missing close on port {} connect()", |
| 524 | + [str(key) for key in cls.connections.values()], |
| 525 | + ) |
| 526 | + dirty = True |
| 527 | + if cls.listeners: |
| 528 | + Log.error( |
| 529 | + "NullModem_FATAL missing close on port {} listen()", |
| 530 | + [str(value) for value in cls.listeners], |
| 531 | + ) |
| 532 | + dirty = True |
| 533 | + return dirty |
481 | 534 |
|
482 | 535 | # ---------------- # |
483 | 536 | # external methods # |
484 | 537 | # ---------------- # |
485 | 538 |
|
486 | | - def close(self): |
| 539 | + def close(self) -> None: |
487 | 540 | """Close null modem""" |
| 541 | + if self._is_closing: |
| 542 | + return |
| 543 | + self._is_closing = True |
488 | 544 | if not self.serving.done(): |
489 | 545 | self.serving.set_result(True) |
490 | | - if self.other_transport: |
491 | | - self.other_transport.other_transport = None |
492 | | - self.other_transport.protocol.connection_lost(None) |
493 | | - self.other_transport = None |
| 546 | + if self.listen: |
| 547 | + del self.listeners[self.listen] |
| 548 | + return |
| 549 | + if self.connections: |
| 550 | + with suppress(KeyError): |
| 551 | + del self.connections[self] |
| 552 | + if self.other_modem: |
| 553 | + self.other_modem.other_modem = None |
| 554 | + self.other_modem.close() |
| 555 | + self.other_modem = None |
| 556 | + if self.protocol: |
494 | 557 | self.protocol.connection_lost(None) |
495 | 558 |
|
496 | | - def sendto(self, data: bytes, _addr: Any = None): |
| 559 | + def sendto(self, data: bytes, _addr: Any = None) -> None: |
497 | 560 | """Send datagrame""" |
498 | | - return self.write(data) |
| 561 | + self.write(data) |
499 | 562 |
|
500 | | - def write(self, data: bytes): |
| 563 | + def write(self, data: bytes) -> None: |
501 | 564 | """Send data""" |
502 | | - self.other_transport.protocol.data_received(data) |
| 565 | + if not self.manipulator: |
| 566 | + self.other_modem.protocol.data_received(data) |
| 567 | + return |
| 568 | + data_manipulated = self.manipulator(data) |
| 569 | + for part in data_manipulated: |
| 570 | + self.other_modem.protocol.data_received(part) |
503 | 571 |
|
504 | | - async def serve_forever(self): |
| 572 | + async def serve_forever(self) -> None: |
505 | 573 | """Serve forever""" |
506 | 574 | await self.serving |
507 | 575 |
|
508 | | - # ---------------- # |
509 | | - # Abstract methods # |
510 | | - # ---------------- # |
| 576 | + # ------------- # |
| 577 | + # Dummy methods # |
| 578 | + # ------------- # |
511 | 579 | def abort(self) -> None: |
512 | 580 | """Abort connection.""" |
513 | 581 | self.close() |
@@ -536,11 +604,10 @@ def get_protocol(self) -> ModbusProtocol | asyncio.BaseProtocol: |
536 | 604 |
|
537 | 605 | def set_protocol(self, protocol: asyncio.BaseProtocol) -> None: |
538 | 606 | """Set current protocol.""" |
539 | | - self.protocol = protocol |
540 | 607 |
|
541 | 608 | def is_closing(self) -> bool: |
542 | 609 | """Return true if closing""" |
543 | | - return False |
| 610 | + return self._is_closing |
544 | 611 |
|
545 | 612 | def is_reading(self) -> bool: |
546 | 613 | """Return true if read is active.""" |
|
0 commit comments