Merge pull request #1895 from ogayot/pr/ssh-dissociate-form-submission

Dissociate SSH key import from form submission
This commit is contained in:
Olivier Gayot 2024-02-13 19:20:34 +01:00 committed by GitHub
commit ce938e6d03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 304 additions and 141 deletions

View File

@ -23,6 +23,7 @@ Identity:
hostname: ubuntu-server hostname: ubuntu-server
# ubuntu # ubuntu
password: '$6$wdAcoXrU039hKYPd$508Qvbe7ObUnxoj15DRCkzC3qO7edjH0VV7BPNRDYK4QR8ofJaEEF2heacn0QgD.f8pO8SNp83XNdWG6tocBM1' password: '$6$wdAcoXrU039hKYPd$508Qvbe7ObUnxoj15DRCkzC3qO7edjH0VV7BPNRDYK4QR8ofJaEEF2heacn0QgD.f8pO8SNp83XNdWG6tocBM1'
SSH:
# We have a specific dry-run behaviour for heracles, making the call to # We have a specific dry-run behaviour for heracles, making the call to
# ssh-import-id succeed unconditionally. # ssh-import-id succeed unconditionally.
ssh-import-id: lp:heracles ssh-import-id: lp:heracles

View File

@ -16,8 +16,8 @@
import logging import logging
from subiquity.client.controller import SubiquityTuiController from subiquity.client.controller import SubiquityTuiController
from subiquity.common.types import SSHData, SSHFetchIdResponse, SSHFetchIdStatus from subiquity.common.types import SSHFetchIdResponse, SSHFetchIdStatus
from subiquity.ui.views.ssh import SSHView from subiquity.ui.views.ssh import IMPORT_KEY_LABEL, ConfirmSSHKeys, SSHView
from subiquitycore.async_helpers import schedule_task from subiquitycore.async_helpers import schedule_task
from subiquitycore.context import with_context from subiquitycore.context import with_context
@ -45,18 +45,43 @@ class SSHController(SubiquityTuiController):
ssh_data = await self.endpoint.GET() ssh_data = await self.endpoint.GET()
return SSHView(self, ssh_data) return SSHView(self, ssh_data)
def run_answers(self): async def run_answers(self):
import subiquitycore.testing.view_helpers as view_helpers
form = self.app.ui.body.form
form.install_server.value = self.answers.get("install_server", False)
form.pwauth.value = self.answers.get("pwauth", True)
for key in self.answers.get("authorized_keys", []):
# We don't have GUI support for this.
self.app.ui.body.add_key_to_table(key)
if "ssh-import-id" in self.answers: if "ssh-import-id" in self.answers:
import_id = self.answers["ssh-import-id"] view_helpers.click(
ssh = SSHData(install_server=True, authorized_keys=[], allow_pw=True) view_helpers.find_button_matching(self.app.ui.body, IMPORT_KEY_LABEL)
self.fetch_ssh_keys(ssh_import_id=import_id, ssh_data=ssh)
else:
ssh = SSHData(
install_server=self.answers.get("install_server", False),
authorized_keys=self.answers.get("authorized_keys", []),
allow_pw=self.answers.get("pwauth", True),
) )
self.done(ssh) service, username = self.answers["ssh-import-id"].split(":", maxsplit=1)
if service not in ("gh", "lp"):
raise ValueError(
f"invalid service {service} - only gh and lp are supported"
)
import_form = self.app.ui.body._w.stretchy.form
view_helpers.enter_data(
import_form, {"import_username": username, "service": service}
)
import_form._click_done(None)
# Wait until the key gets fetched
confirm_overlay = await view_helpers.wait_for_overlay(
self.ui, ConfirmSSHKeys
)
confirm_overlay.ok(None)
form._click_done(None)
def cancel(self): def cancel(self):
self.app.prev_screen() self.app.prev_screen()
@ -67,45 +92,33 @@ class SSHController(SubiquityTuiController):
self._fetch_task.cancel() self._fetch_task.cancel()
@with_context(name="ssh_import_id", description="{ssh_import_id}") @with_context(name="ssh_import_id", description="{ssh_import_id}")
async def _fetch_ssh_keys(self, *, context, ssh_import_id, ssh_data): async def _fetch_ssh_keys(self, *, context, ssh_import_id):
with self.context.child("ssh_import_id", ssh_import_id): response: SSHFetchIdResponse = await self.endpoint.fetch_id.GET(ssh_import_id)
response: SSHFetchIdResponse = await self.endpoint.fetch_id.GET(
ssh_import_id
)
if response.status == SSHFetchIdStatus.IMPORT_ERROR: if response.status == SSHFetchIdStatus.IMPORT_ERROR:
if isinstance(self.ui.body, SSHView): if isinstance(self.ui.body, SSHView):
self.ui.body.fetching_ssh_keys_failed( self.ui.body.fetching_ssh_keys_failed(
_("Importing keys failed:"), response.error _("Importing keys failed:"), response.error
) )
return return
elif response.status == SSHFetchIdStatus.FINGERPRINT_ERROR: elif response.status == SSHFetchIdStatus.FINGERPRINT_ERROR:
if isinstance(self.ui.body, SSHView): if isinstance(self.ui.body, SSHView):
self.ui.body.fetching_ssh_keys_failed( self.ui.body.fetching_ssh_keys_failed(
_("ssh-keygen failed to show fingerprint of downloaded keys:"), _("ssh-keygen failed to show fingerprint of downloaded keys:"),
response.error, response.error,
) )
return return
identities = response.identities identities = response.identities
if "ssh-import-id" in self.app.answers.get("Identity", {}): if isinstance(self.ui.body, SSHView):
ssh_data.authorized_keys = [ self.ui.body.confirm_ssh_keys(ssh_import_id, identities)
id_.to_authorized_key() for id_ in identities else:
] log.debug("ui.body of unexpected instance: %s", type(self.ui.body).__name__)
self.done(ssh_data)
else:
if isinstance(self.ui.body, SSHView):
self.ui.body.confirm_ssh_keys(ssh_data, ssh_import_id, identities)
else:
log.debug(
"ui.body of unexpected instance: %s",
type(self.ui.body).__name__,
)
def fetch_ssh_keys(self, ssh_import_id, ssh_data): def fetch_ssh_keys(self, ssh_import_id):
self._fetch_task = schedule_task( self._fetch_task = schedule_task(
self._fetch_ssh_keys(ssh_import_id=ssh_import_id, ssh_data=ssh_data) self._fetch_ssh_keys(ssh_import_id=ssh_import_id)
) )
def done(self, result): def done(self, result):

