diff options
Diffstat (limited to 'backend/decky_loader/utilities.py')
| -rw-r--r-- | backend/decky_loader/utilities.py | 94 |
1 files changed, 91 insertions, 3 deletions
diff --git a/backend/decky_loader/utilities.py b/backend/decky_loader/utilities.py index 174b7cb0..d7d16f04 100644 --- a/backend/decky_loader/utilities.py +++ b/backend/decky_loader/utilities.py @@ -1,6 +1,8 @@ from __future__ import annotations from os import stat_result import uuid +from urllib.parse import unquote +from json.decoder import JSONDecodeError from os.path import splitext import re from traceback import format_exc @@ -8,6 +10,7 @@ from stat import FILE_ATTRIBUTE_HIDDEN # type: ignore from asyncio import StreamReader, StreamWriter, start_server, gather, open_connection from aiohttp import ClientSession +from aiohttp.web import Request, StreamResponse, Response, json_response, post from typing import TYPE_CHECKING, Callable, Coroutine, Dict, Any, List, TypedDict from logging import getLogger @@ -26,12 +29,17 @@ class FilePickerObj(TypedDict): filest: stat_result is_dir: bool +decky_header_regex = re.compile("X-Decky-(.*)") +extra_header_regex = re.compile("X-Decky-Header-(.*)") + +excluded_default_headers = ["Host", "Origin", "Sec-Fetch-Site", "Sec-Fetch-Mode", "Sec-Fetch-Dest"] + class Utilities: def __init__(self, context: PluginManager) -> None: self.context = context self.legacy_util_methods: Dict[str, Callable[..., Coroutine[Any, Any, Any]]] = { "ping": self.ping, - "http_request": self.http_request, + "http_request": self.http_request_legacy, "install_plugin": self.install_plugin, "install_plugins": self.install_plugins, "cancel_plugin_install": self.cancel_plugin_install, @@ -76,9 +84,33 @@ class Utilities: context.ws.add_route("utilities/enable_rdt", self.enable_rdt) context.ws.add_route("utilities/get_tab_id", self.get_tab_id) context.ws.add_route("utilities/get_user_info", self.get_user_info) - context.ws.add_route("utilities/http_request", self.http_request) + context.ws.add_route("utilities/http_request", self.http_request_legacy) context.ws.add_route("utilities/_call_legacy_utility", self._call_legacy_utility) + context.web_app.add_routes([ + post("/methods/{method_name}", self._handle_legacy_server_method_call) + ]) + + for method in ('GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'): + context.web_app.router.add_route(method, "/fetch", self.http_request) + + + async def _handle_legacy_server_method_call(self, request: Request) -> Response: + method_name = request.match_info["method_name"] + try: + args = await request.json() + except JSONDecodeError: + args = {} + res = {} + try: + r = await self.legacy_util_methods[method_name](**args) + res["result"] = r + res["success"] = True + except Exception as e: + res["result"] = str(e) + res["success"] = False + return json_response(res) + async def _call_legacy_utility(self, method_name: str, kwargs: Dict[Any, Any]) -> Any: self.logger.debug(f"Calling utility {method_name} with legacy kwargs"); res: Dict[Any, Any] = {} @@ -114,7 +146,63 @@ class Utilities: async def uninstall_plugin(self, name: str): return await self.context.plugin_browser.uninstall_plugin(name) - async def http_request(self, method: str, url: str, extra_opts: Any = {}): + # Loosely based on https://gist.github.com/mosquito/4dbfacd51e751827cda7ec9761273e95#file-proxy-py + async def http_request(self, req: Request) -> StreamResponse: + if req.headers.get('X-Decky-Auth', '') != helpers.get_csrf_token() and req.query.get('auth', '') != helpers.get_csrf_token(): + return Response(text='Forbidden', status=403) + + url = req.headers["X-Decky-Fetch-URL"] if "X-Decky-Fetch-URL" in req.headers else unquote(req.query.get('fetch_url', '')) + self.logger.info(f"Preparing {req.method} request to {url}") + + headers = dict(req.headers) + + headers["User-Agent"] = helpers.user_agent + + for excluded_header in excluded_default_headers: + self.logger.debug(f"Excluding default header {excluded_header}") + if excluded_header in headers: + del headers[excluded_header] + + if "X-Decky-Fetch-Excluded-Headers" in req.headers: + for excluded_header in req.headers["X-Decky-Fetch-Excluded-Headers"].split(", "): + self.logger.debug(f"Excluding header {excluded_header}") + if excluded_header in headers: + del headers[excluded_header] + + for header in req.headers: + match = extra_header_regex.search(header) + if match: + header_name = match.group(1) + header_value = req.headers[header] + self.logger.debug(f"Adding extra header {header_name}: {header_value}") + headers[header_name] = header_value + + for header in list(headers.keys()): + match = decky_header_regex.search(header) + if match: + self.logger.debug(f"Removing decky header {header} from request") + del headers[header] + + self.logger.debug(f"Final request headers: {headers}") + + body = await req.read() # TODO can this also be streamed? + + async with ClientSession() as web: + async with web.request(req.method, url, headers=headers, data=body, ssl=helpers.get_ssl_context()) as web_res: + res = StreamResponse(headers=web_res.headers, status=web_res.status) + if web_res.headers.get('Transfer-Encoding', '').lower() == 'chunked': + res.enable_chunked_encoding() + + await res.prepare(req) + self.logger.debug(f"Starting stream for {url}") + async for data in web_res.content.iter_any(): + await res.write(data) + if data: + await res.drain() + self.logger.debug(f"Finished stream for {url}") + return res + + async def http_request_legacy(self, method: str, url: str, extra_opts: Any = {}): async with ClientSession() as web: res = await web.request(method, url, ssl=helpers.get_ssl_context(), **extra_opts) text = await res.text() |
