use a ContextVar to know when a request is part of making a view
we can do this now that we are based on core20
This commit is contained in:
parent
755436ea45
commit
eb35a2c69e
|
@ -14,6 +14,7 @@
|
|||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
|
@ -135,9 +136,17 @@ class SubiquityClient(TuiApplication):
|
|||
self.our_tty = "not a tty"
|
||||
|
||||
self.conn = aiohttp.UnixConnector(self.opts.socket)
|
||||
|
||||
self.in_make_view_cvar = contextvars.ContextVar(
|
||||
'in_make_view', default=False)
|
||||
|
||||
def header_func():
|
||||
if self.in_make_view_cvar.get():
|
||||
return {'x-make-view-request': 'yes'}
|
||||
else:
|
||||
return None
|
||||
|
||||
self.client = make_client_for_conn(API, self.conn, self.resp_hook)
|
||||
self.client1 = make_client_for_conn(
|
||||
API, self.conn, self.resp_hook, headers={'x-first-request': 'yes'})
|
||||
|
||||
self.error_reporter = ErrorReporter(
|
||||
self.context.child("ErrorReporter"), self.opts.dry_run, self.root,
|
||||
|
@ -484,7 +493,11 @@ class SubiquityClient(TuiApplication):
|
|||
await coro
|
||||
|
||||
async def make_view_for_controller(self, new):
|
||||
view = await super().make_view_for_controller(new)
|
||||
tok = self.in_make_view_cvar.set(True)
|
||||
try:
|
||||
view = await super().make_view_for_controller(new)
|
||||
finally:
|
||||
self.in_make_view_cvar.reset(tok)
|
||||
if new.answers:
|
||||
self.aio_loop.create_task(self._start_answers_for_view(new, view))
|
||||
with open(self.state_path('last-screen'), 'w') as fp:
|
||||
|
|
|
@ -35,4 +35,3 @@ class SubiquityTuiController(TuiController):
|
|||
self.answers = app.answers.get(self.name, {})
|
||||
if self.endpoint_name is not None:
|
||||
self.endpoint = getattr(self.app.client, self.endpoint_name)
|
||||
self.endpoint1 = getattr(self.app.client1, self.endpoint_name)
|
||||
|
|
|
@ -55,7 +55,7 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator):
|
|||
self.answers.setdefault('manual', [])
|
||||
|
||||
async def make_ui(self):
|
||||
status = await self.endpoint1.guided.GET()
|
||||
status = await self.endpoint.guided.GET()
|
||||
if status.status == ProbeStatus.PROBING:
|
||||
self.app.aio_loop.create_task(self._wait_for_probing())
|
||||
return SlowProbing(self)
|
||||
|
|
|
@ -27,7 +27,7 @@ class IdentityController(SubiquityTuiController):
|
|||
endpoint_name = 'identity'
|
||||
|
||||
async def make_ui(self):
|
||||
data = await self.endpoint1.GET()
|
||||
data = await self.endpoint.GET()
|
||||
return IdentityView(self, data)
|
||||
|
||||
def run_answers(self):
|
||||
|
|
|
@ -28,7 +28,7 @@ class KeyboardController(SubiquityTuiController):
|
|||
endpoint_name = 'keyboard'
|
||||
|
||||
async def make_ui(self):
|
||||
setup = await self.endpoint1.GET()
|
||||
setup = await self.endpoint.GET()
|
||||
return KeyboardView(self, setup)
|
||||
|
||||
async def run_answers(self):
|
||||
|
|
|
@ -26,7 +26,7 @@ class MirrorController(SubiquityTuiController):
|
|||
endpoint_name = 'mirror'
|
||||
|
||||
async def make_ui(self):
|
||||
mirror = await self.endpoint1.GET()
|
||||
mirror = await self.endpoint.GET()
|
||||
return MirrorView(self, mirror)
|
||||
|
||||
def run_answers(self):
|
||||
|
|
|
@ -107,7 +107,7 @@ class NetworkController(SubiquityTuiController, NetworkAnswersMixin):
|
|||
shutil.rmtree(self.tdir)
|
||||
|
||||
async def make_ui(self):
|
||||
network_status = await self.endpoint1.GET()
|
||||
network_status = await self.endpoint.GET()
|
||||
self.view = NetworkView(
|
||||
self, network_status.devices,
|
||||
network_status.wlan_support_install_state.name)
|
||||
|
|
|
@ -26,7 +26,7 @@ class ProxyController(SubiquityTuiController):
|
|||
endpoint_name = 'proxy'
|
||||
|
||||
async def make_ui(self):
|
||||
proxy = await self.endpoint1.GET()
|
||||
proxy = await self.endpoint.GET()
|
||||
return ProxyView(self, proxy)
|
||||
|
||||
def run_answers(self):
|
||||
|
|
|
@ -54,7 +54,7 @@ class RefreshController(SubiquityTuiController):
|
|||
if self.app.updated:
|
||||
raise Skip()
|
||||
show = False
|
||||
self.status = await self.endpoint1.GET()
|
||||
self.status = await self.endpoint.GET()
|
||||
if index == 1:
|
||||
if self.status.availability == RefreshCheckState.AVAILABLE:
|
||||
show = True
|
||||
|
|
|
@ -26,7 +26,7 @@ class SerialController(SubiquityTuiController):
|
|||
async def make_ui(self):
|
||||
if not self.app.opts.run_on_serial:
|
||||
raise Skip()
|
||||
ssh_info = await self.app.client1.meta.ssh_info.GET()
|
||||
ssh_info = await self.app.client.meta.ssh_info.GET()
|
||||
return SerialView(self, ssh_info)
|
||||
|
||||
def run_answers(self):
|
||||
|
|
|
@ -37,7 +37,7 @@ class SnapListController(SubiquityTuiController):
|
|||
endpoint_name = 'snaplist'
|
||||
|
||||
async def make_ui(self):
|
||||
data = await self.endpoint1.GET()
|
||||
data = await self.endpoint.GET()
|
||||
if data.status == SnapCheckState.FAILED:
|
||||
# If loading snaps failed or the network is disabled, skip the
|
||||
# screen.
|
||||
|
|
|
@ -26,7 +26,7 @@ class SourceController(SubiquityTuiController):
|
|||
endpoint_name = 'source'
|
||||
|
||||
async def make_ui(self):
|
||||
sources = await self.endpoint1.GET()
|
||||
sources = await self.endpoint.GET()
|
||||
return SourceView(self, sources.sources, sources.current_id)
|
||||
|
||||
def run_answers(self):
|
||||
|
|
|
@ -47,7 +47,7 @@ class SSHController(SubiquityTuiController):
|
|||
'ssh-import-id']
|
||||
|
||||
async def make_ui(self):
|
||||
ssh_data = await self.endpoint1.GET()
|
||||
ssh_data = await self.endpoint.GET()
|
||||
return SSHView(self, ssh_data)
|
||||
|
||||
def run_answers(self):
|
||||
|
|
|
@ -30,7 +30,7 @@ class WelcomeController(SubiquityTuiController):
|
|||
async def make_ui(self):
|
||||
if not self.app.rich_mode:
|
||||
raise Skip()
|
||||
language = await self.endpoint1.GET()
|
||||
language = await self.endpoint.GET()
|
||||
i18n.switch_language(language)
|
||||
self.serial = self.app.opts.run_on_serial
|
||||
return WelcomeView(self, language, self.serial)
|
||||
|
|
|
@ -27,7 +27,7 @@ class ZdevController(SubiquityTuiController):
|
|||
endpoint_name = 'zdev'
|
||||
|
||||
async def make_ui(self):
|
||||
infos = await self.endpoint1.GET()
|
||||
infos = await self.endpoint.GET()
|
||||
return ZdevView(self, infos)
|
||||
|
||||
def run_answers(self):
|
||||
|
|
|
@ -68,22 +68,25 @@ def make_client(endpoint_cls, make_request, serializer=None):
|
|||
|
||||
def make_client_for_conn(
|
||||
endpoint_cls, conn, resp_hook=lambda r: r, serializer=None,
|
||||
headers={}):
|
||||
header_func=None):
|
||||
session = aiohttp.ClientSession(
|
||||
connector=conn, connector_owner=False)
|
||||
|
||||
@contextlib38.asynccontextmanager
|
||||
async def make_request(method, path, *, params, json):
|
||||
async with aiohttp.ClientSession(
|
||||
connector=conn, connector_owner=False,
|
||||
headers=headers) as session:
|
||||
# session.request needs a full URL with scheme and host
|
||||
# even though that's in some ways a bit silly with a unix
|
||||
# socket, so we just hardcode something here (I guess the
|
||||
# "a" gets sent along to the server in the Host: header
|
||||
# and the server could in principle do something like
|
||||
# virtual host based selection but well....)
|
||||
url = 'http://a' + path
|
||||
async with session.request(
|
||||
method, url, json=json, params=params,
|
||||
timeout=0) as response:
|
||||
yield resp_hook(response)
|
||||
# session.request needs a full URL with scheme and host even though
|
||||
# that's in some ways a bit silly with a unix socket, so we just
|
||||
# hardcode something here (I guess the "a" gets sent along to the
|
||||
# server in the Host: header and the server could in principle do
|
||||
# something like virtual host based selection but well....)
|
||||
url = 'http://a' + path
|
||||
if header_func is not None:
|
||||
headers = header_func()
|
||||
else:
|
||||
headers = None
|
||||
async with session.request(
|
||||
method, url, json=json, params=params,
|
||||
headers=headers, timeout=0) as response:
|
||||
yield resp_hook(response)
|
||||
|
||||
return make_client(endpoint_cls, make_request, serializer)
|
||||
|
|
|
@ -408,7 +408,7 @@ class SubiquityServer(Application):
|
|||
if not controller.interactive():
|
||||
override_status = 'skip'
|
||||
elif (self.state == ApplicationState.NEEDS_CONFIRMATION and
|
||||
request.headers.get('x-first-request') == 'yes'):
|
||||
request.headers.get('x-make-view-request') == 'yes'):
|
||||
if self.base_model.is_postinstall_only(controller.model_name):
|
||||
override_status = 'confirm'
|
||||
if override_status is not None:
|
||||
|
|
Loading…
Reference in New Issue