use a ContextVar to know when a request is part of making a view

we can do this now that we are based on core20
This commit is contained in:
Michael Hudson-Doyle 2021-10-05 16:52:04 +13:00 committed by Dan Bungert
parent 755436ea45
commit eb35a2c69e
17 changed files with 48 additions and 33 deletions

View File

@ -14,6 +14,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import contextvars
import inspect
import logging
import os
@ -135,9 +136,17 @@ class SubiquityClient(TuiApplication):
self.our_tty = "not a tty"
self.conn = aiohttp.UnixConnector(self.opts.socket)
self.in_make_view_cvar = contextvars.ContextVar(
'in_make_view', default=False)
def header_func():
if self.in_make_view_cvar.get():
return {'x-make-view-request': 'yes'}
else:
return None
self.client = make_client_for_conn(API, self.conn, self.resp_hook)
self.client1 = make_client_for_conn(
API, self.conn, self.resp_hook, headers={'x-first-request': 'yes'})
self.error_reporter = ErrorReporter(
self.context.child("ErrorReporter"), self.opts.dry_run, self.root,
@ -484,7 +493,11 @@ class SubiquityClient(TuiApplication):
await coro
async def make_view_for_controller(self, new):
view = await super().make_view_for_controller(new)
tok = self.in_make_view_cvar.set(True)
try:
view = await super().make_view_for_controller(new)
finally:
self.in_make_view_cvar.reset(tok)
if new.answers:
self.aio_loop.create_task(self._start_answers_for_view(new, view))
with open(self.state_path('last-screen'), 'w') as fp:

View File

@ -35,4 +35,3 @@ class SubiquityTuiController(TuiController):
self.answers = app.answers.get(self.name, {})
if self.endpoint_name is not None:
self.endpoint = getattr(self.app.client, self.endpoint_name)
self.endpoint1 = getattr(self.app.client1, self.endpoint_name)

View File

@ -55,7 +55,7 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator):
self.answers.setdefault('manual', [])
async def make_ui(self):
status = await self.endpoint1.guided.GET()
status = await self.endpoint.guided.GET()
if status.status == ProbeStatus.PROBING:
self.app.aio_loop.create_task(self._wait_for_probing())
return SlowProbing(self)

View File

@ -27,7 +27,7 @@ class IdentityController(SubiquityTuiController):
endpoint_name = 'identity'
async def make_ui(self):
data = await self.endpoint1.GET()
data = await self.endpoint.GET()
return IdentityView(self, data)
def run_answers(self):

View File

@ -28,7 +28,7 @@ class KeyboardController(SubiquityTuiController):
endpoint_name = 'keyboard'
async def make_ui(self):
setup = await self.endpoint1.GET()
setup = await self.endpoint.GET()
return KeyboardView(self, setup)
async def run_answers(self):

View File

@ -26,7 +26,7 @@ class MirrorController(SubiquityTuiController):
endpoint_name = 'mirror'
async def make_ui(self):
mirror = await self.endpoint1.GET()
mirror = await self.endpoint.GET()
return MirrorView(self, mirror)
def run_answers(self):

View File

@ -107,7 +107,7 @@ class NetworkController(SubiquityTuiController, NetworkAnswersMixin):
shutil.rmtree(self.tdir)
async def make_ui(self):
network_status = await self.endpoint1.GET()
network_status = await self.endpoint.GET()
self.view = NetworkView(
self, network_status.devices,
network_status.wlan_support_install_state.name)

View File

@ -26,7 +26,7 @@ class ProxyController(SubiquityTuiController):
endpoint_name = 'proxy'
async def make_ui(self):
proxy = await self.endpoint1.GET()
proxy = await self.endpoint.GET()
return ProxyView(self, proxy)
def run_answers(self):

View File

@ -54,7 +54,7 @@ class RefreshController(SubiquityTuiController):
if self.app.updated:
raise Skip()
show = False
self.status = await self.endpoint1.GET()
self.status = await self.endpoint.GET()
if index == 1:
if self.status.availability == RefreshCheckState.AVAILABLE:
show = True

