summaryrefslogtreecommitdiff
path: root/backend/plugin/python_plugin.py
blob: c283308826bac81d9fabf25aab2a699f0e28f784 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import json
import multiprocessing
import os
import uuid
from asyncio import (Protocol, get_event_loop, new_event_loop, set_event_loop,
                     sleep)
from importlib.util import module_from_spec, spec_from_file_location
from posixpath import join
from signal import SIGINT, signal
from tempfile import mkdtemp

from plugin_protocol import PluginProtocolServer

multiprocessing.set_start_method("fork")

# only useable by the python backend
class PluginProtocolClient(Protocol):
    def __init__(self, backend, logger) -> None:
        super().__init__()
        self.backend = backend
        self.logger = logger

    def connection_made(self, transport):
        self.transport = transport

    def data_received(self, data: bytes) -> None:
        message = json.loads(data.decode("utf-8"))
        message_id = str(uuid.UUID(message["id"]))
        message_type = message["type"]
        payload = message["payload"]

        self.logger.debug(f"received {message_id} {message_type} {payload}")
        if message_type == "method_call":
          get_event_loop().create_task(self.handle_method_call(message_id, payload["name"], payload["args"]))

    async def handle_method_call(self, message_id, method_name, method_args):
        try:
            result = await self.backend.execute_method(method_name, method_args)
            self.respond_message(message_id, "method_response", dict(success = True, result = result))
        except AttributeError as e:
            self.respond_message(message_id, "method_response", dict(success = False, result = f"plugin does not expose a method called {method_name}"))
        except Exception as e:
            self.respond_message(message_id, "method_response", dict(success = False, result = str(e)))

    def respond_message(self, message_id, message_type, payload):
        self.logger.debug(f"sending {message_id} {message_type} {payload}")
        message = json.dumps(dict(id = str(message_id), type = message_type, payload = payload))
        self.transport.write(message.encode('utf-8'))


class PythonPlugin:
    def __init__(self, plugin_directory, file_name, flags, logger) -> None:
        self.client = PluginProtocolClient(self, logger)
        self.server = PluginProtocolServer(self)
        self.connection = None

        self.plugin_directory = plugin_directory
        self.file_name = file_name
        self.flags = flags
        self.logger = logger

    def _init(self):
        self.logger.debug(f"child process Initializing")
        signal(SIGINT, lambda s, f: exit(0))

        set_event_loop(new_event_loop())
        # TODO: both processes can access the socket
        # setuid(0 if "root" in self.flags else 1000)
        spec = spec_from_file_location("_", join(self.plugin_directory, self.file_name))
        module = module_from_spec(spec)
        spec.loader.exec_module(module)
        self.Plugin = module.Plugin

        if hasattr(self.Plugin, "_main"):
            self.logger.debug("Found _main, calling it")
            get_event_loop().create_task(self.Plugin._main(self.Plugin))

        get_event_loop().create_task(self._connect())
        get_event_loop().run_forever()

    async def _connect(self):
        self.logger.debug(f"connecting to unix server on {self.unix_socket_path}")
        await get_event_loop().create_unix_connection(lambda: self.client, path=self.unix_socket_path)

    async def start(self):
        if self.connection:
            self.connection.close()

        self.unix_socket_path = PythonPlugin.generate_socket_path()
        self.logger.debug(f"starting unix server on {self.unix_socket_path}")
        self.connection = await get_event_loop().create_unix_server(lambda: self.server, path=self.unix_socket_path)

        self.process = multiprocessing.Process(target=self._init)
        self.process.start()
        get_event_loop().create_task(self.process_loop())

    async def process_loop(self):
        await get_event_loop().run_in_executor(None, self.process.join)
        self.logger.info("backend process was killed - restarting in 10 seconds")
        await sleep(10)
        await self.start()

    # called on the server/loader process
    async def call_method(self, method_name, method_args):
        if not self.process.is_alive():
            return dict(success = False, result = "Process not alive")

        return await self.server.call_method(method_name, method_args)

    # called on the client
    def execute_method(self, method_name, method_args):
        return getattr(self.Plugin, method_name)(self.Plugin, **method_args)

    def generate_socket_path():
        tmp_dir = mkdtemp("decky-plugin")
        os.chown(tmp_dir, 1000, 1000)
        return join(tmp_dir, "socket")