summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--backend/decky_loader/loader.py16
-rw-r--r--backend/decky_loader/main.py4
-rw-r--r--backend/decky_loader/utilities.py6
-rw-r--r--backend/decky_loader/wsrouter.py85
-rw-r--r--frontend/src/plugin-loader.tsx9
-rw-r--r--frontend/src/utils/settings.ts22
-rw-r--r--frontend/src/wsrouter.ts45
7 files changed, 113 insertions, 74 deletions
diff --git a/backend/decky_loader/loader.py b/backend/decky_loader/loader.py
index 85b5de58..8721ea05 100644
--- a/backend/decky_loader/loader.py
+++ b/backend/decky_loader/loader.py
@@ -70,7 +70,7 @@ class FileChangeHandler(RegexMatchingEventHandler):
self.maybe_reload(src_path)
class Loader:
- def __init__(self, server_instance: PluginManager, plugin_path: str, loop: AbstractEventLoop, live_reload: bool = False) -> None:
+ def __init__(self, server_instance: PluginManager, ws: WSRouter, plugin_path: str, loop: AbstractEventLoop, live_reload: bool = False) -> None:
self.loop = loop
self.logger = getLogger("Loader")
self.plugin_path = plugin_path
@@ -88,10 +88,7 @@ class Loader:
self.observer.start()
self.loop.create_task(self.enable_reload_wait())
- self.ws = WSRouter()
-
- server_instance.web_app.add_routes([
- web.get("/ws", self.ws.handle),
+ server_instance.add_routes([
web.get("/frontend/{path:.*}", self.handle_frontend_assets),
web.get("/locales/{path:.*}", self.handle_frontend_locales),
web.get("/plugins", self.get_plugins),
@@ -101,6 +98,15 @@ class Loader:
web.post("/plugins/{plugin_name}/reload", self.handle_backend_reload_request)
])
+ ws.add_route("test", self.test_method)
+
+ async def test_method():
+ await sleep(2)
+
+ return {
+ "test data": True
+ }
+
async def enable_reload_wait(self):
if self.live_reload:
await sleep(10)
diff --git a/backend/decky_loader/main.py b/backend/decky_loader/main.py
index fae30574..e33f0a9b 100644
--- a/backend/decky_loader/main.py
+++ b/backend/decky_loader/main.py
@@ -1,6 +1,7 @@
# Change PyInstaller files permissions
import sys
from typing import Dict
+from wsrouter import WSRouter
from .localplatform.localplatform import (chmod, chown, service_stop, service_start,
ON_WINDOWS, get_log_level, get_live_reload,
get_server_port, get_server_host, get_chown_plugin_path,
@@ -63,7 +64,8 @@ class PluginManager:
allow_credentials=True
)
})
- self.plugin_loader = Loader(self, plugin_path, self.loop, get_live_reload())
+ self.ws = WSRouter(self.loop, self.web_app)
+ self.plugin_loader = Loader(self, self.ws, plugin_path, self.loop, get_live_reload())
self.settings = SettingsManager("loader", path.join(get_privileged_path(), "settings"))
self.plugin_browser = PluginBrowser(plugin_path, self.plugin_loader.plugins, self.plugin_loader, self.settings)
self.utilities = Utilities(self)
diff --git a/backend/decky_loader/utilities.py b/backend/decky_loader/utilities.py
index f04ed371..20280c24 100644
--- a/backend/decky_loader/utilities.py
+++ b/backend/decky_loader/utilities.py
@@ -63,7 +63,11 @@ class Utilities:
web.post("/methods/{method_name}", self._handle_server_method_call)
])
- async def _handle_server_method_call(self, request: web.Request):
+ context.ws.add_route("utilities/ping", self.ping)
+ context.ws.add_route("utilities/settings/get", self.get_setting)
+ context.ws.add_route("utilities/settings/set", self.set_setting)
+
+ async def _handle_server_method_call(self, request):
method_name = request.match_info["method_name"]
try:
args = await request.json()
diff --git a/backend/decky_loader/wsrouter.py b/backend/decky_loader/wsrouter.py
index 9c8fe424..2b4c3a3b 100644
--- a/backend/decky_loader/wsrouter.py
+++ b/backend/decky_loader/wsrouter.py
@@ -1,15 +1,18 @@
from logging import getLogger
-from asyncio import Future
+from asyncio import AbstractEventLoop, Future
-from aiohttp import web, WSMsgType
+from aiohttp import WSMsgType
+from aiohttp.web import Application, WebSocketResponse, Request, Response, get
from enum import Enum
-from typing import Dict, Any, Callable
+from typing import Dict
from traceback import format_exc
+from helpers import get_csrf_token
+
class MessageType(Enum):
# Call-reply
CALL = 0
@@ -23,7 +26,8 @@ class MessageType(Enum):
# see wsrouter.ts for typings
class WSRouter:
- def __init__(self) -> None:
+ def __init__(self, loop: AbstractEventLoop, server_instance: Application) -> None:
+ self.loop = loop
self.ws = None
self.req_id = 0
self.routes = {}
@@ -31,12 +35,25 @@ class WSRouter:
# self.subscriptions: Dict[str, Callable[[Any]]] = {}
self.logger = getLogger("WSRouter")
- async def add_route(self, name, route):
+ server_instance.add_routes([
+ get("/ws", self.handle)
+ ])
+
+ async def write(self, dta: Dict[str, any]):
+ await self.ws.send_json(dta)
+
+ def add_route(self, name: str, route):
self.routes[name] = route
- async def handle(self, request):
+ def remove_route(self, name: str):
+ del self.routes[name]
+
+ async def handle(self, request: Request):
+ # Auth is a query param as JS WebSocket doesn't support headers
+ if request.rel_url.query["auth"] != get_csrf_token():
+ return Response(text='Forbidden', status='403')
self.logger.debug('Websocket connection starting')
- ws = web.WebSocketResponse()
+ ws = WebSocketResponse()
await ws.prepare(request)
self.logger.debug('Websocket connection ready')
@@ -58,29 +75,29 @@ class WSRouter:
# TODO DO NOT RELY ON THIS!
break
else:
- match data.type:
- case MessageType.CALL:
+ data = msg.json()
+ match data["type"]:
+ case MessageType.CALL.value:
# do stuff with the message
- data = msg.json()
- if self.routes[data.route]:
+ if self.routes[data["route"]]:
try:
- res = await self.routes[data.route](*data.args)
- await self.write({"type": MessageType.REPLY, "id": data.id, "result": res})
- self.logger.debug(f"Started PY call {data.route} ID {data.id}")
+ res = await self.routes[data["route"]](*data["args"])
+ await self.write({"type": MessageType.REPLY.value, "id": data["id"], "result": res})
+ self.logger.debug(f'Started PY call {data["route"]} ID {data["id"]}')
except:
- await self.write({"type": MessageType.ERROR, "id": data.id, "error": format_exc()})
+ await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": format_exc()})
else:
- await self.write({"type": MessageType.ERROR, "id": data.id, "error": "Route does not exist."})
- case MessageType.REPLY:
- if self.running_calls[data.id]:
- self.running_calls[data.id].set_result(data.result)
- del self.running_calls[data.id]
- self.logger.debug(f"Resolved JS call {data.id} with value {str(data.result)}")
- case MessageType.ERROR:
- if self.running_calls[data.id]:
- self.running_calls[data.id].set_exception(data.error)
- del self.running_calls[data.id]
- self.logger.debug(f"Errored JS call {data.id} with error {data.error}")
+ await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": "Route does not exist."})
+ case MessageType.REPLY.value:
+ if self.running_calls[data["id"]]:
+ self.running_calls[data["id"]].set_result(data["result"])
+ del self.running_calls[data["id"]]
+ self.logger.debug(f'Resolved JS call {data["id"]} with value {str(data["result"])}')
+ case MessageType.ERROR.value:
+ if self.running_calls[data["id"]]:
+ self.running_calls[data["id"]].set_exception(data["error"])
+ del self.running_calls[data["id"]]
+ self.logger.debug(f'Errored JS call {data["id"]} with error {data["error"]}')
case _:
self.logger.error("Unknown message type", data)
@@ -94,5 +111,17 @@ class WSRouter:
self.logger.debug('Websocket connection closed')
return ws
- async def write(self, dta: Dict[str, any]):
- await self.ws.send_json(dta) \ No newline at end of file
+ async def call(self, route: str, *args):
+ future = Future()
+
+ self.req_id += 1
+
+ id = self.req_id
+
+ self.running_calls[id] = future
+
+ self.logger.debug(f'Calling JS method {route} with args {str(args)}')
+
+ self.write({ "type": MessageType.CALL.value, "route": route, "args": args, "id": id })
+
+ return await future
diff --git a/frontend/src/plugin-loader.tsx b/frontend/src/plugin-loader.tsx
index e5f69f1f..86592016 100644
--- a/frontend/src/plugin-loader.tsx
+++ b/frontend/src/plugin-loader.tsx
@@ -34,6 +34,7 @@ import Toaster from './toaster';
import { VerInfo, callUpdaterMethod } from './updater';
import { getSetting, setSetting } from './utils/settings';
import TranslationHelper, { TranslationClass } from './utils/TranslationHelper';
+import { WSRouter } from './wsrouter';
const StorePage = lazy(() => import('./components/store/Store'));
const SettingsPage = lazy(() => import('./components/settings'));
@@ -48,6 +49,8 @@ class PluginLoader extends Logger {
public toaster: Toaster = new Toaster();
private deckyState: DeckyState = new DeckyState();
+ public ws: WSRouter = new WSRouter();
+
public hiddenPluginsService = new HiddenPluginsService(this.deckyState);
public notificationService = new NotificationService(this.deckyState);
@@ -102,9 +105,11 @@ class PluginLoader extends Logger {
initFilepickerPatches();
- this.getUserInfo();
+ this.ws.connect().then(() => {
+ this.getUserInfo();
- this.updateVersion();
+ this.updateVersion();
+ });
}
public async getUserInfo() {
diff --git a/frontend/src/utils/settings.ts b/frontend/src/utils/settings.ts
index cadfe935..d390d7ba 100644
--- a/frontend/src/utils/settings.ts
+++ b/frontend/src/utils/settings.ts
@@ -1,24 +1,8 @@
-interface GetSettingArgs<T> {
- key: string;
- default: T;
-}
-
-interface SetSettingArgs<T> {
- key: string;
- value: T;
-}
-
export async function getSetting<T>(key: string, def: T): Promise<T> {
- const res = (await window.DeckyPluginLoader.callServerMethod('get_setting', {
- key,
- default: def,
- } as GetSettingArgs<T>)) as { result: T };
- return res.result;
+ const res = await window.DeckyPluginLoader.ws.call<[string, T], T>('utilities/settings/get', key, def);
+ return res;
}
export async function setSetting<T>(key: string, value: T): Promise<void> {
- await window.DeckyPluginLoader.callServerMethod('set_setting', {
- key,
- value,
- } as SetSettingArgs<T>);
+ await window.DeckyPluginLoader.ws.call<[string, T], void>('utilities/settings/set', key, value);
}
diff --git a/frontend/src/wsrouter.ts b/frontend/src/wsrouter.ts
index b6437568..3a36b5b0 100644
--- a/frontend/src/wsrouter.ts
+++ b/frontend/src/wsrouter.ts
@@ -41,10 +41,11 @@ interface PromiseResolver<T> {
promise: Promise<T>;
}
-class WSRouter extends Logger {
+export class WSRouter extends Logger {
routes: Map<string, (...args: any) => any> = new Map();
runningCalls: Map<number, PromiseResolver<any>> = new Map();
ws?: WebSocket;
+ connectPromise?: Promise<void>;
// Used to map results and errors to calls
reqId: number = 0;
constructor() {
@@ -52,30 +53,35 @@ class WSRouter extends Logger {
}
connect() {
- this.ws = new WebSocket('ws://127.0.0.1:1337/ws');
-
- this.ws.addEventListener('message', this.onMessage.bind(this));
- this.ws.addEventListener('close', this.onError.bind(this));
- this.ws.addEventListener('message', this.onError.bind(this));
+ return (this.connectPromise = new Promise<void>((resolve) => {
+ // Auth is a query param as JS WebSocket doesn't support headers
+ this.ws = new WebSocket(`ws://127.0.0.1:1337/ws?auth=${window.deckyAuthToken}`);
+
+ this.ws.addEventListener('open', () => {
+ this.debug('WS Connected');
+ resolve();
+ delete this.connectPromise;
+ });
+ this.ws.addEventListener('message', this.onMessage.bind(this));
+ this.ws.addEventListener('close', this.onError.bind(this));
+ // this.ws.addEventListener('error', this.onError.bind(this));
+ }));
}
createPromiseResolver<T>(): PromiseResolver<T> {
- let resolver: PromiseResolver<T>;
+ let resolver: Partial<PromiseResolver<T>> = {};
const promise = new Promise<T>((resolve, reject) => {
- resolver = {
- promise,
- resolve,
- reject,
- };
- this.debug('Created new PromiseResolver');
+ resolver.resolve = resolve;
+ resolver.reject = reject;
});
- this.debug('Returning new PromiseResolver');
+ resolver.promise = promise;
// The promise will always run first
// @ts-expect-error 2454
return resolver;
}
- write(data: Message) {
+ async write(data: Message) {
+ if (this.connectPromise) await this.connectPromise;
this.ws?.send(JSON.stringify(data));
}
@@ -129,9 +135,9 @@ class WSRouter extends Logger {
} catch (e) {
this.error('Error parsing WebSocket message', e);
}
- this.call<[number, number], string>('methodName', 1, 2);
}
+ // this.call<[number, number], string>('methodName', 1, 2);
call<Args extends any[] = any[], Return = void>(route: string, ...args: Args): Promise<Return> {
const resolver = this.createPromiseResolver<Return>();
@@ -139,12 +145,15 @@ class WSRouter extends Logger {
this.runningCalls.set(id, resolver);
+ this.debug(`Calling PY method ${route} with args`, args);
+
this.write({ type: MessageType.CALL, route, args, id });
return resolver.promise;
}
- onError(error: any) {
- this.error('WS ERROR', error);
+ async onError(error: any) {
+ this.error('WS DISCONNECTED', error);
+ await this.connect();
}
}