diff options
Diffstat (limited to 'backend')
| -rw-r--r-- | backend/decky_loader/loader.py | 16 | ||||
| -rw-r--r-- | backend/decky_loader/main.py | 4 | ||||
| -rw-r--r-- | backend/decky_loader/utilities.py | 6 | ||||
| -rw-r--r-- | backend/decky_loader/wsrouter.py | 85 |
4 files changed, 76 insertions, 35 deletions
diff --git a/backend/decky_loader/loader.py b/backend/decky_loader/loader.py index 85b5de58..8721ea05 100644 --- a/backend/decky_loader/loader.py +++ b/backend/decky_loader/loader.py @@ -70,7 +70,7 @@ class FileChangeHandler(RegexMatchingEventHandler): self.maybe_reload(src_path) class Loader: - def __init__(self, server_instance: PluginManager, plugin_path: str, loop: AbstractEventLoop, live_reload: bool = False) -> None: + def __init__(self, server_instance: PluginManager, ws: WSRouter, plugin_path: str, loop: AbstractEventLoop, live_reload: bool = False) -> None: self.loop = loop self.logger = getLogger("Loader") self.plugin_path = plugin_path @@ -88,10 +88,7 @@ class Loader: self.observer.start() self.loop.create_task(self.enable_reload_wait()) - self.ws = WSRouter() - - server_instance.web_app.add_routes([ - web.get("/ws", self.ws.handle), + server_instance.add_routes([ web.get("/frontend/{path:.*}", self.handle_frontend_assets), web.get("/locales/{path:.*}", self.handle_frontend_locales), web.get("/plugins", self.get_plugins), @@ -101,6 +98,15 @@ class Loader: web.post("/plugins/{plugin_name}/reload", self.handle_backend_reload_request) ]) + ws.add_route("test", self.test_method) + + async def test_method(): + await sleep(2) + + return { + "test data": True + } + async def enable_reload_wait(self): if self.live_reload: await sleep(10) diff --git a/backend/decky_loader/main.py b/backend/decky_loader/main.py index fae30574..e33f0a9b 100644 --- a/backend/decky_loader/main.py +++ b/backend/decky_loader/main.py @@ -1,6 +1,7 @@ # Change PyInstaller files permissions import sys from typing import Dict +from wsrouter import WSRouter from .localplatform.localplatform import (chmod, chown, service_stop, service_start, ON_WINDOWS, get_log_level, get_live_reload, get_server_port, get_server_host, get_chown_plugin_path, @@ -63,7 +64,8 @@ class PluginManager: allow_credentials=True ) }) - self.plugin_loader = Loader(self, plugin_path, self.loop, get_live_reload()) + self.ws = WSRouter(self.loop, self.web_app) + self.plugin_loader = Loader(self, self.ws, plugin_path, self.loop, get_live_reload()) self.settings = SettingsManager("loader", path.join(get_privileged_path(), "settings")) self.plugin_browser = PluginBrowser(plugin_path, self.plugin_loader.plugins, self.plugin_loader, self.settings) self.utilities = Utilities(self) diff --git a/backend/decky_loader/utilities.py b/backend/decky_loader/utilities.py index f04ed371..20280c24 100644 --- a/backend/decky_loader/utilities.py +++ b/backend/decky_loader/utilities.py @@ -63,7 +63,11 @@ class Utilities: web.post("/methods/{method_name}", self._handle_server_method_call) ]) - async def _handle_server_method_call(self, request: web.Request): + context.ws.add_route("utilities/ping", self.ping) + context.ws.add_route("utilities/settings/get", self.get_setting) + context.ws.add_route("utilities/settings/set", self.set_setting) + + async def _handle_server_method_call(self, request): method_name = request.match_info["method_name"] try: args = await request.json() diff --git a/backend/decky_loader/wsrouter.py b/backend/decky_loader/wsrouter.py index 9c8fe424..2b4c3a3b 100644 --- a/backend/decky_loader/wsrouter.py +++ b/backend/decky_loader/wsrouter.py @@ -1,15 +1,18 @@ from logging import getLogger -from asyncio import Future +from asyncio import AbstractEventLoop, Future -from aiohttp import web, WSMsgType +from aiohttp import WSMsgType +from aiohttp.web import Application, WebSocketResponse, Request, Response, get from enum import Enum -from typing import Dict, Any, Callable +from typing import Dict from traceback import format_exc +from helpers import get_csrf_token + class MessageType(Enum): # Call-reply CALL = 0 @@ -23,7 +26,8 @@ class MessageType(Enum): # see wsrouter.ts for typings class WSRouter: - def __init__(self) -> None: + def __init__(self, loop: AbstractEventLoop, server_instance: Application) -> None: + self.loop = loop self.ws = None self.req_id = 0 self.routes = {} @@ -31,12 +35,25 @@ class WSRouter: # self.subscriptions: Dict[str, Callable[[Any]]] = {} self.logger = getLogger("WSRouter") - async def add_route(self, name, route): + server_instance.add_routes([ + get("/ws", self.handle) + ]) + + async def write(self, dta: Dict[str, any]): + await self.ws.send_json(dta) + + def add_route(self, name: str, route): self.routes[name] = route - async def handle(self, request): + def remove_route(self, name: str): + del self.routes[name] + + async def handle(self, request: Request): + # Auth is a query param as JS WebSocket doesn't support headers + if request.rel_url.query["auth"] != get_csrf_token(): + return Response(text='Forbidden', status='403') self.logger.debug('Websocket connection starting') - ws = web.WebSocketResponse() + ws = WebSocketResponse() await ws.prepare(request) self.logger.debug('Websocket connection ready') @@ -58,29 +75,29 @@ class WSRouter: # TODO DO NOT RELY ON THIS! break else: - match data.type: - case MessageType.CALL: + data = msg.json() + match data["type"]: + case MessageType.CALL.value: # do stuff with the message - data = msg.json() - if self.routes[data.route]: + if self.routes[data["route"]]: try: - res = await self.routes[data.route](*data.args) - await self.write({"type": MessageType.REPLY, "id": data.id, "result": res}) - self.logger.debug(f"Started PY call {data.route} ID {data.id}") + res = await self.routes[data["route"]](*data["args"]) + await self.write({"type": MessageType.REPLY.value, "id": data["id"], "result": res}) + self.logger.debug(f'Started PY call {data["route"]} ID {data["id"]}') except: - await self.write({"type": MessageType.ERROR, "id": data.id, "error": format_exc()}) + await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": format_exc()}) else: - await self.write({"type": MessageType.ERROR, "id": data.id, "error": "Route does not exist."}) - case MessageType.REPLY: - if self.running_calls[data.id]: - self.running_calls[data.id].set_result(data.result) - del self.running_calls[data.id] - self.logger.debug(f"Resolved JS call {data.id} with value {str(data.result)}") - case MessageType.ERROR: - if self.running_calls[data.id]: - self.running_calls[data.id].set_exception(data.error) - del self.running_calls[data.id] - self.logger.debug(f"Errored JS call {data.id} with error {data.error}") + await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": "Route does not exist."}) + case MessageType.REPLY.value: + if self.running_calls[data["id"]]: + self.running_calls[data["id"]].set_result(data["result"]) + del self.running_calls[data["id"]] + self.logger.debug(f'Resolved JS call {data["id"]} with value {str(data["result"])}') + case MessageType.ERROR.value: + if self.running_calls[data["id"]]: + self.running_calls[data["id"]].set_exception(data["error"]) + del self.running_calls[data["id"]] + self.logger.debug(f'Errored JS call {data["id"]} with error {data["error"]}') case _: self.logger.error("Unknown message type", data) @@ -94,5 +111,17 @@ class WSRouter: self.logger.debug('Websocket connection closed') return ws - async def write(self, dta: Dict[str, any]): - await self.ws.send_json(dta)
\ No newline at end of file + async def call(self, route: str, *args): + future = Future() + + self.req_id += 1 + + id = self.req_id + + self.running_calls[id] = future + + self.logger.debug(f'Calling JS method {route} with args {str(args)}') + + self.write({ "type": MessageType.CALL.value, "route": route, "args": args, "id": id }) + + return await future |
