Merge pull request #1442 from mwhudson/main

Revert "Merge pull request #1435 from mwhudson/snapdapi"
This commit is contained in:
Michael Hudson-Doyle 2022-10-05 14:47:58 +13:00 committed by GitHub
commit cc33a668a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 42 additions and 265 deletions

View File

@ -29,7 +29,6 @@ from subiquity.common.types import (
AnyStep, AnyStep,
ApplicationState, ApplicationState,
ApplicationStatus, ApplicationStatus,
Change,
Disk, Disk,
ErrorReportRef, ErrorReportRef,
GuidedChoice, GuidedChoice,
@ -139,7 +138,7 @@ class API:
"""Start the update and return the change id.""" """Start the update and return the change id."""
class progress: class progress:
def GET(change_id: str) -> Change: ... def GET(change_id: str) -> dict: ...
class keyboard: class keyboard:
def GET() -> KeyboardSetup: ... def GET() -> KeyboardSetup: ...

View File

@ -20,7 +20,7 @@
import datetime import datetime
import enum import enum
import shlex import shlex
from typing import Any, Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import attr import attr
@ -621,42 +621,3 @@ class WSLConfigurationAdvanced:
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class WSLSetupOptions: class WSLSetupOptions:
install_language_support_packages: bool = attr.ib(default=True) install_language_support_packages: bool = attr.ib(default=True)
class TaskStatus(enum.Enum):
DO = "Do"
DOING = "Doing"
DONE = "Done"
ABORT = "Abort"
UNDO = "Undo"
UNDOING = "Undoing"
HOLD = "Hold"
ERROR = "Error"
@attr.s(auto_attribs=True)
class TaskProgress:
label: str = ''
done: int = 0
total: int = 0
@attr.s(auto_attribs=True)
class Task:
id: str
kind: str
summary: str
status: TaskStatus
progress: TaskProgress = TaskProgress()
@attr.s(auto_attribs=True)
class Change:
id: str
kind: str
summary: str
status: TaskStatus
tasks: List[Task]
ready: bool
err: Optional[str] = None
data: Any = None

View File

@ -27,7 +27,6 @@ from subiquitycore.context import with_context
from subiquity.common.apidef import API from subiquity.common.apidef import API
from subiquity.common.types import ( from subiquity.common.types import (
Change,
RefreshCheckState, RefreshCheckState,
RefreshStatus, RefreshStatus,
) )
@ -38,12 +37,6 @@ from subiquity.common.snap import (
from subiquity.server.controller import ( from subiquity.server.controller import (
SubiquityController, SubiquityController,
) )
from subiquity.server.snapdapi import (
SnapAction,
SnapActionRequest,
TaskStatus,
post_and_wait,
)
from subiquity.server.types import InstallerChannels from subiquity.server.types import InstallerChannels
@ -106,34 +99,33 @@ class RefreshController(SubiquityController):
change_id = await self.start_update(context=context) change_id = await self.start_update(context=context)
while True: while True:
change = await self.get_progress(change_id) change = await self.get_progress(change_id)
if change.status not in [ if change['status'] not in ['Do', 'Doing', 'Done']:
TaskStatus.DO, TaskStatus.DOING, TaskStatus.DONE]: raise Exception(f"update failed: {change['status']}")
raise Exception(f"update failed: {change.status}")
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
@with_context() @with_context()
async def configure_snapd(self, context): async def configure_snapd(self, context):
with context.child("get_details") as subcontext: with context.child("get_details") as subcontext:
try: try:
r = await self.app.snapdapi.v2.snaps[self.snap_name].GET() r = await self.app.snapd.get(
'v2/snaps/{snap_name}'.format(
snap_name=self.snap_name))
except requests.exceptions.RequestException: except requests.exceptions.RequestException:
log.exception("getting snap details") log.exception("getting snap details")
return return
self.status.current_snap_version = r.version self.status.current_snap_version = r['result']['version']
for k in 'channel', 'revision', 'version': for k in 'channel', 'revision', 'version':
self.app.note_data_for_apport( self.app.note_data_for_apport(
"Snap" + k.title(), getattr(r, k)) "Snap" + k.title(), r['result'][k])
subcontext.description = "current version of snap is: %r" % ( subcontext.description = "current version of snap is: %r" % (
self.status.current_snap_version) self.status.current_snap_version)
channel = self.get_refresh_channel() channel = self.get_refresh_channel()
desc = "switching {} to {}".format(self.snap_name, channel) desc = "switching {} to {}".format(self.snap_name, channel)
with context.child("switching", desc) as subcontext: with context.child("switching", desc) as subcontext:
try: try:
await post_and_wait( await self.app.snapd.post_and_wait(
self.app.snapdapi, 'v2/snaps/{}'.format(self.snap_name),
self.app.snapdapi.v2.snaps[self.snap_name].POST, {'action': 'switch', 'channel': channel})
SnapActionRequest(
action=SnapAction.SWITCH, channel=channel))
except requests.exceptions.RequestException: except requests.exceptions.RequestException:
log.exception("switching channels") log.exception("switching channels")
return return
@ -180,17 +172,17 @@ class RefreshController(SubiquityController):
self.status.availability = RefreshCheckState.UNAVAILABLE self.status.availability = RefreshCheckState.UNAVAILABLE
return return
try: try:
result = await self.app.snapdapi.v2.find.GET(select='refresh') result = await self.app.snapd.get('v2/find', select='refresh')
except requests.exceptions.RequestException: except requests.exceptions.RequestException:
log.exception("checking for snap update failed") log.exception("checking for snap update failed")
context.description = "checking for snap update failed" context.description = "checking for snap update failed"
self.status.availability = RefreshCheckState.UNKNOWN self.status.availability = RefreshCheckState.UNKNOWN
return return
log.debug("check_for_update received %s", result) log.debug("check_for_update received %s", result)
for snap in result: for snap in result["result"]:
if snap.name != self.snap_name: if snap["name"] != self.snap_name:
continue continue
self.status.new_snap_version = snap.version self.status.new_snap_version = snap["version"]
# In certain circumstances, the version of the snap that is # In certain circumstances, the version of the snap that is
# reported by snapd is older than the one currently running. In # reported by snapd is older than the one currently running. In
# this scenario, we do not want to suggest an update that would # this scenario, we do not want to suggest an update that would
@ -221,14 +213,16 @@ class RefreshController(SubiquityController):
@with_context() @with_context()
async def start_update(self, context): async def start_update(self, context):
change_id = await self.app.snapdapi.v2.snaps[self.snap_name].POST( change = await self.app.snapd.post(
SnapActionRequest(action=SnapAction.REFRESH, ignore_running=True)) 'v2/snaps/{}'.format(self.snap_name),
context.description = "change id: {}".format(change_id) {'action': 'refresh', 'ignore-running': True})
return change_id context.description = "change id: {}".format(change)
return change
async def get_progress(self, change_id: str) -> Change: async def get_progress(self, change):
change = await self.app.snapdapi.v2.changes[change_id].GET() result = await self.app.snapd.get('v2/changes/{}'.format(change))
if change.status == TaskStatus.DONE: change = result['result']
if change['status'] == 'Done':
# Clearly if we got here we didn't get restarted by # Clearly if we got here we didn't get restarted by
# snapd/systemctl (dry-run mode) # snapd/systemctl (dry-run mode)
self.app.restart() self.app.restart()
@ -242,5 +236,5 @@ class RefreshController(SubiquityController):
async def POST(self, context) -> str: async def POST(self, context) -> str:
return await self.start_update(context=context) return await self.start_update(context=context)
async def progress_GET(self, change_id: str) -> Change: async def progress_GET(self, change_id: str) -> dict:
return await self.get_progress(change_id) return await self.get_progress(change_id)

View File

@ -76,7 +76,6 @@ from subiquity.server.geoip import (
) )
from subiquity.server.errors import ErrorController from subiquity.server.errors import ErrorController
from subiquity.server.runner import get_command_runner from subiquity.server.runner import get_command_runner
from subiquity.server.snapdapi import make_api_client
from subiquity.server.types import InstallerChannels from subiquity.server.types import InstallerChannels
from subiquitycore.snapd import ( from subiquitycore.snapd import (
AsyncSnapd, AsyncSnapd,
@ -314,11 +313,9 @@ class SubiquityServer(Application):
"examples", "snaps"), "examples", "snaps"),
self.scale_factor, opts.output_base) self.scale_factor, opts.output_base)
self.snapd = AsyncSnapd(connection) self.snapd = AsyncSnapd(connection)
self.snapdapi = make_api_client(self.snapd)
elif os.path.exists(self.snapd_socket_path): elif os.path.exists(self.snapd_socket_path):
connection = SnapdConnection(self.root, self.snapd_socket_path) connection = SnapdConnection(self.root, self.snapd_socket_path)
self.snapd = AsyncSnapd(connection) self.snapd = AsyncSnapd(connection)
self.snapdapi = make_api_client(self.snapd)
else: else:
log.info("no snapd socket found. Snap support is disabled") log.info("no snapd socket found. Snap support is disabled")
self.snapd = None self.snapd = None

