Merge pull request #1895 from ogayot/pr/ssh-dissociate-form-submission
Dissociate SSH key import from form submission
This commit is contained in:
commit
ce938e6d03
|
@ -23,6 +23,7 @@ Identity:
|
|||
hostname: ubuntu-server
|
||||
# ubuntu
|
||||
password: '$6$wdAcoXrU039hKYPd$508Qvbe7ObUnxoj15DRCkzC3qO7edjH0VV7BPNRDYK4QR8ofJaEEF2heacn0QgD.f8pO8SNp83XNdWG6tocBM1'
|
||||
SSH:
|
||||
# We have a specific dry-run behaviour for heracles, making the call to
|
||||
# ssh-import-id succeed unconditionally.
|
||||
ssh-import-id: lp:heracles
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
import logging
|
||||
|
||||
from subiquity.client.controller import SubiquityTuiController
|
||||
from subiquity.common.types import SSHData, SSHFetchIdResponse, SSHFetchIdStatus
|
||||
from subiquity.ui.views.ssh import SSHView
|
||||
from subiquity.common.types import SSHFetchIdResponse, SSHFetchIdStatus
|
||||
from subiquity.ui.views.ssh import IMPORT_KEY_LABEL, ConfirmSSHKeys, SSHView
|
||||
from subiquitycore.async_helpers import schedule_task
|
||||
from subiquitycore.context import with_context
|
||||
|
||||
|
@ -45,18 +45,43 @@ class SSHController(SubiquityTuiController):
|
|||
ssh_data = await self.endpoint.GET()
|
||||
return SSHView(self, ssh_data)
|
||||
|
||||
def run_answers(self):
|
||||
async def run_answers(self):
|
||||
import subiquitycore.testing.view_helpers as view_helpers
|
||||
|
||||
form = self.app.ui.body.form
|
||||
form.install_server.value = self.answers.get("install_server", False)
|
||||
form.pwauth.value = self.answers.get("pwauth", True)
|
||||
|
||||
for key in self.answers.get("authorized_keys", []):
|
||||
# We don't have GUI support for this.
|
||||
self.app.ui.body.add_key_to_table(key)
|
||||
|
||||
if "ssh-import-id" in self.answers:
|
||||
import_id = self.answers["ssh-import-id"]
|
||||
ssh = SSHData(install_server=True, authorized_keys=[], allow_pw=True)
|
||||
self.fetch_ssh_keys(ssh_import_id=import_id, ssh_data=ssh)
|
||||
else:
|
||||
ssh = SSHData(
|
||||
install_server=self.answers.get("install_server", False),
|
||||
authorized_keys=self.answers.get("authorized_keys", []),
|
||||
allow_pw=self.answers.get("pwauth", True),
|
||||
view_helpers.click(
|
||||
view_helpers.find_button_matching(self.app.ui.body, IMPORT_KEY_LABEL)
|
||||
)
|
||||
self.done(ssh)
|
||||
service, username = self.answers["ssh-import-id"].split(":", maxsplit=1)
|
||||
if service not in ("gh", "lp"):
|
||||
raise ValueError(
|
||||
f"invalid service {service} - only gh and lp are supported"
|
||||
)
|
||||
|
||||
import_form = self.app.ui.body._w.stretchy.form
|
||||
|
||||
view_helpers.enter_data(
|
||||
import_form, {"import_username": username, "service": service}
|
||||
)
|
||||
|
||||
import_form._click_done(None)
|
||||
|
||||
# Wait until the key gets fetched
|
||||
confirm_overlay = await view_helpers.wait_for_overlay(
|
||||
self.ui, ConfirmSSHKeys
|
||||
)
|
||||
|
||||
confirm_overlay.ok(None)
|
||||
|
||||
form._click_done(None)
|
||||
|
||||
def cancel(self):
|
||||
self.app.prev_screen()
|
||||
|
@ -67,45 +92,33 @@ class SSHController(SubiquityTuiController):
|
|||
self._fetch_task.cancel()
|
||||
|
||||
@with_context(name="ssh_import_id", description="{ssh_import_id}")
|
||||
async def _fetch_ssh_keys(self, *, context, ssh_import_id, ssh_data):
|
||||
with self.context.child("ssh_import_id", ssh_import_id):
|
||||
response: SSHFetchIdResponse = await self.endpoint.fetch_id.GET(
|
||||
ssh_import_id
|
||||
)
|
||||
async def _fetch_ssh_keys(self, *, context, ssh_import_id):
|
||||
response: SSHFetchIdResponse = await self.endpoint.fetch_id.GET(ssh_import_id)
|
||||
|
||||
if response.status == SSHFetchIdStatus.IMPORT_ERROR:
|
||||
if isinstance(self.ui.body, SSHView):
|
||||
self.ui.body.fetching_ssh_keys_failed(
|
||||
_("Importing keys failed:"), response.error
|
||||
)
|
||||
return
|
||||
elif response.status == SSHFetchIdStatus.FINGERPRINT_ERROR:
|
||||
if isinstance(self.ui.body, SSHView):
|
||||
self.ui.body.fetching_ssh_keys_failed(
|
||||
_("ssh-keygen failed to show fingerprint of downloaded keys:"),
|
||||
response.error,
|
||||
)
|
||||
return
|
||||
if response.status == SSHFetchIdStatus.IMPORT_ERROR:
|
||||
if isinstance(self.ui.body, SSHView):
|
||||
self.ui.body.fetching_ssh_keys_failed(
|
||||
_("Importing keys failed:"), response.error
|
||||
)
|
||||
return
|
||||
elif response.status == SSHFetchIdStatus.FINGERPRINT_ERROR:
|
||||
if isinstance(self.ui.body, SSHView):
|
||||
self.ui.body.fetching_ssh_keys_failed(
|
||||
_("ssh-keygen failed to show fingerprint of downloaded keys:"),
|
||||
response.error,
|
||||
)
|
||||
return
|
||||
|
||||
identities = response.identities
|
||||
identities = response.identities
|
||||
|
||||
if "ssh-import-id" in self.app.answers.get("Identity", {}):
|
||||
ssh_data.authorized_keys = [
|
||||
id_.to_authorized_key() for id_ in identities
|
||||
]
|
||||
self.done(ssh_data)
|
||||
else:
|
||||
if isinstance(self.ui.body, SSHView):
|
||||
self.ui.body.confirm_ssh_keys(ssh_data, ssh_import_id, identities)
|
||||
else:
|
||||
log.debug(
|
||||
"ui.body of unexpected instance: %s",
|
||||
type(self.ui.body).__name__,
|
||||
)
|
||||
if isinstance(self.ui.body, SSHView):
|
||||
self.ui.body.confirm_ssh_keys(ssh_import_id, identities)
|
||||
else:
|
||||
log.debug("ui.body of unexpected instance: %s", type(self.ui.body).__name__)
|
||||
|
||||
def fetch_ssh_keys(self, ssh_import_id, ssh_data):
|
||||
def fetch_ssh_keys(self, ssh_import_id):
|
||||
self._fetch_task = schedule_task(
|
||||
self._fetch_ssh_keys(ssh_import_id=ssh_import_id, ssh_data=ssh_data)
|
||||
self._fetch_ssh_keys(ssh_import_id=ssh_import_id)
|
||||
)
|
||||
|
||||
def done(self, result):
|
||||
|
|
|
@ -21,27 +21,30 @@ from urwid import LineBox, Text, connect_signal
|
|||
|
||||
from subiquity.common.types import SSHData, SSHIdentity
|
||||
from subiquity.ui.views.identity import UsernameField
|
||||
from subiquitycore.ui.buttons import cancel_btn, ok_btn
|
||||
from subiquitycore.ui.actionmenu import Action, ActionMenu
|
||||
from subiquitycore.ui.buttons import cancel_btn, done_btn, menu_btn, ok_btn
|
||||
from subiquitycore.ui.container import ListBox, Pile, WidgetWrap
|
||||
from subiquitycore.ui.form import BooleanField, ChoiceField, Form
|
||||
from subiquitycore.ui.form import BooleanField, ChoiceField, Form, Toggleable
|
||||
from subiquitycore.ui.spinner import Spinner
|
||||
from subiquitycore.ui.stretchy import Stretchy
|
||||
from subiquitycore.ui.utils import SomethingFailed, button_pile, screen
|
||||
from subiquitycore.ui.table import ColSpec, TablePile, TableRow
|
||||
from subiquitycore.ui.utils import (
|
||||
Color,
|
||||
Padding,
|
||||
SomethingFailed,
|
||||
button_pile,
|
||||
make_action_menu_row,
|
||||
screen,
|
||||
)
|
||||
from subiquitycore.view import BaseView
|
||||
|
||||
log = logging.getLogger("subiquity.ui.views.ssh")
|
||||
|
||||
|
||||
SSH_IMPORT_MAXLEN = 256 + 3 # account for lp: or gh:
|
||||
IMPORT_KEY_LABEL = _("Import SSH key")
|
||||
|
||||
_ssh_import_data = {
|
||||
None: {
|
||||
"caption": _("Import Username:"),
|
||||
"help": "",
|
||||
"valid_char": ".",
|
||||
"error_invalid_char": "",
|
||||
"regex": ".*",
|
||||
},
|
||||
"gh": {
|
||||
"caption": _("GitHub Username:"),
|
||||
"help": _("Enter your GitHub username."),
|
||||
|
@ -63,57 +66,33 @@ _ssh_import_data = {
|
|||
}
|
||||
|
||||
|
||||
class SSHForm(Form):
|
||||
install_server = BooleanField(_("Install OpenSSH server"))
|
||||
|
||||
ssh_import_id = ChoiceField(
|
||||
class SSHImportForm(Form):
|
||||
service = ChoiceField(
|
||||
_("Import SSH identity:"),
|
||||
choices=[
|
||||
(_("No"), True, None),
|
||||
(_("from GitHub"), True, "gh"),
|
||||
(_("from Launchpad"), True, "lp"),
|
||||
],
|
||||
help=_("You can import your SSH keys from GitHub or Launchpad."),
|
||||
)
|
||||
|
||||
import_username = UsernameField(_ssh_import_data[None]["caption"])
|
||||
|
||||
pwauth = BooleanField(_("Allow password authentication over SSH"))
|
||||
|
||||
cancel_label = _("Back")
|
||||
|
||||
def __init__(self, initial):
|
||||
super().__init__(initial=initial)
|
||||
connect_signal(self.install_server.widget, "change", self._toggle_server)
|
||||
self._toggle_server(None, self.install_server.value)
|
||||
|
||||
def _toggle_server(self, sender, new_value):
|
||||
if new_value:
|
||||
self.ssh_import_id.enabled = True
|
||||
self.import_username.enabled = self.ssh_import_id.value is not None
|
||||
self.pwauth.enabled = self.ssh_import_id.value is not None
|
||||
else:
|
||||
self.ssh_import_id.enabled = False
|
||||
self.import_username.enabled = False
|
||||
self.pwauth.enabled = False
|
||||
import_username = UsernameField(_ssh_import_data["gh"]["caption"])
|
||||
|
||||
# validation of the import username does not read from
|
||||
# ssh_import_id.value because it is sometimes done from the
|
||||
# 'select' signal of the import id selector, which is called
|
||||
# before the import id selector's value has actually changed. so
|
||||
# service.value because it is sometimes done from the
|
||||
# 'select' signal of the service selector, which is called
|
||||
# before the service selector's value has actually changed. so
|
||||
# the signal handler stuffs the value here before doing
|
||||
# validation (yes, this is a hack).
|
||||
ssh_import_id_value = None
|
||||
service_value = None
|
||||
|
||||
def validate_import_username(self):
|
||||
if self.ssh_import_id_value is None:
|
||||
return
|
||||
username = self.import_username.value
|
||||
if len(username) == 0:
|
||||
return _("This field must not be blank.")
|
||||
if len(username) > SSH_IMPORT_MAXLEN:
|
||||
return _("SSH id too long, must be < ") + str(SSH_IMPORT_MAXLEN)
|
||||
if self.ssh_import_id_value == "lp":
|
||||
if self.service_value == "lp":
|
||||
lp_regex = r"^[a-z0-9][a-z0-9\+\.\-]*$"
|
||||
if not re.match(lp_regex, self.import_username.value):
|
||||
return _(
|
||||
|
@ -123,7 +102,7 @@ class SSHForm(Form):
|
|||
"the first character."
|
||||
""
|
||||
)
|
||||
elif self.ssh_import_id_value == "gh":
|
||||
elif self.service_value == "gh":
|
||||
if not re.match(r"^[a-zA-Z0-9\-]+$", username):
|
||||
return _(
|
||||
"A GitHub username may only contain alphanumeric "
|
||||
|
@ -132,6 +111,71 @@ class SSHForm(Form):
|
|||
)
|
||||
|
||||
|
||||
class SSHImportStretchy(Stretchy):
|
||||
def __init__(self, parent):
|
||||
self.parent = parent
|
||||
self.form = SSHImportForm(initial={})
|
||||
|
||||
connect_signal(self.form, "submit", lambda unused: self.done())
|
||||
connect_signal(self.form, "cancel", lambda unused: self.cancel())
|
||||
connect_signal(
|
||||
self.form.service.widget, "select", self._import_service_selected
|
||||
)
|
||||
|
||||
self._import_service_selected(
|
||||
sender=None, service=self.form.service.widget.value
|
||||
)
|
||||
|
||||
rows = self.form.as_rows()
|
||||
title = _("Import SSH key")
|
||||
super().__init__(title, [Pile(rows), Text(""), self.form.buttons], 0, 0)
|
||||
|
||||
def cancel(self):
|
||||
self.parent.remove_overlay(self)
|
||||
|
||||
def done(self):
|
||||
ssh_import_id = self.form.service.value + ":" + self.form.import_username.value
|
||||
fsk = FetchingSSHKeys(self.parent)
|
||||
self.parent.remove_overlay(self)
|
||||
self.parent.show_overlay(fsk, width=fsk.width, min_width=None)
|
||||
self.parent.controller.fetch_ssh_keys(ssh_import_id=ssh_import_id)
|
||||
|
||||
def _import_service_selected(self, sender, service: str):
|
||||
iu = self.form.import_username
|
||||
data = _ssh_import_data[service]
|
||||
iu.help = _(data["help"])
|
||||
iu.caption = _(data["caption"])
|
||||
iu.widget.valid_char_pat = data["valid_char"]
|
||||
iu.widget.error_invalid_char = _(data["error_invalid_char"])
|
||||
self.form.service_value = service
|
||||
if iu.value != "":
|
||||
iu.validate()
|
||||
|
||||
|
||||
class SSHShowKeyStretchy(Stretchy):
|
||||
def __init__(self, parent, key: str) -> None:
|
||||
self.parent = parent
|
||||
widgets = [
|
||||
Text(key),
|
||||
Text(""),
|
||||
button_pile([done_btn(_("Close"), on_press=self.close)]),
|
||||
]
|
||||
|
||||
title = _("SSH key")
|
||||
super().__init__(title, widgets, 0, 2)
|
||||
|
||||
def close(self, button=None) -> None:
|
||||
self.parent.remove_overlay()
|
||||
|
||||
|
||||
class SSHForm(Form):
|
||||
install_server = BooleanField(_("Install OpenSSH server"))
|
||||
|
||||
pwauth = BooleanField(_("Allow password authentication over SSH"))
|
||||
|
||||
cancel_label = _("Back")
|
||||
|
||||
|
||||
class FetchingSSHKeys(WidgetWrap):
|
||||
def __init__(self, parent):
|
||||
self.parent = parent
|
||||
|
@ -155,14 +199,13 @@ class FetchingSSHKeys(WidgetWrap):
|
|||
)
|
||||
|
||||
def cancel(self, sender):
|
||||
self.parent.controller._fetch_cancel()
|
||||
self.parent.remove_overlay()
|
||||
self.parent.controller._fetch_cancel()
|
||||
|
||||
|
||||
class ConfirmSSHKeys(Stretchy):
|
||||
def __init__(self, parent, ssh_data, identities: List[SSHIdentity]):
|
||||
def __init__(self, parent, identities: List[SSHIdentity]):
|
||||
self.parent = parent
|
||||
self.ssh_data = ssh_data
|
||||
self.identities: List[SSHIdentity] = identities
|
||||
|
||||
ok = ok_btn(label=_("Yes"), on_press=self.ok)
|
||||
|
@ -200,10 +243,11 @@ class ConfirmSSHKeys(Stretchy):
|
|||
self.parent.remove_overlay()
|
||||
|
||||
def ok(self, sender):
|
||||
self.ssh_data.authorized_keys = [
|
||||
id_.to_authorized_key() for id_ in self.identities
|
||||
]
|
||||
self.parent.controller.done(self.ssh_data)
|
||||
for identity in self.identities:
|
||||
self.parent.add_key_to_table(identity.to_authorized_key())
|
||||
self.parent.refresh_keys_table()
|
||||
|
||||
self.parent.remove_overlay()
|
||||
|
||||
|
||||
class SSHView(BaseView):
|
||||
|
@ -222,76 +266,156 @@ class SSHView(BaseView):
|
|||
}
|
||||
|
||||
self.form = SSHForm(initial=initial)
|
||||
self.keys = ssh_data.authorized_keys
|
||||
|
||||
connect_signal(
|
||||
self.form.ssh_import_id.widget, "select", self._select_ssh_import_id
|
||||
self._import_key_btn = Toggleable(
|
||||
menu_btn(
|
||||
label=IMPORT_KEY_LABEL,
|
||||
on_press=lambda unused: self.show_import_key_overlay(),
|
||||
)
|
||||
)
|
||||
bp = button_pile([self._import_key_btn])
|
||||
bp.align = "left"
|
||||
|
||||
colspecs = {
|
||||
0: ColSpec(rpad=1),
|
||||
1: ColSpec(can_shrink=True),
|
||||
2: ColSpec(rpad=1),
|
||||
3: ColSpec(rpad=1),
|
||||
}
|
||||
self.keys_table = TablePile([], colspecs=colspecs)
|
||||
self.refresh_keys_table()
|
||||
|
||||
rows = self.form.as_rows() + [
|
||||
Text(""),
|
||||
bp,
|
||||
Text(""),
|
||||
Text(_("AUTHORIZED KEYS")),
|
||||
Text(""),
|
||||
self.keys_table,
|
||||
]
|
||||
|
||||
connect_signal(self.form, "submit", self.done)
|
||||
connect_signal(self.form, "cancel", self.cancel)
|
||||
connect_signal(self.form.install_server.widget, "change", self._toggle_server)
|
||||
|
||||
self._toggle_server(None, self.form.install_server.value)
|
||||
|
||||
self.form_rows = ListBox(self.form.as_rows())
|
||||
super().__init__(
|
||||
screen(
|
||||
self.form_rows,
|
||||
ListBox(rows),
|
||||
self.form.buttons,
|
||||
excerpt=_(self.excerpt),
|
||||
focus_buttons=False,
|
||||
)
|
||||
)
|
||||
|
||||
def _select_ssh_import_id(self, sender, val):
|
||||
iu = self.form.import_username
|
||||
data = _ssh_import_data[val]
|
||||
iu.help = _(data["help"])
|
||||
iu.caption = _(data["caption"])
|
||||
iu.widget.valid_char_pat = data["valid_char"]
|
||||
iu.widget.error_invalid_char = _(data["error_invalid_char"])
|
||||
iu.enabled = val is not None
|
||||
self.form.pwauth.enabled = val is not None
|
||||
# The logic here is a little tortured but the idea is that if
|
||||
# the users switches from not importing a key to importing
|
||||
# one, untick pwauth (but don't fiddle with the users choice
|
||||
# if just switching between import sources), and conversely if
|
||||
# the user is not importing a key then the box has to be
|
||||
# ticked (and disabled).
|
||||
if (val is None) != (self.form.ssh_import_id.value is None):
|
||||
self.form.pwauth.value = val is None
|
||||
if val is not None:
|
||||
self.form_rows.base_widget.body.focus += 2
|
||||
self.form.ssh_import_id_value = val
|
||||
if iu.value != "" or val is None:
|
||||
iu.validate()
|
||||
|
||||
def done(self, sender):
|
||||
log.debug("User input: {}".format(self.form.as_data()))
|
||||
ssh_data = SSHData(
|
||||
install_server=self.form.install_server.value,
|
||||
allow_pw=self.form.pwauth.value,
|
||||
authorized_keys=self.keys,
|
||||
)
|
||||
|
||||
# if user specifed a value, allow user to validate fingerprint
|
||||
if self.form.ssh_import_id.value:
|
||||
ssh_import_id = (
|
||||
self.form.ssh_import_id.value + ":" + self.form.import_username.value
|
||||
)
|
||||
fsk = FetchingSSHKeys(self)
|
||||
self.show_overlay(fsk, width=fsk.width, min_width=None)
|
||||
self.controller.fetch_ssh_keys(
|
||||
ssh_import_id=ssh_import_id, ssh_data=ssh_data
|
||||
)
|
||||
else:
|
||||
self.controller.done(ssh_data)
|
||||
self.controller.done(ssh_data)
|
||||
|
||||
def cancel(self, result=None):
|
||||
self.controller.cancel()
|
||||
|
||||
def confirm_ssh_keys(self, ssh_data, ssh_import_id, identities: List[SSHIdentity]):
|
||||
def show_import_key_overlay(self):
|
||||
self.show_stretchy_overlay(SSHImportStretchy(self))
|
||||
|
||||
def confirm_ssh_keys(self, ssh_import_id, identities: List[SSHIdentity]):
|
||||
self.remove_overlay()
|
||||
self.show_stretchy_overlay(ConfirmSSHKeys(self, ssh_data, identities))
|
||||
self.show_stretchy_overlay(ConfirmSSHKeys(self, identities))
|
||||
|
||||
def fetching_ssh_keys_failed(self, msg, stderr):
|
||||
# FIXME in answers-based runs, the overlay does not exist so we pass
|
||||
# not_found_ok=True.
|
||||
self.remove_overlay(not_found_ok=True)
|
||||
self.remove_overlay()
|
||||
self.show_stretchy_overlay(SomethingFailed(self, msg, stderr))
|
||||
|
||||
def add_key_to_table(self, key: str) -> None:
|
||||
"""Add the specified key to the list of authorized keys. When adding
|
||||
the first one, we also disable password authentication (but give the
|
||||
user the ability to re-enable it)"""
|
||||
self.keys.append(key)
|
||||
if len(self.keys) == 1:
|
||||
self.form.pwauth.value = False
|
||||
if self.form.install_server:
|
||||
self.form.pwauth.enabled = True
|
||||
|
||||
def remove_key_from_table(self, key: str) -> None:
|
||||
"""Remove the specified key from the list of authorized keys. When
|
||||
removing the last one, we also re-enable password authentication (and
|
||||
disable the checkbox)."""
|
||||
self.keys.remove(key)
|
||||
if not self.keys:
|
||||
self.form.pwauth.value = True
|
||||
self.form.pwauth.enabled = False
|
||||
|
||||
def show_key_overlay(self, key: str):
|
||||
self.show_stretchy_overlay(SSHShowKeyStretchy(self, key))
|
||||
|
||||
def refresh_keys_table(self):
|
||||
rows: List[TableRow] = []
|
||||
|
||||
if not self.keys:
|
||||
rows = [
|
||||
TableRow(
|
||||
[
|
||||
(
|
||||
4,
|
||||
Padding.push_2(
|
||||
Color.info_minor(Text(_("No authorized key")))
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
]
|
||||
for key in self.keys:
|
||||
menu = ActionMenu(
|
||||
[
|
||||
Action(
|
||||
label=_("Delete"),
|
||||
enabled=True,
|
||||
value="delete",
|
||||
opens_dialog=False,
|
||||
),
|
||||
Action(
|
||||
label=_("Show"),
|
||||
enabled=True,
|
||||
value="show",
|
||||
opens_dialog=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
rows.append(
|
||||
make_action_menu_row(
|
||||
[
|
||||
Text("["),
|
||||
Text(key, wrap="ellipsis"),
|
||||
menu,
|
||||
Text("]"),
|
||||
],
|
||||
menu,
|
||||
)
|
||||
)
|
||||
connect_signal(menu, "action", self._action, key)
|
||||
|
||||
self.keys_table.set_contents(rows)
|
||||
|
||||
def _action(self, sender, value, key):
|
||||
if value == "delete":
|
||||
self.remove_key_from_table(key)
|
||||
self.refresh_keys_table()
|
||||
elif value == "show":
|
||||
self.show_key_overlay(key)
|
||||
|
||||
def _toggle_server(self, sender, installed: bool):
|
||||
self._import_key_btn.enabled = installed
|
||||
|
||||
if installed:
|
||||
self.form.pwauth.enabled = bool(self.keys)
|
||||
else:
|
||||
self.form.pwauth.enabled = False
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import asyncio
|
||||
import re
|
||||
from typing import Type
|
||||
|
||||
import urwid
|
||||
|
||||
from subiquitycore.ui.stretchy import StretchyOverlay
|
||||
from subiquitycore.ui.frame import SubiquityCoreUI
|
||||
from subiquitycore.ui.stretchy import Stretchy, StretchyOverlay
|
||||
|
||||
|
||||
def find_with_pred(w, pred, return_path=False):
|
||||
|
@ -86,3 +89,25 @@ def enter_data(form, data):
|
|||
bf = getattr(form, k)
|
||||
assert bf.enabled, "%s is not enabled" % (k,)
|
||||
bf.value = v
|
||||
|
||||
|
||||
async def wait_for_overlay(
|
||||
ui: SubiquityCoreUI, overlay_type: Type[Stretchy], *, timeout=None, step=0.01
|
||||
) -> Stretchy:
|
||||
"""Wait until an overlay of the specified type gets displayed on the
|
||||
screen and return it. If timeout is hit before the overlay is displayed, an
|
||||
asyncio.TimeoutError will be raised. When timeout is set to None, we will
|
||||
wait forever."""
|
||||
if timeout is not None:
|
||||
task = wait_for_overlay(ui, overlay_type)
|
||||
return await asyncio.wait_for(task, timeout=timeout)
|
||||
|
||||
while True:
|
||||
try:
|
||||
stretchy = ui.body._w.stretchy
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
if isinstance(stretchy, overlay_type):
|
||||
return stretchy
|
||||
await asyncio.sleep(step)
|
||||
|
|
Loading…
Reference in New Issue