diff --git a/examples/answers/answers.yaml b/examples/answers/answers.yaml index 6ea54f58..45695025 100644 --- a/examples/answers/answers.yaml +++ b/examples/answers/answers.yaml @@ -23,6 +23,7 @@ Identity: hostname: ubuntu-server # ubuntu password: '$6$wdAcoXrU039hKYPd$508Qvbe7ObUnxoj15DRCkzC3qO7edjH0VV7BPNRDYK4QR8ofJaEEF2heacn0QgD.f8pO8SNp83XNdWG6tocBM1' +SSH: # We have a specific dry-run behaviour for heracles, making the call to # ssh-import-id succeed unconditionally. ssh-import-id: lp:heracles diff --git a/subiquity/client/controllers/ssh.py b/subiquity/client/controllers/ssh.py index 832f8c0a..6158dce6 100644 --- a/subiquity/client/controllers/ssh.py +++ b/subiquity/client/controllers/ssh.py @@ -16,8 +16,8 @@ import logging from subiquity.client.controller import SubiquityTuiController -from subiquity.common.types import SSHData, SSHFetchIdResponse, SSHFetchIdStatus -from subiquity.ui.views.ssh import SSHView +from subiquity.common.types import SSHFetchIdResponse, SSHFetchIdStatus +from subiquity.ui.views.ssh import IMPORT_KEY_LABEL, ConfirmSSHKeys, SSHView from subiquitycore.async_helpers import schedule_task from subiquitycore.context import with_context @@ -45,18 +45,43 @@ class SSHController(SubiquityTuiController): ssh_data = await self.endpoint.GET() return SSHView(self, ssh_data) - def run_answers(self): + async def run_answers(self): + import subiquitycore.testing.view_helpers as view_helpers + + form = self.app.ui.body.form + form.install_server.value = self.answers.get("install_server", False) + form.pwauth.value = self.answers.get("pwauth", True) + + for key in self.answers.get("authorized_keys", []): + # We don't have GUI support for this. + self.app.ui.body.add_key_to_table(key) + if "ssh-import-id" in self.answers: - import_id = self.answers["ssh-import-id"] - ssh = SSHData(install_server=True, authorized_keys=[], allow_pw=True) - self.fetch_ssh_keys(ssh_import_id=import_id, ssh_data=ssh) - else: - ssh = SSHData( - install_server=self.answers.get("install_server", False), - authorized_keys=self.answers.get("authorized_keys", []), - allow_pw=self.answers.get("pwauth", True), + view_helpers.click( + view_helpers.find_button_matching(self.app.ui.body, IMPORT_KEY_LABEL) ) - self.done(ssh) + service, username = self.answers["ssh-import-id"].split(":", maxsplit=1) + if service not in ("gh", "lp"): + raise ValueError( + f"invalid service {service} - only gh and lp are supported" + ) + + import_form = self.app.ui.body._w.stretchy.form + + view_helpers.enter_data( + import_form, {"import_username": username, "service": service} + ) + + import_form._click_done(None) + + # Wait until the key gets fetched + confirm_overlay = await view_helpers.wait_for_overlay( + self.ui, ConfirmSSHKeys + ) + + confirm_overlay.ok(None) + + form._click_done(None) def cancel(self): self.app.prev_screen() @@ -67,45 +92,33 @@ class SSHController(SubiquityTuiController): self._fetch_task.cancel() @with_context(name="ssh_import_id", description="{ssh_import_id}") - async def _fetch_ssh_keys(self, *, context, ssh_import_id, ssh_data): - with self.context.child("ssh_import_id", ssh_import_id): - response: SSHFetchIdResponse = await self.endpoint.fetch_id.GET( - ssh_import_id - ) + async def _fetch_ssh_keys(self, *, context, ssh_import_id): + response: SSHFetchIdResponse = await self.endpoint.fetch_id.GET(ssh_import_id) - if response.status == SSHFetchIdStatus.IMPORT_ERROR: - if isinstance(self.ui.body, SSHView): - self.ui.body.fetching_ssh_keys_failed( - _("Importing keys failed:"), response.error - ) - return - elif response.status == SSHFetchIdStatus.FINGERPRINT_ERROR: - if isinstance(self.ui.body, SSHView): - self.ui.body.fetching_ssh_keys_failed( - _("ssh-keygen failed to show fingerprint of downloaded keys:"), - response.error, - ) - return + if response.status == SSHFetchIdStatus.IMPORT_ERROR: + if isinstance(self.ui.body, SSHView): + self.ui.body.fetching_ssh_keys_failed( + _("Importing keys failed:"), response.error + ) + return + elif response.status == SSHFetchIdStatus.FINGERPRINT_ERROR: + if isinstance(self.ui.body, SSHView): + self.ui.body.fetching_ssh_keys_failed( + _("ssh-keygen failed to show fingerprint of downloaded keys:"), + response.error, + ) + return - identities = response.identities + identities = response.identities - if "ssh-import-id" in self.app.answers.get("Identity", {}): - ssh_data.authorized_keys = [ - id_.to_authorized_key() for id_ in identities - ] - self.done(ssh_data) - else: - if isinstance(self.ui.body, SSHView): - self.ui.body.confirm_ssh_keys(ssh_data, ssh_import_id, identities) - else: - log.debug( - "ui.body of unexpected instance: %s", - type(self.ui.body).__name__, - ) + if isinstance(self.ui.body, SSHView): + self.ui.body.confirm_ssh_keys(ssh_import_id, identities) + else: + log.debug("ui.body of unexpected instance: %s", type(self.ui.body).__name__) - def fetch_ssh_keys(self, ssh_import_id, ssh_data): + def fetch_ssh_keys(self, ssh_import_id): self._fetch_task = schedule_task( - self._fetch_ssh_keys(ssh_import_id=ssh_import_id, ssh_data=ssh_data) + self._fetch_ssh_keys(ssh_import_id=ssh_import_id) ) def done(self, result): diff --git a/subiquity/ui/views/ssh.py b/subiquity/ui/views/ssh.py index ee0582c1..90f1f9b3 100644 --- a/subiquity/ui/views/ssh.py +++ b/subiquity/ui/views/ssh.py @@ -21,27 +21,30 @@ from urwid import LineBox, Text, connect_signal from subiquity.common.types import SSHData, SSHIdentity from subiquity.ui.views.identity import UsernameField -from subiquitycore.ui.buttons import cancel_btn, ok_btn +from subiquitycore.ui.actionmenu import Action, ActionMenu +from subiquitycore.ui.buttons import cancel_btn, done_btn, menu_btn, ok_btn from subiquitycore.ui.container import ListBox, Pile, WidgetWrap -from subiquitycore.ui.form import BooleanField, ChoiceField, Form +from subiquitycore.ui.form import BooleanField, ChoiceField, Form, Toggleable from subiquitycore.ui.spinner import Spinner from subiquitycore.ui.stretchy import Stretchy -from subiquitycore.ui.utils import SomethingFailed, button_pile, screen +from subiquitycore.ui.table import ColSpec, TablePile, TableRow +from subiquitycore.ui.utils import ( + Color, + Padding, + SomethingFailed, + button_pile, + make_action_menu_row, + screen, +) from subiquitycore.view import BaseView log = logging.getLogger("subiquity.ui.views.ssh") SSH_IMPORT_MAXLEN = 256 + 3 # account for lp: or gh: +IMPORT_KEY_LABEL = _("Import SSH key") _ssh_import_data = { - None: { - "caption": _("Import Username:"), - "help": "", - "valid_char": ".", - "error_invalid_char": "", - "regex": ".*", - }, "gh": { "caption": _("GitHub Username:"), "help": _("Enter your GitHub username."), @@ -63,57 +66,33 @@ _ssh_import_data = { } -class SSHForm(Form): - install_server = BooleanField(_("Install OpenSSH server")) - - ssh_import_id = ChoiceField( +class SSHImportForm(Form): + service = ChoiceField( _("Import SSH identity:"), choices=[ - (_("No"), True, None), (_("from GitHub"), True, "gh"), (_("from Launchpad"), True, "lp"), ], help=_("You can import your SSH keys from GitHub or Launchpad."), ) - import_username = UsernameField(_ssh_import_data[None]["caption"]) - - pwauth = BooleanField(_("Allow password authentication over SSH")) - - cancel_label = _("Back") - - def __init__(self, initial): - super().__init__(initial=initial) - connect_signal(self.install_server.widget, "change", self._toggle_server) - self._toggle_server(None, self.install_server.value) - - def _toggle_server(self, sender, new_value): - if new_value: - self.ssh_import_id.enabled = True - self.import_username.enabled = self.ssh_import_id.value is not None - self.pwauth.enabled = self.ssh_import_id.value is not None - else: - self.ssh_import_id.enabled = False - self.import_username.enabled = False - self.pwauth.enabled = False + import_username = UsernameField(_ssh_import_data["gh"]["caption"]) # validation of the import username does not read from - # ssh_import_id.value because it is sometimes done from the - # 'select' signal of the import id selector, which is called - # before the import id selector's value has actually changed. so + # service.value because it is sometimes done from the + # 'select' signal of the service selector, which is called + # before the service selector's value has actually changed. so # the signal handler stuffs the value here before doing # validation (yes, this is a hack). - ssh_import_id_value = None + service_value = None def validate_import_username(self): - if self.ssh_import_id_value is None: - return username = self.import_username.value if len(username) == 0: return _("This field must not be blank.") if len(username) > SSH_IMPORT_MAXLEN: return _("SSH id too long, must be < ") + str(SSH_IMPORT_MAXLEN) - if self.ssh_import_id_value == "lp": + if self.service_value == "lp": lp_regex = r"^[a-z0-9][a-z0-9\+\.\-]*$" if not re.match(lp_regex, self.import_username.value): return _( @@ -123,7 +102,7 @@ class SSHForm(Form): "the first character." "" ) - elif self.ssh_import_id_value == "gh": + elif self.service_value == "gh": if not re.match(r"^[a-zA-Z0-9\-]+$", username): return _( "A GitHub username may only contain alphanumeric " @@ -132,6 +111,71 @@ class SSHForm(Form): ) +class SSHImportStretchy(Stretchy): + def __init__(self, parent): + self.parent = parent + self.form = SSHImportForm(initial={}) + + connect_signal(self.form, "submit", lambda unused: self.done()) + connect_signal(self.form, "cancel", lambda unused: self.cancel()) + connect_signal( + self.form.service.widget, "select", self._import_service_selected + ) + + self._import_service_selected( + sender=None, service=self.form.service.widget.value + ) + + rows = self.form.as_rows() + title = _("Import SSH key") + super().__init__(title, [Pile(rows), Text(""), self.form.buttons], 0, 0) + + def cancel(self): + self.parent.remove_overlay(self) + + def done(self): + ssh_import_id = self.form.service.value + ":" + self.form.import_username.value + fsk = FetchingSSHKeys(self.parent) + self.parent.remove_overlay(self) + self.parent.show_overlay(fsk, width=fsk.width, min_width=None) + self.parent.controller.fetch_ssh_keys(ssh_import_id=ssh_import_id) + + def _import_service_selected(self, sender, service: str): + iu = self.form.import_username + data = _ssh_import_data[service] + iu.help = _(data["help"]) + iu.caption = _(data["caption"]) + iu.widget.valid_char_pat = data["valid_char"] + iu.widget.error_invalid_char = _(data["error_invalid_char"]) + self.form.service_value = service + if iu.value != "": + iu.validate() + + +class SSHShowKeyStretchy(Stretchy): + def __init__(self, parent, key: str) -> None: + self.parent = parent + widgets = [ + Text(key), + Text(""), + button_pile([done_btn(_("Close"), on_press=self.close)]), + ] + + title = _("SSH key") + super().__init__(title, widgets, 0, 2) + + def close(self, button=None) -> None: + self.parent.remove_overlay() + + +class SSHForm(Form): + install_server = BooleanField(_("Install OpenSSH server")) + + pwauth = BooleanField(_("Allow password authentication over SSH")) + + cancel_label = _("Back") + + class FetchingSSHKeys(WidgetWrap): def __init__(self, parent): self.parent = parent @@ -155,14 +199,13 @@ class FetchingSSHKeys(WidgetWrap): ) def cancel(self, sender): - self.parent.controller._fetch_cancel() self.parent.remove_overlay() + self.parent.controller._fetch_cancel() class ConfirmSSHKeys(Stretchy): - def __init__(self, parent, ssh_data, identities: List[SSHIdentity]): + def __init__(self, parent, identities: List[SSHIdentity]): self.parent = parent - self.ssh_data = ssh_data self.identities: List[SSHIdentity] = identities ok = ok_btn(label=_("Yes"), on_press=self.ok) @@ -200,10 +243,11 @@ class ConfirmSSHKeys(Stretchy): self.parent.remove_overlay() def ok(self, sender): - self.ssh_data.authorized_keys = [ - id_.to_authorized_key() for id_ in self.identities - ] - self.parent.controller.done(self.ssh_data) + for identity in self.identities: + self.parent.add_key_to_table(identity.to_authorized_key()) + self.parent.refresh_keys_table() + + self.parent.remove_overlay() class SSHView(BaseView): @@ -222,76 +266,156 @@ class SSHView(BaseView): } self.form = SSHForm(initial=initial) + self.keys = ssh_data.authorized_keys - connect_signal( - self.form.ssh_import_id.widget, "select", self._select_ssh_import_id + self._import_key_btn = Toggleable( + menu_btn( + label=IMPORT_KEY_LABEL, + on_press=lambda unused: self.show_import_key_overlay(), + ) ) + bp = button_pile([self._import_key_btn]) + bp.align = "left" + + colspecs = { + 0: ColSpec(rpad=1), + 1: ColSpec(can_shrink=True), + 2: ColSpec(rpad=1), + 3: ColSpec(rpad=1), + } + self.keys_table = TablePile([], colspecs=colspecs) + self.refresh_keys_table() + + rows = self.form.as_rows() + [ + Text(""), + bp, + Text(""), + Text(_("AUTHORIZED KEYS")), + Text(""), + self.keys_table, + ] connect_signal(self.form, "submit", self.done) connect_signal(self.form, "cancel", self.cancel) + connect_signal(self.form.install_server.widget, "change", self._toggle_server) + + self._toggle_server(None, self.form.install_server.value) - self.form_rows = ListBox(self.form.as_rows()) super().__init__( screen( - self.form_rows, + ListBox(rows), self.form.buttons, excerpt=_(self.excerpt), focus_buttons=False, ) ) - def _select_ssh_import_id(self, sender, val): - iu = self.form.import_username - data = _ssh_import_data[val] - iu.help = _(data["help"]) - iu.caption = _(data["caption"]) - iu.widget.valid_char_pat = data["valid_char"] - iu.widget.error_invalid_char = _(data["error_invalid_char"]) - iu.enabled = val is not None - self.form.pwauth.enabled = val is not None - # The logic here is a little tortured but the idea is that if - # the users switches from not importing a key to importing - # one, untick pwauth (but don't fiddle with the users choice - # if just switching between import sources), and conversely if - # the user is not importing a key then the box has to be - # ticked (and disabled). - if (val is None) != (self.form.ssh_import_id.value is None): - self.form.pwauth.value = val is None - if val is not None: - self.form_rows.base_widget.body.focus += 2 - self.form.ssh_import_id_value = val - if iu.value != "" or val is None: - iu.validate() - def done(self, sender): log.debug("User input: {}".format(self.form.as_data())) ssh_data = SSHData( install_server=self.form.install_server.value, allow_pw=self.form.pwauth.value, + authorized_keys=self.keys, ) - # if user specifed a value, allow user to validate fingerprint - if self.form.ssh_import_id.value: - ssh_import_id = ( - self.form.ssh_import_id.value + ":" + self.form.import_username.value - ) - fsk = FetchingSSHKeys(self) - self.show_overlay(fsk, width=fsk.width, min_width=None) - self.controller.fetch_ssh_keys( - ssh_import_id=ssh_import_id, ssh_data=ssh_data - ) - else: - self.controller.done(ssh_data) + self.controller.done(ssh_data) def cancel(self, result=None): self.controller.cancel() - def confirm_ssh_keys(self, ssh_data, ssh_import_id, identities: List[SSHIdentity]): + def show_import_key_overlay(self): + self.show_stretchy_overlay(SSHImportStretchy(self)) + + def confirm_ssh_keys(self, ssh_import_id, identities: List[SSHIdentity]): self.remove_overlay() - self.show_stretchy_overlay(ConfirmSSHKeys(self, ssh_data, identities)) + self.show_stretchy_overlay(ConfirmSSHKeys(self, identities)) def fetching_ssh_keys_failed(self, msg, stderr): - # FIXME in answers-based runs, the overlay does not exist so we pass - # not_found_ok=True. - self.remove_overlay(not_found_ok=True) + self.remove_overlay() self.show_stretchy_overlay(SomethingFailed(self, msg, stderr)) + + def add_key_to_table(self, key: str) -> None: + """Add the specified key to the list of authorized keys. When adding + the first one, we also disable password authentication (but give the + user the ability to re-enable it)""" + self.keys.append(key) + if len(self.keys) == 1: + self.form.pwauth.value = False + if self.form.install_server: + self.form.pwauth.enabled = True + + def remove_key_from_table(self, key: str) -> None: + """Remove the specified key from the list of authorized keys. When + removing the last one, we also re-enable password authentication (and + disable the checkbox).""" + self.keys.remove(key) + if not self.keys: + self.form.pwauth.value = True + self.form.pwauth.enabled = False + + def show_key_overlay(self, key: str): + self.show_stretchy_overlay(SSHShowKeyStretchy(self, key)) + + def refresh_keys_table(self): + rows: List[TableRow] = [] + + if not self.keys: + rows = [ + TableRow( + [ + ( + 4, + Padding.push_2( + Color.info_minor(Text(_("No authorized key"))) + ), + ) + ] + ) + ] + for key in self.keys: + menu = ActionMenu( + [ + Action( + label=_("Delete"), + enabled=True, + value="delete", + opens_dialog=False, + ), + Action( + label=_("Show"), + enabled=True, + value="show", + opens_dialog=True, + ), + ] + ) + + rows.append( + make_action_menu_row( + [ + Text("["), + Text(key, wrap="ellipsis"), + menu, + Text("]"), + ], + menu, + ) + ) + connect_signal(menu, "action", self._action, key) + + self.keys_table.set_contents(rows) + + def _action(self, sender, value, key): + if value == "delete": + self.remove_key_from_table(key) + self.refresh_keys_table() + elif value == "show": + self.show_key_overlay(key) + + def _toggle_server(self, sender, installed: bool): + self._import_key_btn.enabled = installed + + if installed: + self.form.pwauth.enabled = bool(self.keys) + else: + self.form.pwauth.enabled = False diff --git a/subiquitycore/testing/view_helpers.py b/subiquitycore/testing/view_helpers.py index 1e87c0ff..b32bb0dc 100644 --- a/subiquitycore/testing/view_helpers.py +++ b/subiquitycore/testing/view_helpers.py @@ -1,8 +1,11 @@ +import asyncio import re +from typing import Type import urwid -from subiquitycore.ui.stretchy import StretchyOverlay +from subiquitycore.ui.frame import SubiquityCoreUI +from subiquitycore.ui.stretchy import Stretchy, StretchyOverlay def find_with_pred(w, pred, return_path=False): @@ -86,3 +89,25 @@ def enter_data(form, data): bf = getattr(form, k) assert bf.enabled, "%s is not enabled" % (k,) bf.value = v + + +async def wait_for_overlay( + ui: SubiquityCoreUI, overlay_type: Type[Stretchy], *, timeout=None, step=0.01 +) -> Stretchy: + """Wait until an overlay of the specified type gets displayed on the + screen and return it. If timeout is hit before the overlay is displayed, an + asyncio.TimeoutError will be raised. When timeout is set to None, we will + wait forever.""" + if timeout is not None: + task = wait_for_overlay(ui, overlay_type) + return await asyncio.wait_for(task, timeout=timeout) + + while True: + try: + stretchy = ui.body._w.stretchy + except AttributeError: + pass + else: + if isinstance(stretchy, overlay_type): + return stretchy + await asyncio.sleep(step)