diff options
Diffstat (limited to 'backend/localsocket.py')
| -rw-r--r-- | backend/localsocket.py | 37 |
1 files changed, 22 insertions, 15 deletions
diff --git a/backend/localsocket.py b/backend/localsocket.py index ef0e3933..3659da03 100644 --- a/backend/localsocket.py +++ b/backend/localsocket.py @@ -1,10 +1,13 @@ -import asyncio, time, random +import asyncio, time +from typing import Awaitable, Callable +import random + from localplatform import ON_WINDOWS BUFFER_LIMIT = 2 ** 20 # 1 MiB class UnixSocket: - def __init__(self, on_new_message): + def __init__(self, on_new_message: Callable[[str], Awaitable[str|None]]): ''' on_new_message takes 1 string argument. It's return value gets used, if not None, to write data to the socket. @@ -46,28 +49,32 @@ class UnixSocket: self.reader = None async def read_single_line(self) -> str|None: - reader, writer = await self.get_socket_connection() + reader, _ = await self.get_socket_connection() - if self.reader == None: - return None + try: + assert reader + except AssertionError: + return return await self._read_single_line(reader) async def write_single_line(self, message : str): - reader, writer = await self.get_socket_connection() + _, writer = await self.get_socket_connection() - if self.writer == None: - return; + try: + assert writer + except AssertionError: + return await self._write_single_line(writer, message) - async def _read_single_line(self, reader) -> str: + async def _read_single_line(self, reader: asyncio.StreamReader) -> str: line = bytearray() while True: try: line.extend(await reader.readuntil()) except asyncio.LimitOverrunError: - line.extend(await reader.read(reader._limit)) + line.extend(await reader.read(reader._limit)) # type: ignore continue except asyncio.IncompleteReadError as err: line.extend(err.partial) @@ -77,27 +84,27 @@ class UnixSocket: return line.decode("utf-8") - async def _write_single_line(self, writer, message : str): + async def _write_single_line(self, writer: asyncio.StreamWriter, message : str): if not message.endswith("\n"): message += "\n" writer.write(message.encode("utf-8")) await writer.drain() - async def _listen_for_method_call(self, reader, writer): + async def _listen_for_method_call(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): while True: line = await self._read_single_line(reader) try: res = await self.on_new_message(line) - except Exception as e: + except Exception: return if res != None: await self._write_single_line(writer, res) class PortSocket (UnixSocket): - def __init__(self, on_new_message): + def __init__(self, on_new_message: Callable[[str], Awaitable[str|None]]): ''' on_new_message takes 1 string argument. It's return value gets used, if not None, to write data to the socket. @@ -125,7 +132,7 @@ class PortSocket (UnixSocket): return True if ON_WINDOWS: - class LocalSocket (PortSocket): + class LocalSocket (PortSocket): # type: ignore pass else: class LocalSocket (UnixSocket): |