View File

@ -1,169 +0,0 @@
# Copyright 2022 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import aiohttp
import asyncio
import contextlib
import enum
import logging
from typing import List
from subiquity.common.api.client import make_client
from subiquity.common.api.defs import api, path_parameter, Payload
from subiquity.common.serialize import named_field, Serializer
from subiquity.common.types import Change, TaskStatus
import attr
log = logging.getLogger('subiquity.server.snapdapi')
RFC3339 = '%Y-%m-%dT%H:%M:%S.%fZ'
def date_field(name=None, default=attr.NOTHING):
metadata = {'time_fmt': RFC3339}
if name is not None:
metadata.update(named_field(name).metadata)
return attr.ib(metadata=metadata, default=default)
ChangeID = str
class SnapStatus(enum.Enum):
ACTIVE = 'active'
AVAILABLE = 'available'
@attr.s(auto_attribs=True)
class Publisher:
id: str
username: str
display_name: str = named_field('display-name')
@attr.s(auto_attribs=True)
class Snap:
id: str
name: str
status: SnapStatus
publisher: Publisher
version: str
revision: str
channel: str
class SnapAction(enum.Enum):
REFRESH = 'refresh'
SWITCH = 'switch'
@attr.s(auto_attribs=True)
class SnapActionRequest:
action: SnapAction
channel: str = ''
ignore_running: bool = named_field('ignore-running', False)
class ResponseType:
SYNC = 'sync'
ASYNC = 'async'
ERROR = 'error'
@attr.s(auto_attribs=True)
class Response:
type: str
status_code: int = named_field("status-code")
status: str
@api
class SnapdAPI:
serialize_query_args = False
class v2:
class changes:
@path_parameter
class change_id:
def GET() -> Change: ...
class snaps:
@path_parameter
class snap_name:
def GET() -> Snap: ...
def POST(action: Payload[SnapActionRequest]) -> ChangeID: ...
class find:
def GET(name: str = '', select: str = '') -> List[Snap]: ...
class _FakeResponse:
def __init__(self, data):
self.data = data
def raise_for_status(self):
pass
async def json(self):
return self.data
class _FakeError:
def __init__(self, data):
self.data = data
def raise_for_status(self):
raise aiohttp.ClientError(self.data['result']['message'])
def make_api_client(async_snapd):
# subiquity.common.api.client is designed around how to make requests
# with aiohttp's client code, not the AsyncSnapd API but with a bit of
# effort it can be contorted into shape. Clearly it would be better to
# use aiohttp to talk to snapd but that would require porting across
# the fake implementation used in dry-run mode.
@contextlib.asynccontextmanager
async def make_request(method, path, *, params, json):
if method == "GET":
content = await async_snapd.get(path[1:], **params)
else:
content = await async_snapd.post(path[1:], json, **params)
response = serializer.deserialize(Response, content)
if response.type == ResponseType.SYNC:
content = content['result']
elif response.type == ResponseType.ASYNC:
content = content['change']
elif response.type == ResponseType.ERROR:
yield _FakeError()
yield _FakeResponse(content)
serializer = Serializer(
ignore_unknown_fields=True, serialize_enums_by='value')
return make_client(SnapdAPI, make_request, serializer=serializer)
async def post_and_wait(client, meth, *args, **kw):
change_id = await meth(*args, **kw)
log.debug('post_and_wait %s', change_id)
while True:
result = await client.v2.changes[change_id].GET()
if result.status == TaskStatus.DONE:
return result.data
await asyncio.sleep(0.1)

