summaryrefslogtreecommitdiff
path: root/backend/localsocket.py
blob: f38fe5e7e06182523ee04668b115cb02274fa8d1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import asyncio, time
from typing import Awaitable, Callable
import random

from .localplatform import ON_WINDOWS

BUFFER_LIMIT = 2 ** 20  # 1 MiB

class UnixSocket:
    def __init__(self, on_new_message: Callable[[str], Awaitable[str|None]]):
        '''
        on_new_message takes 1 string argument.
        It's return value gets used, if not None, to write data to the socket.
        Method should be async
        '''
        self.socket_addr = f"/tmp/plugin_socket_{time.time()}"
        self.on_new_message = on_new_message
        self.socket = None
        self.reader = None
        self.writer = None

    async def setup_server(self):
        self.socket = await asyncio.start_unix_server(self._listen_for_method_call, path=self.socket_addr, limit=BUFFER_LIMIT)
    
    async def _open_socket_if_not_exists(self):
        if not self.reader:
            retries = 0
            while retries < 10:
                try:
                    self.reader, self.writer = await asyncio.open_unix_connection(self.socket_addr, limit=BUFFER_LIMIT)
                    return True
                except:
                    await asyncio.sleep(2)
                    retries += 1
            return False
        else:
            return True

    async def get_socket_connection(self):
        if not await self._open_socket_if_not_exists():
            return None, None
        
        return self.reader, self.writer
    
    async def close_socket_connection(self):
        if self.writer != None:
            self.writer.close()

        self.reader = None

    async def read_single_line(self) -> str|None:
        reader, _ = await self.get_socket_connection()

        try:
            assert reader
        except AssertionError:
            return

        return await self._read_single_line(reader)

    async def write_single_line(self, message : str):
        _, writer = await self.get_socket_connection()

        try:
            assert writer
        except AssertionError:
            return

        await self._write_single_line(writer, message)

    async def _read_single_line(self, reader: asyncio.StreamReader) -> str:
        line = bytearray()
        while True:
            try:
                line.extend(await reader.readuntil())
            except asyncio.LimitOverrunError:
                line.extend(await reader.read(reader._limit)) # type: ignore
                continue
            except asyncio.IncompleteReadError as err:
                line.extend(err.partial)
                break
            else:
                break

        return line.decode("utf-8")
    
    async def _write_single_line(self, writer: asyncio.StreamWriter, message : str):
        if not message.endswith("\n"):
            message += "\n"

        writer.write(message.encode("utf-8"))
        await writer.drain()

    async def _listen_for_method_call(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
        while True:
            line = await self._read_single_line(reader)

            try:
                res = await self.on_new_message(line)
            except Exception:
                return

            if res != None:
                await self._write_single_line(writer, res)
            
class PortSocket (UnixSocket):
    def __init__(self, on_new_message: Callable[[str], Awaitable[str|None]]):
        '''
        on_new_message takes 1 string argument.
        It's return value gets used, if not None, to write data to the socket.
        Method should be async
        '''
        super().__init__(on_new_message)
        self.host = "127.0.0.1"
        self.port = random.sample(range(40000, 60000), 1)[0]
    
    async def setup_server(self):
        self.socket = await asyncio.start_server(self._listen_for_method_call, host=self.host, port=self.port, limit=BUFFER_LIMIT)
    
    async def _open_socket_if_not_exists(self):
        if not self.reader:
            retries = 0
            while retries < 10:
                try:
                    self.reader, self.writer = await asyncio.open_connection(host=self.host, port=self.port, limit=BUFFER_LIMIT)
                    return True
                except:
                    await asyncio.sleep(2)
                    retries += 1
            return False
        else:
            return True

if ON_WINDOWS:
    class LocalSocket (PortSocket):  # type: ignore
        pass
else:
    class LocalSocket (UnixSocket):
        pass