View File

@ -26,7 +26,7 @@ class SerialController(SubiquityTuiController):
async def make_ui(self):
if not self.app.opts.run_on_serial:
raise Skip()
ssh_info = await self.app.client1.meta.ssh_info.GET()
ssh_info = await self.app.client.meta.ssh_info.GET()
return SerialView(self, ssh_info)
def run_answers(self):

View File

@ -37,7 +37,7 @@ class SnapListController(SubiquityTuiController):
endpoint_name = 'snaplist'
async def make_ui(self):
data = await self.endpoint1.GET()
data = await self.endpoint.GET()
if data.status == SnapCheckState.FAILED:
# If loading snaps failed or the network is disabled, skip the
# screen.

View File

@ -26,7 +26,7 @@ class SourceController(SubiquityTuiController):
endpoint_name = 'source'
async def make_ui(self):
sources = await self.endpoint1.GET()
sources = await self.endpoint.GET()
return SourceView(self, sources.sources, sources.current_id)
def run_answers(self):

View File

@ -47,7 +47,7 @@ class SSHController(SubiquityTuiController):
'ssh-import-id']
async def make_ui(self):
ssh_data = await self.endpoint1.GET()
ssh_data = await self.endpoint.GET()
return SSHView(self, ssh_data)
def run_answers(self):

View File

@ -30,7 +30,7 @@ class WelcomeController(SubiquityTuiController):
async def make_ui(self):
if not self.app.rich_mode:
raise Skip()
language = await self.endpoint1.GET()
language = await self.endpoint.GET()
i18n.switch_language(language)
self.serial = self.app.opts.run_on_serial
return WelcomeView(self, language, self.serial)

View File

@ -27,7 +27,7 @@ class ZdevController(SubiquityTuiController):
endpoint_name = 'zdev'
async def make_ui(self):
infos = await self.endpoint1.GET()
infos = await self.endpoint.GET()
return ZdevView(self, infos)
def run_answers(self):

View File

@ -68,22 +68,25 @@ def make_client(endpoint_cls, make_request, serializer=None):
def make_client_for_conn(
endpoint_cls, conn, resp_hook=lambda r: r, serializer=None,
headers={}):
header_func=None):
session = aiohttp.ClientSession(
connector=conn, connector_owner=False)
@contextlib38.asynccontextmanager
async def make_request(method, path, *, params, json):
async with aiohttp.ClientSession(
connector=conn, connector_owner=False,
headers=headers) as session:
# session.request needs a full URL with scheme and host
# even though that's in some ways a bit silly with a unix
# socket, so we just hardcode something here (I guess the
# "a" gets sent along to the server in the Host: header
# and the server could in principle do something like
# virtual host based selection but well....)
url = 'http://a' + path
async with session.request(
method, url, json=json, params=params,
timeout=0) as response:
yield resp_hook(response)
# session.request needs a full URL with scheme and host even though
# that's in some ways a bit silly with a unix socket, so we just
# hardcode something here (I guess the "a" gets sent along to the
# server in the Host: header and the server could in principle do
# something like virtual host based selection but well....)
url = 'http://a' + path
if header_func is not None:
headers = header_func()
else:
headers = None
async with session.request(
method, url, json=json, params=params,
headers=headers, timeout=0) as response:
yield resp_hook(response)
return make_client(endpoint_cls, make_request, serializer)

View File

@ -408,7 +408,7 @@ class SubiquityServer(Application):
if not controller.interactive():
override_status = 'skip'
elif (self.state == ApplicationState.NEEDS_CONFIRMATION and
request.headers.get('x-first-request') == 'yes'):
request.headers.get('x-make-view-request') == 'yes'):
if self.base_model.is_postinstall_only(controller.model_name):
override_status = 'confirm'
if override_status is not None: