diff options
Diffstat (limited to 'backend/plugin/python_plugin.py')
| -rw-r--r-- | backend/plugin/python_plugin.py | 117 |
1 files changed, 117 insertions, 0 deletions
diff --git a/backend/plugin/python_plugin.py b/backend/plugin/python_plugin.py new file mode 100644 index 00000000..c2833088 --- /dev/null +++ b/backend/plugin/python_plugin.py @@ -0,0 +1,117 @@ +import json +import multiprocessing +import os +import uuid +from asyncio import (Protocol, get_event_loop, new_event_loop, set_event_loop, + sleep) +from importlib.util import module_from_spec, spec_from_file_location +from posixpath import join +from signal import SIGINT, signal +from tempfile import mkdtemp + +from plugin_protocol import PluginProtocolServer + +multiprocessing.set_start_method("fork") + +# only useable by the python backend +class PluginProtocolClient(Protocol): + def __init__(self, backend, logger) -> None: + super().__init__() + self.backend = backend + self.logger = logger + + def connection_made(self, transport): + self.transport = transport + + def data_received(self, data: bytes) -> None: + message = json.loads(data.decode("utf-8")) + message_id = str(uuid.UUID(message["id"])) + message_type = message["type"] + payload = message["payload"] + + self.logger.debug(f"received {message_id} {message_type} {payload}") + if message_type == "method_call": + get_event_loop().create_task(self.handle_method_call(message_id, payload["name"], payload["args"])) + + async def handle_method_call(self, message_id, method_name, method_args): + try: + result = await self.backend.execute_method(method_name, method_args) + self.respond_message(message_id, "method_response", dict(success = True, result = result)) + except AttributeError as e: + self.respond_message(message_id, "method_response", dict(success = False, result = f"plugin does not expose a method called {method_name}")) + except Exception as e: + self.respond_message(message_id, "method_response", dict(success = False, result = str(e))) + + def respond_message(self, message_id, message_type, payload): + self.logger.debug(f"sending {message_id} {message_type} {payload}") + message = json.dumps(dict(id = str(message_id), type = message_type, payload = payload)) + self.transport.write(message.encode('utf-8')) + + +class PythonPlugin: + def __init__(self, plugin_directory, file_name, flags, logger) -> None: + self.client = PluginProtocolClient(self, logger) + self.server = PluginProtocolServer(self) + self.connection = None + + self.plugin_directory = plugin_directory + self.file_name = file_name + self.flags = flags + self.logger = logger + + def _init(self): + self.logger.debug(f"child process Initializing") + signal(SIGINT, lambda s, f: exit(0)) + + set_event_loop(new_event_loop()) + # TODO: both processes can access the socket + # setuid(0 if "root" in self.flags else 1000) + spec = spec_from_file_location("_", join(self.plugin_directory, self.file_name)) + module = module_from_spec(spec) + spec.loader.exec_module(module) + self.Plugin = module.Plugin + + if hasattr(self.Plugin, "_main"): + self.logger.debug("Found _main, calling it") + get_event_loop().create_task(self.Plugin._main(self.Plugin)) + + get_event_loop().create_task(self._connect()) + get_event_loop().run_forever() + + async def _connect(self): + self.logger.debug(f"connecting to unix server on {self.unix_socket_path}") + await get_event_loop().create_unix_connection(lambda: self.client, path=self.unix_socket_path) + + async def start(self): + if self.connection: + self.connection.close() + + self.unix_socket_path = PythonPlugin.generate_socket_path() + self.logger.debug(f"starting unix server on {self.unix_socket_path}") + self.connection = await get_event_loop().create_unix_server(lambda: self.server, path=self.unix_socket_path) + + self.process = multiprocessing.Process(target=self._init) + self.process.start() + get_event_loop().create_task(self.process_loop()) + + async def process_loop(self): + await get_event_loop().run_in_executor(None, self.process.join) + self.logger.info("backend process was killed - restarting in 10 seconds") + await sleep(10) + await self.start() + + # called on the server/loader process + async def call_method(self, method_name, method_args): + if not self.process.is_alive(): + return dict(success = False, result = "Process not alive") + + return await self.server.call_method(method_name, method_args) + + # called on the client + def execute_method(self, method_name, method_args): + return getattr(self.Plugin, method_name)(self.Plugin, **method_args) + + def generate_socket_path(): + tmp_dir = mkdtemp("decky-plugin") + os.chown(tmp_dir, 1000, 1000) + return join(tmp_dir, "socket") |