View File

@ -21,27 +21,30 @@ from urwid import LineBox, Text, connect_signal
from subiquity.common.types import SSHData, SSHIdentity from subiquity.common.types import SSHData, SSHIdentity
from subiquity.ui.views.identity import UsernameField from subiquity.ui.views.identity import UsernameField
from subiquitycore.ui.buttons import cancel_btn, ok_btn from subiquitycore.ui.actionmenu import Action, ActionMenu
from subiquitycore.ui.buttons import cancel_btn, done_btn, menu_btn, ok_btn
from subiquitycore.ui.container import ListBox, Pile, WidgetWrap from subiquitycore.ui.container import ListBox, Pile, WidgetWrap
from subiquitycore.ui.form import BooleanField, ChoiceField, Form from subiquitycore.ui.form import BooleanField, ChoiceField, Form, Toggleable
from subiquitycore.ui.spinner import Spinner from subiquitycore.ui.spinner import Spinner
from subiquitycore.ui.stretchy import Stretchy from subiquitycore.ui.stretchy import Stretchy
from subiquitycore.ui.utils import SomethingFailed, button_pile, screen from subiquitycore.ui.table import ColSpec, TablePile, TableRow
from subiquitycore.ui.utils import (
Color,
Padding,
SomethingFailed,
button_pile,
make_action_menu_row,
screen,
)
from subiquitycore.view import BaseView from subiquitycore.view import BaseView
log = logging.getLogger("subiquity.ui.views.ssh") log = logging.getLogger("subiquity.ui.views.ssh")
SSH_IMPORT_MAXLEN = 256 + 3 # account for lp: or gh: SSH_IMPORT_MAXLEN = 256 + 3 # account for lp: or gh:
IMPORT_KEY_LABEL = _("Import SSH key")
_ssh_import_data = { _ssh_import_data = {
None: {
"caption": _("Import Username:"),
"help": "",
"valid_char": ".",
"error_invalid_char": "",
"regex": ".*",
},
"gh": { "gh": {
"caption": _("GitHub Username:"), "caption": _("GitHub Username:"),
"help": _("Enter your GitHub username."), "help": _("Enter your GitHub username."),
@ -63,57 +66,33 @@ _ssh_import_data = {
} }
class SSHForm(Form): class SSHImportForm(Form):
install_server = BooleanField(_("Install OpenSSH server")) service = ChoiceField(
ssh_import_id = ChoiceField(
_("Import SSH identity:"), _("Import SSH identity:"),
choices=[ choices=[
(_("No"), True, None),
(_("from GitHub"), True, "gh"), (_("from GitHub"), True, "gh"),
(_("from Launchpad"), True, "lp"), (_("from Launchpad"), True, "lp"),
], ],
help=_("You can import your SSH keys from GitHub or Launchpad."), help=_("You can import your SSH keys from GitHub or Launchpad."),
) )
import_username = UsernameField(_ssh_import_data[None]["caption"]) import_username = UsernameField(_ssh_import_data["gh"]["caption"])
pwauth = BooleanField(_("Allow password authentication over SSH"))
cancel_label = _("Back")
def __init__(self, initial):
super().__init__(initial=initial)
connect_signal(self.install_server.widget, "change", self._toggle_server)
self._toggle_server(None, self.install_server.value)
def _toggle_server(self, sender, new_value):
if new_value:
self.ssh_import_id.enabled = True
self.import_username.enabled = self.ssh_import_id.value is not None
self.pwauth.enabled = self.ssh_import_id.value is not None
else:
self.ssh_import_id.enabled = False
self.import_username.enabled = False
self.pwauth.enabled = False
# validation of the import username does not read from # validation of the import username does not read from
# ssh_import_id.value because it is sometimes done from the # service.value because it is sometimes done from the
# 'select' signal of the import id selector, which is called # 'select' signal of the service selector, which is called
# before the import id selector's value has actually changed. so # before the service selector's value has actually changed. so
# the signal handler stuffs the value here before doing # the signal handler stuffs the value here before doing
# validation (yes, this is a hack). # validation (yes, this is a hack).
ssh_import_id_value = None service_value = None
def validate_import_username(self): def validate_import_username(self):
if self.ssh_import_id_value is None:
return
username = self.import_username.value username = self.import_username.value
if len(username) == 0: if len(username) == 0:
return _("This field must not be blank.") return _("This field must not be blank.")
if len(username) > SSH_IMPORT_MAXLEN: if len(username) > SSH_IMPORT_MAXLEN:
return _("SSH id too long, must be < ") + str(SSH_IMPORT_MAXLEN) return _("SSH id too long, must be < ") + str(SSH_IMPORT_MAXLEN)
if self.ssh_import_id_value == "lp": if self.service_value == "lp":
lp_regex = r"^[a-z0-9][a-z0-9\+\.\-]*$" lp_regex = r"^[a-z0-9][a-z0-9\+\.\-]*$"
if not re.match(lp_regex, self.import_username.value): if not re.match(lp_regex, self.import_username.value):
return _( return _(
@ -123,7 +102,7 @@ class SSHForm(Form):
"the first character." "the first character."
"" ""
) )
elif self.ssh_import_id_value == "gh": elif self.service_value == "gh":
if not re.match(r"^[a-zA-Z0-9\-]+$", username): if not re.match(r"^[a-zA-Z0-9\-]+$", username):
return _( return _(
"A GitHub username may only contain alphanumeric " "A GitHub username may only contain alphanumeric "
@ -132,6 +111,71 @@ class SSHForm(Form):
) )
class SSHImportStretchy(Stretchy):
def __init__(self, parent):
self.parent = parent
self.form = SSHImportForm(initial={})
connect_signal(self.form, "submit", lambda unused: self.done())
connect_signal(self.form, "cancel", lambda unused: self.cancel())
connect_signal(
self.form.service.widget, "select", self._import_service_selected
)
self._import_service_selected(
sender=None, service=self.form.service.widget.value
)
rows = self.form.as_rows()
title = _("Import SSH key")
super().__init__(title, [Pile(rows), Text(""), self.form.buttons], 0, 0)
def cancel(self):
self.parent.remove_overlay(self)
def done(self):
ssh_import_id = self.form.service.value + ":" + self.form.import_username.value
fsk = FetchingSSHKeys(self.parent)
self.parent.remove_overlay(self)
self.parent.show_overlay(fsk, width=fsk.width, min_width=None)
self.parent.controller.fetch_ssh_keys(ssh_import_id=ssh_import_id)
def _import_service_selected(self, sender, service: str):
iu = self.form.import_username
data = _ssh_import_data[service]
iu.help = _(data["help"])
iu.caption = _(data["caption"])
iu.widget.valid_char_pat = data["valid_char"]
iu.widget.error_invalid_char = _(data["error_invalid_char"])
self.form.service_value = service
if iu.value != "":
iu.validate()
class SSHShowKeyStretchy(Stretchy):
def __init__(self, parent, key: str) -> None:
self.parent = parent
widgets = [
Text(key),
Text(""),
button_pile([done_btn(_("Close"), on_press=self.close)]),
]
title = _("SSH key")
super().__init__(title, widgets, 0, 2)
def close(self, button=None) -> None:
self.parent.remove_overlay()
class SSHForm(Form):
install_server = BooleanField(_("Install OpenSSH server"))
pwauth = BooleanField(_("Allow password authentication over SSH"))
cancel_label = _("Back")
class FetchingSSHKeys(WidgetWrap): class FetchingSSHKeys(WidgetWrap):
def __init__(self, parent): def __init__(self, parent):
self.parent = parent self.parent = parent
@ -155,14 +199,13 @@ class FetchingSSHKeys(WidgetWrap):
) )
def cancel(self, sender): def cancel(self, sender):
self.parent.controller._fetch_cancel()
self.parent.remove_overlay() self.parent.remove_overlay()
self.parent.controller._fetch_cancel()
class ConfirmSSHKeys(Stretchy): class ConfirmSSHKeys(Stretchy):
def __init__(self, parent, ssh_data, identities: List[SSHIdentity]): def __init__(self, parent, identities: List[SSHIdentity]):
self.parent = parent self.parent = parent
self.ssh_data = ssh_data
self.identities: List[SSHIdentity] = identities self.identities: List[SSHIdentity] = identities
ok = ok_btn(label=_("Yes"), on_press=self.ok) ok = ok_btn(label=_("Yes"), on_press=self.ok)
@ -200,10 +243,11 @@ class ConfirmSSHKeys(Stretchy):
self.parent.remove_overlay() self.parent.remove_overlay()
def ok(self, sender): def ok(self, sender):
self.ssh_data.authorized_keys = [ for identity in self.identities:
id_.to_authorized_key() for id_ in self.identities self.parent.add_key_to_table(identity.to_authorized_key())
] self.parent.refresh_keys_table()
self.parent.controller.done(self.ssh_data)
self.parent.remove_overlay()
class SSHView(BaseView): class SSHView(BaseView):
@ -222,76 +266,156 @@ class SSHView(BaseView):
} }
self.form = SSHForm(initial=initial) self.form = SSHForm(initial=initial)
self.keys = ssh_data.authorized_keys
connect_signal( self._import_key_btn = Toggleable(
self.form.ssh_import_id.widget, "select", self._select_ssh_import_id menu_btn(
label=IMPORT_KEY_LABEL,
on_press=lambda unused: self.show_import_key_overlay(),
)
) )
bp = button_pile([self._import_key_btn])
bp.align = "left"
colspecs = {
0: ColSpec(rpad=1),
1: ColSpec(can_shrink=True),
2: ColSpec(rpad=1),
3: ColSpec(rpad=1),
}
self.keys_table = TablePile([], colspecs=colspecs)
self.refresh_keys_table()
rows = self.form.as_rows() + [
Text(""),
bp,
Text(""),
Text(_("AUTHORIZED KEYS")),
Text(""),
self.keys_table,
]
connect_signal(self.form, "submit", self.done) connect_signal(self.form, "submit", self.done)
connect_signal(self.form, "cancel", self.cancel) connect_signal(self.form, "cancel", self.cancel)
connect_signal(self.form.install_server.widget, "change", self._toggle_server)
self._toggle_server(None, self.form.install_server.value)
self.form_rows = ListBox(self.form.as_rows())
super().__init__( super().__init__(
screen( screen(
self.form_rows, ListBox(rows),
self.form.buttons, self.form.buttons,
excerpt=_(self.excerpt), excerpt=_(self.excerpt),
focus_buttons=False, focus_buttons=False,
) )
) )
def _select_ssh_import_id(self, sender, val):
iu = self.form.import_username
data = _ssh_import_data[val]
iu.help = _(data["help"])
iu.caption = _(data["caption"])
iu.widget.valid_char_pat = data["valid_char"]
iu.widget.error_invalid_char = _(data["error_invalid_char"])
iu.enabled = val is not None
self.form.pwauth.enabled = val is not None
# The logic here is a little tortured but the idea is that if
# the users switches from not importing a key to importing
# one, untick pwauth (but don't fiddle with the users choice
# if just switching between import sources), and conversely if
# the user is not importing a key then the box has to be
# ticked (and disabled).
if (val is None) != (self.form.ssh_import_id.value is None):
self.form.pwauth.value = val is None
if val is not None:
self.form_rows.base_widget.body.focus += 2
self.form.ssh_import_id_value = val
if iu.value != "" or val is None:
iu.validate()
def done(self, sender): def done(self, sender):
log.debug("User input: {}".format(self.form.as_data())) log.debug("User input: {}".format(self.form.as_data()))
ssh_data = SSHData( ssh_data = SSHData(
install_server=self.form.install_server.value, install_server=self.form.install_server.value,
allow_pw=self.form.pwauth.value, allow_pw=self.form.pwauth.value,
authorized_keys=self.keys,
) )
# if user specifed a value, allow user to validate fingerprint self.controller.done(ssh_data)
if self.form.ssh_import_id.value:
ssh_import_id = (
self.form.ssh_import_id.value + ":" + self.form.import_username.value
)
fsk = FetchingSSHKeys(self)
self.show_overlay(fsk, width=fsk.width, min_width=None)
self.controller.fetch_ssh_keys(
ssh_import_id=ssh_import_id, ssh_data=ssh_data
)
else:
self.controller.done(ssh_data)
def cancel(self, result=None): def cancel(self, result=None):
self.controller.cancel() self.controller.cancel()
def confirm_ssh_keys(self, ssh_data, ssh_import_id, identities: List[SSHIdentity]): def show_import_key_overlay(self):
self.show_stretchy_overlay(SSHImportStretchy(self))
def confirm_ssh_keys(self, ssh_import_id, identities: List[SSHIdentity]):
self.remove_overlay() self.remove_overlay()
self.show_stretchy_overlay(ConfirmSSHKeys(self, ssh_data, identities)) self.show_stretchy_overlay(ConfirmSSHKeys(self, identities))
def fetching_ssh_keys_failed(self, msg, stderr): def fetching_ssh_keys_failed(self, msg, stderr):
# FIXME in answers-based runs, the overlay does not exist so we pass self.remove_overlay()
# not_found_ok=True.
self.remove_overlay(not_found_ok=True)
self.show_stretchy_overlay(SomethingFailed(self, msg, stderr)) self.show_stretchy_overlay(SomethingFailed(self, msg, stderr))
def add_key_to_table(self, key: str) -> None:
"""Add the specified key to the list of authorized keys. When adding
the first one, we also disable password authentication (but give the
user the ability to re-enable it)"""
self.keys.append(key)
if len(self.keys) == 1:
self.form.pwauth.value = False
if self.form.install_server:
self.form.pwauth.enabled = True
def remove_key_from_table(self, key: str) -> None:
"""Remove the specified key from the list of authorized keys. When
removing the last one, we also re-enable password authentication (and
disable the checkbox)."""
self.keys.remove(key)
if not self.keys:
self.form.pwauth.value = True
self.form.pwauth.enabled = False
def show_key_overlay(self, key: str):
self.show_stretchy_overlay(SSHShowKeyStretchy(self, key))
def refresh_keys_table(self):
rows: List[TableRow] = []
if not self.keys:
rows = [
TableRow(
[
(
4,
Padding.push_2(
Color.info_minor(Text(_("No authorized key")))
),
)
]
)
]
for key in self.keys:
menu = ActionMenu(
[
Action(
label=_("Delete"),
enabled=True,
value="delete",
opens_dialog=False,
),
Action(
label=_("Show"),
enabled=True,
value="show",
opens_dialog=True,
),
]
)
rows.append(
make_action_menu_row(
[
Text("["),
Text(key, wrap="ellipsis"),
menu,
Text("]"),
],
menu,
)
)
connect_signal(menu, "action", self._action, key)
self.keys_table.set_contents(rows)
def _action(self, sender, value, key):
if value == "delete":
self.remove_key_from_table(key)
self.refresh_keys_table()
elif value == "show":
self.show_key_overlay(key)
def _toggle_server(self, sender, installed: bool):
self._import_key_btn.enabled = installed
if installed:
self.form.pwauth.enabled = bool(self.keys)
else:
self.form.pwauth.enabled = False

