make a helpful async wrapper around snapd

This commit is contained in:
Michael Hudson-Doyle 2019-12-03 12:28:05 +13:00
parent 3ba3b04e09
commit 67be814dc3
4 changed files with 48 additions and 62 deletions

View File

@ -13,9 +13,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import enum
from functools import partial
import logging
import os
@ -25,7 +23,6 @@ from subiquitycore.controller import BaseController
from subiquitycore.core import Skip
from subiquity.async_helpers import (
run_in_thread,
schedule_task,
)
@ -74,14 +71,11 @@ class RefreshController(BaseController):
async def configure_snapd(self):
try:
response = await run_in_thread(
self.app.snapd_connection.get,
r = await self.app.snapd.get(
'v2/snaps/{snap_name}'.format(snap_name=self.snap_name))
response.raise_for_status()
except requests.exceptions.RequestException:
log.exception("getting snap details")
return
r = response.json()
self.current_snap_version = r['result']['version']
for k in 'channel', 'revision', 'version':
self.app.note_data_for_apport(
@ -92,27 +86,12 @@ class RefreshController(BaseController):
channel = self.get_refresh_channel()
log.debug("switching %s to %s", self.snap_name, channel)
try:
response = await run_in_thread(
self.app.snapd_connection.post,
await self.app.snapd.post_and_wait(
'v2/snaps/{}'.format(self.snap_name),
{'action': 'switch', 'channel': channel})
response.raise_for_status()
except requests.exceptions.RequestException:
log.exception("switching channels")
return
change = response.json()["change"]
while True:
try:
response = await run_in_thread(
self.app.snapd_connection.get,
'v2/changes/{}'.format(change))
response.raise_for_status()
except requests.exceptions.RequestException:
log.exception("checking switch")
return
if response.json()["result"]["status"] == "Done":
break
await asyncio.sleep(0.1)
log.debug("snap switching completed")
self.switch_state = SwitchState.SWITCHED
self._maybe_check_for_update()
@ -170,12 +149,7 @@ class RefreshController(BaseController):
async def check_for_update(self):
try:
response = await run_in_thread(
partial(
self.app.snapd_connection.get,
'v2/find',
select='refresh'))
response.raise_for_status()
result = await self.app.snapd.get('v2/find', select='refresh')
except requests.exceptions.RequestException as e:
log.exception("checking for update")
self.check_error = e
@ -186,7 +160,6 @@ class RefreshController(BaseController):
# ones!
if self.check_state.is_definite():
return
result = response.json()
log.debug("_check_result %s", result)
for snap in result["result"]:
if snap["name"] == self.snap_name:
@ -208,35 +181,29 @@ class RefreshController(BaseController):
async def _start_update(self, callback):
try:
response = await run_in_thread(
self.app.snapd_connection.post,
change = await self.app.snapd.post(
'v2/snaps/{}'.format(self.snap_name),
{'action': 'refresh'})
response.raise_for_status()
except requests.exceptions.RequestException as e:
log.exception("requesting update")
self.update_state = CheckState.FAILED
self.update_failure = e
return
result = response.json()
log.debug("refresh requested: %s", result)
callback(result['change'])
log.debug("refresh requested: %s", change)
callback(change)
def get_progress(self, change, callback):
schedule_task(self._get_progress(change, callback))
async def _get_progress(self, change, callback):
try:
response = await run_in_thread(
self.app.snapd_connection.get,
result = await self.app.snapd.get(
'v2/changes/{}'.format(change))
response.raise_for_status()
except requests.exceptions.RequestException as e:
log.exception("checking for progress")
self.update_state = CheckState.FAILED
self.update_failure = e
return
result = response.json()
callback(result['result'])
def start_ui(self, index=1):

View File

@ -13,7 +13,6 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from functools import partial
import logging
import requests.exceptions
@ -22,7 +21,6 @@ from subiquitycore.controller import BaseController
from subiquitycore.core import Skip
from subiquity.async_helpers import (
run_in_thread,
schedule_task,
)
from subiquity.models.snaplist import SnapSelection
@ -33,16 +31,15 @@ log = logging.getLogger('subiquity.controllers.snaplist')
class SnapdSnapInfoLoader:
def __init__(self, model, app, connection, store_section):
def __init__(self, model, snapd, store_section):
self.model = model
self.app = app
self.store_section = store_section
self._running = False
self.snap_list_fetched = False
self.failed = False
self.connection = connection
self.snapd = snapd
self.pending_info_snaps = []
self.ongoing = {} # {snap:[callbacks]}
@ -54,12 +51,8 @@ class SnapdSnapInfoLoader:
async def _start(self):
self.ongoing[None] = []
try:
response = await run_in_thread(
partial(
self.connection.get,
'v2/find',
section=self.store_section))
response.raise_for_status()
result = await self.snapd.get(
'v2/find', section=self.store_section)
except requests.exceptions.RequestException:
log.exception("loading list of snaps failed")
self.failed = True
@ -67,7 +60,7 @@ class SnapdSnapInfoLoader:
return
if not self._running:
return
self.model.load_find_data(response.json())
self.model.load_find_data(result)
self.snap_list_fetched = True
self.pending_snaps = self.model.get_snap_list()
log.debug("fetched list of %s snaps", len(self.model.get_snap_list()))
@ -84,19 +77,14 @@ class SnapdSnapInfoLoader:
async def _fetch_info_for_snap(self, snap):
log.debug('starting fetch for %s', snap.name)
try:
response = await run_in_thread(
partial(
self.connection.get,
'v2/find',
name=snap.name))
response.raise_for_status()
data = await self.snapd.get(
'v2/find', name=snap.name)
except requests.exceptions.RequestException:
log.exception("loading snap info failed")
# XXX something better here?
return
if not self._running:
return
data = response.json()
log.debug('got data for %s', snap.name)
self.model.load_info_data(data)
for cb in self.ongoing.pop(snap):
@ -131,7 +119,7 @@ class SnapListController(BaseController):
def _make_loader(self):
return SnapdSnapInfoLoader(
self.model, self.app, self.app.snapd_connection,
self.model, self.app.snapd,
self.opts.snap_section)
def __init__(self, app):

View File

@ -29,6 +29,7 @@ from subiquity.controllers.error import (
)
from subiquity.models.subiquity import SubiquityModel
from subiquity.snapd import (
AsyncSnapd,
FakeSnapdConnection,
SnapdConnection,
)
@ -98,7 +99,7 @@ class Subiquity(Application):
"examples", "snaps"))
else:
connection = SnapdConnection(self.root, self.snapd_socket_path)
self.snapd_connection = connection
self.snapd = AsyncSnapd(connection)
self.signal.connect_signals([
('network-proxy-set', self._proxy_set),
('network-change', self._network_change),
@ -129,7 +130,7 @@ class Subiquity(Application):
def _proxy_set(self):
self.run_in_bg(
lambda: self.snapd_connection.configure_proxy(
lambda: self.snapd.connection.configure_proxy(
self.base_model.proxy),
lambda fut: (
fut.result(), self.signal.emit_signal('snapd-network-change')),

View File

@ -13,6 +13,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
from functools import partial
import glob
import json
import logging
@ -23,6 +25,7 @@ from urllib.parse import (
urlencode,
)
from subiquity.async_helpers import run_in_thread
from subiquitycore.utils import run_command
import requests_unixsocket
@ -157,3 +160,30 @@ class FakeSnapdConnection:
return rs.next()
raise Exception(
"Don't know how to fake GET response to {}".format((path, args)))
class AsyncSnapd:
def __init__(self, connection):
self.connection = connection
async def get(self, path, **args):
response = await run_in_thread(
partial(self.connection.get, path, **args))
response.raise_for_status()
return response.json()
async def post(self, path, body, **args):
response = await run_in_thread(
partial(self.connection.post, path, body, **args))
response.raise_for_status()
return response.json()['change']
async def post_and_wait(self, path, body, **args):
change = await self.post(path, body, **args)
change_path = 'v2/changes/{}'.format(change)
while True:
result = await self.get(change_path)
if result["result"]["status"] == "Done":
break
await asyncio.sleep(0.1)