summaryrefslogtreecommitdiff
path: root/backend/decky_loader/localplatform/localsocket.py
blob: b25b275a573ad7af7c1b0d589898ba0fa0aebee5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import asyncio, time
from typing import Any, Callable, Coroutine
import random

from .localplatform import ON_WINDOWS

BUFFER_LIMIT = 2 ** 20  # 1 MiB

class UnixSocket:
    def __init__(self, on_new_message: Callable[[str], Coroutine[Any, Any, Any]]):
        '''
        on_new_message takes 1 string argument.
        It's return value gets used, if not None, to write data to the socket.
        Method should be async
        '''
        self.socket_addr = f"/tmp/plugin_socket_{time.time()}"
        self.on_new_message = on_new_message
        self.socket = None
        self.reader = None
        self.writer = None
        self.server_writer = None
        self.open_lock = asyncio.Lock()

    async def setup_server(self):
        try:
            self.socket = await asyncio.start_unix_server(self._listen_for_method_call, path=self.socket_addr, limit=BUFFER_LIMIT)
        except asyncio.CancelledError:
            await self.close_socket_connection()
            raise

    async def _open_socket_if_not_exists(self):
        if not self.reader:
            retries = 0
            while retries < 10:
                try:
                    self.reader, self.writer = await asyncio.open_unix_connection(self.socket_addr, limit=BUFFER_LIMIT)
                    return True
                except:
                    await asyncio.sleep(2)
                    retries += 1
            return False
        else:
            return True

    async def get_socket_connection(self):
        async with self.open_lock:
            if not await self._open_socket_if_not_exists():
                return None, None
            
            return self.reader, self.writer
    
    async def close_socket_connection(self):
        if self.writer != None:
            self.writer.close()

        self.reader = None

        if self.socket:
            self.socket.close()
            await self.socket.wait_closed()

    async def read_single_line(self) -> str|None:
        reader, _ = await self.get_socket_connection()

        try:
            assert reader
        except AssertionError:
            return

        return await self._read_single_line(reader)

    async def write_single_line(self, message : str):
        _, writer = await self.get_socket_connection()

        try:
            assert writer
        except AssertionError:
            return

        await self._write_single_line(writer, message)

    async def _read_single_line(self, reader: asyncio.StreamReader) -> str:
        line = bytearray()
        while True:
            try:
                line.extend(await reader.readuntil())
            except asyncio.LimitOverrunError:
                line.extend(await reader.read(reader._limit)) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
                continue
            except asyncio.IncompleteReadError as err:
                line.extend(err.partial)
                break
            except asyncio.CancelledError:
                break
            else:
                break

        return line.decode("utf-8")
    
    async def _write_single_line(self, writer: asyncio.StreamWriter, message : str):
        if not message.endswith("\n"):
            message += "\n"

        writer.write(message.encode("utf-8"))
        await writer.drain()
    
    async def write_single_line_server(self, message: str):
        if self.server_writer is None:
            return
        await self._write_single_line(self.server_writer, message)

    async def _listen_for_method_call(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
        self.server_writer = writer
        while True:

            def _(task: asyncio.Task[str|None]):
                res = task.result()
                if res is not None:
                    asyncio.create_task(self._write_single_line(writer, res))

            line = await self._read_single_line(reader)
            asyncio.create_task(self.on_new_message(line)).add_done_callback(_)
            
class PortSocket (UnixSocket):
    def __init__(self, on_new_message: Callable[[str], Coroutine[Any, Any, Any]]):
        '''
        on_new_message takes 1 string argument.
        It's return value gets used, if not None, to write data to the socket.
        Method should be async
        '''
        super().__init__(on_new_message)
        self.host = "127.0.0.1"
        self.port = random.sample(range(40000, 60000), 1)[0]
    
    async def setup_server(self):
        try:
            self.socket = await asyncio.start_server(self._listen_for_method_call, host=self.host, port=self.port, limit=BUFFER_LIMIT)
        except asyncio.CancelledError:
            await self.close_socket_connection()
            raise

    async def _open_socket_if_not_exists(self):
        if not self.reader:
            retries = 0
            while retries < 10:
                try:
                    self.reader, self.writer = await asyncio.open_connection(host=self.host, port=self.port, limit=BUFFER_LIMIT)
                    return True
                except:
                    await asyncio.sleep(2)
                    retries += 1
            return False
        else:
            return True

if ON_WINDOWS:
    class LocalSocket (PortSocket):  # type: ignore
        pass
else:
    class LocalSocket (UnixSocket):
        pass