summaryrefslogtreecommitdiff
path: root/backend
diff options
context:
space:
mode:
Diffstat (limited to 'backend')
-rw-r--r--backend/decky_loader/loader.py16
-rw-r--r--backend/decky_loader/main.py4
-rw-r--r--backend/decky_loader/utilities.py6
-rw-r--r--backend/decky_loader/wsrouter.py85
4 files changed, 76 insertions, 35 deletions
diff --git a/backend/decky_loader/loader.py b/backend/decky_loader/loader.py
index 85b5de58..8721ea05 100644
--- a/backend/decky_loader/loader.py
+++ b/backend/decky_loader/loader.py
@@ -70,7 +70,7 @@ class FileChangeHandler(RegexMatchingEventHandler):
self.maybe_reload(src_path)
class Loader:
- def __init__(self, server_instance: PluginManager, plugin_path: str, loop: AbstractEventLoop, live_reload: bool = False) -> None:
+ def __init__(self, server_instance: PluginManager, ws: WSRouter, plugin_path: str, loop: AbstractEventLoop, live_reload: bool = False) -> None:
self.loop = loop
self.logger = getLogger("Loader")
self.plugin_path = plugin_path
@@ -88,10 +88,7 @@ class Loader:
self.observer.start()
self.loop.create_task(self.enable_reload_wait())
- self.ws = WSRouter()
-
- server_instance.web_app.add_routes([
- web.get("/ws", self.ws.handle),
+ server_instance.add_routes([
web.get("/frontend/{path:.*}", self.handle_frontend_assets),
web.get("/locales/{path:.*}", self.handle_frontend_locales),
web.get("/plugins", self.get_plugins),
@@ -101,6 +98,15 @@ class Loader:
web.post("/plugins/{plugin_name}/reload", self.handle_backend_reload_request)
])
+ ws.add_route("test", self.test_method)
+
+ async def test_method():
+ await sleep(2)
+
+ return {
+ "test data": True
+ }
+
async def enable_reload_wait(self):
if self.live_reload:
await sleep(10)
diff --git a/backend/decky_loader/main.py b/backend/decky_loader/main.py
index fae30574..e33f0a9b 100644
--- a/backend/decky_loader/main.py
+++ b/backend/decky_loader/main.py
@@ -1,6 +1,7 @@
# Change PyInstaller files permissions
import sys
from typing import Dict
+from wsrouter import WSRouter
from .localplatform.localplatform import (chmod, chown, service_stop, service_start,
ON_WINDOWS, get_log_level, get_live_reload,
get_server_port, get_server_host, get_chown_plugin_path,
@@ -63,7 +64,8 @@ class PluginManager:
allow_credentials=True
)
})
- self.plugin_loader = Loader(self, plugin_path, self.loop, get_live_reload())
+ self.ws = WSRouter(self.loop, self.web_app)
+ self.plugin_loader = Loader(self, self.ws, plugin_path, self.loop, get_live_reload())
self.settings = SettingsManager("loader", path.join(get_privileged_path(), "settings"))
self.plugin_browser = PluginBrowser(plugin_path, self.plugin_loader.plugins, self.plugin_loader, self.settings)
self.utilities = Utilities(self)
diff --git a/backend/decky_loader/utilities.py b/backend/decky_loader/utilities.py
index f04ed371..20280c24 100644
--- a/backend/decky_loader/utilities.py
+++ b/backend/decky_loader/utilities.py
@@ -63,7 +63,11 @@ class Utilities:
web.post("/methods/{method_name}", self._handle_server_method_call)
])
- async def _handle_server_method_call(self, request: web.Request):
+ context.ws.add_route("utilities/ping", self.ping)
+ context.ws.add_route("utilities/settings/get", self.get_setting)
+ context.ws.add_route("utilities/settings/set", self.set_setting)
+
+ async def _handle_server_method_call(self, request):
method_name = request.match_info["method_name"]
try:
args = await request.json()
diff --git a/backend/decky_loader/wsrouter.py b/backend/decky_loader/wsrouter.py
index 9c8fe424..2b4c3a3b 100644
--- a/backend/decky_loader/wsrouter.py
+++ b/backend/decky_loader/wsrouter.py
@@ -1,15 +1,18 @@
from logging import getLogger
-from asyncio import Future
+from asyncio import AbstractEventLoop, Future
-from aiohttp import web, WSMsgType
+from aiohttp import WSMsgType
+from aiohttp.web import Application, WebSocketResponse, Request, Response, get
from enum import Enum
-from typing import Dict, Any, Callable
+from typing import Dict
from traceback import format_exc
+from helpers import get_csrf_token
+
class MessageType(Enum):
# Call-reply
CALL = 0
@@ -23,7 +26,8 @@ class MessageType(Enum):
# see wsrouter.ts for typings
class WSRouter:
- def __init__(self) -> None:
+ def __init__(self, loop: AbstractEventLoop, server_instance: Application) -> None:
+ self.loop = loop
self.ws = None
self.req_id = 0
self.routes = {}
@@ -31,12 +35,25 @@ class WSRouter:
# self.subscriptions: Dict[str, Callable[[Any]]] = {}
self.logger = getLogger("WSRouter")
- async def add_route(self, name, route):
+ server_instance.add_routes([
+ get("/ws", self.handle)
+ ])
+
+ async def write(self, dta: Dict[str, any]):
+ await self.ws.send_json(dta)
+
+ def add_route(self, name: str, route):
self.routes[name] = route
- async def handle(self, request):
+ def remove_route(self, name: str):
+ del self.routes[name]
+
+ async def handle(self, request: Request):
+ # Auth is a query param as JS WebSocket doesn't support headers
+ if request.rel_url.query["auth"] != get_csrf_token():
+ return Response(text='Forbidden', status='403')
self.logger.debug('Websocket connection starting')
- ws = web.WebSocketResponse()
+ ws = WebSocketResponse()
await ws.prepare(request)
self.logger.debug('Websocket connection ready')
@@ -58,29 +75,29 @@ class WSRouter:
# TODO DO NOT RELY ON THIS!
break
else:
- match data.type:
- case MessageType.CALL:
+ data = msg.json()
+ match data["type"]:
+ case MessageType.CALL.value:
# do stuff with the message
- data = msg.json()
- if self.routes[data.route]:
+ if self.routes[data["route"]]:
try:
- res = await self.routes[data.route](*data.args)
- await self.write({"type": MessageType.REPLY, "id": data.id, "result": res})
- self.logger.debug(f"Started PY call {data.route} ID {data.id}")
+ res = await self.routes[data["route"]](*data["args"])
+ await self.write({"type": MessageType.REPLY.value, "id": data["id"], "result": res})
+ self.logger.debug(f'Started PY call {data["route"]} ID {data["id"]}')
except:
- await self.write({"type": MessageType.ERROR, "id": data.id, "error": format_exc()})
+ await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": format_exc()})
else:
- await self.write({"type": MessageType.ERROR, "id": data.id, "error": "Route does not exist."})
- case MessageType.REPLY:
- if self.running_calls[data.id]:
- self.running_calls[data.id].set_result(data.result)
- del self.running_calls[data.id]
- self.logger.debug(f"Resolved JS call {data.id} with value {str(data.result)}")
- case MessageType.ERROR:
- if self.running_calls[data.id]:
- self.running_calls[data.id].set_exception(data.error)
- del self.running_calls[data.id]
- self.logger.debug(f"Errored JS call {data.id} with error {data.error}")
+ await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": "Route does not exist."})
+ case MessageType.REPLY.value:
+ if self.running_calls[data["id"]]:
+ self.running_calls[data["id"]].set_result(data["result"])
+ del self.running_calls[data["id"]]
+ self.logger.debug(f'Resolved JS call {data["id"]} with value {str(data["result"])}')
+ case MessageType.ERROR.value:
+ if self.running_calls[data["id"]]:
+ self.running_calls[data["id"]].set_exception(data["error"])
+ del self.running_calls[data["id"]]
+ self.logger.debug(f'Errored JS call {data["id"]} with error {data["error"]}')
case _:
self.logger.error("Unknown message type", data)
@@ -94,5 +111,17 @@ class WSRouter:
self.logger.debug('Websocket connection closed')
return ws
- async def write(self, dta: Dict[str, any]):
- await self.ws.send_json(dta) \ No newline at end of file
+ async def call(self, route: str, *args):
+ future = Future()
+
+ self.req_id += 1
+
+ id = self.req_id
+
+ self.running_calls[id] = future
+
+ self.logger.debug(f'Calling JS method {route} with args {str(args)}')
+
+ self.write({ "type": MessageType.CALL.value, "route": route, "args": args, "id": id })
+
+ return await future