1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
|
from logging import getLogger
from asyncio import AbstractEventLoop, Future
from aiohttp import WSMsgType
from aiohttp.web import Application, WebSocketResponse, Request, Response, get
from enum import Enum
from typing import Dict
from traceback import format_exc
from helpers import get_csrf_token
class MessageType(Enum):
# Call-reply
CALL = 0
REPLY = 1
ERROR = 2
# # Pub/sub
# SUBSCRIBE = 3
# UNSUBSCRIBE = 4
# PUBLISH = 5
# see wsrouter.ts for typings
class WSRouter:
def __init__(self, loop: AbstractEventLoop, server_instance: Application) -> None:
self.loop = loop
self.ws = None
self.req_id = 0
self.routes = {}
self.running_calls: Dict[int, Future] = {}
# self.subscriptions: Dict[str, Callable[[Any]]] = {}
self.logger = getLogger("WSRouter")
server_instance.add_routes([
get("/ws", self.handle)
])
async def write(self, dta: Dict[str, any]):
await self.ws.send_json(dta)
def add_route(self, name: str, route):
self.routes[name] = route
def remove_route(self, name: str):
del self.routes[name]
async def handle(self, request: Request):
# Auth is a query param as JS WebSocket doesn't support headers
if request.rel_url.query["auth"] != get_csrf_token():
return Response(text='Forbidden', status='403')
self.logger.debug('Websocket connection starting')
ws = WebSocketResponse()
await ws.prepare(request)
self.logger.debug('Websocket connection ready')
if self.ws != None:
try:
await self.ws.close()
except:
pass
self.ws = None
self.ws = ws
try:
async for msg in ws:
self.logger.debug(msg)
if msg.type == WSMsgType.TEXT:
self.logger.debug(msg.data)
if msg.data == 'close':
# TODO DO NOT RELY ON THIS!
break
else:
data = msg.json()
match data["type"]:
case MessageType.CALL.value:
# do stuff with the message
if self.routes[data["route"]]:
try:
res = await self.routes[data["route"]](*data["args"])
await self.write({"type": MessageType.REPLY.value, "id": data["id"], "result": res})
self.logger.debug(f'Started PY call {data["route"]} ID {data["id"]}')
except:
await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": format_exc()})
else:
await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": "Route does not exist."})
case MessageType.REPLY.value:
if self.running_calls[data["id"]]:
self.running_calls[data["id"]].set_result(data["result"])
del self.running_calls[data["id"]]
self.logger.debug(f'Resolved JS call {data["id"]} with value {str(data["result"])}')
case MessageType.ERROR.value:
if self.running_calls[data["id"]]:
self.running_calls[data["id"]].set_exception(data["error"])
del self.running_calls[data["id"]]
self.logger.debug(f'Errored JS call {data["id"]} with error {data["error"]}')
case _:
self.logger.error("Unknown message type", data)
finally:
try:
await ws.close()
self.ws = None
except:
pass
self.logger.debug('Websocket connection closed')
return ws
async def call(self, route: str, *args):
future = Future()
self.req_id += 1
id = self.req_id
self.running_calls[id] = future
self.logger.debug(f'Calling JS method {route} with args {str(args)}')
self.write({ "type": MessageType.CALL.value, "route": route, "args": args, "id": id })
return await future
|