From eb35a2c69ebdb94be8fdea7df6b95bca7b3731b4 Mon Sep 17 00:00:00 2001 From: Michael Hudson-Doyle Date: Tue, 5 Oct 2021 16:52:04 +1300 Subject: [PATCH] use a ContextVar to know when a request is part of making a view we can do this now that we are based on core20 --- subiquity/client/client.py | 19 +++++++++++-- subiquity/client/controller.py | 1 - subiquity/client/controllers/filesystem.py | 2 +- subiquity/client/controllers/identity.py | 2 +- subiquity/client/controllers/keyboard.py | 2 +- subiquity/client/controllers/mirror.py | 2 +- subiquity/client/controllers/network.py | 2 +- subiquity/client/controllers/proxy.py | 2 +- subiquity/client/controllers/refresh.py | 2 +- subiquity/client/controllers/serial.py | 2 +- subiquity/client/controllers/snaplist.py | 2 +- subiquity/client/controllers/source.py | 2 +- subiquity/client/controllers/ssh.py | 2 +- subiquity/client/controllers/welcome.py | 2 +- subiquity/client/controllers/zdev.py | 2 +- subiquity/common/api/client.py | 33 ++++++++++++---------- subiquity/server/server.py | 2 +- 17 files changed, 48 insertions(+), 33 deletions(-) diff --git a/subiquity/client/client.py b/subiquity/client/client.py index 095e8a38..bc43a8df 100644 --- a/subiquity/client/client.py +++ b/subiquity/client/client.py @@ -14,6 +14,7 @@ # along with this program. If not, see . import asyncio +import contextvars import inspect import logging import os @@ -135,9 +136,17 @@ class SubiquityClient(TuiApplication): self.our_tty = "not a tty" self.conn = aiohttp.UnixConnector(self.opts.socket) + + self.in_make_view_cvar = contextvars.ContextVar( + 'in_make_view', default=False) + + def header_func(): + if self.in_make_view_cvar.get(): + return {'x-make-view-request': 'yes'} + else: + return None + self.client = make_client_for_conn(API, self.conn, self.resp_hook) - self.client1 = make_client_for_conn( - API, self.conn, self.resp_hook, headers={'x-first-request': 'yes'}) self.error_reporter = ErrorReporter( self.context.child("ErrorReporter"), self.opts.dry_run, self.root, @@ -484,7 +493,11 @@ class SubiquityClient(TuiApplication): await coro async def make_view_for_controller(self, new): - view = await super().make_view_for_controller(new) + tok = self.in_make_view_cvar.set(True) + try: + view = await super().make_view_for_controller(new) + finally: + self.in_make_view_cvar.reset(tok) if new.answers: self.aio_loop.create_task(self._start_answers_for_view(new, view)) with open(self.state_path('last-screen'), 'w') as fp: diff --git a/subiquity/client/controller.py b/subiquity/client/controller.py index 9192bc36..1740ef05 100644 --- a/subiquity/client/controller.py +++ b/subiquity/client/controller.py @@ -35,4 +35,3 @@ class SubiquityTuiController(TuiController): self.answers = app.answers.get(self.name, {}) if self.endpoint_name is not None: self.endpoint = getattr(self.app.client, self.endpoint_name) - self.endpoint1 = getattr(self.app.client1, self.endpoint_name) diff --git a/subiquity/client/controllers/filesystem.py b/subiquity/client/controllers/filesystem.py index a61c7386..33f494de 100644 --- a/subiquity/client/controllers/filesystem.py +++ b/subiquity/client/controllers/filesystem.py @@ -55,7 +55,7 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator): self.answers.setdefault('manual', []) async def make_ui(self): - status = await self.endpoint1.guided.GET() + status = await self.endpoint.guided.GET() if status.status == ProbeStatus.PROBING: self.app.aio_loop.create_task(self._wait_for_probing()) return SlowProbing(self) diff --git a/subiquity/client/controllers/identity.py b/subiquity/client/controllers/identity.py index 501ce0cf..b0ccda7b 100644 --- a/subiquity/client/controllers/identity.py +++ b/subiquity/client/controllers/identity.py @@ -27,7 +27,7 @@ class IdentityController(SubiquityTuiController): endpoint_name = 'identity' async def make_ui(self): - data = await self.endpoint1.GET() + data = await self.endpoint.GET() return IdentityView(self, data) def run_answers(self): diff --git a/subiquity/client/controllers/keyboard.py b/subiquity/client/controllers/keyboard.py index 3bff896d..cc6d8f6f 100644 --- a/subiquity/client/controllers/keyboard.py +++ b/subiquity/client/controllers/keyboard.py @@ -28,7 +28,7 @@ class KeyboardController(SubiquityTuiController): endpoint_name = 'keyboard' async def make_ui(self): - setup = await self.endpoint1.GET() + setup = await self.endpoint.GET() return KeyboardView(self, setup) async def run_answers(self): diff --git a/subiquity/client/controllers/mirror.py b/subiquity/client/controllers/mirror.py index d976789a..3626c3f5 100644 --- a/subiquity/client/controllers/mirror.py +++ b/subiquity/client/controllers/mirror.py @@ -26,7 +26,7 @@ class MirrorController(SubiquityTuiController): endpoint_name = 'mirror' async def make_ui(self): - mirror = await self.endpoint1.GET() + mirror = await self.endpoint.GET() return MirrorView(self, mirror) def run_answers(self): diff --git a/subiquity/client/controllers/network.py b/subiquity/client/controllers/network.py index 0abeb48f..77d946e9 100644 --- a/subiquity/client/controllers/network.py +++ b/subiquity/client/controllers/network.py @@ -107,7 +107,7 @@ class NetworkController(SubiquityTuiController, NetworkAnswersMixin): shutil.rmtree(self.tdir) async def make_ui(self): - network_status = await self.endpoint1.GET() + network_status = await self.endpoint.GET() self.view = NetworkView( self, network_status.devices, network_status.wlan_support_install_state.name) diff --git a/subiquity/client/controllers/proxy.py b/subiquity/client/controllers/proxy.py index 02cc02c7..0223b04a 100644 --- a/subiquity/client/controllers/proxy.py +++ b/subiquity/client/controllers/proxy.py @@ -26,7 +26,7 @@ class ProxyController(SubiquityTuiController): endpoint_name = 'proxy' async def make_ui(self): - proxy = await self.endpoint1.GET() + proxy = await self.endpoint.GET() return ProxyView(self, proxy) def run_answers(self): diff --git a/subiquity/client/controllers/refresh.py b/subiquity/client/controllers/refresh.py index 3933e6e4..05214a9e 100644 --- a/subiquity/client/controllers/refresh.py +++ b/subiquity/client/controllers/refresh.py @@ -54,7 +54,7 @@ class RefreshController(SubiquityTuiController): if self.app.updated: raise Skip() show = False - self.status = await self.endpoint1.GET() + self.status = await self.endpoint.GET() if index == 1: if self.status.availability == RefreshCheckState.AVAILABLE: show = True diff --git a/subiquity/client/controllers/serial.py b/subiquity/client/controllers/serial.py index a8e33ba3..039d23d5 100644 --- a/subiquity/client/controllers/serial.py +++ b/subiquity/client/controllers/serial.py @@ -26,7 +26,7 @@ class SerialController(SubiquityTuiController): async def make_ui(self): if not self.app.opts.run_on_serial: raise Skip() - ssh_info = await self.app.client1.meta.ssh_info.GET() + ssh_info = await self.app.client.meta.ssh_info.GET() return SerialView(self, ssh_info) def run_answers(self): diff --git a/subiquity/client/controllers/snaplist.py b/subiquity/client/controllers/snaplist.py index e0d14ab3..32161304 100644 --- a/subiquity/client/controllers/snaplist.py +++ b/subiquity/client/controllers/snaplist.py @@ -37,7 +37,7 @@ class SnapListController(SubiquityTuiController): endpoint_name = 'snaplist' async def make_ui(self): - data = await self.endpoint1.GET() + data = await self.endpoint.GET() if data.status == SnapCheckState.FAILED: # If loading snaps failed or the network is disabled, skip the # screen. diff --git a/subiquity/client/controllers/source.py b/subiquity/client/controllers/source.py index d2b549db..c0ce37a7 100644 --- a/subiquity/client/controllers/source.py +++ b/subiquity/client/controllers/source.py @@ -26,7 +26,7 @@ class SourceController(SubiquityTuiController): endpoint_name = 'source' async def make_ui(self): - sources = await self.endpoint1.GET() + sources = await self.endpoint.GET() return SourceView(self, sources.sources, sources.current_id) def run_answers(self): diff --git a/subiquity/client/controllers/ssh.py b/subiquity/client/controllers/ssh.py index f5450580..3eae0117 100644 --- a/subiquity/client/controllers/ssh.py +++ b/subiquity/client/controllers/ssh.py @@ -47,7 +47,7 @@ class SSHController(SubiquityTuiController): 'ssh-import-id'] async def make_ui(self): - ssh_data = await self.endpoint1.GET() + ssh_data = await self.endpoint.GET() return SSHView(self, ssh_data) def run_answers(self): diff --git a/subiquity/client/controllers/welcome.py b/subiquity/client/controllers/welcome.py index 10a4056a..ff600174 100644 --- a/subiquity/client/controllers/welcome.py +++ b/subiquity/client/controllers/welcome.py @@ -30,7 +30,7 @@ class WelcomeController(SubiquityTuiController): async def make_ui(self): if not self.app.rich_mode: raise Skip() - language = await self.endpoint1.GET() + language = await self.endpoint.GET() i18n.switch_language(language) self.serial = self.app.opts.run_on_serial return WelcomeView(self, language, self.serial) diff --git a/subiquity/client/controllers/zdev.py b/subiquity/client/controllers/zdev.py index 7ae1923f..92c7d590 100644 --- a/subiquity/client/controllers/zdev.py +++ b/subiquity/client/controllers/zdev.py @@ -27,7 +27,7 @@ class ZdevController(SubiquityTuiController): endpoint_name = 'zdev' async def make_ui(self): - infos = await self.endpoint1.GET() + infos = await self.endpoint.GET() return ZdevView(self, infos) def run_answers(self): diff --git a/subiquity/common/api/client.py b/subiquity/common/api/client.py index 5db1f230..cca1f442 100644 --- a/subiquity/common/api/client.py +++ b/subiquity/common/api/client.py @@ -68,22 +68,25 @@ def make_client(endpoint_cls, make_request, serializer=None): def make_client_for_conn( endpoint_cls, conn, resp_hook=lambda r: r, serializer=None, - headers={}): + header_func=None): + session = aiohttp.ClientSession( + connector=conn, connector_owner=False) + @contextlib38.asynccontextmanager async def make_request(method, path, *, params, json): - async with aiohttp.ClientSession( - connector=conn, connector_owner=False, - headers=headers) as session: - # session.request needs a full URL with scheme and host - # even though that's in some ways a bit silly with a unix - # socket, so we just hardcode something here (I guess the - # "a" gets sent along to the server in the Host: header - # and the server could in principle do something like - # virtual host based selection but well....) - url = 'http://a' + path - async with session.request( - method, url, json=json, params=params, - timeout=0) as response: - yield resp_hook(response) + # session.request needs a full URL with scheme and host even though + # that's in some ways a bit silly with a unix socket, so we just + # hardcode something here (I guess the "a" gets sent along to the + # server in the Host: header and the server could in principle do + # something like virtual host based selection but well....) + url = 'http://a' + path + if header_func is not None: + headers = header_func() + else: + headers = None + async with session.request( + method, url, json=json, params=params, + headers=headers, timeout=0) as response: + yield resp_hook(response) return make_client(endpoint_cls, make_request, serializer) diff --git a/subiquity/server/server.py b/subiquity/server/server.py index 6c6c6786..126c5d00 100644 --- a/subiquity/server/server.py +++ b/subiquity/server/server.py @@ -408,7 +408,7 @@ class SubiquityServer(Application): if not controller.interactive(): override_status = 'skip' elif (self.state == ApplicationState.NEEDS_CONFIRMATION and - request.headers.get('x-first-request') == 'yes'): + request.headers.get('x-make-view-request') == 'yes'): if self.base_model.is_postinstall_only(controller.model_name): override_status = 'confirm' if override_status is not None: