Merge pull request #1443 from mwhudson/snapdapi-again
rich snapdapi again
This commit is contained in:
commit
631fbcad6a
|
@ -29,6 +29,7 @@ from subiquity.common.types import (
|
|||
AnyStep,
|
||||
ApplicationState,
|
||||
ApplicationStatus,
|
||||
Change,
|
||||
Disk,
|
||||
ErrorReportRef,
|
||||
GuidedChoice,
|
||||
|
@ -138,7 +139,7 @@ class API:
|
|||
"""Start the update and return the change id."""
|
||||
|
||||
class progress:
|
||||
def GET(change_id: str) -> dict: ...
|
||||
def GET(change_id: str) -> Change: ...
|
||||
|
||||
class keyboard:
|
||||
def GET() -> KeyboardSetup: ...
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
import datetime
|
||||
import enum
|
||||
import shlex
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -621,3 +621,42 @@ class WSLConfigurationAdvanced:
|
|||
@attr.s(auto_attribs=True)
|
||||
class WSLSetupOptions:
|
||||
install_language_support_packages: bool = attr.ib(default=True)
|
||||
|
||||
|
||||
class TaskStatus(enum.Enum):
|
||||
DO = "Do"
|
||||
DOING = "Doing"
|
||||
DONE = "Done"
|
||||
ABORT = "Abort"
|
||||
UNDO = "Undo"
|
||||
UNDOING = "Undoing"
|
||||
HOLD = "Hold"
|
||||
ERROR = "Error"
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskProgress:
|
||||
label: str = ''
|
||||
done: int = 0
|
||||
total: int = 0
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class Task:
|
||||
id: str
|
||||
kind: str
|
||||
summary: str
|
||||
status: TaskStatus
|
||||
progress: TaskProgress = TaskProgress()
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class Change:
|
||||
id: str
|
||||
kind: str
|
||||
summary: str
|
||||
status: TaskStatus
|
||||
tasks: List[Task]
|
||||
ready: bool
|
||||
err: Optional[str] = None
|
||||
data: Any = None
|
||||
|
|
|
@ -27,6 +27,7 @@ from subiquitycore.context import with_context
|
|||
|
||||
from subiquity.common.apidef import API
|
||||
from subiquity.common.types import (
|
||||
Change,
|
||||
RefreshCheckState,
|
||||
RefreshStatus,
|
||||
)
|
||||
|
@ -37,6 +38,12 @@ from subiquity.common.snap import (
|
|||
from subiquity.server.controller import (
|
||||
SubiquityController,
|
||||
)
|
||||
from subiquity.server.snapdapi import (
|
||||
SnapAction,
|
||||
SnapActionRequest,
|
||||
TaskStatus,
|
||||
post_and_wait,
|
||||
)
|
||||
from subiquity.server.types import InstallerChannels
|
||||
|
||||
|
||||
|
@ -99,33 +106,34 @@ class RefreshController(SubiquityController):
|
|||
change_id = await self.start_update(context=context)
|
||||
while True:
|
||||
change = await self.get_progress(change_id)
|
||||
if change['status'] not in ['Do', 'Doing', 'Done']:
|
||||
raise Exception(f"update failed: {change['status']}")
|
||||
if change.status not in [
|
||||
TaskStatus.DO, TaskStatus.DOING, TaskStatus.DONE]:
|
||||
raise Exception(f"update failed: {change.status}")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@with_context()
|
||||
async def configure_snapd(self, context):
|
||||
with context.child("get_details") as subcontext:
|
||||
try:
|
||||
r = await self.app.snapd.get(
|
||||
'v2/snaps/{snap_name}'.format(
|
||||
snap_name=self.snap_name))
|
||||
r = await self.app.snapdapi.v2.snaps[self.snap_name].GET()
|
||||
except requests.exceptions.RequestException:
|
||||
log.exception("getting snap details")
|
||||
return
|
||||
self.status.current_snap_version = r['result']['version']
|
||||
self.status.current_snap_version = r.version
|
||||
for k in 'channel', 'revision', 'version':
|
||||
self.app.note_data_for_apport(
|
||||
"Snap" + k.title(), r['result'][k])
|
||||
"Snap" + k.title(), getattr(r, k))
|
||||
subcontext.description = "current version of snap is: %r" % (
|
||||
self.status.current_snap_version)
|
||||
channel = self.get_refresh_channel()
|
||||
desc = "switching {} to {}".format(self.snap_name, channel)
|
||||
with context.child("switching", desc) as subcontext:
|
||||
try:
|
||||
await self.app.snapd.post_and_wait(
|
||||
'v2/snaps/{}'.format(self.snap_name),
|
||||
{'action': 'switch', 'channel': channel})
|
||||
await post_and_wait(
|
||||
self.app.snapdapi,
|
||||
self.app.snapdapi.v2.snaps[self.snap_name].POST,
|
||||
SnapActionRequest(
|
||||
action=SnapAction.SWITCH, channel=channel))
|
||||
except requests.exceptions.RequestException:
|
||||
log.exception("switching channels")
|
||||
return
|
||||
|
@ -172,17 +180,17 @@ class RefreshController(SubiquityController):
|
|||
self.status.availability = RefreshCheckState.UNAVAILABLE
|
||||
return
|
||||
try:
|
||||
result = await self.app.snapd.get('v2/find', select='refresh')
|
||||
result = await self.app.snapdapi.v2.find.GET(select='refresh')
|
||||
except requests.exceptions.RequestException:
|
||||
log.exception("checking for snap update failed")
|
||||
context.description = "checking for snap update failed"
|
||||
self.status.availability = RefreshCheckState.UNKNOWN
|
||||
return
|
||||
log.debug("check_for_update received %s", result)
|
||||
for snap in result["result"]:
|
||||
if snap["name"] != self.snap_name:
|
||||
for snap in result:
|
||||
if snap.name != self.snap_name:
|
||||
continue
|
||||
self.status.new_snap_version = snap["version"]
|
||||
self.status.new_snap_version = snap.version
|
||||
# In certain circumstances, the version of the snap that is
|
||||
# reported by snapd is older than the one currently running. In
|
||||
# this scenario, we do not want to suggest an update that would
|
||||
|
@ -213,16 +221,14 @@ class RefreshController(SubiquityController):
|
|||
|
||||
@with_context()
|
||||
async def start_update(self, context):
|
||||
change = await self.app.snapd.post(
|
||||
'v2/snaps/{}'.format(self.snap_name),
|
||||
{'action': 'refresh', 'ignore-running': True})
|
||||
context.description = "change id: {}".format(change)
|
||||
return change
|
||||
change_id = await self.app.snapdapi.v2.snaps[self.snap_name].POST(
|
||||
SnapActionRequest(action=SnapAction.REFRESH, ignore_running=True))
|
||||
context.description = "change id: {}".format(change_id)
|
||||
return change_id
|
||||
|
||||
async def get_progress(self, change):
|
||||
result = await self.app.snapd.get('v2/changes/{}'.format(change))
|
||||
change = result['result']
|
||||
if change['status'] == 'Done':
|
||||
async def get_progress(self, change_id: str) -> Change:
|
||||
change = await self.app.snapdapi.v2.changes[change_id].GET()
|
||||
if change.status == TaskStatus.DONE:
|
||||
# Clearly if we got here we didn't get restarted by
|
||||
# snapd/systemctl (dry-run mode)
|
||||
self.app.restart()
|
||||
|
@ -236,5 +242,5 @@ class RefreshController(SubiquityController):
|
|||
async def POST(self, context) -> str:
|
||||
return await self.start_update(context=context)
|
||||
|
||||
async def progress_GET(self, change_id: str) -> dict:
|
||||
async def progress_GET(self, change_id: str) -> Change:
|
||||
return await self.get_progress(change_id)
|
||||
|
|
|
@ -76,6 +76,7 @@ from subiquity.server.geoip import (
|
|||
)
|
||||
from subiquity.server.errors import ErrorController
|
||||
from subiquity.server.runner import get_command_runner
|
||||
from subiquity.server.snapdapi import make_api_client
|
||||
from subiquity.server.types import InstallerChannels
|
||||
from subiquitycore.snapd import (
|
||||
AsyncSnapd,
|
||||
|
@ -313,9 +314,11 @@ class SubiquityServer(Application):
|
|||
"examples", "snaps"),
|
||||
self.scale_factor, opts.output_base)
|
||||
self.snapd = AsyncSnapd(connection)
|
||||
self.snapdapi = make_api_client(self.snapd)
|
||||
elif os.path.exists(self.snapd_socket_path):
|
||||
connection = SnapdConnection(self.root, self.snapd_socket_path)
|
||||
self.snapd = AsyncSnapd(connection)
|
||||
self.snapdapi = make_api_client(self.snapd)
|
||||
else:
|
||||
log.info("no snapd socket found. Snap support is disabled")
|
||||
self.snapd = None
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
# Copyright 2022 Canonical, Ltd.
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import contextlib
|
||||
import enum
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from subiquity.common.api.client import make_client
|
||||
from subiquity.common.api.defs import api, path_parameter, Payload
|
||||
from subiquity.common.serialize import named_field, Serializer
|
||||
from subiquity.common.types import Change, TaskStatus
|
||||
|
||||
import attr
|
||||
|
||||
|
||||
log = logging.getLogger('subiquity.server.snapdapi')
|
||||
|
||||
RFC3339 = '%Y-%m-%dT%H:%M:%S.%fZ'
|
||||
|
||||
|
||||
def date_field(name=None, default=attr.NOTHING):
|
||||
metadata = {'time_fmt': RFC3339}
|
||||
if name is not None:
|
||||
metadata.update(named_field(name).metadata)
|
||||
return attr.ib(metadata=metadata, default=default)
|
||||
|
||||
|
||||
ChangeID = str
|
||||
|
||||
|
||||
class SnapStatus(enum.Enum):
|
||||
ACTIVE = 'active'
|
||||
AVAILABLE = 'available'
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class Publisher:
|
||||
id: str
|
||||
username: str
|
||||
display_name: str = named_field('display-name')
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class Snap:
|
||||
id: str
|
||||
name: str
|
||||
status: SnapStatus
|
||||
version: str
|
||||
revision: str
|
||||
channel: str
|
||||
publisher: Optional[Publisher] = None
|
||||
|
||||
|
||||
class SnapAction(enum.Enum):
|
||||
REFRESH = 'refresh'
|
||||
SWITCH = 'switch'
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class SnapActionRequest:
|
||||
action: SnapAction
|
||||
channel: str = ''
|
||||
ignore_running: bool = named_field('ignore-running', False)
|
||||
|
||||
|
||||
class ResponseType:
|
||||
SYNC = 'sync'
|
||||
ASYNC = 'async'
|
||||
ERROR = 'error'
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class Response:
|
||||
type: str
|
||||
status_code: int = named_field("status-code")
|
||||
status: str
|
||||
|
||||
|
||||
@api
|
||||
class SnapdAPI:
|
||||
serialize_query_args = False
|
||||
|
||||
class v2:
|
||||
class changes:
|
||||
@path_parameter
|
||||
class change_id:
|
||||
def GET() -> Change: ...
|
||||
|
||||
class snaps:
|
||||
@path_parameter
|
||||
class snap_name:
|
||||
def GET() -> Snap: ...
|
||||
def POST(action: Payload[SnapActionRequest]) -> ChangeID: ...
|
||||
|
||||
class find:
|
||||
def GET(name: str = '', select: str = '') -> List[Snap]: ...
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def raise_for_status(self):
|
||||
pass
|
||||
|
||||
async def json(self):
|
||||
return self.data
|
||||
|
||||
|
||||
class _FakeError:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def raise_for_status(self):
|
||||
raise aiohttp.ClientError(self.data['result']['message'])
|
||||
|
||||
|
||||
def make_api_client(async_snapd):
|
||||
# subiquity.common.api.client is designed around how to make requests
|
||||
# with aiohttp's client code, not the AsyncSnapd API but with a bit of
|
||||
# effort it can be contorted into shape. Clearly it would be better to
|
||||
# use aiohttp to talk to snapd but that would require porting across
|
||||
# the fake implementation used in dry-run mode.
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def make_request(method, path, *, params, json):
|
||||
if method == "GET":
|
||||
content = await async_snapd.get(path[1:], **params)
|
||||
else:
|
||||
content = await async_snapd.post(path[1:], json, **params)
|
||||
response = serializer.deserialize(Response, content)
|
||||
if response.type == ResponseType.SYNC:
|
||||
content = content['result']
|
||||
elif response.type == ResponseType.ASYNC:
|
||||
content = content['change']
|
||||
elif response.type == ResponseType.ERROR:
|
||||
yield _FakeError()
|
||||
yield _FakeResponse(content)
|
||||
|
||||
serializer = Serializer(
|
||||
ignore_unknown_fields=True, serialize_enums_by='value')
|
||||
|
||||
return make_client(SnapdAPI, make_request, serializer=serializer)
|
||||
|
||||
|
||||
async def post_and_wait(client, meth, *args, **kw):
|
||||
change_id = await meth(*args, **kw)
|
||||
log.debug('post_and_wait %s', change_id)
|
||||
|
||||
while True:
|
||||
result = await client.v2.changes[change_id].GET()
|
||||
if result.status == TaskStatus.DONE:
|
||||
return result.data
|
||||
await asyncio.sleep(0.1)
|
|
@ -32,7 +32,7 @@ from subiquitycore.ui.container import Columns, ListBox
|
|||
from subiquitycore.ui.spinner import Spinner
|
||||
from subiquitycore.ui.utils import button_pile, Color, screen
|
||||
|
||||
from subiquity.common.types import RefreshCheckState
|
||||
from subiquity.common.types import RefreshCheckState, TaskStatus
|
||||
|
||||
log = logging.getLogger('subiquity.ui.views.refresh')
|
||||
|
||||
|
@ -77,20 +77,20 @@ class TaskProgress(WidgetWrap):
|
|||
super().__init__(cols)
|
||||
|
||||
def update(self, task):
|
||||
progress = task['progress']
|
||||
done = progress['done']
|
||||
total = progress['total']
|
||||
progress = task.progress
|
||||
done = progress.done
|
||||
total = progress.total
|
||||
if total > 1:
|
||||
if self.mode == "spinning":
|
||||
bar = TaskProgressBar()
|
||||
self._w = bar
|
||||
else:
|
||||
bar = self._w
|
||||
bar.label = task['summary']
|
||||
bar.label = task.summary
|
||||
bar.done = total
|
||||
bar.current = done
|
||||
else:
|
||||
self.label.set_text(task['summary'])
|
||||
self.label.set_text(task.summary)
|
||||
self.spinner.spin()
|
||||
|
||||
|
||||
|
@ -253,13 +253,17 @@ class RefreshView(BaseView):
|
|||
return
|
||||
while True:
|
||||
change = await self.controller.get_progress(change_id)
|
||||
if change['status'] == 'Done':
|
||||
if change.status == TaskStatus.DONE:
|
||||
# Clearly if we got here we didn't get restarted by
|
||||
# snapd/systemctl (dry-run mode or logged in via SSH)
|
||||
self.controller.app.restart(remove_last_screen=False)
|
||||
return
|
||||
if change['status'] not in ['Do', 'Doing']:
|
||||
self.update_failed(change.get('err', "Unknown error"))
|
||||
if change.status not in (TaskStatus.DO, TaskStatus.DOING):
|
||||
if change.err:
|
||||
err = change.err
|
||||
else:
|
||||
err = "Unknown error"
|
||||
self.update_failed(err)
|
||||
return
|
||||
self.update_progress(change)
|
||||
await asyncio.sleep(0.1)
|
||||
|
@ -284,14 +288,14 @@ class RefreshView(BaseView):
|
|||
self._w = screen(rows, buttons, excerpt=_(self.update_failed_excerpt))
|
||||
|
||||
def update_progress(self, change):
|
||||
for task in change['tasks']:
|
||||
tid = task['id']
|
||||
if task['status'] == "Done":
|
||||
for task in change.tasks:
|
||||
tid = task.id
|
||||
if task.status == TaskStatus.DONE:
|
||||
bar = self.task_to_bar.get(tid)
|
||||
if bar is not None:
|
||||
self.lb_tasks.base_widget.body.remove(bar)
|
||||
del self.task_to_bar[tid]
|
||||
if task['status'] == "Doing":
|
||||
if task.status == TaskStatus.DOING:
|
||||
if tid not in self.task_to_bar:
|
||||
self.task_to_bar[tid] = bar = TaskProgress()
|
||||
self.lb_tasks.base_widget.body.append(bar)
|
||||
|
|
|
@ -152,7 +152,8 @@ class FakeSnapdConnection:
|
|||
"Don't know how to fake POST response to {}".format((path, args)))
|
||||
|
||||
def get(self, path, **args):
|
||||
time.sleep(1/self.scale_factor)
|
||||
if 'change' not in path:
|
||||
time.sleep(1/self.scale_factor)
|
||||
filename = path.replace('/', '-')
|
||||
if args:
|
||||
filename += '-' + urlencode(sorted(args.items()))
|
||||
|
@ -184,10 +185,10 @@ class AsyncSnapd:
|
|||
response = await run_in_thread(
|
||||
partial(self.connection.post, path, body, **args))
|
||||
response.raise_for_status()
|
||||
return response.json()['change']
|
||||
return response.json()
|
||||
|
||||
async def post_and_wait(self, path, body, **args):
|
||||
change = await self.post(path, body, **args)
|
||||
change = (await self.post(path, body, **args))['change']
|
||||
change_path = 'v2/changes/{}'.format(change)
|
||||
while True:
|
||||
result = await self.get(change_path)
|
||||
|
|
Loading…
Reference in New Issue