summaryrefslogtreecommitdiff
path: root/backend/decky_loader/wsrouter.py
diff options
context:
space:
mode:
Diffstat (limited to 'backend/decky_loader/wsrouter.py')
-rw-r--r--backend/decky_loader/wsrouter.py85
1 files changed, 57 insertions, 28 deletions
diff --git a/backend/decky_loader/wsrouter.py b/backend/decky_loader/wsrouter.py
index 9c8fe424..2b4c3a3b 100644
--- a/backend/decky_loader/wsrouter.py
+++ b/backend/decky_loader/wsrouter.py
@@ -1,15 +1,18 @@
from logging import getLogger
-from asyncio import Future
+from asyncio import AbstractEventLoop, Future
-from aiohttp import web, WSMsgType
+from aiohttp import WSMsgType
+from aiohttp.web import Application, WebSocketResponse, Request, Response, get
from enum import Enum
-from typing import Dict, Any, Callable
+from typing import Dict
from traceback import format_exc
+from helpers import get_csrf_token
+
class MessageType(Enum):
# Call-reply
CALL = 0
@@ -23,7 +26,8 @@ class MessageType(Enum):
# see wsrouter.ts for typings
class WSRouter:
- def __init__(self) -> None:
+ def __init__(self, loop: AbstractEventLoop, server_instance: Application) -> None:
+ self.loop = loop
self.ws = None
self.req_id = 0
self.routes = {}
@@ -31,12 +35,25 @@ class WSRouter:
# self.subscriptions: Dict[str, Callable[[Any]]] = {}
self.logger = getLogger("WSRouter")
- async def add_route(self, name, route):
+ server_instance.add_routes([
+ get("/ws", self.handle)
+ ])
+
+ async def write(self, dta: Dict[str, any]):
+ await self.ws.send_json(dta)
+
+ def add_route(self, name: str, route):
self.routes[name] = route
- async def handle(self, request):
+ def remove_route(self, name: str):
+ del self.routes[name]
+
+ async def handle(self, request: Request):
+ # Auth is a query param as JS WebSocket doesn't support headers
+ if request.rel_url.query["auth"] != get_csrf_token():
+ return Response(text='Forbidden', status='403')
self.logger.debug('Websocket connection starting')
- ws = web.WebSocketResponse()
+ ws = WebSocketResponse()
await ws.prepare(request)
self.logger.debug('Websocket connection ready')
@@ -58,29 +75,29 @@ class WSRouter:
# TODO DO NOT RELY ON THIS!
break
else:
- match data.type:
- case MessageType.CALL:
+ data = msg.json()
+ match data["type"]:
+ case MessageType.CALL.value:
# do stuff with the message
- data = msg.json()
- if self.routes[data.route]:
+ if self.routes[data["route"]]:
try:
- res = await self.routes[data.route](*data.args)
- await self.write({"type": MessageType.REPLY, "id": data.id, "result": res})
- self.logger.debug(f"Started PY call {data.route} ID {data.id}")
+ res = await self.routes[data["route"]](*data["args"])
+ await self.write({"type": MessageType.REPLY.value, "id": data["id"], "result": res})
+ self.logger.debug(f'Started PY call {data["route"]} ID {data["id"]}')
except:
- await self.write({"type": MessageType.ERROR, "id": data.id, "error": format_exc()})
+ await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": format_exc()})
else:
- await self.write({"type": MessageType.ERROR, "id": data.id, "error": "Route does not exist."})
- case MessageType.REPLY:
- if self.running_calls[data.id]:
- self.running_calls[data.id].set_result(data.result)
- del self.running_calls[data.id]
- self.logger.debug(f"Resolved JS call {data.id} with value {str(data.result)}")
- case MessageType.ERROR:
- if self.running_calls[data.id]:
- self.running_calls[data.id].set_exception(data.error)
- del self.running_calls[data.id]
- self.logger.debug(f"Errored JS call {data.id} with error {data.error}")
+ await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": "Route does not exist."})
+ case MessageType.REPLY.value:
+ if self.running_calls[data["id"]]:
+ self.running_calls[data["id"]].set_result(data["result"])
+ del self.running_calls[data["id"]]
+ self.logger.debug(f'Resolved JS call {data["id"]} with value {str(data["result"])}')
+ case MessageType.ERROR.value:
+ if self.running_calls[data["id"]]:
+ self.running_calls[data["id"]].set_exception(data["error"])
+ del self.running_calls[data["id"]]
+ self.logger.debug(f'Errored JS call {data["id"]} with error {data["error"]}')
case _:
self.logger.error("Unknown message type", data)
@@ -94,5 +111,17 @@ class WSRouter:
self.logger.debug('Websocket connection closed')
return ws
- async def write(self, dta: Dict[str, any]):
- await self.ws.send_json(dta) \ No newline at end of file
+ async def call(self, route: str, *args):
+ future = Future()
+
+ self.req_id += 1
+
+ id = self.req_id
+
+ self.running_calls[id] = future
+
+ self.logger.debug(f'Calling JS method {route} with args {str(args)}')
+
+ self.write({ "type": MessageType.CALL.value, "route": route, "args": args, "id": id })
+
+ return await future