summaryrefslogtreecommitdiff
path: root/backend
diff options
context:
space:
mode:
Diffstat (limited to 'backend')
-rw-r--r--backend/decky_loader/loader.py2
-rw-r--r--backend/decky_loader/utilities.py4
-rw-r--r--backend/decky_loader/wsrouter.py104
3 files changed, 60 insertions, 50 deletions
diff --git a/backend/decky_loader/loader.py b/backend/decky_loader/loader.py
index fa4949f2..7f81777f 100644
--- a/backend/decky_loader/loader.py
+++ b/backend/decky_loader/loader.py
@@ -88,7 +88,7 @@ class Loader:
self.observer.start()
self.loop.create_task(self.enable_reload_wait())
- server_instance.add_routes([
+ server_instance.web_app.add_routes([
web.get("/frontend/{path:.*}", self.handle_frontend_assets),
web.get("/locales/{path:.*}", self.handle_frontend_locales),
web.get("/plugins", self.get_plugins),
diff --git a/backend/decky_loader/utilities.py b/backend/decky_loader/utilities.py
index 2eea63ea..0e3e9fb0 100644
--- a/backend/decky_loader/utilities.py
+++ b/backend/decky_loader/utilities.py
@@ -166,7 +166,7 @@ class Utilities:
style.textContent = `{style}`;
}})()
""", False)
-
+ assert result is not None # TODO remove this once it has proper typings
if "exceptionDetails" in result["result"]:
raise result["result"]["exceptionDetails"]
@@ -233,7 +233,7 @@ class Utilities:
folders.append({"file": file, "filest": filest, "is_dir": True})
elif include_files:
# Handle requested extensions if present
- if len(include_ext) == 0 or 'all_files' in include_ext \
+ if include_ext == None or len(include_ext) == 0 or 'all_files' in include_ext \
or splitext(file.name)[1].lstrip('.') in include_ext:
if (is_hidden and include_hidden) or not is_hidden:
files.append({"file": file, "filest": filest, "is_dir": False})
diff --git a/backend/decky_loader/wsrouter.py b/backend/decky_loader/wsrouter.py
index 7a7b59c9..28e3e925 100644
--- a/backend/decky_loader/wsrouter.py
+++ b/backend/decky_loader/wsrouter.py
@@ -1,37 +1,51 @@
from logging import getLogger
-from asyncio import AbstractEventLoop, Future
+from asyncio import AbstractEventLoop, Future, create_task
-from aiohttp import WSMsgType
+from aiohttp import WSMsgType, WSMessage
from aiohttp.web import Application, WebSocketResponse, Request, Response, get
-from enum import Enum
+from enum import IntEnum
-from typing import Dict
+from typing import Callable, Dict, Any, cast, TypeVar, Type
+from dataclasses import dataclass
from traceback import format_exc
from helpers import get_csrf_token
-class MessageType(Enum):
- # Call-reply
+class MessageType(IntEnum):
+ ERROR = -1
+ # Call-reply, Frontend -> Backend
CALL = 0
REPLY = 1
- ERROR = 2
- # # Pub/sub
- # SUBSCRIBE = 3
- # UNSUBSCRIBE = 4
- # PUBLISH = 5
+ # Pub/Sub, Backend -> Frontend
+ EVENT = 3
+
+# WSMessage with slightly better typings
+class WSMessageExtra(WSMessage):
+ data: Any
+ type: WSMsgType
+@dataclass
+class Message:
+ data: Any
+ type: MessageType
+
+# @dataclass
+# class CallMessage
# see wsrouter.ts for typings
+DataType = TypeVar("DataType")
+
+Route = Callable[..., Future[Any]]
+
class WSRouter:
def __init__(self, loop: AbstractEventLoop, server_instance: Application) -> None:
self.loop = loop
- self.ws = None
- self.req_id = 0
- self.routes = {}
- self.running_calls: Dict[int, Future] = {}
+ self.ws: WebSocketResponse | None
+ self.instance_id = 0
+ self.routes: Dict[str, Route] = {}
# self.subscriptions: Dict[str, Callable[[Any]]] = {}
self.logger = getLogger("WSRouter")
@@ -39,22 +53,38 @@ class WSRouter:
get("/ws", self.handle)
])
- async def write(self, dta: Dict[str, any]):
- await self.ws.send_json(dta)
+ async def write(self, data: Dict[str, Any]):
+ if self.ws != None:
+ await self.ws.send_json(data)
+ else:
+ self.logger.warn("Dropping message as there is no connected socket: %s", data)
- def add_route(self, name: str, route):
+ def add_route(self, name: str, route: Route):
self.routes[name] = route
def remove_route(self, name: str):
del self.routes[name]
+ async def _call_route(self, route: str, args: ..., call_id: int):
+ instance_id = self.instance_id
+ res = await self.routes[route](*args)
+ if instance_id != self.instance_id:
+ try:
+ self.logger.warn("Ignoring %s reply from stale instance %d with args %s and response %s", route, instance_id, args, res)
+ except:
+ self.logger.warn("Ignoring %s reply from stale instance %d (failed to log event data)", route, instance_id)
+ finally:
+ return
+ await self.write({"type": MessageType.REPLY.value, "id": call_id, "result": res})
+
async def handle(self, request: Request):
# Auth is a query param as JS WebSocket doesn't support headers
if request.rel_url.query["auth"] != get_csrf_token():
- return Response(text='Forbidden', status='403')
+ return Response(text='Forbidden', status=403)
self.logger.debug('Websocket connection starting')
ws = WebSocketResponse()
await ws.prepare(request)
+ self.instance_id += 1
self.logger.debug('Websocket connection ready')
if self.ws != None:
@@ -68,6 +98,8 @@ class WSRouter:
try:
async for msg in ws:
+ msg = cast(WSMessageExtra, msg)
+
self.logger.debug(msg)
if msg.type == WSMsgType.TEXT:
self.logger.debug(msg.data)
@@ -81,25 +113,13 @@ class WSRouter:
# do stuff with the message
if data["route"] in self.routes:
try:
- res = await self.routes[data["route"]](*data["args"])
- await self.write({"type": MessageType.REPLY.value, "id": data["id"], "result": res})
self.logger.debug(f'Started PY call {data["route"]} ID {data["id"]}')
+ create_task(self._call_route(data["route"], data["args"], data["id"]))
except:
- await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": format_exc()})
+ create_task(self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": format_exc()}))
else:
# Dunno why but fstring doesnt work here
- await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": "Route " + data["route"] + " does not exist."})
- case MessageType.REPLY.value:
- if self.running_calls[data["id"]]:
- self.running_calls[data["id"]].set_result(data["result"])
- del self.running_calls[data["id"]]
- self.logger.debug(f'Resolved JS call {data["id"]} with value {str(data["result"])}')
- case MessageType.ERROR.value:
- if self.running_calls[data["id"]]:
- self.running_calls[data["id"]].set_exception(data["error"])
- del self.running_calls[data["id"]]
- self.logger.debug(f'Errored JS call {data["id"]} with error {data["error"]}')
-
+ create_task(self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": "Route " + data["route"] + " does not exist."}))
case _:
self.logger.error("Unknown message type", data)
finally:
@@ -112,17 +132,7 @@ class WSRouter:
self.logger.debug('Websocket connection closed')
return ws
- async def call(self, route: str, *args):
- future = Future()
-
- self.req_id += 1
-
- id = self.req_id
-
- self.running_calls[id] = future
-
- self.logger.debug(f'Calling JS method {route} with args {str(args)}')
-
- self.write({ "type": MessageType.CALL.value, "route": route, "args": args, "id": id })
+ async def emit(self, event: str, data: DataType | None = None, data_type: Type[DataType] = Any):
+ self.logger.debug('Firing frontend event %s with args %s', data)
- return await future
+ await self.write({ "type": MessageType.EVENT.value, "event": event, "data": data }) \ No newline at end of file