summaryrefslogtreecommitdiff
path: root/backend/decky_loader/wsrouter.py
blob: 7a7b59c925125745db3dd393cacc6bdfbd5a513a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from logging import getLogger

from asyncio import AbstractEventLoop, Future

from aiohttp import WSMsgType
from aiohttp.web import Application, WebSocketResponse, Request, Response, get

from enum import Enum

from typing import Dict

from traceback import format_exc

from helpers import get_csrf_token

class MessageType(Enum):
    # Call-reply
    CALL = 0
    REPLY = 1
    ERROR = 2
    # # Pub/sub
    # SUBSCRIBE = 3
    # UNSUBSCRIBE = 4
    # PUBLISH = 5

# see wsrouter.ts for typings

class WSRouter:
    def __init__(self, loop: AbstractEventLoop, server_instance: Application) -> None:
        self.loop = loop
        self.ws = None
        self.req_id = 0
        self.routes = {}
        self.running_calls: Dict[int, Future] = {}
        # self.subscriptions: Dict[str, Callable[[Any]]] = {}
        self.logger = getLogger("WSRouter")

        server_instance.add_routes([
            get("/ws", self.handle)
        ])

    async def write(self, dta: Dict[str, any]):
        await self.ws.send_json(dta)

    def add_route(self, name: str, route):
        self.routes[name] = route

    def remove_route(self, name: str):
        del self.routes[name]

    async def handle(self, request: Request):
        # Auth is a query param as JS WebSocket doesn't support headers
        if request.rel_url.query["auth"] != get_csrf_token():
            return Response(text='Forbidden', status='403') 
        self.logger.debug('Websocket connection starting')
        ws = WebSocketResponse()
        await ws.prepare(request)
        self.logger.debug('Websocket connection ready')

        if self.ws != None:
            try:
                await self.ws.close()
            except:
                pass
            self.ws = None

        self.ws = ws
        
        try:
            async for msg in ws:
                self.logger.debug(msg)
                if msg.type == WSMsgType.TEXT:
                    self.logger.debug(msg.data)
                    if msg.data == 'close':
                        # TODO DO NOT RELY ON THIS!
                        break
                    else:
                        data = msg.json()
                        match data["type"]:
                            case MessageType.CALL.value:
                                # do stuff with the message
                                if data["route"] in self.routes:
                                    try:
                                        res = await self.routes[data["route"]](*data["args"])
                                        await self.write({"type": MessageType.REPLY.value, "id": data["id"], "result": res})
                                        self.logger.debug(f'Started PY call {data["route"]} ID {data["id"]}')
                                    except:
                                        await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": format_exc()})
                                else:
                                    # Dunno why but fstring doesnt work here
                                    await self.write({"type": MessageType.ERROR.value, "id": data["id"], "error": "Route " + data["route"] + " does not exist."})
                            case MessageType.REPLY.value:
                                if self.running_calls[data["id"]]:
                                    self.running_calls[data["id"]].set_result(data["result"])
                                    del self.running_calls[data["id"]]
                                    self.logger.debug(f'Resolved JS call {data["id"]} with value {str(data["result"])}')
                            case MessageType.ERROR.value:
                                if self.running_calls[data["id"]]:
                                    self.running_calls[data["id"]].set_exception(data["error"])
                                    del self.running_calls[data["id"]]
                                    self.logger.debug(f'Errored JS call {data["id"]} with error {data["error"]}')

                            case _:
                                self.logger.error("Unknown message type", data)
        finally:
            try:
                await ws.close()
                self.ws = None
            except:
                pass

        self.logger.debug('Websocket connection closed')
        return ws

    async def call(self, route: str, *args):
        future = Future()

        self.req_id += 1

        id = self.req_id

        self.running_calls[id] = future

        self.logger.debug(f'Calling JS method {route} with args {str(args)}')

        self.write({ "type": MessageType.CALL.value, "route": route, "args": args, "id": id })

        return await future