View File

@ -1,8 +1,11 @@
import asyncio
import re import re
from typing import Type
import urwid import urwid
from subiquitycore.ui.stretchy import StretchyOverlay from subiquitycore.ui.frame import SubiquityCoreUI
from subiquitycore.ui.stretchy import Stretchy, StretchyOverlay
def find_with_pred(w, pred, return_path=False): def find_with_pred(w, pred, return_path=False):
@ -86,3 +89,25 @@ def enter_data(form, data):
bf = getattr(form, k) bf = getattr(form, k)
assert bf.enabled, "%s is not enabled" % (k,) assert bf.enabled, "%s is not enabled" % (k,)
bf.value = v bf.value = v
async def wait_for_overlay(
ui: SubiquityCoreUI, overlay_type: Type[Stretchy], *, timeout=None, step=0.01
) -> Stretchy:
"""Wait until an overlay of the specified type gets displayed on the
screen and return it. If timeout is hit before the overlay is displayed, an
asyncio.TimeoutError will be raised. When timeout is set to None, we will
wait forever."""
if timeout is not None:
task = wait_for_overlay(ui, overlay_type)
return await asyncio.wait_for(task, timeout=timeout)
while True:
try:
stretchy = ui.body._w.stretchy
except AttributeError:
pass
else:
if isinstance(stretchy, overlay_type):
return stretchy
await asyncio.sleep(step)