diff --git a/subiquity/common/apidef.py b/subiquity/common/apidef.py index 85821a3d..867df2a8 100644 --- a/subiquity/common/apidef.py +++ b/subiquity/common/apidef.py @@ -29,6 +29,7 @@ from subiquity.common.types import ( AnyStep, ApplicationState, ApplicationStatus, + Change, Disk, ErrorReportRef, GuidedChoice, @@ -138,7 +139,7 @@ class API: """Start the update and return the change id.""" class progress: - def GET(change_id: str) -> dict: ... + def GET(change_id: str) -> Change: ... class keyboard: def GET() -> KeyboardSetup: ... diff --git a/subiquity/common/types.py b/subiquity/common/types.py index f617c546..4e89dae3 100644 --- a/subiquity/common/types.py +++ b/subiquity/common/types.py @@ -20,7 +20,7 @@ import datetime import enum import shlex -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import attr @@ -621,3 +621,42 @@ class WSLConfigurationAdvanced: @attr.s(auto_attribs=True) class WSLSetupOptions: install_language_support_packages: bool = attr.ib(default=True) + + +class TaskStatus(enum.Enum): + DO = "Do" + DOING = "Doing" + DONE = "Done" + ABORT = "Abort" + UNDO = "Undo" + UNDOING = "Undoing" + HOLD = "Hold" + ERROR = "Error" + + +@attr.s(auto_attribs=True) +class TaskProgress: + label: str = '' + done: int = 0 + total: int = 0 + + +@attr.s(auto_attribs=True) +class Task: + id: str + kind: str + summary: str + status: TaskStatus + progress: TaskProgress = TaskProgress() + + +@attr.s(auto_attribs=True) +class Change: + id: str + kind: str + summary: str + status: TaskStatus + tasks: List[Task] + ready: bool + err: Optional[str] = None + data: Any = None diff --git a/subiquity/server/controllers/refresh.py b/subiquity/server/controllers/refresh.py index 5ef5693f..ae8022c5 100644 --- a/subiquity/server/controllers/refresh.py +++ b/subiquity/server/controllers/refresh.py @@ -27,6 +27,7 @@ from subiquitycore.context import with_context from subiquity.common.apidef import API from subiquity.common.types import ( + Change, RefreshCheckState, RefreshStatus, ) @@ -37,6 +38,12 @@ from subiquity.common.snap import ( from subiquity.server.controller import ( SubiquityController, ) +from subiquity.server.snapdapi import ( + SnapAction, + SnapActionRequest, + TaskStatus, + post_and_wait, + ) from subiquity.server.types import InstallerChannels @@ -99,33 +106,34 @@ class RefreshController(SubiquityController): change_id = await self.start_update(context=context) while True: change = await self.get_progress(change_id) - if change['status'] not in ['Do', 'Doing', 'Done']: - raise Exception(f"update failed: {change['status']}") + if change.status not in [ + TaskStatus.DO, TaskStatus.DOING, TaskStatus.DONE]: + raise Exception(f"update failed: {change.status}") await asyncio.sleep(0.1) @with_context() async def configure_snapd(self, context): with context.child("get_details") as subcontext: try: - r = await self.app.snapd.get( - 'v2/snaps/{snap_name}'.format( - snap_name=self.snap_name)) + r = await self.app.snapdapi.v2.snaps[self.snap_name].GET() except requests.exceptions.RequestException: log.exception("getting snap details") return - self.status.current_snap_version = r['result']['version'] + self.status.current_snap_version = r.version for k in 'channel', 'revision', 'version': self.app.note_data_for_apport( - "Snap" + k.title(), r['result'][k]) + "Snap" + k.title(), getattr(r, k)) subcontext.description = "current version of snap is: %r" % ( self.status.current_snap_version) channel = self.get_refresh_channel() desc = "switching {} to {}".format(self.snap_name, channel) with context.child("switching", desc) as subcontext: try: - await self.app.snapd.post_and_wait( - 'v2/snaps/{}'.format(self.snap_name), - {'action': 'switch', 'channel': channel}) + await post_and_wait( + self.app.snapdapi, + self.app.snapdapi.v2.snaps[self.snap_name].POST, + SnapActionRequest( + action=SnapAction.SWITCH, channel=channel)) except requests.exceptions.RequestException: log.exception("switching channels") return @@ -172,17 +180,17 @@ class RefreshController(SubiquityController): self.status.availability = RefreshCheckState.UNAVAILABLE return try: - result = await self.app.snapd.get('v2/find', select='refresh') + result = await self.app.snapdapi.v2.find.GET(select='refresh') except requests.exceptions.RequestException: log.exception("checking for snap update failed") context.description = "checking for snap update failed" self.status.availability = RefreshCheckState.UNKNOWN return log.debug("check_for_update received %s", result) - for snap in result["result"]: - if snap["name"] != self.snap_name: + for snap in result: + if snap.name != self.snap_name: continue - self.status.new_snap_version = snap["version"] + self.status.new_snap_version = snap.version # In certain circumstances, the version of the snap that is # reported by snapd is older than the one currently running. In # this scenario, we do not want to suggest an update that would @@ -213,16 +221,14 @@ class RefreshController(SubiquityController): @with_context() async def start_update(self, context): - change = await self.app.snapd.post( - 'v2/snaps/{}'.format(self.snap_name), - {'action': 'refresh', 'ignore-running': True}) - context.description = "change id: {}".format(change) - return change + change_id = await self.app.snapdapi.v2.snaps[self.snap_name].POST( + SnapActionRequest(action=SnapAction.REFRESH, ignore_running=True)) + context.description = "change id: {}".format(change_id) + return change_id - async def get_progress(self, change): - result = await self.app.snapd.get('v2/changes/{}'.format(change)) - change = result['result'] - if change['status'] == 'Done': + async def get_progress(self, change_id: str) -> Change: + change = await self.app.snapdapi.v2.changes[change_id].GET() + if change.status == TaskStatus.DONE: # Clearly if we got here we didn't get restarted by # snapd/systemctl (dry-run mode) self.app.restart() @@ -236,5 +242,5 @@ class RefreshController(SubiquityController): async def POST(self, context) -> str: return await self.start_update(context=context) - async def progress_GET(self, change_id: str) -> dict: + async def progress_GET(self, change_id: str) -> Change: return await self.get_progress(change_id) diff --git a/subiquity/server/server.py b/subiquity/server/server.py index 916422b4..85abd580 100644 --- a/subiquity/server/server.py +++ b/subiquity/server/server.py @@ -76,6 +76,7 @@ from subiquity.server.geoip import ( ) from subiquity.server.errors import ErrorController from subiquity.server.runner import get_command_runner +from subiquity.server.snapdapi import make_api_client from subiquity.server.types import InstallerChannels from subiquitycore.snapd import ( AsyncSnapd, @@ -313,9 +314,11 @@ class SubiquityServer(Application): "examples", "snaps"), self.scale_factor, opts.output_base) self.snapd = AsyncSnapd(connection) + self.snapdapi = make_api_client(self.snapd) elif os.path.exists(self.snapd_socket_path): connection = SnapdConnection(self.root, self.snapd_socket_path) self.snapd = AsyncSnapd(connection) + self.snapdapi = make_api_client(self.snapd) else: log.info("no snapd socket found. Snap support is disabled") self.snapd = None diff --git a/subiquity/server/snapdapi.py b/subiquity/server/snapdapi.py new file mode 100644 index 00000000..f82e2adc --- /dev/null +++ b/subiquity/server/snapdapi.py @@ -0,0 +1,169 @@ +# Copyright 2022 Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import aiohttp +import asyncio +import contextlib +import enum +import logging +from typing import List, Optional + +from subiquity.common.api.client import make_client +from subiquity.common.api.defs import api, path_parameter, Payload +from subiquity.common.serialize import named_field, Serializer +from subiquity.common.types import Change, TaskStatus + +import attr + + +log = logging.getLogger('subiquity.server.snapdapi') + +RFC3339 = '%Y-%m-%dT%H:%M:%S.%fZ' + + +def date_field(name=None, default=attr.NOTHING): + metadata = {'time_fmt': RFC3339} + if name is not None: + metadata.update(named_field(name).metadata) + return attr.ib(metadata=metadata, default=default) + + +ChangeID = str + + +class SnapStatus(enum.Enum): + ACTIVE = 'active' + AVAILABLE = 'available' + + +@attr.s(auto_attribs=True) +class Publisher: + id: str + username: str + display_name: str = named_field('display-name') + + +@attr.s(auto_attribs=True) +class Snap: + id: str + name: str + status: SnapStatus + version: str + revision: str + channel: str + publisher: Optional[Publisher] = None + + +class SnapAction(enum.Enum): + REFRESH = 'refresh' + SWITCH = 'switch' + + +@attr.s(auto_attribs=True) +class SnapActionRequest: + action: SnapAction + channel: str = '' + ignore_running: bool = named_field('ignore-running', False) + + +class ResponseType: + SYNC = 'sync' + ASYNC = 'async' + ERROR = 'error' + + +@attr.s(auto_attribs=True) +class Response: + type: str + status_code: int = named_field("status-code") + status: str + + +@api +class SnapdAPI: + serialize_query_args = False + + class v2: + class changes: + @path_parameter + class change_id: + def GET() -> Change: ... + + class snaps: + @path_parameter + class snap_name: + def GET() -> Snap: ... + def POST(action: Payload[SnapActionRequest]) -> ChangeID: ... + + class find: + def GET(name: str = '', select: str = '') -> List[Snap]: ... + + +class _FakeResponse: + def __init__(self, data): + self.data = data + + def raise_for_status(self): + pass + + async def json(self): + return self.data + + +class _FakeError: + def __init__(self, data): + self.data = data + + def raise_for_status(self): + raise aiohttp.ClientError(self.data['result']['message']) + + +def make_api_client(async_snapd): + # subiquity.common.api.client is designed around how to make requests + # with aiohttp's client code, not the AsyncSnapd API but with a bit of + # effort it can be contorted into shape. Clearly it would be better to + # use aiohttp to talk to snapd but that would require porting across + # the fake implementation used in dry-run mode. + + @contextlib.asynccontextmanager + async def make_request(method, path, *, params, json): + if method == "GET": + content = await async_snapd.get(path[1:], **params) + else: + content = await async_snapd.post(path[1:], json, **params) + response = serializer.deserialize(Response, content) + if response.type == ResponseType.SYNC: + content = content['result'] + elif response.type == ResponseType.ASYNC: + content = content['change'] + elif response.type == ResponseType.ERROR: + yield _FakeError() + yield _FakeResponse(content) + + serializer = Serializer( + ignore_unknown_fields=True, serialize_enums_by='value') + + return make_client(SnapdAPI, make_request, serializer=serializer) + + +async def post_and_wait(client, meth, *args, **kw): + change_id = await meth(*args, **kw) + log.debug('post_and_wait %s', change_id) + + while True: + result = await client.v2.changes[change_id].GET() + if result.status == TaskStatus.DONE: + return result.data + await asyncio.sleep(0.1) diff --git a/subiquity/ui/views/refresh.py b/subiquity/ui/views/refresh.py index 6d0cbc59..4293b7d4 100644 --- a/subiquity/ui/views/refresh.py +++ b/subiquity/ui/views/refresh.py @@ -32,7 +32,7 @@ from subiquitycore.ui.container import Columns, ListBox from subiquitycore.ui.spinner import Spinner from subiquitycore.ui.utils import button_pile, Color, screen -from subiquity.common.types import RefreshCheckState +from subiquity.common.types import RefreshCheckState, TaskStatus log = logging.getLogger('subiquity.ui.views.refresh') @@ -77,20 +77,20 @@ class TaskProgress(WidgetWrap): super().__init__(cols) def update(self, task): - progress = task['progress'] - done = progress['done'] - total = progress['total'] + progress = task.progress + done = progress.done + total = progress.total if total > 1: if self.mode == "spinning": bar = TaskProgressBar() self._w = bar else: bar = self._w - bar.label = task['summary'] + bar.label = task.summary bar.done = total bar.current = done else: - self.label.set_text(task['summary']) + self.label.set_text(task.summary) self.spinner.spin() @@ -253,13 +253,17 @@ class RefreshView(BaseView): return while True: change = await self.controller.get_progress(change_id) - if change['status'] == 'Done': + if change.status == TaskStatus.DONE: # Clearly if we got here we didn't get restarted by # snapd/systemctl (dry-run mode or logged in via SSH) self.controller.app.restart(remove_last_screen=False) return - if change['status'] not in ['Do', 'Doing']: - self.update_failed(change.get('err', "Unknown error")) + if change.status not in (TaskStatus.DO, TaskStatus.DOING): + if change.err: + err = change.err + else: + err = "Unknown error" + self.update_failed(err) return self.update_progress(change) await asyncio.sleep(0.1) @@ -284,14 +288,14 @@ class RefreshView(BaseView): self._w = screen(rows, buttons, excerpt=_(self.update_failed_excerpt)) def update_progress(self, change): - for task in change['tasks']: - tid = task['id'] - if task['status'] == "Done": + for task in change.tasks: + tid = task.id + if task.status == TaskStatus.DONE: bar = self.task_to_bar.get(tid) if bar is not None: self.lb_tasks.base_widget.body.remove(bar) del self.task_to_bar[tid] - if task['status'] == "Doing": + if task.status == TaskStatus.DOING: if tid not in self.task_to_bar: self.task_to_bar[tid] = bar = TaskProgress() self.lb_tasks.base_widget.body.append(bar) diff --git a/subiquitycore/snapd.py b/subiquitycore/snapd.py index 069a9ba9..70f29c45 100644 --- a/subiquitycore/snapd.py +++ b/subiquitycore/snapd.py @@ -152,7 +152,8 @@ class FakeSnapdConnection: "Don't know how to fake POST response to {}".format((path, args))) def get(self, path, **args): - time.sleep(1/self.scale_factor) + if 'change' not in path: + time.sleep(1/self.scale_factor) filename = path.replace('/', '-') if args: filename += '-' + urlencode(sorted(args.items())) @@ -184,10 +185,10 @@ class AsyncSnapd: response = await run_in_thread( partial(self.connection.post, path, body, **args)) response.raise_for_status() - return response.json()['change'] + return response.json() async def post_and_wait(self, path, body, **args): - change = await self.post(path, body, **args) + change = (await self.post(path, body, **args))['change'] change_path = 'v2/changes/{}'.format(change) while True: result = await self.get(change_path)