# Copyright 2019 Canonical, Ltd. # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as # published by the Free Software Foundation, either version 3 of the # License, or (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . import asyncio import glob import json import logging import os import time from functools import partial from urllib.parse import quote_plus, urlencode import requests_unixsocket from subiquitycore.async_helpers import run_in_thread from subiquitycore.utils import run_command log = logging.getLogger("subiquitycore.snapd") # Every method in this module blocks. Do not call them from the main thread! class SnapdConnection: def __init__(self, root, sock): self.root = root self.url_base = "http+unix://{}/".format(quote_plus(sock)) def get(self, path, **args): if args: path += "?" + urlencode(args) with requests_unixsocket.Session() as session: return session.get(self.url_base + path, timeout=60) def post(self, path, body, **args): if args: path += "?" + urlencode(args) with requests_unixsocket.Session() as session: return session.post(self.url_base + path, data=json.dumps(body), timeout=60) def configure_proxy(self, proxy): log.debug("restarting snapd to pick up proxy config") dropin_dir = os.path.join(self.root, "etc/systemd/system/snapd.service.d") os.makedirs(dropin_dir, exist_ok=True) with open(os.path.join(dropin_dir, "snap_proxy.conf"), "w") as fp: fp.write(proxy.proxy_systemd_dropin()) if self.root == "/": cmds = [ ["systemctl", "daemon-reload"], ["systemctl", "restart", "snapd.service"], ] else: cmds = [["sleep", "2"]] for cmd in cmds: run_command(cmd) class _FakeFileResponse: def __init__(self, path): self.path = path def raise_for_status(self): pass def json(self): with open(self.path) as fp: return json.load(fp) class _FakeMemoryResponse: def __init__(self, data): self.data = data def raise_for_status(self): pass def json(self): return self.data class ResponseSet: """Responses for a endpoint that returns different data each time. Motivating example is v2/changes/$change_id.""" def __init__(self, files): self.files = files self.index = 0 def next(self): f = self.files[self.index] d = int(os.environ.get("SUBIQUITY_REPLAY_TIMESCALE", 1)) # Make sure we return the last response even when we skip most # of them. if d > 1 and self.index + d >= len(self.files): self.index = len(self.files) - 1 else: self.index += d return _FakeFileResponse(f) class FakeSnapdConnection: def __init__(self, snap_data_dir, scale_factor, output_base): self.snap_data_dir = snap_data_dir self.scale_factor = scale_factor self.response_sets = {} self.output_base = output_base def configure_proxy(self, proxy): log.debug("pretending to restart snapd to pick up proxy config") time.sleep(2 / self.scale_factor) def post(self, path, body, **args): if path == "v2/snaps/subiquity" and body["action"] == "refresh": # The post-refresh hook does this in the real world. update_marker_file = self.output_base + "/run/subiquity/updating" open(update_marker_file, "w").close() return _FakeMemoryResponse( { "type": "async", "change": "7", "status-code": 200, "status": "OK", } ) change = None if path == "v2/snaps/subiquity" and body["action"] == "switch": change = "8" if path.startswith("v2/systems/") and body["action"] == "install": system = path.split("/")[2] step = body["step"] if step == "finish": if system == "finish-fail": change = "15" else: change = "5" elif step == "setup-storage-encryption": change = "6" if change is not None: return _FakeMemoryResponse( { "type": "async", "change": change, "status-code": 200, "status": "Accepted", } ) raise Exception( "Don't know how to fake POST response to {}".format((path, args)) ) def get(self, path, **args): if "change" not in path: time.sleep(1 / self.scale_factor) filename = path.replace("/", "-") if args: filename += "-" + urlencode(sorted(args.items())) if filename in self.response_sets: return self.response_sets[filename].next() filepath = os.path.join(self.snap_data_dir, filename) if os.path.exists(filepath + ".json"): return _FakeFileResponse(filepath + ".json") if os.path.isdir(filepath): files = sorted(glob.glob(os.path.join(filepath, "*.json"))) rs = self.response_sets[filename] = ResponseSet(files) return rs.next() raise Exception( "Don't know how to fake GET response to {}".format((path, args)) ) def get_fake_connection(scale_factor=1000, output_base=None): proj_dir = os.path.dirname(os.path.dirname(__file__)) if output_base is None: output_base = os.path.join(proj_dir, ".subiquity") return FakeSnapdConnection( os.path.join(proj_dir, "examples", "snaps"), scale_factor, output_base ) class AsyncSnapd: def __init__(self, connection): self.connection = connection async def get(self, path, **args): response = await run_in_thread(partial(self.connection.get, path, **args)) response.raise_for_status() return response.json() async def post(self, path, body, **args): response = await run_in_thread( partial(self.connection.post, path, body, **args) ) response.raise_for_status() return response.json() async def post_and_wait(self, path, body, **args): change = (await self.post(path, body, **args))["change"] change_path = "v2/changes/{}".format(change) while True: result = await self.get(change_path) if result["result"]["status"] == "Done": break await asyncio.sleep(0.1)