diff --git a/subiquity/common/apidef.py b/subiquity/common/apidef.py index 867df2a8..85821a3d 100644 --- a/subiquity/common/apidef.py +++ b/subiquity/common/apidef.py @@ -29,7 +29,6 @@ from subiquity.common.types import ( AnyStep, ApplicationState, ApplicationStatus, - Change, Disk, ErrorReportRef, GuidedChoice, @@ -139,7 +138,7 @@ class API: """Start the update and return the change id.""" class progress: - def GET(change_id: str) -> Change: ... + def GET(change_id: str) -> dict: ... class keyboard: def GET() -> KeyboardSetup: ... diff --git a/subiquity/common/types.py b/subiquity/common/types.py index 4e89dae3..f617c546 100644 --- a/subiquity/common/types.py +++ b/subiquity/common/types.py @@ -20,7 +20,7 @@ import datetime import enum import shlex -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union import attr @@ -621,42 +621,3 @@ class WSLConfigurationAdvanced: @attr.s(auto_attribs=True) class WSLSetupOptions: install_language_support_packages: bool = attr.ib(default=True) - - -class TaskStatus(enum.Enum): - DO = "Do" - DOING = "Doing" - DONE = "Done" - ABORT = "Abort" - UNDO = "Undo" - UNDOING = "Undoing" - HOLD = "Hold" - ERROR = "Error" - - -@attr.s(auto_attribs=True) -class TaskProgress: - label: str = '' - done: int = 0 - total: int = 0 - - -@attr.s(auto_attribs=True) -class Task: - id: str - kind: str - summary: str - status: TaskStatus - progress: TaskProgress = TaskProgress() - - -@attr.s(auto_attribs=True) -class Change: - id: str - kind: str - summary: str - status: TaskStatus - tasks: List[Task] - ready: bool - err: Optional[str] = None - data: Any = None diff --git a/subiquity/server/controllers/refresh.py b/subiquity/server/controllers/refresh.py index ae8022c5..5ef5693f 100644 --- a/subiquity/server/controllers/refresh.py +++ b/subiquity/server/controllers/refresh.py @@ -27,7 +27,6 @@ from subiquitycore.context import with_context from subiquity.common.apidef import API from subiquity.common.types import ( - Change, RefreshCheckState, RefreshStatus, ) @@ -38,12 +37,6 @@ from subiquity.common.snap import ( from subiquity.server.controller import ( SubiquityController, ) -from subiquity.server.snapdapi import ( - SnapAction, - SnapActionRequest, - TaskStatus, - post_and_wait, - ) from subiquity.server.types import InstallerChannels @@ -106,34 +99,33 @@ class RefreshController(SubiquityController): change_id = await self.start_update(context=context) while True: change = await self.get_progress(change_id) - if change.status not in [ - TaskStatus.DO, TaskStatus.DOING, TaskStatus.DONE]: - raise Exception(f"update failed: {change.status}") + if change['status'] not in ['Do', 'Doing', 'Done']: + raise Exception(f"update failed: {change['status']}") await asyncio.sleep(0.1) @with_context() async def configure_snapd(self, context): with context.child("get_details") as subcontext: try: - r = await self.app.snapdapi.v2.snaps[self.snap_name].GET() + r = await self.app.snapd.get( + 'v2/snaps/{snap_name}'.format( + snap_name=self.snap_name)) except requests.exceptions.RequestException: log.exception("getting snap details") return - self.status.current_snap_version = r.version + self.status.current_snap_version = r['result']['version'] for k in 'channel', 'revision', 'version': self.app.note_data_for_apport( - "Snap" + k.title(), getattr(r, k)) + "Snap" + k.title(), r['result'][k]) subcontext.description = "current version of snap is: %r" % ( self.status.current_snap_version) channel = self.get_refresh_channel() desc = "switching {} to {}".format(self.snap_name, channel) with context.child("switching", desc) as subcontext: try: - await post_and_wait( - self.app.snapdapi, - self.app.snapdapi.v2.snaps[self.snap_name].POST, - SnapActionRequest( - action=SnapAction.SWITCH, channel=channel)) + await self.app.snapd.post_and_wait( + 'v2/snaps/{}'.format(self.snap_name), + {'action': 'switch', 'channel': channel}) except requests.exceptions.RequestException: log.exception("switching channels") return @@ -180,17 +172,17 @@ class RefreshController(SubiquityController): self.status.availability = RefreshCheckState.UNAVAILABLE return try: - result = await self.app.snapdapi.v2.find.GET(select='refresh') + result = await self.app.snapd.get('v2/find', select='refresh') except requests.exceptions.RequestException: log.exception("checking for snap update failed") context.description = "checking for snap update failed" self.status.availability = RefreshCheckState.UNKNOWN return log.debug("check_for_update received %s", result) - for snap in result: - if snap.name != self.snap_name: + for snap in result["result"]: + if snap["name"] != self.snap_name: continue - self.status.new_snap_version = snap.version + self.status.new_snap_version = snap["version"] # In certain circumstances, the version of the snap that is # reported by snapd is older than the one currently running. In # this scenario, we do not want to suggest an update that would @@ -221,14 +213,16 @@ class RefreshController(SubiquityController): @with_context() async def start_update(self, context): - change_id = await self.app.snapdapi.v2.snaps[self.snap_name].POST( - SnapActionRequest(action=SnapAction.REFRESH, ignore_running=True)) - context.description = "change id: {}".format(change_id) - return change_id + change = await self.app.snapd.post( + 'v2/snaps/{}'.format(self.snap_name), + {'action': 'refresh', 'ignore-running': True}) + context.description = "change id: {}".format(change) + return change - async def get_progress(self, change_id: str) -> Change: - change = await self.app.snapdapi.v2.changes[change_id].GET() - if change.status == TaskStatus.DONE: + async def get_progress(self, change): + result = await self.app.snapd.get('v2/changes/{}'.format(change)) + change = result['result'] + if change['status'] == 'Done': # Clearly if we got here we didn't get restarted by # snapd/systemctl (dry-run mode) self.app.restart() @@ -242,5 +236,5 @@ class RefreshController(SubiquityController): async def POST(self, context) -> str: return await self.start_update(context=context) - async def progress_GET(self, change_id: str) -> Change: + async def progress_GET(self, change_id: str) -> dict: return await self.get_progress(change_id) diff --git a/subiquity/server/server.py b/subiquity/server/server.py index 85abd580..916422b4 100644 --- a/subiquity/server/server.py +++ b/subiquity/server/server.py @@ -76,7 +76,6 @@ from subiquity.server.geoip import ( ) from subiquity.server.errors import ErrorController from subiquity.server.runner import get_command_runner -from subiquity.server.snapdapi import make_api_client from subiquity.server.types import InstallerChannels from subiquitycore.snapd import ( AsyncSnapd, @@ -314,11 +313,9 @@ class SubiquityServer(Application): "examples", "snaps"), self.scale_factor, opts.output_base) self.snapd = AsyncSnapd(connection) - self.snapdapi = make_api_client(self.snapd) elif os.path.exists(self.snapd_socket_path): connection = SnapdConnection(self.root, self.snapd_socket_path) self.snapd = AsyncSnapd(connection) - self.snapdapi = make_api_client(self.snapd) else: log.info("no snapd socket found. Snap support is disabled") self.snapd = None diff --git a/subiquity/server/snapdapi.py b/subiquity/server/snapdapi.py deleted file mode 100644 index a4a1197d..00000000 --- a/subiquity/server/snapdapi.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright 2022 Canonical, Ltd. -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . - -import aiohttp -import asyncio -import contextlib -import enum -import logging -from typing import List - -from subiquity.common.api.client import make_client -from subiquity.common.api.defs import api, path_parameter, Payload -from subiquity.common.serialize import named_field, Serializer -from subiquity.common.types import Change, TaskStatus - -import attr - - -log = logging.getLogger('subiquity.server.snapdapi') - -RFC3339 = '%Y-%m-%dT%H:%M:%S.%fZ' - - -def date_field(name=None, default=attr.NOTHING): - metadata = {'time_fmt': RFC3339} - if name is not None: - metadata.update(named_field(name).metadata) - return attr.ib(metadata=metadata, default=default) - - -ChangeID = str - - -class SnapStatus(enum.Enum): - ACTIVE = 'active' - AVAILABLE = 'available' - - -@attr.s(auto_attribs=True) -class Publisher: - id: str - username: str - display_name: str = named_field('display-name') - - -@attr.s(auto_attribs=True) -class Snap: - id: str - name: str - status: SnapStatus - publisher: Publisher - version: str - revision: str - channel: str - - -class SnapAction(enum.Enum): - REFRESH = 'refresh' - SWITCH = 'switch' - - -@attr.s(auto_attribs=True) -class SnapActionRequest: - action: SnapAction - channel: str = '' - ignore_running: bool = named_field('ignore-running', False) - - -class ResponseType: - SYNC = 'sync' - ASYNC = 'async' - ERROR = 'error' - - -@attr.s(auto_attribs=True) -class Response: - type: str - status_code: int = named_field("status-code") - status: str - - -@api -class SnapdAPI: - serialize_query_args = False - - class v2: - class changes: - @path_parameter - class change_id: - def GET() -> Change: ... - - class snaps: - @path_parameter - class snap_name: - def GET() -> Snap: ... - def POST(action: Payload[SnapActionRequest]) -> ChangeID: ... - - class find: - def GET(name: str = '', select: str = '') -> List[Snap]: ... - - -class _FakeResponse: - def __init__(self, data): - self.data = data - - def raise_for_status(self): - pass - - async def json(self): - return self.data - - -class _FakeError: - def __init__(self, data): - self.data = data - - def raise_for_status(self): - raise aiohttp.ClientError(self.data['result']['message']) - - -def make_api_client(async_snapd): - # subiquity.common.api.client is designed around how to make requests - # with aiohttp's client code, not the AsyncSnapd API but with a bit of - # effort it can be contorted into shape. Clearly it would be better to - # use aiohttp to talk to snapd but that would require porting across - # the fake implementation used in dry-run mode. - - @contextlib.asynccontextmanager - async def make_request(method, path, *, params, json): - if method == "GET": - content = await async_snapd.get(path[1:], **params) - else: - content = await async_snapd.post(path[1:], json, **params) - response = serializer.deserialize(Response, content) - if response.type == ResponseType.SYNC: - content = content['result'] - elif response.type == ResponseType.ASYNC: - content = content['change'] - elif response.type == ResponseType.ERROR: - yield _FakeError() - yield _FakeResponse(content) - - serializer = Serializer( - ignore_unknown_fields=True, serialize_enums_by='value') - - return make_client(SnapdAPI, make_request, serializer=serializer) - - -async def post_and_wait(client, meth, *args, **kw): - change_id = await meth(*args, **kw) - log.debug('post_and_wait %s', change_id) - - while True: - result = await client.v2.changes[change_id].GET() - if result.status == TaskStatus.DONE: - return result.data - await asyncio.sleep(0.1) diff --git a/subiquity/ui/views/refresh.py b/subiquity/ui/views/refresh.py index 4293b7d4..6d0cbc59 100644 --- a/subiquity/ui/views/refresh.py +++ b/subiquity/ui/views/refresh.py @@ -32,7 +32,7 @@ from subiquitycore.ui.container import Columns, ListBox from subiquitycore.ui.spinner import Spinner from subiquitycore.ui.utils import button_pile, Color, screen -from subiquity.common.types import RefreshCheckState, TaskStatus +from subiquity.common.types import RefreshCheckState log = logging.getLogger('subiquity.ui.views.refresh') @@ -77,20 +77,20 @@ class TaskProgress(WidgetWrap): super().__init__(cols) def update(self, task): - progress = task.progress - done = progress.done - total = progress.total + progress = task['progress'] + done = progress['done'] + total = progress['total'] if total > 1: if self.mode == "spinning": bar = TaskProgressBar() self._w = bar else: bar = self._w - bar.label = task.summary + bar.label = task['summary'] bar.done = total bar.current = done else: - self.label.set_text(task.summary) + self.label.set_text(task['summary']) self.spinner.spin() @@ -253,17 +253,13 @@ class RefreshView(BaseView): return while True: change = await self.controller.get_progress(change_id) - if change.status == TaskStatus.DONE: + if change['status'] == 'Done': # Clearly if we got here we didn't get restarted by # snapd/systemctl (dry-run mode or logged in via SSH) self.controller.app.restart(remove_last_screen=False) return - if change.status not in (TaskStatus.DO, TaskStatus.DOING): - if change.err: - err = change.err - else: - err = "Unknown error" - self.update_failed(err) + if change['status'] not in ['Do', 'Doing']: + self.update_failed(change.get('err', "Unknown error")) return self.update_progress(change) await asyncio.sleep(0.1) @@ -288,14 +284,14 @@ class RefreshView(BaseView): self._w = screen(rows, buttons, excerpt=_(self.update_failed_excerpt)) def update_progress(self, change): - for task in change.tasks: - tid = task.id - if task.status == TaskStatus.DONE: + for task in change['tasks']: + tid = task['id'] + if task['status'] == "Done": bar = self.task_to_bar.get(tid) if bar is not None: self.lb_tasks.base_widget.body.remove(bar) del self.task_to_bar[tid] - if task.status == TaskStatus.DOING: + if task['status'] == "Doing": if tid not in self.task_to_bar: self.task_to_bar[tid] = bar = TaskProgress() self.lb_tasks.base_widget.body.append(bar) diff --git a/subiquitycore/snapd.py b/subiquitycore/snapd.py index 70f29c45..069a9ba9 100644 --- a/subiquitycore/snapd.py +++ b/subiquitycore/snapd.py @@ -152,8 +152,7 @@ class FakeSnapdConnection: "Don't know how to fake POST response to {}".format((path, args))) def get(self, path, **args): - if 'change' not in path: - time.sleep(1/self.scale_factor) + time.sleep(1/self.scale_factor) filename = path.replace('/', '-') if args: filename += '-' + urlencode(sorted(args.items())) @@ -185,10 +184,10 @@ class AsyncSnapd: response = await run_in_thread( partial(self.connection.post, path, body, **args)) response.raise_for_status() - return response.json() + return response.json()['change'] async def post_and_wait(self, path, body, **args): - change = (await self.post(path, body, **args))['change'] + change = await self.post(path, body, **args) change_path = 'v2/changes/{}'.format(change) while True: result = await self.get(change_path)