summaryrefslogtreecommitdiff
path: root/backend/decky_loader/plugin/sandboxed_plugin.py
blob: 23575900f9381f72f6643c4749a0e0e60a3a98bd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import sys
from os import path, environ
from importlib.util import module_from_spec, spec_from_file_location
from json import dumps, loads
from logging import getLogger
from traceback import format_exc
from asyncio import (get_event_loop, new_event_loop,
                     set_event_loop)
from setproctitle import setproctitle, setthreadtitle

from .messages import SocketResponseDict, SocketMessageType
from ..localplatform.localsocket import LocalSocket
from ..localplatform.localplatform import setgid, setuid, get_username, get_home_path
from ..enums import UserType
from .. import helpers, settings, injector # pyright: ignore [reportUnusedImport]

from typing import List, TypeVar, Any

DataType = TypeVar("DataType")

class SandboxedPlugin:
    def __init__(self,
                 name: str,
                 passive: bool,
                 flags: List[str],
                 file: str,
                 plugin_directory: str,
                 plugin_path: str,
                 version: str|None,
                 author: str,
                 api_version: int) -> None:
        self.name = name
        self.passive = passive
        self.flags = flags
        self.file = file
        self.plugin_path = plugin_path
        self.plugin_directory = plugin_directory
        self.version = version
        self.author = author
        self.api_version = api_version

        self.log = getLogger("sandboxed_plugin")

    def initialize(self, socket: LocalSocket):
        self._socket = socket

        try:
            setproctitle(f"{self.name} ({self.file})")
            setthreadtitle(self.name)

            set_event_loop(new_event_loop())
            if self.passive:
                return
            setgid(UserType.ROOT if "root" in self.flags else UserType.HOST_USER)
            setuid(UserType.ROOT if "root" in self.flags else UserType.HOST_USER)
            # export a bunch of environment variables to help plugin developers
            environ["HOME"] = get_home_path(UserType.ROOT if "root" in self.flags else UserType.HOST_USER)
            environ["USER"] = "root" if "root" in self.flags else get_username()
            environ["DECKY_VERSION"] = helpers.get_loader_version()
            environ["DECKY_USER"] = get_username()
            environ["DECKY_USER_HOME"] = helpers.get_home_path()
            environ["DECKY_HOME"] = helpers.get_homebrew_path()
            environ["DECKY_PLUGIN_SETTINGS_DIR"] = path.join(environ["DECKY_HOME"], "settings", self.plugin_directory)
            environ["DECKY_PLUGIN_RUNTIME_DIR"] = path.join(environ["DECKY_HOME"], "data", self.plugin_directory)
            environ["DECKY_PLUGIN_LOG_DIR"] = path.join(environ["DECKY_HOME"], "logs", self.plugin_directory)
            environ["DECKY_PLUGIN_DIR"] = path.join(self.plugin_path, self.plugin_directory)
            environ["DECKY_PLUGIN_NAME"] = self.name
            if self.version:
                environ["DECKY_PLUGIN_VERSION"] = self.version
            environ["DECKY_PLUGIN_AUTHOR"] = self.author

            # append the plugin's `py_modules` to the recognized python paths
            sys.path.append(path.join(environ["DECKY_PLUGIN_DIR"], "py_modules"))

            #TODO: FIX IN A LESS CURSED WAY
            keys = [key for key in sys.modules if key.startswith("decky_loader.")]
            for key in keys:
                sys.modules[key.replace("decky_loader.", "")] = sys.modules[key]
            
            from .imports import decky
            async def emit(event: str, *args: Any) -> None:
                await self._socket.write_single_line_server(dumps({
                    "type": SocketMessageType.EVENT,
                    "event": event,
                    "args": args
                }))
            # copy the docstring over so we don't have to duplicate it
            emit.__doc__ = decky.emit.__doc__
            decky.emit = emit
            sys.modules["decky"] = decky
            # provided for compatibility
            sys.modules["decky_plugin"] = decky

            spec = spec_from_file_location("_", self.file)
            assert spec is not None
            module = module_from_spec(spec)
            assert spec.loader is not None
            spec.loader.exec_module(module)
            # TODO fix self weirdness once plugin.json versioning is done. need this before WS release!
            if self.api_version > 0:
                self.Plugin = module.Plugin()
            else:
                self.Plugin = module.Plugin

            if hasattr(self.Plugin, "_migration"):
                if self.api_version > 0:
                    get_event_loop().run_until_complete(self.Plugin._migration())
                else:
                    get_event_loop().run_until_complete(self.Plugin._migration(self.Plugin))
            if hasattr(self.Plugin, "_main"):
                if self.api_version > 0:
                    get_event_loop().create_task(self.Plugin._main())
                else:
                    get_event_loop().create_task(self.Plugin._main(self.Plugin))
            get_event_loop().create_task(socket.setup_server(self.on_new_message))
        except:
            self.log.error("Failed to start " + self.name + "!\n" + format_exc())
            sys.exit(0)
        try:
            get_event_loop().run_forever()
        except SystemExit:
            pass
        except:
            self.log.error("Loop exited for " + self.name + "!\n" + format_exc())
        finally:
            get_event_loop().close()

    async def _unload(self):
        try:
            self.log.info("Attempting to unload with plugin " + self.name + "'s \"_unload\" function.\n")
            if hasattr(self.Plugin, "_unload"):
                if self.api_version > 0:
                    await self.Plugin._unload()
                else:
                    await self.Plugin._unload(self.Plugin)
                self.log.info("Unloaded " + self.name + "\n")
            else:
                self.log.info("Could not find \"_unload\" in " + self.name + "'s main.py" + "\n")
        except:
            self.log.error("Failed to unload " + self.name + "!\n" + format_exc())
            pass

    async def _uninstall(self):
        try:
            self.log.info("Attempting to uninstall with plugin " + self.name + "'s \"_uninstall\" function.\n")
            if hasattr(self.Plugin, "_uninstall"):
                if self.api_version > 0:
                    await self.Plugin._uninstall()
                else:
                    await self.Plugin._uninstall(self.Plugin)
                self.log.info("Uninstalled " + self.name + "\n")
            else:
                self.log.info("Could not find \"_uninstall\" in " + self.name + "'s main.py" + "\n")
        except:
            self.log.error("Failed to uninstall " + self.name + "!\n" + format_exc())
            pass

    async def on_new_message(self, message : str) -> str|None:
        data = loads(message)

        if "stop" in data:
            self.log.info(f"Calling Loader unload function for {self.name}.")
            await self._unload()

            if data.get('uninstall'):
                self.log.info("Calling Loader uninstall function.")
                await self._uninstall()

            self.log.debug("Stopping event loop")

            loop = get_event_loop()
            loop.call_soon_threadsafe(loop.stop)
            sys.exit(0)

        d: SocketResponseDict = {"type": SocketMessageType.RESPONSE, "res": None, "success": True, "id": data["id"]}
        try:
            if data.get("legacy"):
                if self.api_version > 0:
                    raise Exception("Legacy methods may not be used on api_version > 0")
                # Legacy kwargs
                d["res"] = await getattr(self.Plugin, data["method"])(self.Plugin, **data["args"])
            else:
                if self.api_version < 1 :
                    raise Exception("api_version 1 or newer is required to call methods with index-based arguments")
                # New args
                d["res"] = await getattr(self.Plugin, data["method"])(*data["args"])
        except Exception as e:
            d["res"] = str(e)
            d["success"] = False
        finally:
            return dumps(d, ensure_ascii=False)