View File

@ -32,7 +32,7 @@ from subiquitycore.ui.container import Columns, ListBox
from subiquitycore.ui.spinner import Spinner from subiquitycore.ui.spinner import Spinner
from subiquitycore.ui.utils import button_pile, Color, screen from subiquitycore.ui.utils import button_pile, Color, screen
from subiquity.common.types import RefreshCheckState, TaskStatus from subiquity.common.types import RefreshCheckState
log = logging.getLogger('subiquity.ui.views.refresh') log = logging.getLogger('subiquity.ui.views.refresh')
@ -77,20 +77,20 @@ class TaskProgress(WidgetWrap):
super().__init__(cols) super().__init__(cols)
def update(self, task): def update(self, task):
progress = task.progress progress = task['progress']
done = progress.done done = progress['done']
total = progress.total total = progress['total']
if total > 1: if total > 1:
if self.mode == "spinning": if self.mode == "spinning":
bar = TaskProgressBar() bar = TaskProgressBar()
self._w = bar self._w = bar
else: else:
bar = self._w bar = self._w
bar.label = task.summary bar.label = task['summary']
bar.done = total bar.done = total
bar.current = done bar.current = done
else: else:
self.label.set_text(task.summary) self.label.set_text(task['summary'])
self.spinner.spin() self.spinner.spin()
@ -253,17 +253,13 @@ class RefreshView(BaseView):
return return
while True: while True:
change = await self.controller.get_progress(change_id) change = await self.controller.get_progress(change_id)
if change.status == TaskStatus.DONE: if change['status'] == 'Done':
# Clearly if we got here we didn't get restarted by # Clearly if we got here we didn't get restarted by
# snapd/systemctl (dry-run mode or logged in via SSH) # snapd/systemctl (dry-run mode or logged in via SSH)
self.controller.app.restart(remove_last_screen=False) self.controller.app.restart(remove_last_screen=False)
return return
if change.status not in (TaskStatus.DO, TaskStatus.DOING): if change['status'] not in ['Do', 'Doing']:
if change.err: self.update_failed(change.get('err', "Unknown error"))
err = change.err
else:
err = "Unknown error"
self.update_failed(err)
return return
self.update_progress(change) self.update_progress(change)
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
@ -288,14 +284,14 @@ class RefreshView(BaseView):
self._w = screen(rows, buttons, excerpt=_(self.update_failed_excerpt)) self._w = screen(rows, buttons, excerpt=_(self.update_failed_excerpt))
def update_progress(self, change): def update_progress(self, change):
for task in change.tasks: for task in change['tasks']:
tid = task.id tid = task['id']
if task.status == TaskStatus.DONE: if task['status'] == "Done":
bar = self.task_to_bar.get(tid) bar = self.task_to_bar.get(tid)
if bar is not None: if bar is not None:
self.lb_tasks.base_widget.body.remove(bar) self.lb_tasks.base_widget.body.remove(bar)
del self.task_to_bar[tid] del self.task_to_bar[tid]
if task.status == TaskStatus.DOING: if task['status'] == "Doing":
if tid not in self.task_to_bar: if tid not in self.task_to_bar:
self.task_to_bar[tid] = bar = TaskProgress() self.task_to_bar[tid] = bar = TaskProgress()
self.lb_tasks.base_widget.body.append(bar) self.lb_tasks.base_widget.body.append(bar)

View File

@ -152,8 +152,7 @@ class FakeSnapdConnection:
"Don't know how to fake POST response to {}".format((path, args))) "Don't know how to fake POST response to {}".format((path, args)))
def get(self, path, **args): def get(self, path, **args):
if 'change' not in path: time.sleep(1/self.scale_factor)
time.sleep(1/self.scale_factor)
filename = path.replace('/', '-') filename = path.replace('/', '-')
if args: if args:
filename += '-' + urlencode(sorted(args.items())) filename += '-' + urlencode(sorted(args.items()))
@ -185,10 +184,10 @@ class AsyncSnapd:
response = await run_in_thread( response = await run_in_thread(
partial(self.connection.post, path, body, **args)) partial(self.connection.post, path, body, **args))
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()['change']
async def post_and_wait(self, path, body, **args): async def post_and_wait(self, path, body, **args):
change = (await self.post(path, body, **args))['change'] change = await self.post(path, body, **args)
change_path = 'v2/changes/{}'.format(change) change_path = 'v2/changes/{}'.format(change)
while True: while True:
result = await self.get(change_path) result = await self.get(change_path)