diff options
Diffstat (limited to 'backend/decky_loader/wsrouter.py')
| -rw-r--r-- | backend/decky_loader/wsrouter.py | 85 |
1 files changed, 57 insertions, 28 deletions
diff --git a/backend/decky_loader/wsrouter.py b/backend/decky_loader/wsrouter.py index 9c8fe424..2b4c3a3b 100644 --- a/backend/decky_loader/wsrouter.py +++ b/backend/decky_loader/wsrouter.py @@ -1,15 +1,18 @@ from logging import getLogger -from asyncio import Future +from asyncio import AbstractEventLoop, Future -from aiohttp import web, WSMsgType +from aiohttp import WSMsgType +from aiohttp.web import Application, WebSocketResponse, Request, Response, get from enum import Enum -from typing import Dict, Any, Callable +from typing import Dict from traceback import format_exc +from helpers import get_csrf_token + class MessageType(Enum): # Call-reply CALL = 0 @@ -23,7 +26,8 @@ class MessageType(Enum): # see wsrouter.ts for typings class WSRouter: - def __init__(self) -> None: + def __init__(self, loop: AbstractEventLoop, server_instance: Application) -> None: + self.loop = loop self.ws = None self.req_id = 0 self.routes = {} @@ -31,12 +35,25 @@ class WSRouter: # self.subscriptions: Dict[str, Callable[[Any]]] = {} self.logger = getLogger("WSRouter") - async def add_route(self, name, route): + server_instance.add_routes([ + get("/ws", self.handle) + ]) + + async def write(self, dta: Dict[str, any]): + await self.ws.send_json(dta) + + def add_route(self, name: str, route): self.routes[name] = route - async def handle(self, request): + def remove_route(self, name: str): + del self.routes[name] + + async def handle(self, request: Request): + # Auth is a query param as JS WebSocket doesn't support headers + if request.rel_url.query["auth"] != get_csrf_token(): + return Response(text='Forbidden', status='403') self.logger.debug('Websocket connection starting') - ws = web.WebSocketResponse() + ws = WebSocketResponse() await ws.prepare(request) self.logger.debug('Websocket connection ready') @@ -58,29 +75,29 @@ class WSRouter: # TODO DO NOT RELY ON THIS! break else: - match data.type: - case MessageType.CALL: + data = msg.json() + match data["type"]: + case MessageType.CALL.value: # do stuff with the message - data = msg.json() - if self.routes[data.route]: + if self.routes[data["route"]]: try: - res = await self.routes[data.route](*data.args) - await self.write({"type": MessageType.REPLY, "id": data.id, "result": res}) - self.logger.debug(f"Started PY call {data.route} ID {data.id}") + res = await self.routes[data["route"]](*data["args"]) + await self.write({"type": MessageType.REPLY.value, "id": data["id"], "result": res}) + self.logger.debug(f'Started PY call {data["route"]} ID {data["id"]}') except: - await self.write({"type": MessageType.ERROR, "id": data.id, "error": format_exc()}) + await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": format_exc()}) else: - await self.write({"type": MessageType.ERROR, "id": data.id, "error": "Route does not exist."}) - case MessageType.REPLY: - if self.running_calls[data.id]: - self.running_calls[data.id].set_result(data.result) - del self.running_calls[data.id] - self.logger.debug(f"Resolved JS call {data.id} with value {str(data.result)}") - case MessageType.ERROR: - if self.running_calls[data.id]: - self.running_calls[data.id].set_exception(data.error) - del self.running_calls[data.id] - self.logger.debug(f"Errored JS call {data.id} with error {data.error}") + await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": "Route does not exist."}) + case MessageType.REPLY.value: + if self.running_calls[data["id"]]: + self.running_calls[data["id"]].set_result(data["result"]) + del self.running_calls[data["id"]] + self.logger.debug(f'Resolved JS call {data["id"]} with value {str(data["result"])}') + case MessageType.ERROR.value: + if self.running_calls[data["id"]]: + self.running_calls[data["id"]].set_exception(data["error"]) + del self.running_calls[data["id"]] + self.logger.debug(f'Errored JS call {data["id"]} with error {data["error"]}') case _: self.logger.error("Unknown message type", data) @@ -94,5 +111,17 @@ class WSRouter: self.logger.debug('Websocket connection closed') return ws - async def write(self, dta: Dict[str, any]): - await self.ws.send_json(dta)
\ No newline at end of file + async def call(self, route: str, *args): + future = Future() + + self.req_id += 1 + + id = self.req_id + + self.running_calls[id] = future + + self.logger.debug(f'Calling JS method {route} with args {str(args)}') + + self.write({ "type": MessageType.CALL.value, "route": route, "args": args, "id": id }) + + return await future |
