From 34d40643ad78af72421fbdf222c29994c8852b03 Mon Sep 17 00:00:00 2001 From: Dan Bungert Date: Tue, 25 Jul 2023 15:26:25 -0600 Subject: [PATCH] format with black + isort --- console_conf/__init__.py | 3 +- console_conf/cmd/tui.py | 92 +- console_conf/cmd/write_login_details.py | 14 +- console_conf/controllers/__init__.py | 22 +- console_conf/controllers/chooser.py | 18 +- console_conf/controllers/identity.py | 104 +- .../controllers/tests/test_chooser.py | 84 +- console_conf/controllers/welcome.py | 5 +- console_conf/core.py | 9 +- console_conf/models/console_conf.py | 2 +- console_conf/models/identity.py | 6 +- console_conf/models/systems.py | 42 +- console_conf/models/tests/test_systems.py | 262 +- console_conf/ui/views/__init__.py | 8 +- console_conf/ui/views/chooser.py | 182 +- console_conf/ui/views/identity.py | 23 +- console_conf/ui/views/login.py | 38 +- console_conf/ui/views/tests/test_chooser.py | 12 +- subiquity/__init__.py | 3 +- subiquity/__main__.py | 3 +- subiquity/client/client.py | 244 +- subiquity/client/controller.py | 5 +- subiquity/client/controllers/__init__.py | 35 +- subiquity/client/controllers/drivers.py | 28 +- subiquity/client/controllers/filesystem.py | 168 +- subiquity/client/controllers/identity.py | 24 +- subiquity/client/controllers/keyboard.py | 14 +- subiquity/client/controllers/mirror.py | 30 +- subiquity/client/controllers/network.py | 78 +- subiquity/client/controllers/progress.py | 31 +- subiquity/client/controllers/proxy.py | 9 +- subiquity/client/controllers/refresh.py | 25 +- subiquity/client/controllers/serial.py | 10 +- subiquity/client/controllers/snaplist.py | 33 +- subiquity/client/controllers/source.py | 23 +- subiquity/client/controllers/ssh.py | 75 +- subiquity/client/controllers/ubuntu_pro.py | 92 +- subiquity/client/controllers/welcome.py | 15 +- subiquity/client/controllers/zdev.py | 6 +- subiquity/client/keycodes.py | 17 +- subiquity/cmd/common.py | 16 +- subiquity/cmd/schema.py | 14 +- subiquity/cmd/server.py | 167 +- subiquity/cmd/tui.py | 112 +- subiquity/common/api/client.py | 41 +- subiquity/common/api/defs.py | 47 +- subiquity/common/api/server.py | 92 +- subiquity/common/api/tests/test_client.py | 82 +- subiquity/common/api/tests/test_endtoend.py | 139 +- subiquity/common/api/tests/test_server.py | 160 +- subiquity/common/apidef.py | 412 +-- subiquity/common/errorreport.py | 151 +- subiquity/common/filesystem/actions.py | 125 +- subiquity/common/filesystem/boot.py | 165 +- subiquity/common/filesystem/gaps.py | 71 +- subiquity/common/filesystem/labels.py | 46 +- subiquity/common/filesystem/manipulator.py | 174 +- subiquity/common/filesystem/sizes.py | 43 +- .../common/filesystem/tests/test_actions.py | 122 +- .../common/filesystem/tests/test_gaps.py | 567 ++-- .../common/filesystem/tests/test_labels.py | 78 +- .../filesystem/tests/test_manipulator.py | 420 ++- .../common/filesystem/tests/test_sizes.py | 62 +- subiquity/common/resources.py | 12 +- subiquity/common/serialize.py | 96 +- subiquity/common/tests/test_serialization.py | 151 +- subiquity/common/tests/test_types.py | 9 +- subiquity/common/types.py | 174 +- subiquity/journald.py | 3 +- subiquity/lockfile.py | 6 +- subiquity/models/ad.py | 5 +- subiquity/models/codecs.py | 2 +- subiquity/models/drivers.py | 2 +- subiquity/models/filesystem.py | 676 ++--- subiquity/models/identity.py | 16 +- subiquity/models/kernel.py | 11 +- subiquity/models/keyboard.py | 93 +- subiquity/models/locale.py | 21 +- subiquity/models/mirror.py | 118 +- subiquity/models/network.py | 55 +- subiquity/models/oem.py | 2 +- subiquity/models/proxy.py | 36 +- subiquity/models/snaplist.py | 64 +- subiquity/models/source.py | 43 +- subiquity/models/ssh.py | 5 +- subiquity/models/subiquity.py | 306 +-- subiquity/models/tests/test_filesystem.py | 1430 ++++++----- subiquity/models/tests/test_keyboard.py | 55 +- subiquity/models/tests/test_locale.py | 8 +- subiquity/models/tests/test_mirror.py | 130 +- subiquity/models/tests/test_network.py | 7 +- subiquity/models/tests/test_source.py | 90 +- subiquity/models/tests/test_subiquity.py | 246 +- subiquity/models/timezone.py | 27 +- subiquity/models/ubuntu_pro.py | 3 +- subiquity/models/updates.py | 6 +- subiquity/server/ad_joiner.py | 116 +- subiquity/server/apt.py | 174 +- subiquity/server/contract_selection.py | 74 +- subiquity/server/controller.py | 33 +- subiquity/server/controllers/__init__.py | 64 +- subiquity/server/controllers/ad.py | 143 +- subiquity/server/controllers/cmdlist.py | 55 +- subiquity/server/controllers/codecs.py | 11 +- subiquity/server/controllers/debconf.py | 5 +- subiquity/server/controllers/drivers.py | 42 +- subiquity/server/controllers/filesystem.py | 667 ++--- subiquity/server/controllers/identity.py | 60 +- subiquity/server/controllers/install.py | 436 ++-- subiquity/server/controllers/integrity.py | 26 +- subiquity/server/controllers/kernel.py | 27 +- subiquity/server/controllers/keyboard.py | 222 +- subiquity/server/controllers/locale.py | 18 +- subiquity/server/controllers/mirror.py | 210 +- subiquity/server/controllers/network.py | 209 +- subiquity/server/controllers/oem.py | 72 +- subiquity/server/controllers/package.py | 7 +- subiquity/server/controllers/proxy.py | 20 +- subiquity/server/controllers/refresh.py | 111 +- subiquity/server/controllers/reporting.py | 48 +- subiquity/server/controllers/shutdown.py | 56 +- subiquity/server/controllers/snaplist.py | 94 +- subiquity/server/controllers/source.py | 76 +- subiquity/server/controllers/ssh.py | 80 +- subiquity/server/controllers/tests/test_ad.py | 116 +- .../controllers/tests/test_filesystem.py | 911 ++++--- .../server/controllers/tests/test_install.py | 292 ++- .../controllers/tests/test_integrity.py | 5 +- .../server/controllers/tests/test_kernel.py | 59 +- .../server/controllers/tests/test_keyboard.py | 58 +- .../server/controllers/tests/test_mirror.py | 130 +- .../server/controllers/tests/test_refresh.py | 61 +- .../server/controllers/tests/test_snaplist.py | 8 +- .../server/controllers/tests/test_source.py | 27 +- .../server/controllers/tests/test_ssh.py | 52 +- .../controllers/tests/test_ubuntu_pro.py | 4 +- .../server/controllers/tests/test_userdata.py | 23 +- subiquity/server/controllers/timezone.py | 48 +- subiquity/server/controllers/ubuntu_pro.py | 91 +- subiquity/server/controllers/updates.py | 12 +- subiquity/server/controllers/userdata.py | 10 +- subiquity/server/controllers/zdev.py | 25 +- subiquity/server/curtin.py | 134 +- subiquity/server/dryrun.py | 39 +- subiquity/server/errors.py | 3 +- subiquity/server/geoip.py | 37 +- subiquity/server/kernel.py | 37 +- subiquity/server/mounter.py | 72 +- subiquity/server/pkghelper.py | 36 +- subiquity/server/runner.py | 79 +- subiquity/server/server.py | 342 ++- subiquity/server/snapdapi.py | 165 +- subiquity/server/ssh.py | 35 +- subiquity/server/tests/test_apt.py | 122 +- .../server/tests/test_contract_selection.py | 60 +- subiquity/server/tests/test_controller.py | 4 +- subiquity/server/tests/test_geoip.py | 25 +- subiquity/server/tests/test_kernel.py | 57 +- subiquity/server/tests/test_mounter.py | 117 +- subiquity/server/tests/test_runner.py | 105 +- subiquity/server/tests/test_server.py | 100 +- subiquity/server/tests/test_ssh.py | 52 +- .../server/tests/test_ubuntu_advantage.py | 132 +- subiquity/server/tests/test_ubuntu_drivers.py | 229 +- subiquity/server/types.py | 14 +- subiquity/server/ubuntu_advantage.py | 139 +- subiquity/server/ubuntu_drivers.py | 118 +- subiquity/tests/api/test_api.py | 2288 +++++++++-------- subiquity/tests/fakes.py | 9 +- subiquity/tests/test_timezonecontroller.py | 63 +- subiquity/ui/frame.py | 4 +- subiquity/ui/mount.py | 65 +- subiquity/ui/views/__init__.py | 20 +- subiquity/ui/views/drivers.py | 117 +- subiquity/ui/views/error.py | 313 ++- subiquity/ui/views/filesystem/__init__.py | 5 +- subiquity/ui/views/filesystem/compound.py | 101 +- subiquity/ui/views/filesystem/delete.py | 121 +- subiquity/ui/views/filesystem/disk_info.py | 33 +- subiquity/ui/views/filesystem/filesystem.py | 294 ++- subiquity/ui/views/filesystem/guided.py | 224 +- subiquity/ui/views/filesystem/helpers.py | 37 +- subiquity/ui/views/filesystem/lvm.py | 153 +- subiquity/ui/views/filesystem/partition.py | 378 +-- subiquity/ui/views/filesystem/probing.py | 66 +- subiquity/ui/views/filesystem/raid.py | 149 +- .../views/filesystem/tests/test_filesystem.py | 26 +- .../ui/views/filesystem/tests/test_lvm.py | 60 +- .../views/filesystem/tests/test_partition.py | 154 +- .../ui/views/filesystem/tests/test_raid.py | 62 +- subiquity/ui/views/help.py | 415 +-- subiquity/ui/views/identity.py | 137 +- subiquity/ui/views/installprogress.py | 124 +- subiquity/ui/views/keyboard.py | 314 +-- subiquity/ui/views/mirror.py | 183 +- subiquity/ui/views/proxy.py | 35 +- subiquity/ui/views/refresh.py | 121 +- subiquity/ui/views/serial.py | 14 +- subiquity/ui/views/snaplist.py | 319 +-- subiquity/ui/views/source.py | 55 +- subiquity/ui/views/ssh.py | 220 +- subiquity/ui/views/tests/test_identity.py | 74 +- .../ui/views/tests/test_installprogress.py | 4 +- subiquity/ui/views/tests/test_welcome.py | 19 +- subiquity/ui/views/ubuntu_pro.py | 573 +++-- subiquity/ui/views/welcome.py | 46 +- subiquity/ui/views/zdev.py | 136 +- subiquitycore/__init__.py | 3 +- subiquitycore/async_helpers.py | 8 +- subiquitycore/context.py | 14 +- subiquitycore/controller.py | 4 +- subiquitycore/controllers/network.py | 257 +- .../controllers/tests/test_network.py | 167 +- subiquitycore/controllerset.py | 3 +- subiquitycore/core.py | 37 +- subiquitycore/file_util.py | 15 +- subiquitycore/i18n.py | 31 +- subiquitycore/log.py | 8 +- subiquitycore/lsb_release.py | 2 +- subiquitycore/models/__init__.py | 3 +- subiquitycore/models/network.py | 269 +- subiquitycore/models/tests/test_network.py | 14 +- subiquitycore/netplan.py | 48 +- subiquitycore/palette.py | 172 +- subiquitycore/prober.py | 33 +- subiquitycore/pubsub.py | 3 +- subiquitycore/screen.py | 38 +- subiquitycore/snapd.py | 124 +- subiquitycore/ssh.py | 52 +- subiquitycore/testing/view_helpers.py | 21 +- subiquitycore/tests/__init__.py | 10 +- subiquitycore/tests/parameterized.py | 3 + subiquitycore/tests/test_async_helpers.py | 14 +- subiquitycore/tests/test_file_util.py | 10 +- subiquitycore/tests/test_lsb_release.py | 2 +- subiquitycore/tests/test_netplan.py | 33 +- subiquitycore/tests/test_prober.py | 7 +- subiquitycore/tests/test_pubsub.py | 6 +- subiquitycore/tests/test_utils.py | 44 +- subiquitycore/tests/test_view.py | 3 +- subiquitycore/tests/util.py | 2 +- subiquitycore/tui.py | 99 +- subiquitycore/tuicontroller.py | 15 +- subiquitycore/ui/actionmenu.py | 47 +- subiquitycore/ui/anchors.py | 45 +- subiquitycore/ui/buttons.py | 8 +- subiquitycore/ui/confirmation.py | 35 +- subiquitycore/ui/container.py | 175 +- subiquitycore/ui/form.py | 138 +- subiquitycore/ui/frame.py | 25 +- subiquitycore/ui/interactive.py | 38 +- subiquitycore/ui/selector.py | 91 +- subiquitycore/ui/spinner.py | 46 +- subiquitycore/ui/stretchy.py | 38 +- subiquitycore/ui/table.py | 154 +- subiquitycore/ui/tests/test_form.py | 32 +- subiquitycore/ui/tests/test_table.py | 59 +- subiquitycore/ui/utils.py | 200 +- subiquitycore/ui/views/network.py | 321 +-- .../network_configure_manual_interface.py | 227 +- .../views/network_configure_wlan_interface.py | 65 +- ...test_network_configure_manual_interface.py | 78 +- subiquitycore/ui/width.py | 12 +- subiquitycore/utils.py | 202 +- subiquitycore/view.py | 86 +- system_setup/__init__.py | 3 +- system_setup/__main__.py | 3 +- system_setup/client/client.py | 27 +- system_setup/client/controllers/__init__.py | 23 +- system_setup/client/controllers/identity.py | 13 +- system_setup/client/controllers/summary.py | 21 +- .../client/controllers/wslconfadvanced.py | 20 +- .../client/controllers/wslconfbase.py | 19 +- .../client/controllers/wslsetupoptions.py | 10 +- system_setup/cmd/schema.py | 10 +- system_setup/cmd/server.py | 94 +- system_setup/cmd/tui.py | 92 +- system_setup/common/wsl_conf.py | 69 +- system_setup/common/wsl_utils.py | 12 +- system_setup/models/system_setup.py | 22 +- system_setup/models/wslconfadvanced.py | 6 +- system_setup/models/wslconfbase.py | 6 +- system_setup/models/wslsetupoptions.py | 6 +- system_setup/server/controllers/__init__.py | 35 +- system_setup/server/controllers/configure.py | 182 +- system_setup/server/controllers/identity.py | 47 +- system_setup/server/controllers/locale.py | 10 +- system_setup/server/controllers/shutdown.py | 15 +- .../server/controllers/wslconfadvanced.py | 40 +- .../server/controllers/wslconfbase.py | 35 +- .../server/controllers/wslsetupoptions.py | 22 +- system_setup/server/server.py | 40 +- system_setup/ui/views/__init__.py | 12 +- system_setup/ui/views/identity.py | 56 +- system_setup/ui/views/summary.py | 61 +- system_setup/ui/views/wslconfadvanced.py | 103 +- system_setup/ui/views/wslconfbase.py | 128 +- system_setup/ui/views/wslsetupoptions.py | 26 +- 298 files changed, 14900 insertions(+), 14044 deletions(-) diff --git a/console_conf/__init__.py b/console_conf/__init__.py index 348b4009..b5e8ddae 100644 --- a/console_conf/__init__.py +++ b/console_conf/__init__.py @@ -16,6 +16,7 @@ """ Console-Conf """ from subiquitycore import i18n + __all__ = [ - 'i18n', + "i18n", ] diff --git a/console_conf/cmd/tui.py b/console_conf/cmd/tui.py index 649c3b26..ae002426 100755 --- a/console_conf/cmd/tui.py +++ b/console_conf/cmd/tui.py @@ -16,42 +16,63 @@ import argparse import asyncio -import sys -import os import logging -from subiquitycore.log import setup_logger -from subiquitycore import __version__ as VERSION +import os +import sys + from console_conf.core import ConsoleConf, RecoveryChooser +from subiquitycore import __version__ as VERSION +from subiquitycore.log import setup_logger def parse_options(argv): parser = argparse.ArgumentParser( - description=( - 'console-conf - Pre-Ownership Configuration for Ubuntu Core'), - prog='console-conf') - parser.add_argument('--dry-run', action='store_true', - dest='dry_run', - help='menu-only, do not call installer function') - parser.add_argument('--serial', action='store_true', - dest='run_on_serial', - help='Run the installer over serial console.') - parser.add_argument('--ascii', action='store_true', - dest='ascii', - help='Run the installer in ascii mode.') - parser.add_argument('--machine-config', metavar='CONFIG', - dest='machine_config', - type=argparse.FileType(), - help="Don't Probe. Use probe data file") - parser.add_argument('--screens', action='append', dest='screens', - default=[]) - parser.add_argument('--answers') - parser.add_argument('--recovery-chooser-mode', action='store_true', - dest='chooser_systems', - help=('Run as a recovery chooser interacting with the ' - 'calling process over stdin/stdout streams')) - parser.add_argument('--output-base', action='store', dest='output_base', - default='.subiquity', - help='in dryrun, control basedir of files') + description=("console-conf - Pre-Ownership Configuration for Ubuntu Core"), + prog="console-conf", + ) + parser.add_argument( + "--dry-run", + action="store_true", + dest="dry_run", + help="menu-only, do not call installer function", + ) + parser.add_argument( + "--serial", + action="store_true", + dest="run_on_serial", + help="Run the installer over serial console.", + ) + parser.add_argument( + "--ascii", + action="store_true", + dest="ascii", + help="Run the installer in ascii mode.", + ) + parser.add_argument( + "--machine-config", + metavar="CONFIG", + dest="machine_config", + type=argparse.FileType(), + help="Don't Probe. Use probe data file", + ) + parser.add_argument("--screens", action="append", dest="screens", default=[]) + parser.add_argument("--answers") + parser.add_argument( + "--recovery-chooser-mode", + action="store_true", + dest="chooser_systems", + help=( + "Run as a recovery chooser interacting with the " + "calling process over stdin/stdout streams" + ), + ) + parser.add_argument( + "--output-base", + action="store", + dest="output_base", + default=".subiquity", + help="in dryrun, control basedir of files", + ) return parser.parse_args(argv) @@ -64,7 +85,7 @@ def main(): if opts.dry_run: LOGDIR = opts.output_base setup_logger(dir=LOGDIR) - logger = logging.getLogger('console_conf') + logger = logging.getLogger("console_conf") logger.info("Starting console-conf v{}".format(VERSION)) logger.info("Arguments passed: {}".format(sys.argv)) @@ -73,8 +94,7 @@ def main(): # when running as a chooser, the stdin/stdout streams are set up by # the process that runs us, attempt to restore the tty in/out by # looking at stderr - chooser_input, chooser_output = restore_std_streams_from( - sys.stderr) + chooser_input, chooser_output = restore_std_streams_from(sys.stderr) interface = RecoveryChooser(opts, chooser_input, chooser_output) else: interface = ConsoleConf(opts) @@ -91,10 +111,10 @@ def restore_std_streams_from(from_file): tty = os.ttyname(from_file.fileno()) # we have tty now chooser_input, chooser_output = sys.stdin, sys.stdout - sys.stdin = open(tty, 'r') - sys.stdout = open(tty, 'w') + sys.stdin = open(tty, "r") + sys.stdout = open(tty, "w") return chooser_input, chooser_output -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/console_conf/cmd/write_login_details.py b/console_conf/cmd/write_login_details.py index 1c0b7085..3cb0c498 100755 --- a/console_conf/cmd/write_login_details.py +++ b/console_conf/cmd/write_login_details.py @@ -17,20 +17,18 @@ import logging import sys -from subiquitycore.log import setup_logger -from subiquitycore import __version__ as VERSION - from console_conf.controllers.identity import write_login_details_standalone +from subiquitycore import __version__ as VERSION +from subiquitycore.log import setup_logger def main(): - logger = setup_logger(dir='/var/log/console-conf') - logger = logging.getLogger('console_conf') - logger.info( - "Starting console-conf-write-login-details v{}".format(VERSION)) + logger = setup_logger(dir="/var/log/console-conf") + logger = logging.getLogger("console_conf") + logger.info("Starting console-conf-write-login-details v{}".format(VERSION)) logger.info("Arguments passed: {}".format(sys.argv)) return write_login_details_standalone() -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/console_conf/controllers/__init__.py b/console_conf/controllers/__init__.py index 175c588f..f45a8534 100644 --- a/console_conf/controllers/__init__.py +++ b/console_conf/controllers/__init__.py @@ -15,19 +15,17 @@ """ console-conf controllers """ -from .identity import IdentityController from subiquitycore.controllers.network import NetworkController -from .welcome import WelcomeController, RecoveryChooserWelcomeController -from .chooser import ( - RecoveryChooserController, - RecoveryChooserConfirmController - ) + +from .chooser import RecoveryChooserConfirmController, RecoveryChooserController +from .identity import IdentityController +from .welcome import RecoveryChooserWelcomeController, WelcomeController __all__ = [ - 'IdentityController', - 'NetworkController', - 'WelcomeController', - 'RecoveryChooserWelcomeController', - 'RecoveryChooserController', - 'RecoveryChooserConfirmController', + "IdentityController", + "NetworkController", + "WelcomeController", + "RecoveryChooserWelcomeController", + "RecoveryChooserController", + "RecoveryChooserConfirmController", ] diff --git a/console_conf/controllers/chooser.py b/console_conf/controllers/chooser.py index c2618996..d4f7cf6a 100644 --- a/console_conf/controllers/chooser.py +++ b/console_conf/controllers/chooser.py @@ -15,18 +15,16 @@ import logging from console_conf.ui.views import ( - ChooserView, - ChooserCurrentSystemView, ChooserConfirmView, - ) - + ChooserCurrentSystemView, + ChooserView, +) from subiquitycore.tuicontroller import TuiController log = logging.getLogger("console_conf.controllers.chooser") class RecoveryChooserBaseController(TuiController): - def __init__(self, app): super().__init__(app) self.model = app.base_model @@ -40,7 +38,6 @@ class RecoveryChooserBaseController(TuiController): class RecoveryChooserController(RecoveryChooserBaseController): - def __init__(self, app): super().__init__(app) self._model_view, self._all_view = self._make_views() @@ -64,9 +61,9 @@ class RecoveryChooserController(RecoveryChooserBaseController): if self.model.current and self.model.current.actions: # only when we have a current system and it has actions available more = len(self.model.systems) > 1 - current_view = ChooserCurrentSystemView(self, - self.model.current, - has_more=more) + current_view = ChooserCurrentSystemView( + self, self.model.current, has_more=more + ) all_view = ChooserView(self, self.model.systems) return current_view, all_view @@ -80,8 +77,7 @@ class RecoveryChooserController(RecoveryChooserBaseController): self.ui.set_body(self._all_view) def back(self): - if self._current_view == self._all_view and \ - self._model_view is not None: + if self._current_view == self._all_view and self._model_view is not None: # back in the all-systems screen goes back to the current model # screen, provided we have one self._current_view = self._model_view diff --git a/console_conf/controllers/identity.py b/console_conf/controllers/identity.py index 245eb42e..47a45095 100644 --- a/console_conf/controllers/identity.py +++ b/console_conf/controllers/identity.py @@ -20,18 +20,17 @@ import pwd import shlex import sys -from subiquitycore.ssh import host_key_info, get_ips_standalone +from console_conf.ui.views import IdentityView, LoginView from subiquitycore.snapd import SnapdConnection +from subiquitycore.ssh import get_ips_standalone, host_key_info from subiquitycore.tuicontroller import TuiController from subiquitycore.utils import disable_console_conf, run_command -from console_conf.ui.views import IdentityView, LoginView - -log = logging.getLogger('console_conf.controllers.identity') +log = logging.getLogger("console_conf.controllers.identity") def get_core_version(): - """ For a ubuntu-core system, return its version or None """ + """For a ubuntu-core system, return its version or None""" path = "/usr/lib/os-release" try: @@ -53,33 +52,33 @@ def get_core_version(): def get_managed(): - """ Check if device is managed """ - con = SnapdConnection('', '/run/snapd.socket') - return con.get('v2/system-info').json()['result']['managed'] + """Check if device is managed""" + con = SnapdConnection("", "/run/snapd.socket") + return con.get("v2/system-info").json()["result"]["managed"] def get_realname(username): try: info = pwd.getpwnam(username) except KeyError: - return '' - return info.pw_gecos.split(',', 1)[0] + return "" + return info.pw_gecos.split(",", 1)[0] def get_device_owner(): - """ Get device owner, if any """ - con = SnapdConnection('', '/run/snapd.socket') - for user in con.get('v2/users').json()['result']: - if 'username' not in user: + """Get device owner, if any""" + con = SnapdConnection("", "/run/snapd.socket") + for user in con.get("v2/users").json()["result"]: + if "username" not in user: continue - username = user['username'] - homedir = '/home/' + username + username = user["username"] + homedir = "/home/" + username if os.path.isdir(homedir): return { - 'username': username, - 'realname': get_realname(username), - 'homedir': homedir, - } + "username": username, + "realname": get_realname(username), + "homedir": homedir, + } return None @@ -110,15 +109,22 @@ def write_login_details(fp, username, ips): tty_name = os.ttyname(0)[5:] # strip off the /dev/ version = get_core_version() or "16" if len(ips) == 0: - fp.write(login_details_tmpl_no_ip.format( - sshcommands=sshcommands, tty_name=tty_name, version=version)) + fp.write( + login_details_tmpl_no_ip.format( + sshcommands=sshcommands, tty_name=tty_name, version=version + ) + ) else: first_ip = ips[0] - fp.write(login_details_tmpl.format(sshcommands=sshcommands, - host_key_info=host_key_info(), - tty_name=tty_name, - first_ip=first_ip, - version=version)) + fp.write( + login_details_tmpl.format( + sshcommands=sshcommands, + host_key_info=host_key_info(), + tty_name=tty_name, + first_ip=first_ip, + version=version, + ) + ) def write_login_details_standalone(): @@ -131,18 +137,16 @@ def write_login_details_standalone(): else: tty_name = os.ttyname(0)[5:] version = get_core_version() or "16" - print(login_details_tmpl_no_ip.format( - tty_name=tty_name, version=version)) + print(login_details_tmpl_no_ip.format(tty_name=tty_name, version=version)) return 2 if owner is None: - print("device managed without user @ {}".format(', '.join(ips))) + print("device managed without user @ {}".format(", ".join(ips))) else: - write_login_details(sys.stdout, owner['username'], ips) + write_login_details(sys.stdout, owner["username"], ips) return 0 class IdentityController(TuiController): - def __init__(self, app): super().__init__(app) self.model = app.base_model.identity @@ -152,11 +156,9 @@ class IdentityController(TuiController): device_owner = get_device_owner() if device_owner: self.model.add_user(device_owner) - key_file = os.path.join(device_owner['homedir'], - ".ssh/authorized_keys") - cp = run_command(['ssh-keygen', '-lf', key_file]) - self.model.user.fingerprints = ( - cp.stdout.replace('\r', '').splitlines()) + key_file = os.path.join(device_owner["homedir"], ".ssh/authorized_keys") + cp = run_command(["ssh-keygen", "-lf", key_file]) + self.model.user.fingerprints = cp.stdout.replace("\r", "").splitlines() return self.make_login_view() else: return IdentityView(self.model, self) @@ -164,35 +166,35 @@ class IdentityController(TuiController): def identity_done(self, email): if self.opts.dry_run: result = { - 'realname': email, - 'username': email, - } + "realname": email, + "username": email, + } self.model.add_user(result) - login_details_path = self.opts.output_base + '/login-details.txt' + login_details_path = self.opts.output_base + "/login-details.txt" else: self.app.urwid_loop.draw_screen() - cp = run_command( - ["snap", "create-user", "--sudoer", "--json", email]) + cp = run_command(["snap", "create-user", "--sudoer", "--json", email]) if cp.returncode != 0: if isinstance(self.ui.body, IdentityView): self.ui.body.snap_create_user_failed( - "Creating user failed:", cp.stderr) + "Creating user failed:", cp.stderr + ) return else: data = json.loads(cp.stdout) result = { - 'realname': email, - 'username': data['username'], - } - os.makedirs('/run/console-conf', exist_ok=True) - login_details_path = '/run/console-conf/login-details.txt' + "realname": email, + "username": data["username"], + } + os.makedirs("/run/console-conf", exist_ok=True) + login_details_path = "/run/console-conf/login-details.txt" self.model.add_user(result) ips = [] net_model = self.app.base_model.network for dev in net_model.get_all_netdevs(): ips.extend(dev.actual_global_ip_addresses) - with open(login_details_path, 'w') as fp: - write_login_details(fp, result['username'], ips) + with open(login_details_path, "w") as fp: + write_login_details(fp, result["username"], ips) self.login() def cancel(self): diff --git a/console_conf/controllers/tests/test_chooser.py b/console_conf/controllers/tests/test_chooser.py index acad6db6..1349b483 100644 --- a/console_conf/controllers/tests/test_chooser.py +++ b/console_conf/controllers/tests/test_chooser.py @@ -15,16 +15,12 @@ import unittest from unittest import mock -from subiquitycore.tests.mocks import make_app from console_conf.controllers.chooser import ( - RecoveryChooserController, RecoveryChooserConfirmController, - ) -from console_conf.models.systems import ( - RecoverySystemsModel, - SelectedSystemAction, - ) - + RecoveryChooserController, +) +from console_conf.models.systems import RecoverySystemsModel, SelectedSystemAction +from subiquitycore.tests.mocks import make_app model1_non_current = { "current": False, @@ -43,7 +39,7 @@ model1_non_current = { "actions": [ {"title": "action 1", "mode": "action1"}, {"title": "action 2", "mode": "action2"}, - ] + ], } model2_current = { @@ -62,21 +58,19 @@ model2_current = { }, "actions": [ {"title": "action 1", "mode": "action1"}, - ] + ], } def make_model(): - return RecoverySystemsModel.from_systems([model1_non_current, - model2_current]) + return RecoverySystemsModel.from_systems([model1_non_current, model2_current]) class TestChooserConfirmController(unittest.TestCase): - def test_back(self): app = make_app() c = RecoveryChooserConfirmController(app) - c.model = mock.Mock(selection='selection') + c.model = mock.Mock(selection="selection") c.back() app.respond.assert_not_called() app.exit.assert_not_called() @@ -86,12 +80,12 @@ class TestChooserConfirmController(unittest.TestCase): def test_confirm(self): app = make_app() c = RecoveryChooserConfirmController(app) - c.model = mock.Mock(selection='selection') + c.model = mock.Mock(selection="selection") c.confirm() - app.respond.assert_called_with('selection') + app.respond.assert_called_with("selection") app.exit.assert_called() - @mock.patch('console_conf.controllers.chooser.ChooserConfirmView') + @mock.patch("console_conf.controllers.chooser.ChooserConfirmView") def test_confirm_view(self, ccv): app = make_app() c = RecoveryChooserConfirmController(app) @@ -102,24 +96,25 @@ class TestChooserConfirmController(unittest.TestCase): class TestChooserController(unittest.TestCase): - def test_select(self): app = make_app(model=make_model()) c = RecoveryChooserController(app) c.select(c.model.systems[0], c.model.systems[0].actions[0]) - exp = SelectedSystemAction(system=c.model.systems[0], - action=c.model.systems[0].actions[0]) + exp = SelectedSystemAction( + system=c.model.systems[0], action=c.model.systems[0].actions[0] + ) self.assertEqual(c.model.selection, exp) app.next_screen.assert_called() app.respond.assert_not_called() app.exit.assert_not_called() - @mock.patch('console_conf.controllers.chooser.ChooserCurrentSystemView', - return_value='current') - @mock.patch('console_conf.controllers.chooser.ChooserView', - return_value='all') + @mock.patch( + "console_conf.controllers.chooser.ChooserCurrentSystemView", + return_value="current", + ) + @mock.patch("console_conf.controllers.chooser.ChooserView", return_value="all") def test_current_ui_first(self, cv, ccsv): app = make_app(model=make_model()) c = RecoveryChooserController(app) @@ -128,15 +123,16 @@ class TestChooserController(unittest.TestCase): # as well as all systems view cv.assert_called_with(c, c.model.systems) v = c.make_ui() - self.assertEqual(v, 'current') + self.assertEqual(v, "current") # user selects more options and the view is replaced c.more_options() - c.ui.set_body.assert_called_with('all') + c.ui.set_body.assert_called_with("all") - @mock.patch('console_conf.controllers.chooser.ChooserCurrentSystemView', - return_value='current') - @mock.patch('console_conf.controllers.chooser.ChooserView', - return_value='all') + @mock.patch( + "console_conf.controllers.chooser.ChooserCurrentSystemView", + return_value="current", + ) + @mock.patch("console_conf.controllers.chooser.ChooserView", return_value="all") def test_current_current_all_there_and_back(self, cv, ccsv): app = make_app(model=make_model()) c = RecoveryChooserController(app) @@ -145,21 +141,22 @@ class TestChooserController(unittest.TestCase): cv.assert_called_with(c, c.model.systems) v = c.make_ui() - self.assertEqual(v, 'current') + self.assertEqual(v, "current") # user selects more options and the view is replaced c.more_options() - c.ui.set_body.assert_called_with('all') + c.ui.set_body.assert_called_with("all") # go back now c.back() - c.ui.set_body.assert_called_with('current') + c.ui.set_body.assert_called_with("current") # nothing c.back() c.ui.set_body.not_called() - @mock.patch('console_conf.controllers.chooser.ChooserCurrentSystemView', - return_value='current') - @mock.patch('console_conf.controllers.chooser.ChooserView', - return_value='all') + @mock.patch( + "console_conf.controllers.chooser.ChooserCurrentSystemView", + return_value="current", + ) + @mock.patch("console_conf.controllers.chooser.ChooserView", return_value="all") def test_only_one_and_current(self, cv, ccsv): model = RecoverySystemsModel.from_systems([model2_current]) app = make_app(model=model) @@ -169,15 +166,16 @@ class TestChooserController(unittest.TestCase): cv.assert_called_with(c, c.model.systems) v = c.make_ui() - self.assertEqual(v, 'current') + self.assertEqual(v, "current") # going back does nothing c.back() c.ui.set_body.not_called() - @mock.patch('console_conf.controllers.chooser.ChooserCurrentSystemView', - return_value='current') - @mock.patch('console_conf.controllers.chooser.ChooserView', - return_value='all') + @mock.patch( + "console_conf.controllers.chooser.ChooserCurrentSystemView", + return_value="current", + ) + @mock.patch("console_conf.controllers.chooser.ChooserView", return_value="all") def test_all_systems_first_no_current(self, cv, ccsv): model = RecoverySystemsModel.from_systems([model1_non_current]) app = make_app(model=model) @@ -192,4 +190,4 @@ class TestChooserController(unittest.TestCase): ccsv.assert_not_called() v = c.make_ui() - self.assertEqual(v, 'all') + self.assertEqual(v, "all") diff --git a/console_conf/controllers/welcome.py b/console_conf/controllers/welcome.py index 085c6ded..c9e68e96 100644 --- a/console_conf/controllers/welcome.py +++ b/console_conf/controllers/welcome.py @@ -13,13 +13,11 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from console_conf.ui.views import WelcomeView, ChooserWelcomeView - +from console_conf.ui.views import ChooserWelcomeView, WelcomeView from subiquitycore.tuicontroller import TuiController class WelcomeController(TuiController): - welcome_view = WelcomeView def make_ui(self): @@ -34,7 +32,6 @@ class WelcomeController(TuiController): class RecoveryChooserWelcomeController(WelcomeController): - welcome_view = ChooserWelcomeView def __init__(self, app): diff --git a/console_conf/core.py b/console_conf/core.py index 51840d0d..78e1af7c 100644 --- a/console_conf/core.py +++ b/console_conf/core.py @@ -15,18 +15,17 @@ import logging -from subiquitycore.prober import Prober -from subiquitycore.tui import TuiApplication - from console_conf.models.console_conf import ConsoleConfModel from console_conf.models.systems import RecoverySystemsModel +from subiquitycore.prober import Prober +from subiquitycore.tui import TuiApplication log = logging.getLogger("console_conf.core") class ConsoleConf(TuiApplication): - from console_conf import controllers as controllers_mod + project = "console_conf" make_model = ConsoleConfModel @@ -43,8 +42,8 @@ class ConsoleConf(TuiApplication): class RecoveryChooser(TuiApplication): - from console_conf import controllers as controllers_mod + project = "console_conf" controllers = [ diff --git a/console_conf/models/console_conf.py b/console_conf/models/console_conf.py index 0810d68a..e78f3982 100644 --- a/console_conf/models/console_conf.py +++ b/console_conf/models/console_conf.py @@ -1,5 +1,5 @@ - from subiquitycore.models.network import NetworkModel + from .identity import IdentityModel diff --git a/console_conf/models/identity.py b/console_conf/models/identity.py index 48743446..ac53a00c 100644 --- a/console_conf/models/identity.py +++ b/console_conf/models/identity.py @@ -17,8 +17,7 @@ import logging import attr - -log = logging.getLogger('console_conf.models.identity') +log = logging.getLogger("console_conf.models.identity") @attr.s @@ -30,8 +29,7 @@ class User(object): class IdentityModel(object): - """ Model representing user identity - """ + """Model representing user identity""" def __init__(self): self._user = None diff --git a/console_conf/models/systems.py b/console_conf/models/systems.py index 501cc109..a6c486ba 100644 --- a/console_conf/models/systems.py +++ b/console_conf/models/systems.py @@ -13,8 +13,8 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import logging import json +import logging import attr import jsonschema @@ -113,9 +113,7 @@ class RecoverySystemsModel: m = syst["model"] b = syst["brand"] model = SystemModel( - model=m["model"], - brand_id=m["brand-id"], - display_name=m["display-name"] + model=m["model"], brand_id=m["brand-id"], display_name=m["display-name"] ) brand = Brand( ID=b["id"], @@ -180,11 +178,13 @@ class RecoverySystem: actions = attr.ib() def __eq__(self, other): - return self.current == other.current and \ - self.label == other.label and \ - self.model == other.model and \ - self.brand == other.brand and \ - self.actions == other.actions + return ( + self.current == other.current + and self.label == other.label + and self.model == other.model + and self.brand == other.brand + and self.actions == other.actions + ) @attr.s @@ -195,10 +195,12 @@ class Brand: validation = attr.ib() def __eq__(self, other): - return self.ID == other.ID and \ - self.username == other.username and \ - self.display_name == other.display_name and \ - self.validation == other.validation + return ( + self.ID == other.ID + and self.username == other.username + and self.display_name == other.display_name + and self.validation == other.validation + ) @attr.s @@ -208,9 +210,11 @@ class SystemModel: display_name = attr.ib() def __eq__(self, other): - return self.model == other.model and \ - self.brand_id == other.brand_id and \ - self.display_name == other.display_name + return ( + self.model == other.model + and self.brand_id == other.brand_id + and self.display_name == other.display_name + ) @attr.s @@ -219,8 +223,7 @@ class SystemAction: mode = attr.ib() def __eq__(self, other): - return self.title == other.title and \ - self.mode == other.mode + return self.title == other.title and self.mode == other.mode @attr.s @@ -229,5 +232,4 @@ class SelectedSystemAction: action = attr.ib() def __eq__(self, other): - return self.system == other.system and \ - self.action == other.action + return self.system == other.system and self.action == other.action diff --git a/console_conf/models/tests/test_systems.py b/console_conf/models/tests/test_systems.py index ba94663d..3b929bf8 100644 --- a/console_conf/models/tests/test_systems.py +++ b/console_conf/models/tests/test_systems.py @@ -13,23 +13,23 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import unittest import json -import jsonschema +import unittest from io import StringIO +import jsonschema + from console_conf.models.systems import ( - RecoverySystemsModel, - RecoverySystem, Brand, - SystemModel, - SystemAction, + RecoverySystem, + RecoverySystemsModel, SelectedSystemAction, - ) + SystemAction, + SystemModel, +) class RecoverySystemsModelTests(unittest.TestCase): - reference = { "systems": [ { @@ -49,7 +49,7 @@ class RecoverySystemsModelTests(unittest.TestCase): "actions": [ {"title": "reinstall", "mode": "install"}, {"title": "recover", "mode": "recover"}, - ] + ], }, { "label": "other", @@ -66,46 +66,53 @@ class RecoverySystemsModelTests(unittest.TestCase): }, "actions": [ {"title": "reinstall", "mode": "install"}, - ] - } - + ], + }, ] } def test_from_systems_stream_happy(self): raw = json.dumps(self.reference) systems = RecoverySystemsModel.from_systems_stream(StringIO(raw)) - exp = RecoverySystemsModel([ - RecoverySystem( - current=True, - label="1234", - model=SystemModel( - model="core20-amd64", - brand_id="brand-id", - display_name="Core 20 AMD64 system"), - brand=Brand( - ID="brand-id", - username="brand-username", - display_name="this is my brand", - validation="verified"), - actions=[SystemAction(title="reinstall", mode="install"), - SystemAction(title="recover", mode="recover")] + exp = RecoverySystemsModel( + [ + RecoverySystem( + current=True, + label="1234", + model=SystemModel( + model="core20-amd64", + brand_id="brand-id", + display_name="Core 20 AMD64 system", + ), + brand=Brand( + ID="brand-id", + username="brand-username", + display_name="this is my brand", + validation="verified", + ), + actions=[ + SystemAction(title="reinstall", mode="install"), + SystemAction(title="recover", mode="recover"), + ], ), - RecoverySystem( - current=False, - label="other", - model=SystemModel( - model="my-brand-box", - brand_id="other-brand-id", - display_name="Funky box"), - brand=Brand( - ID="other-brand-id", - username="other-brand", - display_name="my brand", - validation="unproven"), - actions=[SystemAction(title="reinstall", mode="install")] + RecoverySystem( + current=False, + label="other", + model=SystemModel( + model="my-brand-box", + brand_id="other-brand-id", + display_name="Funky box", + ), + brand=Brand( + ID="other-brand-id", + username="other-brand", + display_name="my brand", + validation="unproven", + ), + actions=[SystemAction(title="reinstall", mode="install")], ), - ]) + ] + ) self.assertEqual(systems.systems, exp.systems) self.assertEqual(systems.current, exp.systems[0]) @@ -114,101 +121,111 @@ class RecoverySystemsModelTests(unittest.TestCase): RecoverySystemsModel.from_systems_stream(StringIO("{}")) def test_from_systems_stream_invalid_missing_system_label(self): - raw = json.dumps({ - "systems": [ - { - "brand": { - "id": "brand-id", - "username": "brand-username", - "display-name": "this is my brand", - "validation": "verified", + raw = json.dumps( + { + "systems": [ + { + "brand": { + "id": "brand-id", + "username": "brand-username", + "display-name": "this is my brand", + "validation": "verified", + }, + "model": { + "model": "core20-amd64", + "brand-id": "brand-id", + "display-name": "Core 20 AMD64 system", + }, + "actions": [ + {"title": "reinstall", "mode": "install"}, + {"title": "recover", "mode": "recover"}, + ], }, - "model": { - "model": "core20-amd64", - "brand-id": "brand-id", - "display-name": "Core 20 AMD64 system", - }, - "actions": [ - {"title": "reinstall", "mode": "install"}, - {"title": "recover", "mode": "recover"}, - ] - }, - ] - }) + ] + } + ) with self.assertRaises(jsonschema.ValidationError): RecoverySystemsModel.from_systems_stream(StringIO(raw)) def test_from_systems_stream_invalid_missing_brand(self): - raw = json.dumps({ - "systems": [ - { - "label": "1234", - "model": { - "model": "core20-amd64", - "brand-id": "brand-id", - "display-name": "Core 20 AMD64 system", + raw = json.dumps( + { + "systems": [ + { + "label": "1234", + "model": { + "model": "core20-amd64", + "brand-id": "brand-id", + "display-name": "Core 20 AMD64 system", + }, + "actions": [ + {"title": "reinstall", "mode": "install"}, + {"title": "recover", "mode": "recover"}, + ], }, - "actions": [ - {"title": "reinstall", "mode": "install"}, - {"title": "recover", "mode": "recover"}, - ] - }, - ] - }) + ] + } + ) with self.assertRaises(jsonschema.ValidationError): RecoverySystemsModel.from_systems_stream(StringIO(raw)) def test_from_systems_stream_invalid_missing_model(self): - raw = json.dumps({ - "systems": [ - { - "label": "1234", - "brand": { - "id": "brand-id", - "username": "brand-username", - "display-name": "this is my brand", - "validation": "verified", + raw = json.dumps( + { + "systems": [ + { + "label": "1234", + "brand": { + "id": "brand-id", + "username": "brand-username", + "display-name": "this is my brand", + "validation": "verified", + }, + "actions": [ + {"title": "reinstall", "mode": "install"}, + {"title": "recover", "mode": "recover"}, + ], }, - "actions": [ - {"title": "reinstall", "mode": "install"}, - {"title": "recover", "mode": "recover"}, - ] - }, - ] - }) + ] + } + ) with self.assertRaises(jsonschema.ValidationError): RecoverySystemsModel.from_systems_stream(StringIO(raw)) def test_from_systems_stream_valid_no_actions(self): - raw = json.dumps({ - "systems": [ - { - "label": "1234", - "model": { - "model": "core20-amd64", - "brand-id": "brand-id", - "display-name": "Core 20 AMD64 system", + raw = json.dumps( + { + "systems": [ + { + "label": "1234", + "model": { + "model": "core20-amd64", + "brand-id": "brand-id", + "display-name": "Core 20 AMD64 system", + }, + "brand": { + "id": "brand-id", + "username": "brand-username", + "display-name": "this is my brand", + "validation": "verified", + }, + "actions": [], }, - "brand": { - "id": "brand-id", - "username": "brand-username", - "display-name": "this is my brand", - "validation": "verified", - }, - "actions": [], - }, - ] - }) + ] + } + ) RecoverySystemsModel.from_systems_stream(StringIO(raw)) def test_selection(self): raw = json.dumps(self.reference) model = RecoverySystemsModel.from_systems_stream(StringIO(raw)) model.select(model.systems[1], model.systems[1].actions[0]) - self.assertEqual(model.selection, - SelectedSystemAction( - system=model.systems[1], - action=model.systems[1].actions[0])) + self.assertEqual( + model.selection, + SelectedSystemAction( + system=model.systems[1], action=model.systems[1].actions[0] + ), + ) model.unselect() self.assertIsNone(model.selection) @@ -221,13 +238,16 @@ class RecoverySystemsModelTests(unittest.TestCase): stream = StringIO() RecoverySystemsModel.to_response_stream(model.selection, stream) fromjson = json.loads(stream.getvalue()) - self.assertEqual(fromjson, { - "label": "other", - "action": { - "mode": "install", - "title": "reinstall", + self.assertEqual( + fromjson, + { + "label": "other", + "action": { + "mode": "install", + "title": "reinstall", }, - }) + }, + ) def test_no_current(self): reference = { @@ -249,7 +269,7 @@ class RecoverySystemsModelTests(unittest.TestCase): "actions": [ {"title": "reinstall", "mode": "install"}, {"title": "recover", "mode": "recover"}, - ] + ], }, ], } diff --git a/console_conf/ui/views/__init__.py b/console_conf/ui/views/__init__.py index 09399bab..a5532abb 100644 --- a/console_conf/ui/views/__init__.py +++ b/console_conf/ui/views/__init__.py @@ -15,14 +15,10 @@ """ ConsoleConf UI Views """ +from .chooser import ChooserConfirmView, ChooserCurrentSystemView, ChooserView from .identity import IdentityView from .login import LoginView -from .welcome import WelcomeView, ChooserWelcomeView -from .chooser import ( - ChooserView, - ChooserCurrentSystemView, - ChooserConfirmView - ) +from .welcome import ChooserWelcomeView, WelcomeView __all__ = [ "IdentityView", diff --git a/console_conf/ui/views/chooser.py b/console_conf/ui/views/chooser.py index 244bf550..6475b256 100644 --- a/console_conf/ui/views/chooser.py +++ b/console_conf/ui/views/chooser.py @@ -20,29 +20,14 @@ Chooser provides a view with recovery chooser actions. """ import logging -from urwid import ( - connect_signal, - Text, - ) -from subiquitycore.ui.buttons import ( - danger_btn, - forward_btn, - back_btn, - ) -from subiquitycore.ui.actionmenu import ( - Action, - ActionMenu, - ) -from subiquitycore.ui.container import Pile, ListBox -from subiquitycore.ui.utils import ( - button_pile, - screen, - make_action_menu_row, - Color, - ) -from subiquitycore.ui.table import TableRow, TablePile -from subiquitycore.view import BaseView +from urwid import Text, connect_signal +from subiquitycore.ui.actionmenu import Action, ActionMenu +from subiquitycore.ui.buttons import back_btn, danger_btn, forward_btn +from subiquitycore.ui.container import ListBox, Pile +from subiquitycore.ui.table import TablePile, TableRow +from subiquitycore.ui.utils import Color, button_pile, make_action_menu_row, screen +from subiquitycore.view import BaseView log = logging.getLogger("console_conf.ui.views.chooser") @@ -69,28 +54,31 @@ class ChooserCurrentSystemView(ChooserBaseView): def __init__(self, controller, current, has_more=False): self.controller = controller - log.debug('more systems available: %s', has_more) - log.debug('current system: %s', current) + log.debug("more systems available: %s", has_more) + log.debug("current system: %s", current) actions = [] for action in sorted(current.actions, key=by_preferred_action_type): - actions.append(forward_btn(label=action.title.capitalize(), - on_press=self._current_system_action, - user_arg=(current, action))) + actions.append( + forward_btn( + label=action.title.capitalize(), + on_press=self._current_system_action, + user_arg=(current, action), + ) + ) if has_more: # add a button to show the other systems actions.append(Text("")) - actions.append(forward_btn(label="Show all available systems", - on_press=self._more_options)) + actions.append( + forward_btn( + label="Show all available systems", on_press=self._more_options + ) + ) lb = ListBox(actions) - super().__init__(current, - screen( - lb, - narrow_rows=True, - excerpt=self.excerpt)) + super().__init__(current, screen(lb, narrow_rows=True, excerpt=self.excerpt)) def _current_system_action(self, sender, arg): current, action = arg @@ -104,43 +92,54 @@ class ChooserCurrentSystemView(ChooserBaseView): class ChooserView(ChooserBaseView): - excerpt = ("Select one of available recovery systems and a desired " - "action to execute.") + excerpt = ( + "Select one of available recovery systems and a desired " "action to execute." + ) def __init__(self, controller, systems): self.controller = controller - heading_table = TablePile([ - TableRow([ - Color.info_minor(Text(header)) for header in [ - "LABEL", "MODEL", "PUBLISHER", "" + heading_table = TablePile( + [ + TableRow( + [ + Color.info_minor(Text(header)) + for header in ["LABEL", "MODEL", "PUBLISHER", ""] ] - ]) + ) ], - spacing=2) + spacing=2, + ) trows = [] - systems = sorted(systems, - key=lambda s: (s.brand.display_name, - s.model.display_name, - s.current, - s.label)) + systems = sorted( + systems, + key=lambda s: ( + s.brand.display_name, + s.model.display_name, + s.current, + s.label, + ), + ) for s in systems: actions = [] - log.debug('actions: %s', s.actions) + log.debug("actions: %s", s.actions) for act in sorted(s.actions, key=by_preferred_action_type): - actions.append(Action(label=act.title.capitalize(), - value=act, - enabled=True)) + actions.append( + Action(label=act.title.capitalize(), value=act, enabled=True) + ) menu = ActionMenu(actions) - connect_signal(menu, 'action', self._system_action, s) - srow = make_action_menu_row([ - Text(s.label), - Text(s.model.display_name), - Text(s.brand.display_name), - Text("(installed)" if s.current else ""), + connect_signal(menu, "action", self._system_action, s) + srow = make_action_menu_row( + [ + Text(s.label), + Text(s.model.display_name), + Text(s.brand.display_name), + Text("(installed)" if s.current else ""), + menu, + ], menu, - ], menu) + ) trows.append(srow) systems_table = TablePile(trows, spacing=2) @@ -154,12 +153,15 @@ class ChooserView(ChooserBaseView): # back to options of current system buttons.append(back_btn("BACK", on_press=self.back)) - super().__init__(controller.model.current, - screen( - rows=rows, - buttons=button_pile(buttons), - focus_buttons=False, - excerpt=self.excerpt)) + super().__init__( + controller.model.current, + screen( + rows=rows, + buttons=button_pile(buttons), + focus_buttons=False, + excerpt=self.excerpt, + ), + ) def _system_action(self, sender, action, system): self.controller.select(system, action) @@ -169,22 +171,24 @@ class ChooserView(ChooserBaseView): class ChooserConfirmView(ChooserBaseView): - canned_summary = { "run": "Continue running the system without any changes.", - "recover": ("You have requested to reboot the system into recovery " - "mode."), - "install": ("You are about to {action_lower} the system version " - "{version} for {model} from {publisher}.\n\n" - "This will remove all existing user data on the " - "device.\n\n" - "The system will reboot in the process."), + "recover": ("You have requested to reboot the system into recovery " "mode."), + "install": ( + "You are about to {action_lower} the system version " + "{version} for {model} from {publisher}.\n\n" + "This will remove all existing user data on the " + "device.\n\n" + "The system will reboot in the process." + ), } - default_summary = ("You are about to execute action \"{action}\" using " - "system version {version} for device {model} from " - "{publisher}.\n\n" - "Make sure you understand the consequences of " - "performing this action.") + default_summary = ( + 'You are about to execute action "{action}" using ' + "system version {version} for device {model} from " + "{publisher}.\n\n" + "Make sure you understand the consequences of " + "performing this action." + ) def __init__(self, controller, selection): self.controller = controller @@ -193,21 +197,21 @@ class ChooserConfirmView(ChooserBaseView): danger_btn("CONFIRM", on_press=self.confirm), back_btn("BACK", on_press=self.back), ] - fmt = self.canned_summary.get(selection.action.mode, - self.default_summary) - summary = fmt.format(action=selection.action.title, - action_lower=selection.action.title.lower(), - model=selection.system.model.display_name, - publisher=selection.system.brand.display_name, - version=selection.system.label) + fmt = self.canned_summary.get(selection.action.mode, self.default_summary) + summary = fmt.format( + action=selection.action.title, + action_lower=selection.action.title.lower(), + model=selection.system.model.display_name, + publisher=selection.system.brand.display_name, + version=selection.system.label, + ) rows = [ Text(summary), ] - super().__init__(controller.model.current, - screen( - rows=rows, - buttons=button_pile(buttons), - focus_buttons=False)) + super().__init__( + controller.model.current, + screen(rows=rows, buttons=button_pile(buttons), focus_buttons=False), + ) def confirm(self, result): self.controller.confirm() diff --git a/console_conf/ui/views/identity.py b/console_conf/ui/views/identity.py index bb5fddbf..9287bd8c 100644 --- a/console_conf/ui/views/identity.py +++ b/console_conf/ui/views/identity.py @@ -14,20 +14,18 @@ # along with this program. If not, see . import logging + from urwid import connect_signal -from subiquitycore.view import BaseView +from subiquitycore.ui.form import EmailField, Form from subiquitycore.ui.utils import SomethingFailed -from subiquitycore.ui.form import ( - Form, - EmailField, -) - +from subiquitycore.view import BaseView log = logging.getLogger("console_conf.ui.views.identity") -sso_help = ("If you do not have an account, visit " - "https://login.ubuntu.com to create one.") +sso_help = ( + "If you do not have an account, visit " "https://login.ubuntu.com to create one." +) class IdentityForm(Form): @@ -45,11 +43,12 @@ class IdentityView(BaseView): self.controller = controller self.form = IdentityForm() - connect_signal(self.form, 'submit', self.done) - connect_signal(self.form, 'cancel', self.cancel) + connect_signal(self.form, "submit", self.done) + connect_signal(self.form, "cancel", self.cancel) - super().__init__(self.form.as_screen(focus_buttons=False, - excerpt=_(self.excerpt))) + super().__init__( + self.form.as_screen(focus_buttons=False, excerpt=_(self.excerpt)) + ) def cancel(self, button=None): self.controller.cancel() diff --git a/console_conf/ui/views/login.py b/console_conf/ui/views/login.py index 67e9c712..e2de4d17 100644 --- a/console_conf/ui/views/login.py +++ b/console_conf/ui/views/login.py @@ -24,7 +24,7 @@ from urwid import Text from subiquitycore.ui.buttons import done_btn from subiquitycore.ui.container import ListBox, Pile -from subiquitycore.ui.utils import button_pile, Padding +from subiquitycore.ui.utils import Padding, button_pile from subiquitycore.view import BaseView log = logging.getLogger("console_conf.ui.views.login") @@ -41,15 +41,23 @@ class LoginView(BaseView): self.items = [] super().__init__( - Pile([ - ('pack', Text("")), - Padding.center_79(ListBox(self._build_model_inputs())), - ('pack', Pile([ - ('pack', Text("")), - button_pile(self._build_buttons()), - ('pack', Text("")), - ])), - ])) + Pile( + [ + ("pack", Text("")), + Padding.center_79(ListBox(self._build_model_inputs())), + ( + "pack", + Pile( + [ + ("pack", Text("")), + button_pile(self._build_buttons()), + ("pack", Text("")), + ] + ), + ), + ] + ) + ) def _build_buttons(self): return [ @@ -68,19 +76,19 @@ class LoginView(BaseView): sl.append(Text("no owner")) return sl - local_tpl = ( - "This device is registered to {realname}.") + local_tpl = "This device is registered to {realname}." remote_tpl = ( "\n\nRemote access was enabled via authentication with SSO user" " <{username}>.\nPublic SSH keys were added to the device " "for remote access.\n\n{realname} can connect remotely to this " - "device via SSH:") + "device via SSH:" + ) sl = [] login_info = { - 'realname': user.realname, - 'username': user.username, + "realname": user.realname, + "username": user.username, } login_text = local_tpl.format(**login_info) login_text += remote_tpl.format(**login_info) diff --git a/console_conf/ui/views/tests/test_chooser.py b/console_conf/ui/views/tests/test_chooser.py index 3b7b8037..95069edf 100644 --- a/console_conf/ui/views/tests/test_chooser.py +++ b/console_conf/ui/views/tests/test_chooser.py @@ -14,16 +14,11 @@ # along with this program. If not, see . import unittest -from console_conf.models.systems import ( - SystemAction, - ) -from console_conf.ui.views.chooser import ( - by_preferred_action_type, - ) +from console_conf.models.systems import SystemAction +from console_conf.ui.views.chooser import by_preferred_action_type class TestSorting(unittest.TestCase): - def test_simple(self): data = [ SystemAction(title="b-other", mode="unknown"), @@ -39,5 +34,4 @@ class TestSorting(unittest.TestCase): SystemAction(title="a-other", mode="some-other"), SystemAction(title="b-other", mode="unknown"), ] - self.assertSequenceEqual(sorted(data, key=by_preferred_action_type), - exp) + self.assertSequenceEqual(sorted(data, key=by_preferred_action_type), exp) diff --git a/subiquity/__init__.py b/subiquity/__init__.py index 3f6b8ffc..b4f4987a 100644 --- a/subiquity/__init__.py +++ b/subiquity/__init__.py @@ -16,8 +16,9 @@ """ Subiquity """ from subiquitycore import i18n + __all__ = [ - 'i18n', + "i18n", ] __version__ = "0.0.5" diff --git a/subiquity/__main__.py b/subiquity/__main__.py index 9ada9ef3..2e18d698 100644 --- a/subiquity/__main__.py +++ b/subiquity/__main__.py @@ -1,5 +1,6 @@ import sys -if __name__ == '__main__': +if __name__ == "__main__": from subiquity.cmd.tui import main + sys.exit(main()) diff --git a/subiquity/client/client.py b/subiquity/client/client.py index 9c8eeab3..6da1285c 100644 --- a/subiquity/client/client.py +++ b/subiquity/client/client.py @@ -25,46 +25,28 @@ from typing import Callable, Dict, List, Optional, Union import aiohttp -from subiquitycore.async_helpers import ( - run_bg_task, - run_in_thread, - ) -from subiquitycore.screen import is_linux_tty -from subiquitycore.tuicontroller import Skip -from subiquitycore.tui import TuiApplication -from subiquitycore.utils import orig_environ -from subiquitycore.view import BaseView - from subiquity.client.controller import Confirm -from subiquity.client.keycodes import ( - NoOpKeycodesFilter, - KeyCodesFilter, - ) +from subiquity.client.keycodes import KeyCodesFilter, NoOpKeycodesFilter from subiquity.common.api.client import make_client_for_conn from subiquity.common.apidef import API -from subiquity.common.errorreport import ( - ErrorReporter, - ) +from subiquity.common.errorreport import ErrorReporter from subiquity.common.serialize import from_json -from subiquity.common.types import ( - ApplicationState, - ErrorReportKind, - ErrorReportRef, - ) +from subiquity.common.types import ApplicationState, ErrorReportKind, ErrorReportRef from subiquity.journald import journald_listen from subiquity.server.server import POSTINSTALL_MODEL_NAMES from subiquity.ui.frame import SubiquityUI from subiquity.ui.views.error import ErrorReportStretchy from subiquity.ui.views.help import HelpMenu, ssh_help_texts -from subiquity.ui.views.installprogress import ( - InstallConfirmation, - ) -from subiquity.ui.views.welcome import ( - CloudInitFail, - ) +from subiquity.ui.views.installprogress import InstallConfirmation +from subiquity.ui.views.welcome import CloudInitFail +from subiquitycore.async_helpers import run_bg_task, run_in_thread +from subiquitycore.screen import is_linux_tty +from subiquitycore.tui import TuiApplication +from subiquitycore.tuicontroller import Skip +from subiquitycore.utils import orig_environ +from subiquitycore.view import BaseView - -log = logging.getLogger('subiquity.client.client') +log = logging.getLogger("subiquity.client.client") class Abort(Exception): @@ -72,7 +54,8 @@ class Abort(Exception): self.error_report_ref = error_report_ref -DEBUG_SHELL_INTRO = _("""\ +DEBUG_SHELL_INTRO = _( + """\ Installer shell session activated. This shell session is running inside the installer environment. You @@ -81,18 +64,19 @@ example by typing Control-D or 'exit'. Be aware that this is an ephemeral environment. Changes to this environment will not survive a reboot. If the install has started, the -installed system will be mounted at /target.""") +installed system will be mounted at /target.""" +) class SubiquityClient(TuiApplication): - - snapd_socket_path: Optional[str] = '/run/snapd.socket' + snapd_socket_path: Optional[str] = "/run/snapd.socket" variant: Optional[str] = None - cmdline = ['snap', 'run', 'subiquity'] - dryrun_cmdline_module = 'subiquity.cmd.tui' + cmdline = ["snap", "run", "subiquity"] + dryrun_cmdline_module = "subiquity.cmd.tui" from subiquity.client import controllers as controllers_mod + project = "subiquity" def make_model(self): @@ -119,7 +103,7 @@ class SubiquityClient(TuiApplication): "Drivers", "SnapList", "Progress", - ] + ] variant_to_controllers: Dict[str, List[str]] = {} @@ -142,11 +126,11 @@ class SubiquityClient(TuiApplication): except OSError: self.our_tty = "not a tty" - self.in_make_view_cvar = contextvars.ContextVar( - 'in_make_view', default=False) + self.in_make_view_cvar = contextvars.ContextVar("in_make_view", default=False) self.error_reporter = ErrorReporter( - self.context.child("ErrorReporter"), self.opts.dry_run, self.root) + self.context.child("ErrorReporter"), self.opts.dry_run, self.root + ) self.note_data_for_apport("SnapUpdated", str(self.updated)) self.note_data_for_apport("UsingAnswers", str(bool(self.answers))) @@ -169,12 +153,9 @@ class SubiquityClient(TuiApplication): def restart(self, remove_last_screen=True, restart_server=False): log.debug(f"restart {remove_last_screen} {restart_server}") if self.fg_proc is not None: - log.debug( - "killing foreground process %s before restarting", - self.fg_proc) + log.debug("killing foreground process %s before restarting", self.fg_proc) self.restarting = True - run_bg_task( - self._kill_fg_proc(remove_last_screen, restart_server)) + run_bg_task(self._kill_fg_proc(remove_last_screen, restart_server)) return if remove_last_screen: self._remove_last_screen() @@ -187,50 +168,56 @@ class SubiquityClient(TuiApplication): self.urwid_loop.stop() cmdline = self.cmdline if self.opts.dry_run: - cmdline = [ - sys.executable, '-m', self.dryrun_cmdline_module, - ] + sys.argv[1:] + ['--socket', self.opts.socket] + cmdline = ( + [ + sys.executable, + "-m", + self.dryrun_cmdline_module, + ] + + sys.argv[1:] + + ["--socket", self.opts.socket] + ) if self.opts.server_pid is not None: - cmdline.extend(['--server-pid', self.opts.server_pid]) + cmdline.extend(["--server-pid", self.opts.server_pid]) log.debug("restarting %r", cmdline) os.execvpe(cmdline[0], cmdline, orig_environ(os.environ)) def resp_hook(self, response): headers = response.headers - if 'x-updated' in headers: + if "x-updated" in headers: if self.server_updated is None: - self.server_updated = headers['x-updated'] - elif self.server_updated != headers['x-updated']: + self.server_updated = headers["x-updated"] + elif self.server_updated != headers["x-updated"]: self.restart(remove_last_screen=False) raise Abort - status = headers.get('x-status') - if status == 'skip': + status = headers.get("x-status") + if status == "skip": raise Skip - elif status == 'confirm': + elif status == "confirm": raise Confirm - if headers.get('x-error-report') is not None: - ref = from_json(ErrorReportRef, headers['x-error-report']) + if headers.get("x-error-report") is not None: + ref = from_json(ErrorReportRef, headers["x-error-report"]) raise Abort(ref) try: response.raise_for_status() except aiohttp.ClientError: report = self.error_reporter.make_apport_report( ErrorReportKind.SERVER_REQUEST_FAIL, - "request to {}".format(response.url.path)) + "request to {}".format(response.url.path), + ) raise Abort(report.ref()) return response async def noninteractive_confirmation(self): await asyncio.sleep(1) - yes = _('yes') - no = _('no') + yes = _("yes") + no = _("no") answer = no print(_("Confirmation is required to continue.")) print(_("Add 'autoinstall' to your kernel command line to avoid this")) print() - prompt = "\n\n{} ({}|{})".format( - _("Continue with autoinstall?"), yes, no) + prompt = "\n\n{} ({}|{})".format(_("Continue with autoinstall?"), yes, no) while answer != yes: print(prompt) answer = await run_in_thread(input) @@ -260,7 +247,8 @@ class SubiquityClient(TuiApplication): if app_state == ApplicationState.NEEDS_CONFIRMATION: if confirm_task is None: confirm_task = asyncio.create_task( - self.noninteractive_confirmation()) + self.noninteractive_confirmation() + ) elif confirm_task is not None: confirm_task.cancel() confirm_task = None @@ -272,21 +260,20 @@ class SubiquityClient(TuiApplication): app_status = await self._status_get(app_state) def subiquity_event_noninteractive(self, event): - if event['SUBIQUITY_EVENT_TYPE'] == 'start': - print('start: ' + event["MESSAGE"]) - elif event['SUBIQUITY_EVENT_TYPE'] == 'finish': - print('finish: ' + event["MESSAGE"]) + if event["SUBIQUITY_EVENT_TYPE"] == "start": + print("start: " + event["MESSAGE"]) + elif event["SUBIQUITY_EVENT_TYPE"] == "finish": + print("finish: " + event["MESSAGE"]) async def connect(self): - def p(s): - print(s, end='', flush=True) + print(s, end="", flush=True) async def spin(message): - p(message + '... ') + p(message + "... ") while True: - for t in ['-', '\\', '|', '/']: - p('\x08' + t) + for t in ["-", "\\", "|", "/"]: + p("\x08" + t) await asyncio.sleep(0.5) async def spinning_wait(message, task): @@ -295,16 +282,18 @@ class SubiquityClient(TuiApplication): return await task finally: spinner.cancel() - p('\x08 \n') + p("\x08 \n") status = await spinning_wait("connecting", self._status_get()) - journald_listen([status.echo_syslog_id], lambda e: print(e['MESSAGE'])) + journald_listen([status.echo_syslog_id], lambda e: print(e["MESSAGE"])) if status.state == ApplicationState.STARTING_UP: status = await spinning_wait( - "starting up", self._status_get(cur=status.state)) + "starting up", self._status_get(cur=status.state) + ) if status.state == ApplicationState.CLOUD_INIT_WAIT: status = await spinning_wait( - "waiting for cloud-init", self._status_get(cur=status.state)) + "waiting for cloud-init", self._status_get(cur=status.state) + ) if status.state == ApplicationState.EARLY_COMMANDS: print("running early commands") status = await self._status_get(cur=status.state) @@ -316,12 +305,13 @@ class SubiquityClient(TuiApplication): def header_func(): if self.in_make_view_cvar.get(): - return {'x-make-view-request': 'yes'} + return {"x-make-view-request": "yes"} else: return None - self.client = make_client_for_conn(API, conn, self.resp_hook, - header_func=header_func) + self.client = make_client_for_conn( + API, conn, self.resp_hook, header_func=header_func + ) self.error_reporter.client = self.client status = await self.connect() @@ -332,11 +322,14 @@ class SubiquityClient(TuiApplication): texts = ssh_help_texts(ssh_info) for line in texts: import urwid + if isinstance(line, urwid.Widget): - line = '\n'.join([ - line.decode('utf-8').rstrip() - for line in line.render((1000,)).text - ]) + line = "\n".join( + [ + line.decode("utf-8").rstrip() + for line in line.render((1000,)).text + ] + ) print(line) return @@ -355,11 +348,11 @@ class SubiquityClient(TuiApplication): # the progress page if hasattr(self.controllers, "Progress"): journald_listen( - [status.event_syslog_id], - self.controllers.Progress.event) + [status.event_syslog_id], self.controllers.Progress.event + ) journald_listen( - [status.log_syslog_id], - self.controllers.Progress.log_line) + [status.log_syslog_id], self.controllers.Progress.log_line + ) if not status.cloud_init_ok: self.add_global_overlay(CloudInitFail(self)) self.error_reporter.load_reports() @@ -376,18 +369,16 @@ class SubiquityClient(TuiApplication): # does not matter as the settings will soon be clobbered but # for a non-interactive one we need to clear things up or the # prompting for confirmation will be confusing. - os.system('stty sane') + os.system("stty sane") journald_listen( - [status.event_syslog_id], - self.subiquity_event_noninteractive, - seek=True) - run_bg_task( - self.noninteractive_watch_app_state(status)) + [status.event_syslog_id], self.subiquity_event_noninteractive, seek=True + ) + run_bg_task(self.noninteractive_watch_app_state(status)) def _exception_handler(self, loop, context): - exc = context.get('exception') + exc = context.get("exception") if self.restarting: - log.debug('ignoring %s %s during restart', exc, type(exc)) + log.debug("ignoring %s %s during restart", exc, type(exc)) return if isinstance(exc, Abort): if self.interactive: @@ -405,8 +396,8 @@ class SubiquityClient(TuiApplication): print("generating crash report") try: report = self.make_apport_report( - ErrorReportKind.UI, "Installer UI", interrupt=False, - wait=True) + ErrorReportKind.UI, "Installer UI", interrupt=False, wait=True + ) if report is not None: print("report saved to {path}".format(path=report.path)) except Exception: @@ -426,7 +417,7 @@ class SubiquityClient(TuiApplication): # the server up to a second to exit, and then we signal it. pid = int(self.opts.server_pid) - print(f'giving the server [{pid}] up to a second to exit') + print(f"giving the server [{pid}] up to a second to exit") for unused in range(10): try: if os.waitpid(pid, os.WNOHANG) != (0, 0): @@ -435,9 +426,9 @@ class SubiquityClient(TuiApplication): # If we attached to an existing server process, # waitpid will fail. pass - await asyncio.sleep(.1) + await asyncio.sleep(0.1) else: - print('killing server {}'.format(pid)) + print("killing server {}".format(pid)) os.kill(pid, 2) os.waitpid(pid, 0) @@ -448,21 +439,19 @@ class SubiquityClient(TuiApplication): if source.id == source_selection.current_id: current = source break - if current is not None and current.variant != 'server': + if current is not None and current.variant != "server": # If using server to install desktop, mark the controllers # the TUI client does not currently have interfaces for as # configured. needed = POSTINSTALL_MODEL_NAMES.for_variant(current.variant) for c in self.controllers.instances: - if getattr(c, 'endpoint_name', None) is not None: + if getattr(c, "endpoint_name", None) is not None: needed.discard(c.endpoint_name) if needed: - log.info( - "marking additional endpoints as configured: %s", - needed) + log.info("marking additional endpoints as configured: %s", needed) await self.client.meta.mark_configured.POST(list(needed)) # TODO: remove this when TUI gets an Active Directory screen: - await self.client.meta.mark_configured.POST(['active_directory']) + await self.client.meta.mark_configured.POST(["active_directory"]) await self.client.meta.confirm.POST(self.our_tty) def add_global_overlay(self, overlay): @@ -477,7 +466,7 @@ class SubiquityClient(TuiApplication): self.ui.body.remove_overlay(overlay) def _remove_last_screen(self): - last_screen = self.state_path('last-screen') + last_screen = self.state_path("last-screen") if os.path.exists(last_screen): os.unlink(last_screen) @@ -488,7 +477,7 @@ class SubiquityClient(TuiApplication): def select_initial_screen(self): last_screen = None if self.updated: - state_path = self.state_path('last-screen') + state_path = self.state_path("last-screen") if os.path.exists(state_path): with open(state_path) as fp: last_screen = fp.read().strip() @@ -502,7 +491,7 @@ class SubiquityClient(TuiApplication): async def _select_initial_screen(self, index): endpoint_names = [] for c in self.controllers.instances[:index]: - if getattr(c, 'endpoint_name', None) is not None: + if getattr(c, "endpoint_name", None) is not None: endpoint_names.append(c.endpoint_name) if endpoint_names: await self.client.meta.mark_configured.POST(endpoint_names) @@ -521,11 +510,12 @@ class SubiquityClient(TuiApplication): log.debug("showing InstallConfirmation over %s", self.ui.body) overlay = InstallConfirmation(self) self.add_global_overlay(overlay) - if self.answers.get('filesystem-confirmed', False): + if self.answers.get("filesystem-confirmed", False): overlay.ok(None) async def _start_answers_for_view( - self, controller, view: Union[BaseView, Callable[[], BaseView]]): + self, controller, view: Union[BaseView, Callable[[], BaseView]] + ): def noop(): return view @@ -551,19 +541,19 @@ class SubiquityClient(TuiApplication): self.in_make_view_cvar.reset(tok) if new.answers: run_bg_task(self._start_answers_for_view(new, view)) - with open(self.state_path('last-screen'), 'w') as fp: + with open(self.state_path("last-screen"), "w") as fp: fp.write(new.name) return view def show_progress(self): - if hasattr(self.controllers, 'Progress'): + if hasattr(self.controllers, "Progress"): self.ui.set_body(self.controllers.Progress.progress_view) def unhandled_input(self, key): - if key == 'f1': + if key == "f1": if not self.ui.right_icon.current_help: self.ui.right_icon.open_pop_up() - elif key in ['ctrl z', 'f2']: + elif key in ["ctrl z", "f2"]: self.debug_shell() elif self.opts.dry_run: self.unhandled_input_dry_run(key) @@ -571,22 +561,22 @@ class SubiquityClient(TuiApplication): super().unhandled_input(key) def unhandled_input_dry_run(self, key): - if key in ['ctrl e', 'ctrl r']: - interrupt = key == 'ctrl e' + if key in ["ctrl e", "ctrl r"]: + interrupt = key == "ctrl e" try: - 1/0 + 1 / 0 except ZeroDivisionError: self.make_apport_report( - ErrorReportKind.UNKNOWN, "example", interrupt=interrupt) - elif key == 'ctrl u': - 1/0 - elif key == 'ctrl b': + ErrorReportKind.UNKNOWN, "example", interrupt=interrupt + ) + elif key == "ctrl u": + 1 / 0 + elif key == "ctrl b": run_bg_task(self.client.dry_run.crash.GET()) else: super().unhandled_input(key) def debug_shell(self, after_hook=None): - def _before(): os.system("clear") print(DEBUG_SHELL_INTRO) @@ -594,7 +584,8 @@ class SubiquityClient(TuiApplication): env = orig_environ(os.environ) cmd = ["bash"] self.run_command_in_foreground( - cmd, env=env, before_hook=_before, after_hook=after_hook, cwd='/') + cmd, env=env, before_hook=_before, after_hook=after_hook, cwd="/" + ) def note_file_for_apport(self, key, path): self.error_reporter.note_file_for_apport(key, path) @@ -603,8 +594,7 @@ class SubiquityClient(TuiApplication): self.error_reporter.note_data_for_apport(key, value) def make_apport_report(self, kind, thing, *, interrupt, wait=False, **kw): - report = self.error_reporter.make_apport_report( - kind, thing, wait=wait, **kw) + report = self.error_reporter.make_apport_report(kind, thing, wait=wait, **kw) if report is not None and interrupt: self.show_error_report(report.ref()) @@ -614,7 +604,7 @@ class SubiquityClient(TuiApplication): def show_error_report(self, error_ref): log.debug("show_error_report %r", error_ref.base) if isinstance(self.ui.body, BaseView): - w = getattr(self.ui.body._w, 'stretchy', None) + w = getattr(self.ui.body._w, "stretchy", None) if isinstance(w, ErrorReportStretchy): # Don't show an error if already looking at one. return diff --git a/subiquity/client/controller.py b/subiquity/client/controller.py index cbff6798..8d4d5005 100644 --- a/subiquity/client/controller.py +++ b/subiquity/client/controller.py @@ -16,9 +16,7 @@ import logging from typing import Optional -from subiquitycore.tuicontroller import ( - TuiController, - ) +from subiquitycore.tuicontroller import TuiController log = logging.getLogger("subiquity.client.controller") @@ -28,7 +26,6 @@ class Confirm(Exception): class SubiquityTuiController(TuiController): - endpoint_name: Optional[str] = None def __init__(self, app): diff --git a/subiquity/client/controllers/__init__.py b/subiquity/client/controllers/__init__.py index 47abe8e6..f3434d7b 100644 --- a/subiquity/client/controllers/__init__.py +++ b/subiquity/client/controllers/__init__.py @@ -14,6 +14,7 @@ # along with this program. If not, see . from subiquitycore.tuicontroller import RepeatedController + from .drivers import DriversController from .filesystem import FilesystemController from .identity import IdentityController @@ -33,21 +34,21 @@ from .zdev import ZdevController # see SubiquityClient.controllers for another list __all__ = [ - 'DriversController', - 'FilesystemController', - 'IdentityController', - 'KeyboardController', - 'MirrorController', - 'NetworkController', - 'ProgressController', - 'ProxyController', - 'RefreshController', - 'RepeatedController', - 'SerialController', - 'SnapListController', - 'SourceController', - 'SSHController', - 'UbuntuProController', - 'WelcomeController', - 'ZdevController', + "DriversController", + "FilesystemController", + "IdentityController", + "KeyboardController", + "MirrorController", + "NetworkController", + "ProgressController", + "ProxyController", + "RefreshController", + "RepeatedController", + "SerialController", + "SnapListController", + "SourceController", + "SSHController", + "UbuntuProController", + "WelcomeController", + "ZdevController", ] diff --git a/subiquity/client/controllers/drivers.py b/subiquity/client/controllers/drivers.py index 339f5193..8cbd7581 100644 --- a/subiquity/client/controllers/drivers.py +++ b/subiquity/client/controllers/drivers.py @@ -17,18 +17,16 @@ import asyncio import logging from typing import List +from subiquity.client.controller import SubiquityTuiController +from subiquity.common.types import DriversPayload, DriversResponse +from subiquity.ui.views.drivers import DriversView, DriversViewStatus from subiquitycore.tuicontroller import Skip -from subiquity.common.types import DriversPayload, DriversResponse -from subiquity.client.controller import SubiquityTuiController -from subiquity.ui.views.drivers import DriversView, DriversViewStatus - -log = logging.getLogger('subiquity.client.controllers.drivers') +log = logging.getLogger("subiquity.client.controllers.drivers") class DriversController(SubiquityTuiController): - - endpoint_name = 'drivers' + endpoint_name = "drivers" async def make_ui(self) -> DriversView: response: DriversResponse = await self.endpoint.GET() @@ -37,8 +35,9 @@ class DriversController(SubiquityTuiController): await self.endpoint.POST(DriversPayload(install=False)) raise Skip - return DriversView(self, response.drivers, - response.install, response.local_only) + return DriversView( + self, response.drivers, response.install, response.local_only + ) async def _wait_drivers(self) -> List[str]: response: DriversResponse = await self.endpoint.GET(wait=True) @@ -46,12 +45,10 @@ class DriversController(SubiquityTuiController): return response.drivers async def run_answers(self): - if 'install' not in self.answers: + if "install" not in self.answers: return - from subiquitycore.testing.view_helpers import ( - click, - ) + from subiquitycore.testing.view_helpers import click view = self.app.ui.body while view.status == DriversViewStatus.WAITING: @@ -60,7 +57,7 @@ class DriversController(SubiquityTuiController): click(view.cont_btn.base_widget) return - view.form.install.value = self.answers['install'] + view.form.install.value = self.answers["install"] click(view.form.done_btn.base_widget) @@ -69,5 +66,4 @@ class DriversController(SubiquityTuiController): def done(self, install: bool) -> None: log.debug("DriversController.done next_screen install=%s", install) - self.app.next_screen( - self.endpoint.POST(DriversPayload(install=install))) + self.app.next_screen(self.endpoint.POST(DriversPayload(install=install))) diff --git a/subiquity/client/controllers/filesystem.py b/subiquity/client/controllers/filesystem.py index 6a3eaf70..fed65685 100644 --- a/subiquity/client/controllers/filesystem.py +++ b/subiquity/client/controllers/filesystem.py @@ -17,50 +17,37 @@ import asyncio import logging from typing import Callable, Optional -from subiquitycore.async_helpers import run_bg_task -from subiquitycore.lsb_release import lsb_release -from subiquitycore.view import BaseView - from subiquity.client.controller import SubiquityTuiController from subiquity.common.filesystem import gaps from subiquity.common.filesystem.manipulator import FilesystemManipulator from subiquity.common.types import ( - ProbeStatus, GuidedCapability, GuidedChoiceV2, GuidedStorageResponseV2, GuidedStorageTargetManual, + ProbeStatus, StorageResponseV2, - ) -from subiquity.models.filesystem import ( - Bootloader, - FilesystemModel, - raidlevels_by_value, - ) -from subiquity.ui.views import ( - FilesystemView, - GuidedDiskSelectionView, - ) -from subiquity.ui.views.filesystem.probing import ( - SlowProbing, - ProbingFailed, - ) - +) +from subiquity.models.filesystem import Bootloader, FilesystemModel, raidlevels_by_value +from subiquity.ui.views import FilesystemView, GuidedDiskSelectionView +from subiquity.ui.views.filesystem.probing import ProbingFailed, SlowProbing +from subiquitycore.async_helpers import run_bg_task +from subiquitycore.lsb_release import lsb_release +from subiquitycore.view import BaseView log = logging.getLogger("subiquity.client.controllers.filesystem") class FilesystemController(SubiquityTuiController, FilesystemManipulator): - - endpoint_name = 'storage' + endpoint_name = "storage" def __init__(self, app): super().__init__(app) self.model = None - self.answers.setdefault('guided', False) - self.answers.setdefault('tpm-default', False) - self.answers.setdefault('guided-index', 0) - self.answers.setdefault('manual', []) + self.answers.setdefault("guided", False) + self.answers.setdefault("tpm-default", False) + self.answers.setdefault("guided-index", 0) + self.answers.setdefault("manual", []) self.current_view: Optional[BaseView] = None async def make_ui(self) -> Callable[[], BaseView]: @@ -90,23 +77,19 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator): if isinstance(self.ui.body, SlowProbing): self.ui.set_body(self.current_view) else: - log.debug("not refreshing the display. Current display is %r", - self.ui.body) + log.debug("not refreshing the display. Current display is %r", self.ui.body) async def make_guided_ui( - self, - status: GuidedStorageResponseV2, - ) -> GuidedDiskSelectionView: + self, + status: GuidedStorageResponseV2, + ) -> GuidedDiskSelectionView: if status.status == ProbeStatus.FAILED: self.app.show_error_report(status.error_report) return ProbingFailed(self, status.error_report) - response: StorageResponseV2 = await self.endpoint.v2.GET( - include_raid=True) + response: StorageResponseV2 = await self.endpoint.v2.GET(include_raid=True) - disk_by_id = { - disk.id: disk for disk in response.disks - } + disk_by_id = {disk.id: disk for disk in response.disks} if status.error_report: self.app.show_error_report(status.error_report) @@ -118,48 +101,46 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator): while not isinstance(self.ui.body, GuidedDiskSelectionView): await asyncio.sleep(0.1) - if self.answers['tpm-default']: - self.app.answers['filesystem-confirmed'] = True + if self.answers["tpm-default"]: + self.app.answers["filesystem-confirmed"] = True self.ui.body.done(self.ui.body.form) - if self.answers['guided']: + if self.answers["guided"]: targets = self.ui.body.form.targets - if 'guided-index' in self.answers: - target = targets[self.answers['guided-index']] - elif 'guided-label' in self.answers: - label = self.answers['guided-label'] + if "guided-index" in self.answers: + target = targets[self.answers["guided-index"]] + elif "guided-label" in self.answers: + label = self.answers["guided-label"] disk_by_id = self.ui.body.form.disk_by_id - [target] = [ - t - for t in targets - if disk_by_id[t.disk_id].label == label - ] - method = self.answers.get('guided-method') + [target] = [t for t in targets if disk_by_id[t.disk_id].label == label] + method = self.answers.get("guided-method") value = { - 'disk': target, - 'use_lvm': method == "lvm", - } - passphrase = self.answers.get('guided-passphrase') + "disk": target, + "use_lvm": method == "lvm", + } + passphrase = self.answers.get("guided-passphrase") if passphrase is not None: - value['lvm_options'] = { - 'encrypt': True, - 'luks_options': { - 'passphrase': passphrase, - 'confirm_passphrase': passphrase, - } - } + value["lvm_options"] = { + "encrypt": True, + "luks_options": { + "passphrase": passphrase, + "confirm_passphrase": passphrase, + }, + } self.ui.body.form.guided_choice.value = value self.ui.body.done(None) - self.app.answers['filesystem-confirmed'] = True + self.app.answers["filesystem-confirmed"] = True while not isinstance(self.ui.body, FilesystemView): await asyncio.sleep(0.1) self.finish() - elif self.answers['manual']: + elif self.answers["manual"]: await self._guided_choice( GuidedChoiceV2( target=GuidedStorageTargetManual(), - capability=GuidedCapability.MANUAL)) - await self._run_actions(self.answers['manual']) - self.answers['manual'] = [] + capability=GuidedCapability.MANUAL, + ) + ) + await self._run_actions(self.answers["manual"]) + self.answers["manual"] = [] def _action_get(self, id): dev_spec = id[0].split() @@ -168,7 +149,7 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator): if dev_spec[1] == "index": dev = self.model.all_disks()[int(dev_spec[2])] elif dev_spec[1] == "serial": - dev = self.model._one(type='disk', serial=dev_spec[2]) + dev = self.model._one(type="disk", serial=dev_spec[2]) elif dev_spec[0] == "raid": if dev_spec[1] == "name": for r in self.model.all_raids(): @@ -192,16 +173,13 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator): raise Exception("could not resolve {}".format(id)) def _action_clean_devices_raid(self, devices): - r = { - self._action_get(d): v - for d, v in zip(devices[::2], devices[1::2]) - } + r = {self._action_get(d): v for d, v in zip(devices[::2], devices[1::2])} for d in r: assert d.ok_for_raid return r def _action_clean_devices_vg(self, devices): - r = {self._action_get(d): 'active' for d in devices} + r = {self._action_get(d): "active" for d in devices} for d in r: assert d.ok_for_lvm_vg return r @@ -210,12 +188,13 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator): return raidlevels_by_value[level] async def _answers_action(self, action): - from subiquitycore.ui.stretchy import StretchyOverlay from subiquity.ui.views.filesystem.delete import ConfirmDeleteStretchy + from subiquitycore.ui.stretchy import StretchyOverlay + log.debug("_answers_action %r", action) - if 'obj' in action: - obj = self._action_get(action['obj']) - action_name = action['action'] + if "obj" in action: + obj = self._action_get(action["obj"]) + action_name = action["action"] if action_name == "MAKE_BOOT": action_name = "TOGGLE_BOOT" if action_name == "CREATE_LV": @@ -223,8 +202,8 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator): if action_name == "PARTITION": obj = gaps.largest_gap(obj) meth = getattr( - self.ui.body.avail_list, - "_{}_{}".format(obj.type, action_name)) + self.ui.body.avail_list, "_{}_{}".format(obj.type, action_name) + ) meth(obj) yield body = self.ui.body._w @@ -235,34 +214,35 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator): body.stretchy.done() else: async for _ in self._enter_form_data( - body.stretchy.form, - action['data'], - action.get("submit", True)): + body.stretchy.form, action["data"], action.get("submit", True) + ): pass - elif action['action'] == 'create-raid': + elif action["action"] == "create-raid": self.ui.body.create_raid() yield body = self.ui.body._w async for _ in self._enter_form_data( - body.stretchy.form, - action['data'], - action.get("submit", True), - clean_suffix='raid'): + body.stretchy.form, + action["data"], + action.get("submit", True), + clean_suffix="raid", + ): pass - elif action['action'] == 'create-vg': + elif action["action"] == "create-vg": self.ui.body.create_vg() yield body = self.ui.body._w async for _ in self._enter_form_data( - body.stretchy.form, - action['data'], - action.get("submit", True), - clean_suffix='vg'): + body.stretchy.form, + action["data"], + action.get("submit", True), + clean_suffix="vg", + ): pass - elif action['action'] == 'done': + elif action["action"] == "done": if not self.ui.body.done_btn.enabled: raise Exception("answers did not provide complete fs config") - self.app.answers['filesystem-confirmed'] = True + self.app.answers["filesystem-confirmed"] = True self.finish() else: raise Exception("could not process action {}".format(action)) @@ -278,8 +258,8 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator): if self.model.bootloader == Bootloader.PREP: self.supports_resilient_boot = False else: - release = lsb_release(dry_run=self.app.opts.dry_run)['release'] - self.supports_resilient_boot = release >= '20.04' + release = lsb_release(dry_run=self.app.opts.dry_run)["release"] + self.supports_resilient_boot = release >= "20.04" self.ui.set_body(FilesystemView(self.model, self)) def guided_choice(self, choice): diff --git a/subiquity/client/controllers/identity.py b/subiquity/client/controllers/identity.py index 8a6c720d..5cca279d 100644 --- a/subiquity/client/controllers/identity.py +++ b/subiquity/client/controllers/identity.py @@ -19,34 +19,34 @@ from subiquity.client.controller import SubiquityTuiController from subiquity.common.types import IdentityData from subiquity.ui.views import IdentityView -log = logging.getLogger('subiquity.client.controllers.identity') +log = logging.getLogger("subiquity.client.controllers.identity") class IdentityController(SubiquityTuiController): - - endpoint_name = 'identity' + endpoint_name = "identity" async def make_ui(self): data = await self.endpoint.GET() return IdentityView(self, data) def run_answers(self): - if all(elem in self.answers for elem in - ['realname', 'username', 'password', 'hostname']): + if all( + elem in self.answers + for elem in ["realname", "username", "password", "hostname"] + ): identity = IdentityData( - realname=self.answers['realname'], - username=self.answers['username'], - hostname=self.answers['hostname'], - crypted_password=self.answers['password']) + realname=self.answers["realname"], + username=self.answers["username"], + hostname=self.answers["hostname"], + crypted_password=self.answers["password"], + ) self.done(identity) def cancel(self): self.app.prev_screen() def done(self, identity_data): - log.debug( - "IdentityController.done next_screen user_spec=%s", - identity_data) + log.debug("IdentityController.done next_screen user_spec=%s", identity_data) self.app.next_screen(self.endpoint.POST(identity_data)) async def validate_username(self, username): diff --git a/subiquity/client/controllers/keyboard.py b/subiquity/client/controllers/keyboard.py index cc6d8f6f..b2ba37d0 100644 --- a/subiquity/client/controllers/keyboard.py +++ b/subiquity/client/controllers/keyboard.py @@ -20,21 +20,20 @@ from subiquity.client.controller import SubiquityTuiController from subiquity.common.types import KeyboardSetting from subiquity.ui.views import KeyboardView -log = logging.getLogger('subiquity.client.controllers.keyboard') +log = logging.getLogger("subiquity.client.controllers.keyboard") class KeyboardController(SubiquityTuiController): - - endpoint_name = 'keyboard' + endpoint_name = "keyboard" async def make_ui(self): setup = await self.endpoint.GET() return KeyboardView(self, setup) async def run_answers(self): - if 'layout' in self.answers: - layout = self.answers['layout'] - variant = self.answers.get('variant', '') + if "layout" in self.answers: + layout = self.answers["layout"] + variant = self.answers.get("variant", "") await self.apply(KeyboardSetting(layout=layout, variant=variant)) self.done() @@ -43,7 +42,8 @@ class KeyboardController(SubiquityTuiController): async def needs_toggle(self, setting): return await self.endpoint.needs_toggle.GET( - layout_code=setting.layout, variant_code=setting.variant) + layout_code=setting.layout, variant_code=setting.variant + ) async def apply(self, setting): await self.endpoint.POST(setting) diff --git a/subiquity/client/controllers/mirror.py b/subiquity/client/controllers/mirror.py index 64adac34..26082660 100644 --- a/subiquity/client/controllers/mirror.py +++ b/subiquity/client/controllers/mirror.py @@ -16,22 +16,16 @@ import asyncio import logging +from subiquity.client.controller import SubiquityTuiController +from subiquity.common.types import MirrorCheckStatus, MirrorGet, MirrorPost +from subiquity.ui.views.mirror import MirrorView from subiquitycore.tuicontroller import Skip -from subiquity.common.types import ( - MirrorCheckStatus, - MirrorGet, - MirrorPost, - ) -from subiquity.client.controller import SubiquityTuiController -from subiquity.ui.views.mirror import MirrorView - -log = logging.getLogger('subiquity.client.controllers.mirror') +log = logging.getLogger("subiquity.client.controllers.mirror") class MirrorController(SubiquityTuiController): - - endpoint_name = 'mirror' + endpoint_name = "mirror" async def make_ui(self): mirror_response: MirrorGet = await self.endpoint.GET() @@ -57,19 +51,18 @@ class MirrorController(SubiquityTuiController): async def run_answers(self): async def wait_mirror_check() -> None: - """ Wait until the mirror check has finished running. """ + """Wait until the mirror check has finished running.""" while True: last_status = self.app.ui.body.last_status if last_status not in [None, MirrorCheckStatus.RUNNING]: return - await asyncio.sleep(.1) + await asyncio.sleep(0.1) - if 'mirror' in self.answers: - self.app.ui.body.form.url.value = self.answers['mirror'] + if "mirror" in self.answers: + self.app.ui.body.form.url.value = self.answers["mirror"] await wait_mirror_check() self.app.ui.body.form._click_done(None) - elif 'country-code' in self.answers \ - or 'accept-default' in self.answers: + elif "country-code" in self.answers or "accept-default" in self.answers: await wait_mirror_check() self.app.ui.body.form._click_done(None) @@ -78,5 +71,4 @@ class MirrorController(SubiquityTuiController): def done(self, mirror): log.debug("MirrorController.done next_screen mirror=%s", mirror) - self.app.next_screen(self.endpoint.POST( - MirrorPost(elected=mirror))) + self.app.next_screen(self.endpoint.POST(MirrorPost(elected=mirror))) diff --git a/subiquity/client/controllers/network.py b/subiquity/client/controllers/network.py index 852f9f50..c6319af1 100644 --- a/subiquity/client/controllers/network.py +++ b/subiquity/client/controllers/network.py @@ -21,6 +21,10 @@ from typing import Optional from aiohttp import web +from subiquity.client.controller import SubiquityTuiController +from subiquity.common.api.server import make_server_at_path +from subiquity.common.apidef import LinkAction, NetEventAPI +from subiquity.common.types import ErrorReportKind, PackageInstallState from subiquitycore.async_helpers import run_bg_task from subiquitycore.controllers.network import NetworkAnswersMixin from subiquitycore.models.network import ( @@ -28,23 +32,14 @@ from subiquitycore.models.network import ( NetDevInfo, StaticConfig, WLANConfig, - ) +) from subiquitycore.ui.views.network import NetworkView -from subiquity.client.controller import SubiquityTuiController -from subiquity.common.api.server import make_server_at_path -from subiquity.common.apidef import LinkAction, NetEventAPI -from subiquity.common.types import ( - ErrorReportKind, - PackageInstallState, - ) - -log = logging.getLogger('subiquity.client.controllers.network') +log = logging.getLogger("subiquity.client.controllers.network") class NetworkController(SubiquityTuiController, NetworkAnswersMixin): - - endpoint_name = 'network' + endpoint_name = "network" def __init__(self, app): super().__init__(app) @@ -53,18 +48,18 @@ class NetworkController(SubiquityTuiController, NetworkAnswersMixin): @web.middleware async def middleware(self, request, handler): resp = await handler(request) - if resp.get('exception'): - exc = resp['exception'] - log.debug( - 'request to {} crashed'.format(request.raw_path), exc_info=exc) + if resp.get("exception"): + exc = resp["exception"] + log.debug("request to {} crashed".format(request.raw_path), exc_info=exc) self.app.make_apport_report( ErrorReportKind.NETWORK_CLIENT_FAIL, "request to {}".format(request.raw_path), - exc=exc, interrupt=True) + exc=exc, + interrupt=True, + ) return resp - async def update_link_POST(self, act: LinkAction, - info: NetDevInfo) -> None: + async def update_link_POST(self, act: LinkAction, info: NetDevInfo) -> None: if self.view is None: return if act == LinkAction.NEW: @@ -91,15 +86,17 @@ class NetworkController(SubiquityTuiController, NetworkAnswersMixin): self.view.show_network_error(stage) async def wlan_support_install_finished_POST( - self, state: PackageInstallState) -> None: + self, state: PackageInstallState + ) -> None: if self.view: self.view.update_for_wlan_support_install_state(state.name) async def subscribe(self): self.tdir = tempfile.mkdtemp() - self.sock_path = os.path.join(self.tdir, 'socket') + self.sock_path = os.path.join(self.tdir, "socket") self.site = await make_server_at_path( - self.sock_path, NetEventAPI, self, middlewares=[self.middleware]) + self.sock_path, NetEventAPI, self, middlewares=[self.middleware] + ) await self.endpoint.subscription.PUT(self.sock_path) async def unsubscribe(self): @@ -110,8 +107,8 @@ class NetworkController(SubiquityTuiController, NetworkAnswersMixin): async def make_ui(self): network_status = await self.endpoint.GET() self.view = NetworkView( - self, network_status.devices, - network_status.wlan_support_install_state.name) + self, network_status.devices, network_status.wlan_support_install_state.name + ) await self.subscribe() return self.view @@ -126,40 +123,37 @@ class NetworkController(SubiquityTuiController, NetworkAnswersMixin): def done(self): self.app.next_screen(self.endpoint.POST()) - def set_static_config(self, dev_name: str, ip_version: int, - static_config: StaticConfig) -> None: + def set_static_config( + self, dev_name: str, ip_version: int, static_config: StaticConfig + ) -> None: run_bg_task( - self.endpoint.set_static_config.POST( - dev_name, ip_version, static_config)) + self.endpoint.set_static_config.POST(dev_name, ip_version, static_config) + ) def enable_dhcp(self, dev_name, ip_version: int) -> None: - run_bg_task( - self.endpoint.enable_dhcp.POST(dev_name, ip_version)) + run_bg_task(self.endpoint.enable_dhcp.POST(dev_name, ip_version)) def disable_network(self, dev_name: str, ip_version: int) -> None: - run_bg_task( - self.endpoint.disable.POST(dev_name, ip_version)) + run_bg_task(self.endpoint.disable.POST(dev_name, ip_version)) def add_vlan(self, dev_name: str, vlan_id: int): - run_bg_task( - self.endpoint.vlan.PUT(dev_name, vlan_id)) + run_bg_task(self.endpoint.vlan.PUT(dev_name, vlan_id)) def set_wlan(self, dev_name: str, wlan: WLANConfig) -> None: - run_bg_task( - self.endpoint.set_wlan.POST(dev_name, wlan)) + run_bg_task(self.endpoint.set_wlan.POST(dev_name, wlan)) def start_scan(self, dev_name: str) -> None: - run_bg_task( - self.endpoint.start_scan.POST(dev_name)) + run_bg_task(self.endpoint.start_scan.POST(dev_name)) def delete_link(self, dev_name: str): run_bg_task(self.endpoint.delete.POST(dev_name)) - def add_or_update_bond(self, existing_name: Optional[str], - new_name: str, new_info: BondConfig) -> None: + def add_or_update_bond( + self, existing_name: Optional[str], new_name: str, new_info: BondConfig + ) -> None: run_bg_task( - self.endpoint.add_or_edit_bond.POST( - existing_name, new_name, new_info)) + self.endpoint.add_or_edit_bond.POST(existing_name, new_name, new_info) + ) async def get_info_for_netdev(self, dev_name: str) -> str: return await self.endpoint.info.GET(dev_name) diff --git a/subiquity/client/controllers/progress.py b/subiquity/client/controllers/progress.py index ab0129cf..3e3a7705 100644 --- a/subiquity/client/controllers/progress.py +++ b/subiquity/client/controllers/progress.py @@ -18,25 +18,16 @@ import logging import aiohttp +from subiquity.client.controller import SubiquityTuiController +from subiquity.common.types import ApplicationState, ShutdownMode +from subiquity.ui.views.installprogress import InstallRunning, ProgressView from subiquitycore.async_helpers import run_bg_task from subiquitycore.context import with_context -from subiquity.client.controller import SubiquityTuiController -from subiquity.common.types import ( - ApplicationState, - ShutdownMode, - ) -from subiquity.ui.views.installprogress import ( - InstallRunning, - ProgressView, - ) - - log = logging.getLogger("subiquity.client.controllers.progress") class ProgressController(SubiquityTuiController): - def __init__(self, app): super().__init__(app) self.progress_view = ProgressView(self) @@ -49,13 +40,13 @@ class ProgressController(SubiquityTuiController): self.progress_view.event_start( event["SUBIQUITY_CONTEXT_ID"], event.get("SUBIQUITY_CONTEXT_PARENT_ID"), - event["MESSAGE"]) + event["MESSAGE"], + ) elif event["SUBIQUITY_EVENT_TYPE"] == "finish": - self.progress_view.event_finish( - event["SUBIQUITY_CONTEXT_ID"]) + self.progress_view.event_finish(event["SUBIQUITY_CONTEXT_ID"]) def log_line(self, event): - log_line = event['MESSAGE'] + log_line = event["MESSAGE"] self.progress_view.add_log_line(log_line) def cancel(self): @@ -79,8 +70,7 @@ class ProgressController(SubiquityTuiController): install_running = None while True: try: - app_status = await self.app.client.meta.status.GET( - cur=self.app_state) + app_status = await self.app.client.meta.status.GET(cur=self.app_state) except aiohttp.ClientError: await asyncio.sleep(1) continue @@ -103,7 +93,8 @@ class ProgressController(SubiquityTuiController): if self.app_state == ApplicationState.RUNNING: if app_status.confirming_tty != self.app.our_tty: install_running = InstallRunning( - self.app, app_status.confirming_tty) + self.app, app_status.confirming_tty + ) self.app.add_global_overlay(install_running) else: if install_running is not None: @@ -111,7 +102,7 @@ class ProgressController(SubiquityTuiController): install_running = None if self.app_state == ApplicationState.DONE: - if self.answers.get('reboot', False): + if self.answers.get("reboot", False): self.click_reboot() def make_ui(self): diff --git a/subiquity/client/controllers/proxy.py b/subiquity/client/controllers/proxy.py index 0223b04a..18bd8c10 100644 --- a/subiquity/client/controllers/proxy.py +++ b/subiquity/client/controllers/proxy.py @@ -18,20 +18,19 @@ import logging from subiquity.client.controller import SubiquityTuiController from subiquity.ui.views.proxy import ProxyView -log = logging.getLogger('subiquity.client.controllers.proxy') +log = logging.getLogger("subiquity.client.controllers.proxy") class ProxyController(SubiquityTuiController): - - endpoint_name = 'proxy' + endpoint_name = "proxy" async def make_ui(self): proxy = await self.endpoint.GET() return ProxyView(self, proxy) def run_answers(self): - if 'proxy' in self.answers: - self.done(self.answers['proxy']) + if "proxy" in self.answers: + self.done(self.answers["proxy"]) def cancel(self): self.app.prev_screen() diff --git a/subiquity/client/controllers/refresh.py b/subiquity/client/controllers/refresh.py index 05214a9e..6c6a1bfa 100644 --- a/subiquity/client/controllers/refresh.py +++ b/subiquity/client/controllers/refresh.py @@ -18,25 +18,16 @@ import logging import aiohttp -from subiquitycore.tuicontroller import ( - Skip, - ) - -from subiquity.common.types import ( - RefreshCheckState, - ) -from subiquity.client.controller import ( - SubiquityTuiController, - ) +from subiquity.client.controller import SubiquityTuiController +from subiquity.common.types import RefreshCheckState from subiquity.ui.views.refresh import RefreshView +from subiquitycore.tuicontroller import Skip - -log = logging.getLogger('subiquity.client.controllers.refresh') +log = logging.getLogger("subiquity.client.controllers.refresh") class RefreshController(SubiquityTuiController): - - endpoint_name = 'refresh' + endpoint_name = "refresh" def __init__(self, app): super().__init__(app) @@ -61,8 +52,10 @@ class RefreshController(SubiquityTuiController): self.offered_first_time = True elif index == 2: if not self.offered_first_time: - if self.status.availability in [RefreshCheckState.UNKNOWN, - RefreshCheckState.AVAILABLE]: + if self.status.availability in [ + RefreshCheckState.UNKNOWN, + RefreshCheckState.AVAILABLE, + ]: show = True else: raise AssertionError("unexpected index {}".format(index)) diff --git a/subiquity/client/controllers/serial.py b/subiquity/client/controllers/serial.py index 039d23d5..d110b66d 100644 --- a/subiquity/client/controllers/serial.py +++ b/subiquity/client/controllers/serial.py @@ -19,7 +19,7 @@ from subiquity.client.controller import SubiquityTuiController from subiquity.ui.views.serial import SerialView from subiquitycore.tuicontroller import Skip -log = logging.getLogger('subiquity.client.controllers.serial') +log = logging.getLogger("subiquity.client.controllers.serial") class SerialController(SubiquityTuiController): @@ -30,11 +30,11 @@ class SerialController(SubiquityTuiController): return SerialView(self, ssh_info) def run_answers(self): - if 'rich' in self.answers: - if self.answers['rich']: - self.ui.body.rich_btn.base_widget._emit('click') + if "rich" in self.answers: + if self.answers["rich"]: + self.ui.body.rich_btn.base_widget._emit("click") else: - self.ui.body.basic_btn.base_widget._emit('click') + self.ui.body.basic_btn.base_widget._emit("click") def done(self, rich): log.debug("SerialController.done rich %s next_screen", rich) diff --git a/subiquity/client/controllers/snaplist.py b/subiquity/client/controllers/snaplist.py index e4666871..257e06d9 100644 --- a/subiquity/client/controllers/snaplist.py +++ b/subiquity/client/controllers/snaplist.py @@ -16,48 +16,37 @@ import logging from typing import List -from subiquitycore.tuicontroller import ( - Skip, - ) - -from subiquity.client.controller import ( - SubiquityTuiController, - ) -from subiquity.common.types import ( - SnapCheckState, - SnapSelection, - ) +from subiquity.client.controller import SubiquityTuiController +from subiquity.common.types import SnapCheckState, SnapSelection from subiquity.ui.views.snaplist import SnapListView +from subiquitycore.tuicontroller import Skip -log = logging.getLogger('subiquity.client.controllers.snaplist') +log = logging.getLogger("subiquity.client.controllers.snaplist") class SnapListController(SubiquityTuiController): - - endpoint_name = 'snaplist' + endpoint_name = "snaplist" async def make_ui(self): data = await self.endpoint.GET() if data.status == SnapCheckState.FAILED: # If loading snaps failed or network is disabled, skip the screen. - log.debug('snaplist GET failed, mark done') + log.debug("snaplist GET failed, mark done") await self.endpoint.POST([]) raise Skip return SnapListView(self, data) def run_answers(self): - if 'snaps' in self.answers: + if "snaps" in self.answers: selections = [] - for snap_name, selection in self.answers['snaps'].items(): - if 'is_classic' in selection: - selection['classic'] = selection.pop('is_classic') + for snap_name, selection in self.answers["snaps"].items(): + if "is_classic" in selection: + selection["classic"] = selection.pop("is_classic") selections.append(SnapSelection(name=snap_name, **selection)) self.done(selections) def done(self, selections: List[SnapSelection]): - log.debug( - "SnapListController.done next_screen snaps_to_install=%s", - selections) + log.debug("SnapListController.done next_screen snaps_to_install=%s", selections) self.app.next_screen(self.endpoint.POST(selections)) def cancel(self, sender=None): diff --git a/subiquity/client/controllers/source.py b/subiquity/client/controllers/source.py index 8d3d17de..d6c88e39 100644 --- a/subiquity/client/controllers/source.py +++ b/subiquity/client/controllers/source.py @@ -18,26 +18,24 @@ import logging from subiquity.client.controller import SubiquityTuiController from subiquity.ui.views.source import SourceView -log = logging.getLogger('subiquity.client.controllers.source') +log = logging.getLogger("subiquity.client.controllers.source") class SourceController(SubiquityTuiController): - - endpoint_name = 'source' + endpoint_name = "source" async def make_ui(self): sources = await self.endpoint.GET() - return SourceView(self, - sources.sources, - sources.current_id, - sources.search_drivers) + return SourceView( + self, sources.sources, sources.current_id, sources.search_drivers + ) def run_answers(self): form = self.app.ui.body.form if "search_drivers" in self.answers: form.search_drivers.value = self.answers["search_drivers"] - if 'source' in self.answers: - wanted_id = self.answers['source'] + if "source" in self.answers: + wanted_id = self.answers["source"] for bf in form._fields: if bf is form.search_drivers: continue @@ -48,6 +46,9 @@ class SourceController(SubiquityTuiController): self.app.prev_screen() def done(self, source_id, search_drivers: bool): - log.debug("SourceController.done source_id=%s, search_drivers=%s", - source_id, search_drivers) + log.debug( + "SourceController.done source_id=%s, search_drivers=%s", + source_id, + search_drivers, + ) self.app.next_screen(self.endpoint.POST(source_id, search_drivers)) diff --git a/subiquity/client/controllers/ssh.py b/subiquity/client/controllers/ssh.py index 47c05c60..a6365912 100644 --- a/subiquity/client/controllers/ssh.py +++ b/subiquity/client/controllers/ssh.py @@ -15,18 +15,13 @@ import logging +from subiquity.client.controller import SubiquityTuiController +from subiquity.common.types import SSHData, SSHFetchIdResponse, SSHFetchIdStatus +from subiquity.ui.views.ssh import SSHView from subiquitycore.async_helpers import schedule_task from subiquitycore.context import with_context -from subiquity.client.controller import SubiquityTuiController -from subiquity.common.types import ( - SSHData, - SSHFetchIdResponse, - SSHFetchIdStatus, - ) -from subiquity.ui.views.ssh import SSHView - -log = logging.getLogger('subiquity.client.controllers.ssh') +log = logging.getLogger("subiquity.client.controllers.ssh") class FetchSSHKeysFailure(Exception): @@ -36,35 +31,31 @@ class FetchSSHKeysFailure(Exception): class SSHController(SubiquityTuiController): - - endpoint_name = 'ssh' + endpoint_name = "ssh" def __init__(self, app): super().__init__(app) self._fetch_task = None if not self.answers: - identity_answers = self.app.answers.get('Identity', {}) - if 'ssh-import-id' in identity_answers: - self.answers['ssh-import-id'] = identity_answers[ - 'ssh-import-id'] + identity_answers = self.app.answers.get("Identity", {}) + if "ssh-import-id" in identity_answers: + self.answers["ssh-import-id"] = identity_answers["ssh-import-id"] async def make_ui(self): ssh_data = await self.endpoint.GET() return SSHView(self, ssh_data) def run_answers(self): - if 'ssh-import-id' in self.answers: - import_id = self.answers['ssh-import-id'] - ssh = SSHData( - install_server=True, - authorized_keys=[], - allow_pw=True) + if "ssh-import-id" in self.answers: + import_id = self.answers["ssh-import-id"] + ssh = SSHData(install_server=True, authorized_keys=[], allow_pw=True) self.fetch_ssh_keys(ssh_import_id=import_id, ssh_data=ssh) else: ssh = SSHData( install_server=self.answers.get("install_server", False), authorized_keys=self.answers.get("authorized_keys", []), - allow_pw=self.answers.get("pwauth", True)) + allow_pw=self.answers.get("pwauth", True), + ) self.done(ssh) def cancel(self): @@ -75,44 +66,50 @@ class SSHController(SubiquityTuiController): return self._fetch_task.cancel() - @with_context( - name="ssh_import_id", description="{ssh_import_id}") + @with_context(name="ssh_import_id", description="{ssh_import_id}") async def _fetch_ssh_keys(self, *, context, ssh_import_id, ssh_data): with self.context.child("ssh_import_id", ssh_import_id): - response: SSHFetchIdResponse = await \ - self.endpoint.fetch_id.GET(ssh_import_id) + response: SSHFetchIdResponse = await self.endpoint.fetch_id.GET( + ssh_import_id + ) if response.status == SSHFetchIdStatus.IMPORT_ERROR: if isinstance(self.ui.body, SSHView): self.ui.body.fetching_ssh_keys_failed( - _("Importing keys failed:"), response.error) + _("Importing keys failed:"), response.error + ) return elif response.status == SSHFetchIdStatus.FINGERPRINT_ERROR: if isinstance(self.ui.body, SSHView): self.ui.body.fetching_ssh_keys_failed( - _("ssh-keygen failed to show fingerprint of" - " downloaded keys:"), - response.error) + _( + "ssh-keygen failed to show fingerprint of" + " downloaded keys:" + ), + response.error, + ) return identities = response.identities - if 'ssh-import-id' in self.app.answers.get("Identity", {}): - ssh_data.authorized_keys = \ - [id_.to_authorized_key() for id_ in identities] + if "ssh-import-id" in self.app.answers.get("Identity", {}): + ssh_data.authorized_keys = [ + id_.to_authorized_key() for id_ in identities + ] self.done(ssh_data) else: if isinstance(self.ui.body, SSHView): - self.ui.body.confirm_ssh_keys( - ssh_data, ssh_import_id, identities) + self.ui.body.confirm_ssh_keys(ssh_data, ssh_import_id, identities) else: - log.debug("ui.body of unexpected instance: %s", - type(self.ui.body).__name__) + log.debug( + "ui.body of unexpected instance: %s", + type(self.ui.body).__name__, + ) def fetch_ssh_keys(self, ssh_import_id, ssh_data): self._fetch_task = schedule_task( - self._fetch_ssh_keys( - ssh_import_id=ssh_import_id, ssh_data=ssh_data)) + self._fetch_ssh_keys(ssh_import_id=ssh_import_id, ssh_data=ssh_data) + ) def done(self, result): log.debug("SSHController.done next_screen result=%s", result) diff --git a/subiquity/client/controllers/ubuntu_pro.py b/subiquity/client/controllers/ubuntu_pro.py index d2adfbb2..d0d71cac 100644 --- a/subiquity/client/controllers/ubuntu_pro.py +++ b/subiquity/client/controllers/ubuntu_pro.py @@ -22,23 +22,21 @@ from typing import Callable, Optional import aiohttp from urwid import Widget -from subiquitycore.async_helpers import schedule_task - from subiquity.client.controller import SubiquityTuiController +from subiquity.common.types import UbuntuProCheckTokenStatus as TokenStatus from subiquity.common.types import ( UbuntuProInfo, UbuntuProResponse, - UbuntuProCheckTokenStatus as TokenStatus, UbuntuProSubscription, UPCSWaitStatus, - ) +) from subiquity.ui.views.ubuntu_pro import ( - UbuntuProView, - UpgradeYesNoForm, - UpgradeModeForm, TokenAddedWidget, - ) - + UbuntuProView, + UpgradeModeForm, + UpgradeYesNoForm, +) +from subiquitycore.async_helpers import schedule_task from subiquitycore.lsb_release import lsb_release from subiquitycore.tuicontroller import Skip @@ -46,12 +44,12 @@ log = logging.getLogger("subiquity.client.controllers.ubuntu_pro") class UbuntuProController(SubiquityTuiController): - """ Client-side controller for Ubuntu Pro configuration. """ + """Client-side controller for Ubuntu Pro configuration.""" endpoint_name = "ubuntu_pro" def __init__(self, app) -> None: - """ Initializer for the client-side UA controller. """ + """Initializer for the client-side UA controller.""" self._check_task: Optional[asyncio.Future] = None self._magic_attach_task: Optional[asyncio.Future] = None @@ -60,7 +58,7 @@ class UbuntuProController(SubiquityTuiController): super().__init__(app) async def make_ui(self) -> UbuntuProView: - """ Generate the UI, based on the data provided by the model. """ + """Generate the UI, based on the data provided by the model.""" dry_run: bool = self.app.opts.dry_run @@ -70,13 +68,13 @@ class UbuntuProController(SubiquityTuiController): raise Skip("Not running LTS version") ubuntu_pro_info: UbuntuProResponse = await self.endpoint.GET() - return UbuntuProView(self, - token=ubuntu_pro_info.token, - has_network=ubuntu_pro_info.has_network) + return UbuntuProView( + self, token=ubuntu_pro_info.token, has_network=ubuntu_pro_info.has_network + ) async def run_answers(self) -> None: - """ Interact with the UI to go through the pre-attach process if - requested. """ + """Interact with the UI to go through the pre-attach process if + requested.""" if "token" not in self.answers: return @@ -100,8 +98,7 @@ class UbuntuProController(SubiquityTuiController): click(find_button_matching(view, UpgradeYesNoForm.ok_label)) def run_token_screen(token: str) -> None: - keypress(view.upgrade_mode_form.with_contract_token.widget, - key="enter") + keypress(view.upgrade_mode_form.with_contract_token.widget, key="enter") data = {"with_contract_token_subform": {"token": token}} # TODO: add this point, it would be good to trigger the validation # code for the token field. @@ -117,13 +114,12 @@ class UbuntuProController(SubiquityTuiController): # Wait until the "Token added successfully" overlay is shown. while not find_with_pred(view, is_token_added_overlay): - await asyncio.sleep(.2) + await asyncio.sleep(0.2) click(find_button_matching(view, TokenAddedWidget.done_label)) def run_subscription_screen() -> None: - click(find_button_matching(view._w, - UbuntuProView.subscription_done_label)) + click(find_button_matching(view._w, UbuntuProView.subscription_done_label)) if not self.answers["token"]: run_yes_no_screen(skip=True) @@ -134,11 +130,14 @@ class UbuntuProController(SubiquityTuiController): await run_token_added_overlay() run_subscription_screen() - def check_token(self, token: str, - on_success: Callable[[UbuntuProSubscription], None], - on_failure: Callable[[TokenStatus], None], - ) -> None: - """ Asynchronously check the token passed as an argument. """ + def check_token( + self, + token: str, + on_success: Callable[[UbuntuProSubscription], None], + on_failure: Callable[[TokenStatus], None], + ) -> None: + """Asynchronously check the token passed as an argument.""" + async def inner() -> None: answer = await self.endpoint.check_token.GET(token) if answer.status == TokenStatus.VALID_TOKEN: @@ -150,29 +149,32 @@ class UbuntuProController(SubiquityTuiController): self._check_task = schedule_task(inner()) def cancel_check_token(self) -> None: - """ Cancel the asynchronous token check (if started). """ + """Cancel the asynchronous token check (if started).""" if self._check_task is not None: self._check_task.cancel() - def contract_selection_initiate( - self, on_initiated: Callable[[str], None]) -> None: - """ Initiate the contract selection asynchronously. Calls on_initiated - when the contract-selection has initiated. """ + def contract_selection_initiate(self, on_initiated: Callable[[str], None]) -> None: + """Initiate the contract selection asynchronously. Calls on_initiated + when the contract-selection has initiated.""" + async def inner() -> None: answer = await self.endpoint.contract_selection.initiate.POST() self.cs_initiated = True on_initiated(answer.user_code) + self._magic_attach_task = schedule_task(inner()) def contract_selection_wait( - self, - on_contract_selected: Callable[[str], None], - on_timeout: Callable[[], None]) -> None: - """ Asynchronously wait for the contract selection to finish. + self, + on_contract_selected: Callable[[str], None], + on_timeout: Callable[[], None], + ) -> None: + """Asynchronously wait for the contract selection to finish. Calls on_contract_selected with the contract-token once the contract gets selected Otherwise, calls on_timeout if no contract got selected within the - allowed time frame. """ + allowed time frame.""" + async def inner() -> None: try: answer = await self.endpoint.contract_selection.wait.GET() @@ -187,9 +189,9 @@ class UbuntuProController(SubiquityTuiController): self._monitor_task = schedule_task(inner()) - def contract_selection_cancel( - self, on_cancelled: Callable[[], None]) -> None: - """ Cancel the asynchronous contract selection (if started). """ + def contract_selection_cancel(self, on_cancelled: Callable[[], None]) -> None: + """Cancel the asynchronous contract selection (if started).""" + async def inner() -> None: await self.endpoint.contract_selection.cancel.POST() self.cs_initiated = False @@ -203,12 +205,10 @@ class UbuntuProController(SubiquityTuiController): self.app.prev_screen() def done(self, token: str) -> None: - """ Submit the token and move on to the next screen. """ - self.app.next_screen( - self.endpoint.POST(UbuntuProInfo(token=token)) - ) + """Submit the token and move on to the next screen.""" + self.app.next_screen(self.endpoint.POST(UbuntuProInfo(token=token))) def next_screen(self) -> None: - """ Move on to the next screen. Assume the token should not be - submitted (or has already been submitted). """ + """Move on to the next screen. Assume the token should not be + submitted (or has already been submitted).""" self.app.next_screen() diff --git a/subiquity/client/controllers/welcome.py b/subiquity/client/controllers/welcome.py index 208663e3..1313bcac 100644 --- a/subiquity/client/controllers/welcome.py +++ b/subiquity/client/controllers/welcome.py @@ -15,17 +15,16 @@ import logging -from subiquitycore import i18n -from subiquitycore.tuicontroller import Skip from subiquity.client.controller import SubiquityTuiController from subiquity.ui.views.welcome import WelcomeView +from subiquitycore import i18n +from subiquitycore.tuicontroller import Skip -log = logging.getLogger('subiquity.client.controllers.welcome') +log = logging.getLogger("subiquity.client.controllers.welcome") class WelcomeController(SubiquityTuiController): - - endpoint_name = 'locale' + endpoint_name = "locale" async def make_ui(self): if not self.app.rich_mode: @@ -36,11 +35,11 @@ class WelcomeController(SubiquityTuiController): return WelcomeView(self, language, self.serial) def run_answers(self): - if 'lang' in self.answers: - self.done((self.answers['lang'], "")) + if "lang" in self.answers: + self.done((self.answers["lang"], "")) def done(self, lang): - """ Completes this controller. lang must be a tuple of strings + """Completes this controller. lang must be a tuple of strings containing the language code and its native representation respectively. """ diff --git a/subiquity/client/controllers/zdev.py b/subiquity/client/controllers/zdev.py index c2b4fe4e..821b30d0 100644 --- a/subiquity/client/controllers/zdev.py +++ b/subiquity/client/controllers/zdev.py @@ -18,20 +18,18 @@ import logging from subiquity.client.controller import SubiquityTuiController from subiquity.ui.views import ZdevView - log = logging.getLogger("subiquity.client.controllers.zdev") class ZdevController(SubiquityTuiController): - - endpoint_name = 'zdev' + endpoint_name = "zdev" async def make_ui(self): infos = await self.endpoint.GET() return ZdevView(self, infos) def run_answers(self): - if 'accept-default' in self.answers: + if "accept-default" in self.answers: self.done() def cancel(self): diff --git a/subiquity/client/keycodes.py b/subiquity/client/keycodes.py index ee7a79fe..585f96e6 100644 --- a/subiquity/client/keycodes.py +++ b/subiquity/client/keycodes.py @@ -19,7 +19,7 @@ import os import struct import sys -log = logging.getLogger('subiquity.client.keycodes') +log = logging.getLogger("subiquity.client.keycodes") # /usr/include/linux/kd.h K_RAW = 0x00 @@ -46,7 +46,7 @@ class KeyCodesFilter: """ def __init__(self): - self._fd = os.open("/proc/self/fd/"+str(sys.stdin.fileno()), os.O_RDWR) + self._fd = os.open("/proc/self/fd/" + str(sys.stdin.fileno()), os.O_RDWR) self.filtering = False def enter_keycodes_mode(self): @@ -56,7 +56,7 @@ class KeyCodesFilter: # well). o = bytearray(4) fcntl.ioctl(self._fd, KDGKBMODE, o) - self._old_mode = struct.unpack('i', o)[0] + self._old_mode = struct.unpack("i", o)[0] # Set the keyboard mode to K_MEDIUMRAW, which causes the keyboard # driver in the kernel to pass us keycodes. fcntl.ioctl(self._fd, KDSKBMODE, K_MEDIUMRAW) @@ -76,17 +76,16 @@ class KeyCodesFilter: while i < len(codes): # This is straight from showkeys.c. if codes[i] & 0x80: - p = 'release ' + p = "release " else: - p = 'press ' - if i + 2 < n and (codes[i] & 0x7f) == 0: + p = "press " + if i + 2 < n and (codes[i] & 0x7F) == 0: if (codes[i + 1] & 0x80) != 0: if (codes[i + 2] & 0x80) != 0: - kc = (((codes[i + 1] & 0x7f) << 7) | - (codes[i + 2] & 0x7f)) + kc = ((codes[i + 1] & 0x7F) << 7) | (codes[i + 2] & 0x7F) i += 3 else: - kc = codes[i] & 0x7f + kc = codes[i] & 0x7F i += 1 r.append(p + str(kc)) return r diff --git a/subiquity/cmd/common.py b/subiquity/cmd/common.py index ea052a17..062647d7 100644 --- a/subiquity/cmd/common.py +++ b/subiquity/cmd/common.py @@ -28,11 +28,13 @@ def setup_environment(): locale.setlocale(locale.LC_CTYPE, "C.UTF-8") # Prefer utils from $SUBIQUITY_ROOT, over system-wide - root = os.environ.get('SUBIQUITY_ROOT') + root = os.environ.get("SUBIQUITY_ROOT") if root: - os.environ['PATH'] = os.pathsep.join([ - os.path.join(root, 'bin'), - os.path.join(root, 'usr', 'bin'), - os.environ['PATH'], - ]) - os.environ["APPORT_DATA_DIR"] = os.path.join(root, 'share/apport') + os.environ["PATH"] = os.pathsep.join( + [ + os.path.join(root, "bin"), + os.path.join(root, "usr", "bin"), + os.environ["PATH"], + ] + ) + os.environ["APPORT_DATA_DIR"] = os.path.join(root, "share/apport") diff --git a/subiquity/cmd/schema.py b/subiquity/cmd/schema.py index e12b4819..29d8f98a 100644 --- a/subiquity/cmd/schema.py +++ b/subiquity/cmd/schema.py @@ -29,16 +29,16 @@ from subiquity.server.server import SubiquityServer def make_schema(app): schema = copy.deepcopy(app.base_schema) for controller in app.controllers.instances: - ckey = getattr(controller, 'autoinstall_key', None) + ckey = getattr(controller, "autoinstall_key", None) if ckey is None: continue cschema = getattr(controller, "autoinstall_schema", None) if cschema is None: continue - schema['properties'][ckey] = cschema + schema["properties"][ckey] = cschema - ckey_alias = getattr(controller, 'autoinstall_key_alias', None) + ckey_alias = getattr(controller, "autoinstall_key_alias", None) if ckey_alias is None: continue @@ -46,15 +46,15 @@ def make_schema(app): cschema["deprecated"] = True cschema["description"] = f"Compatibility only - use {ckey} instead" - schema['properties'][ckey_alias] = cschema + schema["properties"][ckey_alias] = cschema return schema def make_app(): parser = make_server_args_parser() - opts, unknown = parser.parse_known_args(['--dry-run']) - app = SubiquityServer(opts, '') + opts, unknown = parser.parse_known_args(["--dry-run"]) + app = SubiquityServer(opts, "") # This is needed because the ubuntu-pro server controller accesses dr_cfg # in the initializer. app.dr_cfg = DRConfig() @@ -72,5 +72,5 @@ def main(): asyncio.run(run_with_loop()) -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/subiquity/cmd/server.py b/subiquity/cmd/server.py index 08ff3786..9b89aea8 100644 --- a/subiquity/cmd/server.py +++ b/subiquity/cmd/server.py @@ -25,10 +25,7 @@ import attr from subiquitycore.log import setup_logger -from .common import ( - LOGDIR, - setup_environment, - ) +from .common import LOGDIR, setup_environment @attr.s(auto_attribs=True) @@ -41,8 +38,8 @@ class CommandLineParams: def from_cmdline(cls, cmdline): r = cls(cmdline) for tok in shlex.split(cmdline): - if '=' in tok: - k, v = tok.split('=', 1) + if "=" in tok: + k, v = tok.split("=", 1) r._values[k] = v else: r._tokens.add(tok) @@ -57,74 +54,103 @@ class CommandLineParams: def make_server_args_parser(): parser = argparse.ArgumentParser( - description='SUbiquity - Ubiquity for Servers', - prog='subiquity') - parser.add_argument('--dry-run', action='store_true', - dest='dry_run', - help='menu-only, do not call installer function') - parser.add_argument('--dry-run-config', type=argparse.FileType()) - parser.add_argument('--socket') - parser.add_argument('--machine-config', metavar='CONFIG', - dest='machine_config', - type=argparse.FileType(), - help="Don't Probe. Use probe data file") - parser.add_argument('--bootloader', - choices=['none', 'bios', 'prep', 'uefi'], - help='Override style of bootloader to use') + description="SUbiquity - Ubiquity for Servers", prog="subiquity" + ) parser.add_argument( - '--autoinstall', action='store', - help=('Path to autoinstall file. Empty value disables autoinstall. ' - 'By default tries to load /autoinstall.yaml, ' - 'or autoinstall data from cloud-init.')) - with open('/proc/cmdline') as fp: + "--dry-run", + action="store_true", + dest="dry_run", + help="menu-only, do not call installer function", + ) + parser.add_argument("--dry-run-config", type=argparse.FileType()) + parser.add_argument("--socket") + parser.add_argument( + "--machine-config", + metavar="CONFIG", + dest="machine_config", + type=argparse.FileType(), + help="Don't Probe. Use probe data file", + ) + parser.add_argument( + "--bootloader", + choices=["none", "bios", "prep", "uefi"], + help="Override style of bootloader to use", + ) + parser.add_argument( + "--autoinstall", + action="store", + help=( + "Path to autoinstall file. Empty value disables autoinstall. " + "By default tries to load /autoinstall.yaml, " + "or autoinstall data from cloud-init." + ), + ) + with open("/proc/cmdline") as fp: cmdline = fp.read() parser.add_argument( - '--kernel-cmdline', action='store', default=cmdline, - type=CommandLineParams.from_cmdline) + "--kernel-cmdline", + action="store", + default=cmdline, + type=CommandLineParams.from_cmdline, + ) parser.add_argument( - '--snaps-from-examples', action='store_const', const=True, + "--snaps-from-examples", + action="store_const", + const=True, dest="snaps_from_examples", - help=("Load snap details from examples/snaps instead of store. " - "Default in dry-run mode. " - "See examples/snaps/README.md for more.")) + help=( + "Load snap details from examples/snaps instead of store. " + "Default in dry-run mode. " + "See examples/snaps/README.md for more." + ), + ) parser.add_argument( - '--no-snaps-from-examples', action='store_const', const=False, + "--no-snaps-from-examples", + action="store_const", + const=False, dest="snaps_from_examples", - help=("Load snap details from store instead of examples. " - "Default in when not in dry-run mode. " - "See examples/snaps/README.md for more.")) + help=( + "Load snap details from store instead of examples. " + "Default in when not in dry-run mode. " + "See examples/snaps/README.md for more." + ), + ) parser.add_argument( - '--snap-section', action='store', default='server', - help=("Show snaps from this section of the store in the snap " - "list screen.")) + "--snap-section", + action="store", + default="server", + help=("Show snaps from this section of the store in the snap " "list screen."), + ) + parser.add_argument("--source-catalog", dest="source_catalog", action="store") parser.add_argument( - '--source-catalog', dest='source_catalog', action='store') + "--output-base", + action="store", + dest="output_base", + default=".subiquity", + help="in dryrun, control basedir of files", + ) + parser.add_argument("--storage-version", action="store", type=int) + parser.add_argument("--use-os-prober", action="store_true", default=False) parser.add_argument( - '--output-base', action='store', dest='output_base', - default='.subiquity', - help='in dryrun, control basedir of files') - parser.add_argument( - '--storage-version', action='store', type=int) - parser.add_argument( - '--use-os-prober', action='store_true', default=False) - parser.add_argument( - '--postinst-hooks-dir', default='/etc/subiquity/postinst.d', - type=pathlib.Path) + "--postinst-hooks-dir", default="/etc/subiquity/postinst.d", type=pathlib.Path + ) return parser def main(): - print('starting server') + print("starting server") setup_environment() # setup_environment sets $APPORT_DATA_DIR which must be set before # apport is imported, which is done by this import: - from subiquity.server.server import SubiquityServer from subiquity.server.dryrun import DRConfig + from subiquity.server.server import SubiquityServer + parser = make_server_args_parser() opts = parser.parse_args(sys.argv[1:]) if opts.storage_version is None: - opts.storage_version = int(opts.kernel_cmdline.get( - 'subiquity-storage-version', 1)) + opts.storage_version = int( + opts.kernel_cmdline.get("subiquity-storage-version", 1) + ) logdir = LOGDIR if opts.dry_run: if opts.dry_run_config: @@ -139,25 +165,26 @@ def main(): dr_cfg = None if opts.socket is None: if opts.dry_run: - opts.socket = opts.output_base + '/socket' + opts.socket = opts.output_base + "/socket" else: - opts.socket = '/run/subiquity/socket' + opts.socket = "/run/subiquity/socket" os.makedirs(os.path.dirname(opts.socket), exist_ok=True) block_log_dir = os.path.join(logdir, "block") os.makedirs(block_log_dir, exist_ok=True) - handler = logging.FileHandler(os.path.join(block_log_dir, 'discover.log')) - handler.setLevel('DEBUG') + handler = logging.FileHandler(os.path.join(block_log_dir, "discover.log")) + handler.setLevel("DEBUG") handler.setFormatter( - logging.Formatter("%(asctime)s %(name)s:%(lineno)d %(message)s")) - logging.getLogger('probert').addHandler(handler) - handler.addFilter(lambda rec: rec.name != 'probert.network') - logging.getLogger('curtin').addHandler(handler) - logging.getLogger('block-discover').addHandler(handler) + logging.Formatter("%(asctime)s %(name)s:%(lineno)d %(message)s") + ) + logging.getLogger("probert").addHandler(handler) + handler.addFilter(lambda rec: rec.name != "probert.network") + logging.getLogger("curtin").addHandler(handler) + logging.getLogger("block-discover").addHandler(handler) - logfiles = setup_logger(dir=logdir, base='subiquity-server') + logfiles = setup_logger(dir=logdir, base="subiquity-server") - logger = logging.getLogger('subiquity') + logger = logging.getLogger("subiquity") version = os.environ.get("SNAP_REVISION", "unknown") snap = os.environ.get("SNAP", "unknown") logger.info(f"Starting Subiquity server revision {version} of snap {snap}") @@ -168,18 +195,16 @@ def main(): async def run_with_loop(): server = SubiquityServer(opts, block_log_dir) server.dr_cfg = dr_cfg - server.note_file_for_apport( - "InstallerServerLog", logfiles['debug']) - server.note_file_for_apport( - "InstallerServerLogInfo", logfiles['info']) + server.note_file_for_apport("InstallerServerLog", logfiles["debug"]) + server.note_file_for_apport("InstallerServerLogInfo", logfiles["info"]) server.note_file_for_apport( "UdiLog", - os.path.realpath( - "/var/log/installer/ubuntu_desktop_installer.log")) + os.path.realpath("/var/log/installer/ubuntu_desktop_installer.log"), + ) await server.run() asyncio.run(run_with_loop()) -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/subiquity/cmd/tui.py b/subiquity/cmd/tui.py index f9d20c82..405efe87 100755 --- a/subiquity/cmd/tui.py +++ b/subiquity/cmd/tui.py @@ -16,53 +16,65 @@ import argparse import asyncio +import fcntl import logging import os -import fcntl import subprocess import sys from subiquitycore.log import setup_logger -from .common import ( - LOGDIR, - setup_environment, - ) +from .common import LOGDIR, setup_environment from .server import make_server_args_parser def make_client_args_parser(): parser = argparse.ArgumentParser( - description='Subiquity - Ubiquity for Servers', - prog='subiquity') + description="Subiquity - Ubiquity for Servers", prog="subiquity" + ) try: ascii_default = os.ttyname(0) == "/dev/ttysclp0" except OSError: ascii_default = False parser.set_defaults(ascii=ascii_default) - parser.add_argument('--dry-run', action='store_true', - dest='dry_run', - help='menu-only, do not call installer function') - parser.add_argument('--socket') - parser.add_argument('--serial', action='store_true', - dest='run_on_serial', - help='Run the installer over serial console.') - parser.add_argument('--ssh', action='store_true', - dest='ssh', - help='Print ssh login details') - parser.add_argument('--ascii', action='store_true', - dest='ascii', - help='Run the installer in ascii mode.') - parser.add_argument('--unicode', action='store_false', - dest='ascii', - help='Run the installer in unicode mode.') - parser.add_argument('--screens', action='append', dest='screens', - default=[]) - parser.add_argument('--answers') - parser.add_argument('--server-pid') - parser.add_argument('--output-base', action='store', dest='output_base', - default='.subiquity', - help='in dryrun, control basedir of files') + parser.add_argument( + "--dry-run", + action="store_true", + dest="dry_run", + help="menu-only, do not call installer function", + ) + parser.add_argument("--socket") + parser.add_argument( + "--serial", + action="store_true", + dest="run_on_serial", + help="Run the installer over serial console.", + ) + parser.add_argument( + "--ssh", action="store_true", dest="ssh", help="Print ssh login details" + ) + parser.add_argument( + "--ascii", + action="store_true", + dest="ascii", + help="Run the installer in ascii mode.", + ) + parser.add_argument( + "--unicode", + action="store_false", + dest="ascii", + help="Run the installer in unicode mode.", + ) + parser.add_argument("--screens", action="append", dest="screens", default=[]) + parser.add_argument("--answers") + parser.add_argument("--server-pid") + parser.add_argument( + "--output-base", + action="store", + dest="output_base", + default=".subiquity", + help="in dryrun, control basedir of files", + ) return parser @@ -74,27 +86,28 @@ def main(): # setup_environment sets $APPORT_DATA_DIR which must be set before # apport is imported, which is done by this import: from subiquity.client.client import SubiquityClient + parser = make_client_args_parser() args = sys.argv[1:] - if '--dry-run' in args: + if "--dry-run" in args: opts, unknown = parser.parse_known_args(args) if opts.socket is None: base = opts.output_base os.makedirs(base, exist_ok=True) - if os.path.exists(f'{base}/run/subiquity/server-state'): - os.unlink(f'{base}/run/subiquity/server-state') - sock_path = base + '/socket' + if os.path.exists(f"{base}/run/subiquity/server-state"): + os.unlink(f"{base}/run/subiquity/server-state") + sock_path = base + "/socket" opts.socket = sock_path - server_args = ['--dry-run', '--socket=' + sock_path] + unknown - server_args.extend(('--output-base', base)) + server_args = ["--dry-run", "--socket=" + sock_path] + unknown + server_args.extend(("--output-base", base)) server_parser = make_server_args_parser() server_parser.parse_args(server_args) # just to check - server_stdout = open(os.path.join(base, 'server-stdout'), 'w') - server_stderr = open(os.path.join(base, 'server-stderr'), 'w') - server_cmd = [sys.executable, '-m', 'subiquity.cmd.server'] + \ - server_args + server_stdout = open(os.path.join(base, "server-stdout"), "w") + server_stderr = open(os.path.join(base, "server-stderr"), "w") + server_cmd = [sys.executable, "-m", "subiquity.cmd.server"] + server_args server_proc = subprocess.Popen( - server_cmd, stdout=server_stdout, stderr=server_stderr) + server_cmd, stdout=server_stdout, stderr=server_stderr + ) opts.server_pid = str(server_proc.pid) print("running server pid {}".format(server_proc.pid)) elif opts.server_pid is not None: @@ -104,14 +117,14 @@ def main(): else: opts = parser.parse_args(args) if opts.socket is None: - opts.socket = '/run/subiquity/socket' + opts.socket = "/run/subiquity/socket" os.makedirs(os.path.basename(opts.socket), exist_ok=True) logdir = LOGDIR if opts.dry_run: logdir = opts.output_base - logfiles = setup_logger(dir=logdir, base='subiquity-client') + logfiles = setup_logger(dir=logdir, base="subiquity-client") - logger = logging.getLogger('subiquity') + logger = logging.getLogger("subiquity") version = os.environ.get("SNAP_REVISION", "unknown") snap = os.environ.get("SNAP", "unknown") logger.info(f"Starting Subiquity TUI revision {version} of snap {snap}") @@ -127,21 +140,18 @@ def main(): try: fcntl.flock(opts.answers, fcntl.LOCK_EX | fcntl.LOCK_NB) except OSError: - logger.exception( - 'Failed to lock auto answers file, proceding without it.') + logger.exception("Failed to lock auto answers file, proceding without it.") opts.answers.close() opts.answers = None async def run_with_loop(): subiquity_interface = SubiquityClient(opts) - subiquity_interface.note_file_for_apport( - "InstallerLog", logfiles['debug']) - subiquity_interface.note_file_for_apport( - "InstallerLogInfo", logfiles['info']) + subiquity_interface.note_file_for_apport("InstallerLog", logfiles["debug"]) + subiquity_interface.note_file_for_apport("InstallerLogInfo", logfiles["info"]) await subiquity_interface.run() asyncio.run(run_with_loop()) -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/subiquity/common/api/client.py b/subiquity/common/api/client.py index 2c38f89e..4c0dfdff 100644 --- a/subiquity/common/api/client.py +++ b/subiquity/common/api/client.py @@ -19,6 +19,7 @@ import inspect import aiohttp from subiquity.common.serialize import Serializer + from .defs import Payload @@ -27,7 +28,7 @@ def _wrap(make_request, path, meth, serializer, serialize_query_args): meth_params = sig.parameters payload_arg = None for name, param in meth_params.items(): - if getattr(param.annotation, '__origin__', None) is Payload: + if getattr(param.annotation, "__origin__", None) is Payload: payload_arg = name payload_ann = param.annotation.__args__[0] r_ann = sig.return_annotation @@ -41,14 +42,14 @@ def _wrap(make_request, path, meth, serializer, serialize_query_args): data = serializer.serialize(payload_ann, value) else: if serialize_query_args: - value = serializer.to_json( - meth_params[arg_name].annotation, value) + value = serializer.to_json(meth_params[arg_name].annotation, value) query_args[arg_name] = value async with make_request( - meth.__name__, path.format(**self.path_args), - json=data, params=query_args) as resp: + meth.__name__, path.format(**self.path_args), json=data, params=query_args + ) as resp: resp.raise_for_status() return serializer.deserialize(r_ann, await resp.json()) + return impl @@ -73,19 +74,24 @@ def make_client_cls(endpoint_cls, make_request, serializer=None): if serializer is None: serializer = Serializer() - ns = {'__init__': client_init} + ns = {"__init__": client_init} for k, v in endpoint_cls.__dict__.items(): if isinstance(v, type): - if getattr(v, '__parameter__', False): - ns['__getitem__'] = make_getitem(v, make_request, serializer) + if getattr(v, "__parameter__", False): + ns["__getitem__"] = make_getitem(v, make_request, serializer) else: ns[k] = make_client(v, make_request, serializer) elif callable(v): - ns[k] = _wrap(make_request, endpoint_cls.fullpath, v, serializer, - endpoint_cls.serialize_query_args) + ns[k] = _wrap( + make_request, + endpoint_cls.fullpath, + v, + serializer, + endpoint_cls.serialize_query_args, + ) - return type('ClientFor({})'.format(endpoint_cls.__name__), (object,), ns) + return type("ClientFor({})".format(endpoint_cls.__name__), (object,), ns) def make_client(endpoint_cls, make_request, serializer=None, path_args=None): @@ -93,10 +99,9 @@ def make_client(endpoint_cls, make_request, serializer=None, path_args=None): def make_client_for_conn( - endpoint_cls, conn, resp_hook=lambda r: r, serializer=None, - header_func=None): - session = aiohttp.ClientSession( - connector=conn, connector_owner=False) + endpoint_cls, conn, resp_hook=lambda r: r, serializer=None, header_func=None +): + session = aiohttp.ClientSession(connector=conn, connector_owner=False) @contextlib.asynccontextmanager async def make_request(method, path, *, params, json): @@ -105,14 +110,14 @@ def make_client_for_conn( # hardcode something here (I guess the "a" gets sent along to the # server in the Host: header and the server could in principle do # something like virtual host based selection but well....) - url = 'http://a' + path + url = "http://a" + path if header_func is not None: headers = header_func() else: headers = None async with session.request( - method, url, json=json, params=params, - headers=headers, timeout=0) as response: + method, url, json=json, params=params, headers=headers, timeout=0 + ) as response: yield resp_hook(response) return make_client(endpoint_cls, make_request, serializer) diff --git a/subiquity/common/api/defs.py b/subiquity/common/api/defs.py index 47fc4285..38af0e67 100644 --- a/subiquity/common/api/defs.py +++ b/subiquity/common/api/defs.py @@ -27,8 +27,10 @@ class InvalidQueryArgs(InvalidAPIDefinition): self.param = param def __str__(self): - return (f"{self.callable.__qualname__} does not serialize query " - f"arguments but has non-str parameter '{self.param}'") + return ( + f"{self.callable.__qualname__} does not serialize query " + f"arguments but has non-str parameter '{self.param}'" + ) class MultiplePathParameters(InvalidAPIDefinition): @@ -38,45 +40,50 @@ class MultiplePathParameters(InvalidAPIDefinition): self.param2 = param2 def __str__(self): - return (f"{self.cls.__name__} has multiple path parameters " - f"{self.param1!r} and {self.param2!r}") + return ( + f"{self.cls.__name__} has multiple path parameters " + f"{self.param1!r} and {self.param2!r}" + ) -def api(cls, prefix_names=(), prefix_path=(), path_params=(), - serialize_query_args=True): - if hasattr(cls, 'serialize_query_args'): +def api( + cls, prefix_names=(), prefix_path=(), path_params=(), serialize_query_args=True +): + if hasattr(cls, "serialize_query_args"): serialize_query_args = cls.serialize_query_args else: cls.serialize_query_args = serialize_query_args - cls.fullpath = '/' + '/'.join(prefix_path) + cls.fullpath = "/" + "/".join(prefix_path) cls.fullname = prefix_names seen_path_param = None for k, v in cls.__dict__.items(): if isinstance(v, type): v.__shortname__ = k - v.__name__ = cls.__name__ + '.' + k + v.__name__ = cls.__name__ + "." + k path_part = k path_param = () - if getattr(v, '__parameter__', False): + if getattr(v, "__parameter__", False): if seen_path_param: raise MultiplePathParameters(cls, seen_path_param, k) seen_path_param = k - path_part = '{' + path_part + '}' + path_part = "{" + path_part + "}" path_param = (k,) api( v, prefix_names + (k,), prefix_path + (path_part,), path_params + path_param, - serialize_query_args) + serialize_query_args, + ) if callable(v): - v.__qualname__ = cls.__name__ + '.' + k + v.__qualname__ = cls.__name__ + "." + k if not cls.serialize_query_args: params = inspect.signature(v).parameters for param_name, param in params.items(): - if param.annotation is not str and \ - getattr(param.annotation, '__origin__', None) is not \ - Payload: + if ( + param.annotation is not str + and getattr(param.annotation, "__origin__", None) is not Payload + ): raise InvalidQueryArgs(v, param) v.__path_params__ = path_params return cls @@ -96,8 +103,12 @@ def path_parameter(cls): def simple_endpoint(typ): class endpoint: - def GET() -> typ: ... - def POST(data: Payload[typ]): ... + def GET() -> typ: + ... + + def POST(data: Payload[typ]): + ... + return endpoint diff --git a/subiquity/common/api/server.py b/subiquity/common/api/server.py index 5a4ea2a8..09914245 100644 --- a/subiquity/common/api/server.py +++ b/subiquity/common/api/server.py @@ -24,7 +24,7 @@ from subiquity.common.serialize import Serializer from .defs import Payload -log = logging.getLogger('subiquity.common.api.server') +log = logging.getLogger("subiquity.common.api.server") class BindError(Exception): @@ -47,24 +47,26 @@ class SignatureMisatchError(BindError): self.actual = actual def __str__(self): - return (f"implementation of {self.methname} has wrong signature, " - f"should be {self.expected} but is {self.actual}") + return ( + f"implementation of {self.methname} has wrong signature, " + f"should be {self.expected} but is {self.actual}" + ) def trim(text): if text is None: - return '' + return "" elif len(text) > 80: - return text[:77] + '...' + return text[:77] + "..." else: return text async def check_controllers_started(definition, controller, request): - if not hasattr(controller, 'app'): + if not hasattr(controller, "app"): return - if getattr(definition, 'allowed_before_start', False): + if getattr(definition, "allowed_before_start", False): return # Most endpoints should not be responding to requests @@ -76,13 +78,14 @@ async def check_controllers_started(definition, controller, request): if controller.app.controllers_have_started.is_set(): return - log.debug(f'{request.path} waiting on start') + log.debug(f"{request.path} waiting on start") await controller.app.controllers_have_started.wait() - log.debug(f'{request.path} resuming') + log.debug(f"{request.path} resuming") -def _make_handler(controller, definition, implementation, serializer, - serialize_query_args): +def _make_handler( + controller, definition, implementation, serializer, serialize_query_args +): def_sig = inspect.signature(definition) def_ret_ann = def_sig.return_annotation def_params = def_sig.parameters @@ -99,45 +102,46 @@ def _make_handler(controller, definition, implementation, serializer, for param_name in definition.__path_params__: check_def_params.append( inspect.Parameter( - param_name, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - annotation=str)) + param_name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=str + ) + ) for param_name, param in def_params.items(): - if param_name in ('request', 'context'): + if param_name in ("request", "context"): raise Exception( "api method {} cannot have parameter called request or " - "context".format(definition)) - if getattr(param.annotation, '__origin__', None) is Payload: + "context".format(definition) + ) + if getattr(param.annotation, "__origin__", None) is Payload: data_arg = param_name data_annotation = param.annotation.__args__[0] check_def_params.append(param.replace(annotation=data_annotation)) else: - query_args_anns.append( - (param_name, param.annotation, param.default)) + query_args_anns.append((param_name, param.annotation, param.default)) check_def_params.append(param) check_impl_params = [ - p for p in impl_params.values() - if p.name not in ('context', 'request') - ] + p for p in impl_params.values() if p.name not in ("context", "request") + ] check_impl_sig = impl_sig.replace(parameters=check_impl_params) check_def_sig = def_sig.replace(parameters=check_def_params) if check_impl_sig != check_def_sig: raise SignatureMisatchError( - definition.__qualname__, check_def_sig, check_impl_sig) + definition.__qualname__, check_def_sig, check_impl_sig + ) async def handler(request): context = controller.context.child(implementation.__name__) with context: - context.set('request', request) + context.set("request", request) args = {} try: if data_annotation is not None: args[data_arg] = serializer.from_json( - data_annotation, await request.text()) + data_annotation, await request.text() + ) for arg, ann, default in query_args_anns: if arg in request.query: v = request.query[arg] @@ -146,34 +150,35 @@ def _make_handler(controller, definition, implementation, serializer, elif default != inspect._empty: v = default else: - raise TypeError( - 'missing required argument "{}"'.format(arg)) + raise TypeError('missing required argument "{}"'.format(arg)) args[arg] = v for param_name in definition.__path_params__: args[param_name] = request.match_info[param_name] - if 'context' in impl_params: - args['context'] = context - if 'request' in impl_params: - args['request'] = request - await check_controllers_started( - definition, controller, request) + if "context" in impl_params: + args["context"] = context + if "request" in impl_params: + args["request"] = request + await check_controllers_started(definition, controller, request) result = await implementation(**args) resp = web.json_response( serializer.serialize(def_ret_ann, result), - headers={'x-status': 'ok'}) + headers={"x-status": "ok"}, + ) except Exception as exc: tb = traceback.TracebackException.from_exception(exc) resp = web.Response( text="".join(tb.format()), status=500, headers={ - 'x-status': 'error', - 'x-error-type': type(exc).__name__, - 'x-error-msg': str(exc), - }) - resp['exception'] = exc - context.description = '{} {}'.format(resp.status, trim(resp.text)) + "x-status": "error", + "x-error-type": type(exc).__name__, + "x-error-msg": str(exc), + }, + ) + resp["exception"] = exc + context.description = "{} {}".format(resp.status, trim(resp.text)) return resp + handler.controller = controller return handler @@ -181,7 +186,7 @@ def _make_handler(controller, definition, implementation, serializer, async def controller_for_request(request): match_info = await request.app.router.resolve(request) - return getattr(match_info.handler, 'controller', None) + return getattr(match_info.handler, "controller", None) def bind(router, endpoint, controller, serializer=None, _depth=None): @@ -203,8 +208,9 @@ def bind(router, endpoint, controller, serializer=None, _depth=None): method=method, path=endpoint.fullpath, handler=_make_handler( - controller, v, impl, serializer, - endpoint.serialize_query_args)) + controller, v, impl, serializer, endpoint.serialize_query_args + ), + ) async def make_server_at_path(socket_path, endpoint, controller, **kw): diff --git a/subiquity/common/api/tests/test_client.py b/subiquity/common/api/tests/test_client.py index cff4e2f6..35d85043 100644 --- a/subiquity/common/api/tests/test_client.py +++ b/subiquity/common/api/tests/test_client.py @@ -17,12 +17,7 @@ import contextlib import unittest from subiquity.common.api.client import make_client -from subiquity.common.api.defs import ( - api, - InvalidQueryArgs, - path_parameter, - Payload, - ) +from subiquity.common.api.defs import InvalidQueryArgs, Payload, api, path_parameter def extract(c): @@ -46,20 +41,21 @@ class FakeResponse: class TestClient(unittest.TestCase): - def test_simple(self): - @api class API: class endpoint: - def GET() -> str: ... - def POST(data: Payload[str]) -> None: ... + def GET() -> str: + ... + + def POST(data: Payload[str]) -> None: + ... @contextlib.asynccontextmanager async def make_request(method, path, *, params, json): requests.append((method, path, params, json)) if method == "GET": - v = 'value' + v = "value" else: v = None yield FakeResponse(v) @@ -68,85 +64,95 @@ class TestClient(unittest.TestCase): requests = [] r = extract(client.endpoint.GET()) - self.assertEqual(r, 'value') - self.assertEqual(requests, [("GET", '/endpoint', {}, None)]) + self.assertEqual(r, "value") + self.assertEqual(requests, [("GET", "/endpoint", {}, None)]) requests = [] - r = extract(client.endpoint.POST('value')) + r = extract(client.endpoint.POST("value")) self.assertEqual(r, None) - self.assertEqual( - requests, [("POST", '/endpoint', {}, 'value')]) + self.assertEqual(requests, [("POST", "/endpoint", {}, "value")]) def test_args(self): - @api class API: - def GET(arg: str) -> str: ... + def GET(arg: str) -> str: + ... @contextlib.asynccontextmanager async def make_request(method, path, *, params, json): requests.append((method, path, params, json)) - yield FakeResponse(params['arg']) + yield FakeResponse(params["arg"]) client = make_client(API, make_request) requests = [] - r = extract(client.GET(arg='v')) + r = extract(client.GET(arg="v")) self.assertEqual(r, '"v"') - self.assertEqual(requests, [("GET", '/', {'arg': '"v"'}, None)]) + self.assertEqual(requests, [("GET", "/", {"arg": '"v"'}, None)]) def test_path_params(self): - @api class API: @path_parameter class param: - def GET(arg: str) -> str: ... + def GET(arg: str) -> str: + ... @contextlib.asynccontextmanager async def make_request(method, path, *, params, json): requests.append((method, path, params, json)) - yield FakeResponse(params['arg']) + yield FakeResponse(params["arg"]) client = make_client(API, make_request) requests = [] - r = extract(client['foo'].GET(arg='v')) + r = extract(client["foo"].GET(arg="v")) self.assertEqual(r, '"v"') - self.assertEqual(requests, [("GET", '/foo', {'arg': '"v"'}, None)]) + self.assertEqual(requests, [("GET", "/foo", {"arg": '"v"'}, None)]) def test_serialize_query_args(self): @api class API: serialize_query_args = False - def GET(arg: str) -> str: ... + + def GET(arg: str) -> str: + ... class meth: - def GET(arg: str, payload: Payload[int]) -> str: ... + def GET(arg: str, payload: Payload[int]) -> str: + ... class more: serialize_query_args = True - def GET(arg: str) -> str: ... + + def GET(arg: str) -> str: + ... @contextlib.asynccontextmanager async def make_request(method, path, *, params, json): requests.append((method, path, params, json)) - yield FakeResponse(params['arg']) + yield FakeResponse(params["arg"]) client = make_client(API, make_request) requests = [] - extract(client.GET(arg='v')) - extract(client.meth.GET(arg='v', payload=1)) - extract(client.meth.more.GET(arg='v')) - self.assertEqual(requests, [ - ("GET", '/', {'arg': 'v'}, None), - ("GET", '/meth', {'arg': 'v'}, 1), - ("GET", '/meth/more', {'arg': '"v"'}, None)]) + extract(client.GET(arg="v")) + extract(client.meth.GET(arg="v", payload=1)) + extract(client.meth.more.GET(arg="v")) + self.assertEqual( + requests, + [ + ("GET", "/", {"arg": "v"}, None), + ("GET", "/meth", {"arg": "v"}, 1), + ("GET", "/meth/more", {"arg": '"v"'}, None), + ], + ) class API2: serialize_query_args = False - def GET(arg: int) -> str: ... + + def GET(arg: int) -> str: + ... with self.assertRaises(InvalidQueryArgs): api(API2) diff --git a/subiquity/common/api/tests/test_endtoend.py b/subiquity/common/api/tests/test_endtoend.py index c43f76e7..1a24ee5b 100644 --- a/subiquity/common/api/tests/test_endtoend.py +++ b/subiquity/common/api/tests/test_endtoend.py @@ -13,82 +13,77 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import attr import contextlib import functools import unittest import aiohttp +import attr from aiohttp import web from subiquity.common.api.client import make_client from subiquity.common.api.defs import ( - api, MultiplePathParameters, - path_parameter, Payload, - ) + api, + path_parameter, +) -from .test_server import ( - makeTestClient, - ControllerBase, - ) +from .test_server import ControllerBase, makeTestClient def make_request(client, method, path, *, params, json): - return client.request( - method, path, params=params, json=json) + return client.request(method, path, params=params, json=json) @contextlib.asynccontextmanager -async def makeE2EClient(api, impl, - *, middlewares=(), make_request=make_request): - async with makeTestClient( - api, impl, middlewares=middlewares) as client: +async def makeE2EClient(api, impl, *, middlewares=(), make_request=make_request): + async with makeTestClient(api, impl, middlewares=middlewares) as client: mr = functools.partial(make_request, client) yield make_client(api, mr) class TestEndToEnd(unittest.IsolatedAsyncioTestCase): - async def test_simple(self): @api class API: - def GET() -> str: ... + def GET() -> str: + ... class Impl(ControllerBase): async def GET(self) -> str: - return 'value' + return "value" async with makeE2EClient(API, Impl()) as client: - self.assertEqual(await client.GET(), 'value') + self.assertEqual(await client.GET(), "value") async def test_nested(self): @api class API: class endpoint: class nested: - def GET() -> str: ... + def GET() -> str: + ... class Impl(ControllerBase): async def endpoint_nested_GET(self) -> str: - return 'value' + return "value" async with makeE2EClient(API, Impl()) as client: - self.assertEqual(await client.endpoint.nested.GET(), 'value') + self.assertEqual(await client.endpoint.nested.GET(), "value") async def test_args(self): @api class API: - def GET(arg1: str, arg2: str) -> str: ... + def GET(arg1: str, arg2: str) -> str: + ... class Impl(ControllerBase): async def GET(self, arg1: str, arg2: str) -> str: - return '{}+{}'.format(arg1, arg2) + return "{}+{}".format(arg1, arg2) async with makeE2EClient(API, Impl()) as client: - self.assertEqual( - await client.GET(arg1="A", arg2="B"), 'A+B') + self.assertEqual(await client.GET(arg1="A", arg2="B"), "A+B") async def test_path_args(self): @api @@ -97,26 +92,29 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase): class param1: @path_parameter class param2: - def GET(arg1: str, arg2: str) -> str: ... + def GET(arg1: str, arg2: str) -> str: + ... class Impl(ControllerBase): - async def param1_param2_GET(self, param1: str, param2: str, - arg1: str, arg2: str) -> str: - return '{}+{}+{}+{}'.format(param1, param2, arg1, arg2) + async def param1_param2_GET( + self, param1: str, param2: str, arg1: str, arg2: str + ) -> str: + return "{}+{}+{}+{}".format(param1, param2, arg1, arg2) async with makeE2EClient(API, Impl()) as client: - self.assertEqual( - await client["1"]["2"].GET(arg1="A", arg2="B"), '1+2+A+B') + self.assertEqual(await client["1"]["2"].GET(arg1="A", arg2="B"), "1+2+A+B") def test_only_one_path_parameter(self): class API: @path_parameter class param1: - def GET(arg: str) -> str: ... + def GET(arg: str) -> str: + ... @path_parameter class param2: - def GET(arg: str) -> str: ... + def GET(arg: str) -> str: + ... with self.assertRaises(MultiplePathParameters): api(API) @@ -124,35 +122,32 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase): async def test_defaults(self): @api class API: - def GET(arg1: str = "arg1", arg2: str = "arg2") -> str: ... + def GET(arg1: str = "arg1", arg2: str = "arg2") -> str: + ... class Impl(ControllerBase): async def GET(self, arg1: str = "arg1", arg2: str = "arg2") -> str: - return '{}+{}'.format(arg1, arg2) + return "{}+{}".format(arg1, arg2) async with makeE2EClient(API, Impl()) as client: - self.assertEqual( - await client.GET(arg1="A", arg2="B"), 'A+B') - self.assertEqual( - await client.GET(arg1="A"), 'A+arg2') - self.assertEqual( - await client.GET(arg2="B"), 'arg1+B') + self.assertEqual(await client.GET(arg1="A", arg2="B"), "A+B") + self.assertEqual(await client.GET(arg1="A"), "A+arg2") + self.assertEqual(await client.GET(arg2="B"), "arg1+B") async def test_post(self): @api class API: - def POST(data: Payload[dict]) -> str: ... + def POST(data: Payload[dict]) -> str: + ... class Impl(ControllerBase): async def POST(self, data: dict) -> str: - return data['key'] + return data["key"] async with makeE2EClient(API, Impl()) as client: - self.assertEqual( - await client.POST({'key': 'value'}), 'value') + self.assertEqual(await client.POST({"key": "value"}), "value") async def test_typed(self): - @attr.s(auto_attribs=True) class In: val: int @@ -164,11 +159,12 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase): @api class API: class doubler: - def POST(data: In) -> Out: ... + def POST(data: In) -> Out: + ... class Impl(ControllerBase): async def doubler_POST(self, data: In) -> Out: - return Out(doubled=data.val*2) + return Out(doubled=data.val * 2) async with makeE2EClient(API, Impl()) as client: out = await client.doubler.POST(In(3)) @@ -177,17 +173,16 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase): async def test_middleware(self): @api class API: - def GET() -> int: ... + def GET() -> int: + ... class Impl(ControllerBase): async def GET(self) -> int: - return 1/0 + return 1 / 0 @web.middleware async def middleware(request, handler): - return web.Response( - status=200, - headers={'x-status': 'skip'}) + return web.Response(status=200, headers={"x-status": "skip"}) class Skip(Exception): pass @@ -195,15 +190,15 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase): @contextlib.asynccontextmanager async def custom_make_request(client, method, path, *, params, json): async with make_request( - client, method, path, params=params, json=json) as resp: - if resp.headers.get('x-status') == 'skip': + client, method, path, params=params, json=json + ) as resp: + if resp.headers.get("x-status") == "skip": raise Skip yield resp async with makeE2EClient( - API, Impl(), - middlewares=[middleware], - make_request=custom_make_request) as client: + API, Impl(), middlewares=[middleware], make_request=custom_make_request + ) as client: with self.assertRaises(Skip): await client.GET() @@ -211,10 +206,12 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase): @api class API: class good: - def GET(x: int) -> int: ... + def GET(x: int) -> int: + ... class bad: - def GET(x: int) -> int: ... + def GET(x: int) -> int: + ... class Impl(ControllerBase): async def good_GET(self, x: int) -> int: @@ -233,23 +230,25 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase): @api class API: class good: - def GET(x: int) -> int: ... + def GET(x: int) -> int: + ... class bad: - def GET(x: int) -> int: ... + def GET(x: int) -> int: + ... class Impl(ControllerBase): async def good_GET(self, x: int) -> int: return x + 1 async def bad_GET(self, x: int) -> int: - 1/0 + 1 / 0 @web.middleware async def middleware(request, handler): resp = await handler(request) - if resp.get('exception'): - resp.headers['x-status'] = 'ERROR' + if resp.get("exception"): + resp.headers["x-status"] = "ERROR" return resp class Abort(Exception): @@ -258,15 +257,15 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase): @contextlib.asynccontextmanager async def custom_make_request(client, method, path, *, params, json): async with make_request( - client, method, path, params=params, json=json) as resp: - if resp.headers.get('x-status') == 'ERROR': + client, method, path, params=params, json=json + ) as resp: + if resp.headers.get("x-status") == "ERROR": raise Abort yield resp async with makeE2EClient( - API, Impl(), - middlewares=[middleware], - make_request=custom_make_request) as client: + API, Impl(), middlewares=[middleware], make_request=custom_make_request + ) as client: r = await client.good.GET(2) self.assertEqual(r, 3) with self.assertRaises(Abort): diff --git a/subiquity/common/api/tests/test_server.py b/subiquity/common/api/tests/test_server.py index 27167b13..460541f6 100644 --- a/subiquity/common/api/tests/test_server.py +++ b/subiquity/common/api/tests/test_server.py @@ -17,38 +17,30 @@ import contextlib import unittest from unittest import mock -from aiohttp.test_utils import TestClient, TestServer from aiohttp import web +from aiohttp.test_utils import TestClient, TestServer -from subiquitycore.context import Context - -from subiquity.common.api.defs import ( - allowed_before_start, - api, - path_parameter, - Payload, - ) +from subiquity.common.api.defs import Payload, allowed_before_start, api, path_parameter from subiquity.common.api.server import ( - bind, - controller_for_request, MissingImplementationError, SignatureMisatchError, - ) + bind, + controller_for_request, +) +from subiquitycore.context import Context class TestApp: - def report_start_event(self, context, description): pass def report_finish_event(self, context, description, result): pass - project = 'test' + project = "test" class ControllerBase: - def __init__(self): self.context = Context.new(TestApp()) self.app = mock.Mock() @@ -64,109 +56,109 @@ async def makeTestClient(api, impl, middlewares=()): class TestBind(unittest.IsolatedAsyncioTestCase): - async def assertResponse(self, coro, value): resp = await coro - if resp.headers.get('x-error-msg'): + if resp.headers.get("x-error-msg"): raise Exception( - resp.headers['x-error-type'] + ': ' + - resp.headers['x-error-msg']) + resp.headers["x-error-type"] + ": " + resp.headers["x-error-msg"] + ) self.assertEqual(resp.status, 200) self.assertEqual(await resp.json(), value) async def test_simple(self): @api class API: - def GET() -> str: ... + def GET() -> str: + ... class Impl(ControllerBase): async def GET(self) -> str: - return 'value' + return "value" async with makeTestClient(API, Impl()) as client: - await self.assertResponse( - client.get("/"), 'value') + await self.assertResponse(client.get("/"), "value") async def test_nested(self): @api class API: class endpoint: class nested: - def get(): ... + def get(): + ... class Impl(ControllerBase): async def nested_get(self, request, context): - return 'nested' + return "nested" async with makeTestClient(API.endpoint, Impl()) as client: - await self.assertResponse( - client.get("/endpoint/nested"), 'nested') + await self.assertResponse(client.get("/endpoint/nested"), "nested") async def test_args(self): @api class API: - def GET(arg: str): ... + def GET(arg: str): + ... class Impl(ControllerBase): async def GET(self, arg: str): return arg async with makeTestClient(API, Impl()) as client: - await self.assertResponse( - client.get('/?arg="whut"'), 'whut') + await self.assertResponse(client.get('/?arg="whut"'), "whut") async def test_missing_argument(self): @api class API: - def GET(arg: str): ... + def GET(arg: str): + ... class Impl(ControllerBase): async def GET(self, arg: str): return arg async with makeTestClient(API, Impl()) as client: - resp = await client.get('/') + resp = await client.get("/") self.assertEqual(resp.status, 500) - self.assertEqual(resp.headers['x-status'], 'error') - self.assertEqual(resp.headers['x-error-type'], 'TypeError') + self.assertEqual(resp.headers["x-status"], "error") + self.assertEqual(resp.headers["x-error-type"], "TypeError") self.assertEqual( - resp.headers['x-error-msg'], - 'missing required argument "arg"') + resp.headers["x-error-msg"], 'missing required argument "arg"' + ) async def test_error(self): @api class API: - def GET(): ... + def GET(): + ... class Impl(ControllerBase): async def GET(self): - return 1/0 + return 1 / 0 async with makeTestClient(API, Impl()) as client: - resp = await client.get('/') + resp = await client.get("/") self.assertEqual(resp.status, 500) - self.assertEqual(resp.headers['x-status'], 'error') - self.assertEqual( - resp.headers['x-error-type'], 'ZeroDivisionError') + self.assertEqual(resp.headers["x-status"], "error") + self.assertEqual(resp.headers["x-error-type"], "ZeroDivisionError") async def test_post(self): @api class API: - def POST(data: Payload[str]) -> str: ... + def POST(data: Payload[str]) -> str: + ... class Impl(ControllerBase): async def POST(self, data: str) -> str: return data async with makeTestClient(API, Impl()) as client: - await self.assertResponse( - client.post("/", json='value'), - 'value') + await self.assertResponse(client.post("/", json="value"), "value") def test_missing_method(self): @api class API: - def GET(arg: str): ... + def GET(arg: str): + ... class Impl(ControllerBase): pass @@ -179,7 +171,8 @@ class TestBind(unittest.IsolatedAsyncioTestCase): def test_signature_checking(self): @api class API: - def GET(arg: str): ... + def GET(arg: str): + ... class Impl(ControllerBase): async def GET(self, arg: int): @@ -191,31 +184,28 @@ class TestBind(unittest.IsolatedAsyncioTestCase): self.assertEqual(cm.exception.methname, "API.GET") async def test_middleware(self): - @web.middleware async def middleware(request, handler): resp = await handler(request) - exc = resp.get('exception') - if resp.get('exception') is not None: - resp.headers['x-error-type'] = type(exc).__name__ + exc = resp.get("exception") + if resp.get("exception") is not None: + resp.headers["x-error-type"] = type(exc).__name__ return resp @api class API: - def GET() -> str: ... + def GET() -> str: + ... class Impl(ControllerBase): async def GET(self) -> str: - return 1/0 + return 1 / 0 - async with makeTestClient( - API, Impl(), middlewares=[middleware]) as client: + async with makeTestClient(API, Impl(), middlewares=[middleware]) as client: resp = await client.get("/") - self.assertEqual( - resp.headers['x-error-type'], 'ZeroDivisionError') + self.assertEqual(resp.headers["x-error-type"], "ZeroDivisionError") async def test_controller_for_request(self): - seen_controller = None @web.middleware @@ -227,18 +217,18 @@ class TestBind(unittest.IsolatedAsyncioTestCase): @api class API: class meth: - def GET() -> str: ... + def GET() -> str: + ... class Impl(ControllerBase): async def GET(self) -> str: - return '' + return "" impl = Impl() - async with makeTestClient( - API.meth, impl, middlewares=[middleware]) as client: + async with makeTestClient(API.meth, impl, middlewares=[middleware]) as client: resp = await client.get("/meth") - self.assertEqual(await resp.json(), '') + self.assertEqual(await resp.json(), "") self.assertIs(impl, seen_controller) @@ -246,14 +236,19 @@ class TestBind(unittest.IsolatedAsyncioTestCase): @api class API: serialize_query_args = False - def GET(arg: str) -> str: ... + + def GET(arg: str) -> str: + ... class meth: - def GET(arg: str) -> str: ... + def GET(arg: str) -> str: + ... class more: serialize_query_args = True - def GET(arg: str) -> str: ... + + def GET(arg: str) -> str: + ... class Impl(ControllerBase): async def GET(self, arg: str) -> str: @@ -266,67 +261,64 @@ class TestBind(unittest.IsolatedAsyncioTestCase): return arg async with makeTestClient(API, Impl()) as client: - await self.assertResponse( - client.get('/?arg=whut'), 'whut') - await self.assertResponse( - client.get('/meth?arg=whut'), 'whut') - await self.assertResponse( - client.get('/meth/more?arg="whut"'), 'whut') + await self.assertResponse(client.get("/?arg=whut"), "whut") + await self.assertResponse(client.get("/meth?arg=whut"), "whut") + await self.assertResponse(client.get('/meth/more?arg="whut"'), "whut") async def test_path_parameters(self): @api class API: @path_parameter class param: - def GET(arg: int): ... + def GET(arg: int): + ... class Impl(ControllerBase): async def param_GET(self, param: str, arg: int): return param + str(arg) async with makeTestClient(API, Impl()) as client: - await self.assertResponse( - client.get('/value?arg=2'), 'value2') + await self.assertResponse(client.get("/value?arg=2"), "value2") async def test_early_connect_ok(self): @api class API: class can_be_used_early: @allowed_before_start - def GET(): ... + def GET(): + ... class Impl(ControllerBase): def __init__(self): super().__init__() self.app.controllers_have_started = mock.AsyncMock() - self.app.controllers_have_started.is_set = mock.Mock( - return_value=False) + self.app.controllers_have_started.is_set = mock.Mock(return_value=False) async def can_be_used_early_GET(self): pass impl = Impl() async with makeTestClient(API, impl) as client: - await client.get('/can_be_used_early') + await client.get("/can_be_used_early") impl.app.controllers_have_started.wait.assert_not_called() async def test_early_connect_wait(self): @api class API: class must_not_be_used_early: - def GET(): ... + def GET(): + ... class Impl(ControllerBase): def __init__(self): super().__init__() self.app.controllers_have_started = mock.AsyncMock() - self.app.controllers_have_started.is_set = mock.Mock( - return_value=False) + self.app.controllers_have_started.is_set = mock.Mock(return_value=False) async def must_not_be_used_early_GET(self): pass impl = Impl() async with makeTestClient(API, impl) as client: - await client.get('/must_not_be_used_early') + await client.get("/must_not_be_used_early") impl.app.controllers_have_started.wait.assert_called_once() diff --git a/subiquity/common/apidef.py b/subiquity/common/apidef.py index 88e457d6..c1570270 100644 --- a/subiquity/common/apidef.py +++ b/subiquity/common/apidef.py @@ -16,79 +16,79 @@ import enum from typing import List, Optional -from subiquitycore.models.network import ( - BondConfig, - NetDevInfo, - StaticConfig, - WLANConfig, - ) - from subiquity.common.api.defs import ( + Payload, allowed_before_start, api, - Payload, simple_endpoint, - ) +) from subiquity.common.types import ( - AdConnectionInfo, AdAdminNameValidation, + AdConnectionInfo, AdDomainNameValidation, + AddPartitionV2, AdJoinResult, AdPasswordValidation, - AddPartitionV2, AnyStep, ApplicationState, ApplicationStatus, CasperMd5Results, Change, + CodecsData, Disk, + DriversPayload, + DriversResponse, ErrorReportRef, GuidedChoiceV2, GuidedStorageResponseV2, + IdentityData, KeyboardSetting, KeyboardSetup, - IdentityData, - NetworkStatus, - MirrorSelectionFallback, + LiveSessionSSHInfo, + MirrorCheckResponse, MirrorGet, MirrorPost, MirrorPostResponse, - MirrorCheckResponse, + MirrorSelectionFallback, ModifyPartitionV2, + NetworkStatus, + OEMResponse, + PackageInstallState, ReformatDisk, RefreshStatus, ShutdownMode, - CodecsData, - DriversResponse, - DriversPayload, - OEMResponse, SnapInfo, SnapListResponse, SnapSelection, SourceSelectionAndSetting, SSHData, SSHFetchIdResponse, - LiveSessionSSHInfo, StorageResponse, StorageResponseV2, TimeZoneInfo, + UbuntuProCheckTokenAnswer, UbuntuProInfo, UbuntuProResponse, - UbuntuProCheckTokenAnswer, UPCSInitiateResponse, UPCSWaitResponse, UsernameValidation, - PackageInstallState, - ZdevInfo, - WSLConfigurationBase, WSLConfigurationAdvanced, + WSLConfigurationBase, WSLSetupOptions, - ) + ZdevInfo, +) +from subiquitycore.models.network import ( + BondConfig, + NetDevInfo, + StaticConfig, + WLANConfig, +) @api class API: """The API offered by the subiquity installer process.""" + locale = simple_endpoint(str) proxy = simple_endpoint(str) updates = simple_endpoint(str) @@ -99,8 +99,7 @@ class API: class meta: class status: @allowed_before_start - def GET(cur: Optional[ApplicationState] = None) \ - -> ApplicationStatus: + def GET(cur: Optional[ApplicationState] = None) -> ApplicationStatus: """Get the installer state.""" class mark_configured: @@ -111,6 +110,7 @@ class API: def POST(variant: str) -> None: """Choose the install variant - desktop/server/wsl_setup/wsl_reconfigure""" + def GET() -> str: """Get the install variant - desktop/server/wsl_setup/wsl_reconfigure""" @@ -124,10 +124,12 @@ class API: """Restart the server process.""" class ssh_info: - def GET() -> Optional[LiveSessionSSHInfo]: ... + def GET() -> Optional[LiveSessionSSHInfo]: + ... class free_only: - def GET() -> bool: ... + def GET() -> bool: + ... def POST(enable: bool) -> None: """Enable or disable free-only mode. Currently only controlls @@ -135,7 +137,8 @@ class API: confirmation of filesystem changes""" class interactive_sections: - def GET() -> Optional[List[str]]: ... + def GET() -> Optional[List[str]]: + ... class errors: class wait: @@ -159,38 +162,55 @@ class API: """Start the update and return the change id.""" class progress: - def GET(change_id: str) -> Change: ... + def GET(change_id: str) -> Change: + ... class keyboard: - def GET() -> KeyboardSetup: ... - def POST(data: Payload[KeyboardSetting]): ... + def GET() -> KeyboardSetup: + ... + + def POST(data: Payload[KeyboardSetting]): + ... class needs_toggle: - def GET(layout_code: str, variant_code: str) -> bool: ... + def GET(layout_code: str, variant_code: str) -> bool: + ... class steps: - def GET(index: Optional[str]) -> AnyStep: ... + def GET(index: Optional[str]) -> AnyStep: + ... class input_source: - def POST(data: Payload[KeyboardSetting], - user: Optional[str] = None) -> None: ... + def POST( + data: Payload[KeyboardSetting], user: Optional[str] = None + ) -> None: + ... class source: - def GET() -> SourceSelectionAndSetting: ... - def POST(source_id: str, search_drivers: bool = False) -> None: ... + def GET() -> SourceSelectionAndSetting: + ... + + def POST(source_id: str, search_drivers: bool = False) -> None: + ... class zdev: - def GET() -> List[ZdevInfo]: ... + def GET() -> List[ZdevInfo]: + ... class chzdev: - def POST(action: str, zdev: ZdevInfo) -> List[ZdevInfo]: ... + def POST(action: str, zdev: ZdevInfo) -> List[ZdevInfo]: + ... class network: - def GET() -> NetworkStatus: ... - def POST() -> None: ... + def GET() -> NetworkStatus: + ... + + def POST() -> None: + ... class has_network: - def GET() -> bool: ... + def GET() -> bool: + ... class global_addresses: def GET() -> List[str]: @@ -201,8 +221,12 @@ class API: The socket must serve the NetEventAPI below. """ - def PUT(socket_path: str) -> None: ... - def DELETE(socket_path: str) -> None: ... + + def PUT(socket_path: str) -> None: + ... + + def DELETE(socket_path: str) -> None: + ... # These methods could definitely be more RESTish, like maybe a # GET request to /network/interfaces/$name should return netplan @@ -237,77 +261,100 @@ class API: # So methods on nics get an extra dev_name: str parameter) class set_static_config: - def POST(dev_name: str, ip_version: int, - static_config: Payload[StaticConfig]) -> None: ... + def POST( + dev_name: str, ip_version: int, static_config: Payload[StaticConfig] + ) -> None: + ... class enable_dhcp: - def POST(dev_name: str, ip_version: int) -> None: ... + def POST(dev_name: str, ip_version: int) -> None: + ... class disable: - def POST(dev_name: str, ip_version: int) -> None: ... + def POST(dev_name: str, ip_version: int) -> None: + ... class vlan: - def PUT(dev_name: str, vlan_id: int) -> None: ... + def PUT(dev_name: str, vlan_id: int) -> None: + ... class add_or_edit_bond: - def POST(existing_name: Optional[str], new_name: str, - bond_config: Payload[BondConfig]) -> None: ... + def POST( + existing_name: Optional[str], + new_name: str, + bond_config: Payload[BondConfig], + ) -> None: + ... class start_scan: - def POST(dev_name: str) -> None: ... + def POST(dev_name: str) -> None: + ... class set_wlan: - def POST(dev_name: str, wlan: WLANConfig) -> None: ... + def POST(dev_name: str, wlan: WLANConfig) -> None: + ... class delete: - def POST(dev_name: str) -> None: ... + def POST(dev_name: str) -> None: + ... class info: - def GET(dev_name: str) -> str: ... + def GET(dev_name: str) -> str: + ... class storage: class guided: - def POST(data: Payload[GuidedChoiceV2]) \ - -> StorageResponse: + def POST(data: Payload[GuidedChoiceV2]) -> StorageResponse: pass - def GET(wait: bool = False, use_cached_result: bool = False) \ - -> StorageResponse: ... + def GET(wait: bool = False, use_cached_result: bool = False) -> StorageResponse: + ... - def POST(config: Payload[list]): ... + def POST(config: Payload[list]): + ... class dry_run_wait_probe: """This endpoint only works in dry-run mode.""" - def POST() -> None: ... + + def POST() -> None: + ... class reset: - def POST() -> StorageResponse: ... + def POST() -> StorageResponse: + ... class has_rst: def GET() -> bool: pass class has_bitlocker: - def GET() -> List[Disk]: ... + def GET() -> List[Disk]: + ... class v2: def GET( - wait: bool = False, - include_raid: bool = False, - ) -> StorageResponseV2: ... + wait: bool = False, + include_raid: bool = False, + ) -> StorageResponseV2: + ... - def POST() -> StorageResponseV2: ... + def POST() -> StorageResponseV2: + ... class orig_config: - def GET() -> StorageResponseV2: ... + def GET() -> StorageResponseV2: + ... class guided: - def GET(wait: bool = False) -> GuidedStorageResponseV2: ... - def POST(data: Payload[GuidedChoiceV2]) \ - -> GuidedStorageResponseV2: ... + def GET(wait: bool = False) -> GuidedStorageResponseV2: + ... + + def POST(data: Payload[GuidedChoiceV2]) -> GuidedStorageResponseV2: + ... class reset: - def POST() -> StorageResponseV2: ... + def POST() -> StorageResponseV2: + ... class ensure_transaction: """This call will ensure that a transaction is initiated. @@ -316,163 +363,218 @@ class API: A transaction will also be initiated by any v2_storage POST request that modifies the partitioning configuration (e.g., add_partition, edit_partition, ...) but ensure_transaction can - be called early if desired. """ - def POST() -> None: ... + be called early if desired.""" + + def POST() -> None: + ... class reformat_disk: - def POST(data: Payload[ReformatDisk]) \ - -> StorageResponseV2: ... + def POST(data: Payload[ReformatDisk]) -> StorageResponseV2: + ... class add_boot_partition: """Mark a given disk as bootable, which may cause a partition to be added to the disk. It is an error to call this for a disk for which can_be_boot_device is False.""" - def POST(disk_id: str) -> StorageResponseV2: ... + + def POST(disk_id: str) -> StorageResponseV2: + ... class add_partition: """required field format and mount, optional field size - default behavior expands partition to fill disk if size not - supplied or -1. - Other partition fields are ignored. - adding a partion when there is not yet a boot partition can - result in the boot partition being added automatically - see - add_boot_partition for more control over this. - format=None means an unformatted partition + default behavior expands partition to fill disk if size not + supplied or -1. + Other partition fields are ignored. + adding a partion when there is not yet a boot partition can + result in the boot partition being added automatically - see + add_boot_partition for more control over this. + format=None means an unformatted partition """ - def POST(data: Payload[AddPartitionV2]) \ - -> StorageResponseV2: ... + + def POST(data: Payload[AddPartitionV2]) -> StorageResponseV2: + ... class delete_partition: """required field number - It is an error to modify other Partition fields. + It is an error to modify other Partition fields. """ - def POST(data: Payload[ModifyPartitionV2]) \ - -> StorageResponseV2: ... + + def POST(data: Payload[ModifyPartitionV2]) -> StorageResponseV2: + ... class edit_partition: """required field number - optional fields wipe, mount, format, size - It is an error to do wipe=null and change the format. - It is an error to modify other Partition fields. + optional fields wipe, mount, format, size + It is an error to do wipe=null and change the format. + It is an error to modify other Partition fields. """ - def POST(data: Payload[ModifyPartitionV2]) \ - -> StorageResponseV2: ... + + def POST(data: Payload[ModifyPartitionV2]) -> StorageResponseV2: + ... class codecs: - def GET() -> CodecsData: ... - def POST(data: Payload[CodecsData]) -> None: ... + def GET() -> CodecsData: + ... + + def POST(data: Payload[CodecsData]) -> None: + ... class drivers: - def GET(wait: bool = False) -> DriversResponse: ... - def POST(data: Payload[DriversPayload]) -> None: ... + def GET(wait: bool = False) -> DriversResponse: + ... + + def POST(data: Payload[DriversPayload]) -> None: + ... class oem: - def GET(wait: bool = False) -> OEMResponse: ... + def GET(wait: bool = False) -> OEMResponse: + ... class snaplist: - def GET(wait: bool = False) -> SnapListResponse: ... - def POST(data: Payload[List[SnapSelection]]): ... + def GET(wait: bool = False) -> SnapListResponse: + ... + + def POST(data: Payload[List[SnapSelection]]): + ... class snap_info: - def GET(snap_name: str) -> SnapInfo: ... + def GET(snap_name: str) -> SnapInfo: + ... class timezone: - def GET() -> TimeZoneInfo: ... - def POST(tz: str): ... + def GET() -> TimeZoneInfo: + ... + + def POST(tz: str): + ... class shutdown: - def POST(mode: ShutdownMode, immediate: bool = False): ... + def POST(mode: ShutdownMode, immediate: bool = False): + ... class mirror: - def GET() -> MirrorGet: ... + def GET() -> MirrorGet: + ... - def POST(data: Payload[Optional[MirrorPost]]) \ - -> MirrorPostResponse: ... + def POST(data: Payload[Optional[MirrorPost]]) -> MirrorPostResponse: + ... class disable_components: - def GET() -> List[str]: ... - def POST(data: Payload[List[str]]): ... + def GET() -> List[str]: + ... + + def POST(data: Payload[List[str]]): + ... class check_mirror: class start: - def POST(cancel_ongoing: bool = False) -> None: ... + def POST(cancel_ongoing: bool = False) -> None: + ... class progress: - def GET() -> Optional[MirrorCheckResponse]: ... + def GET() -> Optional[MirrorCheckResponse]: + ... class abort: - def POST() -> None: ... + def POST() -> None: + ... fallback = simple_endpoint(MirrorSelectionFallback) class ubuntu_pro: - def GET() -> UbuntuProResponse: ... - def POST(data: Payload[UbuntuProInfo]) -> None: ... + def GET() -> UbuntuProResponse: + ... + + def POST(data: Payload[UbuntuProInfo]) -> None: + ... class skip: - def POST() -> None: ... + def POST() -> None: + ... class check_token: - def GET(token: Payload[str]) \ - -> UbuntuProCheckTokenAnswer: ... + def GET(token: Payload[str]) -> UbuntuProCheckTokenAnswer: + ... class contract_selection: class initiate: - def POST() -> UPCSInitiateResponse: ... + def POST() -> UPCSInitiateResponse: + ... class wait: - def GET() -> UPCSWaitResponse: ... + def GET() -> UPCSWaitResponse: + ... class cancel: - def POST() -> None: ... + def POST() -> None: + ... class identity: - def GET() -> IdentityData: ... - def POST(data: Payload[IdentityData]): ... + def GET() -> IdentityData: + ... + + def POST(data: Payload[IdentityData]): + ... class validate_username: - def GET(username: str) -> UsernameValidation: ... + def GET(username: str) -> UsernameValidation: + ... class ssh: - def GET() -> SSHData: ... - def POST(data: Payload[SSHData]) -> None: ... + def GET() -> SSHData: + ... + + def POST(data: Payload[SSHData]) -> None: + ... class fetch_id: - def GET(user_id: str) -> SSHFetchIdResponse: ... + def GET(user_id: str) -> SSHFetchIdResponse: + ... class integrity: - def GET() -> CasperMd5Results: ... + def GET() -> CasperMd5Results: + ... class active_directory: - def GET() -> Optional[AdConnectionInfo]: ... + def GET() -> Optional[AdConnectionInfo]: + ... + # POST expects input validated by the check methods below: - def POST(data: Payload[AdConnectionInfo]) -> None: ... + def POST(data: Payload[AdConnectionInfo]) -> None: + ... class has_support: - """ Whether the live system supports Active Directory or not. - Network status is not considered. - Clients should call this before showing the AD page. """ - def GET() -> bool: ... + """Whether the live system supports Active Directory or not. + Network status is not considered. + Clients should call this before showing the AD page.""" + + def GET() -> bool: + ... class check_domain_name: - """ Applies basic validation to the candidate domain name, - without calling the domain network. """ - def POST(domain_name: Payload[str]) \ - -> List[AdDomainNameValidation]: ... + """Applies basic validation to the candidate domain name, + without calling the domain network.""" + + def POST(domain_name: Payload[str]) -> List[AdDomainNameValidation]: + ... class ping_domain_controller: - """ Attempts to contact the controller of the provided domain. """ - def POST(domain_name: Payload[str]) \ - -> AdDomainNameValidation: ... + """Attempts to contact the controller of the provided domain.""" + + def POST(domain_name: Payload[str]) -> AdDomainNameValidation: + ... class check_admin_name: - def POST(admin_name: Payload[str]) -> AdAdminNameValidation: ... + def POST(admin_name: Payload[str]) -> AdAdminNameValidation: + ... class check_password: - def POST(password: Payload[str]) -> AdPasswordValidation: ... + def POST(password: Payload[str]) -> AdPasswordValidation: + ... class join_result: - def GET(wait: bool = True) -> AdJoinResult: ... + def GET(wait: bool = True) -> AdJoinResult: + ... class LinkAction(enum.Enum): @@ -484,19 +586,25 @@ class LinkAction(enum.Enum): @api class NetEventAPI: class wlan_support_install_finished: - def POST(state: PackageInstallState) -> None: ... + def POST(state: PackageInstallState) -> None: + ... class update_link: - def POST(act: LinkAction, info: Payload[NetDevInfo]) -> None: ... + def POST(act: LinkAction, info: Payload[NetDevInfo]) -> None: + ... class route_watch: - def POST(has_default_route: bool) -> None: ... + def POST(has_default_route: bool) -> None: + ... class apply_starting: - def POST() -> None: ... + def POST() -> None: + ... class apply_stopping: - def POST() -> None: ... + def POST() -> None: + ... class apply_error: - def POST(stage: str) -> None: ... + def POST(stage: str) -> None: + ... diff --git a/subiquity/common/errorreport.py b/subiquity/common/errorreport.py index 442cb4fa..8e6096ba 100644 --- a/subiquity/common/errorreport.py +++ b/subiquity/common/errorreport.py @@ -27,33 +27,20 @@ from typing import Iterable, Set import apport import apport.crashdb import apport.hookutils - import attr - import bson - import requests - import urwid -from subiquitycore.async_helpers import ( - run_in_thread, - schedule_task, - ) +from subiquity.common.types import ErrorReportKind, ErrorReportRef, ErrorReportState +from subiquitycore.async_helpers import run_in_thread, schedule_task -from subiquity.common.types import ( - ErrorReportKind, - ErrorReportRef, - ErrorReportState, - ) - - -log = logging.getLogger('subiquity.common.errorreport') +log = logging.getLogger("subiquity.common.errorreport") @attr.s(eq=False) class Upload(metaclass=urwid.MetaSignals): - signals = ['progress'] + signals = ["progress"] bytes_to_send = attr.ib() bytes_sent = attr.ib(default=0) @@ -68,13 +55,13 @@ class Upload(metaclass=urwid.MetaSignals): def _progress(self): os.read(self.pipe_r, 4096) - urwid.emit_signal(self, 'progress') + urwid.emit_signal(self, "progress") def _bg_update(self, sent, to_send=None): self.bytes_sent = sent if to_send is not None: self.bytes_to_send = to_send - os.write(self.pipe_w, b'x') + os.write(self.pipe_w, b"x") def stop(self): asyncio.get_running_loop().remove_reader(self.pipe_r) @@ -84,7 +71,6 @@ class Upload(metaclass=urwid.MetaSignals): @attr.s(eq=False) class ErrorReport(metaclass=urwid.MetaSignals): - signals = ["changed"] reporter = attr.ib() @@ -101,17 +87,19 @@ class ErrorReport(metaclass=urwid.MetaSignals): @classmethod def new(cls, reporter, kind): base = "{:.9f}.{}".format(time.time(), kind.name.lower()) - crash_file = open( - os.path.join(reporter.crash_directory, base + ".crash"), - 'wb') + crash_file = open(os.path.join(reporter.crash_directory, base + ".crash"), "wb") - pr = apport.Report('Bug') - pr['CrashDB'] = repr(reporter.crashdb_spec) + pr = apport.Report("Bug") + pr["CrashDB"] = repr(reporter.crashdb_spec) r = cls( - reporter=reporter, base=base, pr=pr, file=crash_file, + reporter=reporter, + base=base, + pr=pr, + file=crash_file, state=ErrorReportState.INCOMPLETE, - context=reporter.context.child(base)) + context=reporter.context.child(base), + ) r.set_meta("kind", kind.name) return r @@ -119,11 +107,15 @@ class ErrorReport(metaclass=urwid.MetaSignals): def from_file(cls, reporter, fpath): base = os.path.splitext(os.path.basename(fpath))[0] report = cls( - reporter, base, pr=apport.Report(date='???'), - state=ErrorReportState.LOADING, file=open(fpath, 'rb'), - context=reporter.context.child(base)) + reporter, + base, + pr=apport.Report(date="???"), + state=ErrorReportState.LOADING, + file=open(fpath, "rb"), + context=reporter.context.child(base), + ) try: - fp = open(report.meta_path, 'r') + fp = open(report.meta_path, "r") except FileNotFoundError: pass else: @@ -140,26 +132,27 @@ class ErrorReport(metaclass=urwid.MetaSignals): if not self.reporter.dry_run: self.pr.add_hooks_info(None) apport.hookutils.attach_hardware(self.pr) - self.pr['Syslog'] = apport.hookutils.recent_syslog(re.compile('.')) - snap_name = os.environ.get('SNAP_NAME', '') - if snap_name != '': + self.pr["Syslog"] = apport.hookutils.recent_syslog(re.compile(".")) + snap_name = os.environ.get("SNAP_NAME", "") + if snap_name != "": self.add_tags([snap_name]) # Because apport-cli will in general be run on a different # machine, we make some slightly obscure alterations to the report # to make this go better. # apport-cli gets upset if neither of these are present. - self.pr['Package'] = 'subiquity ' + os.environ.get( - "SNAP_REVISION", "SNAP_REVISION") - self.pr['SourcePackage'] = 'subiquity' + self.pr["Package"] = "subiquity " + os.environ.get( + "SNAP_REVISION", "SNAP_REVISION" + ) + self.pr["SourcePackage"] = "subiquity" # If ExecutableTimestamp is present, apport-cli will try to check # that ExecutablePath hasn't changed. But it won't be there. - del self.pr['ExecutableTimestamp'] + del self.pr["ExecutableTimestamp"] # apport-cli gets upset at the probert C extensions it sees in # here. /proc/maps is very unlikely to be interesting for us # anyway. - del self.pr['ProcMaps'] + del self.pr["ProcMaps"] self.pr.write(self._file) async def add_info(): @@ -175,6 +168,7 @@ class ErrorReport(metaclass=urwid.MetaSignals): self._file.close() self._file = None urwid.emit_signal(self, "changed") + if wait: with self._context.child("add_info") as context: _bg_add_info() @@ -210,19 +204,17 @@ class ErrorReport(metaclass=urwid.MetaSignals): if uploader.cancelled: log.debug("upload for %s cancelled", self.base) return - yield data[i:i+chunk_size] + yield data[i : i + chunk_size] uploader._bg_update(uploader.bytes_sent + chunk_size) def _bg_upload(): - for_upload = { - "Kind": self.kind.value - } + for_upload = {"Kind": self.kind.value} for k, v in self.pr.items(): if len(v) < 1024 or k in { - "InstallerLogInfo", - "Traceback", - "ProcCpuinfoMinimal", - }: + "InstallerLogInfo", + "Traceback", + "ProcCpuinfoMinimal", + }: for_upload[k] = v else: log.debug("dropping %s of length %s", k, len(v)) @@ -236,9 +228,10 @@ class ErrorReport(metaclass=urwid.MetaSignals): data = bson.BSON().encode(for_upload) self.uploader._bg_update(0, len(data)) headers = { - 'user-agent': 'subiquity/{}'.format( - os.environ.get("SNAP_VERSION", "SNAP_VERSION")), - } + "user-agent": "subiquity/{}".format( + os.environ.get("SNAP_VERSION", "SNAP_VERSION") + ), + } response = requests.post(url, data=chunk(data), headers=headers) response.raise_for_status() return response.text.split()[0] @@ -254,28 +247,27 @@ class ErrorReport(metaclass=urwid.MetaSignals): context.description = oops_id uploader.stop() self.uploader = None - urwid.emit_signal(self, 'changed') + urwid.emit_signal(self, "changed") - urwid.emit_signal(self, 'changed') + urwid.emit_signal(self, "changed") uploader.start() schedule_task(upload()) def _path_with_ext(self, ext): - return os.path.join( - self.reporter.crash_directory, self.base + '.' + ext) + return os.path.join(self.reporter.crash_directory, self.base + "." + ext) @property def meta_path(self): - return self._path_with_ext('meta') + return self._path_with_ext("meta") @property def path(self): - return self._path_with_ext('crash') + return self._path_with_ext("crash") def set_meta(self, key, value): self.meta[key] = value - with open(self.meta_path, 'w') as fp: + with open(self.meta_path, "w") as fp: json.dump(self.meta, fp, indent=4) def mark_seen(self): @@ -300,9 +292,8 @@ class ErrorReport(metaclass=urwid.MetaSignals): """Return fs-label, path-on-fs to report.""" # Not sure if this is more or less sane than shelling out to # findmnt(1). - looking_for = os.path.abspath( - os.path.normpath(self.reporter.crash_directory)) - for line in open('/proc/self/mountinfo').readlines(): + looking_for = os.path.abspath(os.path.normpath(self.reporter.crash_directory)) + for line in open("/proc/self/mountinfo").readlines(): parts = line.strip().split() if os.path.normpath(parts[4]) == looking_for: devname = parts[9] @@ -310,19 +301,19 @@ class ErrorReport(metaclass=urwid.MetaSignals): break else: if self.reporter.dry_run: - path = ('install-logs/2019-11-06.0/crash/' + - self.base + - '.crash') + path = "install-logs/2019-11-06.0/crash/" + self.base + ".crash" return "casper-rw", path return None, None import pyudev + c = pyudev.Context() - devs = list(c.list_devices( - subsystem='block', DEVNAME=os.path.realpath(devname))) + devs = list( + c.list_devices(subsystem="block", DEVNAME=os.path.realpath(devname)) + ) if not devs: return None, None - label = devs[0].get('ID_FS_LABEL_ENC', '') - return label, root[1:] + '/' + self.base + '.crash' + label = devs[0].get("ID_FS_LABEL_ENC", "") + return label, root[1:] + "/" + self.base + ".crash" def ref(self): return ErrorReportRef( @@ -331,7 +322,7 @@ class ErrorReport(metaclass=urwid.MetaSignals): kind=self.kind, seen=self.seen, oops_id=self.oops_id, - ) + ) # with core24 these tag methods can be dropped for equivalent methods # that will be on the report object @@ -349,22 +340,21 @@ class ErrorReport(metaclass=urwid.MetaSignals): class ErrorReporter(object): - def __init__(self, context, dry_run, root, client=None): self.context = context self.dry_run = dry_run - self.crash_directory = os.path.join(root, 'var/crash') + self.crash_directory = os.path.join(root, "var/crash") self.client = client self.reports = [] self._reports_by_base = {} self._reports_by_exception = {} self.crashdb_spec = { - 'impl': 'launchpad', - 'project': 'subiquity', - } + "impl": "launchpad", + "project": "subiquity", + } if dry_run: - self.crashdb_spec['launchpad_instance'] = 'staging' + self.crashdb_spec["launchpad_instance"] = "staging" self._apport_data = [] self._apport_files = [] @@ -398,7 +388,7 @@ class ErrorReporter(object): return self._reports_by_exception.get(exc) def make_apport_report(self, kind, thing, *, wait=False, exc=None, **kw): - if not self.dry_run and not os.path.exists('/cdrom/.disk/info'): + if not self.dry_run and not os.path.exists("/cdrom/.disk/info"): return None log.debug("generating crash report") @@ -415,16 +405,14 @@ class ErrorReporter(object): if exc is None: exc = sys.exc_info()[1] if exc is not None: - report.pr["Title"] = "{} crashed with {}".format( - thing, type(exc).__name__) + report.pr["Title"] = "{} crashed with {}".format(thing, type(exc).__name__) tb = traceback.TracebackException.from_exception(exc) - report.pr['Traceback'] = "".join(tb.format()) + report.pr["Traceback"] = "".join(tb.format()) self._reports_by_exception[exc] = report else: report.pr["Title"] = thing - log.info( - "saving crash report %r to %s", report.pr["Title"], report.path) + log.info("saving crash report %r to %s", report.pr["Title"], report.path) apport_files = self._apport_files[:] apport_data = self._apport_data.copy() @@ -456,8 +444,7 @@ class ErrorReporter(object): await self.client.errors.wait.GET(error_ref) - path = os.path.join( - self.crash_directory, error_ref.base + '.crash') + path = os.path.join(self.crash_directory, error_ref.base + ".crash") report = ErrorReport.from_file(self, path) self.reports.insert(0, report) self._reports_by_base[error_ref.base] = report diff --git a/subiquity/common/filesystem/actions.py b/subiquity/common/filesystem/actions.py index ce4b9689..38ed88ae 100644 --- a/subiquity/common/filesystem/actions.py +++ b/subiquity/common/filesystem/actions.py @@ -27,8 +27,7 @@ from subiquity.models.filesystem import ( Partition, Raid, raidlevels_by_value, - ) - +) _checkers = {} @@ -37,8 +36,9 @@ def make_checker(action): @functools.singledispatch def impl(device): raise NotImplementedError( - "checker for %s on %s not implemented" % ( - action, device)) + "checker for %s on %s not implemented" % (action, device) + ) + _checkers[action] = impl return impl @@ -75,8 +75,7 @@ class DeviceAction(enum.Enum): @functools.singledispatch def _supported_actions(device): - raise NotImplementedError( - "_supported_actions({}) not defined".format(device)) + raise NotImplementedError("_supported_actions({}) not defined".format(device)) @_supported_actions.register(Disk) @@ -86,7 +85,7 @@ def _disk_actions(disk): DeviceAction.REFORMAT, DeviceAction.FORMAT, DeviceAction.REMOVE, - ] + ] if disk._m.bootloader != Bootloader.NONE: actions.append(DeviceAction.TOGGLE_BOOT) return actions @@ -98,7 +97,7 @@ def _part_actions(part): DeviceAction.EDIT, DeviceAction.REMOVE, DeviceAction.DELETE, - ] + ] @_supported_actions.register(Raid) @@ -107,7 +106,7 @@ def _raid_actions(raid): return [ DeviceAction.EDIT, DeviceAction.DELETE, - ] + ] else: actions = [ DeviceAction.EDIT, @@ -115,9 +114,9 @@ def _raid_actions(raid): DeviceAction.REMOVE, DeviceAction.DELETE, DeviceAction.REFORMAT, - ] + ] if raid._m.bootloader == Bootloader.UEFI: - if raid.container and raid.container.metadata == 'imsm': + if raid.container and raid.container.metadata == "imsm": actions.append(DeviceAction.TOGGLE_BOOT) return actions @@ -127,7 +126,7 @@ def _vg_actions(vg): return [ DeviceAction.EDIT, DeviceAction.DELETE, - ] + ] @_supported_actions.register(LVM_LogicalVolume) @@ -135,14 +134,14 @@ def _lv_actions(lv): return [ DeviceAction.EDIT, DeviceAction.DELETE, - ] + ] @_supported_actions.register(gaps.Gap) def _gap_actions(lv): return [ DeviceAction.PARTITION, - ] + ] _can_info = make_checker(DeviceAction.INFO) @@ -161,11 +160,10 @@ def _can_edit_generic(device): if cd is None: return True return _( - "Cannot edit {selflabel} as it is part of the {cdtype} " - "{cdname}.").format( - selflabel=labels.label(device), - cdtype=labels.desc(cd), - cdname=labels.label(cd)) + "Cannot edit {selflabel} as it is part of the {cdtype} " "{cdname}." + ).format( + selflabel=labels.label(device), cdtype=labels.desc(cd), cdname=labels.label(cd) + ) @_can_edit.register(Partition) @@ -183,9 +181,9 @@ def _can_edit_raid(raid): if raid.preserve: return _("Cannot edit pre-existing RAIDs.") elif len(raid._partitions) > 0: - return _( - "Cannot edit {raidlabel} because it has partitions.").format( - raidlabel=labels.label(raid)) + return _("Cannot edit {raidlabel} because it has partitions.").format( + raidlabel=labels.label(raid) + ) else: return _can_edit_generic(raid) @@ -195,10 +193,9 @@ def _can_edit_vg(vg): if vg.preserve: return _("Cannot edit pre-existing volume groups.") elif len(vg._partitions) > 0: - return _( - "Cannot edit {vglabel} because it has logical " - "volumes.").format( - vglabel=labels.label(vg)) + return _("Cannot edit {vglabel} because it has logical " "volumes.").format( + vglabel=labels.label(vg) + ) else: return _can_edit_generic(vg) @@ -251,11 +248,13 @@ def _can_remove_device(device): if cd is None: return False if cd.preserve: - return _("Cannot remove {selflabel} from pre-existing {cdtype} " - "{cdlabel}.").format( - selflabel=labels.label(device), - cdtype=labels.desc(cd), - cdlabel=labels.label(cd)) + return _( + "Cannot remove {selflabel} from pre-existing {cdtype} " "{cdlabel}." + ).format( + selflabel=labels.label(device), + cdtype=labels.desc(cd), + cdlabel=labels.label(cd), + ) if isinstance(cd, Raid): if device in cd.spare_devices: return True @@ -263,19 +262,23 @@ def _can_remove_device(device): if len(cd.devices) == min_devices: return _( "Removing {selflabel} would leave the {cdtype} {cdlabel} with " - "less than {min_devices} devices.").format( - selflabel=labels.label(device), - cdtype=labels.desc(cd), - cdlabel=labels.label(cd), - min_devices=min_devices) + "less than {min_devices} devices." + ).format( + selflabel=labels.label(device), + cdtype=labels.desc(cd), + cdlabel=labels.label(cd), + min_devices=min_devices, + ) elif isinstance(cd, LVM_VolGroup): if len(cd.devices) == 1: return _( "Removing {selflabel} would leave the {cdtype} {cdlabel} with " - "no devices.").format( - selflabel=labels.label(device), - cdtype=labels.desc(cd), - cdlabel=labels.label(cd)) + "no devices." + ).format( + selflabel=labels.label(device), + cdtype=labels.desc(cd), + cdlabel=labels.label(cd), + ) return True @@ -287,11 +290,10 @@ def _can_delete_generic(device): if cd is None: return True return _( - "Cannot delete {selflabel} as it is part of the {cdtype} " - "{cdname}.").format( - selflabel=labels.label(device), - cdtype=labels.desc(cd), - cdname=labels.label(cd)) + "Cannot delete {selflabel} as it is part of the {cdtype} " "{cdname}." + ).format( + selflabel=labels.label(device), cdtype=labels.desc(cd), cdname=labels.label(cd) + ) @_can_delete.register(Partition) @@ -299,8 +301,10 @@ def _can_delete_partition(partition): if partition._is_in_use: return False if partition.device._has_preexisting_partition(): - return _("Cannot delete a single partition from a device that " - "already has partitions.") + return _( + "Cannot delete a single partition from a device that " + "already has partitions." + ) if boot.is_bootloader_partition(partition): return _("Cannot delete required bootloader partition") return _can_delete_generic(partition) @@ -317,22 +321,21 @@ def _can_delete_raid_vg(device): cd = p.constructed_device() return _( "Cannot delete {devicelabel} as partition {partnum} is part " - "of the {cdtype} {cdname}.").format( - devicelabel=labels.label(device), - partnum=p.number, - cdtype=labels.desc(cd), - cdname=labels.label(cd), - ) + "of the {cdtype} {cdname}." + ).format( + devicelabel=labels.label(device), + partnum=p.number, + cdtype=labels.desc(cd), + cdname=labels.label(cd), + ) if mounted_partitions > 1: return _( - "Cannot delete {devicelabel} because it has {count} mounted " - "partitions.").format( - devicelabel=labels.label(device), - count=mounted_partitions) + "Cannot delete {devicelabel} because it has {count} mounted " "partitions." + ).format(devicelabel=labels.label(device), count=mounted_partitions) elif mounted_partitions == 1: return _( "Cannot delete {devicelabel} because it has 1 mounted partition." - ).format(devicelabel=labels.label(device)) + ).format(devicelabel=labels.label(device)) else: return _can_delete_generic(device) @@ -340,8 +343,10 @@ def _can_delete_raid_vg(device): @_can_delete.register(LVM_LogicalVolume) def _can_delete_lv(lv): if lv.volgroup._has_preexisting_partition(): - return _("Cannot delete a single logical volume from a volume " - "group that already has logical volumes.") + return _( + "Cannot delete a single logical volume from a volume " + "group that already has logical volumes." + ) return True diff --git a/subiquity/common/filesystem/boot.py b/subiquity/common/filesystem/boot.py index 463cad96..998d3754 100644 --- a/subiquity/common/filesystem/boot.py +++ b/subiquity/common/filesystem/boot.py @@ -21,16 +21,9 @@ from typing import Any, Optional import attr from subiquity.common.filesystem import gaps, sizes -from subiquity.models.filesystem import ( - align_up, - Disk, - Raid, - Bootloader, - Partition, - ) +from subiquity.models.filesystem import Bootloader, Disk, Partition, Raid, align_up - -log = logging.getLogger('subiquity.common.filesystem.boot') +log = logging.getLogger("subiquity.common.filesystem.boot") @functools.singledispatch @@ -55,7 +48,7 @@ def _is_boot_device_raid(raid): bl = raid._m.bootloader if bl != Bootloader.UEFI: return False - if not raid.container or raid.container.metadata != 'imsm': + if not raid.container or raid.container.metadata != "imsm": return False return any(p.grub_device for p in raid._partitions) @@ -85,8 +78,7 @@ class CreatePartPlan(MakeBootDevicePlan): args: dict = attr.ib(factory=dict) def apply(self, manipulator): - manipulator.create_partition( - self.gap.device, self.gap, self.spec, **self.args) + manipulator.create_partition(self.gap.device, self.gap, self.spec, **self.args) def _can_resize_part(inst, field, part): @@ -166,34 +158,39 @@ class MultiStepPlan(MakeBootDevicePlan): def get_boot_device_plan_bios(device) -> Optional[MakeBootDevicePlan]: - attr_plan = SetAttrPlan(device, 'grub_device', True) - if device.ptable == 'msdos': + attr_plan = SetAttrPlan(device, "grub_device", True) + if device.ptable == "msdos": return attr_plan pgs = gaps.parts_and_gaps(device) if len(pgs) > 0: if isinstance(pgs[0], Partition) and pgs[0].flag == "bios_grub": return attr_plan - gap = gaps.Gap(device=device, - offset=device.alignment_data().min_start_offset, - size=sizes.BIOS_GRUB_SIZE_BYTES) + gap = gaps.Gap( + device=device, + offset=device.alignment_data().min_start_offset, + size=sizes.BIOS_GRUB_SIZE_BYTES, + ) create_part_plan = CreatePartPlan( gap=gap, spec=dict(size=sizes.BIOS_GRUB_SIZE_BYTES, fstype=None, mount=None), - args=dict(flag='bios_grub')) + args=dict(flag="bios_grub"), + ) movable = [] for pg in pgs: if isinstance(pg, gaps.Gap): if pg.size >= sizes.BIOS_GRUB_SIZE_BYTES: - return MultiStepPlan(plans=[ - SlidePlan( - parts=movable, - offset_delta=sizes.BIOS_GRUB_SIZE_BYTES), - create_part_plan, - attr_plan, - ]) + return MultiStepPlan( + plans=[ + SlidePlan( + parts=movable, offset_delta=sizes.BIOS_GRUB_SIZE_BYTES + ), + create_part_plan, + attr_plan, + ] + ) else: return None elif pg.preserve: @@ -204,23 +201,21 @@ def get_boot_device_plan_bios(device) -> Optional[MakeBootDevicePlan]: if not movable: return None - largest_i, largest_part = max( - enumerate(movable), - key=lambda i_p: i_p[1].size) - return MultiStepPlan(plans=[ - ResizePlan( - part=largest_part, - size_delta=-sizes.BIOS_GRUB_SIZE_BYTES), - SlidePlan( - parts=movable[:largest_i+1], - offset_delta=sizes.BIOS_GRUB_SIZE_BYTES), - create_part_plan, - attr_plan, - ]) + largest_i, largest_part = max(enumerate(movable), key=lambda i_p: i_p[1].size) + return MultiStepPlan( + plans=[ + ResizePlan(part=largest_part, size_delta=-sizes.BIOS_GRUB_SIZE_BYTES), + SlidePlan( + parts=movable[: largest_i + 1], offset_delta=sizes.BIOS_GRUB_SIZE_BYTES + ), + create_part_plan, + attr_plan, + ] + ) def get_add_part_plan(device, *, spec, args, resize_partition=None): - size = spec['size'] + size = spec["size"] partitions = device.partitions() create_part_plan = CreatePartPlan(gap=None, spec=spec, args=args) @@ -238,70 +233,79 @@ def get_add_part_plan(device, *, spec, args, resize_partition=None): return None offset = resize_partition.offset + resize_partition.size - size - create_part_plan.gap = gaps.Gap( - device=device, offset=offset, size=size) - return MultiStepPlan(plans=[ - ResizePlan( - part=resize_partition, - size_delta=-size, - allow_resize_preserved=True, + create_part_plan.gap = gaps.Gap(device=device, offset=offset, size=size) + return MultiStepPlan( + plans=[ + ResizePlan( + part=resize_partition, + size_delta=-size, + allow_resize_preserved=True, ), - create_part_plan, - ]) + create_part_plan, + ] + ) else: - new_primaries = [p for p in partitions - if not p.preserve - if p.flag not in ('extended', 'logical')] + new_primaries = [ + p + for p in partitions + if not p.preserve + if p.flag not in ("extended", "logical") + ] if not new_primaries: return None largest_part = max(new_primaries, key=lambda p: p.size) if size > largest_part.size // 2: return None create_part_plan.gap = gaps.Gap( - device=device, offset=largest_part.offset, size=size) - return MultiStepPlan(plans=[ - ResizePlan( - part=largest_part, - size_delta=-size), - SlidePlan( - parts=[largest_part], - offset_delta=size), - create_part_plan, - ]) + device=device, offset=largest_part.offset, size=size + ) + return MultiStepPlan( + plans=[ + ResizePlan(part=largest_part, size_delta=-size), + SlidePlan(parts=[largest_part], offset_delta=size), + create_part_plan, + ] + ) def get_boot_device_plan_uefi(device, resize_partition): for part in device.partitions(): if is_esp(part): - plans = [SetAttrPlan(part, 'grub_device', True)] - if device._m._mount_for_path('/boot/efi') is None: + plans = [SetAttrPlan(part, "grub_device", True)] + if device._m._mount_for_path("/boot/efi") is None: plans.append(MountBootEfiPlan(part)) return MultiStepPlan(plans=plans) part_align = device.alignment_data().part_align size = align_up(sizes.get_efi_size(device.size), part_align) - spec = dict(size=size, fstype='fat32', mount=None) + spec = dict(size=size, fstype="fat32", mount=None) if device._m._mount_for_path("/boot/efi") is None: - spec['mount'] = '/boot/efi' + spec["mount"] = "/boot/efi" return get_add_part_plan( - device, spec=spec, args=dict(flag='boot', grub_device=True), - resize_partition=resize_partition) + device, + spec=spec, + args=dict(flag="boot", grub_device=True), + resize_partition=resize_partition, + ) def get_boot_device_plan_prep(device, resize_partition): for part in device.partitions(): if part.flag == "prep": - return MultiStepPlan(plans=[ - SetAttrPlan(part, 'grub_device', True), - SetAttrPlan(part, 'wipe', 'zero') - ]) + return MultiStepPlan( + plans=[ + SetAttrPlan(part, "grub_device", True), + SetAttrPlan(part, "wipe", "zero"), + ] + ) return get_add_part_plan( device, spec=dict(size=sizes.PREP_GRUB_SIZE_BYTES, fstype=None, mount=None), - args=dict(flag='prep', grub_device=True, wipe='zero'), - resize_partition=resize_partition) + args=dict(flag="prep", grub_device=True, wipe="zero"), + resize_partition=resize_partition, + ) def get_boot_device_plan(device, resize_partition=None): @@ -317,12 +321,11 @@ def get_boot_device_plan(device, resize_partition=None): return get_boot_device_plan_prep(device, resize_partition) if bl == Bootloader.NONE: return NoOpBootPlan() - raise Exception(f'unexpected bootloader {bl} here') + raise Exception(f"unexpected bootloader {bl} here") @functools.singledispatch -def can_be_boot_device(device, *, - resize_partition=None, with_reformatting=False): +def can_be_boot_device(device, *, resize_partition=None, with_reformatting=False): """Can `device` be made into a boot device? If with_reformatting=True, return true if the device can be made @@ -332,8 +335,7 @@ def can_be_boot_device(device, *, @can_be_boot_device.register(Disk) -def _can_be_boot_device_disk(disk, *, - resize_partition=None, with_reformatting=False): +def _can_be_boot_device_disk(disk, *, resize_partition=None, with_reformatting=False): if with_reformatting: disk = disk._reformatted() plan = get_boot_device_plan(disk, resize_partition=resize_partition) @@ -341,12 +343,11 @@ def _can_be_boot_device_disk(disk, *, @can_be_boot_device.register(Raid) -def _can_be_boot_device_raid(raid, *, - resize_partition=None, with_reformatting=False): +def _can_be_boot_device_raid(raid, *, resize_partition=None, with_reformatting=False): bl = raid._m.bootloader if bl != Bootloader.UEFI: return False - if not raid.container or raid.container.metadata != 'imsm': + if not raid.container or raid.container.metadata != "imsm": return False if with_reformatting: return True @@ -376,7 +377,7 @@ def _is_esp_partition(partition): if typecode is None: return False try: - return int(typecode, 0) == 0xef + return int(typecode, 0) == 0xEF except ValueError: # In case there was garbage in the udev entry... return False diff --git a/subiquity/common/filesystem/gaps.py b/subiquity/common/filesystem/gaps.py index fbaec196..736b4240 100644 --- a/subiquity/common/filesystem/gaps.py +++ b/subiquity/common/filesystem/gaps.py @@ -14,21 +14,21 @@ # along with this program. If not, see . import functools -from typing import Tuple, List +from typing import List, Tuple import attr from subiquity.common.types import GapUsable from subiquity.models.filesystem import ( - align_up, - align_down, - Disk, LVM_CHUNK_SIZE, + Disk, LVM_LogicalVolume, LVM_VolGroup, Partition, Raid, - ) + align_down, + align_up, +) # should also set on_setattr=None with attrs 20.1.0 @@ -40,11 +40,11 @@ class Gap: in_extended: bool = False usable: str = GapUsable.YES - type: str = 'gap' + type: str = "gap" @property def id(self): - return 'gap-' + self.device.id + return "gap-" + self.device.id @property def is_usable(self): @@ -55,21 +55,25 @@ class Gap: the supplied size. If size is equal to the gap size, the second gap is None. The original gap is unmodified.""" if size > self.size: - raise ValueError('requested size larger than gap') + raise ValueError("requested size larger than gap") if size == self.size: return (self, None) - first_gap = Gap(device=self.device, - offset=self.offset, - size=size, - in_extended=self.in_extended, - usable=self.usable) + first_gap = Gap( + device=self.device, + offset=self.offset, + size=size, + in_extended=self.in_extended, + usable=self.usable, + ) if self.in_extended: size += self.device.alignment_data().ebr_space - rest_gap = Gap(device=self.device, - offset=self.offset + size, - size=self.size - size, - in_extended=self.in_extended, - usable=self.usable) + rest_gap = Gap( + device=self.device, + offset=self.offset + size, + size=self.size - size, + in_extended=self.in_extended, + usable=self.usable, + ) return (first_gap, rest_gap) def within(self): @@ -134,11 +138,15 @@ def find_disk_gaps_v2(device, info=None): else: usable = GapUsable.TOO_MANY_PRIMARY_PARTS if end - start >= info.min_gap_size: - result.append(Gap(device=device, - offset=start, - size=end - start, - in_extended=in_extended, - usable=usable)) + result.append( + Gap( + device=device, + offset=start, + size=end - start, + in_extended=in_extended, + usable=usable, + ) + ) prev_end = info.min_start_offset @@ -158,8 +166,7 @@ def find_disk_gaps_v2(device, info=None): gap_start = au(prev_end) if extended_end is not None: - gap_start = min( - extended_end, au(gap_start + info.ebr_space)) + gap_start = min(extended_end, au(gap_start + info.ebr_space)) if extended_end is not None and gap_end >= extended_end: maybe_add_gap(gap_start, ad(extended_end), True) @@ -247,7 +254,7 @@ def largest_gap_size(device, in_extended=None): @functools.singledispatch def movable_trailing_partitions_and_gap_size(partition): - """ For a given partition (or LVM logical volume), return the total, + """For a given partition (or LVM logical volume), return the total, potentially available, free space immediately following the partition. By potentially available, we mean that to claim that much free space, some other partitions might need to be moved. @@ -259,13 +266,14 @@ def movable_trailing_partitions_and_gap_size(partition): @movable_trailing_partitions_and_gap_size.register -def _movable_trailing_partitions_and_gap_size_partition(partition: Partition) \ - -> Tuple[List[Partition], int]: +def _movable_trailing_partitions_and_gap_size_partition( + partition: Partition, +) -> Tuple[List[Partition], int]: pgs = parts_and_gaps(partition.device) part_idx = pgs.index(partition) trailing_partitions = [] in_extended = partition.is_logical - for pg in pgs[part_idx + 1:]: + for pg in pgs[part_idx + 1 :]: if isinstance(pg, Partition): if pg.preserve: break @@ -281,8 +289,9 @@ def _movable_trailing_partitions_and_gap_size_partition(partition: Partition) \ @movable_trailing_partitions_and_gap_size.register -def _movable_trailing_partitions_and_gap_size_lvm(volume: LVM_LogicalVolume) \ - -> Tuple[List[LVM_LogicalVolume], int]: +def _movable_trailing_partitions_and_gap_size_lvm( + volume: LVM_LogicalVolume, +) -> Tuple[List[LVM_LogicalVolume], int]: # In a Volume Group, there is no need to move partitions around, one can # always use the remaining space. diff --git a/subiquity/common/filesystem/labels.py b/subiquity/common/filesystem/labels.py index 03faa2cd..17cecba7 100644 --- a/subiquity/common/filesystem/labels.py +++ b/subiquity/common/filesystem/labels.py @@ -18,18 +18,18 @@ import functools from subiquity.common import types from subiquity.common.filesystem import boot, gaps from subiquity.models.filesystem import ( + ZFS, Disk, LVM_LogicalVolume, LVM_VolGroup, Partition, Raid, - ZFS, ZPool, - ) +) def _annotations_generic(device): - preserve = getattr(device, 'preserve', None) + preserve = getattr(device, "preserve", None) if preserve is None: return [] elif preserve: @@ -128,16 +128,15 @@ def _desc_partition(partition): @desc.register(Raid) def _desc_raid(raid): level = raid.raidlevel - if level.lower().startswith('raid'): + if level.lower().startswith("raid"): level = level[4:] if raid.container: raid_type = raid.container.metadata elif raid.metadata == "imsm": - raid_type = 'imsm' # maybe VROC? + raid_type = "imsm" # maybe VROC? else: raid_type = _("software") - return _("{type} RAID {level}").format( - type=raid_type, level=level) + return _("{type} RAID {level}").format(type=raid_type, level=level) @desc.register(LVM_VolGroup) @@ -199,7 +198,8 @@ def _label_partition(partition, *, short=False): return p + _("partition {number}").format(number=partition.number) else: return p + _("partition {number} of {device}").format( - number=partition.number, device=label(partition.device)) + number=partition.number, device=label(partition.device) + ) @label.register(gaps.Gap) @@ -215,10 +215,9 @@ def _usage_labels_generic(device, *, exclude_final_unused=False): if cd is not None: return [ _("{component_name} of {desc} {name}").format( - component_name=cd.component_name, - desc=desc(cd), - name=cd.name), - ] + component_name=cd.component_name, desc=desc(cd), name=cd.name + ), + ] fs = device.fs() if fs is not None: if fs.preserve: @@ -279,11 +278,12 @@ def _usage_labels_disk(disk): @usage_labels.register(Raid) def _usage_labels_raid(raid): - if raid.metadata == 'imsm' and raid._subvolumes: + if raid.metadata == "imsm" and raid._subvolumes: return [ - _('container for {devices}').format( - devices=', '.join([label(v) for v in raid._subvolumes])) - ] + _("container for {devices}").format( + devices=", ".join([label(v) for v in raid._subvolumes]) + ) + ] return _usage_labels_generic(raid) @@ -309,7 +309,7 @@ def _for_client_disk(disk, *, min_size=0): return types.Disk( id=disk.id, label=label(disk), - path=getattr(disk, 'path', None), + path=getattr(disk, "path", None), type=desc(disk), size=disk.size, ptable=disk.ptable, @@ -319,9 +319,10 @@ def _for_client_disk(disk, *, min_size=0): boot_device=boot.is_boot_device(disk), can_be_boot_device=boot.can_be_boot_device(disk), ok_for_guided=disk.size >= min_size, - model=getattr(disk, 'model', None), - vendor=getattr(disk, 'vendor', None), - has_in_use_partition=disk._has_in_use_partition) + model=getattr(disk, "model", None), + vendor=getattr(disk, "vendor", None), + has_in_use_partition=disk._has_in_use_partition, + ) @for_client.register(Partition) @@ -341,7 +342,8 @@ def _for_client_partition(partition, *, min_size=0): estimated_min_size=partition.estimated_min_size, mount=partition.mount, format=partition.format, - is_in_use=partition._is_in_use) + is_in_use=partition._is_in_use, + ) @for_client.register(gaps.Gap) @@ -357,7 +359,7 @@ def _for_client_zpool(zpool, *, min_size=0): zfses=[for_client(zfs) for zfs in zpool.zfses], pool_properties=zpool.pool_properties, fs_properties=zpool.fs_properties, - ) + ) @for_client.register(ZFS) diff --git a/subiquity/common/filesystem/manipulator.py b/subiquity/common/filesystem/manipulator.py index 1ba8e373..7d33e4a1 100644 --- a/subiquity/common/filesystem/manipulator.py +++ b/subiquity/common/filesystem/manipulator.py @@ -19,20 +19,16 @@ from curtin.block import get_resize_fstypes from subiquity.common.filesystem import boot, gaps from subiquity.common.types import Bootloader -from subiquity.models.filesystem import ( - align_up, - Partition, - ) +from subiquity.models.filesystem import Partition, align_up -log = logging.getLogger('subiquity.common.filesystem.manipulator') +log = logging.getLogger("subiquity.common.filesystem.manipulator") class FilesystemManipulator: - def create_mount(self, fs, spec): - if spec.get('mount') is None: + if spec.get("mount") is None: return - mount = self.model.add_mount(fs, spec['mount']) + mount = self.model.add_mount(fs, spec["mount"]) if self.model.needs_bootloader_partition(): vol = fs.volume if vol.type == "partition" and boot.can_be_boot_device(vol.device): @@ -45,20 +41,20 @@ class FilesystemManipulator: self.model.remove_mount(mount) def create_filesystem(self, volume, spec): - if spec.get('fstype') is None: + if spec.get("fstype") is None: # prep partitions are always wiped (and never have a filesystem) - if getattr(volume, 'flag', None) != 'prep': - volume.wipe = spec.get('wipe', volume.wipe) + if getattr(volume, "flag", None) != "prep": + volume.wipe = spec.get("wipe", volume.wipe) fstype = volume.original_fstype() if fstype is None: return None else: - fstype = spec['fstype'] + fstype = spec["fstype"] # partition editing routines are expected to provide the # appropriate wipe value. partition additions may not, give them a # basic wipe value consistent with older behavior until we prove # that everyone does so correctly. - volume.wipe = spec.get('wipe', 'superblock') + volume.wipe = spec.get("wipe", "superblock") preserve = volume.wipe is None fs = self.model.add_filesystem(volume, fstype, preserve) if isinstance(volume, Partition): @@ -66,9 +62,9 @@ class FilesystemManipulator: volume.flag = "swap" elif volume.flag == "swap": volume.flag = "" - if spec.get('fstype') == "swap": + if spec.get("fstype") == "swap": self.model.add_mount(fs, "") - elif spec.get('fstype') is None and spec.get('use_swap'): + elif spec.get("fstype") is None and spec.get("use_swap"): self.model.add_mount(fs, "") else: self.create_mount(fs, spec) @@ -79,35 +75,39 @@ class FilesystemManipulator: return self.delete_mount(fs.mount()) self.model.remove_filesystem(fs) + delete_format = delete_filesystem def create_partition(self, device, gap, spec, **kw): - flag = kw.pop('flag', None) + flag = kw.pop("flag", None) if gap.in_extended: - if flag not in (None, 'logical'): - log.debug(f'overriding flag {flag} ' - 'due to being in an extended partition') - flag = 'logical' + if flag not in (None, "logical"): + log.debug( + f"overriding flag {flag} " "due to being in an extended partition" + ) + flag = "logical" part = self.model.add_partition( - device, size=gap.size, offset=gap.offset, flag=flag, **kw) + device, size=gap.size, offset=gap.offset, flag=flag, **kw + ) self.create_filesystem(part, spec) return part def delete_partition(self, part, override_preserve=False): - if not override_preserve and part.device.preserve and \ - self.model.storage_version < 2: + if ( + not override_preserve + and part.device.preserve + and self.model.storage_version < 2 + ): raise Exception("cannot delete partitions from preserved disks") self.clear(part) self.model.remove_partition(part) def create_raid(self, spec): - for d in spec['devices'] | spec['spare_devices']: + for d in spec["devices"] | spec["spare_devices"]: self.clear(d) raid = self.model.add_raid( - spec['name'], - spec['level'].value, - spec['devices'], - spec['spare_devices']) + spec["name"], spec["level"].value, spec["devices"], spec["spare_devices"] + ) return raid def delete_raid(self, raid): @@ -119,69 +119,74 @@ class FilesystemManipulator: for p in list(raid.partitions()): self.delete_partition(p, True) for d in set(raid.devices) | set(raid.spare_devices): - d.wipe = 'superblock' + d.wipe = "superblock" self.model.remove_raid(raid) def create_volgroup(self, spec): devices = set() - key = spec.get('passphrase') - for device in spec['devices']: + key = spec.get("passphrase") + for device in spec["devices"]: self.clear(device) if key: device = self.model.add_dm_crypt(device, key) devices.add(device) - return self.model.add_volgroup(name=spec['name'], devices=devices) + return self.model.add_volgroup(name=spec["name"], devices=devices) + create_lvm_volgroup = create_volgroup def delete_volgroup(self, vg): for lv in list(vg.partitions()): self.delete_logical_volume(lv) for d in vg.devices: - d.wipe = 'superblock' + d.wipe = "superblock" if d.type == "dm_crypt": self.model.remove_dm_crypt(d) self.model.remove_volgroup(vg) + delete_lvm_volgroup = delete_volgroup def create_logical_volume(self, vg, spec): - lv = self.model.add_logical_volume( - vg=vg, - name=spec['name'], - size=spec['size']) + lv = self.model.add_logical_volume(vg=vg, name=spec["name"], size=spec["size"]) self.create_filesystem(lv, spec) return lv + create_lvm_partition = create_logical_volume def delete_logical_volume(self, lv): self.clear(lv) self.model.remove_logical_volume(lv) + delete_lvm_partition = delete_logical_volume def create_zpool(self, device, pool, mountpoint): fs_properties = dict( - acltype='posixacl', - relatime='on', - canmount='on', - compression='gzip', - devices='off', - xattr='sa', + acltype="posixacl", + relatime="on", + canmount="on", + compression="gzip", + devices="off", + xattr="sa", ) pool_properties = dict(ashift=12) self.model.add_zpool( - device, pool, mountpoint, - fs_properties=fs_properties, pool_properties=pool_properties) + device, + pool, + mountpoint, + fs_properties=fs_properties, + pool_properties=pool_properties, + ) def delete(self, obj): if obj is None: return - getattr(self, 'delete_' + obj.type)(obj) + getattr(self, "delete_" + obj.type)(obj) def clear(self, obj, wipe=None): if obj.type == "disk": obj.preserve = False if wipe is None: - wipe = 'superblock' + wipe = "superblock" obj.wipe = wipe for subobj in obj.fs(), obj.constructed_device(): self.delete(subobj) @@ -202,14 +207,14 @@ class FilesystemManipulator: return True def partition_disk_handler(self, disk, spec, *, partition=None, gap=None): - log.debug('partition_disk_handler: %s %s %s %s', - disk, spec, partition, gap) + log.debug("partition_disk_handler: %s %s %s %s", disk, spec, partition, gap) if partition is not None: - if 'size' in spec and spec['size'] != partition.size: - trailing, gap_size = \ - gaps.movable_trailing_partitions_and_gap_size(partition) - new_size = align_up(spec['size']) + if "size" in spec and spec["size"] != partition.size: + trailing, gap_size = gaps.movable_trailing_partitions_and_gap_size( + partition + ) + new_size = align_up(spec["size"]) size_change = new_size - partition.size if size_change > gap_size: raise Exception("partition size too large") @@ -226,12 +231,12 @@ class FilesystemManipulator: if len(disk.partitions()) == 0: if disk.type == "disk": disk.preserve = False - disk.wipe = 'superblock-recursive' + disk.wipe = "superblock-recursive" elif disk.type == "raid": - disk.wipe = 'superblock-recursive' + disk.wipe = "superblock-recursive" needs_boot = self.model.needs_bootloader_partition() - log.debug('model needs a bootloader partition? {}'.format(needs_boot)) + log.debug("model needs a bootloader partition? {}".format(needs_boot)) can_be_boot = boot.can_be_boot_device(disk) if needs_boot and len(disk.partitions()) == 0 and can_be_boot: self.add_boot_disk(disk) @@ -241,13 +246,13 @@ class FilesystemManipulator: # 1) with len(partitions()) == 0 there could only have been 1 gap # 2) having just done add_boot_disk(), the gap is no longer valid. gap = gaps.largest_gap(disk) - if spec['size'] > gap.size: + if spec["size"] > gap.size: log.debug( - "Adjusting request down from %s to %s", - spec['size'], gap.size) - spec['size'] = gap.size + "Adjusting request down from %s to %s", spec["size"], gap.size + ) + spec["size"] = gap.size - gap = gap.split(spec['size'])[0] + gap = gap.split(spec["size"])[0] self.create_partition(disk, gap, spec) log.debug("Successfully added partition") @@ -256,13 +261,13 @@ class FilesystemManipulator: # keep the partition name for compat with PartitionStretchy.handler lv = partition - log.debug('logical_volume_handler: %s %s %s', vg, lv, spec) + log.debug("logical_volume_handler: %s %s %s", vg, lv, spec) if lv is not None: - if 'name' in spec: - lv.name = spec['name'] - if 'size' in spec: - lv.size = align_up(spec['size']) + if "name" in spec: + lv.name = spec["name"] + if "size" in spec: + lv.size = align_up(spec["size"]) if gaps.largest_gap_size(vg) < 0: raise Exception("lv size too large") self.delete_filesystem(lv.fs()) @@ -272,7 +277,7 @@ class FilesystemManipulator: self.create_logical_volume(vg, spec) def add_format_handler(self, volume, spec): - log.debug('add_format_handler %s %s', volume, spec) + log.debug("add_format_handler %s %s", volume, spec) self.clear(volume) self.create_filesystem(volume, spec) @@ -281,48 +286,47 @@ class FilesystemManipulator: if existing is not None: for d in existing.devices | existing.spare_devices: d._constructed_device = None - for d in spec['devices'] | spec['spare_devices']: + for d in spec["devices"] | spec["spare_devices"]: self.clear(d) d._constructed_device = existing - existing.name = spec['name'] - existing.raidlevel = spec['level'].value - existing.devices = spec['devices'] - existing.spare_devices = spec['spare_devices'] + existing.name = spec["name"] + existing.raidlevel = spec["level"].value + existing.devices = spec["devices"] + existing.spare_devices = spec["spare_devices"] else: self.create_raid(spec) def volgroup_handler(self, existing, spec): if existing is not None: - key = spec.get('passphrase') + key = spec.get("passphrase") for d in existing.devices: if d.type == "dm_crypt": self.model.remove_dm_crypt(d) d = d.volume d._constructed_device = None devices = set() - for d in spec['devices']: + for d in spec["devices"]: self.clear(d) if key: d = self.model.add_dm_crypt(d, key) d._constructed_device = existing devices.add(d) - existing.name = spec['name'] + existing.name = spec["name"] existing.devices = devices else: self.create_volgroup(spec) def _mount_esp(self, part): if part.fs() is None: - self.model.add_filesystem(part, 'fat32') - self.model.add_mount(part.fs(), '/boot/efi') + self.model.add_filesystem(part, "fat32") + self.model.add_mount(part.fs(), "/boot/efi") def remove_boot_disk(self, boot_disk): if self.model.bootloader == Bootloader.BIOS: boot_disk.grub_device = False partitions = [ - p for p in boot_disk.partitions() - if boot.is_bootloader_partition(p) - ] + p for p in boot_disk.partitions() if boot.is_bootloader_partition(p) + ] remount = False if boot_disk.preserve: if self.model.bootloader == Bootloader.BIOS: @@ -339,7 +343,8 @@ class FilesystemManipulator: if not p.fs().preserve and p.original_fstype(): self.delete_filesystem(p.fs()) self.model.add_filesystem( - p, p.original_fstype(), preserve=True) + p, p.original_fstype(), preserve=True + ) else: full = gaps.largest_gap_size(boot_disk) == 0 tot_size = 0 @@ -349,11 +354,10 @@ class FilesystemManipulator: remount = True self.delete_partition(p) if full: - largest_part = max( - boot_disk.partitions(), key=lambda p: p.size) + largest_part = max(boot_disk.partitions(), key=lambda p: p.size) largest_part.size += tot_size if self.model.bootloader == Bootloader.UEFI and remount: - part = self.model._one(type='partition', grub_device=True) + part = self.model._one(type="partition", grub_device=True) if part: self._mount_esp(part) @@ -363,7 +367,7 @@ class FilesystemManipulator: self.remove_boot_disk(disk) plan = boot.get_boot_device_plan(new_boot_disk) if plan is None: - raise ValueError(f'No known plan to make {new_boot_disk} bootable') + raise ValueError(f"No known plan to make {new_boot_disk} bootable") plan.apply(self) if not new_boot_disk._has_preexisting_partition(): if new_boot_disk.type == "disk": diff --git a/subiquity/common/filesystem/sizes.py b/subiquity/common/filesystem/sizes.py index dabb2881..bdbe9b94 100644 --- a/subiquity/common/filesystem/sizes.py +++ b/subiquity/common/filesystem/sizes.py @@ -16,17 +16,10 @@ import math import attr - from curtin import swap -from subiquity.models.filesystem import ( - align_up, - align_down, - MiB, - GiB, - ) from subiquity.common.types import GuidedResizeValues - +from subiquity.models.filesystem import GiB, MiB, align_down, align_up BIOS_GRUB_SIZE_BYTES = 1 * MiB PREP_GRUB_SIZE_BYTES = 8 * MiB @@ -39,18 +32,11 @@ class PartitionScaleFactors: maximum: int -uefi_scale = PartitionScaleFactors( - minimum=538 * MiB, - priority=538, - maximum=1075 * MiB) +uefi_scale = PartitionScaleFactors(minimum=538 * MiB, priority=538, maximum=1075 * MiB) bootfs_scale = PartitionScaleFactors( - minimum=1792 * MiB, - priority=1024, - maximum=2048 * MiB) -rootfs_scale = PartitionScaleFactors( - minimum=900 * MiB, - priority=10000, - maximum=-1) + minimum=1792 * MiB, priority=1024, maximum=2048 * MiB +) +rootfs_scale = PartitionScaleFactors(minimum=900 * MiB, priority=10000, maximum=-1) def scale_partitions(all_factors, available_space): @@ -96,14 +82,15 @@ def get_bootfs_size(available_space): # decide to keep with the existing partition, or allocate to the new install # 4) Assume that the installs will grow proportionally to their minimum sizes, # and split according to the ratio of the minimum sizes -def calculate_guided_resize(part_min: int, part_size: int, install_min: int, - part_align: int = MiB) -> GuidedResizeValues: +def calculate_guided_resize( + part_min: int, part_size: int, install_min: int, part_align: int = MiB +) -> GuidedResizeValues: if part_min < 0: return None part_size = align_up(part_size, part_align) - other_room_to_grow = max(2 * GiB, math.ceil(.25 * part_min)) + other_room_to_grow = max(2 * GiB, math.ceil(0.25 * part_min)) padded_other_min = part_min + other_room_to_grow other_min = min(align_up(padded_other_min, part_align), part_size) @@ -117,8 +104,11 @@ def calculate_guided_resize(part_min: int, part_size: int, install_min: int, raw_recommended = math.ceil(resize_window * ratio) + other_min recommended = align_up(raw_recommended, part_align) return GuidedResizeValues( - install_max=plausible_free_space, - minimum=other_min, recommended=recommended, maximum=other_max) + install_max=plausible_free_space, + minimum=other_min, + recommended=recommended, + maximum=other_max, + ) # Factors for suggested minimum install size: @@ -142,10 +132,9 @@ def calculate_guided_resize(part_min: int, part_size: int, install_min: int, # 4) Swap - Ubuntu policy recommends swap, either in the form of a partition or # a swapfile. Curtin has an existing swap size recommendation at # curtin.swap.suggested_swapsize(). -def calculate_suggested_install_min(source_min: int, - part_align: int = MiB) -> int: +def calculate_suggested_install_min(source_min: int, part_align: int = MiB) -> int: room_for_boot = bootfs_scale.maximum - room_to_grow = max(2 * GiB, math.ceil(.8 * source_min)) + room_to_grow = max(2 * GiB, math.ceil(0.8 * source_min)) room_for_swap = swap.suggested_swapsize() total = source_min + room_for_boot + room_to_grow + room_for_swap return align_up(total, part_align) diff --git a/subiquity/common/filesystem/tests/test_actions.py b/subiquity/common/filesystem/tests/test_actions.py index 88bb2fb3..720827e7 100644 --- a/subiquity/common/filesystem/tests/test_actions.py +++ b/subiquity/common/filesystem/tests/test_actions.py @@ -15,10 +15,8 @@ import unittest -from subiquity.common.filesystem.actions import ( - DeviceAction, - ) from subiquity.common.filesystem import gaps +from subiquity.common.filesystem.actions import DeviceAction from subiquity.models.filesystem import Bootloader from subiquity.models.tests.test_filesystem import ( make_disk, @@ -32,11 +30,10 @@ from subiquity.models.tests.test_filesystem import ( make_partition, make_raid, make_vg, - ) +) class TestActions(unittest.TestCase): - def assertActionNotSupported(self, obj, action): self.assertNotIn(action, DeviceAction.supported(obj)) @@ -51,7 +48,7 @@ class TestActions(unittest.TestCase): def _test_remove_action(self, model, objects): self.assertActionNotPossible(objects[0], DeviceAction.REMOVE) - vg = model.add_volgroup('vg1', {objects[0], objects[1]}) + vg = model.add_volgroup("vg1", {objects[0], objects[1]}) self.assertActionPossible(objects[0], DeviceAction.REMOVE) # Cannot remove a device from a preexisting VG @@ -63,7 +60,7 @@ class TestActions(unittest.TestCase): vg.devices.remove(objects[1]) objects[1]._constructed_device = None self.assertActionNotPossible(objects[0], DeviceAction.REMOVE) - raid = model.add_raid('md0', 'raid1', set(objects[2:]), set()) + raid = model.add_raid("md0", "raid1", set(objects[2:]), set()) self.assertActionPossible(objects[2], DeviceAction.REMOVE) # Cannot remove a device from a preexisting RAID @@ -91,22 +88,22 @@ class TestActions(unittest.TestCase): self.assertActionNotPossible(disk1, DeviceAction.REFORMAT) disk1p1 = make_partition(model, disk1, preserve=False) self.assertActionPossible(disk1, DeviceAction.REFORMAT) - model.add_volgroup('vg0', {disk1p1}) + model.add_volgroup("vg0", {disk1p1}) self.assertActionNotPossible(disk1, DeviceAction.REFORMAT) disk2 = make_disk(model, preserve=True) self.assertActionNotPossible(disk2, DeviceAction.REFORMAT) disk2p1 = make_partition(model, disk2, preserve=True) self.assertActionPossible(disk2, DeviceAction.REFORMAT) - model.add_volgroup('vg1', {disk2p1}) + model.add_volgroup("vg1", {disk2p1}) self.assertActionNotPossible(disk2, DeviceAction.REFORMAT) disk3 = make_disk(model, preserve=False) - model.add_volgroup('vg2', {disk3}) + model.add_volgroup("vg2", {disk3}) self.assertActionNotPossible(disk3, DeviceAction.REFORMAT) disk4 = make_disk(model, preserve=True) - model.add_volgroup('vg2', {disk4}) + model.add_volgroup("vg2", {disk4}) self.assertActionNotPossible(disk4, DeviceAction.REFORMAT) def test_disk_action_PARTITION(self): @@ -119,7 +116,7 @@ class TestActions(unittest.TestCase): make_partition(model, disk) self.assertActionNotPossible(disk, DeviceAction.FORMAT) disk2 = make_disk(model) - model.add_volgroup('vg1', {disk2}) + model.add_volgroup("vg1", {disk2}) self.assertActionNotPossible(disk2, DeviceAction.FORMAT) def test_disk_action_REMOVE(self): @@ -138,18 +135,18 @@ class TestActions(unittest.TestCase): def test_disk_action_TOGGLE_BOOT_BIOS(self): model = make_model(Bootloader.BIOS) # Disks with msdos partition tables can always be the BIOS boot disk. - dos_disk = make_disk(model, ptable='msdos', preserve=True) + dos_disk = make_disk(model, ptable="msdos", preserve=True) self.assertActionPossible(dos_disk, DeviceAction.TOGGLE_BOOT) # Even if they have existing partitions make_partition( - model, dos_disk, size=gaps.largest_gap_size(dos_disk), - preserve=True) + model, dos_disk, size=gaps.largest_gap_size(dos_disk), preserve=True + ) self.assertActionPossible(dos_disk, DeviceAction.TOGGLE_BOOT) # (we never create dos partition tables so no need to test # preserve=False case). # GPT disks with new partition tables can always be the BIOS boot disk - gpt_disk = make_disk(model, ptable='gpt', preserve=False) + gpt_disk = make_disk(model, ptable="gpt", preserve=False) self.assertActionPossible(gpt_disk, DeviceAction.TOGGLE_BOOT) # Even if they are filled with partitions (we resize partitions to fit) make_partition(model, gpt_disk, size=gaps.largest_gap_size(dos_disk)) @@ -157,25 +154,37 @@ class TestActions(unittest.TestCase): # GPT disks with existing partition tables but no partitions can be the # BIOS boot disk (in general we ignore existing empty partition tables) - gpt_disk2 = make_disk(model, ptable='gpt', preserve=True) + gpt_disk2 = make_disk(model, ptable="gpt", preserve=True) self.assertActionPossible(gpt_disk2, DeviceAction.TOGGLE_BOOT) # If there is an existing *partition* though, it cannot be the boot # disk make_partition(model, gpt_disk2, preserve=True) self.assertActionNotPossible(gpt_disk2, DeviceAction.TOGGLE_BOOT) # Unless there is already a bios_grub partition we can reuse - gpt_disk3 = make_disk(model, ptable='gpt', preserve=True) - make_partition(model, gpt_disk3, flag="bios_grub", preserve=True, - offset=1 << 20, size=512 << 20) - make_partition(model, gpt_disk3, preserve=True, - offset=513 << 20, size=8192 << 20) + gpt_disk3 = make_disk(model, ptable="gpt", preserve=True) + make_partition( + model, + gpt_disk3, + flag="bios_grub", + preserve=True, + offset=1 << 20, + size=512 << 20, + ) + make_partition( + model, gpt_disk3, preserve=True, offset=513 << 20, size=8192 << 20 + ) self.assertActionPossible(gpt_disk3, DeviceAction.TOGGLE_BOOT) # Edge case city: the bios_grub partition has to be first - gpt_disk4 = make_disk(model, ptable='gpt', preserve=True) - make_partition(model, gpt_disk4, preserve=True, - offset=1 << 20, size=8192 << 20) - make_partition(model, gpt_disk4, flag="bios_grub", preserve=True, - offset=8193 << 20, size=512 << 20) + gpt_disk4 = make_disk(model, ptable="gpt", preserve=True) + make_partition(model, gpt_disk4, preserve=True, offset=1 << 20, size=8192 << 20) + make_partition( + model, + gpt_disk4, + flag="bios_grub", + preserve=True, + offset=8193 << 20, + size=512 << 20, + ) self.assertActionNotPossible(gpt_disk4, DeviceAction.TOGGLE_BOOT) def _test_TOGGLE_BOOT_boot_partition(self, bl, flag): @@ -188,20 +197,20 @@ class TestActions(unittest.TestCase): new_disk = make_disk(model, preserve=False) self.assertActionPossible(new_disk, DeviceAction.TOGGLE_BOOT) # Even if they are filled with partitions (we resize partitions to fit) - make_partition( - model, new_disk, size=gaps.largest_gap_size(new_disk)) + make_partition(model, new_disk, size=gaps.largest_gap_size(new_disk)) self.assertActionPossible(new_disk, DeviceAction.TOGGLE_BOOT) # A disk with an existing but empty partitions can also be the # UEFI/PREP boot disk. - old_disk = make_disk(model, preserve=True, ptable='gpt') + old_disk = make_disk(model, preserve=True, ptable="gpt") self.assertActionPossible(old_disk, DeviceAction.TOGGLE_BOOT) # If there is an existing partition though, it cannot. make_partition(model, old_disk, preserve=True) self.assertActionNotPossible(old_disk, DeviceAction.TOGGLE_BOOT) # If there is an existing ESP/PReP partition though, fine! - make_partition(model, old_disk, flag=flag, preserve=True, - offset=1 << 20, size=512 << 20) + make_partition( + model, old_disk, flag=flag, preserve=True, offset=1 << 20, size=512 << 20 + ) self.assertActionPossible(old_disk, DeviceAction.TOGGLE_BOOT) def test_disk_action_TOGGLE_BOOT_UEFI(self): @@ -217,7 +226,7 @@ class TestActions(unittest.TestCase): def test_partition_action_EDIT(self): model, part = make_model_and_partition() self.assertActionPossible(part, DeviceAction.EDIT) - model.add_volgroup('vg1', {part}) + model.add_volgroup("vg1", {part}) self.assertActionNotPossible(part, DeviceAction.EDIT) def test_partition_action_REFORMAT(self): @@ -243,20 +252,20 @@ class TestActions(unittest.TestCase): model = make_model() part1 = make_partition(model) self.assertActionPossible(part1, DeviceAction.DELETE) - fs = model.add_filesystem(part1, 'ext4') + fs = model.add_filesystem(part1, "ext4") self.assertActionPossible(part1, DeviceAction.DELETE) - model.add_mount(fs, '/') + model.add_mount(fs, "/") self.assertActionPossible(part1, DeviceAction.DELETE) part2 = make_partition(model) - model.add_volgroup('vg1', {part2}) + model.add_volgroup("vg1", {part2}) self.assertActionNotPossible(part2, DeviceAction.DELETE) - part = make_partition(make_model(Bootloader.BIOS), flag='bios_grub') + part = make_partition(make_model(Bootloader.BIOS), flag="bios_grub") self.assertActionNotPossible(part, DeviceAction.DELETE) - part = make_partition(make_model(Bootloader.UEFI), flag='boot') + part = make_partition(make_model(Bootloader.UEFI), flag="boot") self.assertActionNotPossible(part, DeviceAction.DELETE) - part = make_partition(make_model(Bootloader.PREP), flag='prep') + part = make_partition(make_model(Bootloader.PREP), flag="prep") self.assertActionNotPossible(part, DeviceAction.DELETE) # You cannot delete a partition from a disk that has @@ -277,7 +286,7 @@ class TestActions(unittest.TestCase): model = make_model() raid1 = make_raid(model) self.assertActionPossible(raid1, DeviceAction.EDIT) - model.add_volgroup('vg1', {raid1}) + model.add_volgroup("vg1", {raid1}) self.assertActionNotPossible(raid1, DeviceAction.EDIT) raid2 = make_raid(model) make_partition(model, raid2) @@ -294,23 +303,23 @@ class TestActions(unittest.TestCase): self.assertActionNotPossible(raid1, DeviceAction.REFORMAT) raid1p1 = make_partition(model, raid1) self.assertActionPossible(raid1, DeviceAction.REFORMAT) - model.add_volgroup('vg0', {raid1p1}) + model.add_volgroup("vg0", {raid1p1}) self.assertActionNotPossible(raid1, DeviceAction.REFORMAT) raid2 = make_raid(model, preserve=True) self.assertActionNotPossible(raid2, DeviceAction.REFORMAT) raid2p1 = make_partition(model, raid2, preserve=True) self.assertActionPossible(raid2, DeviceAction.REFORMAT) - model.add_volgroup('vg1', {raid2p1}) + model.add_volgroup("vg1", {raid2p1}) self.assertActionNotPossible(raid2, DeviceAction.REFORMAT) raid3 = make_raid(model) - model.add_volgroup('vg2', {raid3}) + model.add_volgroup("vg2", {raid3}) self.assertActionNotPossible(raid3, DeviceAction.REFORMAT) raid4 = make_raid(model) raid4.preserve = True - model.add_volgroup('vg2', {raid4}) + model.add_volgroup("vg2", {raid4}) self.assertActionNotPossible(raid4, DeviceAction.REFORMAT) def test_raid_action_PARTITION(self): @@ -323,7 +332,7 @@ class TestActions(unittest.TestCase): make_partition(model, raid) self.assertActionNotPossible(raid, DeviceAction.FORMAT) raid2 = make_raid(model) - model.add_volgroup('vg1', {raid2}) + model.add_volgroup("vg1", {raid2}) self.assertActionNotPossible(raid2, DeviceAction.FORMAT) def test_raid_action_REMOVE(self): @@ -338,20 +347,20 @@ class TestActions(unittest.TestCase): self.assertActionPossible(raid1, DeviceAction.DELETE) part = make_partition(model, raid1) self.assertActionPossible(raid1, DeviceAction.DELETE) - fs = model.add_filesystem(part, 'ext4') + fs = model.add_filesystem(part, "ext4") self.assertActionPossible(raid1, DeviceAction.DELETE) - model.add_mount(fs, '/') + model.add_mount(fs, "/") self.assertActionNotPossible(raid1, DeviceAction.DELETE) raid2 = make_raid(model) self.assertActionPossible(raid2, DeviceAction.DELETE) - fs = model.add_filesystem(raid2, 'ext4') + fs = model.add_filesystem(raid2, "ext4") self.assertActionPossible(raid2, DeviceAction.DELETE) - model.add_mount(fs, '/') + model.add_mount(fs, "/") self.assertActionPossible(raid2, DeviceAction.DELETE) raid2 = make_raid(model) - model.add_volgroup('vg0', {raid2}) + model.add_volgroup("vg0", {raid2}) self.assertActionNotPossible(raid2, DeviceAction.DELETE) def test_raid_action_TOGGLE_BOOT(self): @@ -365,7 +374,7 @@ class TestActions(unittest.TestCase): def test_vg_action_EDIT(self): model, vg = make_model_and_vg() self.assertActionPossible(vg, DeviceAction.EDIT) - model.add_logical_volume(vg, 'lv1', size=gaps.largest_gap_size(vg)) + model.add_logical_volume(vg, "lv1", size=gaps.largest_gap_size(vg)) self.assertActionNotPossible(vg, DeviceAction.EDIT) vg2 = make_vg(model) @@ -392,12 +401,11 @@ class TestActions(unittest.TestCase): model, vg = make_model_and_vg() self.assertActionPossible(vg, DeviceAction.DELETE) self.assertActionPossible(vg, DeviceAction.DELETE) - lv = model.add_logical_volume( - vg, 'lv0', size=gaps.largest_gap_size(vg)//2) + lv = model.add_logical_volume(vg, "lv0", size=gaps.largest_gap_size(vg) // 2) self.assertActionPossible(vg, DeviceAction.DELETE) - fs = model.add_filesystem(lv, 'ext4') + fs = model.add_filesystem(lv, "ext4") self.assertActionPossible(vg, DeviceAction.DELETE) - model.add_mount(fs, '/') + model.add_mount(fs, "/") self.assertActionNotPossible(vg, DeviceAction.DELETE) def test_vg_action_TOGGLE_BOOT(self): @@ -431,9 +439,9 @@ class TestActions(unittest.TestCase): def test_lv_action_DELETE(self): model, lv = make_model_and_lv() self.assertActionPossible(lv, DeviceAction.DELETE) - fs = model.add_filesystem(lv, 'ext4') + fs = model.add_filesystem(lv, "ext4") self.assertActionPossible(lv, DeviceAction.DELETE) - model.add_mount(fs, '/') + model.add_mount(fs, "/") self.assertActionPossible(lv, DeviceAction.DELETE) lv2 = make_lv(model) diff --git a/subiquity/common/filesystem/tests/test_gaps.py b/subiquity/common/filesystem/tests/test_gaps.py index 2ef18d44..6277dfd8 100644 --- a/subiquity/common/filesystem/tests/test_gaps.py +++ b/subiquity/common/filesystem/tests/test_gaps.py @@ -13,20 +13,20 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from functools import partial import unittest +from functools import partial from unittest import mock -from subiquitycore.tests.parameterized import parameterized - +from subiquity.common.filesystem import gaps +from subiquity.common.types import GapUsable from subiquity.models.filesystem import ( - Disk, LVM_CHUNK_SIZE, LVM_OVERHEAD, + Disk, MiB, PartitionAlignmentData, Raid, - ) +) from subiquity.models.tests.test_filesystem import ( make_disk, make_lv, @@ -35,22 +35,19 @@ from subiquity.models.tests.test_filesystem import ( make_model_and_lv, make_partition, make_vg, - ) - -from subiquity.common.filesystem import gaps -from subiquity.common.types import GapUsable +) +from subiquitycore.tests.parameterized import parameterized class GapTestCase(unittest.TestCase): def use_alignment_data(self, alignment_data): - m = mock.patch('subiquity.common.filesystem.gaps.parts_and_gaps') + m = mock.patch("subiquity.common.filesystem.gaps.parts_and_gaps") p = m.start() self.addCleanup(m.stop) - p.side_effect = partial( - gaps.find_disk_gaps_v2, info=alignment_data) + p.side_effect = partial(gaps.find_disk_gaps_v2, info=alignment_data) for cls in Disk, Raid: - md = mock.patch.object(cls, 'alignment_data') + md = mock.patch.object(cls, "alignment_data") p = md.start() self.addCleanup(md.stop) p.return_value = alignment_data @@ -86,14 +83,21 @@ class TestSplitGap(GapTestCase): @parameterized.expand([[1], [10]]) def test_split_in_extended(self, ebr_space): - self.use_alignment_data(PartitionAlignmentData( - part_align=1, min_gap_size=1, min_start_offset=0, - min_end_offset=0, primary_part_limit=4, ebr_space=ebr_space)) + self.use_alignment_data( + PartitionAlignmentData( + part_align=1, + min_gap_size=1, + min_start_offset=0, + min_end_offset=0, + primary_part_limit=4, + ebr_space=ebr_space, + ) + ) m = make_model(storage_version=2) - d = make_disk(m, size=100, ptable='dos') + d = make_disk(m, size=100, ptable="dos") [gap] = gaps.parts_and_gaps(d) - make_partition(m, d, offset=gap.offset, size=gap.size, flag='extended') + make_partition(m, d, offset=gap.offset, size=gap.size, flag="extended") [p, g] = gaps.parts_and_gaps(d) self.assertTrue(g.in_extended) g1, g2 = g.split(10) @@ -205,125 +209,158 @@ class TestAfter(unittest.TestCase): class TestDiskGaps(unittest.TestCase): - def test_no_partition_gpt(self): size = 1 << 30 - d = make_disk(size=size, ptable='gpt') + d = make_disk(size=size, ptable="gpt") self.assertEqual( - gaps.find_disk_gaps_v2(d), - [gaps.Gap(d, MiB, size - 2*MiB, False)]) + gaps.find_disk_gaps_v2(d), [gaps.Gap(d, MiB, size - 2 * MiB, False)] + ) def test_no_partition_dos(self): size = 1 << 30 - d = make_disk(size=size, ptable='dos') + d = make_disk(size=size, ptable="dos") self.assertEqual( - gaps.find_disk_gaps_v2(d), - [gaps.Gap(d, MiB, size - MiB, False)]) + gaps.find_disk_gaps_v2(d), [gaps.Gap(d, MiB, size - MiB, False)] + ) def test_all_partition(self): info = PartitionAlignmentData( - part_align=10, min_gap_size=1, min_start_offset=0, - min_end_offset=0, primary_part_limit=10) + part_align=10, + min_gap_size=1, + min_start_offset=0, + min_end_offset=0, + primary_part_limit=10, + ) m, d = make_model_and_disk(size=100) p = make_partition(m, d, offset=0, size=100) - self.assertEqual( - gaps.find_disk_gaps_v2(d, info), - [p]) + self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p]) def test_all_partition_with_min_offsets(self): info = PartitionAlignmentData( - part_align=10, min_gap_size=1, min_start_offset=10, - min_end_offset=10, primary_part_limit=10) + part_align=10, + min_gap_size=1, + min_start_offset=10, + min_end_offset=10, + primary_part_limit=10, + ) m, d = make_model_and_disk(size=100) p = make_partition(m, d, offset=10, size=80) - self.assertEqual( - gaps.find_disk_gaps_v2(d, info), - [p]) + self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p]) def test_half_partition(self): info = PartitionAlignmentData( - part_align=10, min_gap_size=1, min_start_offset=0, - min_end_offset=0, primary_part_limit=10) + part_align=10, + min_gap_size=1, + min_start_offset=0, + min_end_offset=0, + primary_part_limit=10, + ) m, d = make_model_and_disk(size=100) p = make_partition(m, d, offset=0, size=50) - self.assertEqual( - gaps.find_disk_gaps_v2(d, info), - [p, gaps.Gap(d, 50, 50)]) + self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p, gaps.Gap(d, 50, 50)]) def test_gap_in_middle(self): info = PartitionAlignmentData( - part_align=10, min_gap_size=1, min_start_offset=0, - min_end_offset=0, primary_part_limit=10) + part_align=10, + min_gap_size=1, + min_start_offset=0, + min_end_offset=0, + primary_part_limit=10, + ) m, d = make_model_and_disk(size=100) p1 = make_partition(m, d, offset=0, size=20) p2 = make_partition(m, d, offset=80, size=20) - self.assertEqual( - gaps.find_disk_gaps_v2(d, info), - [p1, gaps.Gap(d, 20, 60), p2]) + self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p1, gaps.Gap(d, 20, 60), p2]) def test_small_gap(self): info = PartitionAlignmentData( - part_align=10, min_gap_size=20, min_start_offset=0, - min_end_offset=0, primary_part_limit=10) + part_align=10, + min_gap_size=20, + min_start_offset=0, + min_end_offset=0, + primary_part_limit=10, + ) m, d = make_model_and_disk(size=100) p1 = make_partition(m, d, offset=0, size=40) p2 = make_partition(m, d, offset=50, size=50) - self.assertEqual( - gaps.find_disk_gaps_v2(d, info), - [p1, p2]) + self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p1, p2]) def test_align_gap(self): info = PartitionAlignmentData( - part_align=10, min_gap_size=1, min_start_offset=0, - min_end_offset=0, primary_part_limit=10) + part_align=10, + min_gap_size=1, + min_start_offset=0, + min_end_offset=0, + primary_part_limit=10, + ) m, d = make_model_and_disk(size=100) p1 = make_partition(m, d, offset=0, size=17) p2 = make_partition(m, d, offset=53, size=47) - self.assertEqual( - gaps.find_disk_gaps_v2(d, info), - [p1, gaps.Gap(d, 20, 30), p2]) + self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p1, gaps.Gap(d, 20, 30), p2]) def test_all_extended(self): info = PartitionAlignmentData( - part_align=5, min_gap_size=1, min_start_offset=0, min_end_offset=0, - ebr_space=2, primary_part_limit=10) - m, d = make_model_and_disk(size=100, ptable='dos') - p = make_partition(m, d, offset=0, size=100, flag='extended') + part_align=5, + min_gap_size=1, + min_start_offset=0, + min_end_offset=0, + ebr_space=2, + primary_part_limit=10, + ) + m, d = make_model_and_disk(size=100, ptable="dos") + p = make_partition(m, d, offset=0, size=100, flag="extended") self.assertEqual( gaps.find_disk_gaps_v2(d, info), [ p, gaps.Gap(d, 5, 95, True), - ]) + ], + ) def test_half_extended(self): info = PartitionAlignmentData( - part_align=5, min_gap_size=1, min_start_offset=0, min_end_offset=0, - ebr_space=2, primary_part_limit=10) + part_align=5, + min_gap_size=1, + min_start_offset=0, + min_end_offset=0, + ebr_space=2, + primary_part_limit=10, + ) m, d = make_model_and_disk(size=100) - p = make_partition(m, d, offset=0, size=50, flag='extended') + p = make_partition(m, d, offset=0, size=50, flag="extended") self.assertEqual( gaps.find_disk_gaps_v2(d, info), - [p, gaps.Gap(d, 5, 45, True), gaps.Gap(d, 50, 50, False)]) + [p, gaps.Gap(d, 5, 45, True), gaps.Gap(d, 50, 50, False)], + ) def test_half_extended_one_logical(self): info = PartitionAlignmentData( - part_align=5, min_gap_size=1, min_start_offset=0, min_end_offset=0, - ebr_space=2, primary_part_limit=10) - m, d = make_model_and_disk(size=100, ptable='dos') - p1 = make_partition(m, d, offset=0, size=50, flag='extended') - p2 = make_partition(m, d, offset=5, size=45, flag='logical') + part_align=5, + min_gap_size=1, + min_start_offset=0, + min_end_offset=0, + ebr_space=2, + primary_part_limit=10, + ) + m, d = make_model_and_disk(size=100, ptable="dos") + p1 = make_partition(m, d, offset=0, size=50, flag="extended") + p2 = make_partition(m, d, offset=5, size=45, flag="logical") self.assertEqual( - gaps.find_disk_gaps_v2(d, info), - [p1, p2, gaps.Gap(d, 50, 50, False)]) + gaps.find_disk_gaps_v2(d, info), [p1, p2, gaps.Gap(d, 50, 50, False)] + ) def test_half_extended_half_logical(self): info = PartitionAlignmentData( - part_align=5, min_gap_size=1, min_start_offset=0, min_end_offset=0, - ebr_space=2, primary_part_limit=10) - m, d = make_model_and_disk(size=100, ptable='dos') - p1 = make_partition(m, d, offset=0, size=50, flag='extended') - p2 = make_partition(m, d, offset=5, size=25, flag='logical') + part_align=5, + min_gap_size=1, + min_start_offset=0, + min_end_offset=0, + ebr_space=2, + primary_part_limit=10, + ) + m, d = make_model_and_disk(size=100, ptable="dos") + p1 = make_partition(m, d, offset=0, size=50, flag="extended") + p2 = make_partition(m, d, offset=5, size=25, flag="logical") self.assertEqual( gaps.find_disk_gaps_v2(d, info), [ @@ -331,159 +368,204 @@ class TestDiskGaps(unittest.TestCase): p2, gaps.Gap(d, 35, 15, True), gaps.Gap(d, 50, 50, False), - ]) + ], + ) def test_gap_before_primary(self): # 0----10---20---30---40---50---60---70---80---90---100 # [ g1 ][ p1 (primary) ] info = PartitionAlignmentData( - part_align=5, min_gap_size=1, min_start_offset=0, min_end_offset=0, - ebr_space=1, primary_part_limit=10) - m, d = make_model_and_disk(size=100, ptable='dos') + part_align=5, + min_gap_size=1, + min_start_offset=0, + min_end_offset=0, + ebr_space=1, + primary_part_limit=10, + ) + m, d = make_model_and_disk(size=100, ptable="dos") p1 = make_partition(m, d, offset=50, size=50) self.assertEqual( - gaps.find_disk_gaps_v2(d, info), - [gaps.Gap(d, 0, 50, False), p1]) + gaps.find_disk_gaps_v2(d, info), [gaps.Gap(d, 0, 50, False), p1] + ) def test_gap_in_extended_before_logical(self): # 0----10---20---30---40---50---60---70---80---90---100 # [ p1 (extended) ] # [ g1 ] [ p5 (logical) ] info = PartitionAlignmentData( - part_align=5, min_gap_size=1, min_start_offset=0, min_end_offset=0, - ebr_space=1, primary_part_limit=10) - m, d = make_model_and_disk(size=100, ptable='dos') - p1 = make_partition(m, d, offset=0, size=100, flag='extended') - p5 = make_partition(m, d, offset=50, size=50, flag='logical') + part_align=5, + min_gap_size=1, + min_start_offset=0, + min_end_offset=0, + ebr_space=1, + primary_part_limit=10, + ) + m, d = make_model_and_disk(size=100, ptable="dos") + p1 = make_partition(m, d, offset=0, size=100, flag="extended") + p5 = make_partition(m, d, offset=50, size=50, flag="logical") self.assertEqual( - gaps.find_disk_gaps_v2(d, info), - [p1, gaps.Gap(d, 5, 40, True), p5]) + gaps.find_disk_gaps_v2(d, info), [p1, gaps.Gap(d, 5, 40, True), p5] + ) def test_unusable_gap_primaries(self): info = PartitionAlignmentData( - part_align=10, min_gap_size=1, min_start_offset=0, - min_end_offset=0, primary_part_limit=1) + part_align=10, + min_gap_size=1, + min_start_offset=0, + min_end_offset=0, + primary_part_limit=1, + ) m, d = make_model_and_disk(size=100) p = make_partition(m, d, offset=0, size=90) - g = gaps.Gap(d, offset=90, size=10, - usable=GapUsable.TOO_MANY_PRIMARY_PARTS) - self.assertEqual( - gaps.find_disk_gaps_v2(d, info), - [p, g]) + g = gaps.Gap(d, offset=90, size=10, usable=GapUsable.TOO_MANY_PRIMARY_PARTS) + self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p, g]) def test_usable_gap_primaries(self): info = PartitionAlignmentData( - part_align=10, min_gap_size=1, min_start_offset=0, - min_end_offset=0, primary_part_limit=2) + part_align=10, + min_gap_size=1, + min_start_offset=0, + min_end_offset=0, + primary_part_limit=2, + ) m, d = make_model_and_disk(size=100) p = make_partition(m, d, offset=0, size=90) g = gaps.Gap(d, offset=90, size=10, usable=GapUsable.YES) - self.assertEqual( - gaps.find_disk_gaps_v2(d, info), - [p, g]) + self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p, g]) def test_usable_gap_extended(self): info = PartitionAlignmentData( - part_align=10, min_gap_size=1, min_start_offset=0, - min_end_offset=0, primary_part_limit=1) - m, d = make_model_and_disk(size=100, ptable='dos') - p = make_partition(m, d, offset=0, size=100, flag='extended') - g = gaps.Gap(d, offset=0, size=100, - in_extended=True, usable=GapUsable.YES) - self.assertEqual( - gaps.find_disk_gaps_v2(d, info), - [p, g]) + part_align=10, + min_gap_size=1, + min_start_offset=0, + min_end_offset=0, + primary_part_limit=1, + ) + m, d = make_model_and_disk(size=100, ptable="dos") + p = make_partition(m, d, offset=0, size=100, flag="extended") + g = gaps.Gap(d, offset=0, size=100, in_extended=True, usable=GapUsable.YES) + self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p, g]) class TestMovableTrailingPartitionsAndGapSize(GapTestCase): - def test_largest_non_extended(self): - self.use_alignment_data(PartitionAlignmentData( - part_align=10, min_gap_size=1, min_start_offset=0, - min_end_offset=0, primary_part_limit=4)) + self.use_alignment_data( + PartitionAlignmentData( + part_align=10, + min_gap_size=1, + min_start_offset=0, + min_end_offset=0, + primary_part_limit=4, + ) + ) # 0----10---20---30---40---50---60---70---80---90---100 # [ p1 ] [ p2 (extended) ] # [ g1 ][ g2 ] - m, d = make_model_and_disk(size=100, ptable='dos') + m, d = make_model_and_disk(size=100, ptable="dos") make_partition(m, d, offset=0, size=10) - make_partition(m, d, offset=30, size=70, flag='extended') - g = gaps.Gap(d, offset=10, size=20, - in_extended=False, usable=GapUsable.YES) + make_partition(m, d, offset=30, size=70, flag="extended") + g = gaps.Gap(d, offset=10, size=20, in_extended=False, usable=GapUsable.YES) self.assertEqual(g, gaps.largest_gap(d, in_extended=False)) def test_largest_yes_extended(self): - self.use_alignment_data(PartitionAlignmentData( - part_align=10, min_gap_size=1, min_start_offset=0, - min_end_offset=0, primary_part_limit=4)) + self.use_alignment_data( + PartitionAlignmentData( + part_align=10, + min_gap_size=1, + min_start_offset=0, + min_end_offset=0, + primary_part_limit=4, + ) + ) # 0----10---20---30---40---50---60---70---80---90---100 # [ p1 ] [ p2 (extended) ] # [ g1 ][ g2 ] - m, d = make_model_and_disk(size=100, ptable='dos') + m, d = make_model_and_disk(size=100, ptable="dos") make_partition(m, d, offset=0, size=10) - make_partition(m, d, offset=30, size=70, flag='extended') - g = gaps.Gap(d, offset=30, size=70, - in_extended=True, usable=GapUsable.YES) + make_partition(m, d, offset=30, size=70, flag="extended") + g = gaps.Gap(d, offset=30, size=70, in_extended=True, usable=GapUsable.YES) self.assertEqual(g, gaps.largest_gap(d, in_extended=True)) def test_largest_any_extended(self): - self.use_alignment_data(PartitionAlignmentData( - part_align=10, min_gap_size=1, min_start_offset=0, - min_end_offset=0, primary_part_limit=4)) + self.use_alignment_data( + PartitionAlignmentData( + part_align=10, + min_gap_size=1, + min_start_offset=0, + min_end_offset=0, + primary_part_limit=4, + ) + ) # 0----10---20---30---40---50---60---70---80---90---100 # [ p1 ] [ p2 (extended) ] # [ g1 ][ g2 ] - m, d = make_model_and_disk(size=100, ptable='dos') + m, d = make_model_and_disk(size=100, ptable="dos") make_partition(m, d, offset=0, size=10) - make_partition(m, d, offset=30, size=70, flag='extended') - g = gaps.Gap(d, offset=30, size=70, - in_extended=True, usable=GapUsable.YES) + make_partition(m, d, offset=30, size=70, flag="extended") + g = gaps.Gap(d, offset=30, size=70, in_extended=True, usable=GapUsable.YES) self.assertEqual(g, gaps.largest_gap(d)) def test_no_next_gap(self): - self.use_alignment_data(PartitionAlignmentData( - part_align=10, min_gap_size=1, min_start_offset=10, - min_end_offset=10, primary_part_limit=10)) + self.use_alignment_data( + PartitionAlignmentData( + part_align=10, + min_gap_size=1, + min_start_offset=10, + min_end_offset=10, + primary_part_limit=10, + ) + ) # 0----10---20---30---40---50---60---70---80---90---100 # #####[ p1 ]##### m, d = make_model_and_disk(size=100) p = make_partition(m, d, offset=10, size=80) - self.assertEqual( - ([], 0), - gaps.movable_trailing_partitions_and_gap_size(p)) + self.assertEqual(([], 0), gaps.movable_trailing_partitions_and_gap_size(p)) def test_immediately_trailing_gap(self): - self.use_alignment_data(PartitionAlignmentData( - part_align=10, min_gap_size=1, min_start_offset=10, - min_end_offset=10, primary_part_limit=10)) + self.use_alignment_data( + PartitionAlignmentData( + part_align=10, + min_gap_size=1, + min_start_offset=10, + min_end_offset=10, + primary_part_limit=10, + ) + ) # 0----10---20---30---40---50---60---70---80---90---100 # #####[ p1 ] [ p2 ] ##### m, d = make_model_and_disk(size=100) p1 = make_partition(m, d, offset=10, size=20) p2 = make_partition(m, d, offset=50, size=20) - self.assertEqual( - ([], 20), - gaps.movable_trailing_partitions_and_gap_size(p1)) - self.assertEqual( - ([], 20), - gaps.movable_trailing_partitions_and_gap_size(p2)) + self.assertEqual(([], 20), gaps.movable_trailing_partitions_and_gap_size(p1)) + self.assertEqual(([], 20), gaps.movable_trailing_partitions_and_gap_size(p2)) def test_one_trailing_movable_partition_and_gap(self): - self.use_alignment_data(PartitionAlignmentData( - part_align=10, min_gap_size=1, min_start_offset=10, - min_end_offset=10, primary_part_limit=10)) + self.use_alignment_data( + PartitionAlignmentData( + part_align=10, + min_gap_size=1, + min_start_offset=10, + min_end_offset=10, + primary_part_limit=10, + ) + ) # 0----10---20---30---40---50---60---70---80---90---100 # #####[ p1 ][ p2 ] ##### m, d = make_model_and_disk(size=100) p1 = make_partition(m, d, offset=10, size=40) p2 = make_partition(m, d, offset=50, size=10) - self.assertEqual( - ([p2], 30), - gaps.movable_trailing_partitions_and_gap_size(p1)) + self.assertEqual(([p2], 30), gaps.movable_trailing_partitions_and_gap_size(p1)) def test_two_trailing_movable_partitions_and_gap(self): - self.use_alignment_data(PartitionAlignmentData( - part_align=10, min_gap_size=1, min_start_offset=10, - min_end_offset=10, primary_part_limit=10)) + self.use_alignment_data( + PartitionAlignmentData( + part_align=10, + min_gap_size=1, + min_start_offset=10, + min_end_offset=10, + primary_part_limit=10, + ) + ) # 0----10---20---30---40---50---60---70---80---90---100 # #####[ p1 ][ p2 ][ p3 ] ##### m, d = make_model_and_disk(size=100) @@ -491,122 +573,152 @@ class TestMovableTrailingPartitionsAndGapSize(GapTestCase): p2 = make_partition(m, d, offset=50, size=10) p3 = make_partition(m, d, offset=60, size=10) self.assertEqual( - ([p2, p3], 20), - gaps.movable_trailing_partitions_and_gap_size(p1)) + ([p2, p3], 20), gaps.movable_trailing_partitions_and_gap_size(p1) + ) def test_one_trailing_movable_partition_and_no_gap(self): - self.use_alignment_data(PartitionAlignmentData( - part_align=10, min_gap_size=1, min_start_offset=10, - min_end_offset=10, primary_part_limit=10)) + self.use_alignment_data( + PartitionAlignmentData( + part_align=10, + min_gap_size=1, + min_start_offset=10, + min_end_offset=10, + primary_part_limit=10, + ) + ) # 0----10---20---30---40---50---60---70---80---90---100 # #####[ p1 ][ p2 ]##### m, d = make_model_and_disk(size=100) p1 = make_partition(m, d, offset=10, size=40) p2 = make_partition(m, d, offset=50, size=40) - self.assertEqual( - ([p2], 0), - gaps.movable_trailing_partitions_and_gap_size(p1)) + self.assertEqual(([p2], 0), gaps.movable_trailing_partitions_and_gap_size(p1)) def test_full_extended_partition_then_gap(self): - self.use_alignment_data(PartitionAlignmentData( - part_align=1, min_gap_size=1, min_start_offset=10, - min_end_offset=10, primary_part_limit=10, ebr_space=2)) + self.use_alignment_data( + PartitionAlignmentData( + part_align=1, + min_gap_size=1, + min_start_offset=10, + min_end_offset=10, + primary_part_limit=10, + ebr_space=2, + ) + ) # 0----10---20---30---40---50---60---70---80---90---100 # #####[ p1 (extended) ] ##### # ######[ p5 (logical) ] ##### - m, d = make_model_and_disk(size=100, ptable='dos') - make_partition(m, d, offset=10, size=40, flag='extended') - p5 = make_partition(m, d, offset=12, size=38, flag='logical') - self.assertEqual( - ([], 0), - gaps.movable_trailing_partitions_and_gap_size(p5)) + m, d = make_model_and_disk(size=100, ptable="dos") + make_partition(m, d, offset=10, size=40, flag="extended") + p5 = make_partition(m, d, offset=12, size=38, flag="logical") + self.assertEqual(([], 0), gaps.movable_trailing_partitions_and_gap_size(p5)) def test_full_extended_partition_then_part(self): - self.use_alignment_data(PartitionAlignmentData( - part_align=1, min_gap_size=1, min_start_offset=10, - min_end_offset=10, primary_part_limit=10, ebr_space=2)) + self.use_alignment_data( + PartitionAlignmentData( + part_align=1, + min_gap_size=1, + min_start_offset=10, + min_end_offset=10, + primary_part_limit=10, + ebr_space=2, + ) + ) # 0----10---20---30---40---50---60---70---80---90---100 # #####[ p1 (extended) ][ p2 ]##### # ######[ p5 (logical) ] ##### - m, d = make_model_and_disk(size=100, ptable='dos') - make_partition(m, d, offset=10, size=40, flag='extended') + m, d = make_model_and_disk(size=100, ptable="dos") + make_partition(m, d, offset=10, size=40, flag="extended") make_partition(m, d, offset=50, size=40) - p5 = make_partition(m, d, offset=12, size=38, flag='logical') - self.assertEqual( - ([], 0), - gaps.movable_trailing_partitions_and_gap_size(p5)) + p5 = make_partition(m, d, offset=12, size=38, flag="logical") + self.assertEqual(([], 0), gaps.movable_trailing_partitions_and_gap_size(p5)) def test_gap_in_extended_partition(self): - self.use_alignment_data(PartitionAlignmentData( - part_align=1, min_gap_size=1, min_start_offset=10, - min_end_offset=10, primary_part_limit=10, ebr_space=2)) + self.use_alignment_data( + PartitionAlignmentData( + part_align=1, + min_gap_size=1, + min_start_offset=10, + min_end_offset=10, + primary_part_limit=10, + ebr_space=2, + ) + ) # 0----10---20---30---40---50---60---70---80---90---100 # #####[ p1 (extended) ] ##### # ######[ p5 (logical)] ##### - m, d = make_model_and_disk(size=100, ptable='dos') - make_partition(m, d, offset=10, size=40, flag='extended') - p5 = make_partition(m, d, offset=12, size=30, flag='logical') - self.assertEqual( - ([], 6), - gaps.movable_trailing_partitions_and_gap_size(p5)) + m, d = make_model_and_disk(size=100, ptable="dos") + make_partition(m, d, offset=10, size=40, flag="extended") + p5 = make_partition(m, d, offset=12, size=30, flag="logical") + self.assertEqual(([], 6), gaps.movable_trailing_partitions_and_gap_size(p5)) def test_trailing_logical_partition_then_gap(self): - self.use_alignment_data(PartitionAlignmentData( - part_align=1, min_gap_size=1, min_start_offset=10, - min_end_offset=10, primary_part_limit=10, ebr_space=2)) + self.use_alignment_data( + PartitionAlignmentData( + part_align=1, + min_gap_size=1, + min_start_offset=10, + min_end_offset=10, + primary_part_limit=10, + ebr_space=2, + ) + ) # 0----10---20---30---40---50---60---70---80---90---100 # #####[ p1 (extended) ]##### # ######[ p5 (logical)] [ p6 (logical)] ##### - m, d = make_model_and_disk(size=100, ptable='dos') - make_partition(m, d, offset=10, size=80, flag='extended') - p5 = make_partition(m, d, offset=12, size=30, flag='logical') - p6 = make_partition(m, d, offset=44, size=30, flag='logical') - self.assertEqual( - ([p6], 14), - gaps.movable_trailing_partitions_and_gap_size(p5)) + m, d = make_model_and_disk(size=100, ptable="dos") + make_partition(m, d, offset=10, size=80, flag="extended") + p5 = make_partition(m, d, offset=12, size=30, flag="logical") + p6 = make_partition(m, d, offset=44, size=30, flag="logical") + self.assertEqual(([p6], 14), gaps.movable_trailing_partitions_and_gap_size(p5)) def test_trailing_logical_partition_then_no_gap(self): - self.use_alignment_data(PartitionAlignmentData( - part_align=1, min_gap_size=1, min_start_offset=10, - min_end_offset=10, primary_part_limit=10, ebr_space=2)) + self.use_alignment_data( + PartitionAlignmentData( + part_align=1, + min_gap_size=1, + min_start_offset=10, + min_end_offset=10, + primary_part_limit=10, + ebr_space=2, + ) + ) # 0----10---20---30---40---50---60---70---80---90---100 # #####[ p1 (extended) ]##### # ######[ p5 (logical)] [ p6 (logical) ] ##### - m, d = make_model_and_disk(size=100, ptable='dos') - make_partition(m, d, offset=10, size=80, flag='extended') - p5 = make_partition(m, d, offset=12, size=30, flag='logical') - p6 = make_partition(m, d, offset=44, size=44, flag='logical') - self.assertEqual( - ([p6], 0), - gaps.movable_trailing_partitions_and_gap_size(p5)) + m, d = make_model_and_disk(size=100, ptable="dos") + make_partition(m, d, offset=10, size=80, flag="extended") + p5 = make_partition(m, d, offset=12, size=30, flag="logical") + p6 = make_partition(m, d, offset=44, size=44, flag="logical") + self.assertEqual(([p6], 0), gaps.movable_trailing_partitions_and_gap_size(p5)) def test_trailing_preserved_partition(self): - self.use_alignment_data(PartitionAlignmentData( - part_align=10, min_gap_size=1, min_start_offset=10, - min_end_offset=10, primary_part_limit=10)) + self.use_alignment_data( + PartitionAlignmentData( + part_align=10, + min_gap_size=1, + min_start_offset=10, + min_end_offset=10, + primary_part_limit=10, + ) + ) # 0----10---20---30---40---50---60---70---80---90---100 # #####[ p1 ][ p2 p ] ##### m, d = make_model_and_disk(size=100) p1 = make_partition(m, d, offset=10, size=40) make_partition(m, d, offset=50, size=20, preserve=True) - self.assertEqual( - ([], 0), - gaps.movable_trailing_partitions_and_gap_size(p1)) + self.assertEqual(([], 0), gaps.movable_trailing_partitions_and_gap_size(p1)) def test_trailing_lvm_logical_volume_no_gap(self): model, lv = make_model_and_lv() - self.assertEqual( - ([], 0), - gaps.movable_trailing_partitions_and_gap_size(lv)) + self.assertEqual(([], 0), gaps.movable_trailing_partitions_and_gap_size(lv)) def test_trailing_lvm_logical_volume_with_gap(self): - model, disk = make_model_and_disk( - size=100 * LVM_CHUNK_SIZE + LVM_OVERHEAD) + model, disk = make_model_and_disk(size=100 * LVM_CHUNK_SIZE + LVM_OVERHEAD) vg = make_vg(model, pvs={disk}) lv = make_lv(model, vg=vg, size=40 * LVM_CHUNK_SIZE) self.assertEqual( - ([], LVM_CHUNK_SIZE * 60), - gaps.movable_trailing_partitions_and_gap_size(lv)) + ([], LVM_CHUNK_SIZE * 60), gaps.movable_trailing_partitions_and_gap_size(lv) + ) class TestLargestGaps(unittest.TestCase): @@ -667,6 +779,7 @@ class TestLargestGaps(unittest.TestCase): class TestUsable(unittest.TestCase): def test_strings(self): - self.assertEqual('YES', GapUsable.YES.name) - self.assertEqual('TOO_MANY_PRIMARY_PARTS', - GapUsable.TOO_MANY_PRIMARY_PARTS.name) + self.assertEqual("YES", GapUsable.YES.name) + self.assertEqual( + "TOO_MANY_PRIMARY_PARTS", GapUsable.TOO_MANY_PRIMARY_PARTS.name + ) diff --git a/subiquity/common/filesystem/tests/test_labels.py b/subiquity/common/filesystem/tests/test_labels.py index a24f7e4d..65174acc 100644 --- a/subiquity/common/filesystem/tests/test_labels.py +++ b/subiquity/common/filesystem/tests/test_labels.py @@ -16,22 +16,17 @@ import unittest -from subiquity.common.filesystem.labels import ( - annotations, - for_client, - usage_labels, - ) +from subiquity.common.filesystem.labels import annotations, for_client, usage_labels from subiquity.models.tests.test_filesystem import ( make_model, make_model_and_disk, make_model_and_partition, make_model_and_raid, make_partition, - ) +) class TestAnnotations(unittest.TestCase): - def test_disk_annotations(self): # disks never have annotations model, disk = make_model_and_disk() @@ -42,93 +37,88 @@ class TestAnnotations(unittest.TestCase): def test_partition_annotations(self): model = make_model() part = make_partition(model) - self.assertEqual(annotations(part), ['new']) + self.assertEqual(annotations(part), ["new"]) part.preserve = True - self.assertEqual(annotations(part), ['existing']) + self.assertEqual(annotations(part), ["existing"]) model = make_model() part = make_partition(model, flag="bios_grub") - self.assertEqual( - annotations(part), ['new', 'BIOS grub spacer']) + self.assertEqual(annotations(part), ["new", "BIOS grub spacer"]) part.preserve = True self.assertEqual( - annotations(part), - ['existing', 'unconfigured', 'BIOS grub spacer']) + annotations(part), ["existing", "unconfigured", "BIOS grub spacer"] + ) part.device.grub_device = True self.assertEqual( - annotations(part), - ['existing', 'configured', 'BIOS grub spacer']) + annotations(part), ["existing", "configured", "BIOS grub spacer"] + ) model = make_model() part = make_partition(model, flag="boot", grub_device=True) - self.assertEqual(annotations(part), ['new', 'backup ESP']) + self.assertEqual(annotations(part), ["new", "backup ESP"]) fs = model.add_filesystem(part, fstype="fat32") model.add_mount(fs, "/boot/efi") - self.assertEqual(annotations(part), ['new', 'primary ESP']) + self.assertEqual(annotations(part), ["new", "primary ESP"]) model = make_model() part = make_partition(model, flag="boot", preserve=True) - self.assertEqual(annotations(part), ['existing', 'unused ESP']) + self.assertEqual(annotations(part), ["existing", "unused ESP"]) part.grub_device = True - self.assertEqual(annotations(part), ['existing', 'backup ESP']) + self.assertEqual(annotations(part), ["existing", "backup ESP"]) fs = model.add_filesystem(part, fstype="fat32") model.add_mount(fs, "/boot/efi") - self.assertEqual(annotations(part), ['existing', 'primary ESP']) + self.assertEqual(annotations(part), ["existing", "primary ESP"]) model = make_model() part = make_partition(model, flag="prep", grub_device=True) - self.assertEqual(annotations(part), ['new', 'PReP']) + self.assertEqual(annotations(part), ["new", "PReP"]) model = make_model() part = make_partition(model, flag="prep", preserve=True) - self.assertEqual( - annotations(part), ['existing', 'PReP', 'unconfigured']) + self.assertEqual(annotations(part), ["existing", "PReP", "unconfigured"]) part.grub_device = True - self.assertEqual( - annotations(part), ['existing', 'PReP', 'configured']) + self.assertEqual(annotations(part), ["existing", "PReP", "configured"]) def test_vg_default_annotations(self): model, disk = make_model_and_disk() - vg = model.add_volgroup('vg-0', {disk}) - self.assertEqual(annotations(vg), ['new']) + vg = model.add_volgroup("vg-0", {disk}) + self.assertEqual(annotations(vg), ["new"]) vg.preserve = True - self.assertEqual(annotations(vg), ['existing']) + self.assertEqual(annotations(vg), ["existing"]) def test_vg_encrypted_annotations(self): model, disk = make_model_and_disk() - dm_crypt = model.add_dm_crypt(disk, key='passw0rd') - vg = model.add_volgroup('vg-0', {dm_crypt}) - self.assertEqual(annotations(vg), ['new', 'encrypted']) + dm_crypt = model.add_dm_crypt(disk, key="passw0rd") + vg = model.add_volgroup("vg-0", {dm_crypt}) + self.assertEqual(annotations(vg), ["new", "encrypted"]) class TestUsageLabels(unittest.TestCase): - def test_partition_usage_labels(self): model, partition = make_model_and_partition() self.assertEqual(usage_labels(partition), ["unused"]) - fs = model.add_filesystem(partition, 'ext4') + fs = model.add_filesystem(partition, "ext4") self.assertEqual( - usage_labels(partition), - ["to be formatted as ext4", "not mounted"]) + usage_labels(partition), ["to be formatted as ext4", "not mounted"] + ) model._orig_config = model._render_actions() fs.preserve = True partition.preserve = True self.assertEqual( - usage_labels(partition), - ["already formatted as ext4", "not mounted"]) + usage_labels(partition), ["already formatted as ext4", "not mounted"] + ) model.remove_filesystem(fs) - fs2 = model.add_filesystem(partition, 'ext4') + fs2 = model.add_filesystem(partition, "ext4") self.assertEqual( - usage_labels(partition), - ["to be reformatted as ext4", "not mounted"]) - model.add_mount(fs2, '/') + usage_labels(partition), ["to be reformatted as ext4", "not mounted"] + ) + model.add_mount(fs2, "/") self.assertEqual( - usage_labels(partition), - ["to be reformatted as ext4", "mounted at /"]) + usage_labels(partition), ["to be reformatted as ext4", "mounted at /"] + ) class TestForClient(unittest.TestCase): - def test_for_client_raid_parts(self): model, raid = make_model_and_raid() make_partition(model, raid) diff --git a/subiquity/common/filesystem/tests/test_manipulator.py b/subiquity/common/filesystem/tests/test_manipulator.py index 3eb7a328..64622dce 100644 --- a/subiquity/common/filesystem/tests/test_manipulator.py +++ b/subiquity/common/filesystem/tests/test_manipulator.py @@ -19,24 +19,17 @@ from unittest import mock import attr -from subiquitycore.tests.parameterized import parameterized - -from subiquity.common.filesystem.actions import ( - DeviceAction, - ) from subiquity.common.filesystem import boot, gaps, sizes +from subiquity.common.filesystem.actions import DeviceAction from subiquity.common.filesystem.manipulator import FilesystemManipulator +from subiquity.models.filesystem import Bootloader, MiB, Partition from subiquity.models.tests.test_filesystem import ( make_disk, + make_filesystem, make_model, make_partition, - make_filesystem, - ) -from subiquity.models.filesystem import ( - Bootloader, - MiB, - Partition, - ) +) +from subiquitycore.tests.parameterized import parameterized def make_manipulator(bootloader=None, storage_version=1): @@ -78,18 +71,16 @@ class Create: class TestFilesystemManipulator(unittest.TestCase): - def test_delete_encrypted_vg(self): manipulator, disk = make_manipulator_and_disk() spec = { - 'passphrase': 'passw0rd', - 'devices': {disk}, - 'name': 'vg0', - } + "passphrase": "passw0rd", + "devices": {disk}, + "name": "vg0", + } vg = manipulator.create_volgroup(spec) manipulator.delete_volgroup(vg) - dm_crypts = [ - a for a in manipulator.model._actions if a.type == 'dm_crypt'] + dm_crypts = [a for a in manipulator.model._actions if a.type == "dm_crypt"] self.assertEqual(dm_crypts, []) def test_wipe_existing_fs(self): @@ -98,18 +89,19 @@ class TestFilesystemManipulator(unittest.TestCase): # example: it's ext2, wipe and make it ext2 again manipulator, d = make_manipulator_and_disk() p = make_partition(manipulator.model, d, preserve=True) - fs = make_filesystem(manipulator.model, partition=p, fstype='ext2', - preserve=True) + fs = make_filesystem( + manipulator.model, partition=p, fstype="ext2", preserve=True + ) manipulator.model._actions.append(fs) manipulator.delete_filesystem(fs) - spec = {'wipe': 'random'} - with mock.patch.object(p, 'original_fstype') as original_fstype: - original_fstype.return_value = 'ext2' + spec = {"wipe": "random"} + with mock.patch.object(p, "original_fstype") as original_fstype: + original_fstype.return_value = "ext2" fs = manipulator.create_filesystem(p, spec) self.assertTrue(p.preserve) - self.assertEqual('random', p.wipe) + self.assertEqual("random", p.wipe) self.assertFalse(fs.preserve) - self.assertEqual('ext2', fs.fstype) + self.assertEqual("ext2", fs.fstype) def test_can_only_add_boot_once(self): # This is really testing model code but it's much easier to test with a @@ -122,7 +114,8 @@ class TestFilesystemManipulator(unittest.TestCase): self.assertFalse( DeviceAction.TOGGLE_BOOT.can(disk)[0], "add_boot_disk(disk) did not make _can_TOGGLE_BOOT false " - "with bootloader {}".format(bl)) + "with bootloader {}".format(bl), + ) def assertIsMountedAtBootEFI(self, device): efi_mnts = device._m._all(type="mount", path="/boot/efi") @@ -136,20 +129,23 @@ class TestFilesystemManipulator(unittest.TestCase): def add_existing_boot_partition(self, manipulator, disk): if manipulator.model.bootloader == Bootloader.BIOS: part = manipulator.model.add_partition( - disk, size=1 << 20, offset=0, flag="bios_grub") + disk, size=1 << 20, offset=0, flag="bios_grub" + ) elif manipulator.model.bootloader == Bootloader.UEFI: part = manipulator.model.add_partition( - disk, size=512 << 20, offset=0, flag="boot") + disk, size=512 << 20, offset=0, flag="boot" + ) elif manipulator.model.bootloader == Bootloader.PREP: part = manipulator.model.add_partition( - disk, size=8 << 20, offset=0, flag="prep") + disk, size=8 << 20, offset=0, flag="prep" + ) part.preserve = True return part def assertIsBootDisk(self, manipulator, disk): if manipulator.model.bootloader == Bootloader.BIOS: self.assertTrue(disk.grub_device) - if disk.ptable == 'gpt': + if disk.ptable == "gpt": self.assertEqual(disk.partitions()[0].flag, "bios_grub") elif manipulator.model.bootloader == Bootloader.UEFI: for part in disk.partitions(): @@ -159,7 +155,7 @@ class TestFilesystemManipulator(unittest.TestCase): elif manipulator.model.bootloader == Bootloader.PREP: for part in disk.partitions(): if part.flag == "prep" and part.grub_device: - self.assertEqual(part.wipe, 'zero') + self.assertEqual(part.wipe, "zero") return self.fail("{} is not a boot disk".format(disk)) @@ -186,7 +182,8 @@ class TestFilesystemManipulator(unittest.TestCase): disk2 = make_disk(manipulator.model, preserve=False) gap = gaps.largest_gap(disk2) disk2p1 = manipulator.model.add_partition( - disk2, size=gap.size, offset=gap.offset) + disk2, size=gap.size, offset=gap.offset + ) manipulator.add_boot_disk(disk1) self.assertIsBootDisk(manipulator, disk1) @@ -202,8 +199,7 @@ class TestFilesystemManipulator(unittest.TestCase): self.assertNotMounted(disk2.partitions()[0]) self.assertEqual(len(disk2.partitions()), 2) self.assertEqual(disk2.partitions()[1], disk2p1) - self.assertEqual( - disk2.partitions()[0].size + disk2p1.size, size_before) + self.assertEqual(disk2.partitions()[0].size + disk2p1.size, size_before) manipulator.remove_boot_disk(disk1) self.assertIsNotBootDisk(manipulator, disk1) @@ -228,7 +224,8 @@ class TestFilesystemManipulator(unittest.TestCase): disk2 = make_disk(manipulator.model, preserve=False) gap = gaps.largest_gap(disk2) disk2p1 = manipulator.model.add_partition( - disk2, size=gap.size, offset=gap.offset) + disk2, size=gap.size, offset=gap.offset + ) manipulator.add_boot_disk(disk1) self.assertIsBootDisk(manipulator, disk1) @@ -243,8 +240,7 @@ class TestFilesystemManipulator(unittest.TestCase): self.assertIsMountedAtBootEFI(disk2.partitions()[0]) self.assertEqual(len(disk2.partitions()), 2) self.assertEqual(disk2.partitions()[1], disk2p1) - self.assertEqual( - disk2.partitions()[0].size + disk2p1.size, size_before) + self.assertEqual(disk2.partitions()[0].size + disk2p1.size, size_before) def test_boot_disk_existing(self): for bl in Bootloader: @@ -272,13 +268,16 @@ class TestFilesystemManipulator(unittest.TestCase): manipulator = make_manipulator(Bootloader.UEFI) disk1 = make_disk(manipulator.model, preserve=True) disk1p1 = manipulator.model.add_partition( - disk1, size=512 << 20, offset=0, flag="boot") + disk1, size=512 << 20, offset=0, flag="boot" + ) disk1p1.preserve = True disk1p2 = manipulator.model.add_partition( - disk1, size=8192 << 20, offset=513 << 20) + disk1, size=8192 << 20, offset=513 << 20 + ) disk1p2.preserve = True manipulator.partition_disk_handler( - disk1, {'fstype': 'ext4', 'mount': '/'}, partition=disk1p2) + disk1, {"fstype": "ext4", "mount": "/"}, partition=disk1p2 + ) efi_mnt = manipulator.model._mount_for_path("/boot/efi") self.assertEqual(efi_mnt.device.volume, disk1p1) @@ -295,14 +294,15 @@ class TestFilesystemManipulator(unittest.TestCase): @contextlib.contextmanager def assertPartitionOperations(self, disk, *ops): - existing_parts = set(disk.partitions()) part_details = {} for op in ops: if isinstance(op, MoveResize): part_details[op.part] = ( - op.part.offset + op.offset, op.part.size + op.size) + op.part.offset + op.offset, + op.part.size + op.size, + ) try: yield @@ -314,8 +314,8 @@ class TestFilesystemManipulator(unittest.TestCase): if isinstance(op, MoveResize): new_parts.remove(op.part) self.assertEqual( - part_details[op.part], - (op.part.offset, op.part.size)) + part_details[op.part], (op.part.offset, op.part.size) + ) else: for part in created_parts: lhs = (part.offset, part.size) @@ -336,11 +336,12 @@ class TestFilesystemManipulator(unittest.TestCase): manipulator = make_manipulator(Bootloader.BIOS, version) disk = make_disk(manipulator.model, preserve=True) with self.assertPartitionOperations( - disk, - Create( - offset=disk.alignment_data().min_start_offset, - size=sizes.BIOS_GRUB_SIZE_BYTES), - ): + disk, + Create( + offset=disk.alignment_data().min_start_offset, + size=sizes.BIOS_GRUB_SIZE_BYTES, + ), + ): manipulator.add_boot_disk(disk) self.assertIsBootDisk(manipulator, disk) @@ -349,19 +350,20 @@ class TestFilesystemManipulator(unittest.TestCase): def test_add_boot_BIOS_full(self, version): manipulator = make_manipulator(Bootloader.BIOS, version) disk = make_disk(manipulator.model, preserve=True) - part = make_partition( - manipulator.model, disk, size=gaps.largest_gap_size(disk)) + part = make_partition(manipulator.model, disk, size=gaps.largest_gap_size(disk)) with self.assertPartitionOperations( - disk, - Create( - offset=disk.alignment_data().min_start_offset, - size=sizes.BIOS_GRUB_SIZE_BYTES), - MoveResize( - part=part, - offset=sizes.BIOS_GRUB_SIZE_BYTES, - size=-sizes.BIOS_GRUB_SIZE_BYTES), - ): + disk, + Create( + offset=disk.alignment_data().min_start_offset, + size=sizes.BIOS_GRUB_SIZE_BYTES, + ), + MoveResize( + part=part, + offset=sizes.BIOS_GRUB_SIZE_BYTES, + size=-sizes.BIOS_GRUB_SIZE_BYTES, + ), + ): manipulator.add_boot_disk(disk) self.assertIsBootDisk(manipulator, disk) @@ -371,15 +373,16 @@ class TestFilesystemManipulator(unittest.TestCase): manipulator = make_manipulator(Bootloader.BIOS, version) disk = make_disk(manipulator.model, preserve=True) part = make_partition( - manipulator.model, disk, size=gaps.largest_gap_size(disk)//2) + manipulator.model, disk, size=gaps.largest_gap_size(disk) // 2 + ) with self.assertPartitionOperations( - disk, - Create( - offset=disk.alignment_data().min_start_offset, - size=sizes.BIOS_GRUB_SIZE_BYTES), - Move( - part=part, offset=sizes.BIOS_GRUB_SIZE_BYTES), - ): + disk, + Create( + offset=disk.alignment_data().min_start_offset, + size=sizes.BIOS_GRUB_SIZE_BYTES, + ), + Move(part=part, offset=sizes.BIOS_GRUB_SIZE_BYTES), + ): manipulator.add_boot_disk(disk) self.assertIsBootDisk(manipulator, disk) @@ -388,23 +391,26 @@ class TestFilesystemManipulator(unittest.TestCase): manipulator = make_manipulator(Bootloader.BIOS) # 2002MiB so that the space available for partitioning (2000MiB) # divided by 4 is an whole number of megabytes. - disk = make_disk(manipulator.model, preserve=True, size=2002*MiB) + disk = make_disk(manipulator.model, preserve=True, size=2002 * MiB) part_smaller = make_partition( - manipulator.model, disk, size=gaps.largest_gap_size(disk)//4) + manipulator.model, disk, size=gaps.largest_gap_size(disk) // 4 + ) part_larger = make_partition( - manipulator.model, disk, size=gaps.largest_gap_size(disk)) + manipulator.model, disk, size=gaps.largest_gap_size(disk) + ) with self.assertPartitionOperations( - disk, - Create( - offset=disk.alignment_data().min_start_offset, - size=sizes.BIOS_GRUB_SIZE_BYTES), - Move( - part=part_smaller, offset=sizes.BIOS_GRUB_SIZE_BYTES), - MoveResize( - part=part_larger, - offset=sizes.BIOS_GRUB_SIZE_BYTES, - size=-sizes.BIOS_GRUB_SIZE_BYTES), - ): + disk, + Create( + offset=disk.alignment_data().min_start_offset, + size=sizes.BIOS_GRUB_SIZE_BYTES, + ), + Move(part=part_smaller, offset=sizes.BIOS_GRUB_SIZE_BYTES), + MoveResize( + part=part_larger, + offset=sizes.BIOS_GRUB_SIZE_BYTES, + size=-sizes.BIOS_GRUB_SIZE_BYTES, + ), + ): manipulator.add_boot_disk(disk) self.assertIsBootDisk(manipulator, disk) @@ -413,33 +419,33 @@ class TestFilesystemManipulator(unittest.TestCase): manipulator = make_manipulator(Bootloader.BIOS, version) disk = make_disk(manipulator.model, preserve=True) full_size = gaps.largest_gap_size(disk) - make_partition( - manipulator.model, disk, size=full_size, preserve=True) + make_partition(manipulator.model, disk, size=full_size, preserve=True) self.assertFalse(boot.can_be_boot_device(disk)) def test_no_add_boot_BIOS_preserved_v1(self): manipulator = make_manipulator(Bootloader.BIOS, 1) disk = make_disk(manipulator.model, preserve=True) - half_size = gaps.largest_gap_size(disk)//2 + half_size = gaps.largest_gap_size(disk) // 2 make_partition( - manipulator.model, disk, size=half_size, offset=half_size, - preserve=True) + manipulator.model, disk, size=half_size, offset=half_size, preserve=True + ) self.assertFalse(boot.can_be_boot_device(disk)) def test_add_boot_BIOS_preserved_v2(self): manipulator = make_manipulator(Bootloader.BIOS, 2) disk = make_disk(manipulator.model, preserve=True) - half_size = gaps.largest_gap_size(disk)//2 + half_size = gaps.largest_gap_size(disk) // 2 part = make_partition( - manipulator.model, disk, size=half_size, offset=half_size, - preserve=True) + manipulator.model, disk, size=half_size, offset=half_size, preserve=True + ) with self.assertPartitionOperations( - disk, - Create( - offset=disk.alignment_data().min_start_offset, - size=sizes.BIOS_GRUB_SIZE_BYTES), - Unchanged(part=part), - ): + disk, + Create( + offset=disk.alignment_data().min_start_offset, + size=sizes.BIOS_GRUB_SIZE_BYTES, + ), + Unchanged(part=part), + ): manipulator.add_boot_disk(disk) self.assertIsBootDisk(manipulator, disk) @@ -447,23 +453,25 @@ class TestFilesystemManipulator(unittest.TestCase): manipulator = make_manipulator(Bootloader.BIOS, 2) # 2002MiB so that the space available for partitioning (2000MiB) # divided by 4 is an whole number of megabytes. - disk = make_disk(manipulator.model, preserve=True, size=2002*MiB) + disk = make_disk(manipulator.model, preserve=True, size=2002 * MiB) avail = gaps.largest_gap_size(disk) - p1_new = make_partition( - manipulator.model, disk, size=avail//4) + p1_new = make_partition(manipulator.model, disk, size=avail // 4) p2_preserved = make_partition( - manipulator.model, disk, size=3*avail//4, preserve=True) + manipulator.model, disk, size=3 * avail // 4, preserve=True + ) with self.assertPartitionOperations( - disk, - Create( - offset=disk.alignment_data().min_start_offset, - size=sizes.BIOS_GRUB_SIZE_BYTES), - MoveResize( - part=p1_new, - offset=sizes.BIOS_GRUB_SIZE_BYTES, - size=-sizes.BIOS_GRUB_SIZE_BYTES), - Unchanged(part=p2_preserved), - ): + disk, + Create( + offset=disk.alignment_data().min_start_offset, + size=sizes.BIOS_GRUB_SIZE_BYTES, + ), + MoveResize( + part=p1_new, + offset=sizes.BIOS_GRUB_SIZE_BYTES, + size=-sizes.BIOS_GRUB_SIZE_BYTES, + ), + Unchanged(part=p2_preserved), + ): manipulator.add_boot_disk(disk) self.assertIsBootDisk(manipulator, disk) @@ -471,10 +479,8 @@ class TestFilesystemManipulator(unittest.TestCase): def test_no_add_boot_BIOS_preserved_at_start(self, version): manipulator = make_manipulator(Bootloader.BIOS, version) disk = make_disk(manipulator.model, preserve=True) - half_size = gaps.largest_gap_size(disk)//2 - make_partition( - manipulator.model, disk, size=half_size, - preserve=True) + half_size = gaps.largest_gap_size(disk) // 2 + make_partition(manipulator.model, disk, size=half_size, preserve=True) self.assertFalse(boot.can_be_boot_device(disk)) @parameterized.expand([(1,), (2,)]) @@ -482,20 +488,20 @@ class TestFilesystemManipulator(unittest.TestCase): # Showed up in VM testing - the ISO created cloud-localds shows up as a # potential install target, and weird things happen. manipulator = make_manipulator(Bootloader.BIOS, version) - disk = make_disk(manipulator.model, ptable='gpt', size=1 << 20) + disk = make_disk(manipulator.model, ptable="gpt", size=1 << 20) self.assertFalse(boot.can_be_boot_device(disk)) @parameterized.expand([(1,), (2,)]) def test_add_boot_BIOS_msdos(self, version): manipulator = make_manipulator(Bootloader.BIOS, version) - disk = make_disk(manipulator.model, preserve=True, ptable='msdos') + disk = make_disk(manipulator.model, preserve=True, ptable="msdos") part = make_partition( - manipulator.model, disk, size=gaps.largest_gap_size(disk), - preserve=True) + manipulator.model, disk, size=gaps.largest_gap_size(disk), preserve=True + ) with self.assertPartitionOperations( - disk, - Unchanged(part=part), - ): + disk, + Unchanged(part=part), + ): manipulator.add_boot_disk(disk) self.assertIsBootDisk(manipulator, disk) @@ -512,18 +518,20 @@ class TestFilesystemManipulator(unittest.TestCase): (Bootloader.UEFI, 1), (Bootloader.PREP, 1), (Bootloader.UEFI, 2), - (Bootloader.PREP, 2)] + (Bootloader.PREP, 2), + ] @parameterized.expand(ADD_BOOT_PARAMS) def test_add_boot_empty(self, bl, version): manipulator = make_manipulator(bl, version) disk = make_disk(manipulator.model, preserve=True) with self.assertPartitionOperations( - disk, - Create( - offset=disk.alignment_data().min_start_offset, - size=self.boot_size_for_disk(disk)), - ): + disk, + Create( + offset=disk.alignment_data().min_start_offset, + size=self.boot_size_for_disk(disk), + ), + ): manipulator.add_boot_disk(disk) self.assertIsBootDisk(manipulator, disk) @@ -531,19 +539,13 @@ class TestFilesystemManipulator(unittest.TestCase): def test_add_boot_full(self, bl, version): manipulator = make_manipulator(bl, version) disk = make_disk(manipulator.model, preserve=True) - part = make_partition( - manipulator.model, disk, size=gaps.largest_gap_size(disk)) + part = make_partition(manipulator.model, disk, size=gaps.largest_gap_size(disk)) size = self.boot_size_for_disk(disk) with self.assertPartitionOperations( - disk, - Create( - offset=disk.alignment_data().min_start_offset, - size=size), - MoveResize( - part=part, - offset=size, - size=-size), - ): + disk, + Create(offset=disk.alignment_data().min_start_offset, size=size), + MoveResize(part=part, offset=size, size=-size), + ): manipulator.add_boot_disk(disk) self.assertIsBootDisk(manipulator, disk) @@ -552,15 +554,14 @@ class TestFilesystemManipulator(unittest.TestCase): manipulator = make_manipulator(bl, version) disk = make_disk(manipulator.model, preserve=True) part = make_partition( - manipulator.model, disk, size=gaps.largest_gap_size(disk)//2) + manipulator.model, disk, size=gaps.largest_gap_size(disk) // 2 + ) size = self.boot_size_for_disk(disk) with self.assertPartitionOperations( - disk, - Unchanged(part=part), - Create( - offset=part.offset + part.size, - size=size), - ): + disk, + Unchanged(part=part), + Create(offset=part.offset + part.size, size=size), + ): manipulator.add_boot_disk(disk) self.assertIsBootDisk(manipulator, disk) @@ -569,23 +570,20 @@ class TestFilesystemManipulator(unittest.TestCase): manipulator = make_manipulator(bl, version) # 2002MiB so that the space available for partitioning (2000MiB) # divided by 4 is an whole number of megabytes. - disk = make_disk(manipulator.model, preserve=True, size=2002*MiB) + disk = make_disk(manipulator.model, preserve=True, size=2002 * MiB) part_smaller = make_partition( - manipulator.model, disk, size=gaps.largest_gap_size(disk)//4) + manipulator.model, disk, size=gaps.largest_gap_size(disk) // 4 + ) part_larger = make_partition( - manipulator.model, disk, size=gaps.largest_gap_size(disk)) + manipulator.model, disk, size=gaps.largest_gap_size(disk) + ) size = self.boot_size_for_disk(disk) with self.assertPartitionOperations( - disk, - Unchanged(part_smaller), - Create( - offset=part_smaller.offset + part_smaller.size, - size=size), - MoveResize( - part=part_larger, - offset=size, - size=-size), - ): + disk, + Unchanged(part_smaller), + Create(offset=part_smaller.offset + part_smaller.size, size=size), + MoveResize(part=part_larger, offset=size, size=-size), + ): manipulator.add_boot_disk(disk) self.assertIsBootDisk(manipulator, disk) @@ -594,29 +592,24 @@ class TestFilesystemManipulator(unittest.TestCase): manipulator = make_manipulator(bl, version) disk = make_disk(manipulator.model, preserve=True) full_size = gaps.largest_gap_size(disk) - make_partition( - manipulator.model, disk, size=full_size, preserve=True) + make_partition(manipulator.model, disk, size=full_size, preserve=True) self.assertFalse(boot.can_be_boot_device(disk)) @parameterized.expand(ADD_BOOT_PARAMS) def test_add_boot_half_preserved_partition(self, bl, version): manipulator = make_manipulator(bl, version) disk = make_disk(manipulator.model, preserve=True) - half_size = gaps.largest_gap_size(disk)//2 - part = make_partition( - manipulator.model, disk, size=half_size, - preserve=True) + half_size = gaps.largest_gap_size(disk) // 2 + part = make_partition(manipulator.model, disk, size=half_size, preserve=True) size = self.boot_size_for_disk(disk) if version == 1: self.assertFalse(boot.can_be_boot_device(disk)) else: with self.assertPartitionOperations( - disk, - Unchanged(part), - Create( - offset=part.offset + part.size, - size=size), - ): + disk, + Unchanged(part), + Create(offset=part.offset + part.size, size=size), + ): manipulator.add_boot_disk(disk) self.assertIsBootDisk(manipulator, disk) @@ -626,84 +619,77 @@ class TestFilesystemManipulator(unittest.TestCase): # situation in a v2 world. manipulator = make_manipulator(bl, 2) disk = make_disk(manipulator.model, preserve=True) - half_size = gaps.largest_gap_size(disk)//2 - old_part = make_partition( - manipulator.model, disk, size=half_size) - new_part = make_partition( - manipulator.model, disk, size=half_size) + half_size = gaps.largest_gap_size(disk) // 2 + old_part = make_partition(manipulator.model, disk, size=half_size) + new_part = make_partition(manipulator.model, disk, size=half_size) old_part.preserve = True size = self.boot_size_for_disk(disk) with self.assertPartitionOperations( - disk, - Unchanged(old_part), - Create( - offset=new_part.offset, - size=size), - MoveResize( - part=new_part, - offset=size, - size=-size), - ): + disk, + Unchanged(old_part), + Create(offset=new_part.offset, size=size), + MoveResize(part=new_part, offset=size, size=-size), + ): manipulator.add_boot_disk(disk) self.assertIsBootDisk(manipulator, disk) @parameterized.expand(ADD_BOOT_PARAMS) def test_add_boot_existing_boot_partition(self, bl, version): manipulator = make_manipulator(bl, version) - disk = make_disk(manipulator.model, preserve=True, size=2002*MiB) + disk = make_disk(manipulator.model, preserve=True, size=2002 * MiB) if bl == Bootloader.UEFI: - flag = 'boot' + flag = "boot" else: - flag = 'prep' + flag = "prep" boot_part = make_partition( - manipulator.model, disk, size=self.boot_size_for_disk(disk), - flag=flag) + manipulator.model, disk, size=self.boot_size_for_disk(disk), flag=flag + ) rest_part = make_partition(manipulator.model, disk) boot_part.preserve = rest_part.preserve = True with self.assertPartitionOperations( - disk, - Unchanged(boot_part), - Unchanged(rest_part), - ): + disk, + Unchanged(boot_part), + Unchanged(rest_part), + ): manipulator.add_boot_disk(disk) self.assertIsBootDisk(manipulator, disk) def test_logical_no_esp_in_gap(self): manipulator = make_manipulator(Bootloader.UEFI, storage_version=2) m = manipulator.model - d1 = make_disk(m, preserve=True, ptable='msdos') + d1 = make_disk(m, preserve=True, ptable="msdos") size = gaps.largest_gap_size(d1) - make_partition(m, d1, flag='extended', preserve=True, size=size) + make_partition(m, d1, flag="extended", preserve=True, size=size) size = gaps.largest_gap_size(d1) // 2 - make_partition(m, d1, flag='logical', preserve=True, size=size) + make_partition(m, d1, flag="logical", preserve=True, size=size) self.assertTrue(gaps.largest_gap_size(d1) > sizes.uefi_scale.maximum) self.assertFalse(boot.can_be_boot_device(d1)) def test_logical_no_esp_by_resize(self): manipulator = make_manipulator(Bootloader.UEFI, storage_version=2) m = manipulator.model - d1 = make_disk(m, preserve=True, ptable='msdos') + d1 = make_disk(m, preserve=True, ptable="msdos") size = gaps.largest_gap_size(d1) - make_partition(m, d1, flag='extended', preserve=True, size=size) + make_partition(m, d1, flag="extended", preserve=True, size=size) size = gaps.largest_gap_size(d1) - p5 = make_partition(m, d1, flag='logical', preserve=True, size=size) + p5 = make_partition(m, d1, flag="logical", preserve=True, size=size) self.assertFalse(boot.can_be_boot_device(d1, resize_partition=p5)) def test_logical_no_esp_in_new_part(self): manipulator = make_manipulator(Bootloader.UEFI, storage_version=2) m = manipulator.model - d1 = make_disk(m, preserve=True, ptable='msdos') + d1 = make_disk(m, preserve=True, ptable="msdos") size = gaps.largest_gap_size(d1) - make_partition(m, d1, flag='extended', size=size) + make_partition(m, d1, flag="extended", size=size) size = gaps.largest_gap_size(d1) - make_partition(m, d1, flag='logical', size=size) + make_partition(m, d1, flag="logical", size=size) self.assertFalse(boot.can_be_boot_device(d1)) @parameterized.expand([[Bootloader.UEFI], [Bootloader.PREP]]) def test_too_many_primaries(self, bl): manipulator = make_manipulator(bl, storage_version=2) m = manipulator.model - d1 = make_disk(m, preserve=True, ptable='msdos', size=100 << 30) + d1 = make_disk(m, preserve=True, ptable="msdos", size=100 << 30) tenth = 10 << 30 for i in range(4): make_partition(m, d1, preserve=True, size=tenth) @@ -715,12 +701,12 @@ class TestFilesystemManipulator(unittest.TestCase): def test_logical_when_in_extended(self): manipulator = make_manipulator(storage_version=2) m = manipulator.model - d1 = make_disk(m, preserve=True, ptable='msdos') + d1 = make_disk(m, preserve=True, ptable="msdos") size = gaps.largest_gap_size(d1) - make_partition(m, d1, flag='extended', size=size) + make_partition(m, d1, flag="extended", size=size) gap = gaps.largest_gap(d1) - p = manipulator.create_partition(d1, gap, spec={}, flag='primary') - self.assertEqual(p.flag, 'logical') + p = manipulator.create_partition(d1, gap, spec={}, flag="primary") + self.assertEqual(p.flag, "logical") class TestReformat(unittest.TestCase): @@ -733,19 +719,19 @@ class TestReformat(unittest.TestCase): self.assertEqual(None, disk.ptable) def test_reformat_keep_current(self): - disk = make_disk(self.manipulator.model, ptable='msdos') + disk = make_disk(self.manipulator.model, ptable="msdos") self.manipulator.reformat(disk) - self.assertEqual('msdos', disk.ptable) + self.assertEqual("msdos", disk.ptable) def test_reformat_to_gpt(self): disk = make_disk(self.manipulator.model, ptable=None) - self.manipulator.reformat(disk, 'gpt') - self.assertEqual('gpt', disk.ptable) + self.manipulator.reformat(disk, "gpt") + self.assertEqual("gpt", disk.ptable) def test_reformat_to_msdos(self): disk = make_disk(self.manipulator.model, ptable=None) - self.manipulator.reformat(disk, 'msdos') - self.assertEqual('msdos', disk.ptable) + self.manipulator.reformat(disk, "msdos") + self.assertEqual("msdos", disk.ptable) class TestCanResize(unittest.TestCase): @@ -760,11 +746,11 @@ class TestCanResize(unittest.TestCase): def test_resize_ext4(self): disk = make_disk(self.manipulator.model, ptable=None) part = make_partition(self.manipulator.model, disk, preserve=True) - make_filesystem(self.manipulator.model, partition=part, fstype='ext4') + make_filesystem(self.manipulator.model, partition=part, fstype="ext4") self.assertTrue(self.manipulator.can_resize_partition(part)) def test_resize_invalid(self): disk = make_disk(self.manipulator.model, ptable=None) part = make_partition(self.manipulator.model, disk, preserve=True) - make_filesystem(self.manipulator.model, partition=part, fstype='asdf') + make_filesystem(self.manipulator.model, partition=part, fstype="asdf") self.assertFalse(self.manipulator.can_resize_partition(part)) diff --git a/subiquity/common/filesystem/tests/test_sizes.py b/subiquity/common/filesystem/tests/test_sizes.py index 98912b35..84cbfde4 100644 --- a/subiquity/common/filesystem/tests/test_sizes.py +++ b/subiquity/common/filesystem/tests/test_sizes.py @@ -18,19 +18,17 @@ import unittest from unittest import mock from subiquity.common.filesystem.sizes import ( + PartitionScaleFactors, bootfs_scale, calculate_guided_resize, calculate_suggested_install_min, - get_efi_size, get_bootfs_size, - PartitionScaleFactors, + get_efi_size, scale_partitions, uefi_scale, - ) +) from subiquity.common.types import GuidedResizeValues -from subiquity.models.filesystem import ( - MiB, - ) +from subiquity.models.filesystem import MiB class TestPartitionSizeScaling(unittest.TestCase): @@ -87,46 +85,54 @@ class TestPartitionSizeScaling(unittest.TestCase): class TestCalculateGuidedResize(unittest.TestCase): def test_ignore_nonresizable(self): actual = calculate_guided_resize( - part_min=-1, part_size=100 << 30, install_min=10 << 30) + part_min=-1, part_size=100 << 30, install_min=10 << 30 + ) self.assertIsNone(actual) def test_too_small(self): actual = calculate_guided_resize( - part_min=95 << 30, part_size=100 << 30, install_min=10 << 30) + part_min=95 << 30, part_size=100 << 30, install_min=10 << 30 + ) self.assertIsNone(actual) def test_even_split(self): # 8 GiB * 1.25 == 10 GiB size = 10 << 30 actual = calculate_guided_resize( - part_min=8 << 30, part_size=100 << 30, install_min=size) + part_min=8 << 30, part_size=100 << 30, install_min=size + ) expected = GuidedResizeValues( - install_max=(100 << 30) - size, - minimum=size, recommended=50 << 30, maximum=(100 << 30) - size) + install_max=(100 << 30) - size, + minimum=size, + recommended=50 << 30, + maximum=(100 << 30) - size, + ) self.assertEqual(expected, actual) def test_weighted_split(self): actual = calculate_guided_resize( - part_min=40 << 30, part_size=240 << 30, install_min=10 << 30) + part_min=40 << 30, part_size=240 << 30, install_min=10 << 30 + ) expected = GuidedResizeValues( - install_max=190 << 30, - minimum=50 << 30, recommended=200 << 30, maximum=230 << 30) + install_max=190 << 30, + minimum=50 << 30, + recommended=200 << 30, + maximum=230 << 30, + ) self.assertEqual(expected, actual) def test_unaligned_size(self): - def assertAligned(val): if val % MiB != 0: self.fail( - "val of %s not aligned (%s) with delta %s" % ( - val, val % MiB, delta)) + "val of %s not aligned (%s) with delta %s" % (val, val % MiB, delta) + ) for i in range(1000): delta = random.randrange(MiB) actual = calculate_guided_resize( - part_min=40 << 30, - part_size=(240 << 30) + delta, - install_min=10 << 30) + part_min=40 << 30, part_size=(240 << 30) + delta, install_min=10 << 30 + ) assertAligned(actual.install_max) assertAligned(actual.minimum) assertAligned(actual.recommended) @@ -134,8 +140,8 @@ class TestCalculateGuidedResize(unittest.TestCase): class TestCalculateInstallMin(unittest.TestCase): - @mock.patch('subiquity.common.filesystem.sizes.swap.suggested_swapsize') - @mock.patch('subiquity.common.filesystem.sizes.bootfs_scale') + @mock.patch("subiquity.common.filesystem.sizes.swap.suggested_swapsize") + @mock.patch("subiquity.common.filesystem.sizes.bootfs_scale") def test_small_setups(self, bootfs_scale, swapsize): swapsize.return_value = 1 << 30 bootfs_scale.maximum = 1 << 30 @@ -143,24 +149,24 @@ class TestCalculateInstallMin(unittest.TestCase): # with a small source, we hit the default 2GiB padding self.assertEqual(5 << 30, calculate_suggested_install_min(source_min)) - @mock.patch('subiquity.common.filesystem.sizes.swap.suggested_swapsize') - @mock.patch('subiquity.common.filesystem.sizes.bootfs_scale') + @mock.patch("subiquity.common.filesystem.sizes.swap.suggested_swapsize") + @mock.patch("subiquity.common.filesystem.sizes.bootfs_scale") def test_small_setups_big_swap(self, bootfs_scale, swapsize): swapsize.return_value = 10 << 30 bootfs_scale.maximum = 1 << 30 source_min = 1 << 30 self.assertEqual(14 << 30, calculate_suggested_install_min(source_min)) - @mock.patch('subiquity.common.filesystem.sizes.swap.suggested_swapsize') - @mock.patch('subiquity.common.filesystem.sizes.bootfs_scale') + @mock.patch("subiquity.common.filesystem.sizes.swap.suggested_swapsize") + @mock.patch("subiquity.common.filesystem.sizes.bootfs_scale") def test_small_setups_big_boot(self, bootfs_scale, swapsize): swapsize.return_value = 1 << 30 bootfs_scale.maximum = 10 << 30 source_min = 1 << 30 self.assertEqual(14 << 30, calculate_suggested_install_min(source_min)) - @mock.patch('subiquity.common.filesystem.sizes.swap.suggested_swapsize') - @mock.patch('subiquity.common.filesystem.sizes.bootfs_scale') + @mock.patch("subiquity.common.filesystem.sizes.swap.suggested_swapsize") + @mock.patch("subiquity.common.filesystem.sizes.bootfs_scale") def test_big_source(self, bootfs_scale, swapsize): swapsize.return_value = 1 << 30 bootfs_scale.maximum = 2 << 30 diff --git a/subiquity/common/resources.py b/subiquity/common/resources.py index 75ec169e..0c43a632 100644 --- a/subiquity/common/resources.py +++ b/subiquity/common/resources.py @@ -20,7 +20,7 @@ import logging import os from typing import List -log = logging.getLogger('subiquity.common.resources') +log = logging.getLogger("subiquity.common.resources") def resource_path(relative_path): @@ -31,17 +31,17 @@ def get_users_and_groups(chroot_prefix=[]) -> List: # prevent import when calling just resource_path from subiquitycore.utils import run_command - users_and_groups_path = resource_path('users-and-groups') - groups = ['admin'] + users_and_groups_path = resource_path("users-and-groups") + groups = ["admin"] if os.path.exists(users_and_groups_path): with open(users_and_groups_path) as f: groups = f.read().split() - groups.append('sudo') + groups.append("sudo") - command = chroot_prefix + ['getent', 'group'] + command = chroot_prefix + ["getent", "group"] cp = run_command(command, check=True) target_groups = set() for line in cp.stdout.splitlines(): - target_groups.add(line.split(':')[0]) + target_groups.add(line.split(":")[0]) return list(target_groups.intersection(groups)) diff --git a/subiquity/common/serialize.py b/subiquity/common/serialize.py index 7d1f61eb..483b0989 100644 --- a/subiquity/common/serialize.py +++ b/subiquity/common/serialize.py @@ -15,19 +15,19 @@ import datetime import enum -import json import inspect +import json import typing import attr def named_field(name, default=attr.NOTHING): - return attr.ib(metadata={'name': name}, default=default) + return attr.ib(metadata={"name": name}, default=default) def _field_name(field): - return field.metadata.get('name', field.name) + return field.metadata.get("name", field.name) class SerializationError(Exception): @@ -39,7 +39,7 @@ class SerializationError(Exception): def __str__(self): p = self.path if not p: - p = 'top-level' + p = "top-level" return f"processing {self.obj}: at {p}, {self.message}" @@ -53,13 +53,12 @@ class SerializationContext: @classmethod def new(cls, obj, *, serializing): - return SerializationContext(obj, obj, '', {}, serializing) + return SerializationContext(obj, obj, "", {}, serializing) def child(self, path, cur, metadata=None): if metadata is None: metadata = self.metadata - return attr.evolve( - self, path=self.path + path, cur=cur, metadata=metadata) + return attr.evolve(self, path=self.path + path, cur=cur, metadata=metadata) def error(self, message): raise SerializationError(self.obj, self.path, message) @@ -74,9 +73,9 @@ class SerializationContext: class Serializer: - - def __init__(self, *, compact=False, ignore_unknown_fields=False, - serialize_enums_by="name"): + def __init__( + self, *, compact=False, ignore_unknown_fields=False, serialize_enums_by="name" + ): self.compact = compact self.ignore_unknown_fields = ignore_unknown_fields assert serialize_enums_by in ("value", "name") @@ -87,7 +86,7 @@ class Serializer: typing.List: self._walk_List, dict: self._walk_Dict, typing.Dict: self._walk_Dict, - } + } self.type_serializers = {} self.type_deserializers = {} for typ in int, str, bool, list, type(None): @@ -119,14 +118,14 @@ class Serializer: if self.compact: r.insert(0, a.__name__) else: - r['$type'] = a.__name__ + r["$type"] = a.__name__ return r context.error(f"type of {context.cur} not found in {args}") else: if self.compact: n = context.cur.pop(0) else: - n = context.cur.pop('$type') + n = context.cur.pop("$type") for a in args: if a.__name__ == n: return meth(a, context) @@ -135,9 +134,8 @@ class Serializer: def _walk_List(self, meth, args, context): return [ - meth(args[0], context.child(f'[{i}]', v)) - for i, v in enumerate(context.cur) - ] + meth(args[0], context.child(f"[{i}]", v)) for i, v in enumerate(context.cur) + ] def _walk_Dict(self, meth, args, context): k_ann, v_ann = args @@ -145,10 +143,13 @@ class Serializer: input_items = context.cur else: input_items = context.cur.items() - output_items = [[ - meth(k_ann, context.child(f'/{k}', k)), - meth(v_ann, context.child(f'[{k}]', v)) - ] for k, v in input_items] + output_items = [ + [ + meth(k_ann, context.child(f"/{k}", k)), + meth(v_ann, context.child(f"[{k}]", v)), + ] + for k, v in input_items + ] if context.serializing and k_ann is not str: return output_items return dict(output_items) @@ -156,12 +157,12 @@ class Serializer: def _serialize_dict(self, annotation, context): context.assert_type(annotation) for k in context.cur: - context.child(f'/{k}', k).assert_type(str) + context.child(f"/{k}", k).assert_type(str) return context.cur def _serialize_datetime(self, annotation, context): context.assert_type(annotation) - fmt = context.metadata.get('time_fmt') + fmt = context.metadata.get("time_fmt") if fmt is not None: return context.cur.strftime(fmt) else: @@ -170,15 +171,19 @@ class Serializer: def _serialize_attr(self, annotation, context): serialized = [] for field in attr.fields(annotation): - serialized.append(( - _field_name(field), - self._serialize( - field.type, - context.child( - f'.{field.name}', - getattr(context.cur, field.name), - field.metadata)), - )) + serialized.append( + ( + _field_name(field), + self._serialize( + field.type, + context.child( + f".{field.name}", + getattr(context.cur, field.name), + field.metadata, + ), + ), + ) + ) if self.compact: return [s[1] for s in serialized] else: @@ -196,7 +201,7 @@ class Serializer: return context.cur if attr.has(annotation): return self._serialize_attr(annotation, context) - origin = getattr(annotation, '__origin__', None) + origin = getattr(annotation, "__origin__", None) if origin is not None: args = annotation.__args__ return self.typing_walkers[origin](self._serialize, args, context) @@ -214,7 +219,7 @@ class Serializer: return self._serialize(annotation, context) def _deserialize_datetime(self, annotation, context): - fmt = context.metadata.get('time_fmt') + fmt = context.metadata.get("time_fmt") if fmt is None: context.error("cannot serialize datetime without format") return datetime.datetime.strptime(context.cur, fmt) @@ -224,19 +229,19 @@ class Serializer: context.assert_type(list) args = [] for field, value in zip(attr.fields(annotation), context.cur): - args.append(self._deserialize( - field.type, - context.child(f'[{field.name!r}]', value, field.metadata))) + args.append( + self._deserialize( + field.type, + context.child(f"[{field.name!r}]", value, field.metadata), + ) + ) return annotation(*args) else: context.assert_type(dict) args = {} - fields = { - _field_name(field): field for field in attr.fields(annotation) - } + fields = {_field_name(field): field for field in attr.fields(annotation)} for key, value in context.cur.items(): - if key not in fields and ( - key == '$type' or self.ignore_unknown_fields): + if key not in fields and (key == "$type" or self.ignore_unknown_fields): # Union types can contain a '$type' field that is not # actually one of the keys. This happens if a object is # serialized as part of a Union, sent to an API caller, @@ -245,8 +250,8 @@ class Serializer: continue field = fields[key] args[field.name] = self._deserialize( - field.type, - context.child(f'[{key!r}]', value, field.metadata)) + field.type, context.child(f"[{key!r}]", value, field.metadata) + ) return annotation(**args) def _deserialize_enum(self, annotation, context): @@ -263,10 +268,11 @@ class Serializer: return context.cur if attr.has(annotation): return self._deserialize_attr(annotation, context) - origin = getattr(annotation, '__origin__', None) + origin = getattr(annotation, "__origin__", None) if origin is not None: return self.typing_walkers[origin]( - self._deserialize, annotation.__args__, context) + self._deserialize, annotation.__args__, context + ) if isinstance(annotation, type) and issubclass(annotation, enum.Enum): return self._deserialize_enum(annotation, context) return self.type_deserializers[annotation](annotation, context) diff --git a/subiquity/common/tests/test_serialization.py b/subiquity/common/tests/test_serialization.py index f6c414af..3ada5e65 100644 --- a/subiquity/common/tests/test_serialization.py +++ b/subiquity/common/tests/test_serialization.py @@ -13,7 +13,6 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import attr import datetime import enum import inspect @@ -22,11 +21,9 @@ import string import typing import unittest -from subiquity.common.serialize import ( - named_field, - Serializer, - SerializationError, - ) +import attr + +from subiquity.common.serialize import SerializationError, Serializer, named_field @attr.s(auto_attribs=True) @@ -36,9 +33,7 @@ class Data: @staticmethod def make_random(): - return Data( - random.choice(string.ascii_letters), - random.randint(0, 1000)) + return Data(random.choice(string.ascii_letters), random.randint(0, 1000)) @attr.s(auto_attribs=True) @@ -50,8 +45,8 @@ class Container: def make_random(): return Container( data=Data.make_random(), - data_list=[Data.make_random() - for i in range(random.randint(10, 20))]) + data_list=[Data.make_random() for i in range(random.randint(10, 20))], + ) @attr.s(auto_attribs=True) @@ -67,27 +62,23 @@ class MyEnum(enum.Enum): class CommonSerializerTests: - simple_examples = [ (int, 1), (str, "v"), (list, [1]), (dict, {"2": 3}), (type(None), None), - ] + ] def assertSerializesTo(self, annotation, value, expected): - self.assertEqual( - self.serializer.serialize(annotation, value), expected) + self.assertEqual(self.serializer.serialize(annotation, value), expected) def assertDeserializesTo(self, annotation, value, expected): - self.assertEqual( - self.serializer.deserialize(annotation, value), expected) + self.assertEqual(self.serializer.deserialize(annotation, value), expected) def assertRoundtrips(self, annotation, value): serialized = self.serializer.to_json(annotation, value) - self.assertEqual( - self.serializer.from_json(annotation, serialized), value) + self.assertEqual(self.serializer.from_json(annotation, serialized), value) def assertSerialization(self, annotation, value, expected): self.assertSerializesTo(annotation, value, expected) @@ -114,8 +105,7 @@ class CommonSerializerTests: self.assertSerialization(typ, val, val) def test_non_string_key_dict(self): - self.assertRaises( - Exception, self.serializer.serialize, dict, {1: 2}) + self.assertRaises(Exception, self.serializer.serialize, dict, {1: 2}) def test_roundtrip_dict(self): ann = typing.Dict[int, str] @@ -141,7 +131,8 @@ class CommonSerializerTests: def test_enums_by_value(self): self.serializer = type(self.serializer)( - compact=self.serializer.compact, serialize_enums_by="value") + compact=self.serializer.compact, serialize_enums_by="value" + ) self.assertSerialization(MyEnum, MyEnum.name, "value") def test_serialize_any(self): @@ -151,19 +142,19 @@ class CommonSerializerTests: class TestSerializer(CommonSerializerTests, unittest.TestCase): - serializer = Serializer() def test_datetime(self): @attr.s class C: - d: datetime.datetime = attr.ib(metadata={'time_fmt': '%Y-%m-%d'}) + d: datetime.datetime = attr.ib(metadata={"time_fmt": "%Y-%m-%d"}) + c = C(datetime.datetime(2022, 1, 1)) self.assertSerialization(C, c, {"d": "2022-01-01"}) def test_serialize_attr(self): data = Data.make_random() - expected = {'field1': data.field1, 'field2': data.field2} + expected = {"field1": data.field1, "field2": data.field2} self.assertSerialization(Data, data, expected) def test_serialize_container(self): @@ -171,20 +162,20 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase): data2 = Data.make_random() container = Container(data1, [data2]) expected = { - 'data': {'field1': data1.field1, 'field2': data1.field2}, - 'data_list': [ - {'field1': data2.field1, 'field2': data2.field2}, - ], - } + "data": {"field1": data1.field1, "field2": data1.field2}, + "data_list": [ + {"field1": data2.field1, "field2": data2.field2}, + ], + } self.assertSerialization(Container, container, expected) def test_serialize_union(self): data = Data.make_random() expected = { - '$type': 'Data', - 'field1': data.field1, - 'field2': data.field2, - } + "$type": "Data", + "field1": data.field1, + "field2": data.field2, + } self.assertSerialization(typing.Union[Data, Container], data, expected) def test_arbitrary_types_may_have_type_field(self): @@ -193,18 +184,18 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase): # API entrypoint, one that isn't taking a Union, it must be cool with # the excess $type field. data = { - '$type': 'Data', - 'field1': '1', - 'field2': 2, + "$type": "Data", + "field1": "1", + "field2": 2, } - expected = Data(field1='1', field2=2) + expected = Data(field1="1", field2=2) self.assertDeserializesTo(Data, data, expected) def test_reject_unknown_fields_by_default(self): serializer = Serializer() data = Data.make_random() serialized = serializer.serialize(Data, data) - serialized['foobar'] = 'baz' + serialized["foobar"] = "baz" with self.assertRaises(KeyError): serializer.deserialize(Data, serialized) @@ -212,11 +203,10 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase): serializer = Serializer(ignore_unknown_fields=True) data = Data.make_random() serialized = serializer.serialize(Data, data) - serialized['foobar'] = 'baz' + serialized["foobar"] = "baz" self.assertEqual(serializer.deserialize(Data, serialized), data) def test_override_field_name(self): - @attr.s(auto_attribs=True) class Object: x: int @@ -224,10 +214,10 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase): z: int = named_field("field-z", 0) self.assertSerialization( - Object, Object(1, 2), {"x": 1, "field-y": 2, "field-z": 0}) + Object, Object(1, 2), {"x": 1, "field-y": 2, "field-z": 0} + ) def test_embedding(self): - @attr.s(auto_attribs=True) class Base1: x: str @@ -249,28 +239,28 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase): self.assertSerialization( Derived2, Derived2(b=Derived1(x="a", y=1), c=2), - {"b": {"x": "a", "y": 1}, "c": 2}) + {"b": {"x": "a", "y": 1}, "c": 2}, + ) def test_error_paths(self): with self.assertRaises(SerializationError) as catcher: self.serializer.serialize(str, 1) - self.assertEqual(catcher.exception.path, '') + self.assertEqual(catcher.exception.path, "") @attr.s(auto_attribs=True) class Type: - field1: str = named_field('field-1') + field1: str = named_field("field-1") field2: int with self.assertRaises(SerializationError) as catcher: self.serializer.serialize(Type, Data(2, 3)) - self.assertEqual(catcher.exception.path, '.field1') + self.assertEqual(catcher.exception.path, ".field1") with self.assertRaises(SerializationError) as catcher: - self.serializer.deserialize(Type, {'field-1': 1, 'field2': 2}) + self.serializer.deserialize(Type, {"field-1": 1, "field2": 2}) self.assertEqual(catcher.exception.path, "['field-1']") class TestCompactSerializer(CommonSerializerTests, unittest.TestCase): - serializer = Serializer(compact=True) def test_serialize_attr(self): @@ -285,75 +275,62 @@ class TestCompactSerializer(CommonSerializerTests, unittest.TestCase): expected = [ [data1.field1, data1.field2], [[data2.field1, data2.field2]], - ] + ] self.assertSerialization(Container, container, expected) def test_serialize_union(self): data = Data.make_random() - expected = ['Data', data.field1, data.field2] + expected = ["Data", data.field1, data.field2] self.assertSerialization(typing.Union[Data, Container], data, expected) class TestOptionalAndDefault(CommonSerializerTests, unittest.TestCase): - serializer = Serializer() def test_happy(self): data = { - 'int': 11, - 'optional_int': 12, - 'int_default': 3, - 'optional_int_default': 4 + "int": 11, + "optional_int": 12, + "int_default": 3, + "optional_int_default": 4, } expected = OptionalAndDefault(11, 12) self.assertDeserializesTo(OptionalAndDefault, data, expected) def test_skip_int(self): - data = { - 'optional_int': 12, - 'int_default': 13, - 'optional_int_default': 14 - } + data = {"optional_int": 12, "int_default": 13, "optional_int_default": 14} with self.assertRaises(TypeError): self.serializer.deserialize(OptionalAndDefault, data) def test_skip_optional_int(self): - data = { - 'int': 11, - 'int_default': 13, - 'optional_int_default': 14 - } + data = {"int": 11, "int_default": 13, "optional_int_default": 14} with self.assertRaises(TypeError): self.serializer.deserialize(OptionalAndDefault, data) def test_optional_int_none(self): data = { - 'int': 11, - 'optional_int': None, - 'int_default': 13, - 'optional_int_default': 14 + "int": 11, + "optional_int": None, + "int_default": 13, + "optional_int_default": 14, } expected = OptionalAndDefault(11, None, 13, 14) self.assertDeserializesTo(OptionalAndDefault, data, expected) def test_skip_int_default(self): - data = { - 'int': 11, - 'optional_int': 12, - 'optional_int_default': 14 - } + data = {"int": 11, "optional_int": 12, "optional_int_default": 14} expected = OptionalAndDefault(11, 12, 3, 14) self.assertDeserializesTo(OptionalAndDefault, data, expected) def test_skip_optional_int_default(self): data = { - 'int': 11, - 'optional_int': 12, - 'int_default': 13, + "int": 11, + "optional_int": 12, + "int_default": 13, } expected = OptionalAndDefault(11, 12, 13, 4) @@ -361,10 +338,10 @@ class TestOptionalAndDefault(CommonSerializerTests, unittest.TestCase): def test_optional_int_default_none(self): data = { - 'int': 11, - 'optional_int': 12, - 'int_default': 13, - 'optional_int_default': None, + "int": 11, + "optional_int": 12, + "int_default": 13, + "optional_int_default": None, } expected = OptionalAndDefault(11, 12, 13, None) @@ -372,11 +349,11 @@ class TestOptionalAndDefault(CommonSerializerTests, unittest.TestCase): def test_extra(self): data = { - 'int': 11, - 'optional_int': 12, - 'int_default': 13, - 'optional_int_default': 14, - 'extra': 15 + "int": 11, + "optional_int": 12, + "int_default": 13, + "optional_int_default": 14, + "extra": 15, } with self.assertRaises(KeyError): diff --git a/subiquity/common/tests/test_types.py b/subiquity/common/tests/test_types.py index 5974a7da..ae6930b0 100644 --- a/subiquity/common/tests/test_types.py +++ b/subiquity/common/tests/test_types.py @@ -15,19 +15,16 @@ import unittest -from subiquity.common.types import ( - GuidedCapability, - SizingPolicy, -) +from subiquity.common.types import GuidedCapability, SizingPolicy class TestSizingPolicy(unittest.TestCase): def test_all(self): - actual = SizingPolicy.from_string('all') + actual = SizingPolicy.from_string("all") self.assertEqual(SizingPolicy.ALL, actual) def test_scaled_size(self): - actual = SizingPolicy.from_string('scaled') + actual = SizingPolicy.from_string("scaled") self.assertEqual(SizingPolicy.SCALED, actual) def test_default(self): diff --git a/subiquity/common/types.py b/subiquity/common/types.py index 813e33c4..ac75673e 100644 --- a/subiquity/common/types.py +++ b/subiquity/common/types.py @@ -56,7 +56,7 @@ class ErrorReportRef: class ApplicationState(enum.Enum): - """ Represents the state of the application at a given time. """ + """Represents the state of the application at a given time.""" # States reported during the initial stages of the installation. STARTING_UP = enum.auto() @@ -125,8 +125,8 @@ class RefreshCheckState(enum.Enum): @attr.s(auto_attribs=True) class RefreshStatus: availability: RefreshCheckState - current_snap_version: str = '' - new_snap_version: str = '' + current_snap_version: str = "" + new_snap_version: str = "" @attr.s(auto_attribs=True) @@ -163,7 +163,7 @@ class KeyboardSetting: # Ideally, we would internally represent a keyboard setting as a # toggle + a list of [layout, variant]. layout: str - variant: str = '' + variant: str = "" toggle: Optional[str] = None @@ -216,7 +216,7 @@ class ZdevInfo: @classmethod def from_row(cls, row): - row = dict((k.split('=', 1) for k in shlex.split(row))) + row = dict((k.split("=", 1) for k in shlex.split(row))) for k, v in row.items(): if v == "yes": row[k] = True @@ -228,8 +228,8 @@ class ZdevInfo: @property def typeclass(self): - if self.type.startswith('zfcp'): - return 'zfcp' + if self.type.startswith("zfcp"): + return "zfcp" return self.type @@ -352,22 +352,25 @@ class GuidedCapability(enum.Enum): CORE_BOOT_PREFER_UNENCRYPTED = enum.auto() def is_lvm(self) -> bool: - return self in [GuidedCapability.LVM, - GuidedCapability.LVM_LUKS] + return self in [GuidedCapability.LVM, GuidedCapability.LVM_LUKS] def is_core_boot(self) -> bool: - return self in [GuidedCapability.CORE_BOOT_ENCRYPTED, - GuidedCapability.CORE_BOOT_UNENCRYPTED, - GuidedCapability.CORE_BOOT_PREFER_ENCRYPTED, - GuidedCapability.CORE_BOOT_PREFER_UNENCRYPTED] + return self in [ + GuidedCapability.CORE_BOOT_ENCRYPTED, + GuidedCapability.CORE_BOOT_UNENCRYPTED, + GuidedCapability.CORE_BOOT_PREFER_ENCRYPTED, + GuidedCapability.CORE_BOOT_PREFER_UNENCRYPTED, + ] def supports_manual_customization(self) -> bool: # After posting this capability to guided_POST, is it possible # for the user to customize the layout further? - return self in [GuidedCapability.MANUAL, - GuidedCapability.DIRECT, - GuidedCapability.LVM, - GuidedCapability.LVM_LUKS] + return self in [ + GuidedCapability.MANUAL, + GuidedCapability.DIRECT, + GuidedCapability.LVM, + GuidedCapability.LVM_LUKS, + ] def is_zfs(self) -> bool: return self in [GuidedCapability.ZFS] @@ -414,11 +417,11 @@ class SizingPolicy(enum.Enum): @classmethod def from_string(cls, value): - if value is None or value == 'scaled': + if value is None or value == "scaled": return cls.SCALED - if value == 'all': + if value == "all": return cls.ALL - raise Exception(f'Unknown SizingPolicy value {value}') + raise Exception(f"Unknown SizingPolicy value {value}") @attr.s(auto_attribs=True) @@ -450,14 +453,14 @@ class GuidedStorageTargetResize: @staticmethod def from_recommendations(part, resize_vals, allowed): return GuidedStorageTargetResize( - disk_id=part.device.id, - partition_number=part.number, - new_size=resize_vals.recommended, - minimum=resize_vals.minimum, - recommended=resize_vals.recommended, - maximum=resize_vals.maximum, - allowed=allowed, - ) + disk_id=part.device.id, + partition_number=part.number, + new_size=resize_vals.recommended, + minimum=resize_vals.minimum, + recommended=resize_vals.recommended, + maximum=resize_vals.maximum, + allowed=allowed, + ) @attr.s(auto_attribs=True) @@ -470,15 +473,16 @@ class GuidedStorageTargetUseGap: @attr.s(auto_attribs=True) class GuidedStorageTargetManual: - allowed: List[GuidedCapability] = attr.Factory( - lambda: [GuidedCapability.MANUAL]) + allowed: List[GuidedCapability] = attr.Factory(lambda: [GuidedCapability.MANUAL]) disallowed: List[GuidedDisallowedCapability] = attr.Factory(list) -GuidedStorageTarget = Union[GuidedStorageTargetReformat, - GuidedStorageTargetResize, - GuidedStorageTargetUseGap, - GuidedStorageTargetManual] +GuidedStorageTarget = Union[ + GuidedStorageTargetReformat, + GuidedStorageTargetResize, + GuidedStorageTargetUseGap, + GuidedStorageTargetManual, +] @attr.s(auto_attribs=True) @@ -519,10 +523,10 @@ class ReformatDisk: @attr.s(auto_attribs=True) class IdentityData: - realname: str = '' - username: str = '' - crypted_password: str = attr.ib(default='', repr=False) - hostname: str = '' + realname: str = "" + username: str = "" + crypted_password: str = attr.ib(default="", repr=False) + hostname: str = "" class UsernameValidation(enum.Enum): @@ -542,7 +546,8 @@ class SSHData: @attr.s(auto_attribs=True) class SSHIdentity: - """ Represents a SSH identity (public key + fingerprint). """ + """Represents a SSH identity (public key + fingerprint).""" + key_type: str key: str key_comment: str @@ -579,25 +584,26 @@ class ChannelSnapInfo: version: str size: int released_at: datetime.datetime = attr.ib( - metadata={'time_fmt': '%Y-%m-%dT%H:%M:%S.%fZ'}) + metadata={"time_fmt": "%Y-%m-%dT%H:%M:%S.%fZ"} + ) @attr.s(auto_attribs=True, eq=False) class SnapInfo: name: str - summary: str = '' - publisher: str = '' + summary: str = "" + publisher: str = "" verified: bool = False starred: bool = False - description: str = '' - confinement: str = '' - license: str = '' + description: str = "" + confinement: str = "" + license: str = "" channels: List[ChannelSnapInfo] = attr.Factory(list) @attr.s(auto_attribs=True) class DriversResponse: - """ Response to GET request to drivers. + """Response to GET request to drivers. :install: tells whether third-party drivers will be installed (if any is available). :drivers: tells what third-party drivers will be installed should we decide @@ -606,6 +612,7 @@ class DriversResponse: :local_only: tells if we are looking for drivers only from the ISO. :search_drivers: enables or disables drivers listing. """ + install: bool drivers: Optional[List[str]] local_only: bool @@ -654,7 +661,8 @@ class UbuntuProInfo: @attr.s(auto_attribs=True) class UbuntuProResponse: - """ Response to GET request to /ubuntu_pro """ + """Response to GET request to /ubuntu_pro""" + token: str = attr.ib(repr=False) has_network: bool @@ -668,7 +676,8 @@ class UbuntuProCheckTokenStatus(enum.Enum): @attr.s(auto_attribs=True) class UPCSInitiateResponse: - """ Response to Ubuntu Pro contract selection initiate request. """ + """Response to Ubuntu Pro contract selection initiate request.""" + user_code: str validity_seconds: int @@ -680,7 +689,8 @@ class UPCSWaitStatus(enum.Enum): @attr.s(auto_attribs=True) class UPCSWaitResponse: - """ Response to Ubuntu Pro contract selection wait request. """ + """Response to Ubuntu Pro contract selection wait request.""" + status: UPCSWaitStatus contract_token: Optional[str] @@ -715,19 +725,19 @@ class ShutdownMode(enum.Enum): @attr.s(auto_attribs=True) class WSLConfigurationBase: - automount_root: str = attr.ib(default='/mnt/') - automount_options: str = '' + automount_root: str = attr.ib(default="/mnt/") + automount_options: str = "" network_generatehosts: bool = attr.ib(default=True) network_generateresolvconf: bool = attr.ib(default=True) @attr.s(auto_attribs=True) class WSLConfigurationAdvanced: - automount_enabled: bool = attr.ib(default=True) - automount_mountfstab: bool = attr.ib(default=True) - interop_enabled: bool = attr.ib(default=True) + automount_enabled: bool = attr.ib(default=True) + automount_mountfstab: bool = attr.ib(default=True) + interop_enabled: bool = attr.ib(default=True) interop_appendwindowspath: bool = attr.ib(default=True) - systemd_enabled: bool = attr.ib(default=False) + systemd_enabled: bool = attr.ib(default=False) # Options that affect the setup experience itself, but won't reflect in the @@ -750,7 +760,7 @@ class TaskStatus(enum.Enum): @attr.s(auto_attribs=True) class TaskProgress: - label: str = '' + label: str = "" done: int = 0 total: int = 0 @@ -777,10 +787,10 @@ class Change: class CasperMd5Results(enum.Enum): - UNKNOWN = 'unknown' - FAIL = 'fail' - PASS = 'pass' - SKIP = 'skip' + UNKNOWN = "unknown" + FAIL = "fail" + PASS = "pass" + SKIP = "skip" class MirrorCheckStatus(enum.Enum): @@ -817,9 +827,9 @@ class MirrorGet: class MirrorSelectionFallback(enum.Enum): - ABORT = 'abort' - CONTINUE_ANYWAY = 'continue-anyway' - OFFLINE_INSTALL = 'offline-install' + ABORT = "abort" + CONTINUE_ANYWAY = "continue-anyway" + OFFLINE_INSTALL = "offline-install" @attr.s(auto_attribs=True) @@ -830,32 +840,32 @@ class AdConnectionInfo: class AdAdminNameValidation(enum.Enum): - OK = 'OK' - EMPTY = 'Empty' - INVALID_CHARS = 'Contains invalid characters' + OK = "OK" + EMPTY = "Empty" + INVALID_CHARS = "Contains invalid characters" class AdDomainNameValidation(enum.Enum): - OK = 'OK' - EMPTY = 'Empty' - TOO_LONG = 'Too long' - INVALID_CHARS = 'Contains invalid characters' - START_DOT = 'Starts with a dot' - END_DOT = 'Ends with a dot' - START_HYPHEN = 'Starts with a hyphen' - END_HYPHEN = 'Ends with a hyphen' - MULTIPLE_DOTS = 'Contains multiple dots' - REALM_NOT_FOUND = 'Could not find the domain controller' + OK = "OK" + EMPTY = "Empty" + TOO_LONG = "Too long" + INVALID_CHARS = "Contains invalid characters" + START_DOT = "Starts with a dot" + END_DOT = "Ends with a dot" + START_HYPHEN = "Starts with a hyphen" + END_HYPHEN = "Ends with a hyphen" + MULTIPLE_DOTS = "Contains multiple dots" + REALM_NOT_FOUND = "Could not find the domain controller" class AdPasswordValidation(enum.Enum): - OK = 'OK' - EMPTY = 'Empty' + OK = "OK" + EMPTY = "Empty" class AdJoinResult(enum.Enum): - OK = 'OK' - JOIN_ERROR = 'Failed to join' - EMPTY_HOSTNAME = 'Target hostname cannot be empty' - PAM_ERROR = 'Failed to update pam-auth' + OK = "OK" + JOIN_ERROR = "Failed to join" + EMPTY_HOSTNAME = "Target hostname cannot be empty" + PAM_ERROR = "Failed to update pam-auth" UNKNOWN = "Didn't attempt to join yet" diff --git a/subiquity/journald.py b/subiquity/journald.py index 15158c92..6a208791 100644 --- a/subiquity/journald.py +++ b/subiquity/journald.py @@ -23,7 +23,7 @@ def journald_listen(identifiers, callback, seek=False): reader = journal.Reader() args = [] for identifier in identifiers: - if '=' in identifier: + if "=" in identifier: args.append(identifier) else: args.append("SYSLOG_IDENTIFIER={}".format(identifier)) @@ -37,6 +37,7 @@ def journald_listen(identifiers, callback, seek=False): return for event in reader: callback(event) + loop = asyncio.get_running_loop() loop.add_reader(reader.fileno(), watch) return reader.fileno() diff --git a/subiquity/lockfile.py b/subiquity/lockfile.py index f6b75aeb..16419a00 100644 --- a/subiquity/lockfile.py +++ b/subiquity/lockfile.py @@ -18,11 +18,10 @@ import logging from subiquitycore.async_helpers import run_in_thread -log = logging.getLogger('subiquity.lockfile') +log = logging.getLogger("subiquity.lockfile") class _LockContext: - def __init__(self, lockfile, flags): self.lockfile = lockfile self.flags = flags @@ -49,10 +48,9 @@ class _LockContext: class Lockfile: - def __init__(self, path): self.path = path - self.fp = open(path, 'a+') + self.fp = open(path, "a+") def read_content(self): self.fp.seek(0) diff --git a/subiquity/models/ad.py b/subiquity/models/ad.py index d9c626cc..f727f1d6 100644 --- a/subiquity/models/ad.py +++ b/subiquity/models/ad.py @@ -18,11 +18,12 @@ from typing import Optional from subiquity.common.types import AdConnectionInfo -log = logging.getLogger('subiquity.models.ad') +log = logging.getLogger("subiquity.models.ad") class AdModel: - """ Models the Active Directory feature """ + """Models the Active Directory feature""" + def __init__(self) -> None: self.do_join = False self.conn_info: Optional[AdConnectionInfo] = None diff --git a/subiquity/models/codecs.py b/subiquity/models/codecs.py index 2f4f4b42..326369cb 100644 --- a/subiquity/models/codecs.py +++ b/subiquity/models/codecs.py @@ -15,7 +15,7 @@ import logging -log = logging.getLogger('subiquity.models.codecs') +log = logging.getLogger("subiquity.models.codecs") class CodecsModel: diff --git a/subiquity/models/drivers.py b/subiquity/models/drivers.py index 30167d9f..6756c883 100644 --- a/subiquity/models/drivers.py +++ b/subiquity/models/drivers.py @@ -15,7 +15,7 @@ import logging -log = logging.getLogger('subiquity.models.drivers') +log = logging.getLogger("subiquity.models.drivers") class DriversModel: diff --git a/subiquity/models/filesystem.py b/subiquity/models/filesystem.py index 587df49a..dd4e532e 100644 --- a/subiquity/models/filesystem.py +++ b/subiquity/models/filesystem.py @@ -13,8 +13,6 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from abc import ABC, abstractmethod -import attr import collections import copy import enum @@ -25,28 +23,28 @@ import os import pathlib import platform import tempfile +from abc import ABC, abstractmethod from typing import List, Optional, Set, Union +import attr import more_itertools - from curtin import storage_config from curtin.block import partition_kname from curtin.swap import can_use_swapfile from curtin.util import human2bytes - from probert.storage import StorageInfo from subiquity.common.types import Bootloader, OsProber -log = logging.getLogger('subiquity.models.filesystem') +log = logging.getLogger("subiquity.models.filesystem") MiB = 1024 * 1024 GiB = 1024 * 1024 * 1024 class NotFinalPartitionError(Exception): - """ Exception to raise when guessing the size of a partition that is not - the last one. """ + """Exception to raise when guessing the size of a partition that is not + the last one.""" def _set_backlinks(obj): @@ -61,7 +59,7 @@ def _set_backlinks(obj): obj.id = val obj._m._all_ids.add(obj.id) for field in attr.fields(type(obj)): - backlink = field.metadata.get('backlink') + backlink = field.metadata.get("backlink") if backlink is None: continue v = getattr(obj, field.name) @@ -81,7 +79,7 @@ def _set_backlinks(obj): def _remove_backlinks(obj): for field in attr.fields(type(obj)): - backlink = field.metadata.get('backlink') + backlink = field.metadata.get("backlink") if backlink is None: continue v = getattr(obj, field.name) @@ -110,15 +108,15 @@ def fsobj__repr(obj): v = getattr(obj, f.name) if v is f.default: continue - if f.metadata.get('ref', False): + if f.metadata.get("ref", False): v = v.id - elif f.metadata.get('reflist', False): + elif f.metadata.get("reflist", False): if isinstance(v, set): delims = "{}" else: delims = "[]" v = delims[0] + ", ".join(vv.id for vv in v) + delims[1] - elif f.metadata.get('redact', False): + elif f.metadata.get("redact", False): v = "" else: v = repr(v) @@ -135,24 +133,25 @@ def fsobj(typ): def wrapper(c): c.__attrs_post_init__ = _do_post_inits c._post_inits = [_set_backlinks] - class_post_init = getattr(c, '__post_init__', None) + class_post_init = getattr(c, "__post_init__", None) if class_post_init is not None: c._post_inits.append(class_post_init) c.type = attributes.const(typ) c.id = attr.ib(default=None) c._m = attr.ib(repr=None, default=None) - c.__annotations__['id'] = str - c.__annotations__['_m'] = "FilesystemModel" - c.__annotations__['type'] = str + c.__annotations__["id"] = str + c.__annotations__["_m"] = "FilesystemModel" + c.__annotations__["type"] = str c = attr.s(eq=False, repr=False, auto_attribs=True, kw_only=True)(c) c.__repr__ = fsobj__repr _type_to_cls[typ] = c return c + return wrapper def dependencies(obj): - if obj.type == 'disk': + if obj.type == "disk": dasd = obj.dasd() if dasd: yield dasd @@ -160,19 +159,19 @@ def dependencies(obj): v = getattr(obj, f.name) if not v: continue - elif f.metadata.get('ref', False): + elif f.metadata.get("ref", False): yield v - elif f.metadata.get('reflist', False): + elif f.metadata.get("reflist", False): yield from v def reverse_dependencies(obj): - if obj.type == 'dasd': - disk = obj._m._one(type='disk', device_id=obj.device_id) + if obj.type == "dasd": + disk = obj._m._one(type="disk", device_id=obj.device_id) if disk: yield disk for f in attr.fields(type(obj)): - if not f.metadata.get('is_backlink', False): + if not f.metadata.get("is_backlink", False): continue v = getattr(obj, f.name) if isinstance(v, (list, set)): @@ -198,20 +197,20 @@ class RaidLevel: raidlevels = [ # for translators: this is a description of a RAID level - RaidLevel(_("0 (striped)"), "raid0", 2, False), + RaidLevel(_("0 (striped)"), "raid0", 2, False), # for translators: this is a description of a RAID level - RaidLevel(_("1 (mirrored)"), "raid1", 2), - RaidLevel(_("5"), "raid5", 3), - RaidLevel(_("6"), "raid6", 4), - RaidLevel(_("10"), "raid10", 4), - RaidLevel(_("Container"), "container", 2), - ] + RaidLevel(_("1 (mirrored)"), "raid1", 2), + RaidLevel(_("5"), "raid5", 3), + RaidLevel(_("6"), "raid6", 4), + RaidLevel(_("10"), "raid10", 4), + RaidLevel(_("Container"), "container", 2), +] def _raidlevels_by_value(): r = {level.value: level for level in raidlevels} for n in 0, 1, 5, 6, 10: - r[str(n)] = r[n] = r["raid"+str(n)] + r[str(n)] = r[n] = r["raid" + str(n)] r["stripe"] = r["raid0"] r["mirror"] = r["raid1"] return r @@ -219,7 +218,7 @@ def _raidlevels_by_value(): raidlevels_by_value = _raidlevels_by_value() -HUMAN_UNITS = ['B', 'K', 'M', 'G', 'T', 'P'] +HUMAN_UNITS = ["B", "K", "M", "G", "T", "P"] def humanize_size(size): @@ -228,8 +227,8 @@ def humanize_size(size): p = int(math.floor(math.log(size, 2) / 10)) # We want to truncate the non-integral part, not round to nearest. s = "{:.17f}".format(size / 2 ** (10 * p)) - i = s.index('.') - s = s[:i + 4] + i = s.index(".") + s = s[: i + 4] return s + HUMAN_UNITS[int(p)] @@ -247,11 +246,12 @@ def dehumanize_size(size): else: suffix = None - parts = size.split('.') + parts = size.split(".") if len(parts) > 2: raise ValueError( # Attempting to convert input to a size - _("{input!r} is not valid input").format(input=size_in)) + _("{input!r} is not valid input").format(input=size_in) + ) elif len(parts) == 2: div = 10 ** len(parts[1]) size = parts[0] + parts[1] @@ -263,14 +263,17 @@ def dehumanize_size(size): except ValueError: raise ValueError( # Attempting to convert input to a size - _("{input!r} is not valid input").format(input=size_in)) + _("{input!r} is not valid input").format(input=size_in) + ) if suffix is not None: if suffix not in HUMAN_UNITS: raise ValueError( # Attempting to convert input to a size "unrecognized suffix {suffix!r} in {input!r}".format( - suffix=size_in[-1], input=size_in)) + suffix=size_in[-1], input=size_in + ) + ) mult = 2 ** (10 * HUMAN_UNITS.index(suffix)) else: mult = 1 @@ -307,24 +310,23 @@ def calculate_data_offset_bytes(devsize): devsize = align_down(devsize, DEFAULT_CHUNK) # conversion of choose_bm_space: - if devsize < 64*2: + if devsize < 64 * 2: bmspace = 0 - elif devsize - 64*2 >= 200*1024*1024*2: - bmspace = 128*2 - elif devsize - 4*2 > 8*1024*1024*2: - bmspace = 64*2 + elif devsize - 64 * 2 >= 200 * 1024 * 1024 * 2: + bmspace = 128 * 2 + elif devsize - 4 * 2 > 8 * 1024 * 1024 * 2: + bmspace = 64 * 2 else: - bmspace = 4*2 + bmspace = 4 * 2 # From the end of validate_geometry1, assuming metadata 1.2. - headroom = 128*1024*2 - while (headroom << 10) > devsize and headroom / 2 >= DEFAULT_CHUNK*2*2: + headroom = 128 * 1024 * 2 + while (headroom << 10) > devsize and headroom / 2 >= DEFAULT_CHUNK * 2 * 2: headroom >>= 1 - data_offset = 12*2 + bmspace + headroom - log.debug( - "get_raid_size: adjusting for %s sectors of overhead", data_offset) - data_offset = align_up(data_offset, 2*1024) + data_offset = 12 * 2 + bmspace + headroom + log.debug("get_raid_size: adjusting for %s sectors of overhead", data_offset) + data_offset = align_up(data_offset, 2 * 1024) # convert back to bytes return data_offset << 9 @@ -363,22 +365,20 @@ def get_raid_size(level, devices): # These are only defaults but curtin does not let you change/specify # them at this time. -LVM_OVERHEAD = (1 << 20) +LVM_OVERHEAD = 1 << 20 LVM_CHUNK_SIZE = 4 * (1 << 20) def get_lvm_size(devices, size_overrides={}): r = 0 for d in devices: - r += align_down( - size_overrides.get(d, d.size) - LVM_OVERHEAD, - LVM_CHUNK_SIZE) + r += align_down(size_overrides.get(d, d.size) - LVM_OVERHEAD, LVM_CHUNK_SIZE) return r def _conv_size(s): if isinstance(s, str): - if '%' in s: + if "%" in s: return s return int(human2bytes(s)) return s @@ -389,22 +389,21 @@ class attributes: @staticmethod def ref(*, backlink=None, default=attr.NOTHING): - metadata = {'ref': True} + metadata = {"ref": True} if backlink: - metadata['backlink'] = backlink + metadata["backlink"] = backlink return attr.ib(metadata=metadata, default=default) @staticmethod def reflist(*, backlink=None, default=attr.NOTHING): - metadata = {'reflist': True} + metadata = {"reflist": True} if backlink: - metadata['backlink'] = backlink + metadata["backlink"] = backlink return attr.ib(metadata=metadata, default=default) @staticmethod def backlink(*, default=None): - return attr.ib( - init=False, default=default, metadata={'is_backlink': True}) + return attr.ib(init=False, default=default, metadata={"is_backlink": True}) @staticmethod def const(value): @@ -416,35 +415,35 @@ class attributes: @staticmethod def ptable(): - def conv(val): if val == "dos": val = "msdos" return val + return attr.ib(default=None, converter=conv) @staticmethod def for_api(*, default=attr.NOTHING): - return attr.ib(default=default, metadata={'for_api': True}) + return attr.ib(default=default, metadata={"for_api": True}) def asdict(inst, *, for_api: bool): r = collections.OrderedDict() for field in attr.fields(type(inst)): metadata = field.metadata - if not for_api or not metadata.get('for_api', False): - if field.name.startswith('_'): + if not for_api or not metadata.get("for_api", False): + if field.name.startswith("_"): continue - name = field.name.lstrip('_') - m = getattr(inst, 'serialize_' + name, None) + name = field.name.lstrip("_") + m = getattr(inst, "serialize_" + name, None) if m: r.update(m()) else: v = getattr(inst, field.name) if v is not None: - if metadata.get('ref', False): + if metadata.get("ref", False): r[name] = v.id - elif metadata.get('reflist', False): + elif metadata.get("reflist", False): r[name] = [elem.id for elem in v] elif isinstance(v, StorageInfo): r[name] = {v.name: v.raw} @@ -477,11 +476,11 @@ class _Formattable(ABC): def original_fstype(self): for action in self._m._orig_config: - if action['type'] == 'format' and action['volume'] == self.id: - return action['fstype'] + if action["type"] == "format" and action["volume"] == self.id: + return action["fstype"] for action in self._m._orig_config: - if action['id'] == self.id and action.get('flag') == 'swap': - return 'swap' + if action["id"] == self.id and action.get("flag") == "swap": + return "swap" return None def constructed_device(self, skip_dm_crypt=True): @@ -531,8 +530,7 @@ class _Device(_Formattable, ABC): pass # [Partition] - _partitions: List["Partition"] = attributes.backlink( - default=attr.Factory(list)) + _partitions: List["Partition"] = attributes.backlink(default=attr.Factory(list)) def _reformatted(self): # Return a ephemeral copy of the device with as many partitions @@ -547,7 +545,7 @@ class _Device(_Formattable, ABC): def ptable_for_new_partition(self): if self.ptable is not None: return self.ptable - return 'gpt' + return "gpt" def partitions(self): return self._partitions @@ -591,9 +589,8 @@ class _Device(_Formattable, ABC): return False if self._fs is not None: return self._fs._available() - from subiquity.common.filesystem.gaps import ( - largest_gap_size, - ) + from subiquity.common.filesystem.gaps import largest_gap_size + if largest_gap_size(self) > 0: return True return any(p.available() for p in self._partitions) @@ -605,8 +602,11 @@ class _Device(_Formattable, ABC): return any(p.preserve for p in self._partitions) def renumber_logical_partitions(self, removed_partition): - parts = [p for p in self.partitions_by_number() - if p.is_logical and p.number > removed_partition.number] + parts = [ + p + for p in self.partitions_by_number() + if p.is_logical and p.number > removed_partition.number + ] next_num = removed_partition.number for part in parts: part.number = next_num @@ -650,54 +650,54 @@ class Disk(_Device): return self._m._partition_alignment_data[ptable] def info_for_display(self): - bus = self._info.raw.get('ID_BUS', None) - major = self._info.raw.get('MAJOR', None) - if bus is None and major == '253': - bus = 'virtio' + bus = self._info.raw.get("ID_BUS", None) + major = self._info.raw.get("MAJOR", None) + if bus is None and major == "253": + bus = "virtio" - devpath = self._info.raw.get('DEVPATH', self.path) + devpath = self._info.raw.get("DEVPATH", self.path) # XXX probert should be doing this!! - rotational = '1' + rotational = "1" try: dev = os.path.basename(devpath) - rfile = '/sys/class/block/{}/queue/rotational'.format(dev) - with open(rfile, 'r') as f: + rfile = "/sys/class/block/{}/queue/rotational".format(dev) + with open(rfile, "r") as f: rotational = f.read().strip() except (PermissionError, FileNotFoundError, IOError): - log.exception('WARNING: Failed to read file {}'.format(rfile)) + log.exception("WARNING: Failed to read file {}".format(rfile)) dinfo = { - 'bus': bus, - 'devname': self.path, - 'devpath': devpath, - 'model': self.model or 'unknown', - 'serial': self.serial or 'unknown', - 'wwn': self.wwn or 'unknown', - 'multipath': self.multipath or 'unknown', - 'size': self.size, - 'humansize': humanize_size(self.size), - 'vendor': self._info.vendor or 'unknown', - 'rotational': 'true' if rotational == '1' else 'false', + "bus": bus, + "devname": self.path, + "devpath": devpath, + "model": self.model or "unknown", + "serial": self.serial or "unknown", + "wwn": self.wwn or "unknown", + "multipath": self.multipath or "unknown", + "size": self.size, + "humansize": humanize_size(self.size), + "vendor": self._info.vendor or "unknown", + "rotational": "true" if rotational == "1" else "false", } return dinfo def ptable_for_new_partition(self): if self.ptable is not None: return self.ptable - dasd_config = self._m._probe_data.get('dasd', {}).get(self.path) + dasd_config = self._m._probe_data.get("dasd", {}).get(self.path) if dasd_config is not None: - if dasd_config['type'] == "FBA": - return 'msdos' + if dasd_config["type"] == "FBA": + return "msdos" else: - return 'vtoc' - return 'gpt' + return "vtoc" + return "gpt" @property def size(self): return self._info.size def dasd(self): - return self._m._one(type='dasd', device_id=self.device_id) + return self._m._one(type="dasd", device_id=self.device_id) @property def ok_for_raid(self): @@ -717,17 +717,17 @@ class Disk(_Device): @property def model(self): - return self._decode_id('ID_MODEL_ENC') + return self._decode_id("ID_MODEL_ENC") @property def vendor(self): - return self._decode_id('ID_VENDOR_ENC') + return self._decode_id("ID_VENDOR_ENC") def _decode_id(self, id): id = self._info.raw.get(id) if id is None: return None - return id.encode('utf-8').decode('unicode_escape').strip() + return id.encode("utf-8").decode("unicode_escape").strip() @fsobj("partition") @@ -754,9 +754,12 @@ class Partition(_Formattable): if self.number is not None: return - used_nums = {p.number for p in self.device._partitions - if p.number is not None - if p.is_logical == self.is_logical} + used_nums = { + p.number + for p in self.device._partitions + if p.number is not None + if p.is_logical == self.is_logical + } primary_limit = self.device.alignment_data().primary_part_limit if self.is_logical: possible_nums = range(primary_limit + 1, 129) @@ -766,10 +769,10 @@ class Partition(_Formattable): if num not in used_nums: self.number = num return - raise Exception('Exceeded number of available partitions') + raise Exception("Exceeded number of available partitions") def available(self): - if self.flag in ['bios_grub', 'prep'] or self.grub_device: + if self.flag in ["bios_grub", "prep"] or self.grub_device: return False if self._is_in_use: return False @@ -785,14 +788,15 @@ class Partition(_Formattable): @property def boot(self): from subiquity.common.filesystem import boot + return boot.is_bootloader_partition(self) @property def estimated_min_size(self): - fs_data = self._m._probe_data.get('filesystem', {}).get(self._path()) + fs_data = self._m._probe_data.get("filesystem", {}).get(self._path()) if fs_data is None: return -1 - val = fs_data.get('ESTIMATED_MIN_SIZE', -1) + val = fs_data.get("ESTIMATED_MIN_SIZE", -1) if val == 0: return self.device.alignment_data().part_align if val == -1: @@ -817,7 +821,7 @@ class Partition(_Formattable): @property def os(self): - os_data = self._m._probe_data.get('os', {}).get(self._path()) + os_data = self._m._probe_data.get("os", {}).get(self._path()) if not os_data: return None return OsProber(**os_data) @@ -843,7 +847,8 @@ class Raid(_Device): name: str raidlevel: str = attr.ib(converter=lambda x: raidlevels_by_value[x].value) devices: Set[Union[Disk, Partition, "Raid"]] = attributes.reflist( - backlink="_constructed_device", default=attr.Factory(set)) + backlink="_constructed_device", default=attr.Factory(set) + ) _info: Optional[StorageInfo] = attributes.for_api(default=None) _has_in_use_partition = False @@ -851,18 +856,18 @@ class Raid(_Device): # Surprisingly, the order of devices passed to mdadm --create # matters (see get_raid_size) so we sort devices here the same # way get_raid_size does. - return {'devices': [d.id for d in raid_device_sort(self.devices)]} + return {"devices": [d.id for d in raid_device_sort(self.devices)]} spare_devices: Set[Union[Disk, Partition, "Raid"]] = attributes.reflist( - backlink="_constructed_device", default=attr.Factory(set)) + backlink="_constructed_device", default=attr.Factory(set) + ) preserve: bool = False wipe: Optional[str] = None ptable: Optional[str] = attributes.ptable() metadata: Optional[str] = None _path: Optional[str] = None - container: Optional["Raid"] = attributes.ref( - backlink="_subvolumes", default=None) + container: Optional["Raid"] = attributes.ref(backlink="_subvolumes", default=None) _subvolumes: List["Raid"] = attributes.backlink(default=attr.Factory(list)) @property @@ -871,7 +876,7 @@ class Raid(_Device): return self._path # This is just here to make for_client(raid-with-partitions) work. It # might not be very accurate. - return '/dev/md/' + self.name + return "/dev/md/" + self.name @path.setter def path(self, value): @@ -892,7 +897,7 @@ class Raid(_Device): # For some reason, the overhead on RAID devices seems to be # higher (may be related to alignment of underlying # partitions) - return self.size - 2*GPT_OVERHEAD + return self.size - 2 * GPT_OVERHEAD def available(self): if self.raidlevel == "container": @@ -925,7 +930,8 @@ class Raid(_Device): class LVM_VolGroup(_Device): name: str devices: List[Union[Disk, Partition, Raid]] = attributes.reflist( - backlink="_constructed_device") + backlink="_constructed_device" + ) preserve: bool = False @@ -959,7 +965,7 @@ class LVM_LogicalVolume(_Formattable): if self.size is None: return {} else: - return {'size': "{}B".format(self.size)} + return {"size": "{}B".format(self.size)} def available(self): if self._constructed_device is not None: @@ -976,23 +982,22 @@ class LVM_LogicalVolume(_Formattable): ok_for_lvm_vg = False -LUKS_OVERHEAD = 16*(2**20) +LUKS_OVERHEAD = 16 * (2**20) @fsobj("dm_crypt") class DM_Crypt: volume: _Formattable = attributes.ref(backlink="_constructed_device") - key: Optional[str] = attr.ib(metadata={'redact': True}, default=None) + key: Optional[str] = attr.ib(metadata={"redact": True}, default=None) keyfile: Optional[str] = None path: Optional[str] = None def serialize_key(self): if self.key and not self.keyfile: - f = tempfile.NamedTemporaryFile( - prefix='luks-key-', mode='w', delete=False) + f = tempfile.NamedTemporaryFile(prefix="luks-key-", mode="w", delete=False) f.write(self.key) f.close() - return {'keyfile': f.name} + return {"keyfile": f.name} else: return {} @@ -1061,6 +1066,7 @@ class Mount: def can_delete(self): from subiquity.common.filesystem import boot + # Can't delete mount of /boot/efi or swap, anything else is fine. if not self.path: # swap mount @@ -1081,13 +1087,13 @@ def get_canmount(properties: Optional[dict], default: bool) -> bool: if properties is None: return default vals = { - 'on': True, - 'off': False, - 'noauto': False, + "on": True, + "off": False, + "noauto": False, True: True, False: False, } - result = vals.get(properties.get('canmount', default)) + result = vals.get(properties.get("canmount", default)) if result is None: raise ValueError('canmount must be one of "on", "off", or "noauto"') return result @@ -1096,7 +1102,8 @@ def get_canmount(properties: Optional[dict], default: bool) -> bool: @fsobj("zpool") class ZPool: vdevs: List[Union[Disk, Partition]] = attributes.reflist( - backlink="_constructed_device") + backlink="_constructed_device" + ) pool: str mountpoint: str @@ -1111,7 +1118,7 @@ class ZPool: @property def fstype(self): - return 'zfs' + return "zfs" @property def name(self): @@ -1137,7 +1144,7 @@ class ZFS: @property def fstype(self): - return 'zfs' + return "zfs" @property def canmount(self): @@ -1146,7 +1153,7 @@ class ZFS: @property def path(self): if self.canmount: - return self.properties.get('mountpoint', self.volume) + return self.properties.get("mountpoint", self.volume) else: return None @@ -1156,7 +1163,7 @@ ConstructedDevice = Union[Raid, LVM_VolGroup, ZPool] # A Mountlike is a literal Mount object, or one similar enough in behavior. # Mountlikes have a path property that may return None if that given object is # not actually mountable. -MountlikeNames = ('mount', 'zpool', 'zfs') +MountlikeNames = ("mount", "zpool", "zfs") def align_up(size, block_size=1 << 20): @@ -1198,36 +1205,38 @@ class ActionRenderMode(enum.Enum): class FilesystemModel(object): - target = None _partition_alignment_data = { - 'gpt': PartitionAlignmentData( + "gpt": PartitionAlignmentData( part_align=MiB, min_gap_size=MiB, - min_start_offset=GPT_OVERHEAD//2, - min_end_offset=GPT_OVERHEAD//2, - primary_part_limit=128), - 'msdos': PartitionAlignmentData( + min_start_offset=GPT_OVERHEAD // 2, + min_end_offset=GPT_OVERHEAD // 2, + primary_part_limit=128, + ), + "msdos": PartitionAlignmentData( part_align=MiB, min_gap_size=MiB, - min_start_offset=GPT_OVERHEAD//2, + min_start_offset=GPT_OVERHEAD // 2, min_end_offset=0, ebr_space=MiB, - primary_part_limit=4), + primary_part_limit=4, + ), # XXX check this one!! - 'vtoc': PartitionAlignmentData( + "vtoc": PartitionAlignmentData( part_align=MiB, min_gap_size=MiB, - min_start_offset=GPT_OVERHEAD//2, + min_start_offset=GPT_OVERHEAD // 2, min_end_offset=0, ebr_space=MiB, - primary_part_limit=3), - } + primary_part_limit=3, + ), + } @classmethod def is_mounted_filesystem(self, fstype): - if fstype in [None, 'swap']: + if fstype in [None, "swap"]: return False else: return True @@ -1235,7 +1244,7 @@ class FilesystemModel(object): def _probe_bootloader(self): # This will at some point change to return a list so that we can # configure BIOS _and_ UEFI on amd64 systems. - if os.path.exists('/sys/firmware/efi'): + if os.path.exists("/sys/firmware/efi"): return Bootloader.UEFI elif platform.machine().startswith("ppc64"): return Bootloader.PREP @@ -1274,38 +1283,40 @@ class FilesystemModel(object): return orig_model def process_probe_data(self): - self._orig_config = storage_config.extract_storage_config( - self._probe_data)["storage"]["config"] + self._orig_config = storage_config.extract_storage_config(self._probe_data)[ + "storage" + ]["config"] self._actions = self._actions_from_config( self._orig_config, - blockdevs=self._probe_data['blockdev'], - is_probe_data=True) + blockdevs=self._probe_data["blockdev"], + is_probe_data=True, + ) majmin_to_dev = {} for obj in self._actions: - if not hasattr(obj, '_info'): + if not hasattr(obj, "_info"): continue - major = obj._info.raw.get('MAJOR') - minor = obj._info.raw.get('MINOR') + major = obj._info.raw.get("MAJOR") + minor = obj._info.raw.get("MINOR") if major is None or minor is None: continue - majmin_to_dev[f'{major}:{minor}'] = obj + majmin_to_dev[f"{major}:{minor}"] = obj log.debug("majmin_to_dev %s", majmin_to_dev) - mounts = list(self._probe_data.get('mount', [])) + mounts = list(self._probe_data.get("mount", [])) while mounts: mount = mounts.pop(0) - mounts.extend(mount.get('children', [])) - if mount['target'].startswith(self.target): + mounts.extend(mount.get("children", [])) + if mount["target"].startswith(self.target): # Completely ignore mounts under /target, they are probably # leftovers from a previous install attempt. continue - if 'maj:min' not in mount: + if "maj:min" not in mount: continue - log.debug("considering mount of %s", mount['maj:min']) - obj = majmin_to_dev.get(mount['maj:min']) + log.debug("considering mount of %s", mount["maj:min"]) + obj = majmin_to_dev.get(mount["maj:min"]) if obj is None: continue obj._is_in_use = True @@ -1322,7 +1333,7 @@ class FilesystemModel(object): # the partition will show up as having a filesystem. Casper should # preferentially mount it as a partition though and if it looks like # that has happened, we ignore the filesystem on the drive itself. - for disk in self._all(type='disk'): + for disk in self._all(type="disk"): if disk._fs is None: continue if not disk._partitions: @@ -1330,81 +1341,76 @@ class FilesystemModel(object): p1 = disk._partitions[0] if p1._fs is None: continue - if disk._fs.fstype == p1._fs.fstype == 'iso9660': + if disk._fs.fstype == p1._fs.fstype == "iso9660": if p1._is_in_use and not disk._is_in_use: self.remove_filesystem(disk._fs) for o in self._actions: if o.type == "partition" and o.flag == "swap": if o._fs is None: - self._actions.append(Filesystem( - m=self, fstype="swap", volume=o, preserve=True)) + self._actions.append( + Filesystem(m=self, fstype="swap", volume=o, preserve=True) + ) def load_server_data(self, status): - log.debug('load_server_data %s', status) + log.debug("load_server_data %s", status) self._all_ids = set() self.storage_version = status.storage_version self._orig_config = status.orig_config self._probe_data = { - 'dasd': status.dasd, - } + "dasd": status.dasd, + } self._actions = self._actions_from_config( - status.config, - blockdevs=None, - is_probe_data=False) + status.config, blockdevs=None, is_probe_data=False + ) def _make_matchers(self, match): matchers = [] def _udev_val(disk, key): - return self._probe_data['blockdev'].get(disk.path, {}).get(key, '') + return self._probe_data["blockdev"].get(disk.path, {}).get(key, "") def match_serial(disk): - return fnmatch.fnmatchcase( - _udev_val(disk, "ID_SERIAL"), match['serial']) + return fnmatch.fnmatchcase(_udev_val(disk, "ID_SERIAL"), match["serial"]) def match_model(disk): - return fnmatch.fnmatchcase( - _udev_val(disk, "ID_MODEL"), match['model']) + return fnmatch.fnmatchcase(_udev_val(disk, "ID_MODEL"), match["model"]) def match_vendor(disk): - return fnmatch.fnmatchcase( - _udev_val(disk, "ID_VENDOR"), match['vendor']) + return fnmatch.fnmatchcase(_udev_val(disk, "ID_VENDOR"), match["vendor"]) def match_path(disk): - return fnmatch.fnmatchcase(disk.path, match['path']) + return fnmatch.fnmatchcase(disk.path, match["path"]) def match_id_path(disk): - return fnmatch.fnmatchcase( - _udev_val(disk, "ID_PATH"), match['id_path']) + return fnmatch.fnmatchcase(_udev_val(disk, "ID_PATH"), match["id_path"]) def match_devpath(disk): - return fnmatch.fnmatchcase( - _udev_val(disk, "DEVPATH"), match['devpath']) + return fnmatch.fnmatchcase(_udev_val(disk, "DEVPATH"), match["devpath"]) def match_ssd(disk): - is_ssd = disk.info_for_display()['rotational'] == 'false' - return is_ssd == match['ssd'] + is_ssd = disk.info_for_display()["rotational"] == "false" + return is_ssd == match["ssd"] def match_install_media(disk): return disk._has_in_use_partition - if match.get('install-media', False): + if match.get("install-media", False): matchers.append(match_install_media) - if 'serial' in match: + if "serial" in match: matchers.append(match_serial) - if 'model' in match: + if "model" in match: matchers.append(match_model) - if 'vendor' in match: + if "vendor" in match: matchers.append(match_vendor) - if 'path' in match: + if "path" in match: matchers.append(match_path) - if 'id_path' in match: + if "id_path" in match: matchers.append(match_id_path) - if 'devpath' in match: + if "devpath" in match: matchers.append(match_devpath) - if 'ssd' in match: + if "ssd" in match: matchers.append(match_ssd) return matchers @@ -1420,20 +1426,19 @@ class FilesystemModel(object): break else: candidates.append(candidate) - if 'size' in match or 'ssd' in match: - candidates = [ - c for c in candidates if not c._has_in_use_partition] - if match.get('size') == 'smallest': + if "size" in match or "ssd" in match: + candidates = [c for c in candidates if not c._has_in_use_partition] + if match.get("size") == "smallest": candidates.sort(key=lambda d: d.size) - if match.get('size') == 'largest': + if match.get("size") == "largest": candidates.sort(key=lambda d: d.size, reverse=True) if candidates: return candidates[0] return None def assign_omitted_offsets(self): - """ Assign offsets to partitions that do not already have one. - This method does nothing for storage version 1. """ + """Assign offsets to partitions that do not already have one. + This method does nothing for storage version 1.""" if self.storage_version != 2: return @@ -1451,9 +1456,9 @@ class FilesystemModel(object): return v - v % info.part_align # Extended is considered a primary partition too. - primary_parts, logical_parts = map(list, more_itertools.partition( - is_logical_partition, - disk.partitions())) + primary_parts, logical_parts = map( + list, more_itertools.partition(is_logical_partition, disk.partitions()) + ) prev_end = info.min_start_offset for part in primary_parts: @@ -1465,7 +1470,8 @@ class FilesystemModel(object): return extended_part = next( - filter(lambda x: x.flag == "extended", disk.partitions())) + filter(lambda x: x.flag == "extended", disk.partitions()) + ) prev_end = extended_part.offset for part in logical_parts: @@ -1476,30 +1482,29 @@ class FilesystemModel(object): def apply_autoinstall_config(self, ai_config): disks = self.all_disks() for action in ai_config: - if action['type'] == 'disk': + if action["type"] == "disk": disk = None - if 'serial' in action: - disk = self._one(type='disk', serial=action['serial']) - elif 'path' in action: - disk = self._one(type='disk', path=action['path']) + if "serial" in action: + disk = self._one(type="disk", serial=action["serial"]) + elif "path" in action: + disk = self._one(type="disk", path=action["path"]) else: - match = action.pop('match', {}) + match = action.pop("match", {}) disk = self.disk_for_match(disks, match) if disk is None: - action['match'] = match + action["match"] = match if disk is None: raise Exception("{} matched no disk".format(action)) if disk not in disks: raise Exception( - "{} matched {} which was already used".format( - action, disk)) + "{} matched {} which was already used".format(action, disk) + ) disks.remove(disk) - action['path'] = disk.path - action['serial'] = disk.serial + action["path"] = disk.path + action["serial"] = disk.serial self._actions = self._actions_from_config( - ai_config, - blockdevs=self._probe_data['blockdev'], - is_probe_data=False) + ai_config, blockdevs=self._probe_data["blockdev"], is_probe_data=False + ) self.assign_omitted_offsets() @@ -1513,32 +1518,39 @@ class FilesystemModel(object): # For extended partitions, we use this filter to create # a temporary copy of the disk that excludes all # logical partitions. - def filter_(x): return not is_logical_partition(x) + def filter_(x): + return not is_logical_partition(x) + else: - def filter_(x): return True + + def filter_(x): + return True + filtered_parent = copy.copy(parent) - filtered_parent._partitions = list(filter( - filter_, parent.partitions())) + filtered_parent._partitions = list( + filter(filter_, parent.partitions()) + ) if p is not filtered_parent.partitions()[-1]: raise NotFinalPartitionError( "{} has negative size but is not final partition " - "of {}".format(p, parent)) + "of {}".format(p, parent) + ) # Exclude the current partition itself so that its # incomplete size is not used as is. filtered_parent._partitions.remove(p) - from subiquity.common.filesystem.gaps import ( - largest_gap_size, - ) + from subiquity.common.filesystem.gaps import largest_gap_size + p.size = largest_gap_size( - filtered_parent, - in_extended=is_logical_partition(p)) + filtered_parent, in_extended=is_logical_partition(p) + ) elif isinstance(p.size, str): if p.size.endswith("%"): percentage = int(p.size[:-1]) p.size = align_down( - parent.available_for_partitions*percentage//100) + parent.available_for_partitions * percentage // 100 + ) else: p.size = dehumanize_size(p.size) @@ -1567,9 +1579,9 @@ class FilesystemModel(object): byid = {} objs = [] for action in config: - if is_probe_data and action['type'] == 'mount': + if is_probe_data and action["type"] == "mount": continue - c = _type_to_cls.get(action['type'], None) + c = _type_to_cls.get(action["type"], None) if c is None: # Ignore any action we do not know how to process yet # (e.g. bcache) @@ -1578,15 +1590,15 @@ class FilesystemModel(object): kw = {} field_names = set() for f in attr.fields(c): - n = f.name.lstrip('_') + n = f.name.lstrip("_") field_names.add(f.name) if n not in action: continue v = action[n] try: - if f.metadata.get('ref', False): + if f.metadata.get("ref", False): kw[n] = byid[v] - elif f.metadata.get('reflist', False): + elif f.metadata.get("reflist", False): kw[n] = [byid[id] for id in v] else: kw[n] = v @@ -1595,21 +1607,20 @@ class FilesystemModel(object): # ignored, we need to ignore the current action too # (e.g. a bcache's filesystem). continue - if '_info' in field_names: - if 'info' in kw: - kw['info'] = StorageInfo(kw['info']) - elif 'path' in kw: - path = kw['path'] - kw['info'] = StorageInfo({path: blockdevs[path]}) + if "_info" in field_names: + if "info" in kw: + kw["info"] = StorageInfo(kw["info"]) + elif "path" in kw: + path = kw["path"] + kw["info"] = StorageInfo({path: blockdevs[path]}) if is_probe_data: - kw['preserve'] = True - obj = byid[action['id']] = c(m=self, **kw) + kw["preserve"] = True + obj = byid[action["id"]] = c(m=self, **kw) objs.append(obj) return objs - def _render_actions(self, - mode: ActionRenderMode = ActionRenderMode.DEFAULT): + def _render_actions(self, mode: ActionRenderMode = ActionRenderMode.DEFAULT): # The curtin storage config has the constraint that an action must be # preceded by all the things that it depends on. We handle this by # repeatedly iterating over all actions and checking if we can emit @@ -1624,7 +1635,10 @@ class FilesystemModel(object): if isinstance(obj, Raid): log.debug( "FilesystemModel: estimated size of %s %s is %s", - obj.raidlevel, obj.name, obj.size) + obj.raidlevel, + obj.name, + obj.size, + ) r.append(asdict(obj, for_api=mode == ActionRenderMode.FOR_API)) emitted_ids.add(obj.id) @@ -1644,7 +1658,7 @@ class FilesystemModel(object): if dep.id not in emitted_ids: if dep not in work and dep not in next_work: next_work.append(dep) - if dep.type in ['disk', 'raid']: + if dep.type in ["disk", "raid"]: ensure_partitions(dep) return False if obj.type in MountlikeNames and obj.path is not None: @@ -1656,19 +1670,20 @@ class FilesystemModel(object): if mountpoints[parent] not in emitted_ids: log.debug( "cannot emit action to mount %s until that " - "for %s is emitted", obj.path, parent) + "for %s is emitted", + obj.path, + parent, + ) return False return True mountpoints = {m.path: m.id for m in self.all_mountlikes()} - log.debug('mountpoints %s', mountpoints) + log.debug("mountpoints %s", mountpoints) if mode == ActionRenderMode.FOR_API: work = list(self._actions) else: - work = [ - a for a in self._actions if not getattr(a, 'preserve', False) - ] + work = [a for a in self._actions if not getattr(a, "preserve", False)] while work: next_work = [] @@ -1685,57 +1700,55 @@ class FilesystemModel(object): work = next_work if mode == ActionRenderMode.DEVICES: - r = [act for act in r if act['type'] not in ('format', 'mount')] + r = [act for act in r if act["type"] not in ("format", "mount")] if mode == ActionRenderMode.FORMAT_MOUNT: - r = [act for act in r if act['type'] in ('format', 'mount')] + r = [act for act in r if act["type"] in ("format", "mount")] devices = [] for act in r: - if act['type'] == 'format': + if act["type"] == "format": device = { - 'type': 'device', - 'id': 'synth-device-{}'.format(len(devices)), - 'path': self._one(id=act['volume']).path, - } + "type": "device", + "id": "synth-device-{}".format(len(devices)), + "path": self._one(id=act["volume"]).path, + } devices.append(device) - act['volume'] = device['id'] + act["volume"] = device["id"] r = devices + r return r def render(self, mode: ActionRenderMode = ActionRenderMode.DEFAULT): config = { - 'storage': { - 'version': self.storage_version, - 'config': self._render_actions(mode=mode), - }, - } + "storage": { + "version": self.storage_version, + "config": self._render_actions(mode=mode), + }, + } if self.swap is not None: - config['swap'] = self.swap + config["swap"] = self.swap elif not self.should_add_swapfile(): - config['swap'] = {'size': 0} + config["swap"] = {"size": 0} if self.grub is not None: - config['grub'] = self.grub + config["grub"] = self.grub return config def load_probe_data(self, probe_data): - for devname, devdata in probe_data['blockdev'].items(): - if int(devdata['attrs']['size']) != 0: + for devname, devdata in probe_data["blockdev"].items(): + if int(devdata["attrs"]["size"]) != 0: continue # An unformatted (ECKD) dasd reports a size of 0 via e.g. blockdev # --getsize64. So figuring out how big it is requires a bit more # work. - data = probe_data.get('dasd', {}).get(devname) - if data is None or data['type'] != 'ECKD': + data = probe_data.get("dasd", {}).get(devname) + if data is None or data["type"] != "ECKD": continue - tracks_per_cylinder = data['tracks_per_cylinder'] - cylinders = data['cylinders'] + tracks_per_cylinder = data["tracks_per_cylinder"] + cylinders = data["cylinders"] blocksize = 4096 # hard coded for us! blocks_per_track = 12 # just a mystery fact that has to be known - size = \ - blocksize * blocks_per_track * tracks_per_cylinder * cylinders - log.debug( - "computing size on unformatted dasd from %s as %s", data, size) - devdata['attrs']['size'] = str(size) + size = blocksize * blocks_per_track * tracks_per_cylinder * cylinders + log.debug("computing size on unformatted dasd from %s as %s", data, size) + devdata["attrs"]["size"] = str(size) self._probe_data = probe_data self.reset() @@ -1757,7 +1770,7 @@ class FilesystemModel(object): return list(self._matcher(kw)) def all_mounts(self): - return self._all(type='mount') + return self._all(type="mount") def all_mountlikes(self): ret = [] @@ -1772,51 +1785,69 @@ class FilesystemModel(object): disks = [] compounds = [] for a in self._actions: - if a.type == 'disk': + if a.type == "disk": disks.append(a) elif isinstance(a, _Device): compounds.append(a) compounds.reverse() from subiquity.common.filesystem import labels + disks.sort(key=labels.label) return compounds + disks def all_disks(self): from subiquity.common.filesystem import labels - return sorted(self._all(type='disk'), key=labels.label) + + return sorted(self._all(type="disk"), key=labels.label) def all_raids(self): - return self._all(type='raid') + return self._all(type="raid") def all_volgroups(self): - return self._all(type='lvm_volgroup') + return self._all(type="lvm_volgroup") def _remove(self, obj): _remove_backlinks(obj) self._actions.remove(obj) - def add_partition(self, device, *, size, offset, flag="", wipe=None, - grub_device=None, partition_name=None, - check_alignment=True): + def add_partition( + self, + device, + *, + size, + offset, + flag="", + wipe=None, + grub_device=None, + partition_name=None, + check_alignment=True, + ): align = device.alignment_data().part_align if check_alignment: if offset % align != 0 or size % align != 0: raise Exception( - "size %s or offset %s not aligned to %s", - size, offset, align) + "size %s or offset %s not aligned to %s", size, offset, align + ) from subiquity.common.filesystem import boot + if device._fs is not None: raise Exception("%s is already formatted" % (device,)) p = Partition( - m=self, device=device, size=size, flag=flag, wipe=wipe, - grub_device=grub_device, offset=offset, - partition_name=partition_name) + m=self, + device=device, + size=size, + flag=flag, + wipe=wipe, + grub_device=grub_device, + offset=offset, + partition_name=partition_name, + ) if boot.is_bootloader_partition(p): device._partitions.insert(0, device._partitions.pop()) device.ptable = device.ptable_for_new_partition() dasd = device.dasd() if dasd is not None: - dasd.disk_layout = 'cdl' + dasd.disk_layout = "cdl" dasd.blocksize = 4096 dasd.preserve = False self._actions.append(p) @@ -1827,7 +1858,8 @@ class FilesystemModel(object): raise Exception("can only remove empty partition") from subiquity.common.filesystem.gaps import ( movable_trailing_partitions_and_gap_size, - ) + ) + for p2 in movable_trailing_partitions_and_gap_size(part)[0]: p2.offset -= part.size self._remove(part) @@ -1841,7 +1873,8 @@ class FilesystemModel(object): name=name, raidlevel=raidlevel, devices=devices, - spare_devices=spare_devices) + spare_devices=spare_devices, + ) self._actions.append(r) return r @@ -1884,14 +1917,15 @@ class FilesystemModel(object): log.debug("adding %s to %s", fstype, volume) if not volume.available: if not isinstance(volume, Partition): - if (volume.flag == 'prep' or ( - volume.flag == 'bios_grub' and fstype == 'fat32')): + if volume.flag == "prep" or ( + volume.flag == "bios_grub" and fstype == "fat32" + ): raise Exception("{} is not available".format(volume)) if volume._fs is not None: raise Exception(f"{volume} is already formatted") fs = Filesystem( - m=self, volume=volume, fstype=fstype, preserve=preserve, - label=label) + m=self, volume=volume, fstype=fstype, preserve=preserve, label=label + ) self._actions.append(fs) return fs @@ -1911,23 +1945,22 @@ class FilesystemModel(object): self._remove(mount) def needs_bootloader_partition(self): - '''true if no disk have a boot partition, and one is needed''' + """true if no disk have a boot partition, and one is needed""" # s390x has no such thing if self.bootloader == Bootloader.NONE: return False elif self.bootloader == Bootloader.BIOS: - return self._one(type='disk', grub_device=True) is None + return self._one(type="disk", grub_device=True) is None elif self.bootloader == Bootloader.UEFI: - for esp in self._all(type='partition', grub_device=True): + for esp in self._all(type="partition", grub_device=True): if esp.fs() and esp.fs().mount(): - if esp.fs().mount().path == '/boot/efi': + if esp.fs().mount().path == "/boot/efi": return False return True elif self.bootloader == Bootloader.PREP: - return self._one(type='partition', grub_device=True) is None + return self._one(type="partition", grub_device=True) is None else: - raise AssertionError( - "unknown bootloader type {}".format(self.bootloader)) + raise AssertionError("unknown bootloader type {}".format(self.bootloader)) def _mount_for_path(self, path): for typename in MountlikeNames: @@ -1937,30 +1970,31 @@ class FilesystemModel(object): return None def is_root_mounted(self): - return self._mount_for_path('/') is not None + return self._mount_for_path("/") is not None def can_install(self): - return (self.is_root_mounted() - and not self.needs_bootloader_partition()) + return self.is_root_mounted() and not self.needs_bootloader_partition() def should_add_swapfile(self): - mount = self._mount_for_path('/') + mount = self._mount_for_path("/") if mount is not None: - if not can_use_swapfile('/', mount.fstype): + if not can_use_swapfile("/", mount.fstype): return False - for swap in self._all(type='format', fstype='swap'): + for swap in self._all(type="format", fstype="swap"): if swap.mount(): return False return True - def add_zpool(self, device, pool, mountpoint, *, - fs_properties=None, pool_properties=None): + def add_zpool( + self, device, pool, mountpoint, *, fs_properties=None, pool_properties=None + ): zpool = ZPool( m=self, vdevs=[device], pool=pool, mountpoint=mountpoint, pool_properties=pool_properties, - fs_properties=fs_properties) + fs_properties=fs_properties, + ) self._actions.append(zpool) return zpool diff --git a/subiquity/models/identity.py b/subiquity/models/identity.py index 9dd4db2b..f86948f0 100644 --- a/subiquity/models/identity.py +++ b/subiquity/models/identity.py @@ -17,8 +17,7 @@ import logging import attr - -log = logging.getLogger('subiquity.models.identity') +log = logging.getLogger("subiquity.models.identity") @attr.s @@ -29,8 +28,7 @@ class User(object): class IdentityModel(object): - """ Model representing user identity - """ + """Model representing user identity""" def __init__(self): self._user = None @@ -39,11 +37,11 @@ class IdentityModel(object): def add_user(self, identity_data): self._hostname = identity_data.hostname d = {} - d['realname'] = identity_data.realname - d['username'] = identity_data.username - d['password'] = identity_data.crypted_password - if not d['realname']: - d['realname'] = identity_data.username + d["realname"] = identity_data.realname + d["username"] = identity_data.username + d["password"] = identity_data.crypted_password + if not d["realname"]: + d["realname"] = identity_data.username self._user = User(**d) @property diff --git a/subiquity/models/kernel.py b/subiquity/models/kernel.py index 6f17489c..12a7a34d 100644 --- a/subiquity/models/kernel.py +++ b/subiquity/models/kernel.py @@ -17,7 +17,6 @@ from typing import Optional class KernelModel: - # The name of the kernel metapackage that we intend to install. metapkg_name: Optional[str] = None # During the installation, if we detect that a different kernel version is @@ -42,10 +41,10 @@ class KernelModel: def render(self): if self.curthooks_no_install: - return {'kernel': None} + return {"kernel": None} return { - 'kernel': { - 'package': self.needed_kernel, - }, - } + "kernel": { + "package": self.needed_kernel, + }, + } diff --git a/subiquity/models/keyboard.py b/subiquity/models/keyboard.py index ab609639..eac3ecbb 100644 --- a/subiquity/models/keyboard.py +++ b/subiquity/models/keyboard.py @@ -14,17 +14,14 @@ # along with this program. If not, see . import logging -import re import os +import re import yaml from subiquity.common.resources import resource_path from subiquity.common.serialize import Serializer -from subiquity.common.types import ( - KeyboardLayout, - KeyboardSetting, -) +from subiquity.common.types import KeyboardLayout, KeyboardSetting log = logging.getLogger("subiquity.models.keyboard") @@ -44,12 +41,13 @@ BACKSPACE="guess" class InconsistentMultiLayoutError(ValueError): - """ Exception to raise when a multi layout has a different number of - layouts and variants. """ + """Exception to raise when a multi layout has a different number of + layouts and variants.""" + def __init__(self, layouts: str, variants: str) -> None: super().__init__( - f'inconsistent multi-layout: layouts="{layouts}"' - f' variants="{variants}"') + f'inconsistent multi-layout: layouts="{layouts}"' f' variants="{variants}"' + ) def from_config_file(config_file): @@ -57,10 +55,10 @@ def from_config_file(config_file): content = fp.read() def optval(opt, default): - match = re.search(r'(?m)^\s*%s=(.*)$' % (opt,), content) + match = re.search(r"(?m)^\s*%s=(.*)$" % (opt,), content) if match: r = match.group(1).strip('"') - if r != '': + if r != "": return r return default @@ -68,24 +66,22 @@ def from_config_file(config_file): XKBVARIANT = optval("XKBVARIANT", "") XKBOPTIONS = optval("XKBOPTIONS", "") toggle = None - for option in XKBOPTIONS.split(','): - if option.startswith('grp:'): + for option in XKBOPTIONS.split(","): + if option.startswith("grp:"): toggle = option[4:] return KeyboardSetting(layout=XKBLAYOUT, variant=XKBVARIANT, toggle=toggle) class KeyboardModel: - def __init__(self, root): - self.config_path = os.path.join( - root, 'etc', 'default', 'keyboard') + self.config_path = os.path.join(root, "etc", "default", "keyboard") self.layout_for_lang = self.load_layout_suggestions() if os.path.exists(self.config_path): self.default_setting = from_config_file(self.config_path) else: - self.default_setting = self.layout_for_lang['en_US.UTF-8'] + self.default_setting = self.layout_for_lang["en_US.UTF-8"] self.keyboard_list = KeyboardList() - self.keyboard_list.load_language('C') + self.keyboard_list.load_language("C") self._setting = None @property @@ -105,44 +101,50 @@ class KeyboardModel: if len(layout_tokens) != len(variant_tokens): raise InconsistentMultiLayoutError( - layouts=setting.layout, variants=setting.variant) + layouts=setting.layout, variants=setting.variant + ) for layout, variant in zip(layout_tokens, variant_tokens): kbd_layout = self.keyboard_list.layout_map.get(layout) if kbd_layout is None: raise ValueError(f'Unknown keyboard layout "{layout}"') - if not any(kbd_variant.code == variant - for kbd_variant in kbd_layout.variants): - raise ValueError(f'Unknown keyboard variant "{variant}" ' - f'for layout "{layout}"') + if not any( + kbd_variant.code == variant for kbd_variant in kbd_layout.variants + ): + raise ValueError( + f'Unknown keyboard variant "{variant}" ' f'for layout "{layout}"' + ) def render_config_file(self): options = "" if self.setting.toggle: options = "grp:" + self.setting.toggle return etc_default_keyboard_template.format( - layout=self.setting.layout, - variant=self.setting.variant, - options=options) + layout=self.setting.layout, variant=self.setting.variant, options=options + ) def render(self): return { - 'write_files': { - 'etc_default_keyboard': { - 'path': 'etc/default/keyboard', - 'content': self.render_config_file(), - 'permissions': 0o644, - }, + "write_files": { + "etc_default_keyboard": { + "path": "etc/default/keyboard", + "content": self.render_config_file(), + "permissions": 0o644, }, - 'curthooks_commands': { + }, + "curthooks_commands": { # The below command must be run after updating # etc/default/keyboard on the target so that the initramfs uses # the keyboard mapping selected by the user. See LP #1894009 - '002-setupcon-save-only': [ - 'curtin', 'in-target', '--', 'setupcon', '--save-only', - ], - }, - } + "002-setupcon-save-only": [ + "curtin", + "in-target", + "--", + "setupcon", + "--save-only", + ], + }, + } def setting_for_lang(self, lang): if self._setting is not None: @@ -154,7 +156,7 @@ class KeyboardModel: def load_layout_suggestions(self, path=None): if path is None: - path = resource_path('kbds') + '/keyboard-configuration.yaml' + path = resource_path("kbds") + "/keyboard-configuration.yaml" with open(path) as fp: data = yaml.safe_load(fp) @@ -166,25 +168,24 @@ class KeyboardModel: class KeyboardList: - def __init__(self): - self._kbnames_dir = resource_path('kbds') + self._kbnames_dir = resource_path("kbds") self.serializer = Serializer(compact=True) self._clear() def _file_for_lang(self, code): - return os.path.join(self._kbnames_dir, code + '.jsonl') + return os.path.join(self._kbnames_dir, code + ".jsonl") def _has_language(self, code): return os.path.exists(self._file_for_lang(code)) def load_language(self, code): - if '.' in code: - code = code.split('.')[0] + if "." in code: + code = code.split(".")[0] if not self._has_language(code): - code = code.split('_')[0] + code = code.split("_")[0] if not self._has_language(code): - code = 'C' + code = "C" if code == self.current_lang: return diff --git a/subiquity/models/locale.py b/subiquity/models/locale.py index 66648924..9ca41e8a 100644 --- a/subiquity/models/locale.py +++ b/subiquity/models/locale.py @@ -19,11 +19,11 @@ import subprocess from subiquitycore.utils import arun_command, split_cmd_output -log = logging.getLogger('subiquity.models.locale') +log = logging.getLogger("subiquity.models.locale") class LocaleModel: - """ Model representing locale selection + """Model representing locale selection XXX Only represents *language* selection for now. """ @@ -38,11 +38,7 @@ class LocaleModel: self.selected_language = code async def localectl_set_locale(self) -> None: - cmd = [ - 'localectl', - 'set-locale', - locale.normalize(self.selected_language) - ] + cmd = ["localectl", "set-locale", locale.normalize(self.selected_language)] await arun_command(cmd, check=True) async def try_localectl_set_locale(self) -> None: @@ -60,18 +56,19 @@ class LocaleModel: if self.locale_support == "none": return {} locale = self.selected_language - if '.' not in locale and '_' in locale: - locale += '.UTF-8' - return {'locale': locale} + if "." not in locale and "_" in locale: + locale += ".UTF-8" + return {"locale": locale} async def target_packages(self): if self.selected_language is None: return [] if self.locale_support != "langpack": return [] - lang = self.selected_language.split('.')[0] + lang = self.selected_language.split(".")[0] if lang == "C": return [] return await split_cmd_output( - self.chroot_prefix + ['check-language-support', '-l', lang], None) + self.chroot_prefix + ["check-language-support", "-l", lang], None + ) diff --git a/subiquity/models/mirror.py b/subiquity/models/mirror.py index db27f787..825f0bc6 100644 --- a/subiquity/models/mirror.py +++ b/subiquity/models/mirror.py @@ -72,33 +72,29 @@ primary entry """ import abc -import copy import contextlib +import copy import logging -from typing import ( - Any, Callable, Dict, Iterator, List, - Optional, Sequence, Set, Union, - ) +from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Union from urllib import parse import attr - -from subiquity.common.types import MirrorSelectionFallback - from curtin.commands.apt_config import ( - get_arch_mirrorconfig, - get_mirror, PORTS_ARCHES, PRIMARY_ARCHES, - ) + get_arch_mirrorconfig, + get_mirror, +) from curtin.config import merge_config +from subiquity.common.types import MirrorSelectionFallback + try: from curtin.distro import get_architecture except ImportError: from curtin.util import get_architecture -log = logging.getLogger('subiquity.models.mirror') +log = logging.getLogger("subiquity.models.mirror") DEFAULT_SUPPORTED_ARCHES_URI = "http://archive.ubuntu.com/ubuntu" DEFAULT_PORTS_ARCHES_URI = "http://ports.ubuntu.com/ubuntu-ports" @@ -107,7 +103,8 @@ LEGACY_DEFAULT_PRIMARY_SECTION = [ { "arches": PRIMARY_ARCHES, "uri": DEFAULT_SUPPORTED_ARCHES_URI, - }, { + }, + { "arches": ["default"], "uri": DEFAULT_PORTS_ARCHES_URI, }, @@ -120,9 +117,10 @@ DEFAULT = { @attr.s(auto_attribs=True) class BasePrimaryEntry(abc.ABC): - """ Base class to represent an entry from the 'primary' autoinstall + """Base class to represent an entry from the 'primary' autoinstall section. A BasePrimaryEntry is expected to have a URI and therefore can be - used as a primary candidate. """ + used as a primary candidate.""" + parent: "MirrorModel" = attr.ib(kw_only=True) def stage(self) -> None: @@ -133,19 +131,20 @@ class BasePrimaryEntry(abc.ABC): @abc.abstractmethod def serialize_for_ai(self) -> Any: - """ Serialize the entry for autoinstall. """ + """Serialize the entry for autoinstall.""" @abc.abstractmethod def supports_arch(self, arch: str) -> bool: - """ Tells whether the mirror claims to support the architecture - specified. """ + """Tells whether the mirror claims to support the architecture + specified.""" @attr.s(auto_attribs=True) class PrimaryEntry(BasePrimaryEntry): - """ Represents a single primary mirror candidate; which can be converted + """Represents a single primary mirror candidate; which can be converted to/from an entry of the 'apt->mirror-selection->primary' autoinstall - section. """ + section.""" + # Having uri set to None is only valid for a country mirror. uri: Optional[str] = None # When arches is None, it is assumed that the mirror is compatible with the @@ -183,10 +182,11 @@ class PrimaryEntry(BasePrimaryEntry): class LegacyPrimaryEntry(BasePrimaryEntry): - """ Represents a single primary mirror candidate; which can be converted + """Represents a single primary mirror candidate; which can be converted to/from the whole 'apt->primary' autoinstall section (legacy format). The format is defined by curtin, so we make use of curtin to access - the elements. """ + the elements.""" + def __init__(self, config: List[Any], *, parent: "MirrorModel") -> None: self.config = config super().__init__(parent=parent) @@ -200,8 +200,8 @@ class LegacyPrimaryEntry(BasePrimaryEntry): @uri.setter def uri(self, uri: str) -> None: config = get_arch_mirrorconfig( - {"primary": self.config}, - "primary", self.parent.architecture) + {"primary": self.config}, "primary", self.parent.architecture + ) config["uri"] = uri def mirror_is_default(self) -> bool: @@ -209,8 +209,7 @@ class LegacyPrimaryEntry(BasePrimaryEntry): @classmethod def new_from_default(cls, parent: "MirrorModel") -> "LegacyPrimaryEntry": - return cls(copy.deepcopy(LEGACY_DEFAULT_PRIMARY_SECTION), - parent=parent) + return cls(copy.deepcopy(LEGACY_DEFAULT_PRIMARY_SECTION), parent=parent) def serialize_for_ai(self) -> List[Any]: return self.config @@ -222,18 +221,18 @@ class LegacyPrimaryEntry(BasePrimaryEntry): def countrify_uri(uri: str, cc: str) -> str: - """ Return a URL where the host is prefixed with a country code. """ + """Return a URL where the host is prefixed with a country code.""" parsed = parse.urlparse(uri) - new = parsed._replace(netloc=cc + '.' + parsed.netloc) + new = parsed._replace(netloc=cc + "." + parsed.netloc) return parse.urlunparse(new) CandidateFilter = Callable[[BasePrimaryEntry], bool] -def filter_candidates(candidates: List[BasePrimaryEntry], - *, filters: Sequence[CandidateFilter]) \ - -> Iterator[BasePrimaryEntry]: +def filter_candidates( + candidates: List[BasePrimaryEntry], *, filters: Sequence[CandidateFilter] +) -> Iterator[BasePrimaryEntry]: candidates_iter = iter(candidates) for filt in filters: candidates_iter = filter(filt, candidates_iter) @@ -241,21 +240,20 @@ def filter_candidates(candidates: List[BasePrimaryEntry], class MirrorModel(object): - def __init__(self): self.config = copy.deepcopy(DEFAULT) self.legacy_primary = False self.disabled_components: Set[str] = set() self.primary_elected: Optional[BasePrimaryEntry] = None - self.primary_candidates: List[BasePrimaryEntry] = \ - self._default_primary_entries() + self.primary_candidates: List[ + BasePrimaryEntry + ] = self._default_primary_entries() self.primary_staged: Optional[BasePrimaryEntry] = None self.architecture = get_architecture() # Only useful for legacy primary sections - self.default_mirror = \ - LegacyPrimaryEntry.new_from_default(parent=self).uri + self.default_mirror = LegacyPrimaryEntry.new_from_default(parent=self).uri # What to do if automatic mirror-selection fails. self.fallback = MirrorSelectionFallback.ABORT @@ -263,14 +261,17 @@ class MirrorModel(object): def _default_primary_entries(self) -> List[PrimaryEntry]: return [ PrimaryEntry(parent=self, country_mirror=True), - PrimaryEntry(uri=DEFAULT_SUPPORTED_ARCHES_URI, - arches=PRIMARY_ARCHES, parent=self), - PrimaryEntry(uri=DEFAULT_PORTS_ARCHES_URI, - arches=PORTS_ARCHES, parent=self), + PrimaryEntry( + uri=DEFAULT_SUPPORTED_ARCHES_URI, arches=PRIMARY_ARCHES, parent=self + ), + PrimaryEntry( + uri=DEFAULT_PORTS_ARCHES_URI, arches=PORTS_ARCHES, parent=self + ), ] def get_default_primary_candidates( - self, legacy: Optional[bool] = None) -> Sequence[BasePrimaryEntry]: + self, legacy: Optional[bool] = None + ) -> Sequence[BasePrimaryEntry]: want_legacy = legacy if legacy is not None else self.legacy_primary if want_legacy: return [LegacyPrimaryEntry.new_from_default(parent=self)] @@ -282,15 +283,15 @@ class MirrorModel(object): self.disabled_components = set(data.pop("disable_components")) if "primary" in data and "mirror-selection" in data: - raise ValueError("apt->primary and apt->mirror-selection are" - " mutually exclusive.") + raise ValueError( + "apt->primary and apt->mirror-selection are" " mutually exclusive." + ) self.legacy_primary = "primary" in data primary_candidates = self.get_default_primary_candidates() if "primary" in data: # Legacy sections only support a single candidate - primary_candidates = \ - [LegacyPrimaryEntry(data.pop("primary"), parent=self)] + primary_candidates = [LegacyPrimaryEntry(data.pop("primary"), parent=self)] if "mirror-selection" in data: mirror_selection = data.pop("mirror-selection") if "primary" in mirror_selection: @@ -314,7 +315,8 @@ class MirrorModel(object): return config def _get_apt_config_using_candidate( - self, candidate: BasePrimaryEntry) -> Dict[str, Any]: + self, candidate: BasePrimaryEntry + ) -> Dict[str, Any]: config = self._get_apt_config_common() config["primary"] = candidate.config return config @@ -327,8 +329,7 @@ class MirrorModel(object): assert self.primary_elected is not None return self._get_apt_config_using_candidate(self.primary_elected) - def get_apt_config( - self, final: bool, has_network: bool) -> Dict[str, Any]: + def get_apt_config(self, final: bool, has_network: bool) -> Dict[str, Any]: if not final: return self.get_apt_config_staged() if has_network: @@ -347,15 +348,16 @@ class MirrorModel(object): lambda c: c.uri is not None, lambda c: c.supports_arch(self.architecture), ] - candidate = next(filter_candidates(self.primary_candidates, - filters=filters)) + candidate = next( + filter_candidates(self.primary_candidates, filters=filters) + ) return self._get_apt_config_using_candidate(candidate) # Our last resort is to include no primary section. Curtin will use # its own internal values. return self._get_apt_config_common() def set_country(self, cc): - """ Set the URI of country-mirror candidates. """ + """Set the URI of country-mirror candidates.""" for candidate in self.country_mirror_candidates(): if self.legacy_primary: candidate.uri = countrify_uri(candidate.uri, cc=cc) @@ -367,8 +369,8 @@ class MirrorModel(object): candidate.uri = countrify_uri(uri, cc=cc) def disable_components(self, comps, add: bool) -> None: - """ Add (or remove) a component (e.g., multiverse) from the list of - disabled components. """ + """Add (or remove) a component (e.g., multiverse) from the list of + disabled components.""" comps = set(comps) if add: self.disabled_components |= comps @@ -376,19 +378,17 @@ class MirrorModel(object): self.disabled_components -= comps def create_primary_candidate( - self, uri: Optional[str], - country_mirror: bool = False) -> BasePrimaryEntry: - + self, uri: Optional[str], country_mirror: bool = False + ) -> BasePrimaryEntry: if self.legacy_primary: entry = LegacyPrimaryEntry.new_from_default(parent=self) entry.uri = uri return entry - return PrimaryEntry(uri=uri, country_mirror=country_mirror, - parent=self) + return PrimaryEntry(uri=uri, country_mirror=country_mirror, parent=self) def wants_geoip(self) -> bool: - """ Tell whether geoip results would be useful. """ + """Tell whether geoip results would be useful.""" return next(self.country_mirror_candidates(), None) is not None def country_mirror_candidates(self) -> Iterator[BasePrimaryEntry]: diff --git a/subiquity/models/network.py b/subiquity/models/network.py index eadbe9b2..995b7cb8 100644 --- a/subiquity/models/network.py +++ b/subiquity/models/network.py @@ -20,11 +20,10 @@ from subiquity import cloudinit from subiquitycore.models.network import NetworkModel as CoreNetworkModel from subiquitycore.utils import arun_command -log = logging.getLogger('subiquity.models.network') +log = logging.getLogger("subiquity.models.network") class NetworkModel(CoreNetworkModel): - def __init__(self): super().__init__("subiquity") self.override_config = None @@ -46,59 +45,57 @@ class NetworkModel(CoreNetworkModel): # If host cloud-init version has no readable combined-cloud-config, # default to False. cloud_cfg = cloudinit.get_host_combined_cloud_config() - use_cloudinit_net = cloud_cfg.get('features', {}).get( - 'NETPLAN_CONFIG_ROOT_READ_ONLY', False + use_cloudinit_net = cloud_cfg.get("features", {}).get( + "NETPLAN_CONFIG_ROOT_READ_ONLY", False ) if use_cloudinit_net: r = { - 'write_files': { - 'etc_netplan_installer': { - 'path': ( - 'etc/cloud/cloud.cfg.d/90-installer-network.cfg' - ), - 'content': self.stringify_config(netplan), - 'permissions': '0600', + "write_files": { + "etc_netplan_installer": { + "path": ("etc/cloud/cloud.cfg.d/90-installer-network.cfg"), + "content": self.stringify_config(netplan), + "permissions": "0600", }, } } else: # Separate sensitive wifi config from world-readable config - wifis = netplan['network'].pop('wifis', None) + wifis = netplan["network"].pop("wifis", None) r = { - 'write_files': { + "write_files": { # Disable cloud-init networking - 'no_cloudinit_net': { - 'path': ( - 'etc/cloud/cloud.cfg.d/' - 'subiquity-disable-cloudinit-networking.cfg' + "no_cloudinit_net": { + "path": ( + "etc/cloud/cloud.cfg.d/" + "subiquity-disable-cloudinit-networking.cfg" ), - 'content': 'network: {config: disabled}\n', + "content": "network: {config: disabled}\n", }, # World-readable netplan without sensitive wifi config - 'etc_netplan_installer': { - 'path': 'etc/netplan/00-installer-config.yaml', - 'content': self.stringify_config(netplan), + "etc_netplan_installer": { + "path": "etc/netplan/00-installer-config.yaml", + "content": self.stringify_config(netplan), }, }, } if wifis is not None: netplan_wifi = { - 'network': { - 'version': 2, - 'wifis': wifis, + "network": { + "version": 2, + "wifis": wifis, }, } - r['write_files']['etc_netplan_installer_wifi'] = { - 'path': 'etc/netplan/00-installer-config-wifi.yaml', - 'content': self.stringify_config(netplan_wifi), - 'permissions': '0600', + r["write_files"]["etc_netplan_installer_wifi"] = { + "path": "etc/netplan/00-installer-config-wifi.yaml", + "content": self.stringify_config(netplan_wifi), + "permissions": "0600", } return r async def target_packages(self): if self.needs_wpasupplicant: - return ['wpasupplicant'] + return ["wpasupplicant"] else: return [] diff --git a/subiquity/models/oem.py b/subiquity/models/oem.py index 64ca4055..cf7222ec 100644 --- a/subiquity/models/oem.py +++ b/subiquity/models/oem.py @@ -18,7 +18,7 @@ from typing import Any, Dict, List, Optional, Union import attr -log = logging.getLogger('subiquity.models.oem') +log = logging.getLogger("subiquity.models.oem") @attr.s(auto_attribs=True) diff --git a/subiquity/models/proxy.py b/subiquity/models/proxy.py index caebe3fd..3bc70aa3 100644 --- a/subiquity/models/proxy.py +++ b/subiquity/models/proxy.py @@ -15,17 +15,16 @@ import logging -log = logging.getLogger('subiquity.models.proxy') +log = logging.getLogger("subiquity.models.proxy") -dropin_template = '''\ +dropin_template = """\ [Service] Environment="HTTP_PROXY={proxy}" Environment="HTTPS_PROXY={proxy}" -''' +""" class ProxyModel(object): - def __init__(self): self.proxy = "" @@ -35,8 +34,8 @@ class ProxyModel(object): def get_apt_config(self, final: bool, has_network: bool): if self.proxy: return { - 'http_proxy': self.proxy, - 'https_proxy': self.proxy, + "http_proxy": self.proxy, + "https_proxy": self.proxy, } else: return {} @@ -44,18 +43,19 @@ class ProxyModel(object): def render(self): if self.proxy: return { - 'proxy': { - 'http_proxy': self.proxy, - 'https_proxy': self.proxy, + "proxy": { + "http_proxy": self.proxy, + "https_proxy": self.proxy, + }, + "write_files": { + "snapd_dropin": { + "path": ( + "etc/systemd/system/" "snapd.service.d/snap_proxy.conf" + ), + "content": self.proxy_systemd_dropin(), + "permissions": 0o644, }, - 'write_files': { - 'snapd_dropin': { - 'path': ('etc/systemd/system/' - 'snapd.service.d/snap_proxy.conf'), - 'content': self.proxy_systemd_dropin(), - 'permissions': 0o644, - }, - }, - } + }, + } else: return {} diff --git a/subiquity/models/snaplist.py b/subiquity/models/snaplist.py index 953b2feb..07ea5937 100644 --- a/subiquity/models/snaplist.py +++ b/subiquity/models/snaplist.py @@ -16,11 +16,7 @@ import datetime import logging -from subiquity.common.types import ( - ChannelSnapInfo, - SnapInfo, - ) - +from subiquity.common.types import ChannelSnapInfo, SnapInfo log = logging.getLogger("subiquity.models.snaplist") @@ -44,47 +40,49 @@ class SnapListModel: return s def load_find_data(self, data): - for info in data['result']: - self.update(self._snap_for_name(info['name']), info) + for info in data["result"]: + self.update(self._snap_for_name(info["name"]), info) def add_partial_snap(self, name): self._snaps_for_name(name) def update(self, snap, data): - snap.summary = data['summary'] - snap.publisher = data['developer'] - snap.verified = data['publisher']['validation'] == "verified" - snap.starred = data['publisher']['validation'] == "starred" - snap.description = data['description'] - snap.confinement = data['confinement'] - snap.license = data['license'] + snap.summary = data["summary"] + snap.publisher = data["developer"] + snap.verified = data["publisher"]["validation"] == "verified" + snap.starred = data["publisher"]["validation"] == "starred" + snap.description = data["description"] + snap.confinement = data["confinement"] + snap.license = data["license"] self.complete_snaps.add(snap) def load_info_data(self, data): - info = data['result'][0] - snap = self._snaps_by_name.get(info['name']) + info = data["result"][0] + snap = self._snaps_by_name.get(info["name"]) if snap is None: return if snap not in self.complete_snaps: self.update_snap(snap, info) - channel_map = info['channels'] - for track in info['tracks']: + channel_map = info["channels"] + for track in info["tracks"]: for risk in risks: - channel_name = '{}/{}'.format(track, risk) + channel_name = "{}/{}".format(track, risk) if channel_name in channel_map: channel_data = channel_map[channel_name] if track == "latest": channel_name = risk - snap.channels.append(ChannelSnapInfo( - channel_name=channel_name, - revision=channel_data['revision'], - confinement=channel_data['confinement'], - version=channel_data['version'], - size=channel_data['size'], - released_at=datetime.datetime.strptime( - channel_data['released-at'], - '%Y-%m-%dT%H:%M:%S.%fZ'), - )) + snap.channels.append( + ChannelSnapInfo( + channel_name=channel_name, + revision=channel_data["revision"], + confinement=channel_data["confinement"], + version=channel_data["version"], + size=channel_data["size"], + released_at=datetime.datetime.strptime( + channel_data["released-at"], "%Y-%m-%dT%H:%M:%S.%fZ" + ), + ) + ) return snap def get_snap_list(self): @@ -100,9 +98,9 @@ class SnapListModel: return {} cmds = [] for selection in self.selections: - cmd = ['snap', 'install', '--channel=' + selection.channel] + cmd = ["snap", "install", "--channel=" + selection.channel] if selection.classic: - cmd.append('--classic') + cmd.append("--classic") cmd.append(selection.name) - cmds.append(' '.join(cmd)) - return {'snap': {'commands': cmds}} + cmds.append(" ".join(cmd)) + return {"snap": {"commands": cmds}} diff --git a/subiquity/models/source.py b/subiquity/models/source.py index 689a9ea2..cb21faea 100644 --- a/subiquity/models/source.py +++ b/subiquity/models/source.py @@ -16,13 +16,13 @@ import logging import os import typing -import yaml import attr +import yaml from subiquity.common.serialize import Serializer -log = logging.getLogger('subiquity.models.source') +log = logging.getLogger("subiquity.models.source") @attr.s(auto_attribs=True) @@ -49,33 +49,32 @@ class CatalogEntry: def __attrs_post_init__(self): if not self.variations: - self.variations['default'] = CatalogEntryVariation( + self.variations["default"] = CatalogEntryVariation( path=self.path, size=self.size, - snapd_system_label=self.snapd_system_label) + snapd_system_label=self.snapd_system_label, + ) legacy_server_entry = CatalogEntry( - variant='server', - id='synthesized', - name={'en': 'Ubuntu Server'}, - description={'en': 'the default'}, - path='/media/filesystem', - type='cp', + variant="server", + id="synthesized", + name={"en": "Ubuntu Server"}, + description={"en": "the default"}, + path="/media/filesystem", + type="cp", default=True, size=2 << 30, locale_support="locale-only", variations={ - 'default': CatalogEntryVariation( - path='/media/filesystem', - size=2 << 30), - }) + "default": CatalogEntryVariation(path="/media/filesystem", size=2 << 30), + }, +) class SourceModel: - def __init__(self): - self._dir = '/cdrom/casper' + self._dir = "/cdrom/casper" self.current = legacy_server_entry self.sources = [self.current] self.lang = None @@ -86,8 +85,8 @@ class SourceModel: self.sources = [] self.current = None self.sources = Serializer(ignore_unknown_fields=True).deserialize( - typing.List[CatalogEntry], - yaml.safe_load(fp)) + typing.List[CatalogEntry], yaml.safe_load(fp) + ) for entry in self.sources: if entry.default: self.current = entry @@ -96,7 +95,7 @@ class SourceModel: self.current = self.sources[0] def get_matching_source(self, id_: str) -> CatalogEntry: - """ Return a source object that has the ID requested. """ + """Return a source object that has the ID requested.""" for source in self.sources: if source.id == id_: return source @@ -113,10 +112,10 @@ class SourceModel: if self.lang in self.current.preinstalled_langs: suffix = self.lang else: - suffix = 'no-languages' - path = base + '.' + suffix + ext + suffix = "no-languages" + path = base + "." + suffix + ext scheme = self.current.type - return f'{scheme}://{path}' + return f"{scheme}://{path}" def render(self): return {} diff --git a/subiquity/models/ssh.py b/subiquity/models/ssh.py index c8ebae5b..e6081256 100644 --- a/subiquity/models/ssh.py +++ b/subiquity/models/ssh.py @@ -20,7 +20,6 @@ log = logging.getLogger("subiquity.models.ssh") class SSHModel: - def __init__(self): self.install_server = False self.authorized_keys: List[str] = [] @@ -28,10 +27,10 @@ class SSHModel: # Although the generated config just contains the key above, # we store the imported id so that we can re-fill the form if # we go back to it. - self.ssh_import_id = '' + self.ssh_import_id = "" async def target_packages(self): if self.install_server: - return ['openssh-server'] + return ["openssh-server"] else: return [] diff --git a/subiquity/models/subiquity.py b/subiquity/models/subiquity.py index af1c4ce1..0741d728 100644 --- a/subiquity/models/subiquity.py +++ b/subiquity/models/subiquity.py @@ -14,37 +14,35 @@ # along with this program. If not, see . import asyncio -from collections import OrderedDict import functools import json import logging import os -from typing import Any, Dict, Set import uuid -import yaml +from collections import OrderedDict +from typing import Any, Dict, Set +import yaml from cloudinit.config.schema import ( - SchemaValidationError, - get_schema, - validate_cloudconfig_schema + SchemaValidationError, + get_schema, + validate_cloudconfig_schema, ) try: from cloudinit.config.schema import SchemaProblem except ImportError: - def SchemaProblem(x, y): return (x, y) # TODO(drop on cloud-init 22.3 SRU) + + def SchemaProblem(x, y): + return (x, y) # TODO(drop on cloud-init 22.3 SRU) from curtin.config import merge_config -from subiquitycore.file_util import ( - generate_timestamped_header, - write_file, -) -from subiquitycore.lsb_release import lsb_release - from subiquity.common.resources import get_users_and_groups from subiquity.server.types import InstallerChannels +from subiquitycore.file_util import generate_timestamped_header, write_file +from subiquitycore.lsb_release import lsb_release from .ad import AdModel from .codecs import CodecsModel @@ -66,8 +64,7 @@ from .timezone import TimeZoneModel from .ubuntu_pro import UbuntuProModel from .updates import UpdatesModel - -log = logging.getLogger('subiquity.models.subiquity') +log = logging.getLogger("subiquity.models.subiquity") def merge_cloud_init_config(target, source): @@ -130,13 +127,13 @@ for cfg_file in {cfg_files}: # Emit #cloud-config write_files entry to disable cloud-init after install CLOUDINIT_DISABLE_AFTER_INSTALL = { - "path": "/etc/cloud/cloud-init.disabled", - "defer": True, - "content": ( - "Disabled by Ubuntu live installer after first boot.\n" - "To re-enable cloud-init on this image run:\n" - " sudo cloud-init clean --machine-id\n" - ) + "path": "/etc/cloud/cloud-init.disabled", + "defer": True, + "content": ( + "Disabled by Ubuntu live installer after first boot.\n" + "To re-enable cloud-init on this image run:\n" + " sudo cloud-init clean --machine-id\n" + ), } @@ -156,29 +153,26 @@ class ModelNames: class DebconfSelectionsModel: - def __init__(self): - self.selections = '' + self.selections = "" def render(self): return {} - def get_apt_config( - self, final: bool, has_network: bool) -> Dict[str, Any]: - return {'debconf_selections': {'subiquity': self.selections}} + def get_apt_config(self, final: bool, has_network: bool) -> Dict[str, Any]: + return {"debconf_selections": {"subiquity": self.selections}} class SubiquityModel: """The overall model for subiquity.""" - target = '/target' - chroot_prefix = ['chroot', target] + target = "/target" + chroot_prefix = ["chroot", target] - def __init__(self, root, hub, install_model_names, - postinstall_model_names): + def __init__(self, root, hub, install_model_names, postinstall_model_names): self.root = root self.hub = hub - if root != '/': + if root != "/": self.target = root self.chroot_prefix = [] @@ -212,8 +206,7 @@ class SubiquityModel: self._install_model_names = install_model_names self._postinstall_model_names = postinstall_model_names self._cur_install_model_names = install_model_names.default_names - self._cur_postinstall_model_names = \ - postinstall_model_names.default_names + self._cur_postinstall_model_names = postinstall_model_names.default_names self._install_event = asyncio.Event() self._postinstall_event = asyncio.Event() all_names = set() @@ -222,15 +215,17 @@ class SubiquityModel: for name in all_names: hub.subscribe( (InstallerChannels.CONFIGURED, name), - functools.partial(self._configured, name)) + functools.partial(self._configured, name), + ) def set_source_variant(self, variant): - self._cur_install_model_names = \ - self._install_model_names.for_variant(variant) - self._cur_postinstall_model_names = \ - self._postinstall_model_names.for_variant(variant) - unconfigured_install_model_names = \ + self._cur_install_model_names = self._install_model_names.for_variant(variant) + self._cur_postinstall_model_names = self._postinstall_model_names.for_variant( + variant + ) + unconfigured_install_model_names = ( self._cur_install_model_names - self._configured_names + ) if unconfigured_install_model_names: if self._install_event.is_set(): self._install_event = asyncio.Event() @@ -238,8 +233,9 @@ class SubiquityModel: self._confirmation_task.cancel() else: self._install_event.set() - unconfigured_postinstall_model_names = \ + unconfigured_postinstall_model_names = ( self._cur_postinstall_model_names - self._configured_names + ) if unconfigured_postinstall_model_names: if self._postinstall_event.is_set(): self._postinstall_event = asyncio.Event() @@ -247,27 +243,34 @@ class SubiquityModel: self._postinstall_event.set() def _configured(self, model_name): - """ Add the model to the set of models that have been configured. If + """Add the model to the set of models that have been configured. If there is no more model to configure in the relevant section(s) (i.e., - INSTALL or POSTINSTALL), we trigger the associated event(s). """ - def log_and_trigger(stage: str, names: Set[str], - event: asyncio.Event) -> None: + INSTALL or POSTINSTALL), we trigger the associated event(s).""" + + def log_and_trigger(stage: str, names: Set[str], event: asyncio.Event) -> None: unconfigured = names - self._configured_names log.debug( "model %s for %s stage is configured, to go %s", - model_name, stage, unconfigured) + model_name, + stage, + unconfigured, + ) if not unconfigured: event.set() self._configured_names.add(model_name) if model_name in self._cur_install_model_names: - log_and_trigger(stage="install", - names=self._cur_install_model_names, - event=self._install_event) + log_and_trigger( + stage="install", + names=self._cur_install_model_names, + event=self._install_event, + ) if model_name in self._cur_postinstall_model_names: - log_and_trigger(stage="postinstall", - names=self._cur_postinstall_model_names, - event=self._postinstall_event) + log_and_trigger( + stage="postinstall", + names=self._cur_postinstall_model_names, + event=self._postinstall_event, + ) async def wait_install(self): if len(self._cur_install_model_names) == 0: @@ -281,8 +284,7 @@ class SubiquityModel: async def wait_confirmation(self): if self._confirmation_task is None: - self._confirmation_task = asyncio.create_task( - self._confirmation.wait()) + self._confirmation_task = asyncio.create_task(self._confirmation.wait()) try: await self._confirmation_task except asyncio.CancelledError: @@ -293,8 +295,10 @@ class SubiquityModel: self._confirmation_task = None def is_postinstall_only(self, model_name): - return model_name in self._cur_postinstall_model_names and \ - model_name not in self._cur_install_model_names + return ( + model_name in self._cur_postinstall_model_names + and model_name not in self._cur_install_model_names + ) async def confirm(self): self._confirmation.set() @@ -326,9 +330,9 @@ class SubiquityModel: if warnings: log.warning( "The cloud-init configuration for %s contains" - " deprecated values:\n%s", data_source, "\n".join( - warnings - ) + " deprecated values:\n%s", + data_source, + "\n".join(warnings), ) if e.schema_errors: if data_source == "autoinstall.user-data": @@ -342,50 +346,46 @@ class SubiquityModel: def _cloud_init_config(self): config = { - 'growpart': { - 'mode': 'off', - }, - 'resize_rootfs': False, + "growpart": { + "mode": "off", + }, + "resize_rootfs": False, } if self.identity.hostname is not None: - config['preserve_hostname'] = True + config["preserve_hostname"] = True user = self.identity.user if user: groups = get_users_and_groups(self.chroot_prefix) user_info = { - 'name': user.username, - 'gecos': user.realname, - 'passwd': user.password, - 'shell': '/bin/bash', - 'groups': ','.join(sorted(groups)), - 'lock_passwd': False, - } + "name": user.username, + "gecos": user.realname, + "passwd": user.password, + "shell": "/bin/bash", + "groups": ",".join(sorted(groups)), + "lock_passwd": False, + } if self.ssh.authorized_keys: - user_info['ssh_authorized_keys'] = self.ssh.authorized_keys - config['users'] = [user_info] + user_info["ssh_authorized_keys"] = self.ssh.authorized_keys + config["users"] = [user_info] else: if self.ssh.authorized_keys: - config['ssh_authorized_keys'] = self.ssh.authorized_keys + config["ssh_authorized_keys"] = self.ssh.authorized_keys if self.ssh.install_server: - config['ssh_pwauth'] = self.ssh.pwauth + config["ssh_pwauth"] = self.ssh.pwauth for model_name in self._postinstall_model_names.all(): model = getattr(self, model_name) - if getattr(model, 'make_cloudconfig', None): + if getattr(model, "make_cloudconfig", None): merge_config(config, model.make_cloudconfig()) merge_cloud_init_config(config, self.userdata) - if lsb_release()['release'] not in ("20.04", "22.04"): - config.setdefault("write_files", []).append( - CLOUDINIT_DISABLE_AFTER_INSTALL - ) - self.validate_cloudconfig_schema( - data=config, data_source="system install" - ) + if lsb_release()["release"] not in ("20.04", "22.04"): + config.setdefault("write_files", []).append(CLOUDINIT_DISABLE_AFTER_INSTALL) + self.validate_cloudconfig_schema(data=config, data_source="system install") return config async def target_packages(self): packages = list(self.packages) for model_name in self._postinstall_model_names.all(): - meth = getattr(getattr(self, model_name), 'target_packages', None) + meth = getattr(getattr(self, model_name), "target_packages", None) if meth is not None: packages.extend(await meth()) return packages @@ -395,49 +395,55 @@ class SubiquityModel: # first boot of the target, it reconfigures datasource_list to none # for subsequent boots. # (mwhudson does not entirely know what the above means!) - userdata = '#cloud-config\n' + yaml.dump(self._cloud_init_config()) - metadata = {'instance-id': str(uuid.uuid4())} - config = yaml.dump({ - 'datasource_list': ["None"], - 'datasource': { - "None": { - 'userdata_raw': userdata, - 'metadata': metadata, + userdata = "#cloud-config\n" + yaml.dump(self._cloud_init_config()) + metadata = {"instance-id": str(uuid.uuid4())} + config = yaml.dump( + { + "datasource_list": ["None"], + "datasource": { + "None": { + "userdata_raw": userdata, + "metadata": metadata, }, }, - }) + } + ) files = [ - ('etc/cloud/cloud.cfg.d/99-installer.cfg', config, 0o600), - ('etc/cloud/ds-identify.cfg', 'policy: enabled\n', 0o644), - ] + ("etc/cloud/cloud.cfg.d/99-installer.cfg", config, 0o600), + ("etc/cloud/ds-identify.cfg", "policy: enabled\n", 0o644), + ] # Add cloud-init clean hooks to support golden-image creation. cfg_files = ["/" + path for (path, _content, _cmode) in files] cfg_files.extend(self.network.rendered_config_paths()) - if lsb_release()['release'] not in ("20.04", "22.04"): + if lsb_release()["release"] not in ("20.04", "22.04"): cfg_files.append("/etc/cloud/cloud-init.disabled") if self.identity.hostname is not None: hostname = self.identity.hostname.strip() - files.extend([ - ('etc/hostname', hostname + "\n", 0o644), - ('etc/hosts', HOSTS_CONTENT.format(hostname=hostname), 0o644), - ]) + files.extend( + [ + ("etc/hostname", hostname + "\n", 0o644), + ("etc/hosts", HOSTS_CONTENT.format(hostname=hostname), 0o644), + ] + ) - files.append(( - 'etc/cloud/clean.d/99-installer', - CLOUDINIT_CLEAN_FILE_TMPL.format( - header=generate_timestamped_header(), - cfg_files=json.dumps(sorted(cfg_files)) - ), - 0o755 - )) + files.append( + ( + "etc/cloud/clean.d/99-installer", + CLOUDINIT_CLEAN_FILE_TMPL.format( + header=generate_timestamped_header(), + cfg_files=json.dumps(sorted(cfg_files)), + ), + 0o755, + ) + ) return files def configure_cloud_init(self): if self.target is None: # i.e. reset_partition_only return - if self.source.current.variant == 'core': + if self.source.current.variant == "core": # can probably be supported but requires changes return for path, content, cmode in self._cloud_init_files(): @@ -446,60 +452,58 @@ class SubiquityModel: write_file(path, content, cmode=cmode) def _media_info(self): - if os.path.exists('/cdrom/.disk/info'): - with open('/cdrom/.disk/info') as fp: + if os.path.exists("/cdrom/.disk/info"): + with open("/cdrom/.disk/info") as fp: return fp.read() else: return "media-info" def _machine_id(self): - with open('/etc/machine-id') as fp: + with open("/etc/machine-id") as fp: return fp.read() def render(self): config = { - 'grub': { - 'terminal': 'unmodified', - 'probe_additional_os': True, - 'reorder_uefi': False, + "grub": { + "terminal": "unmodified", + "probe_additional_os": True, + "reorder_uefi": False, + }, + "install": { + "unmount": "disabled", + "save_install_config": False, + "save_install_log": False, + }, + "pollinate": { + "user_agent": { + "subiquity": "%s_%s" + % ( + os.environ.get("SNAP_VERSION", "dry-run"), + os.environ.get("SNAP_REVISION", "dry-run"), + ), }, - - 'install': { - 'unmount': 'disabled', - 'save_install_config': False, - 'save_install_log': False, + }, + "write_files": { + "etc_machine_id": { + "path": "etc/machine-id", + "content": self._machine_id(), + "permissions": 0o444, }, - - 'pollinate': { - 'user_agent': { - 'subiquity': "%s_%s" % (os.environ.get("SNAP_VERSION", - 'dry-run'), - os.environ.get("SNAP_REVISION", - 'dry-run')), - }, + "media_info": { + "path": "var/log/installer/media-info", + "content": self._media_info(), + "permissions": 0o644, }, + }, + } - 'write_files': { - 'etc_machine_id': { - 'path': 'etc/machine-id', - 'content': self._machine_id(), - 'permissions': 0o444, - }, - 'media_info': { - 'path': 'var/log/installer/media-info', - 'content': self._media_info(), - 'permissions': 0o644, - }, - }, - } - - if os.path.exists('/run/casper-md5check.json'): - with open('/run/casper-md5check.json') as fp: - config['write_files']['md5check'] = { - 'path': 'var/log/installer/casper-md5check.json', - 'content': fp.read(), - 'permissions': 0o644, - } + if os.path.exists("/run/casper-md5check.json"): + with open("/run/casper-md5check.json") as fp: + config["write_files"]["md5check"] = { + "path": "var/log/installer/casper-md5check.json", + "content": fp.read(), + "permissions": 0o644, + } for model_name in self._install_model_names.all(): model = getattr(self, model_name) diff --git a/subiquity/models/tests/test_filesystem.py b/subiquity/models/tests/test_filesystem.py index 4c89395d..e0762cec 100644 --- a/subiquity/models/tests/test_filesystem.py +++ b/subiquity/models/tests/test_filesystem.py @@ -19,39 +19,37 @@ from unittest import mock import attr import yaml +from subiquity.common.filesystem import gaps +from subiquity.models.filesystem import ( + LVM_CHUNK_SIZE, + ZFS, + ActionRenderMode, + Bootloader, + Disk, + Filesystem, + FilesystemModel, + NotFinalPartitionError, + Partition, + ZPool, + align_down, + dehumanize_size, + get_canmount, + get_raid_size, + humanize_size, +) from subiquitycore.tests import SubiTestCase from subiquitycore.tests.parameterized import parameterized from subiquitycore.utils import matching_dicts -from subiquity.models.filesystem import ( - ActionRenderMode, - Bootloader, - dehumanize_size, - Disk, - Filesystem, - FilesystemModel, - get_canmount, - get_raid_size, - humanize_size, - NotFinalPartitionError, - Partition, - ZFS, - ZPool, - align_down, - LVM_CHUNK_SIZE, - ) -from subiquity.common.filesystem import gaps - class TestHumanizeSize(unittest.TestCase): - basics = [ - ('1.000M', 2**20), - ('1.500M', 2**20+2**19), - ('1.500M', 2**20+2**19), - ('1023.000M', 1023*2**20), - ('1.000G', 1024*2**20), - ] + ("1.000M", 2**20), + ("1.500M", 2**20 + 2**19), + ("1.500M", 2**20 + 2**19), + ("1023.000M", 1023 * 2**20), + ("1.000G", 1024 * 2**20), + ] def test_basics(self): for string, integer in self.basics: @@ -60,39 +58,32 @@ class TestHumanizeSize(unittest.TestCase): class TestDehumanizeSize(unittest.TestCase): - basics = [ - ('1', 1), - ('134', 134), - - ('0.5B', 0), # Does it make sense to allow this? - ('1B', 1), - - ('1K', 2**10), - ('1k', 2**10), - ('0.5K', 2**9), - ('2.125K', 2**11 + 2**7), - - ('1M', 2**20), - ('1m', 2**20), - ('0.5M', 2**19), - ('2.125M', 2**21 + 2**17), - - ('1G', 2**30), - ('1g', 2**30), - ('0.25G', 2**28), - ('2.5G', 2**31 + 2**29), - - ('1T', 2**40), - ('1t', 2**40), - ('4T', 2**42), - ('4.125T', 2**42 + 2**37), - - ('1P', 2**50), - ('1P', 2**50), - ('0.5P', 2**49), - ('1.5P', 2**50 + 2**49), - ] + ("1", 1), + ("134", 134), + ("0.5B", 0), # Does it make sense to allow this? + ("1B", 1), + ("1K", 2**10), + ("1k", 2**10), + ("0.5K", 2**9), + ("2.125K", 2**11 + 2**7), + ("1M", 2**20), + ("1m", 2**20), + ("0.5M", 2**19), + ("2.125M", 2**21 + 2**17), + ("1G", 2**30), + ("1g", 2**30), + ("0.25G", 2**28), + ("2.5G", 2**31 + 2**29), + ("1T", 2**40), + ("1t", 2**40), + ("4T", 2**42), + ("4.125T", 2**42 + 2**37), + ("1P", 2**50), + ("1P", 2**50), + ("0.5P", 2**49), + ("1.5P", 2**50 + 2**49), + ] def test_basics(self): for input, expected_output in self.basics: @@ -100,13 +91,13 @@ class TestDehumanizeSize(unittest.TestCase): self.assertEqual(expected_output, dehumanize_size(input)) errors = [ - ('', "input cannot be empty"), - ('1u', "unrecognized suffix 'u' in '1u'"), - ('-1', "'-1': cannot be negative"), - ('1.1.1', "'1.1.1' is not valid input"), - ('1rm', "'1rm' is not valid input"), - ('1e6M', "'1e6M' is not valid input"), - ] + ("", "input cannot be empty"), + ("1u", "unrecognized suffix 'u' in '1u'"), + ("-1", "'-1': cannot be negative"), + ("1.1.1", "'1.1.1' is not valid input"), + ("1rm", "'1rm' is not valid input"), + ("1e6M", "'1e6M' is not valid input"), + ] def test_errors(self): for input, expected_error in self.errors: @@ -116,25 +107,21 @@ class TestDehumanizeSize(unittest.TestCase): except ValueError as e: actual_error = str(e) else: - self.fail( - "dehumanize_size({!r}) did not error".format(input)) + self.fail("dehumanize_size({!r}) did not error".format(input)) self.assertEqual(expected_error, actual_error) @attr.s class FakeDev: - size = attr.ib() id = attr.ib(default="id") class TestRoundRaidSize(unittest.TestCase): - def test_lp1816777(self): - self.assertLessEqual( - get_raid_size("raid1", [FakeDev(500107862016)]*2), - 499972571136) + get_raid_size("raid1", [FakeDev(500107862016)] * 2), 499972571136 + ) @attr.s @@ -160,15 +147,14 @@ def make_model(bootloader=None, storage_version=None): def make_disk(fs_model=None, **kw): if fs_model is None: fs_model = make_model() - if 'serial' not in kw: - kw['serial'] = 'serial%s' % len(fs_model._actions) - if 'path' not in kw: - kw['path'] = '/dev/thing%s' % len(fs_model._actions) - if 'ptable' not in kw: - kw['ptable'] = 'gpt' - size = kw.pop('size', 100*(2**30)) - fs_model._actions.append(Disk( - m=fs_model, info=FakeStorageInfo(size=size), **kw)) + if "serial" not in kw: + kw["serial"] = "serial%s" % len(fs_model._actions) + if "path" not in kw: + kw["path"] = "/dev/thing%s" % len(fs_model._actions) + if "ptable" not in kw: + kw["ptable"] = "gpt" + size = kw.pop("size", 100 * (2**30)) + fs_model._actions.append(Disk(m=fs_model, info=FakeStorageInfo(size=size), **kw)) disk = fs_model._actions[-1] return disk @@ -178,26 +164,32 @@ def make_model_and_disk(bootloader=None, **kw): return model, make_disk(model, **kw) -def make_partition(model, device=None, *, preserve=False, size=None, - offset=None, **kw): - flag = kw.pop('flag', None) +def make_partition(model, device=None, *, preserve=False, size=None, offset=None, **kw): + flag = kw.pop("flag", None) if device is None: device = make_disk(model) if size is None or offset is None: gap = gaps.largest_gap(device) if size is None: - size = gap.size//2 + size = gap.size // 2 if offset is None: offset = gap.offset - partition = Partition(m=model, device=device, size=size, offset=offset, - preserve=preserve, flag=flag, **kw) + partition = Partition( + m=model, + device=device, + size=size, + offset=offset, + preserve=preserve, + flag=flag, + **kw, + ) if partition.preserve: partition._info = FakeStorageInfo(size=size) model._actions.append(partition) return partition -def make_filesystem(model, partition, *, fstype='ext4', **kw): +def make_filesystem(model, partition, *, fstype="ext4", **kw): return Filesystem(m=model, volume=partition, fstype=fstype, **kw) @@ -207,9 +199,8 @@ def make_model_and_partition(bootloader=None): def make_raid(model, **kw): - name = 'md%s' % len(model._actions) - r = model.add_raid( - name, 'raid1', {make_disk(model), make_disk(model)}, set()) + name = "md%s" % len(model._actions) + r = model.add_raid(name, "raid1", {make_disk(model), make_disk(model)}, set()) size = r.size for k, v in kw.items(): setattr(r, k, v) @@ -224,7 +215,7 @@ def make_model_and_raid(bootloader=None): def make_vg(model, pvs=None): - name = 'vg%s' % len(model._actions) + name = "vg%s" % len(model._actions) if pvs is None: pvs = {make_disk(model)} @@ -240,7 +231,7 @@ def make_model_and_vg(bootloader=None): def make_lv(model, vg=None, size=None): if vg is None: vg = make_vg(model) - name = 'lv%s' % len(model._actions) + name = "lv%s" % len(model._actions) size = gaps.largest_gap_size(vg) if size is None else size return model.add_logical_volume(vg, name, size) @@ -256,9 +247,8 @@ def make_zpool(model=None, device=None, pool=None, mountpoint=None, **kw): if device is None: device = make_disk(model) if pool is None: - pool = f'pool{len(model._actions)}' - return model.add_zpool( - device=device, pool=pool, mountpoint=mountpoint, **kw) + pool = f"pool{len(model._actions)}" + return model.add_zpool(device=device, pool=pool, mountpoint=mountpoint, **kw) def make_zfs(model, *, pool, **kw): @@ -268,15 +258,13 @@ def make_zfs(model, *, pool, **kw): class TestFilesystemModel(unittest.TestCase): - - def _test_ok_for_xxx(self, model, make_new_device, attr, - test_partitions=True): + def _test_ok_for_xxx(self, model, make_new_device, attr, test_partitions=True): # Newly formatted devs are ok_for_raid dev1 = make_new_device(model) self.assertTrue(getattr(dev1, attr)) # A freshly formatted dev is not ok_for_raid dev2 = make_new_device(model) - model.add_filesystem(dev2, 'ext4') + model.add_filesystem(dev2, "ext4") self.assertFalse(getattr(dev2, attr)) if test_partitions: # A device with a partition is not ok_for_raid @@ -289,11 +277,11 @@ class TestFilesystemModel(unittest.TestCase): # A dev with an existing filesystem is ok (there is no # way to remove the format) dev5 = make_new_device(model, preserve=True) - fs = model.add_filesystem(dev5, 'ext4') + fs = model.add_filesystem(dev5, "ext4") fs.preserve = True self.assertTrue(dev5.ok_for_raid) # But a existing, *mounted* filesystem is not. - model.add_mount(fs, '/') + model.add_mount(fs, "/") self.assertFalse(dev5.ok_for_raid) def test_disk_ok_for_xxx(self): @@ -308,13 +296,13 @@ class TestFilesystemModel(unittest.TestCase): self._test_ok_for_xxx(model, make_partition, "ok_for_raid", False) self._test_ok_for_xxx(model, make_partition, "ok_for_lvm_vg", False) - part = make_partition(make_model(Bootloader.BIOS), flag='bios_grub') + part = make_partition(make_model(Bootloader.BIOS), flag="bios_grub") self.assertFalse(part.ok_for_raid) self.assertFalse(part.ok_for_lvm_vg) - part = make_partition(make_model(Bootloader.UEFI), flag='boot') + part = make_partition(make_model(Bootloader.UEFI), flag="boot") self.assertFalse(part.ok_for_raid) self.assertFalse(part.ok_for_lvm_vg) - part = make_partition(make_model(Bootloader.PREP), flag='prep') + part = make_partition(make_model(Bootloader.PREP), flag="prep") self.assertFalse(part.ok_for_raid) self.assertFalse(part.ok_for_lvm_vg) @@ -339,15 +327,15 @@ def fake_up_blockdata_disk(disk, **kw): model = disk._m if model._probe_data is None: model._probe_data = {} - blockdev = model._probe_data.setdefault('blockdev', {}) + blockdev = model._probe_data.setdefault("blockdev", {}) d = blockdev[disk.path] = { - 'DEVTYPE': 'disk', - 'ID_SERIAL': disk.serial, - 'ID_MODEL': disk.model, - 'attrs': { - 'size': disk.size, - }, - } + "DEVTYPE": "disk", + "ID_SERIAL": disk.serial, + "ID_MODEL": disk.model, + "attrs": { + "size": disk.size, + }, + } d.update(kw) @@ -357,342 +345,381 @@ def fake_up_blockdata(model): class TestAutoInstallConfig(unittest.TestCase): - def test_basic(self): model, disk = make_model_and_disk() fake_up_blockdata(model) - model.apply_autoinstall_config([{'type': 'disk', 'id': 'disk0'}]) + model.apply_autoinstall_config([{"type": "disk", "id": "disk0"}]) [new_disk] = model.all_disks() self.assertIsNot(new_disk, disk) self.assertEqual(new_disk.serial, disk.serial) def test_largest(self): model = make_model() - make_disk(model, serial='smaller', size=10*(2**30)) - make_disk(model, serial='larger', size=11*(2**30)) + make_disk(model, serial="smaller", size=10 * (2**30)) + make_disk(model, serial="larger", size=11 * (2**30)) fake_up_blockdata(model) - model.apply_autoinstall_config([ - { - 'type': 'disk', - 'id': 'disk0', - 'match': { - 'size': 'largest', + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + "match": { + "size": "largest", }, - }, - ]) + }, + ] + ) new_disk = model._one(type="disk", id="disk0") self.assertEqual(new_disk.serial, "larger") def test_smallest(self): model = make_model() - make_disk(model, serial='smaller', size=10*(2**30)) - make_disk(model, serial='larger', size=11*(2**30)) + make_disk(model, serial="smaller", size=10 * (2**30)) + make_disk(model, serial="larger", size=11 * (2**30)) fake_up_blockdata(model) - model.apply_autoinstall_config([ - { - 'type': 'disk', - 'id': 'disk0', - 'match': { - 'size': 'smallest', + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + "match": { + "size": "smallest", }, - }, - ]) + }, + ] + ) new_disk = model._one(type="disk", id="disk0") self.assertEqual(new_disk.serial, "smaller") def test_smallest_skips_zero_size(self): model = make_model() - make_disk(model, serial='smallest', size=0) - make_disk(model, serial='smaller', size=10*(2**30)) - make_disk(model, serial='larger', size=11*(2**30)) + make_disk(model, serial="smallest", size=0) + make_disk(model, serial="smaller", size=10 * (2**30)) + make_disk(model, serial="larger", size=11 * (2**30)) fake_up_blockdata(model) - model.apply_autoinstall_config([ - { - 'type': 'disk', - 'id': 'disk0', - 'match': { - 'size': 'smallest', + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + "match": { + "size": "smallest", }, - }, - ]) + }, + ] + ) new_disk = model._one(type="disk", id="disk0") self.assertEqual(new_disk.serial, "smaller") def test_serial_exact(self): model = make_model() - make_disk(model, serial='aaaa', path='/dev/aaa') - make_disk(model, serial='bbbb', path='/dev/bbb') + make_disk(model, serial="aaaa", path="/dev/aaa") + make_disk(model, serial="bbbb", path="/dev/bbb") fake_up_blockdata(model) - model.apply_autoinstall_config([ - { - 'type': 'disk', - 'id': 'disk0', - 'serial': 'aaaa', - }, - ]) + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + "serial": "aaaa", + }, + ] + ) new_disk = model._one(type="disk", id="disk0") self.assertEqual(new_disk.path, "/dev/aaa") def test_serial_glob(self): model = make_model() - make_disk(model, serial='aaaa', path='/dev/aaa') - make_disk(model, serial='bbbb', path='/dev/bbb') + make_disk(model, serial="aaaa", path="/dev/aaa") + make_disk(model, serial="bbbb", path="/dev/bbb") fake_up_blockdata(model) - model.apply_autoinstall_config([ - { - 'type': 'disk', - 'id': 'disk0', - 'match': { - 'serial': 'a*', + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + "match": { + "serial": "a*", }, - }, - ]) + }, + ] + ) new_disk = model._one(type="disk", id="disk0") self.assertEqual(new_disk.path, "/dev/aaa") def test_path_exact(self): model = make_model() - make_disk(model, serial='aaaa', path='/dev/aaa') - make_disk(model, serial='bbbb', path='/dev/bbb') + make_disk(model, serial="aaaa", path="/dev/aaa") + make_disk(model, serial="bbbb", path="/dev/bbb") fake_up_blockdata(model) - model.apply_autoinstall_config([ - { - 'type': 'disk', - 'id': 'disk0', - 'path': '/dev/aaa', - }, - ]) + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + "path": "/dev/aaa", + }, + ] + ) new_disk = model._one(type="disk", id="disk0") self.assertEqual(new_disk.serial, "aaaa") def test_path_glob(self): model = make_model() - d1 = make_disk(model, serial='aaaa', path='/dev/aaa') - d2 = make_disk(model, serial='bbbb', path='/dev/bbb') + d1 = make_disk(model, serial="aaaa", path="/dev/aaa") + d2 = make_disk(model, serial="bbbb", path="/dev/bbb") fake_up_blockdata_disk(d1) fake_up_blockdata_disk(d2) - model.apply_autoinstall_config([ - { - 'type': 'disk', - 'id': 'disk0', - 'match': { - 'path': '/dev/a*', + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + "match": { + "path": "/dev/a*", }, - }, - ]) + }, + ] + ) new_disk = model._one(type="disk", id="disk0") self.assertEqual(new_disk.serial, d1.serial) def test_model_glob(self): model = make_model() - d1 = make_disk(model, serial='aaaa') - d2 = make_disk(model, serial='bbbb') - fake_up_blockdata_disk(d1, ID_MODEL='aaa') - fake_up_blockdata_disk(d2, ID_MODEL='bbb') - model.apply_autoinstall_config([ - { - 'type': 'disk', - 'id': 'disk0', - 'match': { - 'model': 'a*', + d1 = make_disk(model, serial="aaaa") + d2 = make_disk(model, serial="bbbb") + fake_up_blockdata_disk(d1, ID_MODEL="aaa") + fake_up_blockdata_disk(d2, ID_MODEL="bbb") + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + "match": { + "model": "a*", }, - }, - ]) + }, + ] + ) new_disk = model._one(type="disk", id="disk0") self.assertEqual(new_disk.serial, d1.serial) def test_vendor_glob(self): model = make_model() - d1 = make_disk(model, serial='aaaa') - d2 = make_disk(model, serial='bbbb') - fake_up_blockdata_disk(d1, ID_VENDOR='aaa') - fake_up_blockdata_disk(d2, ID_VENDOR='bbb') - model.apply_autoinstall_config([ - { - 'type': 'disk', - 'id': 'disk0', - 'match': { - 'vendor': 'a*', + d1 = make_disk(model, serial="aaaa") + d2 = make_disk(model, serial="bbbb") + fake_up_blockdata_disk(d1, ID_VENDOR="aaa") + fake_up_blockdata_disk(d2, ID_VENDOR="bbb") + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + "match": { + "vendor": "a*", }, - }, - ]) + }, + ] + ) new_disk = model._one(type="disk", id="disk0") self.assertEqual(new_disk.serial, d1.serial) def test_id_path_glob(self): model = make_model() - d1 = make_disk(model, serial='aaaa') - d2 = make_disk(model, serial='bbbb') - fake_up_blockdata_disk(d1, ID_PATH='pci-0000:00:00.0-nvme-aaa') - fake_up_blockdata_disk(d2, ID_PATH='pci-0000:00:00.0-nvme-bbb') - model.apply_autoinstall_config([ - { - 'type': 'disk', - 'id': 'disk0', - 'match': { - 'id_path': '*aaa', + d1 = make_disk(model, serial="aaaa") + d2 = make_disk(model, serial="bbbb") + fake_up_blockdata_disk(d1, ID_PATH="pci-0000:00:00.0-nvme-aaa") + fake_up_blockdata_disk(d2, ID_PATH="pci-0000:00:00.0-nvme-bbb") + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + "match": { + "id_path": "*aaa", }, - }, - ]) + }, + ] + ) new_disk = model._one(type="disk", id="disk0") self.assertEqual(new_disk.serial, d1.serial) def test_devpath_glob(self): model = make_model() - d1 = make_disk(model, serial='aaaa') - d2 = make_disk(model, serial='bbbb') - fake_up_blockdata_disk(d1, DEVPATH='/devices/pci0000:00/aaa') - fake_up_blockdata_disk(d2, DEVPATH='/devices/pci0000:00/bbb') - model.apply_autoinstall_config([ - { - 'type': 'disk', - 'id': 'disk0', - 'match': { - 'devpath': '*aaa', + d1 = make_disk(model, serial="aaaa") + d2 = make_disk(model, serial="bbbb") + fake_up_blockdata_disk(d1, DEVPATH="/devices/pci0000:00/aaa") + fake_up_blockdata_disk(d2, DEVPATH="/devices/pci0000:00/bbb") + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + "match": { + "devpath": "*aaa", }, - }, - ]) + }, + ] + ) new_disk = model._one(type="disk", id="disk0") self.assertEqual(new_disk.serial, d1.serial) def test_no_matching_disk(self): model = make_model() - make_disk(model, serial='bbbb') + make_disk(model, serial="bbbb") fake_up_blockdata(model) with self.assertRaises(Exception) as cm: - model.apply_autoinstall_config([{ - 'type': 'disk', - 'id': 'disk0', - 'serial': 'aaaa', - }]) + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + "serial": "aaaa", + } + ] + ) self.assertIn("matched no disk", str(cm.exception)) def test_reuse_disk(self): model = make_model() - make_disk(model, serial='aaaa') + make_disk(model, serial="aaaa") fake_up_blockdata(model) with self.assertRaises(Exception) as cm: - model.apply_autoinstall_config([{ - 'type': 'disk', - 'id': 'disk0', - 'serial': 'aaaa', - }, - { - 'type': 'disk', - 'id': 'disk0', - 'serial': 'aaaa', - }]) + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + "serial": "aaaa", + }, + { + "type": "disk", + "id": "disk0", + "serial": "aaaa", + }, + ] + ) self.assertIn("was already used", str(cm.exception)) def test_partition_percent(self): model = make_model() - make_disk(model, serial='aaaa', size=dehumanize_size("100M")) + make_disk(model, serial="aaaa", size=dehumanize_size("100M")) fake_up_blockdata(model) - model.apply_autoinstall_config([ - { - 'type': 'disk', - 'id': 'disk0', - }, - { - 'type': 'partition', - 'id': 'part0', - 'device': 'disk0', - 'size': '50%', - }]) + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + }, + { + "type": "partition", + "id": "part0", + "device": "disk0", + "size": "50%", + }, + ] + ) disk = model._one(type="disk") part = model._one(type="partition") - self.assertEqual(part.size, disk.available_for_partitions//2) + self.assertEqual(part.size, disk.available_for_partitions // 2) def test_partition_remaining(self): model = make_model() - make_disk(model, serial='aaaa', size=dehumanize_size("100M")) + make_disk(model, serial="aaaa", size=dehumanize_size("100M")) fake_up_blockdata(model) - model.apply_autoinstall_config([ - { - 'type': 'disk', - 'id': 'disk0', - }, - { - 'type': 'partition', - 'id': 'part0', - 'device': 'disk0', - 'size': dehumanize_size('50M'), - }, - { - 'type': 'partition', - 'id': 'part1', - 'device': 'disk0', - 'size': -1, - }, - ]) + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + }, + { + "type": "partition", + "id": "part0", + "device": "disk0", + "size": dehumanize_size("50M"), + }, + { + "type": "partition", + "id": "part1", + "device": "disk0", + "size": -1, + }, + ] + ) disk = model._one(type="disk") part1 = model._one(type="partition", id="part1") self.assertEqual( - part1.size, disk.available_for_partitions - dehumanize_size('50M')) + part1.size, disk.available_for_partitions - dehumanize_size("50M") + ) def test_partition_not_final_remaining_size(self): model = make_model() - make_disk(model, serial='aaaa', size=dehumanize_size("100M")) + make_disk(model, serial="aaaa", size=dehumanize_size("100M")) fake_up_blockdata(model) with self.assertRaises(NotFinalPartitionError): - model.apply_autoinstall_config([ - { - 'type': 'disk', - 'id': 'disk0', - }, - { - 'type': 'partition', - 'id': 'part0', - 'device': 'disk0', - 'size': dehumanize_size('50M'), - }, - { - 'type': 'partition', - 'id': 'part1', - 'device': 'disk0', - 'size': -1, - }, - { - 'type': 'partition', - 'id': 'part2', - 'device': 'disk0', - 'size': dehumanize_size('10M'), - }, - ]) + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + }, + { + "type": "partition", + "id": "part0", + "device": "disk0", + "size": dehumanize_size("50M"), + }, + { + "type": "partition", + "id": "part1", + "device": "disk0", + "size": -1, + }, + { + "type": "partition", + "id": "part2", + "device": "disk0", + "size": dehumanize_size("10M"), + }, + ] + ) def test_extended_partition_remaining_size(self): model = make_model(bootloader=Bootloader.BIOS, storage_version=2) - make_disk(model, serial='aaaa', size=dehumanize_size("100M")) + make_disk(model, serial="aaaa", size=dehumanize_size("100M")) fake_up_blockdata(model) - model.apply_autoinstall_config([ - { - 'type': 'disk', - 'id': 'disk0', - 'ptable': 'msdos', - }, - { - 'type': 'partition', - 'id': 'part0', - 'device': 'disk0', - 'size': dehumanize_size('40M'), - }, - { - 'type': 'partition', - 'id': 'part1', - 'device': 'disk0', - 'size': -1, - 'flag': 'extended', - }, - { - 'type': 'partition', - 'number': 5, - 'id': 'part5', - 'device': 'disk0', - 'size': dehumanize_size('10M'), - 'flag': 'logical', - }, - ]) + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + "ptable": "msdos", + }, + { + "type": "partition", + "id": "part0", + "device": "disk0", + "size": dehumanize_size("40M"), + }, + { + "type": "partition", + "id": "part1", + "device": "disk0", + "size": -1, + "flag": "extended", + }, + { + "type": "partition", + "number": 5, + "id": "part5", + "device": "disk0", + "size": dehumanize_size("10M"), + "flag": "logical", + }, + ] + ) extended = model._one(type="partition", id="part1") # Disk test.img: 100 MiB, 104857600 bytes, 204800 sectors # Units: sectors of 1 * 512 = 512 bytes @@ -709,31 +736,33 @@ class TestAutoInstallConfig(unittest.TestCase): def test_logical_partition_remaining_size(self): model = make_model(bootloader=Bootloader.BIOS, storage_version=2) - make_disk(model, serial='aaaa', size=dehumanize_size("100M")) + make_disk(model, serial="aaaa", size=dehumanize_size("100M")) fake_up_blockdata(model) - model.apply_autoinstall_config([ - { - 'type': 'disk', - 'id': 'disk0', - 'ptable': 'msdos', - }, - { - 'type': 'partition', - 'id': 'part0', - 'device': 'disk0', - 'size': dehumanize_size('40M'), - 'flag': 'extended', - }, - { - 'type': 'partition', - 'number': 5, - 'id': 'part5', - 'device': 'disk0', - 'size': -1, - 'flag': 'logical', - }, - ]) + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + "ptable": "msdos", + }, + { + "type": "partition", + "id": "part0", + "device": "disk0", + "size": dehumanize_size("40M"), + "flag": "extended", + }, + { + "type": "partition", + "number": 5, + "id": "part5", + "device": "disk0", + "size": -1, + "flag": "logical", + }, + ] + ) disk = model._one(type="disk") extended = model._one(type="partition", id="part0") logical = model._one(type="partition", id="part5") @@ -760,44 +789,46 @@ class TestAutoInstallConfig(unittest.TestCase): def test_partition_remaining_size_in_extended_and_logical(self): model = make_model(bootloader=Bootloader.BIOS, storage_version=2) - make_disk(model, serial='aaaa', size=dehumanize_size("100M")) + make_disk(model, serial="aaaa", size=dehumanize_size("100M")) fake_up_blockdata(model) - model.apply_autoinstall_config([ - { - 'type': 'disk', - 'id': 'disk0', - 'ptable': 'msdos', - }, - { - 'type': 'partition', - 'id': 'part0', - 'device': 'disk0', - 'size': dehumanize_size('40M'), - }, - { - 'type': 'partition', - 'id': 'part1', - 'device': 'disk0', - 'size': -1, - 'flag': 'extended', - }, - { - 'type': 'partition', - 'number': 5, - 'id': 'part5', - 'device': 'disk0', - 'size': dehumanize_size('10M'), - 'flag': 'logical', - }, - { - 'type': 'partition', - 'number': 6, - 'id': 'part6', - 'device': 'disk0', - 'size': -1, - 'flag': 'logical', - }, - ]) + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + "ptable": "msdos", + }, + { + "type": "partition", + "id": "part0", + "device": "disk0", + "size": dehumanize_size("40M"), + }, + { + "type": "partition", + "id": "part1", + "device": "disk0", + "size": -1, + "flag": "extended", + }, + { + "type": "partition", + "number": 5, + "id": "part5", + "device": "disk0", + "size": dehumanize_size("10M"), + "flag": "logical", + }, + { + "type": "partition", + "number": 6, + "id": "part6", + "device": "disk0", + "size": -1, + "flag": "logical", + }, + ] + ) extended = model._one(type="partition", id="part1") p5 = model._one(type="partition", id="part5") p6 = model._one(type="partition", id="part6") @@ -822,62 +853,64 @@ class TestAutoInstallConfig(unittest.TestCase): def test_partition_remaining_size_in_extended_and_logical_multiple(self): model = make_model(bootloader=Bootloader.BIOS, storage_version=2) - make_disk(model, serial='aaaa', size=dehumanize_size("20G")) + make_disk(model, serial="aaaa", size=dehumanize_size("20G")) fake_up_blockdata(model) - model.apply_autoinstall_config([ - { - 'type': 'disk', - 'id': 'disk0', - 'ptable': 'msdos', - }, - { - 'type': 'partition', - 'flag': 'boot', - 'device': 'disk0', - 'size': dehumanize_size('1G'), - 'id': 'partition-boot', - }, - { - 'type': 'partition', - 'device': 'disk0', - 'size': dehumanize_size('6G'), - 'id': 'partition-root', - }, - { - 'type': 'partition', - 'device': 'disk0', - 'size': dehumanize_size('100M'), - 'id': 'partition-swap', - }, - { - 'type': 'partition', - 'device': 'disk0', - 'size': -1, - 'flag': 'extended', - 'id': 'partition-extended', - }, - { - 'type': 'partition', - 'device': 'disk0', - 'size': dehumanize_size('1G'), - 'flag': 'logical', - 'id': 'partition-tmp', - }, - { - 'type': 'partition', - 'device': 'disk0', - 'size': dehumanize_size('2G'), - 'flag': 'logical', - 'id': 'partition-var', - }, - { - 'type': 'partition', - 'device': 'disk0', - 'size': -1, - 'flag': 'logical', - 'id': 'partition-home', - } - ]) + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + "ptable": "msdos", + }, + { + "type": "partition", + "flag": "boot", + "device": "disk0", + "size": dehumanize_size("1G"), + "id": "partition-boot", + }, + { + "type": "partition", + "device": "disk0", + "size": dehumanize_size("6G"), + "id": "partition-root", + }, + { + "type": "partition", + "device": "disk0", + "size": dehumanize_size("100M"), + "id": "partition-swap", + }, + { + "type": "partition", + "device": "disk0", + "size": -1, + "flag": "extended", + "id": "partition-extended", + }, + { + "type": "partition", + "device": "disk0", + "size": dehumanize_size("1G"), + "flag": "logical", + "id": "partition-tmp", + }, + { + "type": "partition", + "device": "disk0", + "size": dehumanize_size("2G"), + "flag": "logical", + "id": "partition-var", + }, + { + "type": "partition", + "device": "disk0", + "size": -1, + "flag": "logical", + "id": "partition-home", + }, + ] + ) p_boot = model._one(type="partition", id="partition-boot") p_root = model._one(type="partition", id="partition-root") p_swap = model._one(type="partition", id="partition-swap") @@ -918,67 +951,73 @@ class TestAutoInstallConfig(unittest.TestCase): def test_lv_percent(self): model = make_model() - make_disk(model, serial='aaaa', size=dehumanize_size("100M")) + make_disk(model, serial="aaaa", size=dehumanize_size("100M")) fake_up_blockdata(model) - model.apply_autoinstall_config([ - { - 'type': 'disk', - 'id': 'disk0', - }, - { - 'type': 'lvm_volgroup', - 'id': 'vg0', - 'name': 'vg0', - 'devices': ['disk0'], - }, - { - 'type': 'lvm_partition', - 'id': 'lv1', - 'name': 'lv1', - 'volgroup': 'vg0', - 'size': "50%", - }, - ]) + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + }, + { + "type": "lvm_volgroup", + "id": "vg0", + "name": "vg0", + "devices": ["disk0"], + }, + { + "type": "lvm_partition", + "id": "lv1", + "name": "lv1", + "volgroup": "vg0", + "size": "50%", + }, + ] + ) vg = model._one(type="lvm_volgroup") lv1 = model._one(type="lvm_partition") - self.assertEqual(lv1.size, vg.available_for_partitions//2) + self.assertEqual(lv1.size, vg.available_for_partitions // 2) def test_lv_remaining(self): model = make_model() - make_disk(model, serial='aaaa', size=dehumanize_size("100M")) + make_disk(model, serial="aaaa", size=dehumanize_size("100M")) fake_up_blockdata(model) - model.apply_autoinstall_config([ - { - 'type': 'disk', - 'id': 'disk0', - }, - { - 'type': 'lvm_volgroup', - 'id': 'vg0', - 'name': 'vg0', - 'devices': ['disk0'], - }, - { - 'type': 'lvm_partition', - 'id': 'lv1', - 'name': 'lv1', - 'volgroup': 'vg0', - 'size': dehumanize_size("50M"), - }, - { - 'type': 'lvm_partition', - 'id': 'lv2', - 'name': 'lv2', - 'volgroup': 'vg0', - 'size': -1, - }, - ]) + model.apply_autoinstall_config( + [ + { + "type": "disk", + "id": "disk0", + }, + { + "type": "lvm_volgroup", + "id": "vg0", + "name": "vg0", + "devices": ["disk0"], + }, + { + "type": "lvm_partition", + "id": "lv1", + "name": "lv1", + "volgroup": "vg0", + "size": dehumanize_size("50M"), + }, + { + "type": "lvm_partition", + "id": "lv2", + "name": "lv2", + "volgroup": "vg0", + "size": -1, + }, + ] + ) vg = model._one(type="lvm_volgroup") - lv2 = model._one(type="lvm_partition", id='lv2') + lv2 = model._one(type="lvm_partition", id="lv2") self.assertEqual( - lv2.size, align_down( - vg.available_for_partitions - dehumanize_size("50M"), - LVM_CHUNK_SIZE)) + lv2.size, + align_down( + vg.available_for_partitions - dehumanize_size("50M"), LVM_CHUNK_SIZE + ), + ) def test_render_does_not_include_unreferenced(self): model = make_model(Bootloader.NONE) @@ -986,9 +1025,9 @@ class TestAutoInstallConfig(unittest.TestCase): disk2 = make_disk(model, preserve=True) disk1p1 = make_partition(model, disk1, preserve=True) disk2p1 = make_partition(model, disk2, preserve=True) - fs = model.add_filesystem(disk1p1, 'ext4') - model.add_mount(fs, '/') - rendered_ids = {action['id'] for action in model._render_actions()} + fs = model.add_filesystem(disk1p1, "ext4") + model.add_mount(fs, "/") + rendered_ids = {action["id"] for action in model._render_actions()} self.assertTrue(disk1.id in rendered_ids) self.assertTrue(disk1p1.id in rendered_ids) self.assertTrue(disk2.id not in rendered_ids) @@ -1000,12 +1039,11 @@ class TestAutoInstallConfig(unittest.TestCase): disk2 = make_disk(model, preserve=True) disk1p1 = make_partition(model, disk1, preserve=True) disk2p1 = make_partition(model, disk2, preserve=True) - fs = model.add_filesystem(disk1p1, 'ext4') - model.add_mount(fs, '/') + fs = model.add_filesystem(disk1p1, "ext4") + model.add_mount(fs, "/") rendered_ids = { - action['id'] - for action in model._render_actions(ActionRenderMode.FOR_API) - } + action["id"] for action in model._render_actions(ActionRenderMode.FOR_API) + } self.assertTrue(disk1.id in rendered_ids) self.assertTrue(disk1p1.id in rendered_ids) self.assertTrue(disk2.id in rendered_ids) @@ -1015,12 +1053,11 @@ class TestAutoInstallConfig(unittest.TestCase): model = make_model(Bootloader.NONE) disk1 = make_disk(model, preserve=True) disk1p1 = make_partition(model, disk1, preserve=True) - fs = model.add_filesystem(disk1p1, 'ext4') - mnt = model.add_mount(fs, '/') + fs = model.add_filesystem(disk1p1, "ext4") + mnt = model.add_mount(fs, "/") rendered_ids = { - action['id'] - for action in model._render_actions(ActionRenderMode.DEVICES) - } + action["id"] for action in model._render_actions(ActionRenderMode.DEVICES) + } self.assertTrue(disk1.id in rendered_ids) self.assertTrue(disk1p1.id in rendered_ids) self.assertTrue(fs.id not in rendered_ids) @@ -1030,29 +1067,31 @@ class TestAutoInstallConfig(unittest.TestCase): model = make_model(Bootloader.NONE) disk1 = make_disk(model, preserve=True) disk1p1 = make_partition(model, disk1, preserve=True) - disk1p1.path = '/dev/vda1' - fs = model.add_filesystem(disk1p1, 'ext4') - mnt = model.add_mount(fs, '/') + disk1p1.path = "/dev/vda1" + fs = model.add_filesystem(disk1p1, "ext4") + mnt = model.add_mount(fs, "/") actions = model._render_actions(ActionRenderMode.FORMAT_MOUNT) - rendered_by_id = {action['id']: action for action in actions} + rendered_by_id = {action["id"]: action for action in actions} self.assertTrue(disk1.id not in rendered_by_id) self.assertTrue(disk1p1.id not in rendered_by_id) self.assertTrue(fs.id in rendered_by_id) self.assertTrue(mnt.id in rendered_by_id) - vol_id = rendered_by_id[fs.id]['volume'] - self.assertEqual(rendered_by_id[vol_id]['type'], 'device') - self.assertEqual(rendered_by_id[vol_id]['path'], '/dev/vda1') + vol_id = rendered_by_id[fs.id]["volume"] + self.assertEqual(rendered_by_id[vol_id]["type"], "device") + self.assertEqual(rendered_by_id[vol_id]["path"], "/dev/vda1") def test_render_includes_all_partitions(self): model = make_model(Bootloader.NONE) disk1 = make_disk(model, preserve=True) - disk1p1 = make_partition(model, disk1, preserve=True, - offset=1 << 20, size=512 << 20) - disk1p2 = make_partition(model, disk1, preserve=True, - offset=513 << 20, size=8192 << 20) - fs = model.add_filesystem(disk1p2, 'ext4') - model.add_mount(fs, '/') - rendered_ids = {action['id'] for action in model._render_actions()} + disk1p1 = make_partition( + model, disk1, preserve=True, offset=1 << 20, size=512 << 20 + ) + disk1p2 = make_partition( + model, disk1, preserve=True, offset=513 << 20, size=8192 << 20 + ) + fs = model.add_filesystem(disk1p2, "ext4") + model.add_mount(fs, "/") + rendered_ids = {action["id"] for action in model._render_actions()} self.assertTrue(disk1.id in rendered_ids) self.assertTrue(disk1p1.id in rendered_ids) self.assertTrue(disk1p2.id in rendered_ids) @@ -1061,13 +1100,13 @@ class TestAutoInstallConfig(unittest.TestCase): model = make_model(Bootloader.NONE) disk1 = make_disk(model, preserve=True) disk1p1 = make_partition(model, disk1, preserve=True) - fs = model.add_filesystem(disk1p1, 'ext4') - model.add_mount(fs, '/') + fs = model.add_filesystem(disk1p1, "ext4") + model.add_mount(fs, "/") actions = model._render_actions() for action in actions: - if action['id'] != disk1p1.id: + if action["id"] != disk1p1.id: continue - self.assertEqual(action['number'], 1) + self.assertEqual(action["number"], 1) def test_render_includes_unmounted_new_partition(self): model = make_model(Bootloader.NONE) @@ -1075,9 +1114,9 @@ class TestAutoInstallConfig(unittest.TestCase): disk2 = make_disk(model) disk1p1 = make_partition(model, disk1, preserve=True) disk2p1 = make_partition(model, disk2) - fs = model.add_filesystem(disk1p1, 'ext4') - model.add_mount(fs, '/') - rendered_ids = {action['id'] for action in model._render_actions()} + fs = model.add_filesystem(disk1p1, "ext4") + model.add_mount(fs, "/") + rendered_ids = {action["id"] for action in model._render_actions()} self.assertTrue(disk1.id in rendered_ids) self.assertTrue(disk1p1.id in rendered_ids) self.assertTrue(disk2.id in rendered_ids) @@ -1093,47 +1132,53 @@ class TestPartitionNumbering(unittest.TestCase): self.cur_idx += 1 def test_gpt(self): - m, d1 = make_model_and_disk(ptable='gpt') + m, d1 = make_model_and_disk(ptable="gpt") for _ in range(8): self.assert_next(make_partition(m, d1)) - @parameterized.expand([ - ['msdos', 4], - ['vtoc', 3], - ]) + @parameterized.expand( + [ + ["msdos", 4], + ["vtoc", 3], + ] + ) def test_all_primary(self, ptable, primaries): m = make_model(storage_version=2) d1 = make_disk(m, ptable=ptable) for _ in range(primaries): self.assert_next(make_partition(m, d1)) - @parameterized.expand([ - ['msdos', 4], - ['vtoc', 3], - ]) + @parameterized.expand( + [ + ["msdos", 4], + ["vtoc", 3], + ] + ) def test_primary_and_extended(self, ptable, primaries): m = make_model(storage_version=2) d1 = make_disk(m, ptable=ptable) for _ in range(primaries - 1): self.assert_next(make_partition(m, d1)) size = gaps.largest_gap_size(d1) - self.assert_next(make_partition(m, d1, flag='extended', size=size)) + self.assert_next(make_partition(m, d1, flag="extended", size=size)) for _ in range(3): - self.assert_next(make_partition(m, d1, flag='logical')) + self.assert_next(make_partition(m, d1, flag="logical")) @parameterized.expand( - [[pt, primaries, i] - for pt, primaries in (('msdos', 4), ('vtoc', 3)) - for i in range(3)] + [ + [pt, primaries, i] + for pt, primaries in (("msdos", 4), ("vtoc", 3)) + for i in range(3) + ] ) def test_delete_logical(self, ptable, primaries, idx_to_remove): m = make_model(storage_version=2) d1 = make_disk(m, ptable=ptable) self.assert_next(make_partition(m, d1)) size = gaps.largest_gap_size(d1) - self.assert_next(make_partition(m, d1, flag='extended', size=size)) + self.assert_next(make_partition(m, d1, flag="extended", size=size)) self.cur_idx = primaries + 1 - parts = [make_partition(m, d1, flag='logical') for _ in range(3)] + parts = [make_partition(m, d1, flag="logical") for _ in range(3)] for p in parts: self.assert_next(p) to_remove = parts.pop(idx_to_remove) @@ -1143,24 +1188,29 @@ class TestPartitionNumbering(unittest.TestCase): self.assert_next(p) @parameterized.expand( - [[pt, primaries, i] - for pt, primaries in (('msdos', 4), ('vtoc', 3)) - for i in range(3)] + [ + [pt, primaries, i] + for pt, primaries in (("msdos", 4), ("vtoc", 3)) + for i in range(3) + ] ) def test_out_of_offset_order(self, ptable, primaries, idx_to_remove): m = make_model(storage_version=2) d1 = make_disk(m, ptable=ptable, size=100 << 20) self.assert_next(make_partition(m, d1, size=10 << 20)) size = gaps.largest_gap_size(d1) - self.assert_next(make_partition(m, d1, flag='extended', size=size)) + self.assert_next(make_partition(m, d1, flag="extended", size=size)) self.cur_idx = primaries + 1 parts = [] - parts.append(make_partition( - m, d1, flag='logical', size=9 << 20, offset=30 << 20)) - parts.append(make_partition( - m, d1, flag='logical', size=9 << 20, offset=20 << 20)) - parts.append(make_partition( - m, d1, flag='logical', size=9 << 20, offset=40 << 20)) + parts.append( + make_partition(m, d1, flag="logical", size=9 << 20, offset=30 << 20) + ) + parts.append( + make_partition(m, d1, flag="logical", size=9 << 20, offset=20 << 20) + ) + parts.append( + make_partition(m, d1, flag="logical", size=9 << 20, offset=40 << 20) + ) for p in parts: self.assert_next(p) to_remove = parts.pop(idx_to_remove) @@ -1169,12 +1219,14 @@ class TestPartitionNumbering(unittest.TestCase): for p in parts: self.assert_next(p) - @parameterized.expand([ - [1, 'msdos', 4], - [1, 'vtoc', 3], - [2, 'msdos', 4], - [2, 'vtoc', 3], - ]) + @parameterized.expand( + [ + [1, "msdos", 4], + [1, "vtoc", 3], + [2, "msdos", 4], + [2, "vtoc", 3], + ] + ) def test_no_extra_primary(self, sv, ptable, primaries): m = make_model(storage_version=sv) d1 = make_disk(m, ptable=ptable, size=100 << 30) @@ -1183,7 +1235,7 @@ class TestPartitionNumbering(unittest.TestCase): with self.assertRaises(Exception): make_partition(m, d1) - @parameterized.expand([['gpt'], ['msdos'], ['vtoc']]) + @parameterized.expand([["gpt"], ["msdos"], ["vtoc"]]) def test_p1_preserved(self, ptable): m = make_model(storage_version=2) d1 = make_disk(m, ptable=ptable) @@ -1197,7 +1249,7 @@ class TestPartitionNumbering(unittest.TestCase): self.assertEqual(3, p3.number) self.assertEqual(False, p3.preserve) - @parameterized.expand([['gpt'], ['msdos'], ['vtoc']]) + @parameterized.expand([["gpt"], ["msdos"], ["vtoc"]]) def test_p2_preserved(self, ptable): m = make_model(storage_version=2) d1 = make_disk(m, ptable=ptable) @@ -1213,7 +1265,7 @@ class TestPartitionNumbering(unittest.TestCase): class TestAlignmentData(unittest.TestCase): - @parameterized.expand([['gpt'], ['msdos'], ['vtoc']]) + @parameterized.expand([["gpt"], ["msdos"], ["vtoc"]]) def test_alignment_gaps_coherence(self, ptable): d1 = make_disk(ptable=ptable) ad = d1.alignment_data() @@ -1224,55 +1276,60 @@ class TestAlignmentData(unittest.TestCase): class TestSwap(unittest.TestCase): def test_basic(self): m = make_model() - with mock.patch.object(m, 'should_add_swapfile', return_value=False): + with mock.patch.object(m, "should_add_swapfile", return_value=False): cfg = m.render() - self.assertEqual({'size': 0}, cfg['swap']) + self.assertEqual({"size": 0}, cfg["swap"]) - @parameterized.expand([ - ['ext4'], - ['btrfs'], - ]) + @parameterized.expand( + [ + ["ext4"], + ["btrfs"], + ] + ) def test_should_add_swapfile_nomount(self, fs): m, d1 = make_model_and_disk(Bootloader.BIOS) d1p1 = make_partition(m, d1) m.add_filesystem(d1p1, fs) self.assertTrue(m.should_add_swapfile()) - @parameterized.expand([ - ['ext4', None, True], - ['btrfs', 5, True], - ['btrfs', 4, False], - ['zfs', None, False], - ]) + @parameterized.expand( + [ + ["ext4", None, True], + ["btrfs", 5, True], + ["btrfs", 4, False], + ["zfs", None, False], + ] + ) def test_should_add_swapfile(self, fs, kern_maj_ver, expected): m, d1 = make_model_and_disk(Bootloader.BIOS) d1p1 = make_partition(m, d1) - m.add_mount(m.add_filesystem(d1p1, fs), '/') - with mock.patch('curtin.swap.get_target_kernel_version', - return_value={'major': kern_maj_ver}): + m.add_mount(m.add_filesystem(d1p1, fs), "/") + with mock.patch( + "curtin.swap.get_target_kernel_version", + return_value={"major": kern_maj_ver}, + ): self.assertEqual(expected, m.should_add_swapfile()) def test_should_add_swapfile_has_swappart(self): m, d1 = make_model_and_disk(Bootloader.BIOS) d1p1 = make_partition(m, d1) d1p2 = make_partition(m, d1) - m.add_mount(m.add_filesystem(d1p1, 'ext4'), '/') - m.add_mount(m.add_filesystem(d1p2, 'swap'), '') + m.add_mount(m.add_filesystem(d1p1, "ext4"), "/") + m.add_mount(m.add_filesystem(d1p2, "swap"), "") self.assertFalse(m.should_add_swapfile()) class TestPartition(unittest.TestCase): - def test_is_logical(self): m = make_model(storage_version=2) - d = make_disk(m, ptable='msdos') - make_partition(m, d, flag='extended') - p3 = make_partition(m, d, number=3, flag='swap') - p4 = make_partition(m, d, number=4, flag='boot') + d = make_disk(m, ptable="msdos") + make_partition(m, d, flag="extended") + p3 = make_partition(m, d, number=3, flag="swap") + p4 = make_partition(m, d, number=4, flag="boot") - p5 = make_partition(m, d, number=5, flag='logical') - p6 = make_partition(m, d, number=6, flag='boot') - p7 = make_partition(m, d, number=7, flag='swap') + p5 = make_partition(m, d, number=5, flag="logical") + p6 = make_partition(m, d, number=6, flag="boot") + p7 = make_partition(m, d, number=7, flag="swap") self.assertFalse(p3.is_logical) self.assertFalse(p4.is_logical) @@ -1282,41 +1339,53 @@ class TestPartition(unittest.TestCase): class TestCanmount(SubiTestCase): - @parameterized.expand(( - ('on', True), - ('"on"', True), - ('true', True), - ('off', False), - ('"off"', False), - ('false', False), - ('noauto', False), - ('"noauto"', False), - )) + @parameterized.expand( + ( + ("on", True), + ('"on"', True), + ("true", True), + ("off", False), + ('"off"', False), + ("false", False), + ("noauto", False), + ('"noauto"', False), + ) + ) def test_present(self, value, expected): - property_yaml = f'canmount: {value}' + property_yaml = f"canmount: {value}" properties = yaml.safe_load(property_yaml) for default in (True, False): - self.assertEqual(expected, get_canmount(properties, default), - f'yaml {property_yaml} default {default}') + self.assertEqual( + expected, + get_canmount(properties, default), + f"yaml {property_yaml} default {default}", + ) - @parameterized.expand(( - ['{}'], - ['something-else: on'], - )) + @parameterized.expand( + ( + ["{}"], + ["something-else: on"], + ) + ) def test_not_present(self, property_yaml): properties = yaml.safe_load(property_yaml) for default in (True, False): - self.assertEqual(default, get_canmount(properties, default), - f'yaml {property_yaml} default {default}') + self.assertEqual( + default, + get_canmount(properties, default), + f"yaml {property_yaml} default {default}", + ) - @parameterized.expand(( - ['asdf'], - ['"true"'], - ['"false"'], - )) + @parameterized.expand( + ( + ["asdf"], + ['"true"'], + ['"false"'], + ) + ) def test_invalid(self, value): with self.assertRaises(ValueError): - properties = yaml.safe_load(f'canmount: {value}') + properties = yaml.safe_load(f"canmount: {value}") get_canmount(properties, False) @@ -1324,16 +1393,20 @@ class TestZPool(SubiTestCase): def test_zpool_to_action(self): m = make_model() d = make_disk(m) - zp = make_zpool(model=m, device=d, mountpoint='/', pool='p1') - zfs = make_zfs(model=m, pool=zp, volume='/ROOTFS') + zp = make_zpool(model=m, device=d, mountpoint="/", pool="p1") + zfs = make_zfs(model=m, pool=zp, volume="/ROOTFS") actions = m._render_actions() - a_zp = dict(matching_dicts(actions, type='zpool')[0]) - a_zfs = dict(matching_dicts(actions, type='zfs')[0]) - e_zp = {'vdevs': [d.id], 'pool': 'p1', 'mountpoint': '/', - 'type': 'zpool', 'id': zp.id} - e_zfs = {'pool': zp.id, 'volume': '/ROOTFS', - 'type': 'zfs', 'id': zfs.id} + a_zp = dict(matching_dicts(actions, type="zpool")[0]) + a_zfs = dict(matching_dicts(actions, type="zfs")[0]) + e_zp = { + "vdevs": [d.id], + "pool": "p1", + "mountpoint": "/", + "type": "zpool", + "id": zp.id, + } + e_zfs = {"pool": zp.id, "volume": "/ROOTFS", "type": "zfs", "id": zfs.id} self.assertEqual(e_zp, a_zp) self.assertEqual(e_zfs, a_zfs) @@ -1342,77 +1415,110 @@ class TestZPool(SubiTestCase): d1 = make_disk(m) d2 = make_disk(m) fake_up_blockdata(m) - blockdevs = m._probe_data['blockdev'] + blockdevs = m._probe_data["blockdev"] config = [ - dict(type='disk', id=d1.id, path=d1.path, ptable=d1.ptable, - serial=d1.serial, info={d1.path: blockdevs[d1.path]}), - dict(type='disk', id=d2.id, path=d2.path, ptable=d2.ptable, - serial=d2.serial, info={d2.path: blockdevs[d2.path]}), - dict(type='zpool', id='zpool-1', vdevs=[d1.id], pool='p1', - mountpoint='/', fs_properties=dict(canmount='on')), - dict(type='zpool', id='zpool-2', vdevs=[d2.id], pool='p2', - mountpoint='/srv', fs_properties=dict(canmount='off')), - dict(type='zfs', id='zfs-1', volume='/ROOT', pool='zpool-1', - properties=dict(canmount='off')), - dict(type='zfs', id='zfs-2', volume='/SRV/srv', pool='zpool-2', - properties=dict(mountpoint='/srv', canmount='on')), + dict( + type="disk", + id=d1.id, + path=d1.path, + ptable=d1.ptable, + serial=d1.serial, + info={d1.path: blockdevs[d1.path]}, + ), + dict( + type="disk", + id=d2.id, + path=d2.path, + ptable=d2.ptable, + serial=d2.serial, + info={d2.path: blockdevs[d2.path]}, + ), + dict( + type="zpool", + id="zpool-1", + vdevs=[d1.id], + pool="p1", + mountpoint="/", + fs_properties=dict(canmount="on"), + ), + dict( + type="zpool", + id="zpool-2", + vdevs=[d2.id], + pool="p2", + mountpoint="/srv", + fs_properties=dict(canmount="off"), + ), + dict( + type="zfs", + id="zfs-1", + volume="/ROOT", + pool="zpool-1", + properties=dict(canmount="off"), + ), + dict( + type="zfs", + id="zfs-2", + volume="/SRV/srv", + pool="zpool-2", + properties=dict(mountpoint="/srv", canmount="on"), + ), ] - objs = m._actions_from_config( - config, blockdevs=None, is_probe_data=False) + objs = m._actions_from_config(config, blockdevs=None, is_probe_data=False) actual_d1, actual_d2, zp1, zp2, zfs_zp1, zfs_zp2 = objs self.assertTrue(isinstance(zp1, ZPool)) - self.assertEqual('zpool-1', zp1.id) + self.assertEqual("zpool-1", zp1.id) self.assertEqual([actual_d1], zp1.vdevs) - self.assertEqual('p1', zp1.pool) - self.assertEqual('/', zp1.mountpoint) - self.assertEqual('/', zp1.path) + self.assertEqual("p1", zp1.pool) + self.assertEqual("/", zp1.mountpoint) + self.assertEqual("/", zp1.path) self.assertEqual([zfs_zp1], zp1._zfses) self.assertTrue(isinstance(zp2, ZPool)) - self.assertEqual('zpool-2', zp2.id) + self.assertEqual("zpool-2", zp2.id) self.assertEqual([actual_d2], zp2.vdevs) - self.assertEqual('p2', zp2.pool) - self.assertEqual('/srv', zp2.mountpoint) + self.assertEqual("p2", zp2.pool) + self.assertEqual("/srv", zp2.mountpoint) self.assertEqual(None, zp2.path) self.assertEqual([zfs_zp2], zp2._zfses) self.assertTrue(isinstance(zfs_zp1, ZFS)) - self.assertEqual('zfs-1', zfs_zp1.id) + self.assertEqual("zfs-1", zfs_zp1.id) self.assertEqual(zp1, zfs_zp1.pool) - self.assertEqual('/ROOT', zfs_zp1.volume) + self.assertEqual("/ROOT", zfs_zp1.volume) self.assertEqual(None, zfs_zp1.path) self.assertTrue(isinstance(zfs_zp2, ZFS)) - self.assertEqual('zfs-2', zfs_zp2.id) + self.assertEqual("zfs-2", zfs_zp2.id) self.assertEqual(zp2, zfs_zp2.pool) - self.assertEqual('/SRV/srv', zfs_zp2.volume) - self.assertEqual('/srv', zfs_zp2.path) + self.assertEqual("/SRV/srv", zfs_zp2.volume) + self.assertEqual("/srv", zfs_zp2.path) class TestRootfs(SubiTestCase): def test_mount_rootfs(self): m, p = make_model_and_partition() fs = make_filesystem(m, p) - m.add_mount(fs, '/') + m.add_mount(fs, "/") self.assertTrue(m.is_root_mounted()) def test_mount_srv(self): m, p = make_model_and_partition() fs = make_filesystem(m, p) - m.add_mount(fs, '/srv') + m.add_mount(fs, "/srv") self.assertFalse(m.is_root_mounted()) def test_zpool_not_rootfs_because_not_canmount(self): m = make_model() - make_zpool(model=m, mountpoint='/', fs_properties=dict(canmount='off')) + make_zpool(model=m, mountpoint="/", fs_properties=dict(canmount="off")) self.assertFalse(m.is_root_mounted()) def test_zpool_rootfs_because_canmount(self): m = make_model() - make_zpool(model=m, mountpoint='/', fs_properties=dict(canmount='on')) + make_zpool(model=m, mountpoint="/", fs_properties=dict(canmount="on")) self.assertTrue(m.is_root_mounted()) def test_zpool_nonrootfs_mountpoint(self): m = make_model() - make_zpool(model=m, mountpoint='/srv') + make_zpool(model=m, mountpoint="/srv") self.assertFalse(m.is_root_mounted()) diff --git a/subiquity/models/tests/test_keyboard.py b/subiquity/models/tests/test_keyboard.py index 27d41f56..fbdecd07 100644 --- a/subiquity/models/tests/test_keyboard.py +++ b/subiquity/models/tests/test_keyboard.py @@ -13,15 +13,10 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from subiquitycore.tests.parameterized import parameterized - -from subiquitycore.tests import SubiTestCase - from subiquity.common.types import KeyboardSetting -from subiquity.models.keyboard import ( - InconsistentMultiLayoutError, - KeyboardModel, -) +from subiquity.models.keyboard import InconsistentMultiLayoutError, KeyboardModel +from subiquitycore.tests import SubiTestCase +from subiquitycore.tests.parameterized import parameterized class TestKeyboardModel(SubiTestCase): @@ -30,9 +25,9 @@ class TestKeyboardModel(SubiTestCase): def testDefaultUS(self): self.assertIsNone(self.model._setting) - self.assertEqual('us', self.model.setting.layout) + self.assertEqual("us", self.model.setting.layout) - @parameterized.expand((['zz'], ['en'])) + @parameterized.expand((["zz"], ["en"])) def testSetToInvalidLayout(self, layout): initial = self.model.setting val = KeyboardSetting(layout=layout) @@ -40,39 +35,41 @@ class TestKeyboardModel(SubiTestCase): self.model.setting = val self.assertEqual(initial, self.model.setting) - @parameterized.expand((['zz'])) + @parameterized.expand((["zz"])) def testSetToInvalidVariant(self, variant): initial = self.model.setting - val = KeyboardSetting(layout='us', variant=variant) + val = KeyboardSetting(layout="us", variant=variant) with self.assertRaises(ValueError): self.model.setting = val self.assertEqual(initial, self.model.setting) def testMultiLayout(self): - val = KeyboardSetting(layout='us,ara', variant=',') + val = KeyboardSetting(layout="us,ara", variant=",") self.model.setting = val self.assertEqual(self.model.setting, val) def testInconsistentMultiLayout(self): initial = self.model.setting - val = KeyboardSetting(layout='us,ara', variant='') + val = KeyboardSetting(layout="us,ara", variant="") with self.assertRaises(InconsistentMultiLayoutError): self.model.setting = val self.assertEqual(self.model.setting, initial) def testInvalidMultiLayout(self): initial = self.model.setting - val = KeyboardSetting(layout='us,ara', variant='zz,') + val = KeyboardSetting(layout="us,ara", variant="zz,") with self.assertRaises(ValueError): self.model.setting = val self.assertEqual(self.model.setting, initial) - @parameterized.expand([ - ['ast_ES.UTF-8', 'es', 'ast'], - ['de_DE.UTF-8', 'de', ''], - ['fr_FR.UTF-8', 'fr', 'latin9'], - ['oc', 'us', ''], - ]) + @parameterized.expand( + [ + ["ast_ES.UTF-8", "es", "ast"], + ["de_DE.UTF-8", "de", ""], + ["fr_FR.UTF-8", "fr", "latin9"], + ["oc", "us", ""], + ] + ) def testSettingForLang(self, lang, layout, variant): val = self.model.setting_for_lang(lang) self.assertEqual(layout, val.layout) @@ -81,20 +78,22 @@ class TestKeyboardModel(SubiTestCase): def testAllLangsHaveKeyboardSuggestion(self): # every language in the list needs a suggestion, # even if only the default - with open('languagelist') as fp: + with open("languagelist") as fp: for line in fp.readlines(): - tokens = line.split(':') + tokens = line.split(":") locale = tokens[1] self.assertIn(locale, self.model.layout_for_lang.keys()) def testLoadSuggestions(self): - data = self.tmp_path('kbd.yaml') - with open(data, 'w') as fp: - fp.write(''' + data = self.tmp_path("kbd.yaml") + with open(data, "w") as fp: + fp.write( + """ aa_BB.UTF-8: layout: aa variant: cc -''') +""" + ) actual = self.model.load_layout_suggestions(data) - expected = {'aa_BB.UTF-8': KeyboardSetting(layout='aa', variant='cc')} + expected = {"aa_BB.UTF-8": KeyboardSetting(layout="aa", variant="cc")} self.assertEqual(expected, actual) diff --git a/subiquity/models/tests/test_locale.py b/subiquity/models/tests/test_locale.py index b043d593..589ce4ee 100644 --- a/subiquity/models/tests/test_locale.py +++ b/subiquity/models/tests/test_locale.py @@ -41,16 +41,16 @@ class TestLocaleModel(unittest.IsolatedAsyncioTestCase): self.model.selected_language = "fr_FR" with mock.patch("subiquity.models.locale.arun_command") as arun_cmd: # Currently, the default for fr_FR is fr_FR.ISO8859-1 - with mock.patch("subiquity.models.locale.locale.normalize", - return_value="fr_FR.UTF-8"): + with mock.patch( + "subiquity.models.locale.locale.normalize", return_value="fr_FR.UTF-8" + ): await self.model.localectl_set_locale() arun_cmd.assert_called_once_with(expected_cmd, check=True) async def test_try_localectl_set_locale(self): self.model.selected_language = "fr_FR.UTF-8" exc = subprocess.CalledProcessError(returncode=1, cmd=["localedef"]) - with mock.patch("subiquity.models.locale.arun_command", - side_effect=exc): + with mock.patch("subiquity.models.locale.arun_command", side_effect=exc): await self.model.try_localectl_set_locale() with mock.patch("subiquity.models.locale.arun_command"): await self.model.try_localectl_set_locale() diff --git a/subiquity/models/tests/test_mirror.py b/subiquity/models/tests/test_mirror.py index c4adacdb..252ba25c 100644 --- a/subiquity/models/tests/test_mirror.py +++ b/subiquity/models/tests/test_mirror.py @@ -18,41 +18,37 @@ import unittest from unittest import mock from subiquity.models.mirror import ( - countrify_uri, LEGACY_DEFAULT_PRIMARY_SECTION, + LegacyPrimaryEntry, MirrorModel, MirrorSelectionFallback, - LegacyPrimaryEntry, PrimaryEntry, - ) + countrify_uri, +) class TestCountrifyUrl(unittest.TestCase): def test_official_archive(self): self.assertEqual( - countrify_uri( - "http://archive.ubuntu.com/ubuntu", - cc="fr"), - "http://fr.archive.ubuntu.com/ubuntu") + countrify_uri("http://archive.ubuntu.com/ubuntu", cc="fr"), + "http://fr.archive.ubuntu.com/ubuntu", + ) self.assertEqual( - countrify_uri( - "http://archive.ubuntu.com/ubuntu", - cc="us"), - "http://us.archive.ubuntu.com/ubuntu") + countrify_uri("http://archive.ubuntu.com/ubuntu", cc="us"), + "http://us.archive.ubuntu.com/ubuntu", + ) def test_ports_archive(self): self.assertEqual( - countrify_uri( - "http://ports.ubuntu.com/ubuntu-ports", - cc="fr"), - "http://fr.ports.ubuntu.com/ubuntu-ports") + countrify_uri("http://ports.ubuntu.com/ubuntu-ports", cc="fr"), + "http://fr.ports.ubuntu.com/ubuntu-ports", + ) self.assertEqual( - countrify_uri( - "http://ports.ubuntu.com/ubuntu-ports", - cc="us"), - "http://us.ports.ubuntu.com/ubuntu-ports") + countrify_uri("http://ports.ubuntu.com/ubuntu-ports", cc="us"), + "http://us.ports.ubuntu.com/ubuntu-ports", + ) class TestPrimaryEntry(unittest.TestCase): @@ -78,21 +74,20 @@ class TestPrimaryEntry(unittest.TestCase): model = MirrorModel() entry = PrimaryEntry.from_config("country-mirror", parent=model) - self.assertEqual(entry, PrimaryEntry(parent=model, - country_mirror=True)) + self.assertEqual(entry, PrimaryEntry(parent=model, country_mirror=True)) with self.assertRaises(ValueError): entry = PrimaryEntry.from_config({}, parent=model) - entry = PrimaryEntry.from_config( - {"uri": "http://mirror"}, parent=model) - self.assertEqual(entry, PrimaryEntry( - uri="http://mirror", parent=model)) + entry = PrimaryEntry.from_config({"uri": "http://mirror"}, parent=model) + self.assertEqual(entry, PrimaryEntry(uri="http://mirror", parent=model)) entry = PrimaryEntry.from_config( - {"uri": "http://mirror", "arches": ["amd64"]}, parent=model) - self.assertEqual(entry, PrimaryEntry( - uri="http://mirror", arches=["amd64"], parent=model)) + {"uri": "http://mirror", "arches": ["amd64"]}, parent=model + ) + self.assertEqual( + entry, PrimaryEntry(uri="http://mirror", arches=["amd64"], parent=model) + ) class TestLegacyPrimaryEntry(unittest.TestCase): @@ -111,8 +106,8 @@ class TestLegacyPrimaryEntry(unittest.TestCase): def test_get_uri(self): self.model.architecture = "amd64" primary = LegacyPrimaryEntry( - [{"uri": "http://myurl", "arches": "amd64"}], - parent=self.model) + [{"uri": "http://myurl", "arches": "amd64"}], parent=self.model + ) self.assertEqual(primary.uri, "http://myurl") def test_set_uri(self): @@ -130,8 +125,9 @@ class TestMirrorModel(unittest.TestCase): self.model_legacy = MirrorModel() self.model_legacy.legacy_primary = True self.model_legacy.primary_candidates = [ - LegacyPrimaryEntry(copy.deepcopy( - LEGACY_DEFAULT_PRIMARY_SECTION), parent=self.model_legacy), + LegacyPrimaryEntry( + copy.deepcopy(LEGACY_DEFAULT_PRIMARY_SECTION), parent=self.model_legacy + ), ] self.candidate_legacy = self.model_legacy.primary_candidates[0] self.model_legacy.primary_staged = self.candidate_legacy @@ -152,7 +148,8 @@ class TestMirrorModel(unittest.TestCase): [ "http://CC.archive.ubuntu.com/ubuntu", "http://CC.ports.ubuntu.com/ubuntu-ports", - ]) + ], + ) do_test(self.model) do_test(self.model_legacy) @@ -167,7 +164,7 @@ class TestMirrorModel(unittest.TestCase): def test_default_disable_components(self): def do_test(model, candidate): config = model.get_apt_config_staged() - self.assertEqual([], config['disable_components']) + self.assertEqual([], config["disable_components"]) # The candidate[0] is a country-mirror, skip it. candidate = self.model.primary_candidates[1] @@ -179,15 +176,14 @@ class TestMirrorModel(unittest.TestCase): # autoinstall loads to the config directly model = MirrorModel() data = { - 'disable_components': ['non-free'], - 'fallback': 'offline-install', + "disable_components": ["non-free"], + "fallback": "offline-install", } model.load_autoinstall_data(data) self.assertFalse(model.legacy_primary) model.primary_candidates[0].stage() - self.assertEqual(set(['non-free']), model.disabled_components) - self.assertEqual(model.primary_candidates, - model._default_primary_entries()) + self.assertEqual(set(["non-free"]), model.disabled_components) + self.assertEqual(model.primary_candidates, model._default_primary_entries()) def test_from_autoinstall_modern(self): data = { @@ -202,16 +198,19 @@ class TestMirrorModel(unittest.TestCase): } model = MirrorModel() model.load_autoinstall_data(data) - self.assertEqual(model.primary_candidates, [ - PrimaryEntry(parent=model, country_mirror=True), - PrimaryEntry(uri="http://mirror", parent=model), - ]) + self.assertEqual( + model.primary_candidates, + [ + PrimaryEntry(parent=model, country_mirror=True), + PrimaryEntry(uri="http://mirror", parent=model), + ], + ) def test_disable_add(self): def do_test(model, candidate): - expected = ['things', 'stuff'] + expected = ["things", "stuff"] model.disable_components(expected.copy(), add=True) - actual = model.get_apt_config_staged()['disable_components'] + actual = model.get_apt_config_staged()["disable_components"] actual.sort() expected.sort() self.assertEqual(expected, actual) @@ -223,11 +222,11 @@ class TestMirrorModel(unittest.TestCase): do_test(self.model_legacy, self.candidate_legacy) def test_disable_remove(self): - self.model.disabled_components = set(['a', 'b', 'things']) - to_remove = ['things', 'stuff'] - expected = ['a', 'b'] + self.model.disabled_components = set(["a", "b", "things"]) + to_remove = ["things", "stuff"] + expected = ["a", "b"] self.model.disable_components(to_remove, add=False) - actual = self.model.get_apt_config_staged()['disable_components'] + actual = self.model.get_apt_config_staged()["disable_components"] actual.sort() expected.sort() self.assertEqual(expected, actual) @@ -242,12 +241,15 @@ class TestMirrorModel(unittest.TestCase): self.model.legacy_primary = False self.model.fallback = MirrorSelectionFallback.OFFLINE_INSTALL self.model.primary_candidates = [ - PrimaryEntry(uri=None, arches=None, - country_mirror=True, parent=self.model), - PrimaryEntry(uri="http://mirror.local/ubuntu", - arches=None, parent=self.model), - PrimaryEntry(uri="http://amd64.mirror.local/ubuntu", - arches=["amd64"], parent=self.model), + PrimaryEntry(uri=None, arches=None, country_mirror=True, parent=self.model), + PrimaryEntry( + uri="http://mirror.local/ubuntu", arches=None, parent=self.model + ), + PrimaryEntry( + uri="http://amd64.mirror.local/ubuntu", + arches=["amd64"], + parent=self.model, + ), ] cfg = self.model.make_autoinstall() self.assertEqual(cfg["disable_components"], ["non-free"]) @@ -259,8 +261,7 @@ class TestMirrorModel(unittest.TestCase): self.model.disabled_components = set(["non-free"]) self.model.fallback = MirrorSelectionFallback.ABORT self.model.legacy_primary = True - self.model.primary_candidates = \ - [LegacyPrimaryEntry(primary, parent=self.model)] + self.model.primary_candidates = [LegacyPrimaryEntry(primary, parent=self.model)] self.model.primary_candidates[0].elect() cfg = self.model.make_autoinstall() self.assertEqual(cfg["disable_components"], ["non-free"]) @@ -269,22 +270,23 @@ class TestMirrorModel(unittest.TestCase): def test_create_primary_candidate(self): self.model.legacy_primary = False - candidate = self.model.create_primary_candidate( - "http://mymirror.valid") + candidate = self.model.create_primary_candidate("http://mymirror.valid") self.assertEqual(candidate.uri, "http://mymirror.valid") self.model.legacy_primary = True - candidate = self.model.create_primary_candidate( - "http://mymirror.valid") + candidate = self.model.create_primary_candidate("http://mymirror.valid") self.assertEqual(candidate.uri, "http://mymirror.valid") def test_wants_geoip(self): country_mirror_candidates = mock.patch.object( - self.model, "country_mirror_candidates", return_value=iter([])) + self.model, "country_mirror_candidates", return_value=iter([]) + ) with country_mirror_candidates: self.assertFalse(self.model.wants_geoip()) country_mirror_candidates = mock.patch.object( - self.model, "country_mirror_candidates", - return_value=iter([PrimaryEntry(parent=self.model)])) + self.model, + "country_mirror_candidates", + return_value=iter([PrimaryEntry(parent=self.model)]), + ) with country_mirror_candidates: self.assertTrue(self.model.wants_geoip()) diff --git a/subiquity/models/tests/test_network.py b/subiquity/models/tests/test_network.py index 22db3d3b..636812c2 100644 --- a/subiquity/models/tests/test_network.py +++ b/subiquity/models/tests/test_network.py @@ -17,9 +17,7 @@ import subprocess import unittest from unittest import mock -from subiquity.models.network import ( - NetworkModel, -) +from subiquity.models.network import NetworkModel class TestNetworkModel(unittest.IsolatedAsyncioTestCase): @@ -42,6 +40,5 @@ class TestNetworkModel(unittest.IsolatedAsyncioTestCase): self.assertFalse(await self.model.is_nm_enabled()) with mock.patch("subiquity.models.network.arun_command") as arun: - arun.side_effect = subprocess.CalledProcessError( - 1, [], None, "error") + arun.side_effect = subprocess.CalledProcessError(1, [], None, "error") self.assertFalse(await self.model.is_nm_enabled()) diff --git a/subiquity/models/tests/test_source.py b/subiquity/models/tests/test_source.py index d3562b19..a399d7d5 100644 --- a/subiquity/models/tests/test_source.py +++ b/subiquity/models/tests/test_source.py @@ -18,36 +18,33 @@ import random import shutil import tempfile import unittest + import yaml +from subiquity.models.source import SourceModel from subiquitycore.tests.util import random_string -from subiquity.models.source import ( - SourceModel, - ) - def make_entry(**fields): raw = { - 'name': { - 'en': random_string(), - }, - 'description': { - 'en': random_string(), - }, - 'id': random_string(), - 'type': random_string(), - 'variant': random_string(), - 'path': random_string(), - 'size': random.randint(1000000, 2000000), - } + "name": { + "en": random_string(), + }, + "description": { + "en": random_string(), + }, + "id": random_string(), + "type": random_string(), + "variant": random_string(), + "path": random_string(), + "size": random.randint(1000000, 2000000), + } for k, v in fields.items(): raw[k] = v return raw class TestSourceModel(unittest.TestCase): - def tdir(self): tdir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, tdir) @@ -56,83 +53,78 @@ class TestSourceModel(unittest.TestCase): def write_and_load_entries(self, model, entries, dir=None): if dir is None: dir = self.tdir() - cat_path = os.path.join(dir, 'catalog.yaml') - with open(cat_path, 'w') as fp: + cat_path = os.path.join(dir, "catalog.yaml") + with open(cat_path, "w") as fp: yaml.dump(entries, fp) with open(cat_path) as fp: model.load_from_file(fp) def test_initially_server(self): model = SourceModel() - self.assertEqual(model.current.variant, 'server') + self.assertEqual(model.current.variant, "server") def test_load_from_file_simple(self): - entry = make_entry(id='id1') + entry = make_entry(id="id1") model = SourceModel() self.write_and_load_entries(model, [entry]) - self.assertEqual(model.current.id, 'id1') + self.assertEqual(model.current.id, "id1") def test_load_from_file_ignores_extra_keys(self): - entry = make_entry(id='id1', foobarbaz=random_string()) + entry = make_entry(id="id1", foobarbaz=random_string()) model = SourceModel() self.write_and_load_entries(model, [entry]) - self.assertEqual(model.current.id, 'id1') + self.assertEqual(model.current.id, "id1") def test_default(self): entries = [ - make_entry(id='id1'), - make_entry(id='id2', default=True), - make_entry(id='id3'), - ] + make_entry(id="id1"), + make_entry(id="id2", default=True), + make_entry(id="id3"), + ] model = SourceModel() self.write_and_load_entries(model, entries) - self.assertEqual(model.current.id, 'id2') + self.assertEqual(model.current.id, "id2") def test_get_source_absolute(self): entry = make_entry( - type='scheme', - path='/foo/bar/baz', - ) + type="scheme", + path="/foo/bar/baz", + ) model = SourceModel() self.write_and_load_entries(model, [entry]) - self.assertEqual( - model.get_source(), 'scheme:///foo/bar/baz') + self.assertEqual(model.get_source(), "scheme:///foo/bar/baz") def test_get_source_relative(self): dir = self.tdir() entry = make_entry( - type='scheme', - path='foo/bar/baz', - ) + type="scheme", + path="foo/bar/baz", + ) model = SourceModel() self.write_and_load_entries(model, [entry], dir) - self.assertEqual( - model.get_source(), - f'scheme://{dir}/foo/bar/baz') + self.assertEqual(model.get_source(), f"scheme://{dir}/foo/bar/baz") def test_get_source_initial(self): model = SourceModel() - self.assertEqual( - model.get_source(), - 'cp:///media/filesystem') + self.assertEqual(model.get_source(), "cp:///media/filesystem") def test_render(self): model = SourceModel() self.assertEqual(model.render(), {}) def test_canary(self): - with open('examples/sources/install-canary.yaml') as fp: + with open("examples/sources/install-canary.yaml") as fp: model = SourceModel() model.load_from_file(fp) self.assertEqual(2, len(model.sources)) - minimal = model.get_matching_source('ubuntu-desktop-minimal') + minimal = model.get_matching_source("ubuntu-desktop-minimal") self.assertIsNotNone(minimal.variations) - self.assertEqual(minimal.size, minimal.variations['default'].size) - self.assertEqual(minimal.path, minimal.variations['default'].path) - self.assertIsNone(minimal.variations['default'].snapd_system_label) + self.assertEqual(minimal.size, minimal.variations["default"].size) + self.assertEqual(minimal.path, minimal.variations["default"].path) + self.assertIsNone(minimal.variations["default"].snapd_system_label) - standard = model.get_matching_source('ubuntu-desktop') + standard = model.get_matching_source("ubuntu-desktop") self.assertIsNotNone(standard.variations) self.assertEqual(2, len(standard.variations)) diff --git a/subiquity/models/tests/test_subiquity.py b/subiquity/models/tests/test_subiquity.py index 6e30f653..db6852ee 100644 --- a/subiquity/models/tests/test_subiquity.py +++ b/subiquity/models/tests/test_subiquity.py @@ -15,18 +15,20 @@ import fnmatch import json +import re import unittest from unittest import mock -import re -import yaml +import yaml from cloudinit.config.schema import SchemaValidationError + try: from cloudinit.config.schema import SchemaProblem except ImportError: - def SchemaProblem(x, y): return (x, y) # TODO(drop on cloud-init 22.3 SRU) -from subiquitycore.pubsub import MessageHub + def SchemaProblem(x, y): + return (x, y) # TODO(drop on cloud-init 22.3 SRU) + from subiquity.common.types import IdentityData from subiquity.models.subiquity import ( @@ -35,13 +37,11 @@ from subiquity.models.subiquity import ( ModelNames, SubiquityModel, ) -from subiquity.server.server import ( - INSTALL_MODEL_NAMES, - POSTINSTALL_MODEL_NAMES, - ) +from subiquity.server.server import INSTALL_MODEL_NAMES, POSTINSTALL_MODEL_NAMES from subiquity.server.types import InstallerChannels +from subiquitycore.pubsub import MessageHub -getent_group_output = ''' +getent_group_output = """ root:x:0: daemon:x:1: bin:x:2: @@ -57,56 +57,55 @@ man:x:12: sudo:x:27: ssh:x:118: users:x:100: -''' +""" class TestModelNames(unittest.TestCase): - def test_for_known_variant(self): - model_names = ModelNames({'a'}, var1={'b'}, var2={'c'}) - self.assertEqual(model_names.for_variant('var1'), {'a', 'b'}) + model_names = ModelNames({"a"}, var1={"b"}, var2={"c"}) + self.assertEqual(model_names.for_variant("var1"), {"a", "b"}) def test_for_unknown_variant(self): - model_names = ModelNames({'a'}, var1={'b'}, var2={'c'}) - self.assertEqual(model_names.for_variant('var3'), {'a'}) + model_names = ModelNames({"a"}, var1={"b"}, var2={"c"}) + self.assertEqual(model_names.for_variant("var3"), {"a"}) def test_all(self): - model_names = ModelNames({'a'}, var1={'b'}, var2={'c'}) - self.assertEqual(model_names.all(), {'a', 'b', 'c'}) + model_names = ModelNames({"a"}, var1={"b"}, var2={"c"}) + self.assertEqual(model_names.all(), {"a", "b", "c"}) class TestSubiquityModel(unittest.IsolatedAsyncioTestCase): maxDiff = None def writtenFiles(self, config): - for k, v in config.get('write_files', {}).items(): + for k, v in config.get("write_files", {}).items(): yield v def assertConfigWritesFile(self, config, path): - self.assertIn(path, [s['path'] for s in self.writtenFiles(config)]) + self.assertIn(path, [s["path"] for s in self.writtenFiles(config)]) def writtenFilesMatching(self, config, pattern): files = list(self.writtenFiles(config)) matching = [] for spec in files: - if fnmatch.fnmatch(spec['path'], pattern): + if fnmatch.fnmatch(spec["path"], pattern): matching.append(spec) return matching def writtenFilesMatchingContaining(self, config, pattern, content): matching = [] for spec in self.writtenFilesMatching(config, pattern): - if content in spec['content']: + if content in spec["content"]: matching.append(content) return matching def configVal(self, config, path): cur = config - for component in path.split('.'): + for component in path.split("."): if not isinstance(cur, dict): self.fail( - "extracting {} reached non-dict {} too early".format( - path, cur)) + "extracting {} reached non-dict {} too early".format(path, cur) + ) if component not in cur: self.fail("no value found for {}".format(path)) cur = cur[component] @@ -117,11 +116,11 @@ class TestSubiquityModel(unittest.IsolatedAsyncioTestCase): def assertConfigDoesNotHaveVal(self, config, path): cur = config - for component in path.split('.'): + for component in path.split("."): if not isinstance(cur, dict): self.fail( - "extracting {} reached non-dict {} too early".format( - path, cur)) + "extracting {} reached non-dict {} too early".format(path, cur) + ) if component not in cur: return cur = cur[component] @@ -129,157 +128,149 @@ class TestSubiquityModel(unittest.IsolatedAsyncioTestCase): async def test_configure(self): hub = MessageHub() - model = SubiquityModel( - 'test', hub, ModelNames({'a', 'b'}), ModelNames(set())) - model.set_source_variant('var') - await hub.abroadcast((InstallerChannels.CONFIGURED, 'a')) + model = SubiquityModel("test", hub, ModelNames({"a", "b"}), ModelNames(set())) + model.set_source_variant("var") + await hub.abroadcast((InstallerChannels.CONFIGURED, "a")) self.assertFalse(model._install_event.is_set()) - await hub.abroadcast((InstallerChannels.CONFIGURED, 'b')) + await hub.abroadcast((InstallerChannels.CONFIGURED, "b")) self.assertTrue(model._install_event.is_set()) def make_model(self): return SubiquityModel( - 'test', MessageHub(), INSTALL_MODEL_NAMES, POSTINSTALL_MODEL_NAMES) + "test", MessageHub(), INSTALL_MODEL_NAMES, POSTINSTALL_MODEL_NAMES + ) def test_proxy_set(self): model = self.make_model() - proxy_val = 'http://my-proxy' + proxy_val = "http://my-proxy" model.proxy.proxy = proxy_val config = model.render() - self.assertConfigHasVal(config, 'proxy.http_proxy', proxy_val) - self.assertConfigHasVal(config, 'proxy.https_proxy', proxy_val) + self.assertConfigHasVal(config, "proxy.http_proxy", proxy_val) + self.assertConfigHasVal(config, "proxy.https_proxy", proxy_val) confs = self.writtenFilesMatchingContaining( config, - 'etc/systemd/system/snapd.service.d/*.conf', - 'HTTP_PROXY=' + proxy_val) + "etc/systemd/system/snapd.service.d/*.conf", + "HTTP_PROXY=" + proxy_val, + ) self.assertTrue(len(confs) > 0) def test_proxy_notset(self): model = self.make_model() config = model.render() - self.assertConfigDoesNotHaveVal(config, 'proxy.http_proxy') - self.assertConfigDoesNotHaveVal(config, 'proxy.https_proxy') - self.assertConfigDoesNotHaveVal(config, 'apt.http_proxy') - self.assertConfigDoesNotHaveVal(config, 'apt.https_proxy') + self.assertConfigDoesNotHaveVal(config, "proxy.http_proxy") + self.assertConfigDoesNotHaveVal(config, "proxy.https_proxy") + self.assertConfigDoesNotHaveVal(config, "apt.http_proxy") + self.assertConfigDoesNotHaveVal(config, "apt.https_proxy") confs = self.writtenFilesMatchingContaining( - config, - 'etc/systemd/system/snapd.service.d/*.conf', - 'HTTP_PROXY=') + config, "etc/systemd/system/snapd.service.d/*.conf", "HTTP_PROXY=" + ) self.assertTrue(len(confs) == 0) def test_keyboard(self): model = self.make_model() config = model.render() - self.assertConfigWritesFile(config, 'etc/default/keyboard') + self.assertConfigWritesFile(config, "etc/default/keyboard") def test_writes_machine_id_media_info(self): model_no_proxy = self.make_model() model_proxy = self.make_model() - model_proxy.proxy.proxy = 'http://something' + model_proxy.proxy.proxy = "http://something" for model in model_no_proxy, model_proxy: config = model.render() - self.assertConfigWritesFile(config, 'etc/machine-id') - self.assertConfigWritesFile(config, 'var/log/installer/media-info') + self.assertConfigWritesFile(config, "etc/machine-id") + self.assertConfigWritesFile(config, "var/log/installer/media-info") def test_storage_version(self): model = self.make_model() config = model.render() - self.assertConfigHasVal(config, 'storage.version', 1) + self.assertConfigHasVal(config, "storage.version", 1) def test_write_netplan(self): model = self.make_model() config = model.render() netplan_content = None - for fspec in config['write_files'].values(): - if fspec['path'].startswith('etc/netplan'): + for fspec in config["write_files"].values(): + if fspec["path"].startswith("etc/netplan"): if netplan_content is not None: self.fail("writing two files to netplan?") - netplan_content = fspec['content'] + netplan_content = fspec["content"] self.assertIsNot(netplan_content, None) netplan = yaml.safe_load(netplan_content) - self.assertConfigHasVal(netplan, 'network.version', 2) + self.assertConfigHasVal(netplan, "network.version", 2) def test_sources(self): model = self.make_model() config = model.render() - self.assertNotIn('sources', config) + self.assertNotIn("sources", config) def test_mirror(self): model = self.make_model() - mirror_val = 'http://my-mirror' + mirror_val = "http://my-mirror" model.mirror.create_primary_candidate(mirror_val).elect() config = model.render() - self.assertNotIn('apt', config) + self.assertNotIn("apt", config) - @mock.patch('subiquitycore.utils.run_command') + @mock.patch("subiquitycore.utils.run_command") def test_cloud_init_user_list_merge(self, run_cmd): main_user = IdentityData( - username='mainuser', - crypted_password='sample_value', - hostname='somehost') - secondary_user = {'name': 'user2'} + username="mainuser", crypted_password="sample_value", hostname="somehost" + ) + secondary_user = {"name": "user2"} - with self.subTest('Main user + secondary user'): + with self.subTest("Main user + secondary user"): model = self.make_model() model.identity.add_user(main_user) - model.userdata = {'users': [secondary_user]} + model.userdata = {"users": [secondary_user]} run_cmd.return_value.stdout = getent_group_output cloud_init_config = model._cloud_init_config() - self.assertEqual(len(cloud_init_config['users']), 2) - self.assertEqual(cloud_init_config['users'][0]['name'], 'mainuser') - self.assertEqual( - cloud_init_config['users'][0]['groups'], - 'adm,sudo,users' - ) - self.assertEqual(cloud_init_config['users'][1]['name'], 'user2') - run_cmd.assert_called_with(['getent', 'group'], check=True) + self.assertEqual(len(cloud_init_config["users"]), 2) + self.assertEqual(cloud_init_config["users"][0]["name"], "mainuser") + self.assertEqual(cloud_init_config["users"][0]["groups"], "adm,sudo,users") + self.assertEqual(cloud_init_config["users"][1]["name"], "user2") + run_cmd.assert_called_with(["getent", "group"], check=True) - with self.subTest('Secondary user only'): + with self.subTest("Secondary user only"): model = self.make_model() - model.userdata = {'users': [secondary_user]} + model.userdata = {"users": [secondary_user]} cloud_init_config = model._cloud_init_config() - self.assertEqual(len(cloud_init_config['users']), 1) - self.assertEqual(cloud_init_config['users'][0]['name'], 'user2') + self.assertEqual(len(cloud_init_config["users"]), 1) + self.assertEqual(cloud_init_config["users"][0]["name"], "user2") - with self.subTest('Invalid user-data raises error'): + with self.subTest("Invalid user-data raises error"): model = self.make_model() - model.userdata = {'bootcmd': "nope"} + model.userdata = {"bootcmd": "nope"} with self.assertRaises(SchemaValidationError) as ctx: model._cloud_init_config() expected_error = ( - "Cloud config schema errors: bootcmd: 'nope' is not of type" - " 'array'" + "Cloud config schema errors: bootcmd: 'nope' is not of type" " 'array'" ) self.assertEqual(expected_error, str(ctx.exception)) - @mock.patch('subiquity.models.subiquity.lsb_release') - @mock.patch('subiquitycore.file_util.datetime.datetime') + @mock.patch("subiquity.models.subiquity.lsb_release") + @mock.patch("subiquitycore.file_util.datetime.datetime") def test_cloud_init_files_emits_datasource_config_and_clean_script( self, datetime, lsb_release ): datetime.utcnow.return_value = "2004-03-05 ..." main_user = IdentityData( - username='mainuser', - crypted_password='sample_pass', - hostname='somehost') + username="mainuser", crypted_password="sample_pass", hostname="somehost" + ) model = self.make_model() model.identity.add_user(main_user) model.userdata = {} expected_files = { - 'etc/cloud/cloud.cfg.d/99-installer.cfg': re.compile( - 'datasource:\n None:\n metadata:\n instance-id: .*\n userdata_raw: "#cloud-config\\\\ngrowpart:\\\\n mode: \\\'off\\\'\\\\npreserve_hostname: true\\\\n\\\\\n' # noqa + "etc/cloud/cloud.cfg.d/99-installer.cfg": re.compile( + "datasource:\n None:\n metadata:\n instance-id: .*\n userdata_raw: \"#cloud-config\\\\ngrowpart:\\\\n mode: \\'off\\'\\\\npreserve_hostname: true\\\\n\\\\\n" # noqa ), - 'etc/hostname': 'somehost\n', - 'etc/cloud/ds-identify.cfg': 'policy: enabled\n', - 'etc/hosts': HOSTS_CONTENT.format(hostname='somehost'), + "etc/hostname": "somehost\n", + "etc/cloud/ds-identify.cfg": "policy: enabled\n", + "etc/hosts": HOSTS_CONTENT.format(hostname="somehost"), } # Avoid removing /etc/hosts and /etc/hostname in cloud-init clean - cfg_files = [ - "/" + key for key in expected_files.keys() if "host" not in key - ] + cfg_files = ["/" + key for key in expected_files.keys() if "host" not in key] cfg_files.append( # Obtained from NetworkModel.render when cloud-init features # NETPLAN_CONFIG_ROOT_READ_ONLY is True @@ -287,14 +278,14 @@ class TestSubiquityModel(unittest.IsolatedAsyncioTestCase): ) header = "# Autogenerated by Subiquity: 2004-03-05 ... UTC\n" with self.subTest( - 'Stable releases Jammy do not disable cloud-init.' - ' NETPLAN_ROOT_READ_ONLY=True uses cloud-init networking' + "Stable releases Jammy do not disable cloud-init." + " NETPLAN_ROOT_READ_ONLY=True uses cloud-init networking" ): lsb_release.return_value = {"release": "22.04"} - expected_files['etc/cloud/clean.d/99-installer'] = ( - CLOUDINIT_CLEAN_FILE_TMPL.format( - header=header, cfg_files=json.dumps(sorted(cfg_files)) - ) + expected_files[ + "etc/cloud/clean.d/99-installer" + ] = CLOUDINIT_CLEAN_FILE_TMPL.format( + header=header, cfg_files=json.dumps(sorted(cfg_files)) ) with unittest.mock.patch( "subiquity.cloudinit.open", @@ -304,57 +295,52 @@ class TestSubiquityModel(unittest.IsolatedAsyncioTestCase): ) ), ): - for (cpath, content, perms) in model._cloud_init_files(): + for cpath, content, perms in model._cloud_init_files(): if isinstance(expected_files[cpath], re.Pattern): - self.assertIsNotNone( - expected_files[cpath].match(content) - ) + self.assertIsNotNone(expected_files[cpath].match(content)) else: self.assertEqual(expected_files[cpath], content) with self.subTest( - 'Kinetic++ disables cloud-init post install.' - ' NETPLAN_ROOT_READ_ONLY=False avoids cloud-init networking' + "Kinetic++ disables cloud-init post install." + " NETPLAN_ROOT_READ_ONLY=False avoids cloud-init networking" ): lsb_release.return_value = {"release": "22.10"} cfg_files.append( # Added by _cloud_init_files for 22.10 and later releases - '/etc/cloud/cloud-init.disabled', + "/etc/cloud/cloud-init.disabled", ) # Obtained from NetworkModel.render - cfg_files.remove('/etc/cloud/cloud.cfg.d/90-installer-network.cfg') - cfg_files.append('/etc/netplan/00-installer-config.yaml') + cfg_files.remove("/etc/cloud/cloud.cfg.d/90-installer-network.cfg") + cfg_files.append("/etc/netplan/00-installer-config.yaml") cfg_files.append( - '/etc/cloud/cloud.cfg.d/' - 'subiquity-disable-cloudinit-networking.cfg' + "/etc/cloud/cloud.cfg.d/" "subiquity-disable-cloudinit-networking.cfg" ) expected_files[ - 'etc/cloud/clean.d/99-installer' + "etc/cloud/clean.d/99-installer" ] = CLOUDINIT_CLEAN_FILE_TMPL.format( header=header, cfg_files=json.dumps(sorted(cfg_files)) ) with unittest.mock.patch( - 'subiquity.cloudinit.open', + "subiquity.cloudinit.open", mock.mock_open( read_data=json.dumps( - {'features': {'NETPLAN_CONFIG_ROOT_READ_ONLY': False}} + {"features": {"NETPLAN_CONFIG_ROOT_READ_ONLY": False}} ) ), ): - for (cpath, content, perms) in model._cloud_init_files(): + for cpath, content, perms in model._cloud_init_files(): if isinstance(expected_files[cpath], re.Pattern): - self.assertIsNotNone( - expected_files[cpath].match(content) - ) + self.assertIsNotNone(expected_files[cpath].match(content)) else: self.assertEqual(expected_files[cpath], content) def test_validatecloudconfig_schema(self): model = self.make_model() - with self.subTest('Valid cloud-config does not error'): + with self.subTest("Valid cloud-config does not error"): model.validate_cloudconfig_schema( data={"ssh_import_id": ["chad.smith"]}, - data_source="autoinstall.user-data" + data_source="autoinstall.user-data", ) # Create our own subclass for focal as schema_deprecations @@ -367,22 +353,18 @@ class TestSubiquityModel(unittest.IsolatedAsyncioTestCase): self.schema_deprecations = schema_deprecations problem = SchemaProblem( - "bogus", - "'bogus' is deprecated, use 'notbogus' instead" + "bogus", "'bogus' is deprecated, use 'notbogus' instead" ) - with self.subTest('Deprecated cloud-config warns'): + with self.subTest("Deprecated cloud-config warns"): with unittest.mock.patch( "subiquity.models.subiquity.validate_cloudconfig_schema" ) as validate: - validate.side_effect = SchemaDeprecation( - schema_deprecations=(problem,) - ) + validate.side_effect = SchemaDeprecation(schema_deprecations=(problem,)) with self.assertLogs( "subiquity.models.subiquity", level="INFO" ) as logs: model.validate_cloudconfig_schema( - data={"bogus": True}, - data_source="autoinstall.user-data" + data={"bogus": True}, data_source="autoinstall.user-data" ) expected = ( "WARNING:subiquity.models.subiquity:The cloud-init" @@ -391,22 +373,20 @@ class TestSubiquityModel(unittest.IsolatedAsyncioTestCase): ) self.assertEqual(logs.output, [expected]) - with self.subTest('Invalid cloud-config schema errors'): + with self.subTest("Invalid cloud-config schema errors"): with self.assertRaises(SchemaValidationError) as ctx: model.validate_cloudconfig_schema( data={"bootcmd": "nope"}, data_source="system info" ) expected_error = ( - "Cloud config schema errors: bootcmd: 'nope' is not of" - " type 'array'" + "Cloud config schema errors: bootcmd: 'nope' is not of" " type 'array'" ) self.assertEqual(expected_error, str(ctx.exception)) - with self.subTest('Prefix autoinstall.user-data cloud-config errors'): + with self.subTest("Prefix autoinstall.user-data cloud-config errors"): with self.assertRaises(SchemaValidationError) as ctx: model.validate_cloudconfig_schema( - data={"bootcmd": "nope"}, - data_source="autoinstall.user-data" + data={"bootcmd": "nope"}, data_source="autoinstall.user-data" ) expected_error = ( "Cloud config schema errors:" diff --git a/subiquity/models/timezone.py b/subiquity/models/timezone.py index 433ae625..6368aa43 100644 --- a/subiquity/models/timezone.py +++ b/subiquity/models/timezone.py @@ -15,13 +15,13 @@ import logging -log = logging.getLogger('subiquity.models.timezone') +log = logging.getLogger("subiquity.models.timezone") class TimeZoneModel(object): - """ Model representing timezone""" + """Model representing timezone""" - timezone = '' + timezone = "" def __init__(self): # This is the raw request from the API / autoinstall. @@ -32,16 +32,16 @@ class TimeZoneModel(object): self._request = None # The actually timezone to set, possibly post-geoip lookup or # possibly manually specified. - self.timezone = '' + self.timezone = "" self.got_from_geoip = False def set(self, value): self._request = value - self.timezone = value if value != 'geoip' else '' + self.timezone = value if value != "geoip" else "" @property def detect_with_geoip(self): - return self._request == 'geoip' + return self._request == "geoip" @property def should_set_tz(self): @@ -54,10 +54,15 @@ class TimeZoneModel(object): def make_cloudconfig(self): if not self.should_set_tz: return {} - return {'timezone': self.timezone} + return {"timezone": self.timezone} def __repr__(self): - return ("").format( - self.request, self.detect_with_geoip, self.should_set_tz, - self.timezone, self.got_from_geoip) + return ( + "" + ).format( + self.request, + self.detect_with_geoip, + self.should_set_tz, + self.timezone, + self.got_from_geoip, + ) diff --git a/subiquity/models/ubuntu_pro.py b/subiquity/models/ubuntu_pro.py index ddff3c49..d1b231f9 100644 --- a/subiquity/models/ubuntu_pro.py +++ b/subiquity/models/ubuntu_pro.py @@ -26,8 +26,9 @@ class UbuntuProModel: the provided token is correct ; nor to retrieve information about the subscription. """ + def __init__(self): - """ Initialize the model. """ + """Initialize the model.""" self.token: str = "" def make_cloudconfig(self) -> dict: diff --git a/subiquity/models/updates.py b/subiquity/models/updates.py index 40a5929b..67781c68 100644 --- a/subiquity/models/updates.py +++ b/subiquity/models/updates.py @@ -15,13 +15,13 @@ import logging -log = logging.getLogger('subiquity.models.updates') +log = logging.getLogger("subiquity.models.updates") class UpdatesModel(object): - """ Model representing updates selection""" + """Model representing updates selection""" - updates = 'security' + updates = "security" def __repr__(self): return "".format(self.updates) diff --git a/subiquity/server/ad_joiner.py b/subiquity/server/ad_joiner.py index 5c245d6f..144ae8bf 100644 --- a/subiquity/server/ad_joiner.py +++ b/subiquity/server/ad_joiner.py @@ -14,30 +14,28 @@ # along with this program. If not, see . import asyncio -from contextlib import contextmanager import logging import os +from contextlib import contextmanager from socket import gethostname from subprocess import CalledProcessError -from subiquitycore.utils import arun_command, run_command -from subiquity.server.curtin import run_curtin_command -from subiquity.common.types import ( - AdConnectionInfo, - AdJoinResult, -) -log = logging.getLogger('subiquity.server.ad_joiner') +from subiquity.common.types import AdConnectionInfo, AdJoinResult +from subiquity.server.curtin import run_curtin_command +from subiquitycore.utils import arun_command, run_command + +log = logging.getLogger("subiquity.server.ad_joiner") @contextmanager def joining_context(hostname: str, root_dir: str): - """ Temporarily adjusts the host name to [hostname] and bind-mounts - interesting system directories in preparation for running realm - in target's [root_dir], undoing it all on exit. """ + """Temporarily adjusts the host name to [hostname] and bind-mounts + interesting system directories in preparation for running realm + in target's [root_dir], undoing it all on exit.""" hostname_current = gethostname() binds = ("/proc", "/sys", "/dev", "/run") try: - hostname_process = run_command(['hostname', hostname]) + hostname_process = run_command(["hostname", hostname]) for bind in binds: bound_dir = os.path.join(root_dir, bind[1:]) if bound_dir != bind: @@ -45,7 +43,7 @@ def joining_context(hostname: str, root_dir: str): yield hostname_process finally: # Restoring the live session hostname. - hostname_process = run_command(['hostname', hostname_current]) + hostname_process = run_command(["hostname", hostname_current]) if hostname_process.returncode: log.info("Failed to restore live session hostname") for bind in reversed(binds): @@ -54,17 +52,18 @@ def joining_context(hostname: str, root_dir: str): run_command(["umount", "-f", bound_dir]) -class AdJoinStrategy(): +class AdJoinStrategy: realm = "/usr/sbin/realm" pam = "/usr/sbin/pam-auth-update" def __init__(self, app): self.app = app - async def do_join(self, info: AdConnectionInfo, hostname: str, context) \ - -> AdJoinResult: - """ This method changes the hostname and perform a real AD join, thus - should only run in a live session. """ + async def do_join( + self, info: AdConnectionInfo, hostname: str, context + ) -> AdJoinResult: + """This method changes the hostname and perform a real AD join, thus + should only run in a live session.""" root_dir = self.app.base_model.target # Set hostname for AD to determine FQDN (no FQDN option in realm join, # only adcli, which only understands the live system, but not chroot) @@ -73,11 +72,21 @@ class AdJoinStrategy(): log.info("Failed to set live session hostname for adcli") return AdJoinResult.JOIN_ERROR - cp = await arun_command([self.realm, "join", "--install", root_dir, - "--user", info.admin_name, - "--computer-name", hostname, - "--unattended", info.domain_name], - input=info.password) + cp = await arun_command( + [ + self.realm, + "join", + "--install", + root_dir, + "--user", + info.admin_name, + "--computer-name", + hostname, + "--unattended", + info.domain_name, + ], + input=info.password, + ) if cp.returncode: # Try again without the computer name. Lab tests shown more @@ -88,10 +97,19 @@ class AdJoinStrategy(): log.debug(cp.stderr) log.debug(cp.stdout) log.debug("Trying again without overriding the computer name:") - cp = await arun_command([self.realm, "join", "--install", - root_dir, "--user", info.admin_name, - "--unattended", info.domain_name], - input=info.password) + cp = await arun_command( + [ + self.realm, + "join", + "--install", + root_dir, + "--user", + info.admin_name, + "--unattended", + info.domain_name, + ], + input=info.password, + ) if cp.returncode: log.debug("Joining operation failed:") @@ -102,11 +120,19 @@ class AdJoinStrategy(): # Enable pam_mkhomedir try: # The function raises if the process fail. - await run_curtin_command(self.app, context, - "in-target", "-t", root_dir, - "--", self.pam, "--package", - "--enable", "mkhomedir", - private_mounts=False) + await run_curtin_command( + self.app, + context, + "in-target", + "-t", + root_dir, + "--", + self.pam, + "--package", + "--enable", + "mkhomedir", + private_mounts=False, + ) return AdJoinResult.OK except CalledProcessError: @@ -120,24 +146,25 @@ class AdJoinStrategy(): class StubStrategy(AdJoinStrategy): - async def do_join(self, info: AdConnectionInfo, hostname: str, context) \ - -> AdJoinResult: - """ Enables testing without real join. The result depends on the - domain name initial character, such that if it is: - - p or P: returns PAM_ERROR. - - j or J: returns JOIN_ERROR. - - returns OK otherwise. """ + async def do_join( + self, info: AdConnectionInfo, hostname: str, context + ) -> AdJoinResult: + """Enables testing without real join. The result depends on the + domain name initial character, such that if it is: + - p or P: returns PAM_ERROR. + - j or J: returns JOIN_ERROR. + - returns OK otherwise.""" initial = info.domain_name[0] - if initial in ('j', 'J'): + if initial in ("j", "J"): return AdJoinResult.JOIN_ERROR - if initial in ('p', 'P'): + if initial in ("p", "P"): return AdJoinResult.PAM_ERROR return AdJoinResult.OK -class AdJoiner(): +class AdJoiner: def __init__(self, app): self._result = AdJoinResult.UNKNOWN self._completion_event = asyncio.Event() @@ -146,8 +173,9 @@ class AdJoiner(): else: self.strategy = AdJoinStrategy(app) - async def join_domain(self, info: AdConnectionInfo, hostname: str, - context) -> AdJoinResult: + async def join_domain( + self, info: AdConnectionInfo, hostname: str, context + ) -> AdJoinResult: if hostname: self._result = await self.strategy.do_join(info, hostname, context) else: diff --git a/subiquity/server/apt.py b/subiquity/server/apt.py index 71f7450b..d06a498d 100644 --- a/subiquity/server/apt.py +++ b/subiquity/server/apt.py @@ -28,13 +28,8 @@ import tempfile from typing import List, Optional import apt_pkg - -from curtin.config import merge_config from curtin.commands.extract import AbstractSourceHandler - -from subiquitycore.file_util import write_file, generate_config_yaml -from subiquitycore.lsb_release import lsb_release -from subiquitycore.utils import astart_command, orig_environ +from curtin.config import merge_config from subiquity.server.curtin import run_curtin_command from subiquity.server.mounter import ( @@ -43,18 +38,21 @@ from subiquity.server.mounter import ( Mountpoint, OverlayCleanupError, OverlayMountpoint, - ) +) +from subiquitycore.file_util import generate_config_yaml, write_file +from subiquitycore.lsb_release import lsb_release +from subiquitycore.utils import astart_command, orig_environ -log = logging.getLogger('subiquity.server.apt') +log = logging.getLogger("subiquity.server.apt") class AptConfigCheckError(Exception): - """ Error to raise when apt-get update fails with the currently applied - configuration. """ + """Error to raise when apt-get update fails with the currently applied + configuration.""" def get_index_targets() -> List[str]: - """ Return the identifier of the data files that would be downloaded during + """Return the identifier of the data files that would be downloaded during apt-get update. NOTE: this uses the default configuration files from the host so this might slightly differ from what we would have in the overlay. @@ -108,8 +106,7 @@ class AptConfigurer: # system, or if it is not, just copy /var/lib/apt/lists from the # 'configured_tree' overlay. - def __init__(self, app, mounter: Mounter, - source_handler: AbstractSourceHandler): + def __init__(self, app, mounter: Mounter, source_handler: AbstractSourceHandler): self.app = app self.mounter = mounter self.source_handler: AbstractSourceHandler = source_handler @@ -133,30 +130,37 @@ class AptConfigurer: self.app.base_model.debconf_selections, ] for model in models: - merge_config(cfg, model.get_apt_config( - final=final, has_network=has_network)) - return {'apt': cfg} + merge_config( + cfg, model.get_apt_config(final=final, has_network=has_network) + ) + return {"apt": cfg} async def apply_apt_config(self, context, final: bool): - self.configured_tree = await self.mounter.setup_overlay( - [self.source_path]) + self.configured_tree = await self.mounter.setup_overlay([self.source_path]) config_location = os.path.join( - self.app.root, 'var/log/installer/subiquity-curtin-apt.conf') + self.app.root, "var/log/installer/subiquity-curtin-apt.conf" + ) generate_config_yaml(config_location, self.apt_config(final)) self.app.note_data_for_apport("CurtinAptConfig", config_location) await run_curtin_command( - self.app, context, 'apt-config', '-t', self.configured_tree.p(), - config=config_location, private_mounts=True) + self.app, + context, + "apt-config", + "-t", + self.configured_tree.p(), + config=config_location, + private_mounts=True, + ) async def run_apt_config_check(self, output: io.StringIO) -> None: - """ Run apt-get update (with various options limiting the amount of + """Run apt-get update (with various options limiting the amount of data donwloaded) in the overlay where the apt configuration was previously deployed. The output of apt-get (stdout + stderr) will be written to the output parameter. Raises a AptConfigCheckError exception if the apt-get command exited - with non-zero. """ + with non-zero.""" assert self.configured_tree is not None pfx = pathlib.Path(self.configured_tree.p()) @@ -173,15 +177,15 @@ class AptConfigurer: # Need to ensure the "partial" directory exists. partial_dir = apt_dirs["State::Lists"] / "partial" - partial_dir.mkdir( - parents=True, exist_ok=True) + partial_dir.mkdir(parents=True, exist_ok=True) try: shutil.chown(partial_dir, user="_apt") except (PermissionError, LookupError) as exc: log.warning("could to set owner of file %s: %r", partial_dir, exc) apt_cmd = [ - "apt-get", "update", + "apt-get", + "update", "-oAPT::Update::Error-Mode=any", # Workaround because the default sandbox user (i.e., _apt) does not # have access to the overlay. @@ -201,8 +205,9 @@ class AptConfigurer: env["APT_CONFIG"] = config_file.name config_file.write(apt_config.dump()) config_file.flush() - proc = await astart_command(apt_cmd, stderr=subprocess.STDOUT, - clean_locale=False, env=env) + proc = await astart_command( + apt_cmd, stderr=subprocess.STDOUT, clean_locale=False, env=env + ) async def _reader(): while not proc.stdout.at_eof(): @@ -223,42 +228,51 @@ class AptConfigurer: async def configure_for_install(self, context): assert self.configured_tree is not None - self.install_tree = await self.mounter.setup_overlay( - [self.configured_tree]) + self.install_tree = await self.mounter.setup_overlay([self.configured_tree]) - os.mkdir(self.install_tree.p('cdrom')) - await self.mounter.mount( - '/cdrom', self.install_tree.p('cdrom'), options='bind') + os.mkdir(self.install_tree.p("cdrom")) + await self.mounter.mount("/cdrom", self.install_tree.p("cdrom"), options="bind") if self.app.base_model.network.has_network: os.rename( - self.install_tree.p('etc/apt/sources.list'), - self.install_tree.p('etc/apt/sources.list.d/original.list')) + self.install_tree.p("etc/apt/sources.list"), + self.install_tree.p("etc/apt/sources.list.d/original.list"), + ) else: - proxy_path = self.install_tree.p( - 'etc/apt/apt.conf.d/90curtin-aptproxy') + proxy_path = self.install_tree.p("etc/apt/apt.conf.d/90curtin-aptproxy") if os.path.exists(proxy_path): os.unlink(proxy_path) - codename = lsb_release(dry_run=self.app.opts.dry_run)['codename'] + codename = lsb_release(dry_run=self.app.opts.dry_run)["codename"] write_file( - self.install_tree.p('etc/apt/sources.list'), - f'deb [check-date=no] file:///cdrom {codename} main restricted\n') + self.install_tree.p("etc/apt/sources.list"), + f"deb [check-date=no] file:///cdrom {codename} main restricted\n", + ) await run_curtin_command( - self.app, context, "in-target", "-t", self.install_tree.p(), - "--", "apt-get", "update", private_mounts=True) + self.app, + context, + "in-target", + "-t", + self.install_tree.p(), + "--", + "apt-get", + "update", + private_mounts=True, + ) return self.install_tree.p() @contextlib.asynccontextmanager async def overlay(self): - overlay = await self.mounter.setup_overlay([ + overlay = await self.mounter.setup_overlay( + [ self.install_tree.upperdir, self.configured_tree.upperdir, self.source_path, - ]) + ] + ) try: yield overlay finally: @@ -269,8 +283,7 @@ class AptConfigurer: # compares equal to the one we discarded earlier. # But really, there should be better ways to handle this. try: - await self.mounter.unmount( - Mountpoint(mountpoint=overlay.mountpoint)) + await self.mounter.unmount(Mountpoint(mountpoint=overlay.mountpoint)) except subprocess.CalledProcessError as exc: raise OverlayCleanupError from exc @@ -286,41 +299,53 @@ class AptConfigurer: async def _restore_dir(dir): shutil.rmtree(target_mnt.p(dir)) - await self.app.command_runner.run([ - 'cp', '-aT', self.configured_tree.p(dir), target_mnt.p(dir), - ]) + await self.app.command_runner.run( + [ + "cp", + "-aT", + self.configured_tree.p(dir), + target_mnt.p(dir), + ] + ) def _restore_file(path: str) -> None: shutil.copyfile(self.configured_tree.p(path), target_mnt.p(path)) # The file only exists if we are online with contextlib.suppress(FileNotFoundError): - os.unlink(target_mnt.p('etc/apt/sources.list.d/original.list')) - _restore_file('etc/apt/sources.list') + os.unlink(target_mnt.p("etc/apt/sources.list.d/original.list")) + _restore_file("etc/apt/sources.list") with contextlib.suppress(FileNotFoundError): - _restore_file('etc/apt/apt.conf.d/90curtin-aptproxy') + _restore_file("etc/apt/apt.conf.d/90curtin-aptproxy") if self.app.base_model.network.has_network: await run_curtin_command( - self.app, context, "in-target", "-t", target_mnt.p(), - "--", "apt-get", "update", private_mounts=True) + self.app, + context, + "in-target", + "-t", + target_mnt.p(), + "--", + "apt-get", + "update", + private_mounts=True, + ) else: - await _restore_dir('var/lib/apt/lists') + await _restore_dir("var/lib/apt/lists") await self.cleanup() try: - d = target_mnt.p('cdrom') + d = target_mnt.p("cdrom") os.rmdir(d) except OSError as ose: - log.warning(f'failed to rmdir {d}: {ose}') + log.warning(f"failed to rmdir {d}: {ose}") async def setup_target(self, context, target: str): # Call this after the rootfs has been extracted to the real target # system but before any configuration is applied to it. target_mnt = Mountpoint(mountpoint=target) - await self.mounter.mount( - '/cdrom', target_mnt.p('cdrom'), options='bind') + await self.mounter.mount("/cdrom", target_mnt.p("cdrom"), options="bind") class DryRunAptConfigurer(AptConfigurer): @@ -339,8 +364,8 @@ class DryRunAptConfigurer(AptConfigurer): await self.cleanup() def get_mirror_check_strategy(self, url: str) -> "MirrorCheckStrategy": - """ For a given mirror URL, return the strategy that we should use to - perform mirror checking. """ + """For a given mirror URL, return the strategy that we should use to + perform mirror checking.""" for known in self.app.dr_cfg.apt_mirrors_known: if "url" in known: if known["url"] != url: @@ -354,16 +379,18 @@ class DryRunAptConfigurer(AptConfigurer): return self.MirrorCheckStrategy(known["strategy"]) return self.MirrorCheckStrategy( - self.app.dr_cfg.apt_mirror_check_default_strategy) + self.app.dr_cfg.apt_mirror_check_default_strategy + ) async def apt_config_check_failure(self, output: io.StringIO) -> None: - """ Pretend that the execution of the apt-get update command results in - a failure. """ + """Pretend that the execution of the apt-get update command results in + a failure.""" url = self.app.base_model.mirror.primary_staged.uri release = lsb_release(dry_run=True)["codename"] host = url.split("/")[2] - output.write(f"""\ + output.write( + f"""\ Ign:1 {url} {release} InRelease Ign:2 {url} {release}-updates InRelease Ign:3 {url} {release}-backports InRelease @@ -389,28 +416,31 @@ E: Failed to fetch {url}/dists/{release}-security/InRelease\ Temporary failure resolving '{host}' E: Some index files failed to download. They have been ignored, or old ones used instead. -""") +""" + ) raise AptConfigCheckError async def apt_config_check_success(self, output: io.StringIO) -> None: - """ Pretend that the execution of the apt-get update command results in - a success. """ + """Pretend that the execution of the apt-get update command results in + a success.""" url = self.app.base_model.mirror.primary_staged.uri release = lsb_release(dry_run=True)["codename"] - output.write(f"""\ + output.write( + f"""\ Get:1 {url} {release} InRelease [267 kB] Get:2 {url} {release}-updates InRelease [109 kB] Get:3 {url} {release}-backports InRelease [99.9 kB] Get:4 {url} {release}-security InRelease [109 kB] Fetched 585 kB in 1s (1057 kB/s) Reading package lists... -""") +""" + ) async def run_apt_config_check(self, output: io.StringIO) -> None: - """ Dry-run implementation of the Apt config check. + """Dry-run implementation of the Apt config check. The strategy used is based on the URL of the primary mirror. The - apt-get command can either run on the host or be faked entirely. """ + apt-get command can either run on the host or be faked entirely.""" assert self.configured_tree is not None failure = self.apt_config_check_failure diff --git a/subiquity/server/contract_selection.py b/subiquity/server/contract_selection.py index 52a46f21..cf162e16 100644 --- a/subiquity/server/contract_selection.py +++ b/subiquity/server/contract_selection.py @@ -16,40 +16,38 @@ import asyncio import logging -from typing import Any, List import time +from typing import Any, List -from subiquity.server.ubuntu_advantage import ( - UAInterface, - APIUsageError, - ) - +from subiquity.server.ubuntu_advantage import APIUsageError, UAInterface from subiquitycore.async_helpers import schedule_task - log = logging.getLogger("subiquity.server.contract_selection") class UPCSExpiredError(Exception): - """ Exception to be raised when a contract selection expired. """ + """Exception to be raised when a contract selection expired.""" class UnknownError(Exception): - """ Exception to be raised in case of unexpected error. """ + """Exception to be raised in case of unexpected error.""" + def __init__(self, message: str = "", errors: List[Any] = []) -> None: self.errors = errors super().__init__(message) class ContractSelection: - """ Represents an already initiated contract selection. """ + """Represents an already initiated contract selection.""" + def __init__( - self, - client: UAInterface, - magic_token: str, - user_code: str, - validity_seconds: int) -> None: - """ Initialize the contract selection. """ + self, + client: UAInterface, + magic_token: str, + user_code: str, + validity_seconds: int, + ) -> None: + """Initialize the contract selection.""" self.client = client self.magic_token = magic_token self.user_code = user_code @@ -57,30 +55,29 @@ class ContractSelection: self.task = asyncio.create_task(self._run_polling()) @classmethod - async def initiate(cls, client: UAInterface) \ - -> "ContractSelection": - """ Initiate a contract selection and return a ContractSelection - request object. """ + async def initiate(cls, client: UAInterface) -> "ContractSelection": + """Initiate a contract selection and return a ContractSelection + request object.""" answer = await client.strategy.magic_initiate_v1() if answer["result"] != "success": raise UnknownError(errors=answer["errors"]) return cls( - client=client, - magic_token=answer["data"]["attributes"]["token"], - validity_seconds=answer["data"]["attributes"]["expires_in"], - user_code=answer["data"]["attributes"]["user_code"]) + client=client, + magic_token=answer["data"]["attributes"]["token"], + validity_seconds=answer["data"]["attributes"]["expires_in"], + user_code=answer["data"]["attributes"]["user_code"], + ) async def _run_polling(self) -> str: - """ Runs the polling and eventually return a contract token. """ + """Runs the polling and eventually return a contract token.""" # Wait an initial 30 seconds before sending the first request, then # send requests at regular interval every 10 seconds. await asyncio.sleep(30 / self.client.strategy.scale_factor) start_time = time.monotonic() - answer = await self.client.strategy.magic_wait_v1( - magic_token=self.magic_token) + answer = await self.client.strategy.magic_wait_v1(magic_token=self.magic_token) call_duration = time.monotonic() - start_time @@ -89,26 +86,28 @@ class ContractSelection: exception = UnknownError() for error in answer["errors"]: - if error["code"] == "magic-attach-token-error" and \ - call_duration >= 60 / self.client.strategy.scale_factor: + if ( + error["code"] == "magic-attach-token-error" + and call_duration >= 60 / self.client.strategy.scale_factor + ): # Assume it's a timeout if it lasted more than 1 minute. exception = UPCSExpiredError(error["title"]) else: - log.warning("magic_wait_v1: %s: %s", - error["code"], error["title"]) + log.warning("magic_wait_v1: %s: %s", error["code"], error["title"]) raise exception def cancel(self): - """ Cancel the polling task and asynchronously delete the associated - resource. """ + """Cancel the polling task and asynchronously delete the associated + resource.""" self.task.cancel() async def delete_resource() -> None: - """ Release the resource on the server. """ + """Release the resource on the server.""" try: answer = await self.client.strategy.magic_revoke_v1( - magic_token=self.magic_token) + magic_token=self.magic_token + ) except APIUsageError as e: log.warning("failed to revoke magic-token: %r", e) @@ -117,7 +116,8 @@ class ContractSelection: log.debug("successfully revoked magic-token") return for error in answer["errors"]: - log.warning("magic_revoke_v1: %s: %s", - error["code"], error["title"]) + log.warning( + "magic_revoke_v1: %s: %s", error["code"], error["title"] + ) schedule_task(delete_resource()) diff --git a/subiquity/server/controller.py b/subiquity/server/controller.py index 06def369..4662f187 100644 --- a/subiquity/server/controller.py +++ b/subiquity/server/controller.py @@ -20,19 +20,15 @@ from typing import Any, Optional import jsonschema -from subiquitycore.context import with_context -from subiquitycore.controller import ( - BaseController, - ) - from subiquity.common.api.server import bind from subiquity.server.types import InstallerChannels +from subiquitycore.context import with_context +from subiquitycore.controller import BaseController log = logging.getLogger("subiquity.server.controller") class SubiquityController(BaseController): - autoinstall_key: Optional[str] = None autoinstall_schema: Any = None autoinstall_default: Any = None @@ -48,10 +44,9 @@ class SubiquityController(BaseController): def __init__(self, app): super().__init__(app) - self.context.set('controller', self) + self.context.set("controller", self) if self.interactive_for_variants is not None: - self.app.hub.subscribe( - InstallerChannels.INSTALL_CONFIRMED, self._confirmed) + self.app.hub.subscribe(InstallerChannels.INSTALL_CONFIRMED, self._confirmed) async def _confirmed(self): variant = self.app.base_model.source.current.variant @@ -102,8 +97,7 @@ class SubiquityController(BaseController): def interactive(self): if not self.app.autoinstall_config: return self._active - i_sections = self.app.autoinstall_config.get( - 'interactive-sections', []) + i_sections = self.app.autoinstall_config.get("interactive-sections", []) if "*" in i_sections: return self._active @@ -111,21 +105,23 @@ class SubiquityController(BaseController): if self.autoinstall_key in i_sections: return self._active - if self.autoinstall_key_alias is not None \ - and self.autoinstall_key_alias in i_sections: + if ( + self.autoinstall_key_alias is not None + and self.autoinstall_key_alias in i_sections + ): return self._active async def configured(self): - """Let the world know that this controller's model is now configured. - """ - with open(self.app.state_path('states', self.name), 'w') as fp: + """Let the world know that this controller's model is now configured.""" + with open(self.app.state_path("states", self.name), "w") as fp: json.dump(self.serialize(), fp) if self.model_name is not None: await self.app.hub.abroadcast( - (InstallerChannels.CONFIGURED, self.model_name)) + (InstallerChannels.CONFIGURED, self.model_name) + ) def load_state(self): - state_path = self.app.state_path('states', self.name) + state_path = self.app.state_path("states", self.name) if not os.path.exists(state_path): return with open(state_path) as fp: @@ -143,6 +139,5 @@ class SubiquityController(BaseController): class NonInteractiveController(SubiquityController): - def interactive(self): return False diff --git a/subiquity/server/controllers/__init__.py b/subiquity/server/controllers/__init__.py index d10bbdfa..8508d9bc 100644 --- a/subiquity/server/controllers/__init__.py +++ b/subiquity/server/controllers/__init__.py @@ -14,7 +14,7 @@ # along with this program. If not, see . from .ad import AdController -from .cmdlist import EarlyController, LateController, ErrorController +from .cmdlist import EarlyController, ErrorController, LateController from .codecs import CodecsController from .debconf import DebconfController from .drivers import DriversController @@ -22,8 +22,8 @@ from .filesystem import FilesystemController from .identity import IdentityController from .install import InstallController from .integrity import IntegrityController -from .keyboard import KeyboardController from .kernel import KernelController +from .keyboard import KeyboardController from .locale import LocaleController from .mirror import MirrorController from .network import NetworkController @@ -43,34 +43,34 @@ from .userdata import UserdataController from .zdev import ZdevController __all__ = [ - 'AdController', - 'CodecsController', - 'DebconfController', - 'DriversController', - 'EarlyController', - 'ErrorController', - 'FilesystemController', - 'IdentityController', - 'IntegrityController', - 'InstallController', - 'KernelController', - 'KeyboardController', - 'LateController', - 'LocaleController', - 'MirrorController', - 'NetworkController', - 'OEMController', - 'PackageController', - 'ProxyController', - 'RefreshController', - 'ReportingController', - 'ShutdownController', - 'SnapListController', - 'SourceController', - 'SSHController', - 'TimeZoneController', - 'UbuntuProController', - 'UpdatesController', - 'UserdataController', - 'ZdevController', + "AdController", + "CodecsController", + "DebconfController", + "DriversController", + "EarlyController", + "ErrorController", + "FilesystemController", + "IdentityController", + "IntegrityController", + "InstallController", + "KernelController", + "KeyboardController", + "LateController", + "LocaleController", + "MirrorController", + "NetworkController", + "OEMController", + "PackageController", + "ProxyController", + "RefreshController", + "ReportingController", + "ShutdownController", + "SnapListController", + "SourceController", + "SSHController", + "TimeZoneController", + "UbuntuProController", + "UpdatesController", + "UserdataController", + "ZdevController", ] diff --git a/subiquity/server/controllers/ad.py b/subiquity/server/controllers/ad.py index daf8606f..8f1ffb08 100644 --- a/subiquity/server/controllers/ad.py +++ b/subiquity/server/controllers/ad.py @@ -17,23 +17,23 @@ import logging import os import re from typing import List, Optional, Set -from subiquitycore.utils import arun_command -from subiquitycore.async_helpers import run_bg_task from subiquity.common.apidef import API from subiquity.common.types import ( - AdConnectionInfo, AdAdminNameValidation, + AdConnectionInfo, AdDomainNameValidation, AdJoinResult, - AdPasswordValidation + AdPasswordValidation, ) from subiquity.server.ad_joiner import AdJoiner from subiquity.server.controller import SubiquityController +from subiquitycore.async_helpers import run_bg_task +from subiquitycore.utils import arun_command -log = logging.getLogger('subiquity.server.controllers.ad') +log = logging.getLogger("subiquity.server.controllers.ad") -DC_RE = r'^[a-zA-Z0-9.-]+$' +DC_RE = r"^[a-zA-Z0-9.-]+$" # Validation has changed since Ubiquity. See the following links on why: # https://bugs.launchpad.net/ubuntu-mate/+bug/1985971 # https://github.com/canonical/subiquity/pull/1553#discussion_r1103063195 @@ -58,26 +58,27 @@ class DcPingStrategy: return AdDomainNameValidation.OK async def discover(self) -> str: - """ Attempts to discover a domain through the network. - Returns the domain or an empty string on error. """ + """Attempts to discover a domain through the network. + Returns the domain or an empty string on error.""" cp = await arun_command([self.cmd, self.arg], env={}) discovered = "" if cp.returncode == 0: # A typical output looks like: # 'creative.com\n type: kerberos\n realm-name: CREATIVE.COM\n...' - discovered = cp.stdout.split('\n')[0].strip() + discovered = cp.stdout.split("\n")[0].strip() return discovered class StubDcPingStrategy(DcPingStrategy): - """ For testing purpose. This class doesn't talk to the network. - Instead its response follows the following rule: + """For testing purpose. This class doesn't talk to the network. + Instead its response follows the following rule: + + - addresses starting with "r" return REALM_NOT_FOUND; + - addresses starting with any other letter return OK.""" - - addresses starting with "r" return REALM_NOT_FOUND; - - addresses starting with any other letter return OK. """ async def ping(self, address: str) -> AdDomainNameValidation: - if address[0] == 'r': + if address[0] == "r": return AdDomainNameValidation.REALM_NOT_FOUND return AdDomainNameValidation.OK @@ -90,38 +91,40 @@ class StubDcPingStrategy(DcPingStrategy): class AdController(SubiquityController): - """ Implements the server part of the Active Directory feature. """ + """Implements the server part of the Active Directory feature.""" + endpoint = API.active_directory # No auto install key and schema for now due password handling uncertainty. autoinstall_key = "active-directory" model_name = "active_directory" autoinstall_schema = { - 'type': 'object', - 'properties': { - 'admin-name': { - 'type': 'string', + "type": "object", + "properties": { + "admin-name": { + "type": "string", }, - 'domain-name': { - 'type': 'string', + "domain-name": { + "type": "string", }, }, - 'additionalProperties': False, + "additionalProperties": False, } - autoinstall_default = {"admin-name": '', 'domain-name': ''} + autoinstall_default = {"admin-name": "", "domain-name": ""} def make_autoinstall(self): info = self.model.conn_info if info is None: return None - return {'admin-name': info.admin_name, 'domain-name': info.domain_name} + return {"admin-name": info.admin_name, "domain-name": info.domain_name} def load_autoinstall_data(self, data): if data is None: return - if 'admin-name' in data and 'domain-name' in data: - info = AdConnectionInfo(admin_name=data['admin-name'], - domain_name=data['domain-name']) + if "admin-name" in data and "domain-name" in data: + info = AdConnectionInfo( + admin_name=data["admin-name"], domain_name=data["domain-name"] + ) self.model.set(info) self.model.do_join = False @@ -157,56 +160,56 @@ class AdController(SubiquityController): return self.model.conn_info async def POST(self, data: AdConnectionInfo) -> None: - """ Configures this controller with the supplied info. - Clients are required to validate the info before POST'ing """ + """Configures this controller with the supplied info. + Clients are required to validate the info before POST'ing""" self.model.set(data) await self.configured() - async def check_admin_name_POST(self, admin_name: str) \ - -> AdAdminNameValidation: + async def check_admin_name_POST(self, admin_name: str) -> AdAdminNameValidation: return AdValidators.admin_user_name(admin_name) - async def check_domain_name_POST(self, domain_name: str) \ - -> List[AdDomainNameValidation]: + async def check_domain_name_POST( + self, domain_name: str + ) -> List[AdDomainNameValidation]: result = AdValidators.domain_name(domain_name) return list(result) - async def ping_domain_controller_POST(self, domain_name: str) \ - -> AdDomainNameValidation: - return await AdValidators.ping_domain_controller(domain_name, - self.ping_strgy) + async def ping_domain_controller_POST( + self, domain_name: str + ) -> AdDomainNameValidation: + return await AdValidators.ping_domain_controller(domain_name, self.ping_strgy) async def check_password_POST(self, password: str) -> AdPasswordValidation: return AdValidators.password(password) async def has_support_GET(self) -> bool: - """ Returns True if the executables required - to configure AD are present in the live system.""" + """Returns True if the executables required + to configure AD are present in the live system.""" return self.ping_strgy.has_support() async def join_result_GET(self, wait: bool = True) -> AdJoinResult: - """ If [wait] is True and the model is set for joining, this method - blocks until an attempt to join a domain completes. - Otherwise returns the current known state. - Most likely it will be AdJoinResult.UNKNOWN. """ + """If [wait] is True and the model is set for joining, this method + blocks until an attempt to join a domain completes. + Otherwise returns the current known state. + Most likely it will be AdJoinResult.UNKNOWN.""" if wait and self.model.do_join: self.join_result = await self.ad_joiner.join_result() return self.join_result async def join_domain(self, hostname: str, context) -> None: - """ To be called from the install controller if the user requested - joining an AD domain """ - await self.ad_joiner.join_domain(self.model.conn_info, hostname, - context) + """To be called from the install controller if the user requested + joining an AD domain""" + await self.ad_joiner.join_domain(self.model.conn_info, hostname, context) # Helper out-of-class functions grouped. class AdValidators: - """ Groups functions that validates the AD info supplied by users. """ + """Groups functions that validates the AD info supplied by users.""" + @staticmethod def admin_user_name(name) -> AdAdminNameValidation: - """ Validates the supplied admin name against known patterns. """ + """Validates the supplied admin name against known patterns.""" if len(name) == 0: log.debug("admin name is empty") @@ -215,15 +218,14 @@ class AdValidators: # Triggers error if any of the forbidden chars is present. regex = re.compile(f"[{re.escape(AD_ACCOUNT_FORBIDDEN_CHARS)}]") if regex.search(name): - log.debug('<%s>: domain admin name contains invalid characters', - name) + log.debug("<%s>: domain admin name contains invalid characters", name) return AdAdminNameValidation.INVALID_CHARS return AdAdminNameValidation.OK @staticmethod def password(value: str) -> AdPasswordValidation: - """ Validates that the password is not empty. """ + """Validates that the password is not empty.""" if not value: return AdPasswordValidation.EMPTY @@ -232,7 +234,7 @@ class AdValidators: @staticmethod def domain_name(name: str) -> Set[AdDomainNameValidation]: - """ Check the correctness of a proposed host name. + """Check the correctness of a proposed host name. Returns a set of the possible errors: - OK = self explanatory. Should be the only result then. @@ -256,34 +258,30 @@ class AdValidators: log.debug("<%s>: %s too long", name, FIELD) result.add(AdDomainNameValidation.TOO_LONG) - if name.startswith('-'): - log.debug("<%s>: %s cannot start with hyphens (-)", - name, FIELD) + if name.startswith("-"): + log.debug("<%s>: %s cannot start with hyphens (-)", name, FIELD) result.add(AdDomainNameValidation.START_HYPHEN) - if name.endswith('-'): - log.debug("<%s>: %s cannot end with hyphens (-)", - name, FIELD) + if name.endswith("-"): + log.debug("<%s>: %s cannot end with hyphens (-)", name, FIELD) result.add(AdDomainNameValidation.END_HYPHEN) - if '..' in name: - log.debug('<%s>: %s cannot contain double dots (..)', name, FIELD) + if ".." in name: + log.debug("<%s>: %s cannot contain double dots (..)", name, FIELD) result.add(AdDomainNameValidation.MULTIPLE_DOTS) - if name.startswith('.'): - log.debug('<%s>: %s cannot start with dots (.)', - name, FIELD) + if name.startswith("."): + log.debug("<%s>: %s cannot start with dots (.)", name, FIELD) result.add(AdDomainNameValidation.START_DOT) - if name.endswith('.'): - log.debug('<%s>: %s cannot end with dots (.)', - name, FIELD) + if name.endswith("."): + log.debug("<%s>: %s cannot end with dots (.)", name, FIELD) result.add(AdDomainNameValidation.END_DOT) regex = re.compile(DC_RE) if not regex.search(name): result.add(AdDomainNameValidation.INVALID_CHARS) - log.debug('<%s>: %s contains invalid characters', name, FIELD) + log.debug("<%s>: %s contains invalid characters", name, FIELD) if result: return result @@ -291,10 +289,11 @@ class AdValidators: return {AdDomainNameValidation.OK} @staticmethod - async def ping_domain_controller(name: str, strategy: DcPingStrategy) \ - -> AdDomainNameValidation: - """ Attempts to find the specified DC in the network. - Returns either OK, EMPTY or REALM_NOT_FOUND. """ + async def ping_domain_controller( + name: str, strategy: DcPingStrategy + ) -> AdDomainNameValidation: + """Attempts to find the specified DC in the network. + Returns either OK, EMPTY or REALM_NOT_FOUND.""" if not name: return AdDomainNameValidation.EMPTY diff --git a/subiquity/server/controllers/cmdlist.py b/subiquity/server/controllers/cmdlist.py index d6149438..b363e5e5 100644 --- a/subiquity/server/controllers/cmdlist.py +++ b/subiquity/server/controllers/cmdlist.py @@ -21,44 +21,43 @@ from typing import List, Sequence, Union import attr from systemd import journal +from subiquity.common.types import ApplicationState +from subiquity.server.controller import NonInteractiveController from subiquitycore.async_helpers import run_bg_task from subiquitycore.context import with_context from subiquitycore.utils import arun_command -from subiquity.common.types import ApplicationState -from subiquity.server.controller import NonInteractiveController - @attr.s(auto_attribs=True) class Command: - """ Represents a command, specified either as a list of arguments or as a - single string. """ + """Represents a command, specified either as a list of arguments or as a + single string.""" + args: Union[str, Sequence[str]] check: bool def desc(self) -> str: - """ Return a user-friendly representation of the command. """ + """Return a user-friendly representation of the command.""" if isinstance(self.args, str): return self.args return shlex.join(self.args) def as_args_list(self) -> List[str]: - """ Return the command as a list of arguments. """ + """Return the command as a list of arguments.""" if isinstance(self.args, str): - return ['sh', '-c', self.args] + return ["sh", "-c", self.args] return list(self.args) class CmdListController(NonInteractiveController): - autoinstall_default = [] autoinstall_schema = { - 'type': 'array', - 'items': { - 'type': ['string', 'array'], - 'items': {'type': 'string'}, - }, - } + "type": "array", + "items": { + "type": ["string", "array"], + "items": {"type": "string"}, + }, + } builtin_cmds: Sequence[Command] = () cmds: Sequence[Command] = () cmd_check = True @@ -82,22 +81,20 @@ class CmdListController(NonInteractiveController): with context.child("command_{}".format(i), desc): args = cmd.as_args_list() if self.syslog_id: - journal.send( - " running " + desc, SYSLOG_IDENTIFIER=self.syslog_id) + journal.send(" running " + desc, SYSLOG_IDENTIFIER=self.syslog_id) args = [ - 'systemd-cat', '--level-prefix=false', - '--identifier=' + self.syslog_id, - ] + args + "systemd-cat", + "--level-prefix=false", + "--identifier=" + self.syslog_id, + ] + args await arun_command( - args, env=env, - stdin=None, stdout=None, stderr=None, - check=cmd.check) + args, env=env, stdin=None, stdout=None, stderr=None, check=cmd.check + ) self.run_event.set() class EarlyController(CmdListController): - - autoinstall_key = 'early-commands' + autoinstall_key = "early-commands" def __init__(self, app): super().__init__(app) @@ -105,8 +102,7 @@ class EarlyController(CmdListController): class LateController(CmdListController): - - autoinstall_key = 'late-commands' + autoinstall_key = "late-commands" def __init__(self, app): super().__init__(app) @@ -130,7 +126,7 @@ class LateController(CmdListController): def env(self): env = super().env() if self.app.base_model.target is not None: - env['TARGET_MOUNT_POINT'] = self.app.base_model.target + env["TARGET_MOUNT_POINT"] = self.app.base_model.target return env def start(self): @@ -144,8 +140,7 @@ class LateController(CmdListController): class ErrorController(CmdListController): - - autoinstall_key = 'error-commands' + autoinstall_key = "error-commands" cmd_check = False @with_context() diff --git a/subiquity/server/controllers/codecs.py b/subiquity/server/controllers/codecs.py index 312b1078..81f21754 100644 --- a/subiquity/server/controllers/codecs.py +++ b/subiquity/server/controllers/codecs.py @@ -19,19 +19,18 @@ from subiquity.common.apidef import API from subiquity.common.types import CodecsData from subiquity.server.controller import SubiquityController -log = logging.getLogger('subiquity.server.controllers.codecs') +log = logging.getLogger("subiquity.server.controllers.codecs") class CodecsController(SubiquityController): - endpoint = API.codecs autoinstall_key = model_name = "codecs" autoinstall_schema = { - 'type': 'object', - 'properties': { - 'install': { - 'type': 'boolean', + "type": "object", + "properties": { + "install": { + "type": "boolean", }, }, } diff --git a/subiquity/server/controllers/debconf.py b/subiquity/server/controllers/debconf.py index 251f765e..2e383410 100644 --- a/subiquity/server/controllers/debconf.py +++ b/subiquity/server/controllers/debconf.py @@ -17,13 +17,12 @@ from subiquity.server.controller import NonInteractiveController class DebconfController(NonInteractiveController): - model_name = "debconf_selections" autoinstall_key = "debconf-selections" autoinstall_default = "" autoinstall_schema = { - 'type': 'string', - } + "type": "string", + } def load_autoinstall_data(self, data): self.model.selections = data diff --git a/subiquity/server/controllers/drivers.py b/subiquity/server/controllers/drivers.py index 21660355..a0f24431 100644 --- a/subiquity/server/controllers/drivers.py +++ b/subiquity/server/controllers/drivers.py @@ -17,8 +17,6 @@ import asyncio import logging from typing import List, Optional -from subiquitycore.context import with_context - from subiquity.common.apidef import API from subiquity.common.types import DriversPayload, DriversResponse from subiquity.server.apt import OverlayCleanupError @@ -28,21 +26,21 @@ from subiquity.server.ubuntu_drivers import ( CommandNotFoundError, UbuntuDriversInterface, get_ubuntu_drivers_interface, - ) +) +from subiquitycore.context import with_context -log = logging.getLogger('subiquity.server.controllers.drivers') +log = logging.getLogger("subiquity.server.controllers.drivers") class DriversController(SubiquityController): - endpoint = API.drivers autoinstall_key = model_name = "drivers" autoinstall_schema = { - 'type': 'object', - 'properties': { - 'install': { - 'type': 'boolean', + "type": "object", + "properties": { + "install": { + "type": "boolean", }, }, } @@ -70,17 +68,15 @@ class DriversController(SubiquityController): def start(self): self._wait_apt = asyncio.Event() + self.app.hub.subscribe(InstallerChannels.APT_CONFIGURED, self._wait_apt.set) self.app.hub.subscribe( - InstallerChannels.APT_CONFIGURED, - self._wait_apt.set) - self.app.hub.subscribe( - (InstallerChannels.CONFIGURED, "source"), - self.restart_querying_drivers_list) + (InstallerChannels.CONFIGURED, "source"), self.restart_querying_drivers_list + ) def restart_querying_drivers_list(self): - """ Start querying the list of available drivers. This method can be + """Start querying the list of available drivers. This method can be invoked multiple times so we need to stop ongoing operations if the - variant changes. """ + variant changes.""" self.ubuntu_drivers = get_ubuntu_drivers_interface(self.app) @@ -114,8 +110,8 @@ class DriversController(SubiquityController): self.drivers = [] else: self.drivers = await self.ubuntu_drivers.list_drivers( - root_dir=d.mountpoint, - context=context) + root_dir=d.mountpoint, context=context + ) except OverlayCleanupError: log.exception("Failed to cleanup overlay. Continuing anyway.") self.list_drivers_done_event.set() @@ -128,10 +124,12 @@ class DriversController(SubiquityController): search_drivers = self.app.controllers.Source.model.search_drivers - return DriversResponse(install=self.model.do_install, - drivers=self.drivers, - local_only=local_only, - search_drivers=search_drivers) + return DriversResponse( + install=self.model.do_install, + drivers=self.drivers, + local_only=local_only, + search_drivers=search_drivers, + ) async def POST(self, data: DriversPayload) -> None: self.model.do_install = data.install diff --git a/subiquity/server/controllers/filesystem.py b/subiquity/server/controllers/filesystem.py index 4d951394..c959c1ac 100644 --- a/subiquity/server/controllers/filesystem.py +++ b/subiquity/server/controllers/filesystem.py @@ -25,38 +25,15 @@ import time from typing import Any, Dict, List, Optional import attr - +import pyudev from curtin.commands.extract import AbstractSourceHandler from curtin.storage_config import ptable_part_type_to_flag -import pyudev - -from subiquitycore.async_helpers import ( - schedule_task, - SingleInstanceTask, - TaskAlreadyRunningError, - ) -from subiquitycore.context import with_context -from subiquitycore.utils import ( - arun_command, - run_command, - ) -from subiquitycore.lsb_release import lsb_release - from subiquity.common.apidef import API from subiquity.common.errorreport import ErrorReportKind -from subiquity.common.filesystem.actions import ( - DeviceAction, - ) -from subiquity.common.filesystem import ( - boot, - gaps, - labels, - sizes, -) -from subiquity.common.filesystem.manipulator import ( - FilesystemManipulator, -) +from subiquity.common.filesystem import boot, gaps, labels, sizes +from subiquity.common.filesystem.actions import DeviceAction +from subiquity.common.filesystem.manipulator import FilesystemManipulator from subiquity.common.types import ( AddPartitionV2, Bootloader, @@ -77,40 +54,42 @@ from subiquity.common.types import ( SizingPolicy, StorageResponse, StorageResponseV2, - ) +) from subiquity.models.filesystem import ( + LVM_CHUNK_SIZE, ActionRenderMode, ArbitraryDevice, - align_up, - align_down, - _Device, - Disk as ModelDisk, - MiB, - Partition as ModelPartition, - LVM_CHUNK_SIZE, - Raid, - ) -from subiquity.server.controller import ( - SubiquityController, - ) +) +from subiquity.models.filesystem import Disk as ModelDisk +from subiquity.models.filesystem import MiB +from subiquity.models.filesystem import Partition as ModelPartition +from subiquity.models.filesystem import Raid, _Device, align_down, align_up from subiquity.server import snapdapi +from subiquity.server.controller import SubiquityController from subiquity.server.mounter import Mounter from subiquity.server.snapdapi import ( StorageEncryptionSupport, StorageSafety, SystemDetails, - ) +) from subiquity.server.types import InstallerChannels - +from subiquitycore.async_helpers import ( + SingleInstanceTask, + TaskAlreadyRunningError, + schedule_task, +) +from subiquitycore.context import with_context +from subiquitycore.lsb_release import lsb_release +from subiquitycore.utils import arun_command, run_command log = logging.getLogger("subiquity.server.controllers.filesystem") -block_discover_log = logging.getLogger('block-discover') +block_discover_log = logging.getLogger("block-discover") # for translators: 'reason' is the reason FDE is unavailable. system_defective_encryption_text = _( "TPM backed full-disk encryption is not available " - "on this device (the reason given was \"{reason}\")." + 'on this device (the reason given was "{reason}").' ) system_multiple_volumes_text = _( @@ -124,7 +103,7 @@ system_non_gpt_text = _( ) -DRY_RUN_RESET_SIZE = 500*MiB +DRY_RUN_RESET_SIZE = 500 * MiB class NoSnapdSystemsOnSource(Exception): @@ -167,21 +146,25 @@ class VariationInfo: return bool(self.capability_info.allowed) def capability_info_for_gap( - self, - gap: gaps.Gap, - ) -> CapabilityInfo: + self, + gap: gaps.Gap, + ) -> CapabilityInfo: if self.label is None: install_min = sizes.calculate_suggested_install_min( - self.min_size, gap.device.alignment_data().part_align) + self.min_size, gap.device.alignment_data().part_align + ) else: install_min = self.min_size r = CapabilityInfo() r.disallowed = list(self.capability_info.disallowed) if self.capability_info.allowed and gap.size < install_min: for capability in self.capability_info.allowed: - r.disallowed.append(GuidedDisallowedCapability( - capability=capability, - reason=GuidedDisallowedCapabilityReason.TOO_SMALL)) + r.disallowed.append( + GuidedDisallowedCapability( + capability=capability, + reason=GuidedDisallowedCapabilityReason.TOO_SMALL, + ) + ) else: r.allowed = list(self.capability_info.allowed) return r @@ -198,15 +181,16 @@ class VariationInfo: GuidedCapability.LVM, GuidedCapability.LVM_LUKS, GuidedCapability.ZFS, - ])) + ] + ), + ) class FilesystemController(SubiquityController, FilesystemManipulator): - endpoint = API.storage autoinstall_key = "storage" - autoinstall_schema = {'type': 'object'} # ... + autoinstall_schema = {"type": "object"} # ... model_name = "filesystem" _configured = False @@ -222,16 +206,18 @@ class FilesystemController(SubiquityController, FilesystemManipulator): self._monitor = None self._errors = {} self._probe_once_task = SingleInstanceTask( - self._probe_once, propagate_errors=False) + self._probe_once, propagate_errors=False + ) self._probe_task = SingleInstanceTask( - self._probe, propagate_errors=False, cancel_restart=False) + self._probe, propagate_errors=False, cancel_restart=False + ) self._examine_systems_task = SingleInstanceTask(self._examine_systems) self.supports_resilient_boot = False self.app.hub.subscribe( - (InstallerChannels.CONFIGURED, 'source'), - self._examine_systems_task.start_sync) - self.app.hub.subscribe( - InstallerChannels.PRE_SHUTDOWN, self._pre_shutdown) + (InstallerChannels.CONFIGURED, "source"), + self._examine_systems_task.start_sync, + ) + self.app.hub.subscribe(InstallerChannels.PRE_SHUTDOWN, self._pre_shutdown) self._variation_info: Dict[str, VariationInfo] = {} self._info: Optional[VariationInfo] = None self._on_volume: Optional[snapdapi.OnVolume] = None @@ -264,10 +250,9 @@ class FilesystemController(SubiquityController, FilesystemManipulator): self.stop_listening_udev() async def _mount_systems_dir(self, variation_name): - self._source_handler = \ - self.app.controllers.Source.get_handler(variation_name) + self._source_handler = self.app.controllers.Source.get_handler(variation_name) source_path = self._source_handler.setup() - cur_systems_dir = '/var/lib/snapd/seed/systems' + cur_systems_dir = "/var/lib/snapd/seed/systems" source_systems_dir = os.path.join(source_path, cur_systems_dir[1:]) if self.app.opts.dry_run: systems_dir_exists = self.app.dr_cfg.systems_dir_exists @@ -278,13 +263,13 @@ class FilesystemController(SubiquityController, FilesystemManipulator): self._system_mounter = Mounter(self.app) if not self.app.opts.dry_run: await self._system_mounter.bind_mount_tree( - source_systems_dir, cur_systems_dir) + source_systems_dir, cur_systems_dir + ) - cur_snaps_dir = '/var/lib/snapd/seed/snaps' + cur_snaps_dir = "/var/lib/snapd/seed/snaps" source_snaps_dir = os.path.join(source_path, cur_snaps_dir[1:]) if not self.app.opts.dry_run: - await self._system_mounter.bind_mount_tree( - source_snaps_dir, cur_snaps_dir) + await self._system_mounter.bind_mount_tree(source_snaps_dir, cur_snaps_dir) async def _unmount_systems_dir(self): if self._system_mounter is not None: @@ -308,23 +293,19 @@ class FilesystemController(SubiquityController, FilesystemManipulator): def info_for_system(self, name: str, label: str, system: SystemDetails): if len(system.volumes) > 1: - log.error( - "Skipping uninstallable system: %s", - system_multiple_volumes_text) + log.error("Skipping uninstallable system: %s", system_multiple_volumes_text) return None [volume] = system.volumes.values() - if volume.schema != 'gpt': - log.error( - "Skipping uninstallable system: %s", - system_non_gpt_text) + if volume.schema != "gpt": + log.error("Skipping uninstallable system: %s", system_non_gpt_text) return None info = VariationInfo( name=name, label=label, system=system, - ) + ) def disallowed_encryption(msg): GCDR = GuidedDisallowedCapabilityReason @@ -332,40 +313,40 @@ class FilesystemController(SubiquityController, FilesystemManipulator): return GuidedDisallowedCapability( capability=GuidedCapability.CORE_BOOT_ENCRYPTED, reason=reason, - message=msg) + message=msg, + ) se = system.storage_encryption if se.support == StorageEncryptionSupport.DEFECTIVE: info.capability_info.disallowed = [ - disallowed_encryption(se.unavailable_reason)] + disallowed_encryption(se.unavailable_reason) + ] return info - offsets_and_sizes = list( - self._offsets_and_sizes_for_volume(volume)) + offsets_and_sizes = list(self._offsets_and_sizes_for_volume(volume)) _structure, last_offset, last_size = offsets_and_sizes[-1] info.min_size = last_offset + last_size if se.support == StorageEncryptionSupport.DISABLED: - info.capability_info.allowed = [ - GuidedCapability.CORE_BOOT_UNENCRYPTED] - msg = _( - "TPM backed full-disk encryption has been disabled") + info.capability_info.allowed = [GuidedCapability.CORE_BOOT_UNENCRYPTED] + msg = _("TPM backed full-disk encryption has been disabled") info.capability_info.disallowed = [disallowed_encryption(msg)] elif se.support == StorageEncryptionSupport.UNAVAILABLE: - info.capability_info.allowed = [ - GuidedCapability.CORE_BOOT_UNENCRYPTED] + info.capability_info.allowed = [GuidedCapability.CORE_BOOT_UNENCRYPTED] info.capability_info.disallowed = [ - disallowed_encryption(se.unavailable_reason)] + disallowed_encryption(se.unavailable_reason) + ] else: if se.storage_safety == StorageSafety.ENCRYPTED: - info.capability_info.allowed = [ - GuidedCapability.CORE_BOOT_ENCRYPTED] + info.capability_info.allowed = [GuidedCapability.CORE_BOOT_ENCRYPTED] elif se.storage_safety == StorageSafety.PREFER_ENCRYPTED: info.capability_info.allowed = [ - GuidedCapability.CORE_BOOT_PREFER_ENCRYPTED] + GuidedCapability.CORE_BOOT_PREFER_ENCRYPTED + ] elif se.storage_safety == StorageSafety.PREFER_UNENCRYPTED: info.capability_info.allowed = [ - GuidedCapability.CORE_BOOT_PREFER_UNENCRYPTED] + GuidedCapability.CORE_BOOT_PREFER_UNENCRYPTED + ] return info @@ -383,7 +364,8 @@ class FilesystemController(SubiquityController, FilesystemManipulator): self._variation_info[name] = info else: self._variation_info[name] = VariationInfo.classic( - name=name, min_size=variation.size) + name=name, min_size=variation.size + ) @with_context() async def apply_autoinstall_config(self, context=None): @@ -396,20 +378,22 @@ class FilesystemController(SubiquityController, FilesystemManipulator): raise self._errors[True][0] if self.ai_data is None: # If there are any classic variations, we default that. - if any(not variation.is_core_boot_classic() - for variation in self._variation_info.values() - if variation.is_valid()): + if any( + not variation.is_core_boot_classic() + for variation in self._variation_info.values() + if variation.is_valid() + ): self.ai_data = { - 'layout': { - 'name': 'lvm', - }, - } + "layout": { + "name": "lvm", + }, + } else: self.ai_data = { - 'layout': { - 'name': 'hybrid', - }, - } + "layout": { + "name": "hybrid", + }, + } await self.convert_autoinstall_config(context=context) if self.reset_partition_only: return @@ -417,8 +401,8 @@ class FilesystemController(SubiquityController, FilesystemManipulator): raise Exception("autoinstall config did not mount root") if self.model.needs_bootloader_partition(): raise Exception( - "autoinstall config did not create needed bootloader " - "partition") + "autoinstall config did not create needed bootloader " "partition" + ) def update_devices(self, device_map): for action in self.model._actions: @@ -438,18 +422,18 @@ class FilesystemController(SubiquityController, FilesystemManipulator): part_align = device.alignment_data().part_align bootfs_size = align_up(sizes.get_bootfs_size(gap.size), part_align) gap_boot, gap_rest = gap.split(bootfs_size) - spec = dict(fstype="ext4", mount='/boot') + spec = dict(fstype="ext4", mount="/boot") self.create_partition(device, gap_boot, spec) part = self.create_partition(device, gap_rest, dict(fstype=None)) - vg_name = 'ubuntu-vg' + vg_name = "ubuntu-vg" i = 0 - while self.model._one(type='lvm_volgroup', name=vg_name) is not None: + while self.model._one(type="lvm_volgroup", name=vg_name) is not None: i += 1 - vg_name = 'ubuntu-vg-{}'.format(i) + vg_name = "ubuntu-vg-{}".format(i) spec = dict(name=vg_name, devices=set([part])) if choice.password is not None: - spec['passphrase'] = choice.password + spec["passphrase"] = choice.password vg = self.create_volgroup(spec) if choice.sizing_policy == SizingPolicy.SCALED: lv_size = sizes.scaled_rootfs_size(vg.size) @@ -457,15 +441,17 @@ class FilesystemController(SubiquityController, FilesystemManipulator): elif choice.sizing_policy == SizingPolicy.ALL: lv_size = vg.size else: - raise Exception(f'Unhandled size policy {choice.sizing_policy}') - log.debug(f'lv_size {lv_size} for {choice.sizing_policy}') + raise Exception(f"Unhandled size policy {choice.sizing_policy}") + log.debug(f"lv_size {lv_size} for {choice.sizing_policy}") self.create_logical_volume( - vg=vg, spec=dict( + vg=vg, + spec=dict( size=lv_size, name="ubuntu-lv", fstype="ext4", mount="/", - )) + ), + ) def guided_zfs(self, gap, choice: GuidedChoiceV2): device = gap.device @@ -476,19 +462,19 @@ class FilesystemController(SubiquityController, FilesystemManipulator): bpool_part = self.create_partition(device, gap_boot, dict(fstype=None)) rpool_part = self.create_partition(device, gap_rest, dict(fstype=None)) - self.create_zpool(bpool_part, 'bpool', '/boot') - self.create_zpool(rpool_part, 'rpool', '/') + self.create_zpool(bpool_part, "bpool", "/boot") + self.create_zpool(rpool_part, "rpool", "/") @functools.singledispatchmethod - def start_guided(self, target: GuidedStorageTarget, - disk: ModelDisk) -> gaps.Gap: + def start_guided(self, target: GuidedStorageTarget, disk: ModelDisk) -> gaps.Gap: """Setup changes to the disk to prepare the gap that we will be doing a guided install into.""" raise NotImplementedError(target) @start_guided.register - def start_guided_reformat(self, target: GuidedStorageTargetReformat, - disk: ModelDisk) -> gaps.Gap: + def start_guided_reformat( + self, target: GuidedStorageTargetReformat, disk: ModelDisk + ) -> gaps.Gap: """Perform the reformat, and return the resulting gap.""" in_use_parts = [p for p in disk.partitions() if p._is_in_use] if in_use_parts: @@ -496,25 +482,27 @@ class FilesystemController(SubiquityController, FilesystemManipulator): if not p._is_in_use: self.delete_partition(p) else: - self.reformat(disk, wipe='superblock-recursive') + self.reformat(disk, wipe="superblock-recursive") return gaps.largest_gap(disk) @start_guided.register - def start_guided_use_gap(self, target: GuidedStorageTargetUseGap, - disk: ModelDisk) -> gaps.Gap: + def start_guided_use_gap( + self, target: GuidedStorageTargetUseGap, disk: ModelDisk + ) -> gaps.Gap: """Lookup the matching model gap.""" return gaps.at_offset(disk, target.gap.offset) @start_guided.register - def start_guided_resize(self, target: GuidedStorageTargetResize, - disk: ModelDisk) -> gaps.Gap: + def start_guided_resize( + self, target: GuidedStorageTargetResize, disk: ModelDisk + ) -> gaps.Gap: """Perform the resize of the target partition, and return the resulting gap.""" partition = self.get_partition(disk, target.partition_number) part_align = disk.alignment_data().part_align new_size = align_up(target.new_size, part_align) if new_size > partition.size: - raise Exception(f'Aligned requested size {new_size} too large') + raise Exception(f"Aligned requested size {new_size} too large") partition.size = new_size partition.resize = True # Calculating where that gap will be can be tricky due to alignment @@ -523,7 +511,7 @@ class FilesystemController(SubiquityController, FilesystemManipulator): gap = gaps.after(disk, partition.offset) if gap is None: pgs = gaps.parts_and_gaps(disk) - raise Exception(f'gap not found after resize, pgs={pgs}') + raise Exception(f"gap not found after resize, pgs={pgs}") return gap def set_info_for_capability(self, capability: GuidedCapability): @@ -535,14 +523,14 @@ class FilesystemController(SubiquityController, FilesystemManipulator): GuidedCapability.CORE_BOOT_ENCRYPTED, GuidedCapability.CORE_BOOT_PREFER_ENCRYPTED, GuidedCapability.CORE_BOOT_PREFER_UNENCRYPTED, - } + } elif capability == GuidedCapability.CORE_BOOT_UNENCRYPTED: # Similar if the request is for uncrypted caps = { GuidedCapability.CORE_BOOT_UNENCRYPTED, GuidedCapability.CORE_BOOT_PREFER_ENCRYPTED, GuidedCapability.CORE_BOOT_PREFER_UNENCRYPTED, - } + } else: # Otherwise, just look for what we were asked for. caps = {capability} @@ -550,14 +538,11 @@ class FilesystemController(SubiquityController, FilesystemManipulator): if caps & set(info.capability_info.allowed): self._info = info return - raise Exception( - "could not find variation for {}".format(capability)) + raise Exception("could not find variation for {}".format(capability)) async def guided( - self, - choice: GuidedChoiceV2, - reset_partition_only: bool = False - ) -> None: + self, choice: GuidedChoiceV2, reset_partition_only: bool = False + ) -> None: if choice.capability == GuidedCapability.MANUAL: return @@ -569,8 +554,7 @@ class FilesystemController(SubiquityController, FilesystemManipulator): if self.is_core_boot_classic(): assert isinstance(choice.target, GuidedStorageTargetReformat) - self.use_tpm = ( - choice.capability == GuidedCapability.CORE_BOOT_ENCRYPTED) + self.use_tpm = choice.capability == GuidedCapability.CORE_BOOT_ENCRYPTED await self.guided_core_boot(disk) return @@ -580,22 +564,25 @@ class FilesystemController(SubiquityController, FilesystemManipulator): # find what's left of the gap after adding boot gap = gap.within() if gap is None: - raise Exception('failed to locate gap after adding boot') + raise Exception("failed to locate gap after adding boot") if choice.reset_partition: if self.app.opts.dry_run: reset_size = DRY_RUN_RESET_SIZE else: - cp = await arun_command(['du', '-sb', '/cdrom']) + cp = await arun_command(["du", "-sb", "/cdrom"]) reset_size = int(cp.stdout.strip().split()[0]) reset_size = align_up(int(reset_size * 1.10), 256 * MiB) reset_gap, gap = gap.split(reset_size) self.reset_partition = self.create_partition( - device=reset_gap.device, gap=reset_gap, - spec={'fstype': 'fat32'}, flag='msftres') + device=reset_gap.device, + gap=reset_gap, + spec={"fstype": "fat32"}, + flag="msftres", + ) self.reset_partition_only = reset_partition_only if reset_partition_only: - for mount in self.model._all(type='mount'): + for mount in self.model._all(type="mount"): self.delete_mount(mount) self.model.target = self.app.base_model.target = None return @@ -607,7 +594,7 @@ class FilesystemController(SubiquityController, FilesystemManipulator): elif choice.capability == GuidedCapability.DIRECT: self.guided_direct(gap) else: - raise ValueError('cannot process capability') + raise ValueError("cannot process capability") async def _probe_response(self, wait, resp_cls): if not self._probe_task.done(): @@ -618,8 +605,8 @@ class FilesystemController(SubiquityController, FilesystemManipulator): return resp_cls(status=ProbeStatus.PROBING) if True in self._errors: return resp_cls( - status=ProbeStatus.FAILED, - error_report=self._errors[True][1].ref()) + status=ProbeStatus.FAILED, error_report=self._errors[True][1].ref() + ) if not self._examine_systems_task.done(): if wait: await self._examine_systems_task.wait() @@ -640,11 +627,13 @@ class FilesystemController(SubiquityController, FilesystemManipulator): error_report=self.full_probe_error(), orig_config=self.model._orig_config, config=self.model._render_actions(mode=ActionRenderMode.FOR_API), - dasd=self.model._probe_data.get('dasd', {}), - storage_version=self.model.storage_version) + dasd=self.model._probe_data.get("dasd", {}), + storage_version=self.model.storage_version, + ) - async def GET(self, wait: bool = False, use_cached_result: bool = False) \ - -> StorageResponse: + async def GET( + self, wait: bool = False, use_cached_result: bool = False + ) -> StorageResponse: if not use_cached_result: probe_resp = await self._probe_response(wait, StorageResponse) if probe_resp is not None: @@ -653,29 +642,31 @@ class FilesystemController(SubiquityController, FilesystemManipulator): async def POST(self, config: list): log.debug(config) - self.model._actions = \ - self.model._actions_from_config( - config, blockdevs=self.model._probe_data['blockdev'], - is_probe_data=False) + self.model._actions = self.model._actions_from_config( + config, blockdevs=self.model._probe_data["blockdev"], is_probe_data=False + ) await self.configured() def potential_boot_disks(self, check_boot=True, with_reformatting=False): disks = [] - for raid in self.model._all(type='raid'): + for raid in self.model._all(type="raid"): if check_boot and not boot.can_be_boot_device( - raid, with_reformatting=with_reformatting): + raid, with_reformatting=with_reformatting + ): continue disks.append(raid) - for disk in self.model._all(type='disk'): + for disk in self.model._all(type="disk"): if check_boot and not boot.can_be_boot_device( - disk, with_reformatting=with_reformatting): + disk, with_reformatting=with_reformatting + ): continue cd = disk.constructed_device() if isinstance(cd, Raid): can_be_boot = False for v in cd._subvolumes: if check_boot and boot.can_be_boot_device( - v, with_reformatting=with_reformatting): + v, with_reformatting=with_reformatting + ): can_be_boot = True if can_be_boot: continue @@ -683,7 +674,7 @@ class FilesystemController(SubiquityController, FilesystemManipulator): return [d for d in disks] def _offsets_and_sizes_for_volume(self, volume): - offset = self.model._partition_alignment_data['gpt'].min_start_offset + offset = self.model._partition_alignment_data["gpt"].min_start_offset for structure in volume.structure: if structure.role == snapdapi.Role.MBR: continue @@ -708,10 +699,11 @@ class FilesystemController(SubiquityController, FilesystemManipulator): else: parts_by_offset_size = { (part.offset, part.size): part for part in disk.partitions() - } + } for _struct, offset, size in self._offsets_and_sizes_for_volume( - self._on_volume): + self._on_volume + ): if (offset, size) in parts_by_offset_size: preserved_parts.add(parts_by_offset_size[(offset, size)]) @@ -724,16 +716,20 @@ class FilesystemController(SubiquityController, FilesystemManipulator): self.reformat(disk, self._on_volume.schema) for structure, offset, size in self._offsets_and_sizes_for_volume( - self._on_volume): + self._on_volume + ): if (offset, size) in parts_by_offset_size: part = parts_by_offset_size[(offset, size)] else: - if structure.role == snapdapi.Role.SYSTEM_DATA and \ - structure == self._on_volume.structure[-1]: + if ( + structure.role == snapdapi.Role.SYSTEM_DATA + and structure == self._on_volume.structure[-1] + ): gap = gaps.at_offset(disk, offset) size = gap.size part = self.model.add_partition( - disk, offset=offset, size=size, check_alignment=False) + disk, offset=offset, size=size, check_alignment=False + ) type_uuid = structure.gpt_part_type_uuid() if type_uuid: @@ -742,17 +738,18 @@ class FilesystemController(SubiquityController, FilesystemManipulator): if structure.name: part.partition_name = structure.name if structure.filesystem: - part.wipe = 'superblock' + part.wipe = "superblock" self.delete_filesystem(part.fs()) fs = self.model.add_filesystem( - part, structure.filesystem, label=structure.label) + part, structure.filesystem, label=structure.label + ) if structure.role == snapdapi.Role.SYSTEM_DATA: - self.model.add_mount(fs, '/') + self.model.add_mount(fs, "/") elif structure.role == snapdapi.Role.SYSTEM_BOOT: - self.model.add_mount(fs, '/boot') - elif part.flag == 'boot': + self.model.add_mount(fs, "/boot") + elif part.flag == "boot": part.grub_device = True - self.model.add_mount(fs, '/boot/efi') + self.model.add_mount(fs, "/boot/efi") if structure.role != snapdapi.Role.NONE: self._role_to_device[structure.role] = part self._device_to_structure[part] = structure @@ -778,13 +775,15 @@ class FilesystemController(SubiquityController, FilesystemManipulator): snapdapi.SystemActionRequest( action=snapdapi.SystemAction.INSTALL, step=snapdapi.SystemActionStep.SETUP_STORAGE_ENCRYPTION, - on_volumes=self._on_volumes())) - role_to_encrypted_device = result['encrypted-devices'] + on_volumes=self._on_volumes(), + ), + ) + role_to_encrypted_device = result["encrypted-devices"] for role, enc_path in role_to_encrypted_device.items(): arb_device = ArbitraryDevice(m=self.model, path=enc_path) self.model._actions.append(arb_device) part = self._role_to_device[role] - for fs in self.model._all(type='format'): + for fs in self.model._all(type="format"): if fs.volume == part: fs.volume = arb_device @@ -797,7 +796,9 @@ class FilesystemController(SubiquityController, FilesystemManipulator): snapdapi.SystemActionRequest( action=snapdapi.SystemAction.INSTALL, step=snapdapi.SystemActionStep.FINISH, - on_volumes=self._on_volumes())) + on_volumes=self._on_volumes(), + ), + ) async def guided_POST(self, data: GuidedChoiceV2) -> StorageResponse: log.debug(data) @@ -812,15 +813,15 @@ class FilesystemController(SubiquityController, FilesystemManipulator): return await self.GET(context) async def has_rst_GET(self) -> bool: - search = '/sys/module/ahci/drivers/pci:ahci/*/remapped_nvme' + search = "/sys/module/ahci/drivers/pci:ahci/*/remapped_nvme" for remapped_nvme in glob.glob(search): - with open(remapped_nvme, 'r') as f: + with open(remapped_nvme, "r") as f: if int(f.read()) > 0: return True return False async def has_bitlocker_GET(self) -> List[Disk]: - '''list of Disks that contain a partition that is BitLockered''' + """list of Disks that contain a partition that is BitLockered""" bitlockered_disks = [] for disk in self.model.all_disks(): for part in disk.partitions(): @@ -837,16 +838,16 @@ class FilesystemController(SubiquityController, FilesystemManipulator): for p in disk.partitions(): if p.number == number: return p - raise ValueError(f'Partition {number} on {disk.id} not found') + raise ValueError(f"Partition {number} on {disk.id} not found") def calculate_suggested_install_min(self): catalog_entry = self.app.base_model.source.current source_min = max( - variation.size - for variation in catalog_entry.variations.values() - ) - align = max((pa.part_align - for pa in self.model._partition_alignment_data.values())) + variation.size for variation in catalog_entry.variations.values() + ) + align = max( + (pa.part_align for pa in self.model._partition_alignment_data.values()) + ) return sizes.calculate_suggested_install_min(source_min, align) async def get_v2_storage_response(self, model, wait, include_raid): @@ -856,23 +857,22 @@ class FilesystemController(SubiquityController, FilesystemManipulator): if include_raid: disks = self.potential_boot_disks(with_reformatting=True) else: - disks = model._all(type='disk') + disks = model._all(type="disk") minsize = self.calculate_suggested_install_min() return StorageResponseV2( - status=ProbeStatus.DONE, - disks=[labels.for_client(d) for d in disks], - need_root=not model.is_root_mounted(), - need_boot=model.needs_bootloader_partition(), - install_minimum_size=minsize, - ) + status=ProbeStatus.DONE, + disks=[labels.for_client(d) for d in disks], + need_root=not model.is_root_mounted(), + need_boot=model.needs_bootloader_partition(), + install_minimum_size=minsize, + ) async def v2_GET( - self, - wait: bool = False, - include_raid: bool = False, - ) -> StorageResponseV2: - return await self.get_v2_storage_response( - self.model, wait, include_raid) + self, + wait: bool = False, + include_raid: bool = False, + ) -> StorageResponseV2: + return await self.get_v2_storage_response(self.model, wait, include_raid) async def v2_POST(self) -> StorageResponseV2: await self.configured() @@ -909,8 +909,7 @@ class FilesystemController(SubiquityController, FilesystemManipulator): classic_capabilities.update(info.capability_info.allowed) return sorted(classic_capabilities, key=lambda x: x.name) - async def v2_guided_GET(self, wait: bool = False) \ - -> GuidedStorageResponseV2: + async def v2_guided_GET(self, wait: bool = False) -> GuidedStorageResponseV2: """Acquire a list of possible guided storage configuration scenarios. Results are sorted by the size of the space potentially available to the install.""" @@ -934,15 +933,16 @@ class FilesystemController(SubiquityController, FilesystemManipulator): reformat = GuidedStorageTargetReformat( disk_id=disk.id, allowed=capability_info.allowed, - disallowed=capability_info.disallowed) + disallowed=capability_info.disallowed, + ) scenarios.append((disk.size, reformat)) for disk in self.potential_boot_disks(with_reformatting=False): parts = [ p for p in disk.partitions() - if p.flag != 'bios_grub' and not p._is_in_use - ] + if p.flag != "bios_grub" and not p._is_in_use + ] if len(parts) < 1: # On an (essentially) empty disk, don't bother to offer it # with UseGap, as it's basically the same as the Reformat @@ -959,83 +959,87 @@ class FilesystemController(SubiquityController, FilesystemManipulator): capability_info.combine(variation.capability_info_for_gap(gap)) api_gap = labels.for_client(gap) use_gap = GuidedStorageTargetUseGap( - disk_id=disk.id, gap=api_gap, + disk_id=disk.id, + gap=api_gap, allowed=capability_info.allowed, - disallowed=capability_info.disallowed) + disallowed=capability_info.disallowed, + ) scenarios.append((gap.size, use_gap)) for disk in self.potential_boot_disks(check_boot=False): part_align = disk.alignment_data().part_align for partition in disk.partitions(): vals = sizes.calculate_guided_resize( - partition.estimated_min_size, partition.size, - install_min, part_align=part_align) + partition.estimated_min_size, + partition.size, + install_min, + part_align=part_align, + ) if vals is None: # Return a reason here continue if not boot.can_be_boot_device( - disk, resize_partition=partition, - with_reformatting=False): + disk, resize_partition=partition, with_reformatting=False + ): # Return a reason here continue resize = GuidedStorageTargetResize.from_recommendations( - partition, vals, allowed=classic_capabilities) + partition, vals, allowed=classic_capabilities + ) scenarios.append((vals.install_max, resize)) scenarios.sort(reverse=True, key=lambda x: x[0]) return GuidedStorageResponseV2( - status=ProbeStatus.DONE, - configured=self.model.guided_configuration, - targets=[s[1] for s in scenarios]) + status=ProbeStatus.DONE, + configured=self.model.guided_configuration, + targets=[s[1] for s in scenarios], + ) - async def v2_guided_POST(self, data: GuidedChoiceV2) \ - -> GuidedStorageResponseV2: + async def v2_guided_POST(self, data: GuidedChoiceV2) -> GuidedStorageResponseV2: log.debug(data) self.locked_probe_data = True await self.guided(data) return await self.v2_guided_GET() - async def v2_reformat_disk_POST(self, data: ReformatDisk) \ - -> StorageResponseV2: + async def v2_reformat_disk_POST(self, data: ReformatDisk) -> StorageResponseV2: self.locked_probe_data = True self.reformat(self.model._one(id=data.disk_id), data.ptable) return await self.v2_GET() - async def v2_add_boot_partition_POST(self, disk_id: str) \ - -> StorageResponseV2: + async def v2_add_boot_partition_POST(self, disk_id: str) -> StorageResponseV2: self.locked_probe_data = True disk = self.model._one(id=disk_id) if boot.is_boot_device(disk): - raise ValueError('device already has bootloader partition') + raise ValueError("device already has bootloader partition") if DeviceAction.TOGGLE_BOOT not in DeviceAction.supported(disk): raise ValueError("disk does not support boot partiton") self.add_boot_disk(disk) return await self.v2_GET() - async def v2_add_partition_POST(self, data: AddPartitionV2) \ - -> StorageResponseV2: + async def v2_add_partition_POST(self, data: AddPartitionV2) -> StorageResponseV2: log.debug(data) self.locked_probe_data = True if data.partition.boot is not None: - raise ValueError('add_partition does not support changing boot') + raise ValueError("add_partition does not support changing boot") disk = self.model._one(id=data.disk_id) requested_size = data.partition.size or 0 if requested_size > data.gap.size: - raise ValueError('new partition too large') + raise ValueError("new partition too large") if requested_size < 1: requested_size = data.gap.size spec = { - 'size': requested_size, - 'fstype': data.partition.format, - 'mount': data.partition.mount, + "size": requested_size, + "fstype": data.partition.format, + "mount": data.partition.mount, } gap = gaps.at_offset(disk, data.gap.offset).split(requested_size)[0] - self.create_partition(disk, gap, spec, wipe='superblock') + self.create_partition(disk, gap, spec, wipe="superblock") return await self.v2_GET() - async def v2_delete_partition_POST(self, data: ModifyPartitionV2) \ - -> StorageResponseV2: + async def v2_delete_partition_POST( + self, data: ModifyPartitionV2 + ) -> StorageResponseV2: log.debug(data) self.locked_probe_data = True disk = self.model._one(id=data.disk_id) @@ -1043,28 +1047,29 @@ class FilesystemController(SubiquityController, FilesystemManipulator): self.delete_partition(partition) return await self.v2_GET() - async def v2_edit_partition_POST(self, data: ModifyPartitionV2) \ - -> StorageResponseV2: + async def v2_edit_partition_POST( + self, data: ModifyPartitionV2 + ) -> StorageResponseV2: log.debug(data) self.locked_probe_data = True disk = self.model._one(id=data.disk_id) partition = self.get_partition(disk, data.partition.number) - if data.partition.size not in (None, partition.size) \ - and self.app.opts.storage_version < 2: - raise ValueError('edit_partition does not support changing size') - if data.partition.boot is not None \ - and data.partition.boot != partition.boot: - raise ValueError('edit_partition does not support changing boot') - spec = {'mount': data.partition.mount or partition.mount} + if ( + data.partition.size not in (None, partition.size) + and self.app.opts.storage_version < 2 + ): + raise ValueError("edit_partition does not support changing size") + if data.partition.boot is not None and data.partition.boot != partition.boot: + raise ValueError("edit_partition does not support changing boot") + spec = {"mount": data.partition.mount or partition.mount} if data.partition.format is not None: if data.partition.format != partition.original_fstype(): if data.partition.wipe is None: - raise ValueError( - 'changing partition format requires a wipe value') - spec['fstype'] = data.partition.format + raise ValueError("changing partition format requires a wipe value") + spec["fstype"] = data.partition.format if data.partition.size is not None: - spec['size'] = data.partition.size - spec['wipe'] = data.partition.wipe + spec["size"] = data.partition.size + spec["wipe"] = data.partition.wipe self.partition_disk_handler(disk, spec, partition=partition) return await self.v2_GET() @@ -1077,17 +1082,17 @@ class FilesystemController(SubiquityController, FilesystemManipulator): await self._probe_task.task - @with_context(name='probe_once', description='restricted={restricted}') + @with_context(name="probe_once", description="restricted={restricted}") async def _probe_once(self, *, context, restricted): if restricted: - probe_types = {'blockdev', 'filesystem'} - fname = 'probe-data-restricted.json' + probe_types = {"blockdev", "filesystem"} + fname = "probe-data-restricted.json" key = "ProbeDataRestricted" else: - probe_types = {'defaults', 'filesystem_sizing'} + probe_types = {"defaults", "filesystem_sizing"} if self.app.opts.use_os_prober: - probe_types |= {'os'} - fname = 'probe-data.json' + probe_types |= {"os"} + fname = "probe-data.json" key = "ProbeData" storage = await self.app.prober.get_storage(probe_types) # It is possible for the user to submit filesystem config @@ -1097,7 +1102,7 @@ class FilesystemController(SubiquityController, FilesystemManipulator): if self._configured: return fpath = os.path.join(self.app.block_log_dir, fname) - with open(fpath, 'w') as fp: + with open(fpath, "w") as fp: json.dump(storage, fp, indent=4) self.app.note_file_for_apport(key, fpath) if not self.locked_probe_data: @@ -1109,14 +1114,15 @@ class FilesystemController(SubiquityController, FilesystemManipulator): @with_context() async def _probe(self, *, context=None): self._errors = {} - for (restricted, kind, short_label) in [ - (False, ErrorReportKind.BLOCK_PROBE_FAIL, "block"), - (True, ErrorReportKind.DISK_PROBE_FAIL, "disk"), - ]: + for restricted, kind, short_label in [ + (False, ErrorReportKind.BLOCK_PROBE_FAIL, "block"), + (True, ErrorReportKind.DISK_PROBE_FAIL, "disk"), + ]: try: start = time.time() await self._probe_once_task.start( - context=context, restricted=restricted) + context=context, restricted=restricted + ) # We wait on the task directly here, not # self._probe_once_task.wait as if _probe_once_task # gets cancelled, we should be cancelled too. @@ -1127,22 +1133,23 @@ class FilesystemController(SubiquityController, FilesystemManipulator): raise except Exception as exc: block_discover_log.exception( - "block probing failed restricted=%s", restricted) + "block probing failed restricted=%s", restricted + ) report = self.app.make_apport_report(kind, "block probing") if report is not None: self._errors[restricted] = (exc, report) continue finally: elapsed = time.time() - start - log.debug(f'{short_label} probing took {elapsed:.1f} seconds') + log.debug(f"{short_label} probing took {elapsed:.1f} seconds") break async def run_autoinstall_guided(self, layout): - name = layout['name'] + name = layout["name"] password = None sizing_policy = None - if name == 'hybrid': + if name == "hybrid": # this check is conceptually unnecessary but results in a # much cleaner error message... core_boot_caps = set() @@ -1153,30 +1160,32 @@ class FilesystemController(SubiquityController, FilesystemManipulator): core_boot_caps.update(variation.capability_info.allowed) if not core_boot_caps: raise Exception( - "can only use name: hybrid when installing core boot " - "classic") - if 'mode' in layout: - raise Exception( - "cannot use 'mode' when installing core boot classic") - encrypted = layout.get('encrypted', None) + "can only use name: hybrid when installing core boot " "classic" + ) + if "mode" in layout: + raise Exception("cannot use 'mode' when installing core boot classic") + encrypted = layout.get("encrypted", None) GC = GuidedCapability if encrypted is None: - if GC.CORE_BOOT_ENCRYPTED in core_boot_caps or \ - GC.CORE_BOOT_PREFER_ENCRYPTED in core_boot_caps: + if ( + GC.CORE_BOOT_ENCRYPTED in core_boot_caps + or GC.CORE_BOOT_PREFER_ENCRYPTED in core_boot_caps + ): capability = GC.CORE_BOOT_ENCRYPTED else: capability = GC.CORE_BOOT_UNENCRYPTED elif encrypted: capability = GC.CORE_BOOT_ENCRYPTED else: - if core_boot_caps == { - GuidedCapability.CORE_BOOT_ENCRYPTED} and \ - not encrypted: + if ( + core_boot_caps == {GuidedCapability.CORE_BOOT_ENCRYPTED} + and not encrypted + ): raise Exception("cannot install this model unencrypted") capability = GC.CORE_BOOT_UNENCRYPTED - match = layout.get("match", {'size': 'largest'}) + match = layout.get("match", {"size": "largest"}) disk = self.model.disk_for_match(self.model.all_disks(), match) - mode = 'reformat_disk' + mode = "reformat_disk" else: # this check is conceptually unnecessary but results in a # much cleaner error message... @@ -1187,61 +1196,75 @@ class FilesystemController(SubiquityController, FilesystemManipulator): break else: raise Exception( - "must use name: hybrid when installing core boot " - "classic") - mode = layout.get('mode', 'reformat_disk') + "must use name: hybrid when installing core boot " "classic" + ) + mode = layout.get("mode", "reformat_disk") self.validate_layout_mode(mode) - password = layout.get('password', None) - if name == 'lvm': + password = layout.get("password", None) + if name == "lvm": sizing_policy = SizingPolicy.from_string( - layout.get('sizing-policy', None)) + layout.get("sizing-policy", None) + ) if password is not None: capability = GuidedCapability.LVM_LUKS else: capability = GuidedCapability.LVM - elif name == 'zfs': + elif name == "zfs": capability = GuidedCapability.ZFS else: capability = GuidedCapability.DIRECT - if mode == 'reformat_disk': - match = layout.get("match", {'size': 'largest'}) + if mode == "reformat_disk": + match = layout.get("match", {"size": "largest"}) disk = self.model.disk_for_match(self.model.all_disks(), match) - target = GuidedStorageTargetReformat( - disk_id=disk.id, allowed=[]) - elif mode == 'use_gap': - bootable = [d for d in self.model.all_disks() - if boot.can_be_boot_device(d, with_reformatting=False)] + target = GuidedStorageTargetReformat(disk_id=disk.id, allowed=[]) + elif mode == "use_gap": + bootable = [ + d + for d in self.model.all_disks() + if boot.can_be_boot_device(d, with_reformatting=False) + ] gap = gaps.largest_gap(bootable) if not gap: - raise Exception("autoinstall cannot configure storage " - "- no gap found large enough for install") + raise Exception( + "autoinstall cannot configure storage " + "- no gap found large enough for install" + ) target = GuidedStorageTargetUseGap( - disk_id=gap.device.id, gap=gap, allowed=[]) + disk_id=gap.device.id, gap=gap, allowed=[] + ) - log.info(f'autoinstall: running guided {capability} install in ' - f'mode {mode} using {target}') + log.info( + f"autoinstall: running guided {capability} install in " + f"mode {mode} using {target}" + ) await self.guided( - GuidedChoiceV2( - target=target, capability=capability, - password=password, sizing_policy=sizing_policy, - reset_partition=layout.get('reset-partition', False)), - reset_partition_only=layout.get('reset-partition-only', False)) + GuidedChoiceV2( + target=target, + capability=capability, + password=password, + sizing_policy=sizing_policy, + reset_partition=layout.get("reset-partition", False), + ), + reset_partition_only=layout.get("reset-partition-only", False), + ) def validate_layout_mode(self, mode): - if mode not in ('reformat_disk', 'use_gap'): - raise ValueError(f'Unknown layout mode {mode}') + if mode not in ("reformat_disk", "use_gap"): + raise ValueError(f"Unknown layout mode {mode}") @with_context() async def convert_autoinstall_config(self, context=None): # Log disabled to prevent LUKS password leak # log.debug("self.ai_data = %s", self.ai_data) - if 'layout' in self.ai_data: - if 'config' in self.ai_data: - log.warning("The 'storage' section should not contain both " - "'layout' and 'config', using 'layout'") - await self.run_autoinstall_guided(self.ai_data['layout']) - elif 'config' in self.ai_data: + if "layout" in self.ai_data: + if "config" in self.ai_data: + log.warning( + "The 'storage' section should not contain both " + "'layout' and 'config', using 'layout'" + ) + await self.run_autoinstall_guided(self.ai_data["layout"]) + elif "config" in self.ai_data: # XXX in principle should there be a way to influence the # variation chosen here? Not with current use cases for # variations anyway. @@ -1253,23 +1276,24 @@ class FilesystemController(SubiquityController, FilesystemManipulator): break else: raise Exception( - "must not use config: when installing core boot classic") - self.model.apply_autoinstall_config(self.ai_data['config']) - self.model.swap = self.ai_data.get('swap') - self.model.grub = self.ai_data.get('grub') + "must not use config: when installing core boot classic" + ) + self.model.apply_autoinstall_config(self.ai_data["config"]) + self.model.swap = self.ai_data.get("swap") + self.model.grub = self.ai_data.get("grub") def start(self): if self.model.bootloader == Bootloader.PREP: self.supports_resilient_boot = False else: - release = lsb_release(dry_run=self.app.opts.dry_run)['release'] - self.supports_resilient_boot = release >= '20.04' + release = lsb_release(dry_run=self.app.opts.dry_run)["release"] + self.supports_resilient_boot = release >= "20.04" self._start_task = schedule_task(self._start()) async def _start(self): context = pyudev.Context() self._monitor = pyudev.Monitor.from_netlink(context) - self._monitor.filter_by(subsystem='block') + self._monitor.filter_by(subsystem="block") self._monitor.enable_receiving() self.start_listening_udev() await self._probe_task.start() @@ -1286,12 +1310,12 @@ class FilesystemController(SubiquityController, FilesystemManipulator): try: self._probe_task.start_sync() except TaskAlreadyRunningError: - log.debug('Skipping run of Probert - probe run already active') + log.debug("Skipping run of Probert - probe run already active") else: - log.debug('Triggered Probert run on udev event') + log.debug("Triggered Probert run on udev event") def _udev_event(self): - cp = run_command(['udevadm', 'settle', '-t', '0']) + cp = run_command(["udevadm", "settle", "-t", "0"]) if cp.returncode != 0: log.debug("waiting 0.1 to let udev event queue settle") self.stop_listening_udev() @@ -1310,16 +1334,13 @@ class FilesystemController(SubiquityController, FilesystemManipulator): def make_autoinstall(self): rendered = self.model.render() - r = { - 'config': rendered['storage']['config'] - } - if 'swap' in rendered: - r['swap'] = rendered['swap'] + r = {"config": rendered["storage"]["config"]} + if "swap" in rendered: + r["swap"] = rendered["swap"] return r async def _pre_shutdown(self): if not self.reset_partition_only: - await self.app.command_runner.run( - ['umount', '--recursive', '/target']) - if len(self.model._all(type='zpool')) > 0: - await self.app.command_runner.run(['zpool', 'export', '-a']) + await self.app.command_runner.run(["umount", "--recursive", "/target"]) + if len(self.model._all(type="zpool")) > 0: + await self.app.command_runner.run(["zpool", "export", "-a"]) diff --git a/subiquity/server/controllers/identity.py b/subiquity/server/controllers/identity.py index 759c03b2..2bb15590 100644 --- a/subiquity/server/controllers/identity.py +++ b/subiquity/server/controllers/identity.py @@ -14,22 +14,21 @@ # along with this program. If not, see . import logging - -import attr import re from typing import Set -from subiquitycore.context import with_context +import attr from subiquity.common.apidef import API from subiquity.common.resources import resource_path from subiquity.common.types import IdentityData, UsernameValidation from subiquity.server.controller import SubiquityController +from subiquitycore.context import with_context -log = logging.getLogger('subiquity.server.controllers.identity') +log = logging.getLogger("subiquity.server.controllers.identity") USERNAME_MAXLEN = 32 -USERNAME_REGEX = r'[a-z_][a-z0-9_-]*' +USERNAME_REGEX = r"[a-z_][a-z0-9_-]*" def _reserved_names_from_file(path: str) -> Set[str]: @@ -38,7 +37,7 @@ def _reserved_names_from_file(path: str) -> Set[str]: with open(path, "r") as f: for line in f: line = line.strip() - if not line or line.startswith('#'): + if not line or line.startswith("#"): continue names.add(line.split()[0]) @@ -46,48 +45,46 @@ def _reserved_names_from_file(path: str) -> Set[str]: class IdentityController(SubiquityController): - endpoint = API.identity autoinstall_key = model_name = "identity" autoinstall_schema = { - 'type': 'object', - 'properties': { - 'realname': {'type': 'string'}, - 'username': {'type': 'string'}, - 'hostname': {'type': 'string'}, - 'password': {'type': 'string'}, - }, - 'required': ['username', 'hostname', 'password'], - 'additionalProperties': False, - } + "type": "object", + "properties": { + "realname": {"type": "string"}, + "username": {"type": "string"}, + "hostname": {"type": "string"}, + "password": {"type": "string"}, + }, + "required": ["username", "hostname", "password"], + "additionalProperties": False, + } - interactive_for_variants = {'desktop', 'server'} + interactive_for_variants = {"desktop", "server"} def __init__(self, app): super().__init__(app) core_reserved_path = resource_path("reserved-usernames") - self._system_reserved_names = \ - _reserved_names_from_file(core_reserved_path) + self._system_reserved_names = _reserved_names_from_file(core_reserved_path) # Let this field be the customisation point for variants. - self.existing_usernames = {'root'} + self.existing_usernames = {"root"} def load_autoinstall_data(self, data): if data is not None: identity_data = IdentityData( - realname=data.get('realname', ''), - username=data['username'], - hostname=data['hostname'], - crypted_password=data['password'], - ) + realname=data.get("realname", ""), + username=data["username"], + hostname=data["hostname"], + crypted_password=data["password"], + ) self.model.add_user(identity_data) @with_context() async def apply_autoinstall_config(self, context=None): if self.model.user: return - if 'user-data' in self.app.autoinstall_config: + if "user-data" in self.app.autoinstall_config: return if self.app.base_model.target is None: return @@ -97,7 +94,7 @@ class IdentityController(SubiquityController): if self.model.user is None: return {} r = attr.asdict(self.model.user) - r['hostname'] = self.model.hostname + r["hostname"] = self.model.hostname return r async def GET(self) -> IdentityData: @@ -112,8 +109,11 @@ class IdentityController(SubiquityController): async def POST(self, data: IdentityData): validated = await self.validate_username_GET(data.username) if validated != UsernameValidation.OK: - raise ValueError("Username <{}> is invalid and should not be" - " submitted.".format(data.username), validated) + raise ValueError( + "Username <{}> is invalid and should not be" + " submitted.".format(data.username), + validated, + ) self.model.add_user(data) await self.configured() diff --git a/subiquity/server/controllers/install.py b/subiquity/server/controllers/install.py index 420857de..c7625b75 100644 --- a/subiquity/server/controllers/install.py +++ b/subiquity/server/controllers/install.py @@ -18,64 +18,39 @@ import copy import json import logging import os -from pathlib import Path import re import shutil import subprocess import tempfile +from pathlib import Path from typing import Any, Dict, List, Optional -from curtin.config import merge_config -from curtin.util import ( - get_efibootmgr, - is_uefi_bootable, - ) import yaml - -from subiquitycore.async_helpers import ( - run_bg_task, - run_in_thread, - ) -from subiquitycore.context import with_context -from subiquitycore.file_util import ( - write_file, - generate_config_yaml, - generate_timestamped_header, - ) -from subiquitycore.utils import arun_command, log_process_streams +from curtin.config import merge_config +from curtin.util import get_efibootmgr, is_uefi_bootable from subiquity.common.errorreport import ErrorReportKind -from subiquity.common.types import ( - ApplicationState, - PackageInstallState, - ) -from subiquity.journald import ( - journald_listen, - ) +from subiquity.common.types import ApplicationState, PackageInstallState +from subiquity.journald import journald_listen from subiquity.models.filesystem import ActionRenderMode -from subiquity.server.controller import ( - SubiquityController, - ) -from subiquity.server.curtin import ( - run_curtin_command, - start_curtin_command, - ) -from subiquity.server.kernel import ( - list_installed_kernels, - ) -from subiquity.server.mounter import ( - Mounter, - ) -from subiquity.server.types import ( - InstallerChannels, - ) - +from subiquity.server.controller import SubiquityController +from subiquity.server.curtin import run_curtin_command, start_curtin_command +from subiquity.server.kernel import list_installed_kernels +from subiquity.server.mounter import Mounter +from subiquity.server.types import InstallerChannels +from subiquitycore.async_helpers import run_bg_task, run_in_thread +from subiquitycore.context import with_context +from subiquitycore.file_util import ( + generate_config_yaml, + generate_timestamped_header, + write_file, +) +from subiquitycore.utils import arun_command, log_process_streams log = logging.getLogger("subiquity.server.controllers.install") class TracebackExtractor: - start_marker = re.compile(r"^Traceback \(most recent call last\):") end_marker = re.compile(r"\S") @@ -94,7 +69,6 @@ class TracebackExtractor: class InstallController(SubiquityController): - def __init__(self, app): super().__init__(app) self.model = app.base_model @@ -119,11 +93,11 @@ class InstallController(SubiquityController): return os.path.join(self.model.target, *path) def log_event(self, event): - self.tb_extractor.feed(event['MESSAGE']) + self.tb_extractor.feed(event["MESSAGE"]) def write_config(self, config_file: Path, config: Any) -> None: - """ Create a YAML file that represents the curtin install configuration - specified. """ + """Create a YAML file that represents the curtin install configuration + specified.""" config_file.parent.mkdir(parents=True, exist_ok=True) generate_config_yaml(str(config_file), config) @@ -144,15 +118,15 @@ class InstallController(SubiquityController): } def filesystem_config( - self, - device_map_path: Path, - mode: ActionRenderMode = ActionRenderMode.DEFAULT, - ) -> Dict[str, Any]: + self, + device_map_path: Path, + mode: ActionRenderMode = ActionRenderMode.DEFAULT, + ) -> Dict[str, Any]: """Return configuration to be used as part of a curtin 'block-meta' step.""" cfg = self.model.filesystem.render(mode=mode) if device_map_path is not None: - cfg['storage']['device_map_path'] = str(device_map_path) + cfg["storage"]["device_map_path"] = str(device_map_path) return cfg def generic_config(self, **kw) -> Dict[str, Any]: @@ -169,23 +143,25 @@ class InstallController(SubiquityController): "install": { "target": target, "resume_data": None, - "extra_rsync_args": ['--no-links'], + "extra_rsync_args": ["--no-links"], } } @with_context(description="umounting /target dir") async def unmount_target(self, *, context, target): - await run_curtin_command(self.app, context, 'unmount', '-t', target, - private_mounts=False) + await run_curtin_command( + self.app, context, "unmount", "-t", target, private_mounts=False + ) if not self.app.opts.dry_run: shutil.rmtree(target) def supports_apt(self) -> bool: - return self.model.target is not None and \ - self.model.source.current.variant != 'core' + return ( + self.model.target is not None + and self.model.source.current.variant != "core" + ) - @with_context( - description="configuring apt", level="INFO", childlevel="DEBUG") + @with_context(description="configuring apt", level="INFO", childlevel="DEBUG") async def configure_apt(self, *, context): mirror = self.app.controllers.Mirror fsc = self.app.controllers.Filesystem @@ -198,23 +174,24 @@ class InstallController(SubiquityController): mirror = self.app.controllers.Mirror await mirror.final_apt_configurer.setup_target(context, self.tpath()) - @with_context( - description="executing curtin install {name} step") + @with_context(description="executing curtin install {name} step") async def run_curtin_step( - self, - context, - name: str, - stages: List[str], - config_file: Path, - source: Optional[str], - config: Dict[str, Any]): + self, + context, + name: str, + stages: List[str], + config_file: Path, + source: Optional[str], + config: Dict[str, Any], + ): """Run a curtin install step.""" self.app.note_file_for_apport( - f"Curtin{name.title().replace(' ', '')}Config", str(config_file)) + f"Curtin{name.title().replace(' ', '')}Config", str(config_file) + ) self.write_config(config_file=config_file, config=config) - log_file = Path(config['install']['log_file']) + log_file = Path(config["install"]["log_file"]) # Make sure the log directory exists. log_file.parent.mkdir(parents=True, exist_ok=True) @@ -224,24 +201,28 @@ class InstallController(SubiquityController): fh.write(f"\n---- [[ subiquity step {name} ]] ----\n") if source is not None: - source_args = (source, ) + source_args = (source,) else: source_args = () await run_curtin_command( - self.app, context, "install", - "--set", f'json:stages={json.dumps(stages)}', + self.app, + context, + "install", + "--set", + f"json:stages={json.dumps(stages)}", *source_args, - config=str(config_file), private_mounts=False) + config=str(config_file), + private_mounts=False, + ) - device_map_path = config.get('storage', {}).get('device_map_path') + device_map_path = config.get("storage", {}).get("device_map_path") if device_map_path is not None: with open(device_map_path) as fp: device_map = json.load(fp) self.app.controllers.Filesystem.update_devices(device_map) - @with_context( - description="installing system", level="INFO", childlevel="DEBUG") + @with_context(description="installing system", level="INFO", childlevel="DEBUG") async def curtin_install(self, *, context, source): if self.app.opts.dry_run: root = Path(self.app.opts.output_base) @@ -253,12 +234,13 @@ class InstallController(SubiquityController): config_dir = logs_dir / "curtin-install" base_config = self.base_config( - logs_dir, Path(tempfile.mkdtemp()) / "resume-data.json") + logs_dir, Path(tempfile.mkdtemp()) / "resume-data.json" + ) self.app.note_file_for_apport( - "CurtinErrors", base_config['install']['error_tarfile']) - self.app.note_file_for_apport( - "CurtinLog", base_config['install']['log_file']) + "CurtinErrors", base_config["install"]["error_tarfile"] + ) + self.app.note_file_for_apport("CurtinLog", base_config["install"]["log_file"]) fs_controller = self.app.controllers.Filesystem @@ -273,74 +255,89 @@ class InstallController(SubiquityController): config_file=config_dir / filename, source=source, config=config, - ) + ) await run_curtin_step(name="initial", stages=[], step_config={}) if fs_controller.reset_partition_only: await run_curtin_step( - name="partitioning", stages=["partitioning"], + name="partitioning", + stages=["partitioning"], step_config=self.filesystem_config( device_map_path=logs_dir / "device-map.json", - ), - ) + ), + ) elif fs_controller.is_core_boot_classic(): await run_curtin_step( - name="partitioning", stages=["partitioning"], + name="partitioning", + stages=["partitioning"], step_config=self.filesystem_config( mode=ActionRenderMode.DEVICES, device_map_path=logs_dir / "device-map-partition.json", - ), - ) + ), + ) if fs_controller.use_tpm: await fs_controller.setup_encryption(context=context) await run_curtin_step( - name="formatting", stages=["partitioning"], + name="formatting", + stages=["partitioning"], step_config=self.filesystem_config( mode=ActionRenderMode.FORMAT_MOUNT, - device_map_path=logs_dir / "device-map-format.json"), - ) + device_map_path=logs_dir / "device-map-format.json", + ), + ) await run_curtin_step( - name="extract", stages=["extract"], + name="extract", + stages=["extract"], step_config=self.generic_config(), source=source, - ) + ) await self.create_core_boot_classic_fstab(context=context) await run_curtin_step( - name="swap", stages=["swap"], + name="swap", + stages=["swap"], step_config=self.generic_config( swap_commands={ - 'subiquity': [ - 'curtin', 'swap', - '--fstab', self.tpath('etc/fstab'), - ], - }), - ) + "subiquity": [ + "curtin", + "swap", + "--fstab", + self.tpath("etc/fstab"), + ], + } + ), + ) await fs_controller.finish_install(context=context) await self.setup_target(context=context) else: await run_curtin_step( - name="partitioning", stages=["partitioning"], + name="partitioning", + stages=["partitioning"], step_config=self.filesystem_config( device_map_path=logs_dir / "device-map.json", - ), - ) + ), + ) await run_curtin_step( - name="extract", stages=["extract"], + name="extract", + stages=["extract"], step_config=self.generic_config(), source=source, - ) + ) if self.app.opts.dry_run: # In dry-run, extract does not do anything. Let's create what's # needed manually. Ideally, we would not hardcode # var/lib/dpkg/status because it is an implementation detail. status = "var/lib/dpkg/status" (root / status).parent.mkdir(parents=True, exist_ok=True) - await arun_command([ - "cp", "-aT", "--", - str(Path("/") / status), - str(root / status), - ]) + await arun_command( + [ + "cp", + "-aT", + "--", + str(Path("/") / status), + str(root / status), + ] + ) await self.setup_target(context=context) # For OEM, we basically mimic what ubuntu-drivers does: @@ -359,13 +356,21 @@ class InstallController(SubiquityController): for pkg in self.model.oem.metapkgs: source_list = f"/etc/apt/sources.list.d/{pkg.name}.list" await run_curtin_command( - self.app, context, - "in-target", "-t", self.tpath(), "--", - "apt-get", "update", - "-o", f"Dir::Etc::SourceList={source_list}", - "-o", "Dir::Etc::SourceParts=/dev/null", + self.app, + context, + "in-target", + "-t", + self.tpath(), + "--", + "apt-get", + "update", + "-o", + f"Dir::Etc::SourceList={source_list}", + "-o", + "Dir::Etc::SourceParts=/dev/null", "--no-list-cleanup", - private_mounts=False) + private_mounts=False, + ) # NOTE In ubuntu-drivers, this is done in a single call to # apt-get install. @@ -379,9 +384,10 @@ class InstallController(SubiquityController): self.model.kernel.curthooks_no_install = True await run_curtin_step( - name="curthooks", stages=["curthooks"], + name="curthooks", + stages=["curthooks"], step_config=self.generic_config(), - ) + ) # If the current source has a snapd_system_label here we should # really write recovery_system={snapd_system_label} to # {target}/var/lib/snapd/modeenv to get snapd to pick it up on @@ -391,10 +397,11 @@ class InstallController(SubiquityController): mounter = Mounter(self.app) async with mounter.mounted(rp.path) as mp: await run_curtin_step( - name="populate recovery", stages=["extract"], + name="populate recovery", + stages=["extract"], step_config=self.rp_config(logs_dir, mp.p()), - source='cp:///cdrom', - ) + source="cp:///cdrom", + ) await self.create_rp_boot_entry(context=context, rp=rp) @with_context(description="creating boot entry for reset partition") @@ -402,37 +409,47 @@ class InstallController(SubiquityController): fs_controller = self.app.controllers.Filesystem if not fs_controller.reset_partition_only: cp = await self.app.command_runner.run( - ['lsblk', '-n', '-o', 'UUID', rp.path], - capture=True) - uuid = cp.stdout.decode('ascii').strip() + ["lsblk", "-n", "-o", "UUID", rp.path], capture=True + ) + uuid = cp.stdout.decode("ascii").strip() conf = grub_reset_conf.format( - HEADER=generate_timestamped_header(), - PARTITION=rp.number, - UUID=uuid) - with open(self.tpath('etc/grub.d/99_reset'), 'w') as fp: + HEADER=generate_timestamped_header(), PARTITION=rp.number, UUID=uuid + ) + with open(self.tpath("etc/grub.d/99_reset"), "w") as fp: os.chmod(fp.fileno(), 0o755) fp.write(conf) await run_curtin_command( - self.app, context, "in-target", "-t", self.tpath(), "--", - "update-grub", private_mounts=False) + self.app, + context, + "in-target", + "-t", + self.tpath(), + "--", + "update-grub", + private_mounts=False, + ) if self.app.opts.dry_run and not is_uefi_bootable(): # Can't even run efibootmgr in this case. return - state = await self.app.package_installer.install_pkg('efibootmgr') + state = await self.app.package_installer.install_pkg("efibootmgr") if state != PackageInstallState.DONE: - raise RuntimeError('could not install efibootmgr') - efi_state_before = get_efibootmgr('/') + raise RuntimeError("could not install efibootmgr") + efi_state_before = get_efibootmgr("/") cmd = [ - 'efibootmgr', '--create', - '--loader', '\\EFI\\boot\\shimx64.efi', - '--disk', rp.device.path, - '--part', str(rp.number), - '--label', "Restore Ubuntu to factory state", - ] + "efibootmgr", + "--create", + "--loader", + "\\EFI\\boot\\shimx64.efi", + "--disk", + rp.device.path, + "--part", + str(rp.number), + "--label", + "Restore Ubuntu to factory state", + ] await self.app.command_runner.run(cmd) - efi_state_after = get_efibootmgr('/') - new_bootnums = ( - set(efi_state_after.entries) - set(efi_state_before.entries)) + efi_state_after = get_efibootmgr("/") + new_bootnums = set(efi_state_after.entries) - set(efi_state_before.entries) if not new_bootnums: return new_bootnum = new_bootnums.pop() @@ -443,24 +460,27 @@ class InstallController(SubiquityController): was_dup = True if was_dup: cmd = [ - 'efibootmgr', '--delete-bootnum', - '--bootnum', new_bootnum, - ] + "efibootmgr", + "--delete-bootnum", + "--bootnum", + new_bootnum, + ] else: cmd = [ - 'efibootmgr', - '--bootorder', ','.join(efi_state_before.order), - ] + "efibootmgr", + "--bootorder", + ",".join(efi_state_before.order), + ] await self.app.command_runner.run(cmd) @with_context(description="creating fstab") async def create_core_boot_classic_fstab(self, *, context): - with open(self.tpath('etc/fstab'), 'w') as fp: + with open(self.tpath("etc/fstab"), "w") as fp: fp.write("/run/mnt/ubuntu-boot/EFI/ubuntu /boot/grub none bind\n") @with_context() async def install(self, *, context): - context.set('is-install-context', True) + context.set("is-install-context", True) try: while True: self.app.update_state(ApplicationState.WAITING) @@ -468,7 +488,7 @@ class InstallController(SubiquityController): await self.model.wait_install() if not self.app.interactive: - if 'autoinstall' in self.app.kernel_cmdline: + if "autoinstall" in self.app.kernel_cmdline: await self.model.confirm() self.app.update_state(ApplicationState.NEEDS_CONFIRMATION) @@ -481,8 +501,7 @@ class InstallController(SubiquityController): if self.model.target is None: for_install_path = None elif self.supports_apt(): - for_install_path = 'cp://' + await self.configure_apt( - context=context) + for_install_path = "cp://" + await self.configure_apt(context=context) await self.app.hub.abroadcast(InstallerChannels.APT_CONFIGURED) else: @@ -490,12 +509,11 @@ class InstallController(SubiquityController): for_install_path = self.model.source.get_source(fsc._info.name) if self.app.controllers.Filesystem.reset_partition: - self.app.package_installer.start_installing_pkg('efibootmgr') + self.app.package_installer.start_installing_pkg("efibootmgr") if self.model.target is not None: if os.path.exists(self.model.target): - await self.unmount_target( - context=context, target=self.model.target) + await self.unmount_target(context=context, target=self.model.target) await self.curtin_install(context=context, source=for_install_path) @@ -513,17 +531,20 @@ class InstallController(SubiquityController): if self.tb_extractor.traceback: kw["Traceback"] = "\n".join(self.tb_extractor.traceback) self.app.make_apport_report( - ErrorReportKind.INSTALL_FAIL, "install failed", **kw) + ErrorReportKind.INSTALL_FAIL, "install failed", **kw + ) raise @with_context( - description="final system configuration", level="INFO", - childlevel="DEBUG") + description="final system configuration", level="INFO", childlevel="DEBUG" + ) async def postinstall(self, *, context): autoinstall_path = os.path.join( - self.app.root, 'var/log/installer/autoinstall-user-data') + self.app.root, "var/log/installer/autoinstall-user-data" + ) autoinstall_config = "#cloud-config\n" + yaml.dump( - {"autoinstall": self.app.make_autoinstall()}) + {"autoinstall": self.app.make_autoinstall()} + ) write_file(autoinstall_path, autoinstall_config) await self.configure_cloud_init(context=context) if self.supports_apt(): @@ -532,22 +553,19 @@ class InstallController(SubiquityController): await self.install_package(context=context, package=package) if self.model.drivers.do_install: with context.child( - "ubuntu-drivers-install", - "installing third-party drivers") as child: + "ubuntu-drivers-install", "installing third-party drivers" + ) as child: udrivers = self.app.controllers.Drivers.ubuntu_drivers - await udrivers.install_drivers( - root_dir=self.tpath(), - context=child) + await udrivers.install_drivers(root_dir=self.tpath(), context=child) if self.model.network.has_network: self.app.update_state(ApplicationState.UU_RUNNING) policy = self.model.updates.updates - await self.run_unattended_upgrades( - context=context, policy=policy) + await self.run_unattended_upgrades(context=context, policy=policy) await self.restore_apt_config(context=context) if self.model.active_directory.do_join: hostname = self.model.identity.hostname if not hostname: - with open(self.tpath('etc/hostname'), 'r') as f: + with open(self.tpath("etc/hostname"), "r") as f: hostname = f.read().strip() await self.app.controllers.Ad.join_domain(hostname, context) @@ -560,21 +578,23 @@ class InstallController(SubiquityController): async def get_target_packages(self, context): return await self.app.base_model.target_packages() - @with_context( - name="install_{package}", - description="installing {package}") + @with_context(name="install_{package}", description="installing {package}") async def install_package(self, *, context, package): - """ Attempt to download the package up-to three times, then install it. - """ + """Attempt to download the package up-to three times, then install it.""" for attempt, attempts_remaining in enumerate(reversed(range(3))): try: - with context.child('retrieving', f'retrieving {package}'): + with context.child("retrieving", f"retrieving {package}"): await run_curtin_command( - self.app, context, 'system-install', '-t', + self.app, + context, + "system-install", + "-t", self.tpath(), - '--download-only', - '--', package, - private_mounts=False) + "--download-only", + "--", + package, + private_mounts=False, + ) except subprocess.CalledProcessError: log.error(f"failed to download package {package}") if attempts_remaining > 0: @@ -584,12 +604,18 @@ class InstallController(SubiquityController): else: break - with context.child('unpacking', f'unpacking {package}'): + with context.child("unpacking", f"unpacking {package}"): await run_curtin_command( - self.app, context, 'system-install', '-t', self.tpath(), - '--assume-downloaded', - '--', package, - private_mounts=False) + self.app, + context, + "system-install", + "-t", + self.tpath(), + "--assume-downloaded", + "--", + package, + private_mounts=False, + ) @with_context(description="restoring apt configuration") async def restore_apt_config(self, context): @@ -604,39 +630,47 @@ class InstallController(SubiquityController): aptdir = self.tpath("etc/apt/apt.conf.d") os.makedirs(aptdir, exist_ok=True) apt_conf_contents = uu_apt_conf - if policy == 'all': + if policy == "all": apt_conf_contents += uu_apt_conf_update_all else: apt_conf_contents += uu_apt_conf_update_security - fname = 'zzzz-temp-installer-unattended-upgrade' - with open(os.path.join(aptdir, fname), 'wb') as apt_conf: + fname = "zzzz-temp-installer-unattended-upgrade" + with open(os.path.join(aptdir, fname), "wb") as apt_conf: apt_conf.write(apt_conf_contents) apt_conf.close() self.unattended_upgrades_ctx = context self.unattended_upgrades_cmd = await start_curtin_command( - self.app, context, "in-target", "-t", self.tpath(), - "--", "unattended-upgrades", "-v", - private_mounts=True) + self.app, + context, + "in-target", + "-t", + self.tpath(), + "--", + "unattended-upgrades", + "-v", + private_mounts=True, + ) try: await self.unattended_upgrades_cmd.wait() except subprocess.CalledProcessError as cpe: - log_process_streams(logging.ERROR, cpe, 'Unattended upgrades') + log_process_streams(logging.ERROR, cpe, "Unattended upgrades") context.description = f"FAILED to apply {policy} updates" self.unattended_upgrades_cmd = None self.unattended_upgrades_ctx = None async def stop_unattended_upgrades(self): with self.unattended_upgrades_ctx.parent.child( - "stop_unattended_upgrades", - "cancelling update"): - await self.app.command_runner.run([ - 'chroot', self.tpath(), - '/usr/share/unattended-upgrades/' - 'unattended-upgrade-shutdown', - '--stop-only', - ]) - if self.app.opts.dry_run and \ - self.unattended_upgrades_cmd is not None: + "stop_unattended_upgrades", "cancelling update" + ): + await self.app.command_runner.run( + [ + "chroot", + self.tpath(), + "/usr/share/unattended-upgrades/" "unattended-upgrade-shutdown", + "--stop-only", + ] + ) + if self.app.opts.dry_run and self.unattended_upgrades_cmd is not None: self.unattended_upgrades_cmd.proc.terminate() diff --git a/subiquity/server/controllers/integrity.py b/subiquity/server/controllers/integrity.py index 85f71060..f6b924d8 100644 --- a/subiquity/server/controllers/integrity.py +++ b/subiquity/server/controllers/integrity.py @@ -16,31 +16,28 @@ import json import logging -from subiquitycore.async_helpers import schedule_task from subiquity.common.apidef import API from subiquity.common.types import CasperMd5Results from subiquity.journald import journald_get_first_match from subiquity.server.controller import SubiquityController +from subiquitycore.async_helpers import schedule_task +log = logging.getLogger("subiquity.server.controllers.integrity") -log = logging.getLogger('subiquity.server.controllers.integrity') - -mock_pass = {'checksum_missmatch': [], 'result': 'pass'} -mock_skip = {'checksum_missmatch': [], 'result': 'skip'} -mock_fail = {'checksum_missmatch': ['./casper/initrd'], 'result': 'fail'} +mock_pass = {"checksum_missmatch": [], "result": "pass"} +mock_skip = {"checksum_missmatch": [], "result": "skip"} +mock_fail = {"checksum_missmatch": ["./casper/initrd"], "result": "fail"} class IntegrityController(SubiquityController): - endpoint = API.integrity - model_name = 'integrity' - result_filepath = '/run/casper-md5check.json' + model_name = "integrity" + result_filepath = "/run/casper-md5check.json" @property def result(self): - return CasperMd5Results( - self.model.md5check_results.get('result', 'unknown')) + return CasperMd5Results(self.model.md5check_results.get("result", "unknown")) async def GET(self) -> CasperMd5Results: return self.result @@ -48,7 +45,8 @@ class IntegrityController(SubiquityController): async def wait_casper_md5check(self): if not self.app.opts.dry_run: await journald_get_first_match( - 'systemd', 'UNIT=casper-md5check.service', 'JOB_RESULT=done') + "systemd", "UNIT=casper-md5check.service", "JOB_RESULT=done" + ) async def get_md5check_results(self): if self.app.opts.dry_run: @@ -57,10 +55,10 @@ class IntegrityController(SubiquityController): try: ret = json.load(fp) except json.JSONDecodeError as jde: - log.debug(f'error reading casper-md5check results: {jde}') + log.debug(f"error reading casper-md5check results: {jde}") return {} else: - log.debug(f'casper-md5check results: {ret}') + log.debug(f"casper-md5check results: {ret}") return ret async def md5check(self): diff --git a/subiquity/server/controllers/kernel.py b/subiquity/server/controllers/kernel.py index b4341155..8bfa12e4 100644 --- a/subiquity/server/controllers/kernel.py +++ b/subiquity/server/controllers/kernel.py @@ -23,7 +23,6 @@ log = logging.getLogger("subiquity.server.controllers.kernel") class KernelController(NonInteractiveController): - model_name = autoinstall_key = "kernel" autoinstall_default = None autoinstall_schema = { @@ -40,7 +39,7 @@ class KernelController(NonInteractiveController): { "type": "object", "required": ["flavor"], - } + }, ], } @@ -51,37 +50,35 @@ class KernelController(NonInteractiveController): # the ISO may have been configured to tell us what kernel to use # /run is the historical location, but a bit harder to craft an ISO to # have it there as it needs to be populated at runtime. - run_mp_file = os.path.join( - self.app.base_model.root, - "run/kernel-meta-package") + run_mp_file = os.path.join(self.app.base_model.root, "run/kernel-meta-package") etc_mp_file = os.path.join( - self.app.base_model.root, - "etc/subiquity/kernel-meta-package") + self.app.base_model.root, "etc/subiquity/kernel-meta-package" + ) for mp_file in (run_mp_file, etc_mp_file): if os.path.exists(mp_file): with open(mp_file) as fp: kernel_package = fp.read().strip() self.model.metapkg_name = kernel_package self.model.explicitly_requested = True - log.debug(f'Using kernel {kernel_package} due to {mp_file}') + log.debug(f"Using kernel {kernel_package} due to {mp_file}") break else: - log.debug('Using default kernel linux-generic') - self.model.metapkg_name = 'linux-generic' + log.debug("Using default kernel linux-generic") + self.model.metapkg_name = "linux-generic" def load_autoinstall_data(self, data): if data is None: return - package = data.get('package') - flavor = data.get('flavor') + package = data.get("package") + flavor = data.get("flavor") if package is None: dry_run: bool = self.app.opts.dry_run if flavor is None: - flavor = 'generic' + flavor = "generic" package = flavor_to_pkgname(flavor, dry_run=dry_run) - log.debug(f'Using kernel {package} due to autoinstall') + log.debug(f"Using kernel {package} due to autoinstall") self.model.metapkg_name = package self.model.explicitly_requested = True def make_autoinstall(self): - return {'package': self.model.metapkg_name} + return {"package": self.model.metapkg_name} diff --git a/subiquity/server/controllers/keyboard.py b/subiquity/server/controllers/keyboard.py index 1bd3b878..190ab58f 100644 --- a/subiquity/server/controllers/keyboard.py +++ b/subiquity/server/controllers/keyboard.py @@ -14,38 +14,77 @@ # along with this program. If not, see . import logging -from typing import Dict, Optional, Sequence, Tuple import os import pwd +from typing import Dict, Optional, Sequence, Tuple import attr -from subiquitycore.context import with_context -from subiquitycore.utils import arun_command - from subiquity.common.apidef import API from subiquity.common.resources import resource_path from subiquity.common.serialize import Serializer -from subiquity.common.types import ( - AnyStep, - KeyboardSetting, - KeyboardSetup, - ) +from subiquity.common.types import AnyStep, KeyboardSetting, KeyboardSetup from subiquity.server.controller import SubiquityController +from subiquitycore.context import with_context +from subiquitycore.utils import arun_command -log = logging.getLogger('subiquity.server.controllers.keyboard') +log = logging.getLogger("subiquity.server.controllers.keyboard") # Non-latin keyboard layouts that are handled in a uniform way standard_non_latin_layouts = set( - ('af', 'am', 'ara', 'ben', 'bd', 'bg', 'bt', 'by', 'et', 'ge', - 'gh', 'gr', 'guj', 'guru', 'il', 'in', 'iq', 'ir', 'iku', 'kan', - 'kh', 'kz', 'la', 'lao', 'lk', 'kg', 'ma', 'mk', 'mm', 'mn', 'mv', - 'mal', 'np', 'ori', 'pk', 'ru', 'scc', 'sy', 'syr', 'tel', 'th', - 'tj', 'tam', 'tib', 'ua', 'ug', 'uz') + ( + "af", + "am", + "ara", + "ben", + "bd", + "bg", + "bt", + "by", + "et", + "ge", + "gh", + "gr", + "guj", + "guru", + "il", + "in", + "iq", + "ir", + "iku", + "kan", + "kh", + "kz", + "la", + "lao", + "lk", + "kg", + "ma", + "mk", + "mm", + "mn", + "mv", + "mal", + "np", + "ori", + "pk", + "ru", + "scc", + "sy", + "syr", + "tel", + "th", + "tj", + "tam", + "tib", + "ua", + "ug", + "uz", + ) ) -default_desktop_user = 'ubuntu' +default_desktop_user = "ubuntu" def latinizable(layout_code, variant_code) -> Optional[Tuple[str, str]]: @@ -53,34 +92,34 @@ def latinizable(layout_code, variant_code) -> Optional[Tuple[str, str]]: If this setting does not allow the typing of latin characters, return a setting that can be switched to one that can. """ - if layout_code == 'rs': - if variant_code.startswith('latin'): + if layout_code == "rs": + if variant_code.startswith("latin"): return None else: - if variant_code == 'yz': - new_variant_code = 'latinyz' - elif variant_code == 'alternatequotes': - new_variant_code = 'latinalternatequotes' + if variant_code == "yz": + new_variant_code = "latinyz" + elif variant_code == "alternatequotes": + new_variant_code = "latinalternatequotes" else: - new_variant_code = 'latin' - return 'rs,rs', new_variant_code + ',' + variant_code - elif layout_code == 'jp': - if variant_code in ('106', 'common', 'OADG109A', 'nicola_f_bs', ''): + new_variant_code = "latin" + return "rs,rs", new_variant_code + "," + variant_code + elif layout_code == "jp": + if variant_code in ("106", "common", "OADG109A", "nicola_f_bs", ""): return None else: - return 'jp,jp', ',' + variant_code - elif layout_code == 'lt': - if variant_code == 'us': - return 'lt,lt', 'us,' + return "jp,jp", "," + variant_code + elif layout_code == "lt": + if variant_code == "us": + return "lt,lt", "us," else: - return 'lt,lt', variant_code + ',us' - elif layout_code == 'me': - if variant_code == 'basic' or variant_code.startswith('latin'): + return "lt,lt", variant_code + ",us" + elif layout_code == "me": + if variant_code == "basic" or variant_code.startswith("latin"): return None else: - return 'me,me', variant_code + ',us' + return "me,me", variant_code + ",us" elif layout_code in standard_non_latin_layouts: - return 'us,' + layout_code, ',' + variant_code + return "us," + layout_code, "," + variant_code else: return None @@ -90,51 +129,49 @@ def for_ui(setting): Attempt to guess a setting the user chose which resulted in the current config. Basically the inverse of latinizable(). """ - if ',' in setting.layout: - layout1, layout2 = setting.layout.split(',', 1) + if "," in setting.layout: + layout1, layout2 = setting.layout.split(",", 1) else: - layout1, layout2 = setting.layout, '' - if ',' in setting.variant: - variant1, variant2 = setting.variant.split(',', 1) + layout1, layout2 = setting.layout, "" + if "," in setting.variant: + variant1, variant2 = setting.variant.split(",", 1) else: - variant1, variant2 = setting.variant, '' - if setting.layout == 'lt,lt': + variant1, variant2 = setting.variant, "" + if setting.layout == "lt,lt": layout = layout1 variant = variant1 - elif setting.layout in ('rs,rs', 'us,rs', 'jp,jp', 'us,jp'): + elif setting.layout in ("rs,rs", "us,rs", "jp,jp", "us,jp"): layout = layout2 variant = variant2 - elif layout1 == 'us' and layout2 in standard_non_latin_layouts: + elif layout1 == "us" and layout2 in standard_non_latin_layouts: layout = layout2 variant = variant2 - elif ',' in setting.layout: + elif "," in setting.layout: # Something unrecognized - layout = 'us' - variant = '' + layout = "us" + variant = "" else: return setting - return KeyboardSetting( - layout=layout, variant=variant, toggle=setting.toggle) + return KeyboardSetting(layout=layout, variant=variant, toggle=setting.toggle) class KeyboardController(SubiquityController): - endpoint = API.keyboard autoinstall_key = model_name = "keyboard" autoinstall_schema = { - 'type': 'object', - 'properties': { - 'layout': {'type': 'string'}, - 'variant': {'type': 'string'}, - 'toggle': {'type': ['string', 'null']}, - }, - 'required': ['layout'], - 'additionalProperties': False, - } + "type": "object", + "properties": { + "layout": {"type": "string"}, + "variant": {"type": "string"}, + "toggle": {"type": ["string", "null"]}, + }, + "required": ["layout"], + "additionalProperties": False, + } def __init__(self, app): - self._kbds_dir = resource_path('kbds') + self._kbds_dir = resource_path("kbds") self.serializer = Serializer(compact=True) self.pc105_steps = None self.needs_set_keyboard = False @@ -159,15 +196,15 @@ class KeyboardController(SubiquityController): async def set_keyboard(self): path = self.model.config_path os.makedirs(os.path.dirname(path), exist_ok=True) - with open(path, 'w') as fp: + with open(path, "w") as fp: fp.write(self.model.render_config_file()) cmds = [ - ['setupcon', '--save', '--force', '--keyboard-only'], - [resource_path('bin/subiquity-loadkeys')], - ] + ["setupcon", "--save", "--force", "--keyboard-only"], + [resource_path("bin/subiquity-loadkeys")], + ] if self.opts.dry_run: - scale = os.environ.get('SUBIQUITY_REPLAY_TIMESCALE', "1") - cmds = [['sleep', str(1/float(scale))]] + scale = os.environ.get("SUBIQUITY_REPLAY_TIMESCALE", "1") + cmds = [["sleep", str(1 / float(scale))]] for cmd in cmds: await arun_command(cmd) @@ -176,7 +213,8 @@ class KeyboardController(SubiquityController): self.model.keyboard_list.load_language(lang) return KeyboardSetup( setting=for_ui(self.model.setting_for_lang(lang)), - layouts=self.model.keyboard_list.layouts) + layouts=self.model.keyboard_list.layouts, + ) async def POST(self, data: KeyboardSetting): log.debug(data) @@ -188,46 +226,54 @@ class KeyboardController(SubiquityController): await self.set_keyboard() await self.configured() - async def needs_toggle_GET(self, layout_code: str, - variant_code: str) -> bool: + async def needs_toggle_GET(self, layout_code: str, variant_code: str) -> bool: return latinizable(layout_code, variant_code) is not None async def steps_GET(self, index: Optional[str]) -> AnyStep: if self.pc105_steps is None: - path = os.path.join(self._kbds_dir, 'pc105.json') + path = os.path.join(self._kbds_dir, "pc105.json") with open(path) as fp: self.pc105_steps = self.serializer.from_json( - Dict[str, AnyStep], fp.read()) + Dict[str, AnyStep], fp.read() + ) if index is None: index = "0" return self.pc105_steps[index] - async def input_source_POST(self, data: KeyboardSetting, - user: Optional[str] = None) -> None: + async def input_source_POST( + self, data: KeyboardSetting, user: Optional[str] = None + ) -> None: await self.set_input_source(data.layout, data.variant, user=user) async def set_input_source(self, layout: str, variant: str, **kwargs): - xkb_value = f'{layout}+{variant}' if variant else layout + xkb_value = f"{layout}+{variant}" if variant else layout gsettings = [ - 'gsettings', 'set', 'org.gnome.desktop.input-sources', 'sources', - f"[('xkb','{xkb_value}')]" - ] + "gsettings", + "set", + "org.gnome.desktop.input-sources", + "sources", + f"[('xkb','{xkb_value}')]", + ] await self._run_gui_command(gsettings, **kwargs) - async def _run_gui_command(self, command: Sequence[str], - user: Optional[str] = None): + async def _run_gui_command( + self, command: Sequence[str], user: Optional[str] = None + ): if self.opts.dry_run: - scale = os.environ.get('SUBIQUITY_REPLAY_TIMESCALE', '1') - cmd = ['sleep', str(1/float(scale))] + scale = os.environ.get("SUBIQUITY_REPLAY_TIMESCALE", "1") + cmd = ["sleep", str(1 / float(scale))] else: passwd = pwd.getpwnam(user if user else default_desktop_user) - xdg_runtime_dir = f'/run/user/{passwd.pw_uid}' - dbus_session_bus = f'unix:path={xdg_runtime_dir}/bus' + xdg_runtime_dir = f"/run/user/{passwd.pw_uid}" + dbus_session_bus = f"unix:path={xdg_runtime_dir}/bus" cmd = [ - 'systemd-run', '--wait', f'--uid={passwd.pw_uid}', + "systemd-run", + "--wait", + f"--uid={passwd.pw_uid}", f'--setenv=DISPLAY={os.environ.get("DISPLAY", ":0")}', - f'--setenv=XDG_RUNTIME_DIR={xdg_runtime_dir}', - f'--setenv=DBUS_SESSION_BUS_ADDRESS={dbus_session_bus}', - '--', *command - ] + f"--setenv=XDG_RUNTIME_DIR={xdg_runtime_dir}", + f"--setenv=DBUS_SESSION_BUS_ADDRESS={dbus_session_bus}", + "--", + *command, + ] return await arun_command(cmd) diff --git a/subiquity/server/controllers/locale.py b/subiquity/server/controllers/locale.py index 51f680b1..20ae3091 100644 --- a/subiquity/server/controllers/locale.py +++ b/subiquity/server/controllers/locale.py @@ -16,22 +16,20 @@ import logging import os -from subiquitycore import async_helpers as async_helpers from subiquity.common.apidef import API from subiquity.server.controller import SubiquityController from subiquity.server.types import InstallerChannels +from subiquitycore import async_helpers as async_helpers - -log = logging.getLogger('subiquity.server.controllers.locale') +log = logging.getLogger("subiquity.server.controllers.locale") class LocaleController(SubiquityController): - endpoint = API.locale autoinstall_key = model_name = "locale" - autoinstall_schema = {'type': 'string'} - autoinstall_default = 'en_US.UTF-8' + autoinstall_schema = {"type": "string"} + autoinstall_default = "en_US.UTF-8" def load_autoinstall_data(self, data): os.environ["LANG"] = data @@ -41,12 +39,14 @@ class LocaleController(SubiquityController): # But for autoinstall (not interactive), someone else # initializing to a different value would be unexpected behavior. if not self.interactive() or (self.model.selected_language is None): - self.model.selected_language = os.environ.get("LANG") \ - or self.autoinstall_default + self.model.selected_language = ( + os.environ.get("LANG") or self.autoinstall_default + ) async_helpers.run_bg_task(self.configured()) self.app.hub.subscribe( - (InstallerChannels.CONFIGURED, 'source'), self._set_source) + (InstallerChannels.CONFIGURED, "source"), self._set_source + ) def _set_source(self): current = self.app.base_model.source.current diff --git a/subiquity/server/controllers/mirror.py b/subiquity/server/controllers/mirror.py index e3e692ce..bbc32333 100644 --- a/subiquity/server/controllers/mirror.py +++ b/subiquity/server/controllers/mirror.py @@ -20,8 +20,6 @@ from typing import List, Optional import attr -from subiquitycore.context import with_context - from subiquity.common.apidef import API from subiquity.common.types import ( MirrorCheckResponse, @@ -30,27 +28,24 @@ from subiquity.common.types import ( MirrorPost, MirrorPostResponse, MirrorSelectionFallback, - ) +) from subiquity.models.mirror import filter_candidates -from subiquity.server.apt import ( - AptConfigCheckError, - AptConfigurer, - get_apt_configurer, - ) +from subiquity.server.apt import AptConfigCheckError, AptConfigurer, get_apt_configurer from subiquity.server.controller import SubiquityController from subiquity.server.types import InstallerChannels +from subiquitycore.context import with_context -log = logging.getLogger('subiquity.server.controllers.mirror') +log = logging.getLogger("subiquity.server.controllers.mirror") class NoUsableMirrorError(Exception): - """ Exception to be raised when none of the candidate mirrors passed the - test. """ + """Exception to be raised when none of the candidate mirrors passed the + test.""" class MirrorCheckNotStartedError(Exception): - """ Exception to be raised when trying to cancel a mirror - check that was not started. """ + """Exception to be raised when trying to cancel a mirror + check that was not started.""" @attr.s(auto_attribs=True) @@ -61,50 +56,55 @@ class MirrorCheck: class MirrorController(SubiquityController): - endpoint = API.mirror autoinstall_key = "apt" autoinstall_schema = { # This is obviously incomplete. - 'type': 'object', - 'properties': { - 'preserve_sources_list': {'type': 'boolean'}, - 'primary': {'type': 'array'}, # Legacy format defined by curtin. - 'mirror-selection': { - 'type': 'object', - 'properties': { - 'primary': { - 'type': 'array', - 'items': { - 'anyOf': [ + "type": "object", + "properties": { + "preserve_sources_list": {"type": "boolean"}, + "primary": {"type": "array"}, # Legacy format defined by curtin. + "mirror-selection": { + "type": "object", + "properties": { + "primary": { + "type": "array", + "items": { + "anyOf": [ { - 'type': 'string', - 'const': 'country-mirror', - }, { - 'type': 'object', - 'properties': { - 'uri': {'type': 'string'}, - 'arches': { - 'type': 'array', - 'items': {'type': 'string'}, + "type": "string", + "const": "country-mirror", + }, + { + "type": "object", + "properties": { + "uri": {"type": "string"}, + "arches": { + "type": "array", + "items": {"type": "string"}, }, }, - 'required': ['uri'], + "required": ["uri"], }, ], }, }, }, }, - 'geoip': {'type': 'boolean'}, - 'sources': {'type': 'object'}, - 'disable_components': { - 'type': 'array', - 'items': { - 'type': 'string', - 'enum': ['universe', 'multiverse', 'restricted', - 'contrib', 'non-free'] - } + "geoip": {"type": "boolean"}, + "sources": {"type": "object"}, + "disable_components": { + "type": "array", + "items": { + "type": "string", + "enum": [ + "universe", + "multiverse", + "restricted", + "contrib", + "non-free", + ], + }, }, "preferences": { "type": "array", @@ -126,13 +126,13 @@ class MirrorController(SubiquityController): "pin", "pin-priority", ], - } + }, }, "fallback": { "type": "string", "enum": [fb.value for fb in MirrorSelectionFallback], }, - } + }, } model_name = "mirror" @@ -144,14 +144,13 @@ class MirrorController(SubiquityController): self.network_configured_event = asyncio.Event() self.proxy_configured_event = asyncio.Event() self.app.hub.subscribe(InstallerChannels.GEOIP, self.on_geoip) + self.app.hub.subscribe((InstallerChannels.CONFIGURED, "source"), self.on_source) self.app.hub.subscribe( - (InstallerChannels.CONFIGURED, 'source'), self.on_source) + (InstallerChannels.CONFIGURED, "network"), self.network_configured_event.set + ) self.app.hub.subscribe( - (InstallerChannels.CONFIGURED, 'network'), - self.network_configured_event.set) - self.app.hub.subscribe( - (InstallerChannels.CONFIGURED, 'proxy'), - self.proxy_configured_event.set) + (InstallerChannels.CONFIGURED, "proxy"), self.proxy_configured_event.set + ) self._apt_config_key = None self.test_apt_configurer: Optional[AptConfigurer] = None self.final_apt_configurer: Optional[AptConfigurer] = None @@ -161,12 +160,12 @@ class MirrorController(SubiquityController): def load_autoinstall_data(self, data): if data is None: return - geoip = data.pop('geoip', True) + geoip = data.pop("geoip", True) self.model.load_autoinstall_data(data) self.geoip_enabled = geoip and self.model.wants_geoip() async def try_mirror_checking_once(self) -> None: - """ Try mirror checking and log result. """ + """Try mirror checking and log result.""" output = io.StringIO() try: await self.run_mirror_testing(output) @@ -188,7 +187,7 @@ class MirrorController(SubiquityController): await self.proxy_configured_event.wait() if self.geoip_enabled: try: - with context.child('waiting'): + with context.child("waiting"): await asyncio.wait_for(self.cc_event.wait(), 10) except asyncio.TimeoutError: pass @@ -232,8 +231,7 @@ class MirrorController(SubiquityController): fallback = self.model.fallback if fallback == MirrorSelectionFallback.ABORT: - log.error("aborting the install since no primary mirror is" - " usable") + log.error("aborting the install since no primary mirror is" " usable") # TODO there is no guarantee that raising this exception will # actually abort the install. If this is raised from a request # handler, for instance, it will just return a HTTP 500 error. For @@ -241,8 +239,9 @@ class MirrorController(SubiquityController): # request handlers. raise RuntimeError("aborting install since no mirror is usable") elif fallback == MirrorSelectionFallback.OFFLINE_INSTALL: - log.warning("reverting to an offline install since no primary" - " mirror is usable") + log.warning( + "reverting to an offline install since no primary" " mirror is usable" + ) self.app.base_model.network.force_offline = True elif fallback == MirrorSelectionFallback.CONTINUE_ANYWAY: log.warning("continuing the install despite no usable mirror") @@ -253,21 +252,25 @@ class MirrorController(SubiquityController): lambda c: c.supports_arch(self.model.architecture), ] try: - candidate = next(filter_candidates( - self.model.primary_candidates, filters=filters)) + candidate = next( + filter_candidates(self.model.primary_candidates, filters=filters) + ) except StopIteration: - candidate = next(filter_candidates( - self.model.get_default_primary_candidates(), - filters=filters)) - log.warning("deciding to elect primary mirror %s", - candidate.serialize_for_ai()) + candidate = next( + filter_candidates( + self.model.get_default_primary_candidates(), filters=filters + ) + ) + log.warning( + "deciding to elect primary mirror %s", candidate.serialize_for_ai() + ) candidate.elect() else: raise RuntimeError(f"invalid fallback value: {fallback}") async def run_mirror_selection_or_fallback(self, context): - """ Perform the mirror selection and apply the configured fallback - method if no mirror is usable. """ + """Perform the mirror selection and apply the configured fallback + method if no mirror is usable.""" try: await self.find_and_elect_candidate_mirror(context=context) except NoUsableMirrorError: @@ -287,14 +290,17 @@ class MirrorController(SubiquityController): if self.autoinstall_apply_started: # Alternatively, we should cancel and restart the # apply_autoinstall_config but this is out of scope. - raise RuntimeError("source model has changed but autoinstall" - " configuration is already being applied") + raise RuntimeError( + "source model has changed but autoinstall" + " configuration is already being applied" + ) source_entry = self.app.base_model.source.current - if source_entry.variant == 'core': + if source_entry.variant == "core": self.test_apt_configurer = None else: self.test_apt_configurer = get_apt_configurer( - self.app, self.app.controllers.Source.get_handler()) + self.app, self.app.controllers.Source.get_handler() + ) self.source_configured_event.set() def serialize(self): @@ -310,7 +316,7 @@ class MirrorController(SubiquityController): def make_autoinstall(self): config = self.model.make_autoinstall() - config['geoip'] = self.geoip_enabled + config["geoip"] = self.geoip_enabled return config async def _promote_mirror(self): @@ -321,8 +327,7 @@ class MirrorController(SubiquityController): # null as the body instead. await self.run_mirror_selection_or_fallback(self.context) assert self.final_apt_configurer is not None - await self.final_apt_configurer.apply_apt_config( - self.context, final=True) + await self.final_apt_configurer.apply_apt_config(self.context, final=True) async def run_mirror_testing(self, output: io.StringIO) -> None: await self.source_configured_event.wait() @@ -334,13 +339,13 @@ class MirrorController(SubiquityController): if configurer is None: # i.e. core return - await configurer.apply_apt_config( - self.context, final=False) + await configurer.apply_apt_config(self.context, final=False) await configurer.run_apt_config_check(output) async def wait_config(self, variation_name: str) -> AptConfigurer: self.final_apt_configurer = get_apt_configurer( - self.app, self.app.controllers.Source.get_handler(variation_name)) + self.app, self.app.controllers.Source.get_handler(variation_name) + ) await self._promote_mirror() assert self.final_apt_configurer is not None return self.final_apt_configurer @@ -350,7 +355,7 @@ class MirrorController(SubiquityController): staged: Optional[str] = None candidates: List[str] = [] source_entry = self.app.base_model.source.current - if source_entry.variant == 'core': + if source_entry.variant == "core": relevant = False else: relevant = True @@ -363,10 +368,8 @@ class MirrorController(SubiquityController): # Skip the country-mirrors if they have not been resolved yet. candidates = [c.uri for c in compatibles if c.uri is not None] return MirrorGet( - relevant=relevant, - elected=elected, - candidates=candidates, - staged=staged) + relevant=relevant, elected=elected, candidates=candidates, staged=staged + ) async def POST(self, data: Optional[MirrorPost]) -> MirrorPostResponse: log.debug(data) @@ -379,8 +382,10 @@ class MirrorController(SubiquityController): try: await self.find_and_elect_candidate_mirror(self.context) except NoUsableMirrorError: - log.warning("found no usable mirror, expecting the client to" - " give up or to adjust the settings and retry") + log.warning( + "found no usable mirror, expecting the client to" + " give up or to adjust the settings and retry" + ) return MirrorPostResponse.NO_USABLE_MIRROR else: await self.configured() @@ -390,8 +395,9 @@ class MirrorController(SubiquityController): if not data.candidates: raise ValueError("cannot specify an empty list of candidates") uris = data.candidates - self.model.primary_candidates = \ - [self.model.create_primary_candidate(uri) for uri in uris] + self.model.primary_candidates = [ + self.model.create_primary_candidate(uri) for uri in uris + ] if data.staged is not None: self.model.create_primary_candidate(data.staged).stage() @@ -404,11 +410,11 @@ class MirrorController(SubiquityController): # ability to use a mirror for one install without it ending up in # the autoinstall config. Is it worth it though? def ensure_elected_in_candidates(): - if any(map(lambda c: c.uri == data.elected, - self.model.primary_candidates)): + if any( + map(lambda c: c.uri == data.elected, self.model.primary_candidates) + ): return - self.model.primary_candidates.insert( - 0, self.model.primary_elected) + self.model.primary_candidates.insert(0, self.model.primary_elected) if data.candidates is None: ensure_elected_in_candidates() @@ -423,8 +429,7 @@ class MirrorController(SubiquityController): log.debug(data) self.model.disabled_components = set(data) - async def check_mirror_start_POST( - self, cancel_ongoing: bool = False) -> None: + async def check_mirror_start_POST(self, cancel_ongoing: bool = False) -> None: if self.mirror_check is not None and not self.mirror_check.task.done(): if cancel_ongoing: await self.check_mirror_abort_POST() @@ -432,17 +437,19 @@ class MirrorController(SubiquityController): assert False output = io.StringIO() self.mirror_check = MirrorCheck( - uri=self.model.primary_staged.uri, - task=asyncio.create_task(self.run_mirror_testing(output)), - output=output) + uri=self.model.primary_staged.uri, + task=asyncio.create_task(self.run_mirror_testing(output)), + output=output, + ) async def check_mirror_progress_GET(self) -> Optional[MirrorCheckResponse]: if self.mirror_check is None: return None if self.mirror_check.task.done(): if self.mirror_check.task.exception(): - log.warning("Mirror check failed: %r", - self.mirror_check.task.exception()) + log.warning( + "Mirror check failed: %r", self.mirror_check.task.exception() + ) status = MirrorCheckStatus.FAILED else: status = MirrorCheckStatus.OK @@ -450,9 +457,10 @@ class MirrorController(SubiquityController): status = MirrorCheckStatus.RUNNING return MirrorCheckResponse( - url=self.mirror_check.uri, - status=status, - output=self.mirror_check.output.getvalue()) + url=self.mirror_check.uri, + status=status, + output=self.mirror_check.output.getvalue(), + ) async def check_mirror_abort_POST(self) -> None: if self.mirror_check is None: diff --git a/subiquity/server/controllers/network.py b/subiquity/server/controllers/network.py index ccee3bdd..463c8131 100644 --- a/subiquity/server/controllers/network.py +++ b/subiquity/server/controllers/network.py @@ -19,91 +19,74 @@ from typing import List, Optional import aiohttp -from subiquitycore.async_helpers import ( - run_bg_task, - schedule_task, - ) +from subiquity.common.api.client import make_client_for_conn +from subiquity.common.apidef import API, LinkAction, NetEventAPI +from subiquity.common.errorreport import ErrorReportKind +from subiquity.common.types import NetworkStatus, PackageInstallState +from subiquity.server.controller import SubiquityController +from subiquitycore.async_helpers import run_bg_task, schedule_task from subiquitycore.context import with_context from subiquitycore.controllers.network import BaseNetworkController -from subiquitycore.models.network import ( - BondConfig, - StaticConfig, - WLANConfig, - ) - -from subiquity.common.api.client import make_client_for_conn -from subiquity.common.apidef import ( - API, - LinkAction, - NetEventAPI, - ) -from subiquity.common.errorreport import ErrorReportKind -from subiquity.common.types import ( - NetworkStatus, - PackageInstallState, - ) -from subiquity.server.controller import SubiquityController - +from subiquitycore.models.network import BondConfig, StaticConfig, WLANConfig log = logging.getLogger("subiquity.server.controllers.network") MATCH = { - 'type': 'object', - 'properties': { - 'name': {'type': 'string'}, - 'macaddress': {'type': 'string'}, - 'driver': {'type': 'string'}, - }, - 'additionalProperties': False, - } + "type": "object", + "properties": { + "name": {"type": "string"}, + "macaddress": {"type": "string"}, + "driver": {"type": "string"}, + }, + "additionalProperties": False, +} NETPLAN_SCHEMA = { - 'type': 'object', - 'properties': { - 'version': { - 'type': 'integer', - 'minimum': 2, - 'maximum': 2, - }, - 'ethernets': { - 'type': 'object', - 'properties': { - 'match': MATCH, - } - }, - 'wifis': { - 'type': 'object', - 'properties': { - 'match': MATCH, - } - }, - 'bridges': {'type': 'object'}, - 'bonds': {'type': 'object'}, - 'tunnels': {'type': 'object'}, - 'vlans': {'type': 'object'}, + "type": "object", + "properties": { + "version": { + "type": "integer", + "minimum": 2, + "maximum": 2, }, - 'required': ['version'], - } + "ethernets": { + "type": "object", + "properties": { + "match": MATCH, + }, + }, + "wifis": { + "type": "object", + "properties": { + "match": MATCH, + }, + }, + "bridges": {"type": "object"}, + "bonds": {"type": "object"}, + "tunnels": {"type": "object"}, + "vlans": {"type": "object"}, + }, + "required": ["version"], +} class NetworkController(BaseNetworkController, SubiquityController): - endpoint = API.network ai_data = None autoinstall_key = "network" autoinstall_schema = { - 'oneOf': [ + "oneOf": [ NETPLAN_SCHEMA, { - 'type': 'object', - 'properties': { - 'network': NETPLAN_SCHEMA, - }, - 'required': ['network'], + "type": "object", + "properties": { + "network": NETPLAN_SCHEMA, + }, + "required": ["network"], }, - ], - } + ], + } def __init__(self, app): super().__init__(app) @@ -114,25 +97,26 @@ class NetworkController(BaseNetworkController, SubiquityController): self.pending_wlan_devices = set() def maybe_start_install_wpasupplicant(self): - log.debug('maybe_start_install_wpasupplicant') + log.debug("maybe_start_install_wpasupplicant") if self.install_wpasupplicant_task is not None: return self.install_wpasupplicant_task = asyncio.create_task( - self._install_wpasupplicant()) + self._install_wpasupplicant() + ) def wlan_support_install_state(self): - return self.app.package_installer.state_for_pkg('wpasupplicant') + return self.app.package_installer.state_for_pkg("wpasupplicant") async def _install_wpasupplicant(self): if self.opts.dry_run: - await asyncio.sleep(10/self.app.scale_factor) - a = 'DONE' + await asyncio.sleep(10 / self.app.scale_factor) + a = "DONE" for k in self.app.debug_flags: - if k.startswith('wlan_install='): - a = k.split('=', 2)[1] + if k.startswith("wlan_install="): + a = k.split("=", 2)[1] r = getattr(PackageInstallState, a) else: - r = await self.app.package_installer.install_pkg('wpasupplicant') + r = await self.app.package_installer.install_pkg("wpasupplicant") log.debug("wlan_support_install_finished %s", r) self._call_clients("wlan_support_install_finished", r) if r == PackageInstallState.DONE: @@ -153,12 +137,12 @@ class NetworkController(BaseNetworkController, SubiquityController): # # in your autoinstall config. Continue to support that for # backwards compatibility. - if 'network' in self.ai_data: - self.ai_data = self.ai_data['network'] + if "network" in self.ai_data: + self.ai_data = self.ai_data["network"] def start(self): if self.ai_data is not None: - self.model.override_config = {'network': self.ai_data} + self.model.override_config = {"network": self.ai_data} self.apply_config() if self.interactive(): # If interactive, we want edits in the UI to override @@ -193,8 +177,8 @@ class NetworkController(BaseNetworkController, SubiquityController): with context.child("wait_dhcp"): try: await asyncio.wait_for( - asyncio.wait({e.wait() for e in dhcp_events}), - 10) + asyncio.wait({e.wait() for e in dhcp_events}), 10 + ) except asyncio.TimeoutError: pass @@ -208,9 +192,11 @@ class NetworkController(BaseNetworkController, SubiquityController): self.update_initial_configs() self.apply_config(context) else: - log.debug("NetworkManager is enabled and no network" - " autoinstall section was found. Not applying" - " network settings.") + log.debug( + "NetworkManager is enabled and no network" + " autoinstall section was found. Not applying" + " network settings." + ) want_apply_config = False if want_apply_config: with context.child("wait_for_apply"): @@ -233,33 +219,34 @@ class NetworkController(BaseNetworkController, SubiquityController): log.exception("_apply_config failed") self.model.has_network = False self.app.make_apport_report( - ErrorReportKind.NETWORK_FAIL, "applying network") + ErrorReportKind.NETWORK_FAIL, "applying network" + ) if not self.interactive(): raise def make_autoinstall(self): - return self.model.render_config()['network'] + return self.model.render_config()["network"] async def GET(self) -> NetworkStatus: if not self.view_shown: self.apply_config(silent=True) self.view_shown = True - if self.wlan_support_install_state() == \ - PackageInstallState.DONE: + if self.wlan_support_install_state() == PackageInstallState.DONE: devices = self.model.get_all_netdevs() else: devices = [ - dev for dev in self.model.get_all_netdevs() - if dev.type != 'wlan' - ] + dev for dev in self.model.get_all_netdevs() if dev.type != "wlan" + ] return NetworkStatus( devices=[dev.netdev_info() for dev in devices], - wlan_support_install_state=self.wlan_support_install_state()) + wlan_support_install_state=self.wlan_support_install_state(), + ) async def configured(self): self.model.has_network = self.network_event_receiver.has_default_route self.model.needs_wpasupplicant = ( - self.wlan_support_install_state() == PackageInstallState.DONE) + self.wlan_support_install_state() == PackageInstallState.DONE + ) await super().configured() async def POST(self) -> None: @@ -272,20 +259,25 @@ class NetworkController(BaseNetworkController, SubiquityController): return ips async def subscription_PUT(self, socket_path: str) -> None: - log.debug('added subscription %s', socket_path) + log.debug("added subscription %s", socket_path) conn = aiohttp.UnixConnector(socket_path) client = make_client_for_conn(NetEventAPI, conn) lock = asyncio.Lock() self.clients[socket_path] = (client, conn, lock) run_bg_task( self._call_client( - client, conn, lock, "route_watch", - self.network_event_receiver.has_default_route)) + client, + conn, + lock, + "route_watch", + self.network_event_receiver.has_default_route, + ) + ) async def subscription_DELETE(self, socket_path: str) -> None: if socket_path not in self.clients: return - log.debug('removed subscription %s', socket_path) + log.debug("removed subscription %s", socket_path) client, conn, lock = self.clients.pop(socket_path) async with lock: await conn.close() @@ -294,7 +286,7 @@ class NetworkController(BaseNetworkController, SubiquityController): async with lock: log.debug("_call_client %s %s", meth_name, conn.path) if conn.closed: - log.debug('closed') + log.debug("closed") return try: await getattr(client, meth_name).POST(*args) @@ -303,9 +295,8 @@ class NetworkController(BaseNetworkController, SubiquityController): def _call_clients(self, meth_name, *args): for client, conn, lock in self.clients.values(): - log.debug('creating _call_client task %s %s', conn.path, meth_name) - run_bg_task( - self._call_client(client, conn, lock, meth_name, *args)) + log.debug("creating _call_client task %s %s", conn.path, meth_name) + run_bg_task(self._call_client(client, conn, lock, meth_name, *args)) def apply_starting(self): super().apply_starting() @@ -324,22 +315,23 @@ class NetworkController(BaseNetworkController, SubiquityController): self._call_clients("route_watch", has_default_route) def _send_update(self, act, dev): - with self.context.child( - "_send_update", "{} {}".format(act.name, dev.name)): + with self.context.child("_send_update", "{} {}".format(act.name, dev.name)): log.debug("dev_info {} {}".format(dev.name, dev.config)) dev_info = dev.netdev_info() self._call_clients("update_link", act, dev_info) def new_link(self, dev): super().new_link(dev) - if dev.type == 'wlan': + if dev.type == "wlan": self.maybe_start_install_wpasupplicant() state = self.wlan_support_install_state() if state == PackageInstallState.INSTALLING: self.pending_wlan_devices.add(dev) return - elif state in [PackageInstallState.FAILED, - PackageInstallState.NOT_AVAILABLE]: + elif state in [ + PackageInstallState.FAILED, + PackageInstallState.NOT_AVAILABLE, + ]: return # PackageInstallState.DONE falls through self._send_update(LinkAction.NEW, dev) @@ -352,8 +344,9 @@ class NetworkController(BaseNetworkController, SubiquityController): super().del_link(dev) self._send_update(LinkAction.DEL, dev) - async def set_static_config_POST(self, dev_name: str, ip_version: int, - static_config: StaticConfig) -> None: + async def set_static_config_POST( + self, dev_name: str, ip_version: int, static_config: StaticConfig + ) -> None: self.set_static_config(dev_name, ip_version, static_config) async def enable_dhcp_POST(self, dev_name: str, ip_version: int) -> None: @@ -365,9 +358,9 @@ class NetworkController(BaseNetworkController, SubiquityController): async def vlan_PUT(self, dev_name: str, vlan_id: int) -> None: self.add_vlan(dev_name, vlan_id) - async def add_or_edit_bond_POST(self, existing_name: Optional[str], - new_name: str, - bond_config: BondConfig) -> None: + async def add_or_edit_bond_POST( + self, existing_name: Optional[str], new_name: str, bond_config: BondConfig + ) -> None: self.add_or_update_bond(existing_name, new_name, bond_config) async def set_wlan_POST(self, dev_name: str, wlan: WLANConfig) -> None: diff --git a/subiquity/server/controllers/oem.py b/subiquity/server/controllers/oem.py index 7ed24608..793c752d 100644 --- a/subiquity/server/controllers/oem.py +++ b/subiquity/server/controllers/oem.py @@ -17,8 +17,6 @@ import asyncio import logging from typing import List, Optional -from subiquitycore.context import with_context - from subiquity.common.apidef import API from subiquity.common.types import OEMResponse from subiquity.models.oem import OEMMetaPkg @@ -30,13 +28,13 @@ from subiquity.server.types import InstallerChannels from subiquity.server.ubuntu_drivers import ( CommandNotFoundError, get_ubuntu_drivers_interface, - ) +) +from subiquitycore.context import with_context -log = logging.getLogger('subiquity.server.controllers.oem') +log = logging.getLogger("subiquity.server.controllers.oem") class OEMController(SubiquityController): - endpoint = API.oem autoinstall_key = model_name = "oem" @@ -71,23 +69,20 @@ class OEMController(SubiquityController): def start(self) -> None: self._wait_confirmation = asyncio.Event() self.app.hub.subscribe( - InstallerChannels.INSTALL_CONFIRMED, - self._wait_confirmation.set) + InstallerChannels.INSTALL_CONFIRMED, self._wait_confirmation.set + ) self._wait_apt = asyncio.Event() + self.app.hub.subscribe(InstallerChannels.APT_CONFIGURED, self._wait_apt.set) self.app.hub.subscribe( - InstallerChannels.APT_CONFIGURED, - self._wait_apt.set) - self.app.hub.subscribe( - (InstallerChannels.CONFIGURED, "kernel"), - self.kernel_configured_event.set) + (InstallerChannels.CONFIGURED, "kernel"), self.kernel_configured_event.set + ) async def list_and_mark_configured() -> None: await self.load_metapackages_list() await self.ensure_no_kernel_conflict() await self.configured() - self.load_metapkgs_task = asyncio.create_task( - list_and_mark_configured()) + self.load_metapkgs_task = asyncio.create_task(list_and_mark_configured()) def make_autoinstall(self): return self.model.make_autoinstall() @@ -95,17 +90,24 @@ class OEMController(SubiquityController): def load_autoinstall_data(self, *args, **kwargs) -> None: self.model.load_autoinstall_data(*args, **kwargs) - async def wants_oem_kernel(self, pkgname: str, - *, context, overlay) -> bool: - """ For a given package, tell whether it wants the OEM or the default + async def wants_oem_kernel(self, pkgname: str, *, context, overlay) -> bool: + """For a given package, tell whether it wants the OEM or the default kernel flavor. We look for the Ubuntu-Oem-Kernel-Flavour attribute in the package meta-data. If the attribute is present and has the value - "default", then return False. Otherwise, return True. """ + "default", then return False. Otherwise, return True.""" result = await run_curtin_command( - self.app, context, - "in-target", "-t", overlay.mountpoint, "--", - "apt-cache", "show", pkgname, - capture=True, private_mounts=True) + self.app, + context, + "in-target", + "-t", + overlay.mountpoint, + "--", + "apt-cache", + "show", + pkgname, + capture=True, + private_mounts=True, + ) for line in result.stdout.decode("utf-8").splitlines(): if not line.startswith("Ubuntu-Oem-Kernel-Flavour:"): continue @@ -116,8 +118,7 @@ class OEMController(SubiquityController): elif flavor == "oem": return True else: - log.warning("%s wants unexpected kernel flavor: %s", - pkgname, flavor) + log.warning("%s wants unexpected kernel flavor: %s", pkgname, flavor) return True log.warning("%s has no Ubuntu-Oem-Kernel-Flavour", pkgname) @@ -133,8 +134,7 @@ class OEMController(SubiquityController): variant: str = self.app.base_model.source.current.variant fs_controller = self.app.controllers.Filesystem if fs_controller.is_core_boot_classic(): - log.debug("listing of OEM meta-packages disabled on core boot" - " classic") + log.debug("listing of OEM meta-packages disabled on core boot" " classic") self.model.metapkgs = [] return if not self.model.install_on[variant]: @@ -155,14 +155,17 @@ class OEMController(SubiquityController): self.model.metapkgs = [] else: metapkgs: List[str] = await self.ubuntu_drivers.list_oem( - root_dir=d.mountpoint, - context=context) + root_dir=d.mountpoint, context=context + ) self.model.metapkgs = [ OEMMetaPkg( name=name, wants_oem_kernel=await self.wants_oem_kernel( - name, context=context, overlay=d), - ) for name in metapkgs] + name, context=context, overlay=d + ), + ) + for name in metapkgs + ] except OverlayCleanupError: log.exception("Failed to cleanup overlay. Continuing anyway.") @@ -171,7 +174,8 @@ class OEMController(SubiquityController): if pkg.wants_oem_kernel: kernel_model = self.app.base_model.kernel kernel_model.metapkg_name_override = flavor_to_pkgname( - pkg.name, dry_run=self.app.opts.dry_run) + pkg.name, dry_run=self.app.opts.dry_run + ) log.debug("overriding kernel flavor because of OEM") @@ -187,7 +191,8 @@ class OEMController(SubiquityController): # This should be a dialog or something rather than the content of # an exception, really. But this is a simple way to print out # something in autoinstall. - msg = _("""\ + msg = _( + """\ A specific kernel flavor was requested but it cannot be satistified when \ installing on certified hardware. You should either disable the installation of OEM meta-packages using the \ @@ -195,7 +200,8 @@ following autoinstall snippet or let the installer decide which kernel to install. oem: install: false -""") +""" + ) raise RuntimeError(msg) @with_context() diff --git a/subiquity/server/controllers/package.py b/subiquity/server/controllers/package.py index 6fb65413..61fe07f2 100644 --- a/subiquity/server/controllers/package.py +++ b/subiquity/server/controllers/package.py @@ -17,13 +17,12 @@ from subiquity.server.controller import NonInteractiveController class PackageController(NonInteractiveController): - model_name = autoinstall_key = "packages" autoinstall_default = [] autoinstall_schema = { - 'type': 'array', - 'items': {'type': 'string'}, - } + "type": "array", + "items": {"type": "string"}, + } def load_autoinstall_data(self, data): self.model[:] = data diff --git a/subiquity/server/controllers/proxy.py b/subiquity/server/controllers/proxy.py index 19fd9b83..95b4f06c 100644 --- a/subiquity/server/controllers/proxy.py +++ b/subiquity/server/controllers/proxy.py @@ -16,24 +16,22 @@ import logging import os -from subiquitycore.context import with_context - from subiquity.common.apidef import API from subiquity.server.controller import SubiquityController from subiquity.server.types import InstallerChannels +from subiquitycore.context import with_context -log = logging.getLogger('subiquity.server.controllers.proxy') +log = logging.getLogger("subiquity.server.controllers.proxy") class ProxyController(SubiquityController): - endpoint = API.proxy autoinstall_key = model_name = "proxy" autoinstall_schema = { - 'type': ['string', 'null'], - 'format': 'uri', - } + "type": ["string", "null"], + "format": "uri", + } _set_task = None @@ -43,10 +41,8 @@ class ProxyController(SubiquityController): def start(self): if self.model.proxy: - os.environ['http_proxy'] = os.environ['https_proxy'] = \ - self.model.proxy - self._set_task = self.app.hub.broadcast( - InstallerChannels.NETWORK_PROXY_SET) + os.environ["http_proxy"] = os.environ["https_proxy"] = self.model.proxy + self._set_task = self.app.hub.broadcast(InstallerChannels.NETWORK_PROXY_SET) @with_context() async def apply_autoinstall_config(self, context=None): @@ -67,6 +63,6 @@ class ProxyController(SubiquityController): async def POST(self, data: str): self.model.proxy = data - os.environ['http_proxy'] = os.environ['https_proxy'] = data + os.environ["http_proxy"] = os.environ["https_proxy"] = data self.app.hub.broadcast(InstallerChannels.NETWORK_PROXY_SET) await self.configured() diff --git a/subiquity/server/controllers/refresh.py b/subiquity/server/controllers/refresh.py index e5b4bdbc..2d640da7 100644 --- a/subiquity/server/controllers/refresh.py +++ b/subiquity/server/controllers/refresh.py @@ -21,32 +21,21 @@ from typing import Tuple import requests.exceptions -from subiquitycore.async_helpers import ( - schedule_task, - SingleInstanceTask, - ) -from subiquitycore.context import with_context -from subiquitycore.lsb_release import lsb_release - from subiquity.common.apidef import API -from subiquity.common.types import ( - Change, - RefreshCheckState, - RefreshStatus, - ) -from subiquity.server.controller import ( - SubiquityController, - ) +from subiquity.common.types import Change, RefreshCheckState, RefreshStatus +from subiquity.server.controller import SubiquityController from subiquity.server.snapdapi import ( SnapAction, SnapActionRequest, TaskStatus, post_and_wait, - ) +) from subiquity.server.types import InstallerChannels +from subiquitycore.async_helpers import SingleInstanceTask, schedule_task +from subiquitycore.context import with_context +from subiquitycore.lsb_release import lsb_release - -log = logging.getLogger('subiquity.server.controllers.refresh') +log = logging.getLogger("subiquity.server.controllers.refresh") class SnapChannelSource(enum.Enum): @@ -57,18 +46,17 @@ class SnapChannelSource(enum.Enum): class RefreshController(SubiquityController): - endpoint = API.refresh autoinstall_key = "refresh-installer" autoinstall_schema = { - 'type': 'object', - 'properties': { - 'update': {'type': 'boolean'}, - 'channel': {'type': 'string'}, - }, - 'additionalProperties': False, - } + "type": "object", + "properties": { + "update": {"type": "boolean"}, + "channel": {"type": "string"}, + }, + "additionalProperties": False, + } def __init__(self, app): super().__init__(app) @@ -77,8 +65,9 @@ class RefreshController(SubiquityController): self.configure_task = None self.check_task = None self.status = RefreshStatus(availability=RefreshCheckState.UNKNOWN) - self.app.hub.subscribe(InstallerChannels.SNAPD_NETWORK_CHANGE, - self.snapd_network_changed) + self.app.hub.subscribe( + InstallerChannels.SNAPD_NETWORK_CHANGE, self.snapd_network_changed + ) def load_autoinstall_data(self, data): if data is not None: @@ -86,8 +75,8 @@ class RefreshController(SubiquityController): @property def active(self): - if 'update' in self.ai_data: - return self.ai_data['update'] + if "update" in self.ai_data: + return self.ai_data["update"] else: return self.interactive() @@ -96,7 +85,8 @@ class RefreshController(SubiquityController): return self.configure_task = schedule_task(self.configure_snapd()) self.check_task = SingleInstanceTask( - self.check_for_update, propagate_errors=False) + self.check_for_update, propagate_errors=False + ) self.check_task.start_sync() @with_context() @@ -112,8 +102,7 @@ class RefreshController(SubiquityController): change_id = await self.start_update(context=context) while True: change = await self.get_progress(change_id) - if change.status not in [ - TaskStatus.DO, TaskStatus.DOING, TaskStatus.DONE]: + if change.status not in [TaskStatus.DO, TaskStatus.DOING, TaskStatus.DONE]: raise Exception(f"update failed: {change.status}") await asyncio.sleep(0.1) @@ -132,21 +121,24 @@ class RefreshController(SubiquityController): log.exception("getting snap details") return self.status.current_snap_version = snap.version - for k in 'channel', 'revision', 'version': - self.app.note_data_for_apport( - "Snap" + k.title(), getattr(snap, k)) + for k in "channel", "revision", "version": + self.app.note_data_for_apport("Snap" + k.title(), getattr(snap, k)) subcontext.description = "current version of snap is: %r" % ( - self.status.current_snap_version) + self.status.current_snap_version + ) (channel, source) = self.get_refresh_channel() if source == SnapChannelSource.NOT_FOUND: log.debug("no refresh channel found") return info = lsb_release(dry_run=self.app.opts.dry_run) - expected_channel = 'stable/ubuntu-' + info['release'] - if source == SnapChannelSource.DISK_INFO_FILE \ - and snap.channel != expected_channel: - log.debug(f"snap tracking {snap.channel}, not resetting based " - "on .disk/info") + expected_channel = "stable/ubuntu-" + info["release"] + if ( + source == SnapChannelSource.DISK_INFO_FILE + and snap.channel != expected_channel + ): + log.debug( + f"snap tracking {snap.channel}, not resetting based " "on .disk/info" + ) return desc = "switching {} to {}".format(self.snap_name, channel) with context.child("switching", desc) as subcontext: @@ -154,8 +146,8 @@ class RefreshController(SubiquityController): await post_and_wait( self.app.snapdapi, self.app.snapdapi.v2.snaps[self.snap_name].POST, - SnapActionRequest( - action=SnapAction.SWITCH, channel=channel)) + SnapActionRequest(action=SnapAction.SWITCH, channel=channel), + ) except requests.exceptions.RequestException: log.exception("switching channels") return @@ -163,35 +155,33 @@ class RefreshController(SubiquityController): def get_refresh_channel(self) -> Tuple[str, SnapChannelSource]: """Return the channel we should refresh subiquity to.""" - channel = self.app.kernel_cmdline.get('subiquity-channel') + channel = self.app.kernel_cmdline.get("subiquity-channel") if channel is not None: - log.debug( - "get_refresh_channel: found %s on kernel cmdline", channel) + log.debug("get_refresh_channel: found %s on kernel cmdline", channel) return (channel, SnapChannelSource.CMDLINE) - if 'channel' in self.ai_data: - return (self.ai_data['channel'], SnapChannelSource.AUTOINSTALL) + if "channel" in self.ai_data: + return (self.ai_data["channel"], SnapChannelSource.AUTOINSTALL) - info_file = '/cdrom/.disk/info' + info_file = "/cdrom/.disk/info" try: fp = open(info_file) except FileNotFoundError: if self.opts.dry_run: info = ( 'Ubuntu-Server 18.04.2 LTS "Bionic Beaver" - ' - 'Release amd64 (20190214.3)') + "Release amd64 (20190214.3)" + ) else: - log.debug( - "get_refresh_channel: failed to find .disk/info file") + log.debug("get_refresh_channel: failed to find .disk/info file") return (None, SnapChannelSource.NOT_FOUND) else: with fp: info = fp.read() release = info.split()[1] - return ('stable/ubuntu-' + release, SnapChannelSource.DISK_INFO_FILE) + return ("stable/ubuntu-" + release, SnapChannelSource.DISK_INFO_FILE) def snapd_network_changed(self): - if self.active and \ - self.status.availability == RefreshCheckState.UNKNOWN: + if self.active and self.status.availability == RefreshCheckState.UNKNOWN: self.check_task.start_sync() @with_context() @@ -202,7 +192,7 @@ class RefreshController(SubiquityController): self.status.availability = RefreshCheckState.UNAVAILABLE return try: - result = await self.app.snapdapi.v2.find.GET(select='refresh') + result = await self.app.snapdapi.v2.find.GET(select="refresh") except requests.exceptions.RequestException: log.exception("checking for snap update failed") context.description = "checking for snap update failed" @@ -214,8 +204,8 @@ class RefreshController(SubiquityController): continue self.status.new_snap_version = snap.version context.description = ( - "new version of snap available: %r" - % self.status.new_snap_version) + "new version of snap available: %r" % self.status.new_snap_version + ) self.status.availability = RefreshCheckState.AVAILABLE return else: @@ -225,7 +215,8 @@ class RefreshController(SubiquityController): @with_context() async def start_update(self, context): change_id = await self.app.snapdapi.v2.snaps[self.snap_name].POST( - SnapActionRequest(action=SnapAction.REFRESH, ignore_running=True)) + SnapActionRequest(action=SnapAction.REFRESH, ignore_running=True) + ) context.description = "change id: {}".format(change_id) return change_id diff --git a/subiquity/server/controllers/reporting.py b/subiquity/server/controllers/reporting.py index 7cf1b2de..90ad1ee7 100644 --- a/subiquity/server/controllers/reporting.py +++ b/subiquity/server/controllers/reporting.py @@ -16,15 +16,8 @@ import copy import logging -from curtin.reporter import ( - available_handlers, - update_configuration, - ) -from curtin.reporter.events import ( - report_finish_event, - report_start_event, - status, - ) +from curtin.reporter import available_handlers, update_configuration +from curtin.reporter.events import report_finish_event, report_start_event, status from curtin.reporter.handlers import LogHandler as CurtinLogHandler from subiquity.server.controller import NonInteractiveController @@ -33,33 +26,32 @@ from subiquity.server.controller import NonInteractiveController class LogHandler(CurtinLogHandler): def publish_event(self, event): level = getattr(logging, event.level) - logger = logging.getLogger('') + logger = logging.getLogger("") logger.log(level, event.as_string()) -available_handlers.unregister_item('log') -available_handlers.register_item('log', LogHandler) +available_handlers.unregister_item("log") +available_handlers.register_item("log", LogHandler) INITIAL_CONFIG = { - 'logging': {'type': 'log'}, - } -NON_INTERACTIVE_CONFIG = {'builtin': {'type': 'print'}} + "logging": {"type": "log"}, +} +NON_INTERACTIVE_CONFIG = {"builtin": {"type": "print"}} class ReportingController(NonInteractiveController): - autoinstall_key = "reporting" autoinstall_schema = { - 'type': 'object', - 'additionalProperties': { - 'type': 'object', - 'properties': { - 'type': {'type': 'string'}, - }, - 'required': ['type'], - 'additionalProperties': True, + "type": "object", + "additionalProperties": { + "type": "object", + "properties": { + "type": {"type": "string"}, }, - } + "required": ["type"], + "additionalProperties": True, + }, + } def __init__(self, app): super().__init__(app) @@ -77,10 +69,10 @@ class ReportingController(NonInteractiveController): update_configuration(self.config) def report_start_event(self, context, description): - report_start_event( - context.full_name(), description, level=context.level) + report_start_event(context.full_name(), description, level=context.level) def report_finish_event(self, context, description, result): result = getattr(status, result.name, status.WARN) report_finish_event( - context.full_name(), description, result, level=context.level) + context.full_name(), description, result, level=context.level + ) diff --git a/subiquity/server/controllers/shutdown.py b/subiquity/server/controllers/shutdown.py index aebd6b6e..804c49b0 100644 --- a/subiquity/server/controllers/shutdown.py +++ b/subiquity/server/controllers/shutdown.py @@ -19,30 +19,22 @@ import os import platform import subprocess -from subiquitycore.file_util import ( - open_perms, - set_log_perms, -) -from subiquitycore.async_helpers import run_bg_task -from subiquitycore.context import with_context - from subiquity.common.apidef import API from subiquity.common.types import ShutdownMode from subiquity.server.controller import SubiquityController from subiquity.server.controllers.install import ApplicationState from subiquity.server.types import InstallerChannels +from subiquitycore.async_helpers import run_bg_task +from subiquitycore.context import with_context +from subiquitycore.file_util import open_perms, set_log_perms log = logging.getLogger("subiquity.server.controllers.shutdown") class ShutdownController(SubiquityController): - endpoint = API.shutdown - autoinstall_key = 'shutdown' - autoinstall_schema = { - 'type': 'string', - 'enum': ['reboot', 'poweroff'] - } + autoinstall_key = "shutdown" + autoinstall_schema = {"type": "string", "enum": ["reboot", "poweroff"]} def __init__(self, app): super().__init__(app) @@ -57,9 +49,9 @@ class ShutdownController(SubiquityController): self.mode = ShutdownMode.REBOOT def load_autoinstall_data(self, data): - if data == 'reboot': + if data == "reboot": self.mode = ShutdownMode.REBOOT - elif data == 'poweroff': + elif data == "poweroff": self.mode = ShutdownMode.POWEROFF async def POST(self, mode: ShutdownMode, immediate: bool = False): @@ -93,42 +85,43 @@ class ShutdownController(SubiquityController): @with_context() async def copy_logs_to_target(self, context): - if self.opts.dry_run and 'copy-logs-fail' in self.app.debug_flags: + if self.opts.dry_run and "copy-logs-fail" in self.app.debug_flags: raise PermissionError() if self.app.controllers.Filesystem.reset_partition_only: return - target_logs = os.path.join( - self.app.base_model.target, 'var/log/installer') + target_logs = os.path.join(self.app.base_model.target, "var/log/installer") if self.opts.dry_run: os.makedirs(target_logs, exist_ok=True) else: # Preserve ephemeral boot cloud-init logs if applicable cloudinit_logs = [ - cloudinit_log - for cloudinit_log in ( + cloudinit_log + for cloudinit_log in ( "/var/log/cloud-init.log", - "/var/log/cloud-init-output.log" - ) - if os.path.exists(cloudinit_log) + "/var/log/cloud-init-output.log", + ) + if os.path.exists(cloudinit_log) ] if cloudinit_logs: await self.app.command_runner.run( - ['cp', '-a'] + cloudinit_logs + ['/var/log/installer']) + ["cp", "-a"] + cloudinit_logs + ["/var/log/installer"] + ) await self.app.command_runner.run( - ['cp', '-aT', '/var/log/installer', target_logs]) + ["cp", "-aT", "/var/log/installer", target_logs] + ) # Close the permissions from group writes on the target. set_log_perms(target_logs, isdir=True, group_write=False) - journal_txt = os.path.join(target_logs, 'installer-journal.txt') + journal_txt = os.path.join(target_logs, "installer-journal.txt") try: with open_perms(journal_txt) as output: await self.app.command_runner.run( - ['journalctl', '-b'], - stdout=output, stderr=subprocess.STDOUT) + ["journalctl", "-b"], stdout=output, stderr=subprocess.STDOUT + ) except Exception: log.exception("saving journal failed") - @with_context(description='mode={self.mode.name}') + @with_context(description="mode={self.mode.name}") async def shutdown(self, context): await self.app.hub.abroadcast(InstallerChannels.PRE_SHUTDOWN) # As PRE_SHUTDOWN is supposed to be as close as possible to the @@ -138,9 +131,8 @@ class ShutdownController(SubiquityController): self.app.exit() else: if self.app.state == ApplicationState.DONE: - if platform.machine() == 's390x': - await self.app.command_runner.run( - ["chreipl", "/target/boot"]) + if platform.machine() == "s390x": + await self.app.command_runner.run(["chreipl", "/target/boot"]) if self.mode == ShutdownMode.REBOOT: await self.app.command_runner.run(["/sbin/reboot"]) elif self.mode == ShutdownMode.POWEROFF: diff --git a/subiquity/server/controllers/snaplist.py b/subiquity/server/controllers/snaplist.py index 1b174914..949084e6 100644 --- a/subiquity/server/controllers/snaplist.py +++ b/subiquity/server/controllers/snaplist.py @@ -18,36 +18,28 @@ import logging from typing import List import attr - import requests.exceptions -from subiquitycore.async_helpers import ( - schedule_task, - ) -from subiquitycore.context import with_context - from subiquity.common.apidef import API from subiquity.common.types import ( SnapCheckState, SnapInfo, SnapListResponse, SnapSelection, - ) -from subiquity.server.controller import ( - SubiquityController, - ) +) +from subiquity.server.controller import SubiquityController from subiquity.server.types import InstallerChannels +from subiquitycore.async_helpers import schedule_task +from subiquitycore.context import with_context - -log = logging.getLogger('subiquity.server.controllers.snaplist') +log = logging.getLogger("subiquity.server.controllers.snaplist") class SnapListFetchError(Exception): - """ Exception to raise when the list of snaps could not be fetched. """ + """Exception to raise when the list of snaps could not be fetched.""" class SnapdSnapInfoLoader: - def __init__(self, model, snapd, store_section, context): self.model = model self.store_section = store_section @@ -62,8 +54,8 @@ class SnapdSnapInfoLoader: self.load_list_task_created = asyncio.Event() def _fetch_list_ended(self) -> bool: - """ Tells whether the snap list fetch task has ended without being - cancelled. """ + """Tells whether the snap list fetch task has ended without being + cancelled.""" if None not in self.tasks: return False task = self.get_snap_list_task() @@ -72,13 +64,13 @@ class SnapdSnapInfoLoader: return task.done() def fetch_list_completed(self) -> bool: - """ Tells whether the snap list fetch task has completed. """ + """Tells whether the snap list fetch task has completed.""" if not self._fetch_list_ended(): return False return not self.get_snap_list_task().exception() def fetch_list_failed(self) -> bool: - """ Tells whether the snap list fetch task has failed. """ + """Tells whether the snap list fetch task has failed.""" if not self._fetch_list_ended(): return False return bool(self.get_snap_list_task().exception()) @@ -101,14 +93,14 @@ class SnapdSnapInfoLoader: while self.pending_snaps: snap = self.pending_snaps.pop(0) task = self.tasks[snap] = schedule_task( - self._fetch_info_for_snap(snap=snap)) + self._fetch_info_for_snap(snap=snap) + ) await task @with_context(name="list") async def _load_list(self, context=None): try: - result = await self.snapd.get( - 'v2/find', section=self.store_section) + result = await self.snapd.get("v2/find", section=self.store_section) except requests.exceptions.RequestException: raise SnapListFetchError self.model.load_find_data(result) @@ -120,7 +112,7 @@ class SnapdSnapInfoLoader: @with_context(name="fetch/{snap.name}") async def _fetch_info_for_snap(self, snap, context=None): try: - data = await self.snapd.get('v2/find', name=snap.name) + data = await self.snapd.get("v2/find", name=snap.name) except requests.exceptions.RequestException: log.exception("loading snap info failed") # XXX something better here? @@ -134,52 +126,57 @@ class SnapdSnapInfoLoader: if snap not in self.tasks: if snap in self.pending_snaps: self.pending_snaps.remove(snap) - self.tasks[snap] = schedule_task( - self._fetch_info_for_snap(snap=snap)) + self.tasks[snap] = schedule_task(self._fetch_info_for_snap(snap=snap)) return self.tasks[snap] class SnapListController(SubiquityController): - endpoint = API.snaplist autoinstall_key = "snaps" autoinstall_default = [] autoinstall_schema = { - 'type': 'array', - 'items': { - 'type': 'object', - 'properties': { - 'name': {'type': 'string'}, - 'channel': {'type': 'string'}, - 'classic': {'type': 'boolean'}, - }, - 'required': ['name'], - 'additionalProperties': False, + "type": "array", + "items": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "channel": {"type": "string"}, + "classic": {"type": "boolean"}, }, - } + "required": ["name"], + "additionalProperties": False, + }, + } model_name = "snaplist" - interactive_for_variants = {'server'} + interactive_for_variants = {"server"} def _make_loader(self): return SnapdSnapInfoLoader( - self.model, self.app.snapd, self.opts.snap_section, - self.context.child("loader")) + self.model, + self.app.snapd, + self.opts.snap_section, + self.context.child("loader"), + ) def __init__(self, app): super().__init__(app) self.loader = self._make_loader() self.app.hub.subscribe( - InstallerChannels.SNAPD_NETWORK_CHANGE, self.snapd_network_changed) + InstallerChannels.SNAPD_NETWORK_CHANGE, self.snapd_network_changed + ) def load_autoinstall_data(self, ai_data): to_install = [] for snap in ai_data: - to_install.append(SnapSelection( - name=snap['name'], - channel=snap.get('channel', 'stable'), - classic=snap.get('classic', False))) + to_install.append( + SnapSelection( + name=snap["name"], + channel=snap.get("channel", "stable"), + classic=snap.get("classic", False), + ) + ) self.model.set_installed_list(to_install) def snapd_network_changed(self): @@ -198,8 +195,10 @@ class SnapListController(SubiquityController): return [attr.asdict(sel) for sel in self.model.selections] async def GET(self, wait: bool = False) -> SnapListResponse: - if self.loader.fetch_list_failed() \ - or not self.app.base_model.network.has_network: + if ( + self.loader.fetch_list_failed() + or not self.app.base_model.network.has_network + ): return SnapListResponse(status=SnapCheckState.FAILED) if not self.loader.fetch_list_completed() and not wait: return SnapListResponse(status=SnapCheckState.LOADING) @@ -218,7 +217,8 @@ class SnapListController(SubiquityController): return SnapListResponse( status=SnapCheckState.DONE, snaps=self.model.get_snap_list(), - selections=self.model.selections) + selections=self.model.selections, + ) async def POST(self, data: List[SnapSelection]): log.debug(data) diff --git a/subiquity/server/controllers/source.py b/subiquity/server/controllers/source.py index 684038da..2c326474 100644 --- a/subiquity/server/controllers/source.py +++ b/subiquity/server/controllers/source.py @@ -14,31 +14,28 @@ # along with this program. If not, see . import contextlib -from typing import Any, Optional import os +from typing import Any, Optional from curtin.commands.extract import ( AbstractSourceHandler, - get_handler_for_source, TrivialSourceHandler, - ) + get_handler_for_source, +) from curtin.util import sanitize_source from subiquity.common.apidef import API -from subiquity.common.types import ( - SourceSelection, - SourceSelectionAndSetting, - ) +from subiquity.common.types import SourceSelection, SourceSelectionAndSetting from subiquity.server.controller import SubiquityController from subiquity.server.types import InstallerChannels def _translate(d, lang): if lang: - for lang in lang, lang.split('_', 1)[0]: + for lang in lang, lang.split("_", 1)[0]: if lang in d: return d[lang] - return _(d['en']) + return _(d["en"]) def convert_source(source, lang): @@ -49,26 +46,26 @@ def convert_source(source, lang): id=source.id, size=size, variant=source.variant, - default=source.default) + default=source.default, + ) class SourceController(SubiquityController): - model_name = "source" endpoint = API.source autoinstall_key = "source" autoinstall_schema = { - "type": "object", - "properties": { - "search_drivers": { - "type": "boolean", - }, - "id": { - "type": "string", - }, - }, + "type": "object", + "properties": { + "search_drivers": { + "type": "boolean", + }, + "id": { + "type": "string", + }, + }, } # Defaults to true for backward compatibility with existing autoinstall # configurations. Back then, then users were able to install third-party @@ -83,9 +80,9 @@ class SourceController(SubiquityController): def make_autoinstall(self): return { - "search_drivers": self.model.search_drivers, - "id": self.model.current.id, - } + "search_drivers": self.model.search_drivers, + "id": self.model.current.id, + } def load_autoinstall_data(self, data: Any) -> None: if data is None: @@ -103,7 +100,7 @@ class SourceController(SubiquityController): self.ai_source_id = data.get("id") def start(self): - path = '/cdrom/casper/install-sources.yaml' + path = "/cdrom/casper/install-sources.yaml" if self.app.opts.source_catalog is not None: path = self.app.opts.source_catalog if not os.path.exists(path): @@ -112,41 +109,40 @@ class SourceController(SubiquityController): self.model.load_from_file(fp) # Assign the current source if hinted by autoinstall. if self.ai_source_id is not None: - self.model.current = self.model.get_matching_source( - self.ai_source_id) + self.model.current = self.model.get_matching_source(self.ai_source_id) self.app.hub.subscribe( - (InstallerChannels.CONFIGURED, 'locale'), self._set_locale) + (InstallerChannels.CONFIGURED, "locale"), self._set_locale + ) def _set_locale(self): current = self.app.base_model.locale.selected_language - self.model.lang = current.split('_')[0] + self.model.lang = current.split("_")[0] async def GET(self) -> SourceSelectionAndSetting: cur_lang = self.app.base_model.locale.selected_language - cur_lang = cur_lang.rsplit('.', 1)[0] + cur_lang = cur_lang.rsplit(".", 1)[0] return SourceSelectionAndSetting( - [ - convert_source(source, cur_lang) - for source in self.model.sources - ], + [convert_source(source, cur_lang) for source in self.model.sources], self.model.current.id, - search_drivers=self.model.search_drivers) + search_drivers=self.model.search_drivers, + ) - def get_handler(self, variation_name: Optional[str] = None) \ - -> AbstractSourceHandler: + def get_handler( + self, variation_name: Optional[str] = None + ) -> AbstractSourceHandler: handler = get_handler_for_source( - sanitize_source(self.model.get_source(variation_name))) + sanitize_source(self.model.get_source(variation_name)) + ) if self.app.opts.dry_run: - handler = TrivialSourceHandler('/') + handler = TrivialSourceHandler("/") return handler async def configured(self): await super().configured() self.app.base_model.set_source_variant(self.model.current.variant) - async def POST(self, source_id: str, - search_drivers: bool = False) -> None: + async def POST(self, source_id: str, search_drivers: bool = False) -> None: self.model.search_drivers = search_drivers with contextlib.suppress(KeyError): self.model.current = self.model.get_matching_source(source_id) diff --git a/subiquity/server/controllers/ssh.py b/subiquity/server/controllers/ssh.py index d3ff5772..3bd84a78 100644 --- a/subiquity/server/controllers/ssh.py +++ b/subiquity/server/controllers/ssh.py @@ -22,35 +22,30 @@ from subiquity.common.types import ( SSHFetchIdResponse, SSHFetchIdStatus, SSHIdentity, - ) +) from subiquity.server.controller import SubiquityController -from subiquity.server.ssh import ( - SSHFetchError, - SSHKeyFetcher, - DryRunSSHKeyFetcher, - ) +from subiquity.server.ssh import DryRunSSHKeyFetcher, SSHFetchError, SSHKeyFetcher -log = logging.getLogger('subiquity.server.controllers.ssh') +log = logging.getLogger("subiquity.server.controllers.ssh") class SSHController(SubiquityController): - endpoint = API.ssh autoinstall_key = model_name = "ssh" autoinstall_schema = { - 'type': 'object', - 'properties': { - 'install-server': {'type': 'boolean'}, - 'authorized-keys': { - 'type': 'array', - 'items': {'type': 'string'}, - }, - 'allow-pw': {'type': 'boolean'}, + "type": "object", + "properties": { + "install-server": {"type": "boolean"}, + "authorized-keys": { + "type": "array", + "items": {"type": "string"}, + }, + "allow-pw": {"type": "boolean"}, }, } - interactive_for_variants = {'server'} + interactive_for_variants = {"server"} def __init__(self, app): super().__init__(app) @@ -62,22 +57,21 @@ class SSHController(SubiquityController): def load_autoinstall_data(self, data): if data is None: return - self.model.install_server = data.get('install-server', False) - self.model.authorized_keys = data.get('authorized-keys', []) - self.model.pwauth = data.get( - 'allow-pw', not self.model.authorized_keys) + self.model.install_server = data.get("install-server", False) + self.model.authorized_keys = data.get("authorized-keys", []) + self.model.pwauth = data.get("allow-pw", not self.model.authorized_keys) def make_autoinstall(self): return { - 'install-server': self.model.install_server, - 'authorized-keys': self.model.authorized_keys, - 'allow-pw': self.model.pwauth, - } + "install-server": self.model.install_server, + "authorized-keys": self.model.authorized_keys, + "allow-pw": self.model.pwauth, + } async def GET(self) -> SSHData: return SSHData( - install_server=self.model.install_server, - allow_pw=self.model.pwauth) + install_server=self.model.install_server, allow_pw=self.model.pwauth + ) async def POST(self, data: SSHData) -> None: self.model.install_server = data.install_server @@ -90,22 +84,26 @@ class SSHController(SubiquityController): try: for key_material in await self.fetcher.fetch_keys_for_id(user_id): - fingerprint = await self.fetcher.gen_fingerprint_for_key( - key_material) + fingerprint = await self.fetcher.gen_fingerprint_for_key(key_material) fingerprint = fingerprint.replace( - f'# ssh-import-id {user_id}', '' - ).strip() + f"# ssh-import-id {user_id}", "" + ).strip() - key_type, key, key_comment = key_material.split( - ' ', maxsplit=2) - identities.append(SSHIdentity( - key_type=key_type, key=key, key_comment=key_comment, - key_fingerprint=fingerprint - )) - return SSHFetchIdResponse(status=SSHFetchIdStatus.OK, - identities=identities, error=None) + key_type, key, key_comment = key_material.split(" ", maxsplit=2) + identities.append( + SSHIdentity( + key_type=key_type, + key=key, + key_comment=key_comment, + key_fingerprint=fingerprint, + ) + ) + return SSHFetchIdResponse( + status=SSHFetchIdStatus.OK, identities=identities, error=None + ) except SSHFetchError as exc: - return SSHFetchIdResponse(status=exc.status, - identities=None, error=exc.reason) + return SSHFetchIdResponse( + status=exc.status, identities=None, error=exc.reason + ) diff --git a/subiquity/server/controllers/tests/test_ad.py b/subiquity/server/controllers/tests/test_ad.py index 0976197f..425d2d79 100644 --- a/subiquity/server/controllers/tests/test_ad.py +++ b/subiquity/server/controllers/tests/test_ad.py @@ -13,10 +13,8 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from unittest import ( - TestCase, - IsolatedAsyncioTestCase, -) +from unittest import IsolatedAsyncioTestCase, TestCase + from subiquity.common.types import ( AdAdminNameValidation, AdConnectionInfo, @@ -24,109 +22,113 @@ from subiquity.common.types import ( AdJoinResult, AdPasswordValidation, ) -from subiquity.server.controllers.ad import ( - AdController, - AdValidators, -) from subiquity.models.ad import AdModel +from subiquity.server.controllers.ad import AdController, AdValidators from subiquitycore.tests.mocks import make_app class TestADValidation(TestCase): def test_password(self): # We curently only check for empty passwords. - passwd = '' + passwd = "" result = AdValidators.password(passwd) self.assertEqual(AdPasswordValidation.EMPTY, result) - passwd = 'u' + passwd = "u" result = AdValidators.password(passwd) self.assertEqual(AdPasswordValidation.OK, result) def test_domain_name(self): - domain = '' + domain = "" result = AdValidators.domain_name(domain) self.assertEqual({AdDomainNameValidation.EMPTY}, result) - domain = 'u'*64 + domain = "u" * 64 result = AdValidators.domain_name(domain) self.assertEqual({AdDomainNameValidation.TOO_LONG}, result) - domain = '..ubuntu.com' + domain = "..ubuntu.com" result = AdValidators.domain_name(domain) - self.assertEqual({AdDomainNameValidation.MULTIPLE_DOTS, - AdDomainNameValidation.START_DOT}, result) + self.assertEqual( + {AdDomainNameValidation.MULTIPLE_DOTS, AdDomainNameValidation.START_DOT}, + result, + ) - domain = '.ubuntu.com.' + domain = ".ubuntu.com." result = AdValidators.domain_name(domain) - self.assertEqual({AdDomainNameValidation.START_DOT, - AdDomainNameValidation.END_DOT}, result) + self.assertEqual( + {AdDomainNameValidation.START_DOT, AdDomainNameValidation.END_DOT}, result + ) - domain = '-ubuntu.com-' + domain = "-ubuntu.com-" result = AdValidators.domain_name(domain) - self.assertEqual({AdDomainNameValidation.START_HYPHEN, - AdDomainNameValidation.END_HYPHEN}, result) + self.assertEqual( + {AdDomainNameValidation.START_HYPHEN, AdDomainNameValidation.END_HYPHEN}, + result, + ) - domain = 'ubuntu^pro.com' + domain = "ubuntu^pro.com" result = AdValidators.domain_name(domain) self.assertEqual({AdDomainNameValidation.INVALID_CHARS}, result) - domain = 'ubuntu-p^ro.com.' + domain = "ubuntu-p^ro.com." result = AdValidators.domain_name(domain) - self.assertEqual({AdDomainNameValidation.INVALID_CHARS, - AdDomainNameValidation.END_DOT}, result) + self.assertEqual( + {AdDomainNameValidation.INVALID_CHARS, AdDomainNameValidation.END_DOT}, + result, + ) - domain = 'ubuntupro.com' + domain = "ubuntupro.com" result = AdValidators.domain_name(domain) self.assertEqual({AdDomainNameValidation.OK}, result) def test_username(self): - admin = '' + admin = "" result = AdValidators.admin_user_name(admin) self.assertEqual(AdAdminNameValidation.EMPTY, result) - admin = 'ubuntu;pro' + admin = "ubuntu;pro" result = AdValidators.admin_user_name(admin) self.assertEqual(AdAdminNameValidation.INVALID_CHARS, result) - admin = 'ubuntu:pro' + admin = "ubuntu:pro" result = AdValidators.admin_user_name(admin) self.assertEqual(AdAdminNameValidation.INVALID_CHARS, result) - admin = '=ubuntu' + admin = "=ubuntu" result = AdValidators.admin_user_name(admin) self.assertEqual(AdAdminNameValidation.INVALID_CHARS, result) - admin = 'ubuntu@pro' + admin = "ubuntu@pro" result = AdValidators.admin_user_name(admin) self.assertEqual(AdAdminNameValidation.INVALID_CHARS, result) - admin = 'ubuntu+pro' + admin = "ubuntu+pro" result = AdValidators.admin_user_name(admin) self.assertEqual(AdAdminNameValidation.INVALID_CHARS, result) - admin = 'ubuntu\\' + admin = "ubuntu\\" result = AdValidators.admin_user_name(admin) self.assertEqual(AdAdminNameValidation.INVALID_CHARS, result) - admin = 'ubuntu\"pro' + admin = 'ubuntu"pro' result = AdValidators.admin_user_name(admin) self.assertEqual(AdAdminNameValidation.INVALID_CHARS, result) - admin = 'ubuntu[pro' + admin = "ubuntu[pro" result = AdValidators.admin_user_name(admin) self.assertEqual(AdAdminNameValidation.INVALID_CHARS, result) - admin = 'ubuntu>' + admin = "ubuntu>" result = AdValidators.admin_user_name(admin) self.assertEqual(AdAdminNameValidation.INVALID_CHARS, result) - admin = 'ubuntu*pro' + admin = "ubuntu*pro" result = AdValidators.admin_user_name(admin) self.assertEqual(AdAdminNameValidation.INVALID_CHARS, result) # Notice that lowercase is not required. - admin = r'$Ubuntu{}' + admin = r"$Ubuntu{}" result = AdValidators.admin_user_name(admin) self.assertEqual(AdAdminNameValidation.OK, result) @@ -144,36 +146,44 @@ class TestAdJoin(IsolatedAsyncioTestCase): async def test_join_Unknown(self): # Result remains UNKNOWN while AdController.join_domain is not called. - self.controller.model.set(AdConnectionInfo(domain_name='ubuntu.com', - admin_name='Helper', - password='1234')) + self.controller.model.set( + AdConnectionInfo( + domain_name="ubuntu.com", admin_name="Helper", password="1234" + ) + ) result = await self.controller.join_result_GET(wait=False) self.assertEqual(result, AdJoinResult.UNKNOWN) async def test_join_OK(self): # The equivalent of a successful POST - self.controller.model.set(AdConnectionInfo(domain_name='ubuntu.com', - admin_name='Helper', - password='1234')) + self.controller.model.set( + AdConnectionInfo( + domain_name="ubuntu.com", admin_name="Helper", password="1234" + ) + ) # Mimics a client requesting the join result. Blocking by default. result = self.controller.join_result_GET() # Mimics a calling from the install controller. - await self.controller.join_domain('this', 'AD Join') + await self.controller.join_domain("this", "AD Join") self.assertEqual(await result, AdJoinResult.OK) async def test_join_Join_Error(self): - self.controller.model.set(AdConnectionInfo(domain_name='jubuntu.com', - admin_name='Helper', - password='1234')) - await self.controller.join_domain('this', 'AD Join') + self.controller.model.set( + AdConnectionInfo( + domain_name="jubuntu.com", admin_name="Helper", password="1234" + ) + ) + await self.controller.join_domain("this", "AD Join") result = await self.controller.join_result_GET(wait=True) self.assertEqual(result, AdJoinResult.JOIN_ERROR) async def test_join_Pam_Error(self): - self.controller.model.set(AdConnectionInfo(domain_name='pubuntu.com', - admin_name='Helper', - password='1234')) - await self.controller.join_domain('this', 'AD Join') + self.controller.model.set( + AdConnectionInfo( + domain_name="pubuntu.com", admin_name="Helper", password="1234" + ) + ) + await self.controller.join_domain("this", "AD Join") result = await self.controller.join_result_GET(wait=True) self.assertEqual(result, AdJoinResult.PAM_ERROR) diff --git a/subiquity/server/controllers/tests/test_filesystem.py b/subiquity/server/controllers/tests/test_filesystem.py index 957f2ebb..7fd620ef 100644 --- a/subiquity/server/controllers/tests/test_filesystem.py +++ b/subiquity/server/controllers/tests/test_filesystem.py @@ -14,18 +14,11 @@ # along with this program. If not, see . import copy -from unittest import mock, IsolatedAsyncioTestCase import uuid +from unittest import IsolatedAsyncioTestCase, mock from curtin.commands.extract import TrivialSourceHandler -from subiquitycore.tests.parameterized import parameterized - -from subiquitycore.snapd import AsyncSnapd, get_fake_connection -from subiquitycore.tests.mocks import make_app -from subiquitycore.utils import matching_dicts -from subiquitycore.tests.util import random_string - from subiquity.common.filesystem import gaps, labels from subiquity.common.filesystem.actions import DeviceAction from subiquity.common.types import ( @@ -34,9 +27,9 @@ from subiquity.common.types import ( Gap, GapUsable, GuidedCapability, + GuidedChoiceV2, GuidedDisallowedCapability, GuidedDisallowedCapabilityReason, - GuidedChoiceV2, GuidedStorageTargetManual, GuidedStorageTargetReformat, GuidedStorageTargetResize, @@ -46,7 +39,7 @@ from subiquity.common.types import ( ProbeStatus, ReformatDisk, SizingPolicy, - ) +) from subiquity.models.filesystem import dehumanize_size from subiquity.models.source import CatalogEntryVariation from subiquity.models.tests.test_filesystem import ( @@ -54,19 +47,24 @@ from subiquity.models.tests.test_filesystem import ( make_disk, make_model, make_partition, - ) +) from subiquity.server import snapdapi from subiquity.server.controllers.filesystem import ( DRY_RUN_RESET_SIZE, FilesystemController, VariationInfo, - ) +) from subiquity.server.dryrun import DRConfig +from subiquitycore.snapd import AsyncSnapd, get_fake_connection +from subiquitycore.tests.mocks import make_app +from subiquitycore.tests.parameterized import parameterized +from subiquitycore.tests.util import random_string +from subiquitycore.utils import matching_dicts -bootloaders = [(bl, ) for bl in list(Bootloader)] -bootloaders_and_ptables = [(bl, pt) - for bl in list(Bootloader) - for pt in ('gpt', 'msdos', 'vtoc')] +bootloaders = [(bl,) for bl in list(Bootloader)] +bootloaders_and_ptables = [ + (bl, pt) for bl in list(Bootloader) for pt in ("gpt", "msdos", "vtoc") +] default_capabilities = [ @@ -74,53 +72,54 @@ default_capabilities = [ GuidedCapability.LVM, GuidedCapability.LVM_LUKS, GuidedCapability.ZFS, - ] +] default_capabilities_disallowed_too_small = [ GuidedDisallowedCapability( - capability=cap, - reason=GuidedDisallowedCapabilityReason.TOO_SMALL) - for cap in default_capabilities] + capability=cap, reason=GuidedDisallowedCapabilityReason.TOO_SMALL + ) + for cap in default_capabilities +] class TestSubiquityControllerFilesystem(IsolatedAsyncioTestCase): - MOCK_PREFIX = 'subiquity.server.controllers.filesystem.' + MOCK_PREFIX = "subiquity.server.controllers.filesystem." def setUp(self): self.app = make_app() - self.app.opts.bootloader = 'UEFI' + self.app.opts.bootloader = "UEFI" self.app.report_start_event = mock.Mock() self.app.report_finish_event = mock.Mock() self.app.prober = mock.Mock() self.app.prober.get_storage = mock.AsyncMock() - self.app.block_log_dir = '/inexistent' + self.app.block_log_dir = "/inexistent" self.app.note_file_for_apport = mock.Mock() self.fsc = FilesystemController(app=self.app) self.fsc._configured = True async def test_probe_restricted(self): await self.fsc._probe_once(context=None, restricted=True) - expected = {'blockdev', 'filesystem'} + expected = {"blockdev", "filesystem"} self.app.prober.get_storage.assert_called_with(expected) async def test_probe_os_prober_false(self): self.app.opts.use_os_prober = False await self.fsc._probe_once(context=None, restricted=False) actual = self.app.prober.get_storage.call_args.args[0] - self.assertTrue({'defaults'} <= actual) - self.assertNotIn('os', actual) + self.assertTrue({"defaults"} <= actual) + self.assertNotIn("os", actual) async def test_probe_os_prober_true(self): self.app.opts.use_os_prober = True await self.fsc._probe_once(context=None, restricted=False) actual = self.app.prober.get_storage.call_args.args[0] - self.assertTrue({'defaults', 'os'} <= actual) + self.assertTrue({"defaults", "os"} <= actual) async def test_probe_once_fs_configured(self): self.fsc._configured = True self.fsc.queued_probe_data = None - with mock.patch.object(self.fsc.model, 'load_probe_data') as load: + with mock.patch.object(self.fsc.model, "load_probe_data") as load: await self.fsc._probe_once(restricted=True) self.assertIsNone(self.fsc.queued_probe_data) load.assert_not_called() @@ -130,13 +129,13 @@ class TestSubiquityControllerFilesystem(IsolatedAsyncioTestCase): self.fsc.run_autoinstall_guided = mock.AsyncMock() self.fsc.ai_data = { - 'layout': {'name': 'direct'}, + "layout": {"name": "direct"}, } await self.fsc.convert_autoinstall_config() curtin_cfg = model.render() - self.assertNotIn('grub', curtin_cfg) - self.assertNotIn('swap', curtin_cfg) + self.assertNotIn("grub", curtin_cfg) + self.assertNotIn("swap", curtin_cfg) @parameterized.expand(((True,), (False,))) async def test_layout_plus_grub(self, reorder_uefi): @@ -144,61 +143,56 @@ class TestSubiquityControllerFilesystem(IsolatedAsyncioTestCase): self.fsc.run_autoinstall_guided = mock.AsyncMock() self.fsc.ai_data = { - 'layout': {'name': 'direct'}, - 'grub': {'reorder_uefi': reorder_uefi} + "layout": {"name": "direct"}, + "grub": {"reorder_uefi": reorder_uefi}, } await self.fsc.convert_autoinstall_config() curtin_cfg = model.render() - self.assertEqual(reorder_uefi, curtin_cfg['grub']['reorder_uefi']) + self.assertEqual(reorder_uefi, curtin_cfg["grub"]["reorder_uefi"]) - @mock.patch('subiquity.server.controllers.filesystem.open', - mock.mock_open()) + @mock.patch("subiquity.server.controllers.filesystem.open", mock.mock_open()) async def test_probe_once_locked_probe_data(self): self.fsc._configured = False self.fsc.locked_probe_data = True self.fsc.queued_probe_data = None self.app.prober.get_storage = mock.AsyncMock(return_value={}) - with mock.patch.object(self.fsc.model, 'load_probe_data') as load: + with mock.patch.object(self.fsc.model, "load_probe_data") as load: await self.fsc._probe_once(restricted=True) self.assertEqual(self.fsc.queued_probe_data, {}) load.assert_not_called() - @parameterized.expand(((0,), ('1G',))) + @parameterized.expand(((0,), ("1G",))) async def test_layout_plus_swap(self, swapsize): self.fsc.model = model = make_model(Bootloader.UEFI) self.fsc.run_autoinstall_guided = mock.AsyncMock() - self.fsc.ai_data = { - 'layout': {'name': 'direct'}, - 'swap': {'size': swapsize} - } + self.fsc.ai_data = {"layout": {"name": "direct"}, "swap": {"size": swapsize}} await self.fsc.convert_autoinstall_config() curtin_cfg = model.render() - self.assertEqual(swapsize, curtin_cfg['swap']['size']) + self.assertEqual(swapsize, curtin_cfg["swap"]["size"]) - @mock.patch('subiquity.server.controllers.filesystem.open', - mock.mock_open()) + @mock.patch("subiquity.server.controllers.filesystem.open", mock.mock_open()) async def test_probe_once_unlocked_probe_data(self): self.fsc._configured = False self.fsc.locked_probe_data = False self.fsc.queued_probe_data = None self.app.prober.get_storage = mock.AsyncMock(return_value={}) - with mock.patch.object(self.fsc.model, 'load_probe_data') as load: + with mock.patch.object(self.fsc.model, "load_probe_data") as load: await self.fsc._probe_once(restricted=True) self.assertIsNone(self.fsc.queued_probe_data, {}) load.assert_called_once_with({}) async def test_v2_reset_POST_no_queued_data(self): self.fsc.queued_probe_data = None - with mock.patch.object(self.fsc.model, 'load_probe_data') as load: + with mock.patch.object(self.fsc.model, "load_probe_data") as load: await self.fsc.v2_reset_POST() load.assert_not_called() async def test_v2_reset_POST_queued_data(self): self.fsc.queued_probe_data = {} - with mock.patch.object(self.fsc.model, 'load_probe_data') as load: + with mock.patch.object(self.fsc.model, "load_probe_data") as load: await self.fsc.v2_reset_POST() load.assert_called_once_with({}) @@ -209,63 +203,61 @@ class TestSubiquityControllerFilesystem(IsolatedAsyncioTestCase): async def test_v2_reformat_disk_POST(self): self.fsc.locked_probe_data = False - with mock.patch.object(self.fsc, 'reformat') as reformat: - await self.fsc.v2_reformat_disk_POST( - ReformatDisk(disk_id='dev-sda')) + with mock.patch.object(self.fsc, "reformat") as reformat: + await self.fsc.v2_reformat_disk_POST(ReformatDisk(disk_id="dev-sda")) self.assertTrue(self.fsc.locked_probe_data) reformat.assert_called_once() - @mock.patch(MOCK_PREFIX + 'boot.is_boot_device', - mock.Mock(return_value=True)) + @mock.patch(MOCK_PREFIX + "boot.is_boot_device", mock.Mock(return_value=True)) async def test_v2_add_boot_partition_POST_existing_bootloader(self): self.fsc.locked_probe_data = False - with mock.patch.object(self.fsc, 'add_boot_disk') as add_boot_disk: - with self.assertRaisesRegex(ValueError, - r'already\ has\ bootloader'): - await self.fsc.v2_add_boot_partition_POST('dev-sda') + with mock.patch.object(self.fsc, "add_boot_disk") as add_boot_disk: + with self.assertRaisesRegex(ValueError, r"already\ has\ bootloader"): + await self.fsc.v2_add_boot_partition_POST("dev-sda") self.assertTrue(self.fsc.locked_probe_data) add_boot_disk.assert_not_called() - @mock.patch(MOCK_PREFIX + 'boot.is_boot_device', - mock.Mock(return_value=False)) - @mock.patch(MOCK_PREFIX + 'DeviceAction.supported', - mock.Mock(return_value=[])) + @mock.patch(MOCK_PREFIX + "boot.is_boot_device", mock.Mock(return_value=False)) + @mock.patch(MOCK_PREFIX + "DeviceAction.supported", mock.Mock(return_value=[])) async def test_v2_add_boot_partition_POST_not_supported(self): self.fsc.locked_probe_data = False - with mock.patch.object(self.fsc, 'add_boot_disk') as add_boot_disk: - with self.assertRaisesRegex(ValueError, - r'does\ not\ support\ boot'): - await self.fsc.v2_add_boot_partition_POST('dev-sda') + with mock.patch.object(self.fsc, "add_boot_disk") as add_boot_disk: + with self.assertRaisesRegex(ValueError, r"does\ not\ support\ boot"): + await self.fsc.v2_add_boot_partition_POST("dev-sda") self.assertTrue(self.fsc.locked_probe_data) add_boot_disk.assert_not_called() - @mock.patch(MOCK_PREFIX + 'boot.is_boot_device', - mock.Mock(return_value=False)) - @mock.patch(MOCK_PREFIX + 'DeviceAction.supported', - mock.Mock(return_value=[DeviceAction.TOGGLE_BOOT])) + @mock.patch(MOCK_PREFIX + "boot.is_boot_device", mock.Mock(return_value=False)) + @mock.patch( + MOCK_PREFIX + "DeviceAction.supported", + mock.Mock(return_value=[DeviceAction.TOGGLE_BOOT]), + ) async def test_v2_add_boot_partition_POST(self): self.fsc.locked_probe_data = False - with mock.patch.object(self.fsc, 'add_boot_disk') as add_boot_disk: - await self.fsc.v2_add_boot_partition_POST('dev-sda') + with mock.patch.object(self.fsc, "add_boot_disk") as add_boot_disk: + await self.fsc.v2_add_boot_partition_POST("dev-sda") self.assertTrue(self.fsc.locked_probe_data) add_boot_disk.assert_called_once() async def test_v2_add_partition_POST_changing_boot(self): self.fsc.locked_probe_data = False data = AddPartitionV2( - disk_id='dev-sda', - partition=Partition( - format='ext4', - mount='/', - boot=True, - ), gap=Gap( - offset=1 << 20, - size=1000 << 20, - usable=GapUsable.YES, - )) - with mock.patch.object(self.fsc, 'create_partition') as create_part: - with self.assertRaisesRegex(ValueError, - r'does\ not\ support\ changing\ boot'): + disk_id="dev-sda", + partition=Partition( + format="ext4", + mount="/", + boot=True, + ), + gap=Gap( + offset=1 << 20, + size=1000 << 20, + usable=GapUsable.YES, + ), + ) + with mock.patch.object(self.fsc, "create_partition") as create_part: + with self.assertRaisesRegex( + ValueError, r"does\ not\ support\ changing\ boot" + ): await self.fsc.v2_add_partition_POST(data) self.assertTrue(self.fsc.locked_probe_data) create_part.assert_not_called() @@ -273,37 +265,41 @@ class TestSubiquityControllerFilesystem(IsolatedAsyncioTestCase): async def test_v2_add_partition_POST_too_large(self): self.fsc.locked_probe_data = False data = AddPartitionV2( - disk_id='dev-sda', - partition=Partition( - format='ext4', - mount='/', - size=2000 << 20, - ), gap=Gap( - offset=1 << 20, - size=1000 << 20, - usable=GapUsable.YES, - )) - with mock.patch.object(self.fsc, 'create_partition') as create_part: - with self.assertRaisesRegex(ValueError, r'too\ large'): + disk_id="dev-sda", + partition=Partition( + format="ext4", + mount="/", + size=2000 << 20, + ), + gap=Gap( + offset=1 << 20, + size=1000 << 20, + usable=GapUsable.YES, + ), + ) + with mock.patch.object(self.fsc, "create_partition") as create_part: + with self.assertRaisesRegex(ValueError, r"too\ large"): await self.fsc.v2_add_partition_POST(data) self.assertTrue(self.fsc.locked_probe_data) create_part.assert_not_called() - @mock.patch(MOCK_PREFIX + 'gaps.at_offset') + @mock.patch(MOCK_PREFIX + "gaps.at_offset") async def test_v2_add_partition_POST(self, at_offset): at_offset.split = mock.Mock(return_value=[mock.Mock()]) self.fsc.locked_probe_data = False data = AddPartitionV2( - disk_id='dev-sda', - partition=Partition( - format='ext4', - mount='/', - ), gap=Gap( - offset=1 << 20, - size=1000 << 20, - usable=GapUsable.YES, - )) - with mock.patch.object(self.fsc, 'create_partition') as create_part: + disk_id="dev-sda", + partition=Partition( + format="ext4", + mount="/", + ), + gap=Gap( + offset=1 << 20, + size=1000 << 20, + usable=GapUsable.YES, + ), + ) + with mock.patch.object(self.fsc, "create_partition") as create_part: await self.fsc.v2_add_partition_POST(data) self.assertTrue(self.fsc.locked_probe_data) create_part.assert_called_once() @@ -311,11 +307,11 @@ class TestSubiquityControllerFilesystem(IsolatedAsyncioTestCase): async def test_v2_delete_partition_POST(self): self.fsc.locked_probe_data = False data = ModifyPartitionV2( - disk_id='dev-sda', - partition=Partition(number=1), - ) - with mock.patch.object(self.fsc, 'delete_partition') as del_part: - with mock.patch.object(self.fsc, 'get_partition'): + disk_id="dev-sda", + partition=Partition(number=1), + ) + with mock.patch.object(self.fsc, "delete_partition") as del_part: + with mock.patch.object(self.fsc, "get_partition"): await self.fsc.v2_delete_partition_POST(data) self.assertTrue(self.fsc.locked_probe_data) del_part.assert_called_once() @@ -323,14 +319,13 @@ class TestSubiquityControllerFilesystem(IsolatedAsyncioTestCase): async def test_v2_edit_partition_POST_change_boot(self): self.fsc.locked_probe_data = False data = ModifyPartitionV2( - disk_id='dev-sda', - partition=Partition(number=1, boot=True), - ) + disk_id="dev-sda", + partition=Partition(number=1, boot=True), + ) existing = Partition(number=1, size=1000 << 20, boot=False) - with mock.patch.object(self.fsc, 'partition_disk_handler') as handler: - with mock.patch.object(self.fsc, 'get_partition', - return_value=existing): - with self.assertRaisesRegex(ValueError, r'changing\ boot'): + with mock.patch.object(self.fsc, "partition_disk_handler") as handler: + with mock.patch.object(self.fsc, "get_partition", return_value=existing): + with self.assertRaisesRegex(ValueError, r"changing\ boot"): await self.fsc.v2_edit_partition_POST(data) self.assertTrue(self.fsc.locked_probe_data) handler.assert_not_called() @@ -338,13 +333,12 @@ class TestSubiquityControllerFilesystem(IsolatedAsyncioTestCase): async def test_v2_edit_partition_POST(self): self.fsc.locked_probe_data = False data = ModifyPartitionV2( - disk_id='dev-sda', - partition=Partition(number=1), - ) + disk_id="dev-sda", + partition=Partition(number=1), + ) existing = Partition(number=1, size=1000 << 20, boot=False) - with mock.patch.object(self.fsc, 'partition_disk_handler') as handler: - with mock.patch.object(self.fsc, 'get_partition', - return_value=existing): + with mock.patch.object(self.fsc, "partition_disk_handler") as handler: + with mock.patch.object(self.fsc, "get_partition", return_value=existing): await self.fsc.v2_edit_partition_POST(data) self.assertTrue(self.fsc.locked_probe_data) handler.assert_called_once() @@ -352,12 +346,12 @@ class TestSubiquityControllerFilesystem(IsolatedAsyncioTestCase): class TestGuided(IsolatedAsyncioTestCase): boot_expectations = [ - (Bootloader.UEFI, 'gpt', '/boot/efi'), - (Bootloader.UEFI, 'msdos', '/boot/efi'), - (Bootloader.BIOS, 'gpt', None), + (Bootloader.UEFI, "gpt", "/boot/efi"), + (Bootloader.UEFI, "msdos", "/boot/efi"), + (Bootloader.BIOS, "gpt", None), # BIOS + msdos is different - (Bootloader.PREP, 'gpt', None), - (Bootloader.PREP, 'msdos', None), + (Bootloader.PREP, "gpt", None), + (Bootloader.PREP, "msdos", None), ] async def _guided_setup(self, bootloader, ptable, storage_version=None): @@ -368,11 +362,9 @@ class TestGuided(IsolatedAsyncioTestCase): self.controller._examine_systems_task.start_sync() self.app.dr_cfg = DRConfig() self.app.base_model.source.current.variations = { - 'default': CatalogEntryVariation( - path='', size=1), - } - self.app.controllers.Source.get_handler.return_value = \ - TrivialSourceHandler('') + "default": CatalogEntryVariation(path="", size=1), + } + self.app.controllers.Source.get_handler.return_value = TrivialSourceHandler("") await self.controller._examine_systems_task.wait() self.model = make_model(bootloader, storage_version) self.controller.model = self.model @@ -382,54 +374,60 @@ class TestGuided(IsolatedAsyncioTestCase): async def test_guided_direct(self, bootloader, ptable, p1mnt): await self._guided_setup(bootloader, ptable) target = GuidedStorageTargetReformat( - disk_id=self.d1.id, allowed=default_capabilities) + disk_id=self.d1.id, allowed=default_capabilities + ) await self.controller.guided( - GuidedChoiceV2(target=target, capability=GuidedCapability.DIRECT)) + GuidedChoiceV2(target=target, capability=GuidedCapability.DIRECT) + ) [d1p1, d1p2] = self.d1.partitions() self.assertEqual(p1mnt, d1p1.mount) - self.assertEqual('/', d1p2.mount) + self.assertEqual("/", d1p2.mount) self.assertFalse(d1p1.preserve) self.assertFalse(d1p2.preserve) self.assertIsNone(gaps.largest_gap(self.d1)) async def test_guided_reset_partition(self): - await self._guided_setup(Bootloader.UEFI, 'gpt') + await self._guided_setup(Bootloader.UEFI, "gpt") target = GuidedStorageTargetReformat( - disk_id=self.d1.id, allowed=default_capabilities) + disk_id=self.d1.id, allowed=default_capabilities + ) await self.controller.guided( GuidedChoiceV2( - target=target, - capability=GuidedCapability.DIRECT, - reset_partition=True)) + target=target, capability=GuidedCapability.DIRECT, reset_partition=True + ) + ) [d1p1, d1p2, d1p3] = self.d1.partitions() - self.assertEqual('/boot/efi', d1p1.mount) + self.assertEqual("/boot/efi", d1p1.mount) self.assertEqual(None, d1p2.mount) self.assertEqual(DRY_RUN_RESET_SIZE, d1p2.size) - self.assertEqual('/', d1p3.mount) + self.assertEqual("/", d1p3.mount) async def test_guided_reset_partition_only(self): - await self._guided_setup(Bootloader.UEFI, 'gpt') + await self._guided_setup(Bootloader.UEFI, "gpt") target = GuidedStorageTargetReformat( - disk_id=self.d1.id, allowed=default_capabilities) + disk_id=self.d1.id, allowed=default_capabilities + ) await self.controller.guided( GuidedChoiceV2( - target=target, - capability=GuidedCapability.DIRECT, - reset_partition=True), - reset_partition_only=True) + target=target, capability=GuidedCapability.DIRECT, reset_partition=True + ), + reset_partition_only=True, + ) [d1p1, d1p2] = self.d1.partitions() self.assertEqual(None, d1p1.mount) self.assertEqual(None, d1p2.mount) self.assertEqual(DRY_RUN_RESET_SIZE, d1p2.size) async def test_guided_direct_BIOS_MSDOS(self): - await self._guided_setup(Bootloader.BIOS, 'msdos') + await self._guided_setup(Bootloader.BIOS, "msdos") target = GuidedStorageTargetReformat( - disk_id=self.d1.id, allowed=default_capabilities) + disk_id=self.d1.id, allowed=default_capabilities + ) await self.controller.guided( - GuidedChoiceV2(target=target, capability=GuidedCapability.DIRECT)) + GuidedChoiceV2(target=target, capability=GuidedCapability.DIRECT) + ) [d1p1] = self.d1.partitions() - self.assertEqual('/', d1p1.mount) + self.assertEqual("/", d1p1.mount) self.assertFalse(d1p1.preserve) self.assertIsNone(gaps.largest_gap(self.d1)) @@ -437,30 +435,34 @@ class TestGuided(IsolatedAsyncioTestCase): async def test_guided_lvm(self, bootloader, ptable, p1mnt): await self._guided_setup(bootloader, ptable) target = GuidedStorageTargetReformat( - disk_id=self.d1.id, allowed=default_capabilities) - await self.controller.guided(GuidedChoiceV2( - target=target, capability=GuidedCapability.LVM)) + disk_id=self.d1.id, allowed=default_capabilities + ) + await self.controller.guided( + GuidedChoiceV2(target=target, capability=GuidedCapability.LVM) + ) [d1p1, d1p2, d1p3] = self.d1.partitions() self.assertEqual(p1mnt, d1p1.mount) - self.assertEqual('/boot', d1p2.mount) + self.assertEqual("/boot", d1p2.mount) self.assertEqual(None, d1p3.mount) self.assertFalse(d1p1.preserve) self.assertFalse(d1p2.preserve) self.assertFalse(d1p3.preserve) - [vg] = self.model._all(type='lvm_volgroup') + [vg] = self.model._all(type="lvm_volgroup") [part] = list(vg.devices) self.assertEqual(d1p3, part) self.assertIsNone(gaps.largest_gap(self.d1)) async def test_guided_lvm_BIOS_MSDOS(self): - await self._guided_setup(Bootloader.BIOS, 'msdos') + await self._guided_setup(Bootloader.BIOS, "msdos") target = GuidedStorageTargetReformat( - disk_id=self.d1.id, allowed=default_capabilities) + disk_id=self.d1.id, allowed=default_capabilities + ) await self.controller.guided( - GuidedChoiceV2(target=target, capability=GuidedCapability.LVM)) + GuidedChoiceV2(target=target, capability=GuidedCapability.LVM) + ) [d1p1, d1p2] = self.d1.partitions() - self.assertEqual('/boot', d1p1.mount) - [vg] = self.model._all(type='lvm_volgroup') + self.assertEqual("/boot", d1p1.mount) + [vg] = self.model._all(type="lvm_volgroup") [part] = list(vg.devices) self.assertEqual(d1p2, part) self.assertEqual(None, d1p2.mount) @@ -472,9 +474,11 @@ class TestGuided(IsolatedAsyncioTestCase): async def test_guided_zfs(self, bootloader, ptable, p1mnt): await self._guided_setup(bootloader, ptable) target = GuidedStorageTargetReformat( - disk_id=self.d1.id, allowed=default_capabilities) - await self.controller.guided(GuidedChoiceV2( - target=target, capability=GuidedCapability.ZFS)) + disk_id=self.d1.id, allowed=default_capabilities + ) + await self.controller.guided( + GuidedChoiceV2(target=target, capability=GuidedCapability.ZFS) + ) [d1p1, d1p2, d1p3] = self.d1.partitions() self.assertEqual(p1mnt, d1p1.mount) self.assertEqual(None, d1p2.mount) @@ -482,26 +486,28 @@ class TestGuided(IsolatedAsyncioTestCase): self.assertFalse(d1p1.preserve) self.assertFalse(d1p2.preserve) self.assertFalse(d1p3.preserve) - [rpool] = self.model._all(type='zpool', pool='rpool') - self.assertEqual('/', rpool.path) - [bpool] = self.model._all(type='zpool', pool='bpool') - self.assertEqual('/boot', bpool.path) + [rpool] = self.model._all(type="zpool", pool="rpool") + self.assertEqual("/", rpool.path) + [bpool] = self.model._all(type="zpool", pool="bpool") + self.assertEqual("/boot", bpool.path) async def test_guided_zfs_BIOS_MSDOS(self): - await self._guided_setup(Bootloader.BIOS, 'msdos') + await self._guided_setup(Bootloader.BIOS, "msdos") target = GuidedStorageTargetReformat( - disk_id=self.d1.id, allowed=default_capabilities) - await self.controller.guided(GuidedChoiceV2( - target=target, capability=GuidedCapability.ZFS)) + disk_id=self.d1.id, allowed=default_capabilities + ) + await self.controller.guided( + GuidedChoiceV2(target=target, capability=GuidedCapability.ZFS) + ) [d1p1, d1p2] = self.d1.partitions() self.assertEqual(None, d1p1.mount) self.assertEqual(None, d1p2.mount) self.assertFalse(d1p1.preserve) self.assertFalse(d1p2.preserve) - [rpool] = self.model._all(type='zpool', pool='rpool') - self.assertEqual('/', rpool.path) - [bpool] = self.model._all(type='zpool', pool='bpool') - self.assertEqual('/boot', bpool.path) + [rpool] = self.model._all(type="zpool", pool="rpool") + self.assertEqual("/", rpool.path) + [bpool] = self.model._all(type="zpool", pool="bpool") + self.assertEqual("/boot", bpool.path) async def _guided_side_by_side(self, bl, ptable): await self._guided_setup(bl, ptable, storage_version=2) @@ -511,62 +517,77 @@ class TestGuided(IsolatedAsyncioTestCase): if bl == Bootloader.UEFI: # let it pass the is_esp check p._info = FakeStorageInfo(size=p.size) - p._info.raw["ID_PART_ENTRY_TYPE"] = str(0xef) + p._info.raw["ID_PART_ENTRY_TYPE"] = str(0xEF) # Make it more interesting with other partitions. # Also create the extended part if needed. g = gaps.largest_gap(self.d1) - make_partition(self.model, self.d1, preserve=True, - size=10 << 30, offset=g.offset) - if ptable == 'msdos': + make_partition( + self.model, self.d1, preserve=True, size=10 << 30, offset=g.offset + ) + if ptable == "msdos": g = gaps.largest_gap(self.d1) - make_partition(self.model, self.d1, preserve=True, - flag='extended', size=g.size, offset=g.offset) + make_partition( + self.model, + self.d1, + preserve=True, + flag="extended", + size=g.size, + offset=g.offset, + ) g = gaps.largest_gap(self.d1) - make_partition(self.model, self.d1, preserve=True, - flag='logical', size=10 << 30, offset=g.offset) + make_partition( + self.model, + self.d1, + preserve=True, + flag="logical", + size=10 << 30, + offset=g.offset, + ) @parameterized.expand( - [(bl, pt, flag) - for bl in list(Bootloader) - for pt, flag in ( - ('msdos', 'logical'), - ('gpt', None) - )] + [ + (bl, pt, flag) + for bl in list(Bootloader) + for pt, flag in (("msdos", "logical"), ("gpt", None)) + ] ) async def test_guided_direct_side_by_side(self, bl, pt, flag): await self._guided_side_by_side(bl, pt) parts_before = self.d1._partitions.copy() gap = gaps.largest_gap(self.d1) target = GuidedStorageTargetUseGap( - disk_id=self.d1.id, gap=gap, allowed=default_capabilities) + disk_id=self.d1.id, gap=gap, allowed=default_capabilities + ) await self.controller.guided( - GuidedChoiceV2(target=target, capability=GuidedCapability.DIRECT)) + GuidedChoiceV2(target=target, capability=GuidedCapability.DIRECT) + ) parts_after = gaps.parts_and_gaps(self.d1)[:-1] self.assertEqual(parts_before, parts_after) p = gaps.parts_and_gaps(self.d1)[-1] - self.assertEqual('/', p.mount) + self.assertEqual("/", p.mount) self.assertEqual(flag, p.flag) @parameterized.expand( - [(bl, pt, flag) - for bl in list(Bootloader) - for pt, flag in ( - ('msdos', 'logical'), - ('gpt', None) - )] + [ + (bl, pt, flag) + for bl in list(Bootloader) + for pt, flag in (("msdos", "logical"), ("gpt", None)) + ] ) async def test_guided_lvm_side_by_side(self, bl, pt, flag): await self._guided_side_by_side(bl, pt) parts_before = self.d1._partitions.copy() gap = gaps.largest_gap(self.d1) target = GuidedStorageTargetUseGap( - disk_id=self.d1.id, gap=gap, allowed=default_capabilities) + disk_id=self.d1.id, gap=gap, allowed=default_capabilities + ) await self.controller.guided( - GuidedChoiceV2(target=target, capability=GuidedCapability.LVM)) + GuidedChoiceV2(target=target, capability=GuidedCapability.LVM) + ) parts_after = gaps.parts_and_gaps(self.d1)[:-2] self.assertEqual(parts_before, parts_after) p_boot, p_data = gaps.parts_and_gaps(self.d1)[-2:] - self.assertEqual('/boot', p_boot.mount) + self.assertEqual("/boot", p_boot.mount) self.assertEqual(flag, p_boot.flag) self.assertEqual(None, p_data.mount) self.assertEqual(flag, p_data.flag) @@ -578,11 +599,11 @@ class TestLayout(IsolatedAsyncioTestCase): self.app.opts.bootloader = None self.fsc = FilesystemController(app=self.app) - @parameterized.expand([('reformat_disk',), ('use_gap',)]) + @parameterized.expand([("reformat_disk",), ("use_gap",)]) async def test_good_modes(self, mode): self.fsc.validate_layout_mode(mode) - @parameterized.expand([('resize_biggest',), ('use_free',)]) + @parameterized.expand([("resize_biggest",), ("use_free",)]) async def test_bad_modes(self, mode): with self.assertRaises(ValueError): self.fsc.validate_layout_mode(mode) @@ -599,24 +620,28 @@ class TestGuidedV2(IsolatedAsyncioTestCase): self.fsc._examine_systems_task.start_sync() self.app.dr_cfg = DRConfig() self.app.base_model.source.current.variations = { - 'default': CatalogEntryVariation( - path='', size=1), - } - self.app.controllers.Source.get_handler.return_value = \ - TrivialSourceHandler('') + "default": CatalogEntryVariation(path="", size=1), + } + self.app.controllers.Source.get_handler.return_value = TrivialSourceHandler("") await self.fsc._examine_systems_task.wait() self.disk = make_disk(self.model, ptable=ptable, **kw) self.model.storage_version = 2 self.fs_probe = {} self.fsc.model._probe_data = { - 'blockdev': {}, - 'filesystem': self.fs_probe, - } + "blockdev": {}, + "filesystem": self.fs_probe, + } self.fsc._probe_task.task = mock.Mock() self.fsc._examine_systems_task.task = mock.Mock() - if bootloader == Bootloader.BIOS and ptable != 'msdos' and fix_bios: - make_partition(self.model, self.disk, preserve=True, - flag='bios_grub', size=1 << 20, offset=1 << 20) + if bootloader == Bootloader.BIOS and ptable != "msdos" and fix_bios: + make_partition( + self.model, + self.disk, + preserve=True, + flag="bios_grub", + size=1 << 20, + offset=1 << 20, + ) @parameterized.expand(bootloaders_and_ptables) async def test_blank_disk(self, bootloader, ptable): @@ -624,9 +649,10 @@ class TestGuidedV2(IsolatedAsyncioTestCase): await self._setup(bootloader, ptable, fix_bios=False) expected = [ GuidedStorageTargetReformat( - disk_id=self.disk.id, allowed=default_capabilities), + disk_id=self.disk.id, allowed=default_capabilities + ), GuidedStorageTargetManual(), - ] + ] resp = await self.fsc.v2_guided_GET() self.assertEqual(expected, resp.targets) self.assertEqual(ProbeStatus.DONE, resp.status) @@ -640,7 +666,7 @@ class TestGuidedV2(IsolatedAsyncioTestCase): self.assertEqual(ProbeStatus.PROBING, resp.status) async def test_manual(self): - await self._setup(Bootloader.UEFI, 'gpt') + await self._setup(Bootloader.UEFI, "gpt") guided_get_resp = await self.fsc.v2_guided_GET() [reformat, manual] = guided_get_resp.targets self.assertEqual(manual, GuidedStorageTargetManual()) @@ -656,10 +682,12 @@ class TestGuidedV2(IsolatedAsyncioTestCase): resp = await self.fsc.v2_guided_GET() expected = [ GuidedStorageTargetReformat( - disk_id=self.disk.id, allowed=[], - disallowed=default_capabilities_disallowed_too_small), + disk_id=self.disk.id, + allowed=[], + disallowed=default_capabilities_disallowed_too_small, + ), GuidedStorageTargetManual(), - ] + ] self.assertEqual(expected, resp.targets) @parameterized.expand(bootloaders_and_ptables) @@ -667,14 +695,16 @@ class TestGuidedV2(IsolatedAsyncioTestCase): await self._setup(bootloader, ptable, size=100 << 30) p = make_partition(self.model, self.disk, preserve=True, size=50 << 30) gap_offset = p.size + p.offset - self.fs_probe[p._path()] = {'ESTIMATED_MIN_SIZE': 1 << 20} + self.fs_probe[p._path()] = {"ESTIMATED_MIN_SIZE": 1 << 20} resp = await self.fsc.v2_guided_GET() reformat = resp.targets.pop(0) self.assertEqual( GuidedStorageTargetReformat( - disk_id=self.disk.id, allowed=default_capabilities), - reformat) + disk_id=self.disk.id, allowed=default_capabilities + ), + reformat, + ) use_gap = resp.targets.pop(0) self.assertEqual(self.disk.id, use_gap.disk_id) @@ -689,15 +719,18 @@ class TestGuidedV2(IsolatedAsyncioTestCase): @parameterized.expand(bootloaders_and_ptables) async def test_used_full_disk(self, bootloader, ptable): await self._setup(bootloader, ptable) - p = make_partition(self.model, self.disk, preserve=True, - size=gaps.largest_gap_size(self.disk)) - self.fs_probe[p._path()] = {'ESTIMATED_MIN_SIZE': 1 << 20} + p = make_partition( + self.model, self.disk, preserve=True, size=gaps.largest_gap_size(self.disk) + ) + self.fs_probe[p._path()] = {"ESTIMATED_MIN_SIZE": 1 << 20} resp = await self.fsc.v2_guided_GET() reformat = resp.targets.pop(0) self.assertEqual( GuidedStorageTargetReformat( - disk_id=self.disk.id, allowed=default_capabilities), - reformat) + disk_id=self.disk.id, allowed=default_capabilities + ), + reformat, + ) resize = resp.targets.pop(0) self.assertEqual(self.disk.id, resize.disk_id) @@ -710,17 +743,18 @@ class TestGuidedV2(IsolatedAsyncioTestCase): await self._setup(bootloader, ptable, size=250 << 30) # add an extra, filler, partition so that there is no use_gap result make_partition(self.model, self.disk, preserve=True, size=9 << 30) - p = make_partition(self.model, self.disk, preserve=True, - size=240 << 30) - self.fs_probe[p._path()] = {'ESTIMATED_MIN_SIZE': 40 << 30} + p = make_partition(self.model, self.disk, preserve=True, size=240 << 30) + self.fs_probe[p._path()] = {"ESTIMATED_MIN_SIZE": 40 << 30} self.fsc.calculate_suggested_install_min.return_value = 10 << 30 resp = await self.fsc.v2_guided_GET() possible = [t for t in resp.targets if t.allowed] reformat = possible.pop(0) self.assertEqual( GuidedStorageTargetReformat( - disk_id=self.disk.id, allowed=default_capabilities), - reformat) + disk_id=self.disk.id, allowed=default_capabilities + ), + reformat, + ) resize = possible.pop(0) expected = GuidedStorageTargetResize( @@ -730,7 +764,8 @@ class TestGuidedV2(IsolatedAsyncioTestCase): minimum=50 << 30, recommended=200 << 30, maximum=230 << 30, - allowed=default_capabilities) + allowed=default_capabilities, + ) self.assertEqual(expected, resize) self.assertEqual(1, len(possible)) @@ -738,7 +773,7 @@ class TestGuidedV2(IsolatedAsyncioTestCase): async def test_half_disk_reformat(self, bootloader, ptable): await self._setup(bootloader, ptable, size=100 << 30) p = make_partition(self.model, self.disk, preserve=True, size=50 << 30) - self.fs_probe[p._path()] = {'ESTIMATED_MIN_SIZE': 1 << 20} + self.fs_probe[p._path()] = {"ESTIMATED_MIN_SIZE": 1 << 20} guided_get_resp = await self.fsc.v2_guided_GET() reformat = guided_get_resp.targets.pop(0) @@ -750,8 +785,7 @@ class TestGuidedV2(IsolatedAsyncioTestCase): resize = guided_get_resp.targets.pop(0) self.assertTrue(isinstance(resize, GuidedStorageTargetResize)) - data = GuidedChoiceV2( - target=reformat, capability=GuidedCapability.DIRECT) + data = GuidedChoiceV2(target=reformat, capability=GuidedCapability.DIRECT) expected_config = copy.copy(data) resp = await self.fsc.v2_guided_POST(data=data) self.assertEqual(expected_config, resp.configured) @@ -765,7 +799,7 @@ class TestGuidedV2(IsolatedAsyncioTestCase): async def test_half_disk_use_gap(self, bootloader, ptable): await self._setup(bootloader, ptable, size=100 << 30) p = make_partition(self.model, self.disk, preserve=True, size=50 << 30) - self.fs_probe[p._path()] = {'ESTIMATED_MIN_SIZE': 1 << 20} + self.fs_probe[p._path()] = {"ESTIMATED_MIN_SIZE": 1 << 20} resp = await self.fsc.v2_GET() [orig_p, g] = resp.disks[0].partitions[-2:] @@ -777,8 +811,7 @@ class TestGuidedV2(IsolatedAsyncioTestCase): use_gap = guided_get_resp.targets.pop(0) self.assertTrue(isinstance(use_gap, GuidedStorageTargetUseGap)) self.assertEqual(g, use_gap.gap) - data = GuidedChoiceV2( - target=use_gap, capability=GuidedCapability.DIRECT) + data = GuidedChoiceV2(target=use_gap, capability=GuidedCapability.DIRECT) expected_config = copy.copy(data) resp = await self.fsc.v2_guided_POST(data=data) self.assertEqual(expected_config, resp.configured) @@ -787,8 +820,11 @@ class TestGuidedV2(IsolatedAsyncioTestCase): self.assertTrue(isinstance(resize, GuidedStorageTargetResize)) resp = await self.fsc.v2_GET() - existing_part = [p for p in resp.disks[0].partitions - if getattr(p, 'number', None) == orig_p.number][0] + existing_part = [ + p + for p in resp.disks[0].partitions + if getattr(p, "number", None) == orig_p.number + ][0] self.assertEqual(orig_p, existing_part) self.assertFalse(resp.need_root) self.assertFalse(resp.need_boot) @@ -798,7 +834,7 @@ class TestGuidedV2(IsolatedAsyncioTestCase): async def test_half_disk_resize(self, bootloader, ptable): await self._setup(bootloader, ptable, size=100 << 30) p = make_partition(self.model, self.disk, preserve=True, size=50 << 30) - self.fs_probe[p._path()] = {'ESTIMATED_MIN_SIZE': 1 << 20} + self.fs_probe[p._path()] = {"ESTIMATED_MIN_SIZE": 1 << 20} resp = await self.fsc.v2_GET() [orig_p, g] = resp.disks[0].partitions[-2:] @@ -815,27 +851,38 @@ class TestGuidedV2(IsolatedAsyncioTestCase): p_expected = copy.copy(orig_p) p_expected.size = resize.new_size = 20 << 30 p_expected.resize = True - data = GuidedChoiceV2( - target=resize, capability=GuidedCapability.DIRECT) + data = GuidedChoiceV2(target=resize, capability=GuidedCapability.DIRECT) expected_config = copy.copy(data) resp = await self.fsc.v2_guided_POST(data=data) self.assertEqual(expected_config, resp.configured) resp = await self.fsc.v2_GET() - existing_part = [p for p in resp.disks[0].partitions - if getattr(p, 'number', None) == orig_p.number][0] + existing_part = [ + p + for p in resp.disks[0].partitions + if getattr(p, "number", None) == orig_p.number + ][0] self.assertEqual(p_expected, existing_part) self.assertFalse(resp.need_root) self.assertFalse(resp.need_boot) self.assertEqual(1, len(guided_get_resp.targets)) - @parameterized.expand([ - [10], [20], [25], [30], [50], [100], [250], - [1000], [1024], - ]) + @parameterized.expand( + [ + [10], + [20], + [25], + [30], + [50], + [100], + [250], + [1000], + [1024], + ] + ) async def test_lvm_20G_bad_offset(self, disk_size): disk_size = disk_size << 30 - await self._setup(Bootloader.BIOS, 'gpt', size=disk_size) + await self._setup(Bootloader.BIOS, "gpt", size=disk_size) guided_get_resp = await self.fsc.v2_guided_GET() @@ -856,36 +903,39 @@ class TestGuidedV2(IsolatedAsyncioTestCase): self.assertEqual(0, p.size % (1 << 20), p) for i in range(len(parts) - 1): - self.assertEqual( - parts[i + 1].offset, parts[i].offset + parts[i].size) + self.assertEqual(parts[i + 1].offset, parts[i].offset + parts[i].size) self.assertEqual( - disk_size - (1 << 20), parts[-1].offset + parts[-1].size, - disk_size) + disk_size - (1 << 20), parts[-1].offset + parts[-1].size, disk_size + ) async def _sizing_setup(self, bootloader, ptable, disk_size, policy): await self._setup(bootloader, ptable, size=disk_size) resp = await self.fsc.v2_guided_GET() - reformat = [target for target in resp.targets - if isinstance(target, GuidedStorageTargetReformat)][0] - data = GuidedChoiceV2(target=reformat, - capability=GuidedCapability.LVM, - sizing_policy=policy) + reformat = [ + target + for target in resp.targets + if isinstance(target, GuidedStorageTargetReformat) + ][0] + data = GuidedChoiceV2( + target=reformat, capability=GuidedCapability.LVM, sizing_policy=policy + ) await self.fsc.v2_guided_POST(data=data) resp = await self.fsc.GET() - [vg] = matching_dicts(resp.config, type='lvm_volgroup') - [part_id] = vg['devices'] + [vg] = matching_dicts(resp.config, type="lvm_volgroup") + [part_id] = vg["devices"] [part] = matching_dicts(resp.config, id=part_id) - part_size = part['size'] # already an int - [lvm_partition] = matching_dicts(resp.config, type='lvm_partition') - size = dehumanize_size(lvm_partition['size']) + part_size = part["size"] # already an int + [lvm_partition] = matching_dicts(resp.config, type="lvm_partition") + size = dehumanize_size(lvm_partition["size"]) return size, part_size @parameterized.expand(bootloaders_and_ptables) async def test_scaled_disk(self, bootloader, ptable): size, part_size = await self._sizing_setup( - bootloader, ptable, 50 << 30, SizingPolicy.SCALED) + bootloader, ptable, 50 << 30, SizingPolicy.SCALED + ) # expected to be about half, differing by boot and ptable types self.assertLess(20 << 30, size) self.assertLess(size, 30 << 30) @@ -893,7 +943,8 @@ class TestGuidedV2(IsolatedAsyncioTestCase): @parameterized.expand(bootloaders_and_ptables) async def test_unscaled_disk(self, bootloader, ptable): size, part_size = await self._sizing_setup( - bootloader, ptable, 50 << 30, SizingPolicy.ALL) + bootloader, ptable, 50 << 30, SizingPolicy.ALL + ) # there is some subtle differences in sizing depending on # ptable/bootloader and how the rounding goes self.assertLess(part_size - (5 << 20), size) @@ -907,13 +958,14 @@ class TestGuidedV2(IsolatedAsyncioTestCase): # enough space on the rest of the disk. await self._setup(bootloader, ptable, fix_bios=True) make_partition( - self.model, self.disk, preserve=True, - size=4 << 30, is_in_use=True) + self.model, self.disk, preserve=True, size=4 << 30, is_in_use=True + ) expected = [ GuidedStorageTargetReformat( - disk_id=self.disk.id, allowed=default_capabilities), + disk_id=self.disk.id, allowed=default_capabilities + ), GuidedStorageTargetManual(), - ] + ] resp = await self.fsc.v2_guided_GET() self.assertEqual(expected, resp.targets) self.assertEqual(ProbeStatus.DONE, resp.status) @@ -925,16 +977,17 @@ class TestGuidedV2(IsolatedAsyncioTestCase): # even if the disk is full of other partitions. await self._setup(bootloader, ptable, fix_bios=True) make_partition( - self.model, self.disk, preserve=True, - size=4 << 30, is_in_use=True) + self.model, self.disk, preserve=True, size=4 << 30, is_in_use=True + ) make_partition( - self.model, self.disk, preserve=True, - size=gaps.largest_gap_size(self.disk)) + self.model, self.disk, preserve=True, size=gaps.largest_gap_size(self.disk) + ) expected = [ GuidedStorageTargetReformat( - disk_id=self.disk.id, allowed=default_capabilities), + disk_id=self.disk.id, allowed=default_capabilities + ), GuidedStorageTargetManual(), - ] + ] resp = await self.fsc.v2_guided_GET() self.assertEqual(expected, resp.targets) self.assertEqual(ProbeStatus.DONE, resp.status) @@ -946,18 +999,20 @@ class TestGuidedV2(IsolatedAsyncioTestCase): await self._setup(bootloader, ptable, fix_bios=True, size=25 << 30) print(bootloader, ptable) make_partition( - self.model, self.disk, preserve=True, - size=23 << 30, is_in_use=True) + self.model, self.disk, preserve=True, size=23 << 30, is_in_use=True + ) make_partition( - self.model, self.disk, preserve=True, - size=gaps.largest_gap_size(self.disk)) + self.model, self.disk, preserve=True, size=gaps.largest_gap_size(self.disk) + ) resp = await self.fsc.v2_guided_GET() expected = [ GuidedStorageTargetReformat( - disk_id=self.disk.id, allowed=[], - disallowed=default_capabilities_disallowed_too_small), + disk_id=self.disk.id, + allowed=[], + disallowed=default_capabilities_disallowed_too_small, + ), GuidedStorageTargetManual(), - ] + ] self.assertEqual(expected, resp.targets) @parameterized.expand(bootloaders_and_ptables) @@ -967,19 +1022,25 @@ class TestGuidedV2(IsolatedAsyncioTestCase): # and a big enough gap. await self._setup(bootloader, ptable, fix_bios=True) make_partition( - self.model, self.disk, preserve=True, - size=4 << 30, is_in_use=True) + self.model, self.disk, preserve=True, size=4 << 30, is_in_use=True + ) make_partition( - self.model, self.disk, preserve=True, - size=gaps.largest_gap_size(self.disk)//2) + self.model, + self.disk, + preserve=True, + size=gaps.largest_gap_size(self.disk) // 2, + ) expected = [ GuidedStorageTargetReformat( - disk_id=self.disk.id, allowed=default_capabilities), + disk_id=self.disk.id, allowed=default_capabilities + ), GuidedStorageTargetUseGap( - disk_id=self.disk.id, allowed=default_capabilities, - gap=labels.for_client(gaps.largest_gap(self.disk))), + disk_id=self.disk.id, + allowed=default_capabilities, + gap=labels.for_client(gaps.largest_gap(self.disk)), + ), GuidedStorageTargetManual(), - ] + ] resp = await self.fsc.v2_guided_GET() self.assertEqual(expected, resp.targets) self.assertEqual(ProbeStatus.DONE, resp.status) @@ -1020,9 +1081,7 @@ class TestManualBoot(IsolatedAsyncioTestCase): self._setup(bootloader, ptable) d1 = make_disk(self.model) d2 = make_disk(self.model) - make_partition(self.model, d1, - size=gaps.largest_gap_size(d1), - preserve=True) + make_partition(self.model, d1, size=gaps.largest_gap_size(d1), preserve=True) if bootloader == Bootloader.NONE: # NONE will always pass the boot check, even on a full disk bootable = set([d1.id, d2.id]) @@ -1034,21 +1093,18 @@ class TestManualBoot(IsolatedAsyncioTestCase): class TestCoreBootInstallMethods(IsolatedAsyncioTestCase): - def setUp(self): self.app = make_app() self.app.command_runner = mock.AsyncMock() - self.app.opts.bootloader = 'UEFI' + self.app.opts.bootloader = "UEFI" self.app.report_start_event = mock.Mock() self.app.report_finish_event = mock.Mock() self.app.prober = mock.Mock() self.app.prober.get_storage = mock.AsyncMock() - self.app.snapdapi = snapdapi.make_api_client( - AsyncSnapd(get_fake_connection())) + self.app.snapdapi = snapdapi.make_api_client(AsyncSnapd(get_fake_connection())) self.app.dr_cfg = DRConfig() self.app.dr_cfg.systems_dir_exists = True - self.app.controllers.Source.get_handler.return_value = \ - TrivialSourceHandler('') + self.app.controllers.Source.get_handler.return_value = TrivialSourceHandler("") self.fsc = FilesystemController(app=self.app) self.fsc._configured = True self.fsc.model = make_model(Bootloader.UEFI) @@ -1056,79 +1112,88 @@ class TestCoreBootInstallMethods(IsolatedAsyncioTestCase): def _add_details_for_structures(self, structures): self.fsc._info = VariationInfo( - name='foo', - label='system', + name="foo", + label="system", system=snapdapi.SystemDetails( volumes={ - 'pc': snapdapi.Volume(schema='gpt', structure=structures), - } - ) - ) + "pc": snapdapi.Volume(schema="gpt", structure=structures), + } + ), + ) async def test_guided_core_boot(self): disk = make_disk(self.fsc.model) arbitrary_uuid = str(uuid.uuid4()) - self._add_details_for_structures([ - snapdapi.VolumeStructure( - type="83,0FC63DAF-8483-4772-8E79-3D69D8477DE4", - offset=1 << 20, - size=1 << 30, - filesystem='ext4', - name='one'), - snapdapi.VolumeStructure( - type=arbitrary_uuid, - offset=2 << 30, - size=1 << 30, - filesystem='ext4'), - ]) + self._add_details_for_structures( + [ + snapdapi.VolumeStructure( + type="83,0FC63DAF-8483-4772-8E79-3D69D8477DE4", + offset=1 << 20, + size=1 << 30, + filesystem="ext4", + name="one", + ), + snapdapi.VolumeStructure( + type=arbitrary_uuid, offset=2 << 30, size=1 << 30, filesystem="ext4" + ), + ] + ) await self.fsc.guided_core_boot(disk) [part1, part2] = disk.partitions() self.assertEqual(part1.offset, 1 << 20) self.assertEqual(part1.size, 1 << 30) - self.assertEqual(part1.fs().fstype, 'ext4') - self.assertEqual(part1.flag, 'linux') - self.assertEqual(part1.partition_name, 'one') - self.assertEqual( - part1.partition_type, '0FC63DAF-8483-4772-8E79-3D69D8477DE4') + self.assertEqual(part1.fs().fstype, "ext4") + self.assertEqual(part1.flag, "linux") + self.assertEqual(part1.partition_name, "one") + self.assertEqual(part1.partition_type, "0FC63DAF-8483-4772-8E79-3D69D8477DE4") self.assertEqual(part2.flag, None) - self.assertEqual( - part2.partition_type, arbitrary_uuid) + self.assertEqual(part2.partition_type, arbitrary_uuid) async def test_guided_core_boot_reuse(self): disk = make_disk(self.fsc.model) # Add a partition that matches one in the volume structure reused_part = make_partition( - self.fsc.model, disk, offset=1 << 20, size=1 << 30, preserve=True) - self.fsc.model.add_filesystem(reused_part, 'ext4') + self.fsc.model, disk, offset=1 << 20, size=1 << 30, preserve=True + ) + self.fsc.model.add_filesystem(reused_part, "ext4") # And two that do not. make_partition( - self.fsc.model, disk, offset=2 << 30, size=1 << 30, preserve=True) + self.fsc.model, disk, offset=2 << 30, size=1 << 30, preserve=True + ) make_partition( - self.fsc.model, disk, offset=3 << 30, size=1 << 30, preserve=True) - self._add_details_for_structures([ - snapdapi.VolumeStructure( - type="0FC63DAF-8483-4772-8E79-3D69D8477DE4", - offset=1 << 20, - size=1 << 30, - filesystem='ext4'), - ]) + self.fsc.model, disk, offset=3 << 30, size=1 << 30, preserve=True + ) + self._add_details_for_structures( + [ + snapdapi.VolumeStructure( + type="0FC63DAF-8483-4772-8E79-3D69D8477DE4", + offset=1 << 20, + size=1 << 30, + filesystem="ext4", + ), + ] + ) await self.fsc.guided_core_boot(disk) [part] = disk.partitions() self.assertEqual(reused_part, part) - self.assertEqual(reused_part.wipe, 'superblock') - self.assertEqual(part.fs().fstype, 'ext4') + self.assertEqual(reused_part.wipe, "superblock") + self.assertEqual(part.fs().fstype, "ext4") async def test_guided_core_boot_reuse_no_format(self): disk = make_disk(self.fsc.model) existing_part = make_partition( - self.fsc.model, disk, offset=1 << 20, size=1 << 30, preserve=True) - self._add_details_for_structures([ - snapdapi.VolumeStructure( - type="0FC63DAF-8483-4772-8E79-3D69D8477DE4", - offset=1 << 20, - size=1 << 30, - filesystem=None), - ]) + self.fsc.model, disk, offset=1 << 20, size=1 << 30, preserve=True + ) + self._add_details_for_structures( + [ + snapdapi.VolumeStructure( + type="0FC63DAF-8483-4772-8E79-3D69D8477DE4", + offset=1 << 20, + size=1 << 30, + filesystem=None, + ), + ] + ) await self.fsc.guided_core_boot(disk) [part] = disk.partitions() self.assertEqual(existing_part, part) @@ -1136,33 +1201,37 @@ class TestCoreBootInstallMethods(IsolatedAsyncioTestCase): async def test_guided_core_boot_system_data(self): disk = make_disk(self.fsc.model) - self._add_details_for_structures([ - snapdapi.VolumeStructure( - type="21686148-6449-6E6F-744E-656564454649", - offset=1 << 20, - name='BIOS Boot', - size=1 << 20, - role='', - filesystem=''), - snapdapi.VolumeStructure( - type="0FC63DAF-8483-4772-8E79-3D69D8477DE4", - offset=2 << 20, - name='ptname', - size=2 << 30, - role=snapdapi.Role.SYSTEM_DATA, - filesystem='ext4'), - ]) + self._add_details_for_structures( + [ + snapdapi.VolumeStructure( + type="21686148-6449-6E6F-744E-656564454649", + offset=1 << 20, + name="BIOS Boot", + size=1 << 20, + role="", + filesystem="", + ), + snapdapi.VolumeStructure( + type="0FC63DAF-8483-4772-8E79-3D69D8477DE4", + offset=2 << 20, + name="ptname", + size=2 << 30, + role=snapdapi.Role.SYSTEM_DATA, + filesystem="ext4", + ), + ] + ) await self.fsc.guided_core_boot(disk) [bios_part, part] = disk.partitions() self.assertEqual(part.offset, 2 << 20) - self.assertEqual(part.partition_name, 'ptname') - self.assertEqual(part.flag, 'linux') + self.assertEqual(part.partition_name, "ptname") + self.assertEqual(part.flag, "linux") self.assertEqual( - part.size, - disk.size - (2 << 20) - disk.alignment_data().min_end_offset) - self.assertEqual(part.fs().fstype, 'ext4') - self.assertEqual(part.fs().mount().path, '/') - self.assertEqual(part.wipe, 'superblock') + part.size, disk.size - (2 << 20) - disk.alignment_data().min_end_offset + ) + self.assertEqual(part.fs().fstype, "ext4") + self.assertEqual(part.fs().mount().path, "/") + self.assertEqual(part.wipe, "superblock") async def test_from_sample_data(self): # calling this a unit test is definitely questionable. but it @@ -1170,9 +1239,10 @@ class TestCoreBootInstallMethods(IsolatedAsyncioTestCase): self.fsc.model = model = make_model(Bootloader.UEFI) disk = make_disk(model) self.app.base_model.source.current.variations = { - 'default': CatalogEntryVariation( - path='', size=1, snapd_system_label='prefer-encrypted'), - } + "default": CatalogEntryVariation( + path="", size=1, snapd_system_label="prefer-encrypted" + ), + } self.app.dr_cfg.systems_dir_exists = True @@ -1183,42 +1253,44 @@ class TestCoreBootInstallMethods(IsolatedAsyncioTestCase): self.assertEqual(len(response.targets), 1) choice = GuidedChoiceV2( - target=response.targets[0], - capability=GuidedCapability.CORE_BOOT_ENCRYPTED) + target=response.targets[0], capability=GuidedCapability.CORE_BOOT_ENCRYPTED + ) await self.fsc.v2_guided_POST(choice) self.assertEqual(model.storage_version, 2) - partition_count = len([ - structure - for structure in self.fsc._info.system.volumes['pc'].structure - if structure.role != snapdapi.Role.MBR - ]) - self.assertEqual( - partition_count, - len(disk.partitions())) - mounts = {m.path: m.device.volume for m in model._all(type='mount')} - self.assertEqual(set(mounts.keys()), {'/', '/boot', '/boot/efi'}) + partition_count = len( + [ + structure + for structure in self.fsc._info.system.volumes["pc"].structure + if structure.role != snapdapi.Role.MBR + ] + ) + self.assertEqual(partition_count, len(disk.partitions())) + mounts = {m.path: m.device.volume for m in model._all(type="mount")} + self.assertEqual(set(mounts.keys()), {"/", "/boot", "/boot/efi"}) device_map = {p.id: random_string() for p in disk.partitions()} self.fsc.update_devices(device_map) - with mock.patch.object(snapdapi, "post_and_wait", - new_callable=mock.AsyncMock) as mocked: + with mock.patch.object( + snapdapi, "post_and_wait", new_callable=mock.AsyncMock + ) as mocked: mocked.return_value = { - 'encrypted-devices': { - snapdapi.Role.SYSTEM_DATA: 'enc-system-data', - }, - } + "encrypted-devices": { + snapdapi.Role.SYSTEM_DATA: "enc-system-data", + }, + } await self.fsc.setup_encryption(context=self.fsc.context) # setup_encryption mutates the filesystem model objects to # reference the newly created encrypted objects so re-read the # mount to device mapping. - mounts = {m.path: m.device.volume for m in model._all(type='mount')} - self.assertEqual(mounts['/'].path, 'enc-system-data') + mounts = {m.path: m.device.volume for m in model._all(type="mount")} + self.assertEqual(mounts["/"].path, "enc-system-data") - with mock.patch.object(snapdapi, "post_and_wait", - new_callable=mock.AsyncMock) as mocked: + with mock.patch.object( + snapdapi, "post_and_wait", new_callable=mock.AsyncMock + ) as mocked: await self.fsc.finish_install(context=self.fsc.context) mocked.assert_called_once() [call] = mocked.mock_calls @@ -1232,9 +1304,10 @@ class TestCoreBootInstallMethods(IsolatedAsyncioTestCase): self.fsc.model = model = make_model(Bootloader.UEFI) disk = make_disk(model) self.app.base_model.source.current.variations = { - 'default': CatalogEntryVariation( - path='', size=1, snapd_system_label='prefer-encrypted'), - } + "default": CatalogEntryVariation( + path="", size=1, snapd_system_label="prefer-encrypted" + ), + } self.app.dr_cfg.systems_dir_exists = True @@ -1242,26 +1315,27 @@ class TestCoreBootInstallMethods(IsolatedAsyncioTestCase): await self.fsc._examine_systems_task.wait() self.fsc.start() - await self.fsc.run_autoinstall_guided({'name': 'hybrid'}) + await self.fsc.run_autoinstall_guided({"name": "hybrid"}) self.assertEqual(model.storage_version, 2) - partition_count = len([ - structure - for structure in self.fsc._info.system.volumes['pc'].structure - if structure.role != snapdapi.Role.MBR - ]) - self.assertEqual( - partition_count, - len(disk.partitions())) + partition_count = len( + [ + structure + for structure in self.fsc._info.system.volumes["pc"].structure + if structure.role != snapdapi.Role.MBR + ] + ) + self.assertEqual(partition_count, len(disk.partitions())) async def test_from_sample_data_defective(self): self.fsc.model = model = make_model(Bootloader.UEFI) make_disk(model) self.app.base_model.source.current.variations = { - 'default': CatalogEntryVariation( - path='', size=1, snapd_system_label='defective'), - } + "default": CatalogEntryVariation( + path="", size=1, snapd_system_label="defective" + ), + } self.app.dr_cfg.systems_dir_exists = True await self.fsc._examine_systems_task.start() self.fsc.start() @@ -1272,4 +1346,5 @@ class TestCoreBootInstallMethods(IsolatedAsyncioTestCase): disallowed = response.targets[0].disallowed[0] self.assertEqual( disallowed.reason, - GuidedDisallowedCapabilityReason.CORE_BOOT_ENCRYPTION_UNAVAILABLE) + GuidedDisallowedCapabilityReason.CORE_BOOT_ENCRYPTION_UNAVAILABLE, + ) diff --git a/subiquity/server/controllers/tests/test_install.py b/subiquity/server/controllers/tests/test_install.py index 9419e6b7..5c042046 100644 --- a/subiquity/server/controllers/tests/test_install.py +++ b/subiquity/server/controllers/tests/test_install.py @@ -13,26 +13,16 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from pathlib import Path import subprocess import unittest -from unittest.mock import ( - ANY, - AsyncMock, - call, - Mock, - mock_open, - patch, - ) +from pathlib import Path +from unittest.mock import ANY, AsyncMock, Mock, call, mock_open, patch from curtin.util import EFIBootEntry, EFIBootState from subiquity.common.types import PackageInstallState from subiquity.models.tests.test_filesystem import make_model_and_partition -from subiquity.server.controllers.install import ( - InstallController, - ) - +from subiquity.server.controllers.install import InstallController from subiquitycore.tests.mocks import make_app @@ -48,74 +38,80 @@ class TestWriteConfig(unittest.IsolatedAsyncioTestCase): @patch("subiquity.server.controllers.install.run_curtin_command") async def test_run_curtin_install_step(self, run_cmd): - - with patch("subiquity.server.controllers.install.open", - mock_open()) as m_open: + with patch("subiquity.server.controllers.install.open", mock_open()) as m_open: await self.controller.run_curtin_step( - name='MyStep', + name="MyStep", stages=["partitioning", "extract"], config_file=Path("/config.yaml"), source="/source", config=self.controller.base_config( - logs_dir=Path("/"), resume_data_file=Path("resume-data")) - ) + logs_dir=Path("/"), resume_data_file=Path("resume-data") + ), + ) m_open.assert_called_once_with("/curtin-install.log", mode="a") run_cmd.assert_called_once_with( - self.controller.app, - ANY, - "install", - "--set", 'json:stages=["partitioning", "extract"]', - "/source", - config="/config.yaml", - private_mounts=False) + self.controller.app, + ANY, + "install", + "--set", + 'json:stages=["partitioning", "extract"]', + "/source", + config="/config.yaml", + private_mounts=False, + ) @patch("subiquity.server.controllers.install.run_curtin_command") async def test_run_curtin_install_step_no_src(self, run_cmd): - - with patch("subiquity.server.controllers.install.open", - mock_open()) as m_open: + with patch("subiquity.server.controllers.install.open", mock_open()) as m_open: await self.controller.run_curtin_step( - name='MyStep', + name="MyStep", stages=["partitioning", "extract"], config_file=Path("/config.yaml"), source=None, config=self.controller.base_config( - logs_dir=Path("/"), resume_data_file=Path("resume-data")) - ) + logs_dir=Path("/"), resume_data_file=Path("resume-data") + ), + ) m_open.assert_called_once_with("/curtin-install.log", mode="a") run_cmd.assert_called_once_with( - self.controller.app, - ANY, - "install", - "--set", 'json:stages=["partitioning", "extract"]', - config="/config.yaml", - private_mounts=False) + self.controller.app, + ANY, + "install", + "--set", + 'json:stages=["partitioning", "extract"]', + config="/config.yaml", + private_mounts=False, + ) def test_base_config(self): - config = self.controller.base_config( - logs_dir=Path("/logs"), resume_data_file=Path("resume-data")) + logs_dir=Path("/logs"), resume_data_file=Path("resume-data") + ) - self.assertDictEqual(config, { - "install": { - "target": "/target", - "unmount": "disabled", - "save_install_config": False, - "save_install_log": False, - "log_file": "/logs/curtin-install.log", - "log_file_append": True, - "error_tarfile": "/logs/curtin-errors.tar", - "resume_data": "resume-data", + self.assertDictEqual( + config, + { + "install": { + "target": "/target", + "unmount": "disabled", + "save_install_config": False, + "save_install_log": False, + "log_file": "/logs/curtin-install.log", + "log_file_append": True, + "error_tarfile": "/logs/curtin-errors.tar", + "resume_data": "resume-data", } - }) + }, + ) def test_generic_config(self): - with patch.object(self.controller.model, "render", - return_value={"key": "value"}): + with patch.object( + self.controller.model, "render", return_value={"key": "value"} + ): config = self.controller.generic_config(key2="value2") self.assertEqual( @@ -123,65 +119,78 @@ class TestWriteConfig(unittest.IsolatedAsyncioTestCase): { "key": "value", "key2": "value2", - }) + }, + ) efi_state_no_rp = EFIBootState( - current='0000', - timeout='0 seconds', - order=['0000', '0002'], + current="0000", + timeout="0 seconds", + order=["0000", "0002"], entries={ - '0000': EFIBootEntry( - name='ubuntu', - path='HD(1,GPT,...)/File(\\EFI\\ubuntu\\shimx64.efi)'), - '0001': EFIBootEntry( - name='Windows Boot Manager', - path='HD(1,GPT,...,0x82000)/File(\\EFI\\bootmgfw.efi'), - '0002': EFIBootEntry( - name='Linux-Firmware-Updater', - path='HD(1,GPT,...,0x800,0x100000)/File(\\shimx64.efi)\\.fwupd'), - }) + "0000": EFIBootEntry( + name="ubuntu", path="HD(1,GPT,...)/File(\\EFI\\ubuntu\\shimx64.efi)" + ), + "0001": EFIBootEntry( + name="Windows Boot Manager", + path="HD(1,GPT,...,0x82000)/File(\\EFI\\bootmgfw.efi", + ), + "0002": EFIBootEntry( + name="Linux-Firmware-Updater", + path="HD(1,GPT,...,0x800,0x100000)/File(\\shimx64.efi)\\.fwupd", + ), + }, +) efi_state_with_rp = EFIBootState( - current='0000', - timeout='0 seconds', - order=['0000', '0002', '0003'], + current="0000", + timeout="0 seconds", + order=["0000", "0002", "0003"], entries={ - '0000': EFIBootEntry( - name='ubuntu', - path='HD(1,GPT,...)/File(\\EFI\\ubuntu\\shimx64.efi)'), - '0001': EFIBootEntry( - name='Windows Boot Manager', - path='HD(1,GPT,...,0x82000)/File(\\EFI\\bootmgfw.efi'), - '0002': EFIBootEntry( - name='Linux-Firmware-Updater', - path='HD(1,GPT,...,0x800,0x100000)/File(\\shimx64.efi)\\.fwupd'), - '0003': EFIBootEntry( - name='Restore Ubuntu to factory state', - path='HD(1,GPT,...,0x800,0x100000)/File(\\shimx64.efi)'), - }) + "0000": EFIBootEntry( + name="ubuntu", path="HD(1,GPT,...)/File(\\EFI\\ubuntu\\shimx64.efi)" + ), + "0001": EFIBootEntry( + name="Windows Boot Manager", + path="HD(1,GPT,...,0x82000)/File(\\EFI\\bootmgfw.efi", + ), + "0002": EFIBootEntry( + name="Linux-Firmware-Updater", + path="HD(1,GPT,...,0x800,0x100000)/File(\\shimx64.efi)\\.fwupd", + ), + "0003": EFIBootEntry( + name="Restore Ubuntu to factory state", + path="HD(1,GPT,...,0x800,0x100000)/File(\\shimx64.efi)", + ), + }, +) efi_state_with_dup_rp = EFIBootState( - current='0000', - timeout='0 seconds', - order=['0000', '0002', '0004'], + current="0000", + timeout="0 seconds", + order=["0000", "0002", "0004"], entries={ - '0000': EFIBootEntry( - name='ubuntu', - path='HD(1,GPT,...)/File(\\EFI\\ubuntu\\shimx64.efi)'), - '0001': EFIBootEntry( - name='Windows Boot Manager', - path='HD(1,GPT,...,0x82000)/File(\\EFI\\bootmgfw.efi'), - '0002': EFIBootEntry( - name='Linux-Firmware-Updater', - path='HD(1,GPT,...,0x800,0x100000)/File(\\shimx64.efi)\\.fwupd'), - '0003': EFIBootEntry( - name='Restore Ubuntu to factory state', - path='HD(1,GPT,...,0x800,0x100000)/File(\\shimx64.efi)'), - '0004': EFIBootEntry( - name='Restore Ubuntu to factory state', - path='HD(1,GPT,...,0x800,0x100000)/File(\\shimx64.efi)'), - }) + "0000": EFIBootEntry( + name="ubuntu", path="HD(1,GPT,...)/File(\\EFI\\ubuntu\\shimx64.efi)" + ), + "0001": EFIBootEntry( + name="Windows Boot Manager", + path="HD(1,GPT,...,0x82000)/File(\\EFI\\bootmgfw.efi", + ), + "0002": EFIBootEntry( + name="Linux-Firmware-Updater", + path="HD(1,GPT,...,0x800,0x100000)/File(\\shimx64.efi)\\.fwupd", + ), + "0003": EFIBootEntry( + name="Restore Ubuntu to factory state", + path="HD(1,GPT,...,0x800,0x100000)/File(\\shimx64.efi)", + ), + "0004": EFIBootEntry( + name="Restore Ubuntu to factory state", + path="HD(1,GPT,...,0x800,0x100000)/File(\\shimx64.efi)", + ), + }, +) class TestInstallController(unittest.IsolatedAsyncioTestCase): @@ -195,7 +204,8 @@ class TestInstallController(unittest.IsolatedAsyncioTestCase): async def test_install_package(self, m_sleep): run_curtin = "subiquity.server.controllers.install.run_curtin_command" error = subprocess.CalledProcessError( - returncode=1, cmd="curtin system-install git") + returncode=1, cmd="curtin system-install git" + ) with patch(run_curtin): await self.controller.install_package(package="git") @@ -220,46 +230,66 @@ class TestInstallController(unittest.IsolatedAsyncioTestCase): app.command_runner = Mock() self.run = app.command_runner.run = AsyncMock() app.package_installer.install_pkg = AsyncMock() - app.package_installer.install_pkg.return_value = \ - PackageInstallState.DONE + app.package_installer.install_pkg.return_value = PackageInstallState.DONE fsm, self.part = make_model_and_partition() @patch("subiquity.server.controllers.install.get_efibootmgr") async def test_create_rp_boot_entry_add(self, m_get_efibootmgr): - m_get_efibootmgr.side_effect = iter([ - efi_state_no_rp, efi_state_with_rp]) + m_get_efibootmgr.side_effect = iter([efi_state_no_rp, efi_state_with_rp]) self.setup_rp_test() await self.controller.create_rp_boot_entry(rp=self.part) calls = [ - call([ - 'efibootmgr', '--create', - '--loader', '\\EFI\\boot\\shimx64.efi', - '--disk', self.part.device.path, - '--part', str(self.part.number), - '--label', "Restore Ubuntu to factory state", - ]), - call([ - 'efibootmgr', '--bootorder', '0000,0002', - ]), - ] + call( + [ + "efibootmgr", + "--create", + "--loader", + "\\EFI\\boot\\shimx64.efi", + "--disk", + self.part.device.path, + "--part", + str(self.part.number), + "--label", + "Restore Ubuntu to factory state", + ] + ), + call( + [ + "efibootmgr", + "--bootorder", + "0000,0002", + ] + ), + ] self.run.assert_has_awaits(calls) @patch("subiquity.server.controllers.install.get_efibootmgr") async def test_create_rp_boot_entry_dup(self, m_get_efibootmgr): - m_get_efibootmgr.side_effect = iter([ - efi_state_with_rp, efi_state_with_dup_rp]) + m_get_efibootmgr.side_effect = iter([efi_state_with_rp, efi_state_with_dup_rp]) self.setup_rp_test() await self.controller.create_rp_boot_entry(rp=self.part) calls = [ - call([ - 'efibootmgr', '--create', - '--loader', '\\EFI\\boot\\shimx64.efi', - '--disk', self.part.device.path, - '--part', str(self.part.number), - '--label', "Restore Ubuntu to factory state", - ]), - call([ - 'efibootmgr', '--delete-bootnum', '--bootnum', '0004', - ]), - ] + call( + [ + "efibootmgr", + "--create", + "--loader", + "\\EFI\\boot\\shimx64.efi", + "--disk", + self.part.device.path, + "--part", + str(self.part.number), + "--label", + "Restore Ubuntu to factory state", + ] + ), + call( + [ + "efibootmgr", + "--delete-bootnum", + "--bootnum", + "0004", + ] + ), + ] self.run.assert_has_awaits(calls) diff --git a/subiquity/server/controllers/tests/test_integrity.py b/subiquity/server/controllers/tests/test_integrity.py index cb81ac30..cdcca4e5 100644 --- a/subiquity/server/controllers/tests/test_integrity.py +++ b/subiquity/server/controllers/tests/test_integrity.py @@ -13,8 +13,6 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from subiquitycore.tests import SubiTestCase - from subiquity.common.types import CasperMd5Results from subiquity.models.integrity import IntegrityModel from subiquity.server.controllers.integrity import ( @@ -23,13 +21,14 @@ from subiquity.server.controllers.integrity import ( mock_pass, mock_skip, ) +from subiquitycore.tests import SubiTestCase from subiquitycore.tests.mocks import make_app class TestMd5Check(SubiTestCase): def setUp(self): self.app = make_app() - self.app.opts.bootloader = 'UEFI' + self.app.opts.bootloader = "UEFI" self.ic = IntegrityController(app=self.app) self.ic.model = IntegrityModel() diff --git a/subiquity/server/controllers/tests/test_kernel.py b/subiquity/server/controllers/tests/test_kernel.py index fee277bc..da61f542 100644 --- a/subiquity/server/controllers/tests/test_kernel.py +++ b/subiquity/server/controllers/tests/test_kernel.py @@ -16,13 +16,11 @@ import os import os.path -from subiquitycore.tests.parameterized import parameterized - from subiquity.models.kernel import KernelModel from subiquity.server.controllers.kernel import KernelController - from subiquitycore.tests import SubiTestCase from subiquitycore.tests.mocks import make_app +from subiquitycore.tests.parameterized import parameterized class TestMetapackageSelection(SubiTestCase): @@ -33,49 +31,52 @@ class TestMetapackageSelection(SubiTestCase): self.controller.model = KernelModel() def setup_mpfile(self, dirpath, data): - runfile = self.tmp_path(f"{dirpath}/kernel-meta-package", - dir=self.app.base_model.root) + runfile = self.tmp_path( + f"{dirpath}/kernel-meta-package", dir=self.app.base_model.root + ) os.makedirs(os.path.dirname(runfile), exist_ok=True) - with open(runfile, 'w') as fp: + with open(runfile, "w") as fp: fp.write(data) def test_defaults(self): self.controller.start() - self.assertEqual('linux-generic', self.controller.model.metapkg_name) + self.assertEqual("linux-generic", self.controller.model.metapkg_name) def test_mpfile_run(self): - self.setup_mpfile('run', 'linux-aaaa') + self.setup_mpfile("run", "linux-aaaa") self.controller.start() - self.assertEqual('linux-aaaa', self.controller.model.metapkg_name) + self.assertEqual("linux-aaaa", self.controller.model.metapkg_name) def test_mpfile_etc(self): - self.setup_mpfile('etc/subiquity', 'linux-zzzz') + self.setup_mpfile("etc/subiquity", "linux-zzzz") self.controller.start() - self.assertEqual('linux-zzzz', self.controller.model.metapkg_name) + self.assertEqual("linux-zzzz", self.controller.model.metapkg_name) def test_mpfile_both(self): - self.setup_mpfile('run', 'linux-aaaa') - self.setup_mpfile('etc/subiquity', 'linux-zzzz') + self.setup_mpfile("run", "linux-aaaa") + self.setup_mpfile("etc/subiquity", "linux-zzzz") self.controller.start() - self.assertEqual('linux-aaaa', self.controller.model.metapkg_name) + self.assertEqual("linux-aaaa", self.controller.model.metapkg_name) - @parameterized.expand([ - [None, None, 'linux-generic'], - [None, {}, 'linux-generic'], - # when the metapackage file is set, it should be used. - ['linux-zzzz', None, 'linux-zzzz'], - # when we have a metapackage file and autoinstall, use autoinstall. - ['linux-zzzz', {'package': 'linux-aaaa'}, 'linux-aaaa'], - [None, {'package': 'linux-aaaa'}, 'linux-aaaa'], - [None, {'package': 'linux-aaaa', 'flavor': 'bbbb'}, 'linux-aaaa'], - [None, {'flavor': None}, 'linux-generic'], - [None, {'flavor': 'generic'}, 'linux-generic'], - [None, {'flavor': 'hwe'}, 'linux-generic-hwe-20.04'], - [None, {'flavor': 'bbbb'}, 'linux-bbbb-20.04'], - ]) + @parameterized.expand( + [ + [None, None, "linux-generic"], + [None, {}, "linux-generic"], + # when the metapackage file is set, it should be used. + ["linux-zzzz", None, "linux-zzzz"], + # when we have a metapackage file and autoinstall, use autoinstall. + ["linux-zzzz", {"package": "linux-aaaa"}, "linux-aaaa"], + [None, {"package": "linux-aaaa"}, "linux-aaaa"], + [None, {"package": "linux-aaaa", "flavor": "bbbb"}, "linux-aaaa"], + [None, {"flavor": None}, "linux-generic"], + [None, {"flavor": "generic"}, "linux-generic"], + [None, {"flavor": "hwe"}, "linux-generic-hwe-20.04"], + [None, {"flavor": "bbbb"}, "linux-bbbb-20.04"], + ] + ) def test_ai(self, mpfile_data, ai_data, metapkg_name): if mpfile_data is not None: - self.setup_mpfile('etc/subiquity', mpfile_data) + self.setup_mpfile("etc/subiquity", mpfile_data) self.controller.load_autoinstall_data(ai_data) self.controller.start() self.assertEqual(metapkg_name, self.controller.model.metapkg_name) diff --git a/subiquity/server/controllers/tests/test_keyboard.py b/subiquity/server/controllers/tests/test_keyboard.py index 4059a09b..001cf3e2 100644 --- a/subiquity/server/controllers/tests/test_keyboard.py +++ b/subiquity/server/controllers/tests/test_keyboard.py @@ -15,31 +15,24 @@ import os import unittest - from unittest.mock import Mock, patch +from subiquity.common.types import KeyboardSetting +from subiquity.models.keyboard import KeyboardModel +from subiquity.server.controllers.keyboard import KeyboardController from subiquitycore.tests import SubiTestCase from subiquitycore.tests.mocks import make_app from subiquitycore.tests.parameterized import parameterized -from subiquity.models.keyboard import ( - KeyboardModel, - ) -from subiquity.server.controllers.keyboard import ( - KeyboardController, - ) -from subiquity.common.types import KeyboardSetting - class opts: dry_run = True class TestSubiquityModel(SubiTestCase): - async def test_write_config(self): - os.environ['SUBIQUITY_REPLAY_TIMESCALE'] = '100' - new_setting = KeyboardSetting('fr', 'azerty') + os.environ["SUBIQUITY_REPLAY_TIMESCALE"] = "100" + new_setting = KeyboardSetting("fr", "azerty") tmpdir = self.tmp_dir() model = KeyboardModel(tmpdir) model.setting = new_setting @@ -56,27 +49,36 @@ class TestInputSource(unittest.IsolatedAsyncioTestCase): self.app = make_app() self.controller = KeyboardController(self.app) - @parameterized.expand([ - ('us', '', "[('xkb','us')]"), - ('fr', 'latin9', "[('xkb','fr+latin9')]"), - ]) + @parameterized.expand( + [ + ("us", "", "[('xkb','us')]"), + ("fr", "latin9", "[('xkb','fr+latin9')]"), + ] + ) async def test_input_source(self, layout, variant, expected_xkb): - with patch('subiquity.server.controllers.keyboard.arun_command') as \ - mock_arun_command, patch('pwd.getpwnam') as mock_getpwnam: + with patch( + "subiquity.server.controllers.keyboard.arun_command" + ) as mock_arun_command, patch("pwd.getpwnam") as mock_getpwnam: m = Mock() - m.pw_uid = '99' + m.pw_uid = "99" mock_getpwnam.return_value = m self.app.opts.dry_run = False - await self.controller.set_input_source(layout, variant, user='bar') + await self.controller.set_input_source(layout, variant, user="bar") gsettings = [ - 'gsettings', 'set', 'org.gnome.desktop.input-sources', - 'sources', expected_xkb - ] + "gsettings", + "set", + "org.gnome.desktop.input-sources", + "sources", + expected_xkb, + ] cmd = [ - 'systemd-run', '--wait', '--uid=99', + "systemd-run", + "--wait", + "--uid=99", f'--setenv=DISPLAY={os.environ.get("DISPLAY", ":0")}', - '--setenv=XDG_RUNTIME_DIR=/run/user/99', - '--setenv=DBUS_SESSION_BUS_ADDRESS=unix:path=/run/user/99/bus', - '--', *gsettings - ] + "--setenv=XDG_RUNTIME_DIR=/run/user/99", + "--setenv=DBUS_SESSION_BUS_ADDRESS=unix:path=/run/user/99/bus", + "--", + *gsettings, + ] mock_arun_command.assert_called_once_with(cmd) diff --git a/subiquity/server/controllers/tests/test_mirror.py b/subiquity/server/controllers/tests/test_mirror.py index 7c7fc333..c707dd82 100644 --- a/subiquity/server/controllers/tests/test_mirror.py +++ b/subiquity/server/controllers/tests/test_mirror.py @@ -15,17 +15,17 @@ import contextlib import io -import jsonschema import unittest from unittest import mock -from subiquitycore.tests.mocks import make_app +import jsonschema + from subiquity.common.types import MirrorSelectionFallback from subiquity.models.mirror import MirrorModel from subiquity.server.apt import AptConfigCheckError -from subiquity.server.controllers.mirror import MirrorController +from subiquity.server.controllers.mirror import MirrorController, NoUsableMirrorError from subiquity.server.controllers.mirror import log as MirrorLogger -from subiquity.server.controllers.mirror import NoUsableMirrorError +from subiquitycore.tests.mocks import make_app class TestMirrorSchema(unittest.TestCase): @@ -36,15 +36,15 @@ class TestMirrorSchema(unittest.TestCase): self.validate({}) def test_disable_components(self): - self.validate({'disable_components': ['universe']}) + self.validate({"disable_components": ["universe"]}) def test_no_disable_main(self): with self.assertRaises(jsonschema.ValidationError): - self.validate({'disable_components': ['main']}) + self.validate({"disable_components": ["main"]}) def test_no_disable_random_junk(self): with self.assertRaises(jsonschema.ValidationError): - self.validate({'disable_components': ['not-a-component']}) + self.validate({"disable_components": ["not-a-component"]}) class TestMirrorController(unittest.IsolatedAsyncioTestCase): @@ -66,8 +66,9 @@ class TestMirrorController(unittest.IsolatedAsyncioTestCase): def test_make_autoinstall_legacy(self): self.controller.model = MirrorModel() self.controller.model.legacy_primary = True - self.controller.model.primary_candidates = \ + self.controller.model.primary_candidates = ( self.controller.model.get_default_primary_candidates() + ) self.controller.model.primary_candidates[0].elect() config = self.controller.make_autoinstall() self.assertIn("disable_components", config.keys()) @@ -86,19 +87,24 @@ class TestMirrorController(unittest.IsolatedAsyncioTestCase): output = io.StringIO() mock_source_configured = mock.patch.object( - self.controller.source_configured_event, "wait") + self.controller.source_configured_event, "wait" + ) mock_run_apt_config_check = mock.patch.object( - self.controller.test_apt_configurer, "run_apt_config_check", - side_effect=fake_mirror_check_success) + self.controller.test_apt_configurer, + "run_apt_config_check", + side_effect=fake_mirror_check_success, + ) with mock_source_configured, mock_run_apt_config_check: await self.controller.run_mirror_testing(output) self.assertEqual(output.getvalue(), "test is successful!") output = io.StringIO() mock_run_apt_config_check = mock.patch.object( - self.controller.test_apt_configurer, "run_apt_config_check", - side_effect=fake_mirror_check_failure) + self.controller.test_apt_configurer, + "run_apt_config_check", + side_effect=fake_mirror_check_failure, + ) with mock_source_configured, mock_run_apt_config_check: with self.assertRaises(AptConfigCheckError): await self.controller.run_mirror_testing(output) @@ -109,21 +115,22 @@ class TestMirrorController(unittest.IsolatedAsyncioTestCase): with run_test: with self.assertLogs(MirrorLogger, "DEBUG") as debug: await self.controller.try_mirror_checking_once() - self.assertIn("Mirror checking successful", - [record.msg for record in debug.records]) - self.assertIn("APT output follows", - [record.msg for record in debug.records]) + self.assertIn( + "Mirror checking successful", [record.msg for record in debug.records] + ) + self.assertIn("APT output follows", [record.msg for record in debug.records]) - run_test = mock.patch.object(self.controller, "run_mirror_testing", - side_effect=AptConfigCheckError) + run_test = mock.patch.object( + self.controller, "run_mirror_testing", side_effect=AptConfigCheckError + ) with run_test: with self.assertLogs(MirrorLogger, "DEBUG") as debug: with self.assertRaises(AptConfigCheckError): await self.controller.try_mirror_checking_once() - self.assertIn("Mirror checking failed", - [record.msg for record in debug.records]) - self.assertIn("APT output follows", - [record.msg for record in debug.records]) + self.assertIn( + "Mirror checking failed", [record.msg for record in debug.records] + ) + self.assertIn("APT output follows", [record.msg for record in debug.records]) @mock.patch("subiquity.server.controllers.mirror.asyncio.sleep") async def test_find_and_elect_candidate_mirror(self, mock_sleep): @@ -139,39 +146,48 @@ class TestMirrorController(unittest.IsolatedAsyncioTestCase): self.controller.model.primary_candidates = [] with self.assertRaises(NoUsableMirrorError): await self.controller.find_and_elect_candidate_mirror( - self.controller.app.context) + self.controller.app.context + ) self.assertIsNone(self.controller.model.primary_elected) # Test one succeeding candidate self.controller.model.primary_elected = None - self.controller.model.primary_candidates = \ - [self.controller.model.create_primary_candidate("http://mirror")] + self.controller.model.primary_candidates = [ + self.controller.model.create_primary_candidate("http://mirror") + ] with mock.patch.object(self.controller, "try_mirror_checking_once"): await self.controller.find_and_elect_candidate_mirror( - self.controller.app.context) - self.assertEqual(self.controller.model.primary_elected.uri, - "http://mirror") + self.controller.app.context + ) + self.assertEqual(self.controller.model.primary_elected.uri, "http://mirror") # Test one succeeding candidate, on second try self.controller.model.primary_elected = None - self.controller.model.primary_candidates = \ - [self.controller.model.create_primary_candidate("http://mirror")] - with mock.patch.object(self.controller, "try_mirror_checking_once", - side_effect=(AptConfigCheckError, None)): + self.controller.model.primary_candidates = [ + self.controller.model.create_primary_candidate("http://mirror") + ] + with mock.patch.object( + self.controller, + "try_mirror_checking_once", + side_effect=(AptConfigCheckError, None), + ): await self.controller.find_and_elect_candidate_mirror( - self.controller.app.context) - self.assertEqual(self.controller.model.primary_elected.uri, - "http://mirror") + self.controller.app.context + ) + self.assertEqual(self.controller.model.primary_elected.uri, "http://mirror") # Test with a single candidate, failing twice self.controller.model.primary_elected = None - self.controller.model.primary_candidates = \ - [self.controller.model.create_primary_candidate("http://mirror")] - with mock.patch.object(self.controller, "try_mirror_checking_once", - side_effect=AptConfigCheckError): + self.controller.model.primary_candidates = [ + self.controller.model.create_primary_candidate("http://mirror") + ] + with mock.patch.object( + self.controller, "try_mirror_checking_once", side_effect=AptConfigCheckError + ): with self.assertRaises(NoUsableMirrorError): await self.controller.find_and_elect_candidate_mirror( - self.controller.app.context) + self.controller.app.context + ) self.assertIsNone(self.controller.model.primary_elected) # Test with one candidate failing twice, then one succeeding @@ -181,25 +197,28 @@ class TestMirrorController(unittest.IsolatedAsyncioTestCase): self.controller.model.create_primary_candidate("http://success"), ] with mock.patch.object( - self.controller, "try_mirror_checking_once", - side_effect=(AptConfigCheckError, AptConfigCheckError, None)): + self.controller, + "try_mirror_checking_once", + side_effect=(AptConfigCheckError, AptConfigCheckError, None), + ): await self.controller.find_and_elect_candidate_mirror( - self.controller.app.context) - self.assertEqual(self.controller.model.primary_elected.uri, - "http://success") + self.controller.app.context + ) + self.assertEqual(self.controller.model.primary_elected.uri, "http://success") # Test with an unresolved country mirror self.controller.model.primary_elected = None self.controller.model.primary_candidates = [ self.controller.model.create_primary_candidate( - uri=None, country_mirror=True), + uri=None, country_mirror=True + ), self.controller.model.create_primary_candidate("http://success"), ] with mock.patch.object(self.controller, "try_mirror_checking_once"): await self.controller.find_and_elect_candidate_mirror( - self.controller.app.context) - self.assertEqual(self.controller.model.primary_elected.uri, - "http://success") + self.controller.app.context + ) + self.assertEqual(self.controller.model.primary_elected.uri, "http://success") async def test_find_and_elect_candidate_mirror_no_network(self): self.controller.app.context.child = contextlib.nullcontext @@ -210,7 +229,8 @@ class TestMirrorController(unittest.IsolatedAsyncioTestCase): self.controller.cc_event.set() await self.controller.find_and_elect_candidate_mirror( - self.controller.app.context) + self.controller.app.context + ) self.assertIsNone(self.controller.model.primary_elected) async def test_apply_fallback(self): @@ -235,9 +255,11 @@ class TestMirrorController(unittest.IsolatedAsyncioTestCase): controller = self.controller with mock.patch.object(controller, "apply_fallback") as mock_fallback: - with mock.patch.object(controller, - "find_and_elect_candidate_mirror", - side_effect=[None, NoUsableMirrorError]): + with mock.patch.object( + controller, + "find_and_elect_candidate_mirror", + side_effect=[None, NoUsableMirrorError], + ): await controller.run_mirror_selection_or_fallback(context=None) mock_fallback.assert_not_called() await controller.run_mirror_selection_or_fallback(context=None) diff --git a/subiquity/server/controllers/tests/test_refresh.py b/subiquity/server/controllers/tests/test_refresh.py index 3f40986a..94adf24d 100644 --- a/subiquity/server/controllers/tests/test_refresh.py +++ b/subiquity/server/controllers/tests/test_refresh.py @@ -15,28 +15,22 @@ from unittest import mock +from subiquity.server import snapdapi +from subiquity.server.controllers import refresh as refresh_mod +from subiquity.server.controllers.refresh import RefreshController, SnapChannelSource from subiquitycore.snapd import AsyncSnapd, get_fake_connection from subiquitycore.tests import SubiTestCase from subiquitycore.tests.mocks import make_app -from subiquity.server.controllers import refresh as refresh_mod -from subiquity.server.controllers.refresh import ( - RefreshController, - SnapChannelSource, - ) -from subiquity.server import snapdapi - class TestRefreshController(SubiTestCase): - def setUp(self): self.app = make_app() self.app.report_start_event = mock.Mock() self.app.report_finish_event = mock.Mock() self.app.note_data_for_apport = mock.Mock() self.app.prober = mock.Mock() - self.app.snapdapi = snapdapi.make_api_client( - AsyncSnapd(get_fake_connection())) + self.app.snapdapi = snapdapi.make_api_client(AsyncSnapd(get_fake_connection())) self.rc = RefreshController(app=self.app) async def test_configure_snapd_kernel_autoinstall(self): @@ -44,69 +38,66 @@ class TestRefreshController(SubiTestCase): # autoinstall data, switch to it. for src in SnapChannelSource.CMDLINE, SnapChannelSource.AUTOINSTALL: - with mock.patch.object(self.rc, 'get_refresh_channel') as grc: - grc.return_value = ('newchan', src) - with mock.patch.object(refresh_mod, 'post_and_wait') as paw: + with mock.patch.object(self.rc, "get_refresh_channel") as grc: + grc.return_value = ("newchan", src) + with mock.patch.object(refresh_mod, "post_and_wait") as paw: await self.rc.configure_snapd(context=self.rc.context) paw.assert_called_once() request = paw.mock_calls[0].args[2] self.assertEqual(request.action, snapdapi.SnapAction.SWITCH) - self.assertEqual(request.channel, 'newchan') + self.assertEqual(request.channel, "newchan") async def test_configure_snapd_notfound(self): # If a snap channel is not found, ignore that. - with mock.patch.object(self.rc, 'get_refresh_channel') as grc: + with mock.patch.object(self.rc, "get_refresh_channel") as grc: grc.return_value = (None, SnapChannelSource.NOT_FOUND) - with mock.patch.object(refresh_mod, 'post_and_wait') as paw: + with mock.patch.object(refresh_mod, "post_and_wait") as paw: await self.rc.configure_snapd(context=self.rc.context) paw.assert_not_called() - @mock.patch('subiquity.server.controllers.refresh.lsb_release') + @mock.patch("subiquity.server.controllers.refresh.lsb_release") async def test_configure_snapd_disk_info(self, m_lsb): # If a snap channel is found via .disk/info it is applying if # and only if the snap is already tracking stable/ubuntu-XX.YY - m_lsb.return_value = {'release': 'XX.YY'} + m_lsb.return_value = {"release": "XX.YY"} # The ...v2.snaps[snap_name].GET() style of API is cute but # not very easy to mock out. - subiquity_info = await self.app.snapdapi.v2.snaps['subiquity'].GET() + subiquity_info = await self.app.snapdapi.v2.snaps["subiquity"].GET() class StubSnap: async def GET(self): return subiquity_info + POST = None - stub_snaps = {'subiquity': StubSnap()} + stub_snaps = {"subiquity": StubSnap()} # Test with the snap following the expected channel - subiquity_info.channel = 'stable/ubuntu-XX.YY' + subiquity_info.channel = "stable/ubuntu-XX.YY" - with mock.patch.object(self.rc, 'get_refresh_channel') as grc: - grc.return_value = ( - 'newchan', SnapChannelSource.DISK_INFO_FILE) - with mock.patch.object( - self.rc.app.snapdapi.v2, 'snaps', new=stub_snaps): - with mock.patch.object(refresh_mod, 'post_and_wait') as paw: + with mock.patch.object(self.rc, "get_refresh_channel") as grc: + grc.return_value = ("newchan", SnapChannelSource.DISK_INFO_FILE) + with mock.patch.object(self.rc.app.snapdapi.v2, "snaps", new=stub_snaps): + with mock.patch.object(refresh_mod, "post_and_wait") as paw: await self.rc.configure_snapd(context=self.rc.context) paw.assert_called_once() request = paw.mock_calls[0].args[2] self.assertEqual(request.action, snapdapi.SnapAction.SWITCH) - self.assertEqual(request.channel, 'newchan') + self.assertEqual(request.channel, "newchan") # Test with the snap not following the expected channel - subiquity_info.channel = 'something-custom' + subiquity_info.channel = "something-custom" - with mock.patch.object(self.rc, 'get_refresh_channel') as grc: - with mock.patch.object( - self.rc.app.snapdapi.v2, 'snaps', new=stub_snaps): - grc.return_value = ( - 'newchan', SnapChannelSource.DISK_INFO_FILE) - with mock.patch.object(refresh_mod, 'post_and_wait') as paw: + with mock.patch.object(self.rc, "get_refresh_channel") as grc: + with mock.patch.object(self.rc.app.snapdapi.v2, "snaps", new=stub_snaps): + grc.return_value = ("newchan", SnapChannelSource.DISK_INFO_FILE) + with mock.patch.object(refresh_mod, "post_and_wait") as paw: await self.rc.configure_snapd(context=self.rc.context) paw.assert_not_called() diff --git a/subiquity/server/controllers/tests/test_snaplist.py b/subiquity/server/controllers/tests/test_snaplist.py index c0fad271..a31e8e5a 100644 --- a/subiquity/server/controllers/tests/test_snaplist.py +++ b/subiquity/server/controllers/tests/test_snaplist.py @@ -13,15 +13,16 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import requests import unittest from unittest.mock import AsyncMock, Mock +import requests + +from subiquity.models.snaplist import SnapListModel from subiquity.server.controllers.snaplist import ( SnapdSnapInfoLoader, SnapListFetchError, ) -from subiquity.models.snaplist import SnapListModel from subiquitycore.tests.mocks import make_app @@ -34,7 +35,8 @@ class TestSnapdSnapInfoLoader(unittest.IsolatedAsyncioTestCase): self.app.report_finish_event = Mock() self.loader = SnapdSnapInfoLoader( - self.model, self.app.snapd, "server", self.app.context) + self.model, self.app.snapd, "server", self.app.context + ) async def test_list_task_not_started(self): self.assertFalse(self.loader.fetch_list_completed()) diff --git a/subiquity/server/controllers/tests/test_source.py b/subiquity/server/controllers/tests/test_source.py index 8541d4ac..9fd97116 100644 --- a/subiquity/server/controllers/tests/test_source.py +++ b/subiquity/server/controllers/tests/test_source.py @@ -26,26 +26,21 @@ def make_entry(**kw): class TestSubiquityModel(unittest.TestCase): - def test_convert_source(self): entry = make_entry() - source = convert_source(entry, 'C') + source = convert_source(entry, "C") self.assertEqual(source.id, entry.id) def test_convert_translations(self): entry = make_entry( name={ - 'en': 'English', - 'fr': 'French', - 'fr_CA': 'French Canadian', - }) - self.assertEqual( - convert_source(entry, 'C').name, "English") - self.assertEqual( - convert_source(entry, 'en').name, "English") - self.assertEqual( - convert_source(entry, 'fr').name, "French") - self.assertEqual( - convert_source(entry, 'fr_CA').name, "French Canadian") - self.assertEqual( - convert_source(entry, 'fr_BE').name, "French") + "en": "English", + "fr": "French", + "fr_CA": "French Canadian", + } + ) + self.assertEqual(convert_source(entry, "C").name, "English") + self.assertEqual(convert_source(entry, "en").name, "English") + self.assertEqual(convert_source(entry, "fr").name, "French") + self.assertEqual(convert_source(entry, "fr_CA").name, "French Canadian") + self.assertEqual(convert_source(entry, "fr_BE").name, "French") diff --git a/subiquity/server/controllers/tests/test_ssh.py b/subiquity/server/controllers/tests/test_ssh.py index 944e38c4..a141cfaf 100644 --- a/subiquity/server/controllers/tests/test_ssh.py +++ b/subiquity/server/controllers/tests/test_ssh.py @@ -32,35 +32,42 @@ class TestSSHController(unittest.IsolatedAsyncioTestCase): async def test_fetch_id_GET_ok(self): key = "ssh-rsa AAAAA[..] user@host # ssh-import-id lp:user" mock_fetch_keys = mock.patch.object( - self.controller.fetcher, "fetch_keys_for_id", - return_value=[key]) + self.controller.fetcher, "fetch_keys_for_id", return_value=[key] + ) fp = "256 SHA256:rIR9[..] user@host # ssh-import-id lp:user (ED25519)" mock_gen_fingerprint = mock.patch.object( - self.controller.fetcher, "gen_fingerprint_for_key", - return_value=fp) + self.controller.fetcher, "gen_fingerprint_for_key", return_value=fp + ) with mock_fetch_keys, mock_gen_fingerprint: response = await self.controller.fetch_id_GET(user_id="lp:user") self.assertIsInstance(response, SSHFetchIdResponse) self.assertEqual(response.status, SSHFetchIdStatus.OK) - self.assertEqual(response.identities, [SSHIdentity( - key_type="ssh-rsa", - key="AAAAA[..]", - key_comment="user@host # ssh-import-id lp:user", - key_fingerprint="256 SHA256:rIR9[..] user@host (ED25519)", - )]) + self.assertEqual( + response.identities, + [ + SSHIdentity( + key_type="ssh-rsa", + key="AAAAA[..]", + key_comment="user@host # ssh-import-id lp:user", + key_fingerprint="256 SHA256:rIR9[..] user@host (ED25519)", + ) + ], + ) self.assertIsNone(response.error) async def test_fetch_id_GET_import_error(self): stderr = "ERROR No matching keys found for [lp=test2]\n" mock_fetch_keys = mock.patch.object( - self.controller.fetcher, "fetch_keys_for_id", - side_effect=SSHFetchError( - status=SSHFetchIdStatus.IMPORT_ERROR, - reason=stderr)) + self.controller.fetcher, + "fetch_keys_for_id", + side_effect=SSHFetchError( + status=SSHFetchIdStatus.IMPORT_ERROR, reason=stderr + ), + ) with mock_fetch_keys: response = await self.controller.fetch_id_GET(user_id="test2") @@ -72,21 +79,22 @@ class TestSSHController(unittest.IsolatedAsyncioTestCase): async def test_fetch_id_GET_fingerprint_error(self): key = "ssh-rsa AAAAA[..] user@host # ssh-import-id lp:user" mock_fetch_keys = mock.patch.object( - self.controller.fetcher, "fetch_keys_for_id", - return_value=[key]) + self.controller.fetcher, "fetch_keys_for_id", return_value=[key] + ) stderr = "(stdin) is not a public key file\n" mock_gen_fingerprint = mock.patch.object( - self.controller.fetcher, "gen_fingerprint_for_key", - side_effect=SSHFetchError( - status=SSHFetchIdStatus.FINGERPRINT_ERROR, - reason=stderr)) + self.controller.fetcher, + "gen_fingerprint_for_key", + side_effect=SSHFetchError( + status=SSHFetchIdStatus.FINGERPRINT_ERROR, reason=stderr + ), + ) with mock_fetch_keys, mock_gen_fingerprint: response = await self.controller.fetch_id_GET(user_id="test2") - self.assertEqual(response.status, - SSHFetchIdStatus.FINGERPRINT_ERROR) + self.assertEqual(response.status, SSHFetchIdStatus.FINGERPRINT_ERROR) self.assertEqual(response.error, stderr) self.assertIsNone(response.identities) diff --git a/subiquity/server/controllers/tests/test_ubuntu_pro.py b/subiquity/server/controllers/tests/test_ubuntu_pro.py index 99a014c8..2de2f31d 100644 --- a/subiquity/server/controllers/tests/test_ubuntu_pro.py +++ b/subiquity/server/controllers/tests/test_ubuntu_pro.py @@ -15,9 +15,7 @@ import unittest -from subiquity.server.controllers.ubuntu_pro import ( - UbuntuProController, -) +from subiquity.server.controllers.ubuntu_pro import UbuntuProController from subiquity.server.dryrun import DRConfig from subiquitycore.tests.mocks import make_app diff --git a/subiquity/server/controllers/tests/test_userdata.py b/subiquity/server/controllers/tests/test_userdata.py index 4900f7d9..9c2f471e 100644 --- a/subiquity/server/controllers/tests/test_userdata.py +++ b/subiquity/server/controllers/tests/test_userdata.py @@ -15,16 +15,17 @@ import unittest -from subiquity.server.controllers.userdata import ( - UserdataController, -) +from cloudinit.config.schema import SchemaValidationError + +from subiquity.server.controllers.userdata import UserdataController from subiquitycore.tests.mocks import make_app -from cloudinit.config.schema import SchemaValidationError try: from cloudinit.config.schema import SchemaProblem except ImportError: - def SchemaProblem(x, y): return (x, y) # TODO(drop on cloud-init 22.3 SRU) + + def SchemaProblem(x, y): + return (x, y) # TODO(drop on cloud-init 22.3 SRU) class TestUserdataController(unittest.TestCase): @@ -32,7 +33,7 @@ class TestUserdataController(unittest.TestCase): self.controller = UserdataController(make_app()) def test_load_autoinstall_data(self): - with self.subTest('Valid user-data resets userdata model'): + with self.subTest("Valid user-data resets userdata model"): valid_schema = {"ssh_import_id": ["you"]} self.controller.model = {"some": "old userdata"} self.controller.load_autoinstall_data(valid_schema) @@ -40,16 +41,13 @@ class TestUserdataController(unittest.TestCase): fake_error = SchemaValidationError( schema_errors=( - SchemaProblem( - "ssh_import_id", - "'wrong' is not of type 'array'" - ), + SchemaProblem("ssh_import_id", "'wrong' is not of type 'array'"), ), ) invalid_schema = {"ssh_import_id": "wrong"} validate = self.controller.app.base_model.validate_cloudconfig_schema validate.side_effect = fake_error - with self.subTest('Invalid user-data raises error'): + with self.subTest("Invalid user-data raises error"): with self.assertRaises(SchemaValidationError) as ctx: self.controller.load_autoinstall_data(invalid_schema) expected_error = ( @@ -58,6 +56,5 @@ class TestUserdataController(unittest.TestCase): ) self.assertEqual(expected_error, str(ctx.exception)) validate.assert_called_with( - data=invalid_schema, - data_source="autoinstall.user-data" + data=invalid_schema, data_source="autoinstall.user-data" ) diff --git a/subiquity/server/controllers/timezone.py b/subiquity/server/controllers/timezone.py index 27665ddd..be125683 100644 --- a/subiquity/server/controllers/timezone.py +++ b/subiquity/server/controllers/timezone.py @@ -14,76 +14,75 @@ # along with this program. If not, see . import logging +import os import subprocess +from shutil import which from subiquity.common.apidef import API from subiquity.common.types import TimeZoneInfo from subiquity.server.controller import SubiquityController -from shutil import which -import os -log = logging.getLogger('subiquity.server.controllers.timezone') +log = logging.getLogger("subiquity.server.controllers.timezone") def active_timedatectl(): - return which('timedatectl') and os.path.exists('/run/systemd/system') + return which("timedatectl") and os.path.exists("/run/systemd/system") def generate_possible_tzs(): - special_keys = ['', 'geoip'] + special_keys = ["", "geoip"] if not active_timedatectl(): return special_keys - tzcmd = ['timedatectl', 'list-timezones'] + tzcmd = ["timedatectl", "list-timezones"] list_tz_out = subprocess.check_output(tzcmd, universal_newlines=True) real_tzs = list_tz_out.splitlines() return special_keys + real_tzs def timedatectl_settz(app, tz): - tzcmd = ['timedatectl', 'set-timezone', tz] + tzcmd = ["timedatectl", "set-timezone", tz] if app.opts.dry_run: - tzcmd = ['sleep', str(1/app.scale_factor)] + tzcmd = ["sleep", str(1 / app.scale_factor)] try: subprocess.run(tzcmd, universal_newlines=True) except subprocess.CalledProcessError as cpe: - log.error('Failed to set live system timezone: %r', cpe) + log.error("Failed to set live system timezone: %r", cpe) def timedatectl_gettz(): # timedatectl show would be easier, but isn't on bionic - tzcmd = ['timedatectl', 'status'] - env = {'LC_ALL': 'C'} + tzcmd = ["timedatectl", "status"] + env = {"LC_ALL": "C"} # ... # Time zone: America/Denver (MDT, -0600) # ... try: out = subprocess.check_output(tzcmd, env=env, universal_newlines=True) for line in out.splitlines(): - chunks = line.split(':') + chunks = line.split(":") label = chunks[0].strip() - if label != 'Time zone': + if label != "Time zone": continue - chunks = chunks[1].split(' ') + chunks = chunks[1].split(" ") return chunks[1] except subprocess.CalledProcessError as cpe: - log.error('Failed to get live system timezone: %r', cpe) + log.error("Failed to get live system timezone: %r", cpe) except IndexError: - log.error('Failed to acquire system time zone') - log.debug('Failed to find Time zone in timedatectl output') - return 'Etc/UTC' + log.error("Failed to acquire system time zone") + log.debug("Failed to find Time zone in timedatectl output") + return "Etc/UTC" class TimeZoneController(SubiquityController): - endpoint = API.timezone - autoinstall_key = model_name = 'timezone' + autoinstall_key = model_name = "timezone" autoinstall_schema = { - 'type': 'string', - } + "type": "string", + } - autoinstall_default = '' + autoinstall_default = "" def __init__(self, *args, **kwargs): self.possible = None @@ -123,8 +122,7 @@ class TimeZoneController(SubiquityController): async def GET(self) -> TimeZoneInfo: # if someone POSTed before, return that if self.model.timezone: - return TimeZoneInfo(self.model.timezone, - self.model.got_from_geoip) + return TimeZoneInfo(self.model.timezone, self.model.got_from_geoip) # GET requests geoip results if self.app.geoip.timezone: diff --git a/subiquity/server/controllers/ubuntu_pro.py b/subiquity/server/controllers/ubuntu_pro.py index e14dcd7a..238b9cd6 100644 --- a/subiquity/server/controllers/ubuntu_pro.py +++ b/subiquity/server/controllers/ubuntu_pro.py @@ -21,28 +21,25 @@ from typing import Optional from subiquity.common.apidef import API from subiquity.common.types import ( - UbuntuProInfo, UbuntuProCheckTokenAnswer, UbuntuProCheckTokenStatus, + UbuntuProInfo, UbuntuProResponse, UPCSInitiateResponse, - UPCSWaitStatus, UPCSWaitResponse, + UPCSWaitStatus, ) +from subiquity.server.contract_selection import ContractSelection, UPCSExpiredError +from subiquity.server.controller import SubiquityController from subiquity.server.ubuntu_advantage import ( - InvalidTokenError, - ExpiredTokenError, CheckSubscriptionError, - UAInterface, - UAInterfaceStrategy, + ExpiredTokenError, + InvalidTokenError, MockedUAInterfaceStrategy, UAClientUAInterfaceStrategy, + UAInterface, + UAInterfaceStrategy, ) -from subiquity.server.contract_selection import ( - ContractSelection, - UPCSExpiredError, - ) -from subiquity.server.controller import SubiquityController log = logging.getLogger("subiquity.server.controllers.ubuntu_pro") @@ -52,21 +49,21 @@ See https://pkg.go.dev/github.com/btcsuite/btcutil/base58#CheckEncode""" class UPCSAlreadyInitiatedError(Exception): - """ Exception to be raised when trying to initiate a contract selection - while another contract selection is already pending. """ + """Exception to be raised when trying to initiate a contract selection + while another contract selection is already pending.""" class UPCSCancelledError(Exception): - """ Exception to be raised when a contract selection got cancelled. """ + """Exception to be raised when a contract selection got cancelled.""" class UPCSNotInitiatedError(Exception): - """ Exception to be raised when trying to cancel or wait on a contract - selection that was not initiated. """ + """Exception to be raised when trying to cancel or wait on a contract + selection that was not initiated.""" class UbuntuProController(SubiquityController): - """ Represent the server-side Ubuntu Pro controller. """ + """Represent the server-side Ubuntu Pro controller.""" endpoint = API.ubuntu_pro @@ -87,7 +84,7 @@ class UbuntuProController(SubiquityController): } def __init__(self, app) -> None: - """ Initializer for server-side Ubuntu Pro controller. """ + """Initializer for server-side Ubuntu Pro controller.""" strategy: UAInterfaceStrategy if app.opts.dry_run: contracts_url = app.dr_cfg.pro_ua_contracts_url @@ -97,8 +94,7 @@ class UbuntuProController(SubiquityController): strategy.load_default_uaclient_config() strategy.uaclient_config["contract_url"] = contracts_url else: - strategy = MockedUAInterfaceStrategy( - scale_factor=app.scale_factor) + strategy = MockedUAInterfaceStrategy(scale_factor=app.scale_factor) else: # Make sure we execute `$PYTHON "$SNAP/usr/bin/ubuntu-advantage"`. executable = ( @@ -112,58 +108,53 @@ class UbuntuProController(SubiquityController): super().__init__(app) def load_autoinstall_data(self, data: dict) -> None: - """ Load autoinstall data and update the model. """ + """Load autoinstall data and update the model.""" if data is None: return self.model.token = data.get("token", "") def make_autoinstall(self) -> dict: - """ Return a dictionary that can be used as an autoinstall snippet for + """Return a dictionary that can be used as an autoinstall snippet for Ubuntu Pro. """ if not self.model.token: return {} - return { - "token": self.model.token - } + return {"token": self.model.token} def serialize(self) -> str: - """ Save the current state of the model so it can be loaded later. + """Save the current state of the model so it can be loaded later. Currently this function is called automatically by .configured(). """ return self.model.token def deserialize(self, token: str) -> None: - """ Loads the last-known state of the model. """ + """Loads the last-known state of the model.""" self.model.token = token async def GET(self) -> UbuntuProResponse: - """ Handle a GET request coming from the client-side controller. """ + """Handle a GET request coming from the client-side controller.""" has_network = self.app.base_model.network.has_network - return UbuntuProResponse(token=self.model.token, - has_network=has_network) + return UbuntuProResponse(token=self.model.token, has_network=has_network) async def POST(self, data: UbuntuProInfo) -> None: - """ Handle a POST request coming from the client-side controller and + """Handle a POST request coming from the client-side controller and then call .configured(). """ self.model.token = data.token await self.configured() async def skip_POST(self) -> None: - """ When running on a non-LTS release, we want to call this so we can - skip the screen on the client side. """ + """When running on a non-LTS release, we want to call this so we can + skip the screen on the client side.""" await self.configured() - async def check_token_GET(self, token: str) \ - -> UbuntuProCheckTokenAnswer: - """ Handle a GET request asking whether the contract token is valid or + async def check_token_GET(self, token: str) -> UbuntuProCheckTokenAnswer: + """Handle a GET request asking whether the contract token is valid or not. If it is valid, we provide the information about the subscription. """ subscription = None try: - subscription = await \ - self.ua_interface.get_subscription(token=token) + subscription = await self.ua_interface.get_subscription(token=token) except InvalidTokenError: status = UbuntuProCheckTokenStatus.INVALID_TOKEN except ExpiredTokenError: @@ -173,38 +164,36 @@ class UbuntuProController(SubiquityController): else: status = UbuntuProCheckTokenStatus.VALID_TOKEN - return UbuntuProCheckTokenAnswer(status=status, - subscription=subscription) + return UbuntuProCheckTokenAnswer(status=status, subscription=subscription) async def contract_selection_initiate_POST(self) -> UPCSInitiateResponse: - """ Initiate the contract selection request and start the polling. """ + """Initiate the contract selection request and start the polling.""" if self.cs and not self.cs.task.done(): raise UPCSAlreadyInitiatedError self.cs = await ContractSelection.initiate(client=self.ua_interface) return UPCSInitiateResponse( - user_code=self.cs.user_code, - validity_seconds=self.cs.validity_seconds) + user_code=self.cs.user_code, validity_seconds=self.cs.validity_seconds + ) async def contract_selection_wait_GET(self) -> UPCSWaitResponse: - """ Block until the contract selection finishes or times out. + """Block until the contract selection finishes or times out. If the contract selection is successful, the contract token is included - in the response. """ + in the response.""" if self.cs is None: raise UPCSNotInitiatedError try: return UPCSWaitResponse( - status=UPCSWaitStatus.SUCCESS, - contract_token=await asyncio.shield(self.cs.task)) + status=UPCSWaitStatus.SUCCESS, + contract_token=await asyncio.shield(self.cs.task), + ) except UPCSExpiredError: - return UPCSWaitResponse( - status=UPCSWaitStatus.TIMEOUT, - contract_token=None) + return UPCSWaitResponse(status=UPCSWaitStatus.TIMEOUT, contract_token=None) async def contract_selection_cancel_POST(self) -> None: - """ Cancel the currently ongoing contract selection. """ + """Cancel the currently ongoing contract selection.""" if self.cs is None: raise UPCSNotInitiatedError diff --git a/subiquity/server/controllers/updates.py b/subiquity/server/controllers/updates.py index 140c899c..4a75a2ce 100644 --- a/subiquity/server/controllers/updates.py +++ b/subiquity/server/controllers/updates.py @@ -18,22 +18,20 @@ import logging from subiquity.common.apidef import API from subiquity.server.controller import SubiquityController - -log = logging.getLogger('subiquity.server.controllers.updates') +log = logging.getLogger("subiquity.server.controllers.updates") class UpdatesController(SubiquityController): - endpoint = API.updates - possible = ['security', 'all'] + possible = ["security", "all"] autoinstall_key = model_name = "updates" autoinstall_schema = { - 'type': 'string', - 'enum': possible, + "type": "string", + "enum": possible, } - autoinstall_default = 'security' + autoinstall_default = "security" def load_autoinstall_data(self, data): self.deserialize(data) diff --git a/subiquity/server/controllers/userdata.py b/subiquity/server/controllers/userdata.py index c74dda23..8770e78e 100644 --- a/subiquity/server/controllers/userdata.py +++ b/subiquity/server/controllers/userdata.py @@ -21,19 +21,19 @@ log = logging.getLogger("subiquity.server.controllers.userdata") class UserdataController(NonInteractiveController): - - model_name = 'userdata' + model_name = "userdata" autoinstall_key = "user-data" autoinstall_default = {} autoinstall_schema = { - 'type': 'object', - } + "type": "object", + } def load_autoinstall_data(self, data): self.model.clear() if data: self.app.base_model.validate_cloudconfig_schema( - data=data, data_source="autoinstall.user-data", + data=data, + data_source="autoinstall.user-data", ) self.model.update(data) diff --git a/subiquity/server/controllers/zdev.py b/subiquity/server/controllers/zdev.py index 990eac8c..64472241 100644 --- a/subiquity/server/controllers/zdev.py +++ b/subiquity/server/controllers/zdev.py @@ -17,21 +17,23 @@ import asyncio import logging import platform import random -from typing import List - from collections import OrderedDict - -from subiquitycore.utils import arun_command, run_command +from typing import List from subiquity.common.apidef import API from subiquity.common.types import Bootloader, ZdevInfo from subiquity.server.controller import SubiquityController - +from subiquitycore.utils import arun_command, run_command log = logging.getLogger("subiquity.server.controllers.zdev") -lszdev_cmd = ['lszdev', '--quiet', '--pairs', '--columns', - 'id,type,on,exists,pers,auto,failed,names'] +lszdev_cmd = [ + "lszdev", + "--quiet", + "--pairs", + "--columns", + "id,type,on,exists,pers,auto,failed,names", +] lszdev_stock = '''id="0.0.1500" type="dasd-eckd" on="no" exists="yes" pers="no" auto="no" failed="no" names="" id="0.0.1501" type="dasd-eckd" on="yes" exists="yes" pers="auto" auto="yes" failed="no" names="" @@ -591,13 +593,12 @@ id="0.0.c0fe" type="generic-ccw" on="no" exists="yes" pers="no" auto="no" failed class ZdevController(SubiquityController): - endpoint = API.zdev def __init__(self, app): super().__init__(app) if self.opts.dry_run: - if platform.machine() == 's390x': + if platform.machine() == "s390x": zdevinfos = self.lszdev() else: devices = lszdev_stock.splitlines() @@ -612,12 +613,12 @@ class ZdevController(SubiquityController): async def chzdev_POST(self, action: str, zdev: ZdevInfo) -> List[ZdevInfo]: if self.opts.dry_run: - await asyncio.sleep(random.random()*0.4) - on = action == 'enable' + await asyncio.sleep(random.random() * 0.4) + on = action == "enable" self.zdevinfos[zdev.id].on = on self.zdevinfos[zdev.id].pers = on else: - chzdev_cmd = ['chzdev', '--%s' % action, zdev.id] + chzdev_cmd = ["chzdev", "--%s" % action, zdev.id] await arun_command(chzdev_cmd) return await self.GET() diff --git a/subiquity/server/curtin.py b/subiquity/server/curtin.py index 3bee2ac4..38f49c36 100644 --- a/subiquity/server/curtin.py +++ b/subiquity/server/curtin.py @@ -22,29 +22,29 @@ import re import subprocess import sys from typing import Dict, List, Type + import yaml +from subiquity.journald import journald_listen from subiquitycore.context import Context, Status -from subiquity.journald import ( - journald_listen, - ) - -log = logging.getLogger('subiquity.server.curtin') +log = logging.getLogger("subiquity.server.curtin") class _CurtinCommand: - _count = 0 - def __init__(self, opts, runner, command: str, *args: str, - config=None, private_mounts: bool): + def __init__( + self, opts, runner, command: str, *args: str, config=None, private_mounts: bool + ): self.opts = opts self.runner = runner self._event_contexts: Dict[str, Context] = {} _CurtinCommand._count += 1 - self._event_syslog_id = 'curtin_event.%s.%s' % ( - os.getpid(), _CurtinCommand._count) + self._event_syslog_id = "curtin_event.%s.%s" % ( + os.getpid(), + _CurtinCommand._count, + ) self._fd = None self.proc = None self._cmd = self.make_command(command, *args, config=config) @@ -56,17 +56,18 @@ class _CurtinCommand: "MESSAGE": "???", "NAME": "???", "RESULT": "???", - } + } prefix = "CURTIN_" for k, v in event.items(): if k.startswith(prefix): - e[k[len(prefix):]] = v + e[k[len(prefix) :]] = v event_type = e["EVENT_TYPE"] - if event_type == 'start': + if event_type == "start": + def p(name): - parts = name.split('/') + parts = name.split("/") for i in range(len(parts), -1, -1): - yield '/'.join(parts[:i]), '/'.join(parts[i:]) + yield "/".join(parts[:i]), "/".join(parts[i:]) curtin_ctx = None for pre, post in p(e["NAME"]): @@ -77,7 +78,7 @@ class _CurtinCommand: break if curtin_ctx: curtin_ctx.enter() - if event_type == 'finish': + if event_type == "finish": status = getattr(Status, e["RESULT"], Status.WARN) curtin_ctx = self._event_contexts.pop(e["NAME"], None) if curtin_ctx is not None: @@ -85,19 +86,27 @@ class _CurtinCommand: def make_command(self, command: str, *args: str, config=None) -> List[str]: reporting_conf = { - 'subiquity': { - 'type': 'journald', - 'identifier': self._event_syslog_id, - }, - } + "subiquity": { + "type": "journald", + "identifier": self._event_syslog_id, + }, + } cmd = [ - sys.executable, '-m', 'curtin', '--showtrace', '-vvv', - '--set', 'json:reporting=' + json.dumps(reporting_conf), - ] + sys.executable, + "-m", + "curtin", + "--showtrace", + "-vvv", + "--set", + "json:reporting=" + json.dumps(reporting_conf), + ] if config is not None: - cmd.extend([ - '-c', config, - ]) + cmd.extend( + [ + "-c", + config, + ] + ) cmd.append(command) cmd.extend(args) return cmd @@ -107,9 +116,10 @@ class _CurtinCommand: # Yield to the event loop before starting curtin to avoid missing the # first couple of events. await asyncio.sleep(0) - self._event_contexts[''] = context + self._event_contexts[""] = context self.proc = await self.runner.start( - self._cmd, **opts, private_mounts=self.private_mounts) + self._cmd, **opts, private_mounts=self.private_mounts + ) async def wait(self): result = await self.runner.wait(self.proc) @@ -118,7 +128,7 @@ class _CurtinCommand: await asyncio.sleep(0.1) waited += 0.1 log.debug("waited %s seconds for events to drain", waited) - self._event_contexts.pop('', None) + self._event_contexts.pop("", None) asyncio.get_running_loop().remove_reader(self._fd) return result @@ -128,7 +138,6 @@ class _CurtinCommand: class _DryRunCurtinCommand(_CurtinCommand): - stages_mapping = { tuple(): "initial.json", # no stage ("partitioning",): "partitioning.json", @@ -138,7 +147,7 @@ class _DryRunCurtinCommand(_CurtinCommand): } def make_command(self, command, *args, config=None): - if command == 'install': + if command == "install": # Lookup the log file from the config if specified try: with open(config, mode="r") as fh: @@ -156,54 +165,65 @@ class _DryRunCurtinCommand(_CurtinCommand): cmd = [ sys.executable, "scripts/replay-curtin-log.py", - "--event-identifier", self._event_syslog_id, - "--output", log_file, - ] + "--event-identifier", + self._event_syslog_id, + "--output", + log_file, + ] if config: - cmd.extend(['--config', config]) + cmd.extend(["--config", config]) event_log_filename = self.stages_mapping[tuple(stages)] - cmd.extend([ - "--", - f"examples/curtin-events/{event_log_filename}", - ]) + cmd.extend( + [ + "--", + f"examples/curtin-events/{event_log_filename}", + ] + ) return cmd else: return super().make_command(command, *args, config=config) class _FailingDryRunCurtinCommand(_DryRunCurtinCommand): - stages_mapping = { - **_DryRunCurtinCommand.stages_mapping, - **{("extract",): "curtin-events-fail.json"} + **_DryRunCurtinCommand.stages_mapping, + **{("extract",): "curtin-events-fail.json"}, } -async def start_curtin_command(app, context, - command: str, *args: str, - config=None, private_mounts: bool, - **opts) -> _CurtinCommand: +async def start_curtin_command( + app, context, command: str, *args: str, config=None, private_mounts: bool, **opts +) -> _CurtinCommand: cls: Type[_CurtinCommand] if app.opts.dry_run: - if 'install-fail' in app.debug_flags: + if "install-fail" in app.debug_flags: cls = _FailingDryRunCurtinCommand else: cls = _DryRunCurtinCommand else: cls = _CurtinCommand - curtin_cmd = cls(app.opts, app.command_runner, command, *args, - config=config, private_mounts=private_mounts) + curtin_cmd = cls( + app.opts, + app.command_runner, + command, + *args, + config=config, + private_mounts=private_mounts, + ) await curtin_cmd.start(context, **opts) return curtin_cmd async def run_curtin_command( - app, context, - command: str, *args: str, - config=None, - private_mounts: bool, - **opts) -> subprocess.CompletedProcess: + app, context, command: str, *args: str, config=None, private_mounts: bool, **opts +) -> subprocess.CompletedProcess: cmd = await start_curtin_command( - app, context, command, *args, - config=config, private_mounts=private_mounts, **opts) + app, + context, + command, + *args, + config=config, + private_mounts=private_mounts, + **opts, + ) return await cmd.wait() diff --git a/subiquity/server/dryrun.py b/subiquity/server/dryrun.py index 20b3f485..cacf7baf 100644 --- a/subiquity/server/dryrun.py +++ b/subiquity/server/dryrun.py @@ -14,24 +14,24 @@ # along with this program. If not, see . from typing import List, Optional, TypedDict -import yaml import attr +import yaml class DryRunController: - def __init__(self, app): self.app = app self.context = app.context.child("DryRun") async def crash_GET(self) -> None: - 1/0 + 1 / 0 class KnownMirror(TypedDict, total=False): - """ Dictionary type hints for a known mirror. Either url or pattern should - be specified. """ + """Dictionary type hints for a known mirror. Either url or pattern should + be specified.""" + url: str pattern: str @@ -39,16 +39,17 @@ class KnownMirror(TypedDict, total=False): class SSHImport(TypedDict, total=True): - """ Dictionary type hints for a SSH key import. """ + """Dictionary type hints for a SSH key import.""" + strategy: str username: str @attr.s(auto_attribs=True) class DRConfig: - """ Configuration for dry-run-only executions. + """Configuration for dry-run-only executions. All variables here should have default values ; to indicate the behavior we - want by default in dry-run mode. """ + want by default in dry-run mode.""" # Tells whether "$source"/var/lib/snapd/seed/systems exists on the source. systems_dir_exists: bool = False @@ -61,14 +62,15 @@ class DRConfig: apt_mirror_check_default_strategy: str = "run-on-host" apt_mirrors_known: List[KnownMirror] = [ - {"pattern": r"https?://archive\.ubuntu\.com/ubuntu/?", - "strategy": "success"}, - {"pattern": r"https?://[a-z]{2,}\.archive\.ubuntu\.com/ubuntu/?", - "strategy": "success"}, - {"pattern": r"/success/?$", "strategy": "success"}, - {"pattern": r"/rand(om)?/?$", "strategy": "random"}, - {"pattern": r"/host/?$", "strategy": "run-on-host"}, - {"pattern": r"/fail(ed)?/?$", "strategy": "failure"}, + {"pattern": r"https?://archive\.ubuntu\.com/ubuntu/?", "strategy": "success"}, + { + "pattern": r"https?://[a-z]{2,}\.archive\.ubuntu\.com/ubuntu/?", + "strategy": "success", + }, + {"pattern": r"/success/?$", "strategy": "success"}, + {"pattern": r"/rand(om)?/?$", "strategy": "random"}, + {"pattern": r"/host/?$", "strategy": "run-on-host"}, + {"pattern": r"/fail(ed)?/?$", "strategy": "failure"}, ] ssh_import_default_strategy: str = "run-on-host" @@ -79,8 +81,9 @@ class DRConfig: # If running ubuntu-drivers on the host, supply a file to # umockdev-wrapper.py - ubuntu_drivers_run_on_host_umockdev: Optional[str] = \ - "examples/umockdev/dell-certified+nvidia.yaml" + ubuntu_drivers_run_on_host_umockdev: Optional[ + str + ] = "examples/umockdev/dell-certified+nvidia.yaml" @classmethod def load(cls, stream): diff --git a/subiquity/server/errors.py b/subiquity/server/errors.py index 794fff67..859b5bb5 100644 --- a/subiquity/server/errors.py +++ b/subiquity/server/errors.py @@ -13,11 +13,10 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from subiquity.common.types import ErrorReportState, ErrorReportRef +from subiquity.common.types import ErrorReportRef, ErrorReportState class ErrorController: - def __init__(self, app): self.context = app.context.child("Error") self.error_reporter = app.error_reporter diff --git a/subiquity/server/geoip.py b/subiquity/server/geoip.py index f1c90313..e75b2322 100644 --- a/subiquity/server/geoip.py +++ b/subiquity/server/geoip.py @@ -13,19 +13,17 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from abc import ABC, abstractmethod -import aiohttp -import logging import enum +import logging +from abc import ABC, abstractmethod from xml.etree import ElementTree -from subiquitycore.async_helpers import ( - SingleInstanceTask, -) +import aiohttp from subiquity.server.types import InstallerChannels +from subiquitycore.async_helpers import SingleInstanceTask -log = logging.getLogger('subiquity.server.geoip') +log = logging.getLogger("subiquity.server.geoip") class CheckState(enum.IntEnum): @@ -36,17 +34,19 @@ class CheckState(enum.IntEnum): class GeoIPStrategy(ABC): - """ Base class for strategies (e.g. HTTP or dry-run) to retrieve the GeoIP - information. """ + """Base class for strategies (e.g. HTTP or dry-run) to retrieve the GeoIP + information.""" + @abstractmethod async def get_response(self) -> str: - """ Return the GeoIP information as an XML document. """ + """Return the GeoIP information as an XML document.""" class DryRunGeoIPStrategy(GeoIPStrategy): - """ Dry-run implementation to retrieve GeoIP information. """ + """Dry-run implementation to retrieve GeoIP information.""" + async def get_response(self) -> str: - """ Return the GeoIP information as an XML document. """ + """Return the GeoIP information as an XML document.""" return """\ @@ -68,8 +68,9 @@ class DryRunGeoIPStrategy(GeoIPStrategy): class HTTPGeoIPStrategy(GeoIPStrategy): - """ HTTP implementation to retrieve GeoIP information. We use the - geoip.ubuntu.com service. """ + """HTTP implementation to retrieve GeoIP information. We use the + geoip.ubuntu.com service.""" + async def get_response(self) -> str: url = "https://geoip.ubuntu.com/lookup" async with aiohttp.ClientSession() as session: @@ -86,10 +87,10 @@ class GeoIP: self.tz = None self.check_state = CheckState.NOT_STARTED self.lookup_task = SingleInstanceTask(self.lookup) - self.app.hub.subscribe(InstallerChannels.NETWORK_UP, - self.maybe_start_check) - self.app.hub.subscribe(InstallerChannels.NETWORK_PROXY_SET, - self.maybe_start_check) + self.app.hub.subscribe(InstallerChannels.NETWORK_UP, self.maybe_start_check) + self.app.hub.subscribe( + InstallerChannels.NETWORK_PROXY_SET, self.maybe_start_check + ) self.strategy = strategy def maybe_start_check(self): diff --git a/subiquity/server/kernel.py b/subiquity/server/kernel.py index 68e136de..fed507b5 100644 --- a/subiquity/server/kernel.py +++ b/subiquity/server/kernel.py @@ -22,32 +22,39 @@ from subiquitycore.utils import arun_command def flavor_to_pkgname(flavor: str, *, dry_run: bool) -> str: - if flavor == 'generic': - return 'linux-generic' - if flavor == 'hwe': - flavor = 'generic-hwe' + if flavor == "generic": + return "linux-generic" + if flavor == "hwe": + flavor = "generic-hwe" - release = lsb_release(dry_run=dry_run)['release'] + release = lsb_release(dry_run=dry_run)["release"] # Should check this package exists really but # that's a bit tricky until we get cleverer about # the apt config in general. - return f'linux-{flavor}-{release}' + return f"linux-{flavor}-{release}" async def list_installed_kernels(rootfs: pathlib.Path) -> List[str]: - ''' Return the list of linux-image packages installed in rootfs. ''' + """Return the list of linux-image packages installed in rootfs.""" # TODO use python-apt instead coupled with rootdir. # Ideally, we should not hardcode var/lib/dpkg/status which is an # implementation detail. try: - cp = await arun_command([ - 'grep-status', - '--whole-pkg', - '-FProvides', 'linux-image', '--and', '-FStatus', 'installed', - '--show-field=Package', - '--no-field-names', - str(rootfs / pathlib.Path('var/lib/dpkg/status')), - ], check=True) + cp = await arun_command( + [ + "grep-status", + "--whole-pkg", + "-FProvides", + "linux-image", + "--and", + "-FStatus", + "installed", + "--show-field=Package", + "--no-field-names", + str(rootfs / pathlib.Path("var/lib/dpkg/status")), + ], + check=True, + ) except subprocess.CalledProcessError as cpe: # grep-status exits with status 1 when there is no match. if cpe.returncode != 1: diff --git a/subiquity/server/mounter.py b/subiquity/server/mounter.py index 56323a66..2139b7a5 100644 --- a/subiquity/server/mounter.py +++ b/subiquity/server/mounter.py @@ -17,20 +17,19 @@ import contextlib import functools import logging import os -from pathlib import Path import shutil import tempfile +from pathlib import Path from typing import List, Optional, Union import attr from subiquitycore.utils import arun_command -log = logging.getLogger('subiquity.server.mounter') +log = logging.getLogger("subiquity.server.mounter") class TmpFileSet: - def __init__(self): self._tdirs: List[str] = [] @@ -45,23 +44,22 @@ class TmpFileSet: shutil.rmtree(d) self._tdirs.remove(d) except OSError as ose: - log.warning(f'failed to rmtree {d}: {ose}') + log.warning(f"failed to rmtree {d}: {ose}") class OverlayCleanupError(Exception): - """ Exception to raise when an overlay could not be cleaned up. """ + """Exception to raise when an overlay could not be cleaned up.""" class _MountBase: - def p(self, *args: str) -> str: for a in args: - if a.startswith('/'): - raise Exception('no absolute paths here please') + if a.startswith("/"): + raise Exception("no absolute paths here please") return os.path.join(self.mountpoint, *args) def write(self, path, content): - with open(self.p(path), 'w') as fp: + with open(self.p(path), "w") as fp: fp.write(content) @@ -108,11 +106,10 @@ def _lowerdir_for_ovmnt(ovmnt): @lowerdir_for.register(list) def _lowerdir_for_lst(lst): - return ':'.join(reversed([lowerdir_for(item) for item in lst])) + return ":".join(reversed([lowerdir_for(item) for item in lst])) class Mounter: - def __init__(self, app): self.app = app self.tmpfiles = TmpFileSet() @@ -121,9 +118,9 @@ class Mounter: async def mount(self, device, mountpoint=None, options=None, type=None): opts = [] if options is not None: - opts.extend(['-o', options]) + opts.extend(["-o", options]) if type is not None: - opts.extend(['-t', type]) + opts.extend(["-t", type]) if mountpoint is None: mountpoint = tempfile.mkdtemp() created = True @@ -131,13 +128,14 @@ class Mounter: created = False else: path = Path(device) - if options == 'bind' and not path.is_dir(): + if options == "bind" and not path.is_dir(): Path(mountpoint).touch(exist_ok=False) else: os.makedirs(mountpoint, exist_ok=False) created = True await self.app.command_runner.run( - ['mount'] + opts + [device, mountpoint], private_mounts=False) + ["mount"] + opts + [device, mountpoint], private_mounts=False + ) m = Mountpoint(mountpoint=mountpoint, created=created) self._mounts.append(m) return m @@ -146,8 +144,8 @@ class Mounter: if remove: self._mounts.remove(mountpoint) await self.app.command_runner.run( - ['umount', mountpoint.mountpoint], - private_mounts=False) + ["umount", mountpoint.mountpoint], private_mounts=False + ) if mountpoint.created: path = Path(mountpoint.mountpoint) if path.is_dir(): @@ -158,22 +156,18 @@ class Mounter: async def setup_overlay(self, lowers: List[Lower]) -> OverlayMountpoint: tdir = self.tmpfiles.tdir() - target = f'{tdir}/mount' + target = f"{tdir}/mount" lowerdir = lowerdir_for(lowers) - upperdir = f'{tdir}/upper' - workdir = f'{tdir}/work' + upperdir = f"{tdir}/upper" + workdir = f"{tdir}/work" for d in target, workdir, upperdir: os.mkdir(d) - options = f'lowerdir={lowerdir},upperdir={upperdir},workdir={workdir}' + options = f"lowerdir={lowerdir},upperdir={upperdir},workdir={workdir}" - mount = await self.mount( - 'overlay', target, options=options, type='overlay') + mount = await self.mount("overlay", target, options=options, type="overlay") - return OverlayMountpoint( - lowers=lowers, - mountpoint=mount.p(), - upperdir=upperdir) + return OverlayMountpoint(lowers=lowers, mountpoint=mount.p(), upperdir=upperdir) async def cleanup(self): for m in reversed(self._mounts): @@ -184,7 +178,7 @@ class Mounter: """bind-mount files and directories from src that are not already present into dst""" if not os.path.exists(dst): - await self.mount(src, dst, options='bind') + await self.mount(src, dst, options="bind") return for src_dirpath, dirnames, filenames in os.walk(src): dst_dirpath = src_dirpath.replace(src, dst) @@ -193,7 +187,7 @@ class Mounter: if os.path.exists(dst_path): continue src_path = os.path.join(src_dirpath, name) - await self.mount(src_path, dst_path, options='bind') + await self.mount(src_path, dst_path, options="bind") if name in dirnames: dirnames.remove(name) @@ -207,7 +201,6 @@ class Mounter: class DryRunMounter(Mounter): - async def setup_overlay(self, lowers: List[Lower]) -> OverlayMountpoint: # XXX This implementation expects that: # - on first invocation, the lowers list contains a single string @@ -223,11 +216,14 @@ class DryRunMounter(Mounter): else: source = lowerdir target = self.tmpfiles.tdir() - os.mkdir(f'{target}/etc') - await arun_command([ - 'cp', '-aT', f'{source}/etc/apt', f'{target}/etc/apt', - ], check=True) - return OverlayMountpoint( - lowers=[source], - mountpoint=target, - upperdir=None) + os.mkdir(f"{target}/etc") + await arun_command( + [ + "cp", + "-aT", + f"{source}/etc/apt", + f"{target}/etc/apt", + ], + check=True, + ) + return OverlayMountpoint(lowers=[source], mountpoint=target, upperdir=None) diff --git a/subiquity/server/pkghelper.py b/subiquity/server/pkghelper.py index 7664f1e4..bd2c270e 100644 --- a/subiquity/server/pkghelper.py +++ b/subiquity/server/pkghelper.py @@ -20,11 +20,10 @@ from typing import Dict, Optional import apt +from subiquity.common.types import PackageInstallState from subiquitycore.utils import arun_command -from subiquity.common.types import PackageInstallState - -log = logging.getLogger('subiquity.server.pkghelper') +log = logging.getLogger("subiquity.server.pkghelper") class PackageInstaller: @@ -56,37 +55,36 @@ class PackageInstaller: def start_installing_pkg(self, pkgname: str) -> None: if pkgname not in self.pkgs: - self.pkgs[pkgname] = asyncio.create_task( - self._install_pkg(pkgname)) + self.pkgs[pkgname] = asyncio.create_task(self._install_pkg(pkgname)) async def install_pkg(self, pkgname) -> PackageInstallState: self.start_installing_pkg(pkgname) return await self.pkgs[pkgname] async def _install_pkg(self, pkgname: str) -> PackageInstallState: - log.debug('checking if %s is available', pkgname) + log.debug("checking if %s is available", pkgname) binpkg = self.cache.get(pkgname) if not binpkg: - log.debug('%s not found', pkgname) + log.debug("%s not found", pkgname) return PackageInstallState.NOT_AVAILABLE if binpkg.installed: - log.debug('%s already installed', pkgname) + log.debug("%s already installed", pkgname) return PackageInstallState.DONE - if not binpkg.candidate.uri.startswith('cdrom:'): + if not binpkg.candidate.uri.startswith("cdrom:"): log.debug( - '%s not available from cdrom (rather %s)', - pkgname, binpkg.candidate.uri) + "%s not available from cdrom (rather %s)", pkgname, binpkg.candidate.uri + ) return PackageInstallState.NOT_AVAILABLE env = os.environ.copy() - env['DEBIAN_FRONTEND'] = 'noninteractive' + env["DEBIAN_FRONTEND"] = "noninteractive" apt_opts = [ - '--quiet', '--assume-yes', - '--option=Dpkg::Options::=--force-unsafe-io', - '--option=Dpkg::Options::=--force-confold', - ] - cp = await arun_command( - ['apt-get', 'install'] + apt_opts + [pkgname], env=env) - log.debug('apt-get install %s returned %s', pkgname, cp) + "--quiet", + "--assume-yes", + "--option=Dpkg::Options::=--force-unsafe-io", + "--option=Dpkg::Options::=--force-confold", + ] + cp = await arun_command(["apt-get", "install"] + apt_opts + [pkgname], env=env) + log.debug("apt-get install %s returned %s", pkgname, cp) if cp.returncode == 0: return PackageInstallState.DONE else: diff --git a/subiquity/server/runner.py b/subiquity/server/runner.py index c6ab8683..8a4a95d5 100644 --- a/subiquity/server/runner.py +++ b/subiquity/server/runner.py @@ -14,21 +14,22 @@ # along with this program. If not, see . import asyncio -from contextlib import suppress import os import subprocess +from contextlib import suppress from typing import List, Optional from subiquitycore.utils import astart_command class LoggedCommandRunner: - """ Class that executes commands using systemd-run. """ - def __init__(self, ident, - *, use_systemd_user: Optional[bool] = None) -> None: + """Class that executes commands using systemd-run.""" + + def __init__(self, ident, *, use_systemd_user: Optional[bool] = None) -> None: self.ident = ident self.env_allowlist = [ - "PATH", "PYTHONPATH", + "PATH", + "PYTHONPATH", "PYTHON", "TARGET_MOUNT_POINT", "SNAP", @@ -39,14 +40,16 @@ class LoggedCommandRunner: else: self.use_systemd_user = os.geteuid() != 0 - def _forge_systemd_cmd(self, cmd: List[str], - private_mounts: bool, capture: bool) -> List[str]: - """ Return the supplied command prefixed with the systemd-run stuff. - """ + def _forge_systemd_cmd( + self, cmd: List[str], private_mounts: bool, capture: bool + ) -> List[str]: + """Return the supplied command prefixed with the systemd-run stuff.""" prefix = [ "systemd-run", - "--wait", "--same-dir", - "--property", f"SyslogIdentifier={self.ident}", + "--wait", + "--same-dir", + "--property", + f"SyslogIdentifier={self.ident}", ] if private_mounts: prefix.extend(("--property", "PrivateMounts=yes")) @@ -66,26 +69,30 @@ class LoggedCommandRunner: return prefix + cmd - async def start(self, cmd: List[str], - *, private_mounts: bool = False, capture: bool = False) \ - -> asyncio.subprocess.Process: + async def start( + self, cmd: List[str], *, private_mounts: bool = False, capture: bool = False + ) -> asyncio.subprocess.Process: forged: List[str] = self._forge_systemd_cmd( - cmd, private_mounts=private_mounts, capture=capture) + cmd, private_mounts=private_mounts, capture=capture + ) proc = await astart_command(forged) proc.args = forged return proc - async def wait(self, proc: asyncio.subprocess.Process) \ - -> subprocess.CompletedProcess: + async def wait( + self, proc: asyncio.subprocess.Process + ) -> subprocess.CompletedProcess: stdout, stderr = await proc.communicate() # .communicate() forces returncode to be set to a value assert proc.returncode is not None if proc.returncode != 0: raise subprocess.CalledProcessError( - proc.returncode, proc.args, output=stdout, stderr=stderr) + proc.returncode, proc.args, output=stdout, stderr=stderr + ) else: return subprocess.CompletedProcess( - proc.args, proc.returncode, stdout=stdout, stderr=stderr) + proc.args, proc.returncode, stdout=stdout, stderr=stderr + ) async def run(self, cmd: List[str], **opts) -> subprocess.CompletedProcess: proc = await self.start(cmd, **opts) @@ -93,46 +100,44 @@ class LoggedCommandRunner: class DryRunCommandRunner(LoggedCommandRunner): - - def __init__(self, ident, delay, - *, use_systemd_user: Optional[bool] = None) -> None: + def __init__( + self, ident, delay, *, use_systemd_user: Optional[bool] = None + ) -> None: super().__init__(ident, use_systemd_user=use_systemd_user) self.delay = delay - def _forge_systemd_cmd(self, cmd: List[str], - private_mounts: bool, capture: bool) -> List[str]: + def _forge_systemd_cmd( + self, cmd: List[str], private_mounts: bool, capture: bool + ) -> List[str]: if "scripts/replay-curtin-log.py" in cmd: # We actually want to run this command prefixed_command = cmd else: prefixed_command = ["echo", "not running:"] + cmd - return super()._forge_systemd_cmd(prefixed_command, - private_mounts=private_mounts, - capture=capture) + return super()._forge_systemd_cmd( + prefixed_command, private_mounts=private_mounts, capture=capture + ) def _get_delay_for_cmd(self, cmd: List[str]) -> float: - if 'scripts/replay-curtin-log.py' in cmd: + if "scripts/replay-curtin-log.py" in cmd: return 0 - elif 'unattended-upgrades' in cmd: + elif "unattended-upgrades" in cmd: return 3 * self.delay else: return self.delay - async def start(self, cmd: List[str], - *, private_mounts: bool = False, capture: bool = False) \ - -> asyncio.subprocess.Process: + async def start( + self, cmd: List[str], *, private_mounts: bool = False, capture: bool = False + ) -> asyncio.subprocess.Process: delay = self._get_delay_for_cmd(cmd) - proc = await super().start(cmd, - private_mounts=private_mounts, - capture=capture) + proc = await super().start(cmd, private_mounts=private_mounts, capture=capture) await asyncio.sleep(delay) return proc def get_command_runner(app): if app.opts.dry_run: - return DryRunCommandRunner( - app.log_syslog_id, 2/app.scale_factor) + return DryRunCommandRunner(app.log_syslog_id, 2 / app.scale_factor) else: return LoggedCommandRunner(app.log_syslog_id) diff --git a/subiquity/server/server.py b/subiquity/server/server.py index 7ffb47b0..eb06d9a4 100644 --- a/subiquity/server/server.py +++ b/subiquity/server/server.py @@ -20,44 +20,17 @@ import sys import time from typing import List, Optional +import jsonschema +import yaml from aiohttp import web - from cloudinit import safeyaml, stages from cloudinit.config.cc_set_passwords import rand_user_password from cloudinit.distros import ug_util - -import jsonschema - from systemd import journal -import yaml - -from subiquitycore.async_helpers import ( - run_bg_task, - run_in_thread, - ) -from subiquitycore.context import with_context -from subiquitycore.core import Application -from subiquitycore.file_util import ( - copy_file_if_exists, - write_file, - ) -from subiquitycore.prober import Prober -from subiquitycore.ssh import ( - host_key_fingerprints, - user_key_fingerprints, - ) -from subiquitycore.utils import arun_command, run_command - -from subiquity.common.api.server import ( - bind, - controller_for_request, - ) +from subiquity.common.api.server import bind, controller_for_request from subiquity.common.apidef import API -from subiquity.common.errorreport import ( - ErrorReportKind, - ErrorReporter, - ) +from subiquity.common.errorreport import ErrorReporter, ErrorReportKind from subiquity.common.serialize import to_json from subiquity.common.types import ( ApplicationState, @@ -66,47 +39,43 @@ from subiquity.common.types import ( KeyFingerprint, LiveSessionSSHInfo, PasswordKind, - ) -from subiquity.models.subiquity import ( - ModelNames, - SubiquityModel, - ) -from subiquity.server.dryrun import DRConfig +) +from subiquity.models.subiquity import ModelNames, SubiquityModel from subiquity.server.controller import SubiquityController -from subiquity.server.geoip import ( - GeoIP, - DryRunGeoIPStrategy, - HTTPGeoIPStrategy, - ) +from subiquity.server.dryrun import DRConfig from subiquity.server.errors import ErrorController +from subiquity.server.geoip import DryRunGeoIPStrategy, GeoIP, HTTPGeoIPStrategy from subiquity.server.pkghelper import PackageInstaller from subiquity.server.runner import get_command_runner from subiquity.server.snapdapi import make_api_client from subiquity.server.types import InstallerChannels -from subiquitycore.snapd import ( - AsyncSnapd, - get_fake_connection, - SnapdConnection, - ) +from subiquitycore.async_helpers import run_bg_task, run_in_thread +from subiquitycore.context import with_context +from subiquitycore.core import Application +from subiquitycore.file_util import copy_file_if_exists, write_file +from subiquitycore.prober import Prober +from subiquitycore.snapd import AsyncSnapd, SnapdConnection, get_fake_connection +from subiquitycore.ssh import host_key_fingerprints, user_key_fingerprints +from subiquitycore.utils import arun_command, run_command NOPROBERARG = "NOPROBER" -iso_autoinstall_path = 'cdrom/autoinstall.yaml' -root_autoinstall_path = 'autoinstall.yaml' -cloud_autoinstall_path = 'run/subiquity/cloud.autoinstall.yaml' +iso_autoinstall_path = "cdrom/autoinstall.yaml" +root_autoinstall_path = "autoinstall.yaml" +cloud_autoinstall_path = "run/subiquity/cloud.autoinstall.yaml" -log = logging.getLogger('subiquity.server.server') +log = logging.getLogger("subiquity.server.server") class MetaController: - def __init__(self, app): self.app = app self.context = app.context.child("Meta") self.free_only = False - async def status_GET(self, cur: Optional[ApplicationState] = None) \ - -> ApplicationStatus: + async def status_GET( + self, cur: Optional[ApplicationState] = None + ) -> ApplicationStatus: if cur == self.app.state: await self.app.state_event.wait() return ApplicationStatus( @@ -117,7 +86,8 @@ class MetaController: interactive=self.app.interactive, echo_syslog_id=self.app.echo_syslog_id, event_syslog_id=self.app.event_syslog_id, - log_syslog_id=self.app.log_syslog_id) + log_syslog_id=self.app.log_syslog_id, + ) async def confirm_POST(self, tty: str) -> None: self.app.confirming_tty = tty @@ -134,7 +104,7 @@ class MetaController: async def client_variant_POST(self, variant: str) -> None: if variant not in self.app.supported_variants: - raise ValueError(f'unrecognized client variant {variant}') + raise ValueError(f"unrecognized client variant {variant}") self.app.base_model.set_source_variant(variant) self.app.set_source_variant(variant) @@ -154,28 +124,29 @@ class MetaController: user_fingerprints = [ KeyFingerprint(keytype, fingerprint) for keytype, fingerprint in user_key_fingerprints(username) - ] + ] if self.app.installer_user_passwd_kind == PasswordKind.NONE: if not user_key_fingerprints: return None host_fingerprints = [ KeyFingerprint(keytype, fingerprint) for keytype, fingerprint in host_key_fingerprints() - ] + ] return LiveSessionSSHInfo( username=username, password_kind=self.app.installer_user_passwd_kind, password=self.app.installer_user_passwd, authorized_key_fingerprints=user_fingerprints, ips=ips, - host_key_fingerprints=host_fingerprints) + host_key_fingerprints=host_fingerprints, + ) async def free_only_GET(self) -> bool: return self.free_only async def free_only_POST(self, enable: bool) -> None: self.free_only = enable - to_disable = {'restricted', 'multiverse'} + to_disable = {"restricted", "multiverse"} # enabling free only mode means disabling components self.app.base_model.mirror.disable_components(to_disable, enable) @@ -183,14 +154,15 @@ class MetaController: if self.app.autoinstall_config is None: return None - i_sections = self.app.autoinstall_config.get( - 'interactive-sections', None) - if i_sections == ['*']: + i_sections = self.app.autoinstall_config.get("interactive-sections", None) + if i_sections == ["*"]: # expand the asterisk to the actual controller key names - return [controller.autoinstall_key - for controller in self.app.controllers.instances - if controller.interactive() - if controller.autoinstall_key is not None] + return [ + controller.autoinstall_key + for controller in self.app.controllers.instances + if controller.interactive() + if controller.autoinstall_key is not None + ] return i_sections @@ -204,56 +176,60 @@ def get_installer_password_from_cloudinit_log(): with fp: for line in fp: if line.startswith("installer:"): - return line[len("installer:"):].strip() + return line[len("installer:") :].strip() return None -INSTALL_MODEL_NAMES = ModelNames({ - "debconf_selections", - "filesystem", - "kernel", - "keyboard", - "network", - "proxy", - "source", +INSTALL_MODEL_NAMES = ModelNames( + { + "debconf_selections", + "filesystem", + "kernel", + "keyboard", + "network", + "proxy", + "source", }, - desktop={'mirror'}, - server={'mirror'}) + desktop={"mirror"}, + server={"mirror"}, +) -POSTINSTALL_MODEL_NAMES = ModelNames({ - "drivers", - "identity", - "locale", - "network", - "packages", - "snaplist", - "ssh", - "ubuntu_pro", - "userdata", +POSTINSTALL_MODEL_NAMES = ModelNames( + { + "drivers", + "identity", + "locale", + "network", + "packages", + "snaplist", + "ssh", + "ubuntu_pro", + "userdata", }, - desktop={"timezone", "codecs", "active_directory"}) + desktop={"timezone", "codecs", "active_directory"}, +) class SubiquityServer(Application): - - snapd_socket_path = '/run/snapd.socket' + snapd_socket_path = "/run/snapd.socket" base_schema = { - 'type': 'object', - 'properties': { - 'version': { - 'type': 'integer', - 'minimum': 1, - 'maximum': 1, - }, + "type": "object", + "properties": { + "version": { + "type": "integer", + "minimum": 1, + "maximum": 1, }, - 'required': ['version'], - 'additionalProperties': True, - } + }, + "required": ["version"], + "additionalProperties": True, + } project = "subiquity" from subiquity.server import controllers as controllers_mod + controllers = [ "Early", "Reporting", @@ -285,16 +261,17 @@ class SubiquityServer(Application): "Updates", "Late", "Shutdown", - ] + ] supported_variants = ["server", "desktop"] def make_model(self): - root = '/' + root = "/" if self.opts.dry_run: root = os.path.abspath(self.opts.output_base) return SubiquityModel( - root, self.hub, INSTALL_MODEL_NAMES, POSTINSTALL_MODEL_NAMES) + root, self.hub, INSTALL_MODEL_NAMES, POSTINSTALL_MODEL_NAMES + ) def __init__(self, opts, block_log_dir): super().__init__(opts) @@ -306,29 +283,29 @@ class SubiquityServer(Application): self.state_event = asyncio.Event() self.update_state(ApplicationState.STARTING_UP) self.interactive = None - self.confirming_tty = '' + self.confirming_tty = "" self.fatal_error = None self.running_error_commands = False self.installer_user_name = None self.installer_user_passwd_kind = PasswordKind.NONE self.installer_user_passwd = None - self.echo_syslog_id = 'subiquity_echo.{}'.format(os.getpid()) - self.event_syslog_id = 'subiquity_event.{}'.format(os.getpid()) - self.log_syslog_id = 'subiquity_log.{}'.format(os.getpid()) + self.echo_syslog_id = "subiquity_echo.{}".format(os.getpid()) + self.event_syslog_id = "subiquity_event.{}".format(os.getpid()) + self.log_syslog_id = "subiquity_log.{}".format(os.getpid()) self.command_runner = get_command_runner(self) self.package_installer = PackageInstaller() self.error_reporter = ErrorReporter( - self.context.child("ErrorReporter"), self.opts.dry_run, self.root) + self.context.child("ErrorReporter"), self.opts.dry_run, self.root + ) if opts.machine_config == NOPROBERARG: self.prober = None else: self.prober = Prober(opts.machine_config, self.debug_flags) self.kernel_cmdline = opts.kernel_cmdline if opts.snaps_from_examples: - connection = get_fake_connection( - self.scale_factor, opts.output_base) + connection = get_fake_connection(self.scale_factor, opts.output_base) self.snapd = AsyncSnapd(connection) self.snapdapi = make_api_client(self.snapd) elif os.path.exists(self.snapd_socket_path): @@ -342,8 +319,7 @@ class SubiquityServer(Application): self.event_listeners = [] self.autoinstall_config = None self.hub.subscribe(InstallerChannels.NETWORK_UP, self._network_change) - self.hub.subscribe(InstallerChannels.NETWORK_PROXY_SET, - self._proxy_set) + self.hub.subscribe(InstallerChannels.NETWORK_PROXY_SET, self._proxy_set) if self.opts.dry_run: geoip_strategy = DryRunGeoIPStrategy() else: @@ -362,26 +338,25 @@ class SubiquityServer(Application): self.event_listeners.append(listener) def _maybe_push_to_journal(self, event_type, context, description): - if not context.get('is-install-context') and \ - self.interactive in [True, None]: - controller = context.get('controller') + if not context.get("is-install-context") and self.interactive in [True, None]: + controller = context.get("controller") if controller is None or controller.interactive(): return - if context.get('request'): + if context.get("request"): return - indent = context.full_name().count('/') - 2 - if context.get('is-install-context') and self.interactive: + indent = context.full_name().count("/") - 2 + if context.get("is-install-context") and self.interactive: indent -= 1 msg = context.description else: msg = context.full_name() if description: - msg += ': ' + description - msg = ' ' * indent + msg + msg += ": " + description + msg = " " * indent + msg if context.parent: parent_id = str(context.parent.id) else: - parent_id = '' + parent_id = "" journal.send( msg, PRIORITY=context.level, @@ -389,17 +364,18 @@ class SubiquityServer(Application): SUBIQUITY_CONTEXT_NAME=context.full_name(), SUBIQUITY_EVENT_TYPE=event_type, SUBIQUITY_CONTEXT_ID=str(context.id), - SUBIQUITY_CONTEXT_PARENT_ID=parent_id) + SUBIQUITY_CONTEXT_PARENT_ID=parent_id, + ) def report_start_event(self, context, description): for listener in self.event_listeners: listener.report_start_event(context, description) - self._maybe_push_to_journal('start', context, description) + self._maybe_push_to_journal("start", context, description) def report_finish_event(self, context, description, status): for listener in self.event_listeners: listener.report_finish_event(context, description, status) - self._maybe_push_to_journal('finish', context, description) + self._maybe_push_to_journal("finish", context, description) @property def state(self): @@ -418,8 +394,7 @@ class SubiquityServer(Application): self.error_reporter.note_data_for_apport(key, value) def make_apport_report(self, kind, thing, *, wait=False, **kw): - return self.error_reporter.make_apport_report( - kind, thing, wait=wait, **kw) + return self.error_reporter.make_apport_report(kind, thing, wait=wait, **kw) async def _run_error_cmds(self, report): await report._info_task @@ -433,7 +408,7 @@ class SubiquityServer(Application): self.update_state(ApplicationState.ERROR) def _exception_handler(self, loop, context): - exc = context.get('exception') + exc = context.get("exception") if exc is None: super()._exception_handler(loop, context) return @@ -441,8 +416,8 @@ class SubiquityServer(Application): log.error("top level error", exc_info=exc) if not report: report = self.make_apport_report( - ErrorReportKind.UNKNOWN, "unknown error", - exc=exc) + ErrorReportKind.UNKNOWN, "unknown error", exc=exc + ) self.fatal_error = report if self.interactive: self.update_state(ApplicationState.ERROR) @@ -455,31 +430,29 @@ class SubiquityServer(Application): override_status = None controller = await controller_for_request(request) if isinstance(controller, SubiquityController): - if request.headers.get('x-make-view-request') == 'yes': + if request.headers.get("x-make-view-request") == "yes": if not controller.interactive(): - override_status = 'skip' + override_status = "skip" elif self.state == ApplicationState.NEEDS_CONFIRMATION: - if self.base_model.is_postinstall_only( - controller.model_name): - override_status = 'confirm' + if self.base_model.is_postinstall_only(controller.model_name): + override_status = "confirm" if override_status is not None: - resp = web.Response(headers={'x-status': override_status}) + resp = web.Response(headers={"x-status": override_status}) else: resp = await handler(request) if self.updated: - resp.headers['x-updated'] = 'yes' + resp.headers["x-updated"] = "yes" else: - resp.headers['x-updated'] = 'no' - if resp.get('exception'): - exc = resp['exception'] - log.debug( - 'request to {} crashed'.format(request.raw_path), exc_info=exc) + resp.headers["x-updated"] = "no" + if resp.get("exception"): + exc = resp["exception"] + log.debug("request to {} crashed".format(request.raw_path), exc_info=exc) report = self.make_apport_report( ErrorReportKind.SERVER_REQUEST_FAIL, "request to {}".format(request.raw_path), - exc=exc) - resp.headers['x-error-report'] = to_json( - ErrorReportRef, report.ref()) + exc=exc, + ) + resp.headers["x-error-report"] = to_json(ErrorReportRef, report.ref()) return resp @with_context() @@ -488,14 +461,18 @@ class SubiquityServer(Application): if controller.interactive(): log.debug( "apply_autoinstall_config: skipping %s as interactive", - controller.name) + controller.name, + ) continue await controller.apply_autoinstall_config() await controller.configured() def load_autoinstall_config(self, *, only_early): - log.debug('load_autoinstall_config only_early %s file %s', - only_early, self.autoinstall) + log.debug( + "load_autoinstall_config only_early %s file %s", + only_early, + self.autoinstall, + ) if not self.autoinstall: return with open(self.autoinstall) as fp: @@ -517,10 +494,11 @@ class SubiquityServer(Application): bind(app.router, API.errors, ErrorController(self)) if self.opts.dry_run: from .dryrun import DryRunController + bind(app.router, API.dry_run, DryRunController(self)) for controller in self.controllers.instances: controller.add_routes(app) - runner = web.AppRunner(app, keepalive_timeout=0xffffffff) + runner = web.AppRunner(app, keepalive_timeout=0xFFFFFFFF) await runner.setup() await self.start_site(runner) @@ -542,7 +520,7 @@ class SubiquityServer(Application): try: status_cp = await asyncio.wait_for(status_coro, 600) except asyncio.TimeoutError: - status_txt = '' + status_txt = "" self.cloud_init_ok = False else: status_txt = status_cp.stdout @@ -554,14 +532,12 @@ class SubiquityServer(Application): init.read_cfg() init.fetch(existing="trust") self.cloud = init.cloudify() - if 'autoinstall' in self.cloud.cfg: - cfg = self.cloud.cfg['autoinstall'] + if "autoinstall" in self.cloud.cfg: + cfg = self.cloud.cfg["autoinstall"] target = self.base_relative(cloud_autoinstall_path) write_file(target, safeyaml.dumps(cfg)) else: - log.debug( - "cloud-init status: %r, assumed disabled", - status_txt) + log.debug("cloud-init status: %r, assumed disabled", status_txt) def select_autoinstall(self): # precedence @@ -574,15 +550,17 @@ class SubiquityServer(Application): # autoinstall has been explicitly disabled. if self.opts.autoinstall == "": return None - if self.opts.autoinstall is not None and not \ - os.path.exists(self.opts.autoinstall): - raise Exception( - f'Autoinstall argument {self.opts.autoinstall} not found') + if self.opts.autoinstall is not None and not os.path.exists( + self.opts.autoinstall + ): + raise Exception(f"Autoinstall argument {self.opts.autoinstall} not found") - locations = (self.base_relative(root_autoinstall_path), - self.opts.autoinstall, - self.base_relative(cloud_autoinstall_path), - self.base_relative(iso_autoinstall_path)) + locations = ( + self.base_relative(root_autoinstall_path), + self.opts.autoinstall, + self.base_relative(cloud_autoinstall_path), + self.base_relative(iso_autoinstall_path), + ) for loc in locations: if loc is not None and os.path.exists(loc): @@ -595,7 +573,7 @@ class SubiquityServer(Application): return rootpath def _user_has_password(self, username): - with open('/etc/shadow') as fp: + with open("/etc/shadow") as fp: for line in fp: if line.startswith(username + ":$"): return True @@ -611,23 +589,25 @@ class SubiquityServer(Application): with open(passfile) as fp: contents = fp.read() self.installer_user_passwd_kind = PasswordKind.KNOWN - self.installer_user_name, self.installer_user_passwd = \ - contents.split(':', 1) + self.installer_user_name, self.installer_user_passwd = contents.split( + ":", 1 + ) return def use_passwd(passwd): self.installer_user_passwd = passwd self.installer_user_passwd_kind = PasswordKind.KNOWN - with open(passfile, 'w') as fp: - fp.write(self.installer_user_name + ':' + passwd) + with open(passfile, "w") as fp: + fp.write(self.installer_user_name + ":" + passwd) if self.opts.dry_run: - self.installer_user_name = os.environ['USER'] + self.installer_user_name = os.environ["USER"] use_passwd(rand_user_password()) return (users, _groups) = ug_util.normalize_users_groups( - self.cloud.cfg, self.cloud.distro) + self.cloud.cfg, self.cloud.distro + ) (username, _user_config) = ug_util.extract_default(users) self.installer_user_name = username @@ -646,7 +626,7 @@ class SubiquityServer(Application): self.installer_user_passwd_kind = PasswordKind.UNKNOWN elif not user_key_fingerprints(username): passwd = rand_user_password() - cp = run_command('chpasswd', input=username + ':'+passwd+'\n') + cp = run_command("chpasswd", input=username + ":" + passwd + "\n") if cp.returncode == 0: use_passwd(passwd) else: @@ -671,16 +651,15 @@ class SubiquityServer(Application): # output. await asyncio.sleep(1) await self.controllers.Early.run() - open(stamp_file, 'w').close() + open(stamp_file, "w").close() await asyncio.sleep(1) self.load_autoinstall_config(only_early=False) if self.autoinstall_config: - self.interactive = bool( - self.autoinstall_config.get('interactive-sections')) + self.interactive = bool(self.autoinstall_config.get("interactive-sections")) else: self.interactive = True if not self.interactive and not self.opts.dry_run: - open('/run/casper-no-prompt', 'w').close() + open("/run/casper-no-prompt", "w").close() self.load_serialized_state() self.update_state(ApplicationState.WAITING) await super().start() @@ -699,21 +678,24 @@ class SubiquityServer(Application): if not self.snapd: return await run_in_thread( - self.snapd.connection.configure_proxy, self.base_model.proxy) + self.snapd.connection.configure_proxy, self.base_model.proxy + ) self.hub.broadcast(InstallerChannels.SNAPD_NETWORK_CHANGE) def restart(self): if not self.snapd: return - cmdline = ['snap', 'run', 'subiquity.subiquity-server'] + cmdline = ["snap", "run", "subiquity.subiquity-server"] if self.opts.dry_run: cmdline = [ - sys.executable, '-m', 'subiquity.cmd.server', - ] + sys.argv[1:] + sys.executable, + "-m", + "subiquity.cmd.server", + ] + sys.argv[1:] os.execvp(cmdline[0], cmdline) def make_autoinstall(self): - config = {'version': 1} + config = {"version": 1} for controller in self.controllers.instances: controller_conf = controller.make_autoinstall() if controller_conf: diff --git a/subiquity/server/snapdapi.py b/subiquity/server/snapdapi.py index 1eb421b1..4173c5aa 100644 --- a/subiquity/server/snapdapi.py +++ b/subiquity/server/snapdapi.py @@ -13,28 +13,27 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import aiohttp import asyncio import contextlib import enum import logging from typing import Dict, List, Optional -from subiquity.common.api.client import make_client -from subiquity.common.api.defs import api, path_parameter, Payload -from subiquity.common.serialize import named_field, Serializer -from subiquity.common.types import Change, TaskStatus - +import aiohttp import attr +from subiquity.common.api.client import make_client +from subiquity.common.api.defs import Payload, api, path_parameter +from subiquity.common.serialize import Serializer, named_field +from subiquity.common.types import Change, TaskStatus -log = logging.getLogger('subiquity.server.snapdapi') +log = logging.getLogger("subiquity.server.snapdapi") -RFC3339 = '%Y-%m-%dT%H:%M:%S.%fZ' +RFC3339 = "%Y-%m-%dT%H:%M:%S.%fZ" def date_field(name=None, default=attr.NOTHING): - metadata = {'time_fmt': RFC3339} + metadata = {"time_fmt": RFC3339} if name is not None: metadata.update(named_field(name).metadata) return attr.ib(metadata=metadata, default=default) @@ -44,15 +43,15 @@ ChangeID = str class SnapStatus(enum.Enum): - ACTIVE = 'active' - AVAILABLE = 'available' + ACTIVE = "active" + AVAILABLE = "available" @attr.s(auto_attribs=True) class Publisher: id: str username: str - display_name: str = named_field('display-name') + display_name: str = named_field("display-name") @attr.s(auto_attribs=True) @@ -67,21 +66,21 @@ class Snap: class SnapAction(enum.Enum): - REFRESH = 'refresh' - SWITCH = 'switch' + REFRESH = "refresh" + SWITCH = "switch" @attr.s(auto_attribs=True) class SnapActionRequest: action: SnapAction - channel: str = '' - ignore_running: bool = named_field('ignore-running', False) + channel: str = "" + ignore_running: bool = named_field("ignore-running", False) class ResponseType: - SYNC = 'sync' - ASYNC = 'async' - ERROR = 'error' + SYNC = "sync" + ASYNC = "async" + ERROR = "error" @attr.s(auto_attribs=True) @@ -92,32 +91,31 @@ class Response: class Role: - - NONE = '' - MBR = 'mbr' - SYSTEM_BOOT = 'system-boot' - SYSTEM_BOOT_IMAGE = 'system-boot-image' - SYSTEM_BOOT_SELECT = 'system-boot-select' - SYSTEM_DATA = 'system-data' - SYSTEM_RECOVERY_SELECT = 'system-recovery-select' - SYSTEM_SAVE = 'system-save' - SYSTEM_SEED = 'system-seed' - SYSTEM_SEED_NULL = 'system-seed-null' + NONE = "" + MBR = "mbr" + SYSTEM_BOOT = "system-boot" + SYSTEM_BOOT_IMAGE = "system-boot-image" + SYSTEM_BOOT_SELECT = "system-boot-select" + SYSTEM_DATA = "system-data" + SYSTEM_RECOVERY_SELECT = "system-recovery-select" + SYSTEM_SAVE = "system-save" + SYSTEM_SEED = "system-seed" + SYSTEM_SEED_NULL = "system-seed-null" @attr.s(auto_attribs=True) class RelativeOffset: - relative_to: str = named_field('relative-to') + relative_to: str = named_field("relative-to") offset: int @attr.s(auto_attribs=True) class VolumeContent: - source: str = '' - target: str = '' - image: str = '' + source: str = "" + target: str = "" + image: str = "" offset: Optional[int] = None - offset_write: Optional[RelativeOffset] = named_field('offset-write', None) + offset_write: Optional[RelativeOffset] = named_field("offset-write", None) size: int = 0 unpack: bool = False @@ -130,30 +128,30 @@ class VolumeUpdate: @attr.s(auto_attribs=True) class VolumeStructure: - name: str = '' - label: str = named_field('filesystem-label', '') + name: str = "" + label: str = named_field("filesystem-label", "") offset: Optional[int] = None - offset_write: Optional[RelativeOffset] = named_field('offset-write', None) + offset_write: Optional[RelativeOffset] = named_field("offset-write", None) size: int = 0 - type: str = '' + type: str = "" role: str = Role.NONE id: Optional[str] = None - filesystem: str = '' + filesystem: str = "" content: Optional[List[VolumeContent]] = None update: VolumeUpdate = attr.Factory(VolumeUpdate) def gpt_part_type_uuid(self): - if ',' in self.type: - return self.type.split(',', 1)[1].upper() + if "," in self.type: + return self.type.split(",", 1)[1].upper() else: return self.type @attr.s(auto_attribs=True) class Volume: - schema: str = '' - bootloader: str = '' - id: str = '' + schema: str = "" + bootloader: str = "" + id: str = "" structure: Optional[List[VolumeStructure]] = None @@ -173,39 +171,40 @@ class OnVolume(Volume): @classmethod def from_volume(cls, v: Volume): kw = attr.asdict(v, recurse=False) - kw['structure'] = [ - OnVolumeStructure.from_volume_structure(vs) for vs in v.structure] + kw["structure"] = [ + OnVolumeStructure.from_volume_structure(vs) for vs in v.structure + ] return cls(**kw) class StorageEncryptionSupport(enum.Enum): - DISABLED = 'disabled' - AVAILABLE = 'available' - UNAVAILABLE = 'unavailable' - DEFECTIVE = 'defective' + DISABLED = "disabled" + AVAILABLE = "available" + UNAVAILABLE = "unavailable" + DEFECTIVE = "defective" class StorageSafety(enum.Enum): - UNSET = 'unset' - ENCRYPTED = 'encrypted' - PREFER_ENCRYPTED = 'prefer-encrypted' - PREFER_UNENCRYPTED = 'prefer-unencrypted' + UNSET = "unset" + ENCRYPTED = "encrypted" + PREFER_ENCRYPTED = "prefer-encrypted" + PREFER_UNENCRYPTED = "prefer-unencrypted" class EncryptionType(enum.Enum): - NONE = '' - CRYPTSETUP = 'cryptsetup' - DEVICE_SETUP_HOOK = 'device-setup-hook' + NONE = "" + CRYPTSETUP = "cryptsetup" + DEVICE_SETUP_HOOK = "device-setup-hook" @attr.s(auto_attribs=True) class StorageEncryption: support: StorageEncryptionSupport - storage_safety: StorageSafety = named_field('storage-safety') + storage_safety: StorageSafety = named_field("storage-safety") encryption_type: EncryptionType = named_field( - 'encryption-type', default=EncryptionType.NONE) - unavailable_reason: str = named_field( - 'unavailable-reason', default='') + "encryption-type", default=EncryptionType.NONE + ) + unavailable_reason: str = named_field("unavailable-reason", default="") @attr.s(auto_attribs=True) @@ -213,23 +212,24 @@ class SystemDetails: current: bool = False volumes: Dict[str, Volume] = attr.Factory(dict) storage_encryption: Optional[StorageEncryption] = named_field( - 'storage-encryption', default=None) + "storage-encryption", default=None + ) class SystemAction(enum.Enum): - INSTALL = 'install' + INSTALL = "install" class SystemActionStep(enum.Enum): - SETUP_STORAGE_ENCRYPTION = 'setup-storage-encryption' - FINISH = 'finish' + SETUP_STORAGE_ENCRYPTION = "setup-storage-encryption" + FINISH = "finish" @attr.s(auto_attribs=True) class SystemActionRequest: action: SystemAction step: SystemActionStep - on_volumes: Dict[str, OnVolume] = named_field('on-volumes') + on_volumes: Dict[str, OnVolume] = named_field("on-volumes") @api @@ -240,22 +240,30 @@ class SnapdAPI: class changes: @path_parameter class change_id: - def GET() -> Change: ... + def GET() -> Change: + ... class snaps: @path_parameter class snap_name: - def GET() -> Snap: ... - def POST(action: Payload[SnapActionRequest]) -> ChangeID: ... + def GET() -> Snap: + ... + + def POST(action: Payload[SnapActionRequest]) -> ChangeID: + ... class find: - def GET(name: str = '', select: str = '') -> List[Snap]: ... + def GET(name: str = "", select: str = "") -> List[Snap]: + ... class systems: @path_parameter class label: - def GET() -> SystemDetails: ... - def POST(action: Payload[SystemActionRequest]) -> ChangeID: ... + def GET() -> SystemDetails: + ... + + def POST(action: Payload[SystemActionRequest]) -> ChangeID: + ... class _FakeResponse: @@ -274,7 +282,7 @@ class _FakeError: self.data = data def raise_for_status(self): - raise aiohttp.ClientError(self.data['result']['message']) + raise aiohttp.ClientError(self.data["result"]["message"]) def make_api_client(async_snapd): @@ -292,9 +300,9 @@ def make_api_client(async_snapd): content = await async_snapd.post(path[1:], json, **params) response = snapd_serializer.deserialize(Response, content) if response.type == ResponseType.SYNC: - content = content['result'] + content = content["result"] elif response.type == ResponseType.ASYNC: - content = content['change'] + content = content["change"] elif response.type == ResponseType.ERROR: yield _FakeError() yield _FakeResponse(content) @@ -302,13 +310,12 @@ def make_api_client(async_snapd): return make_client(SnapdAPI, make_request, serializer=snapd_serializer) -snapd_serializer = Serializer( - ignore_unknown_fields=True, serialize_enums_by='value') +snapd_serializer = Serializer(ignore_unknown_fields=True, serialize_enums_by="value") async def post_and_wait(client, meth, *args, **kw): change_id = await meth(*args, **kw) - log.debug('post_and_wait %s', change_id) + log.debug("post_and_wait %s", change_id) while True: result = await client.v2.changes[change_id].GET() diff --git a/subiquity/server/ssh.py b/subiquity/server/ssh.py index 1ecdcde5..de84cd16 100644 --- a/subiquity/server/ssh.py +++ b/subiquity/server/ssh.py @@ -19,10 +19,8 @@ import os import subprocess from typing import List +from subiquity.common.types import SSHFetchIdStatus from subiquitycore.utils import arun_command -from subiquity.common.types import ( - SSHFetchIdStatus, - ) log = logging.getLogger("subiquity.server.ssh") @@ -39,7 +37,7 @@ class SSHKeyFetcher: self.app = app async def fetch_keys_for_id(self, user_id: str) -> List[str]: - cmd = ('ssh-import-id', '--output', '-', '--', user_id) + cmd = ("ssh-import-id", "--output", "-", "--", user_id) env = None if self.app.base_model.proxy.proxy: env = os.environ.copy() @@ -49,24 +47,24 @@ class SSHKeyFetcher: cp = await arun_command(cmd, check=True, env=env) except subprocess.CalledProcessError as exc: log.exception("ssh-import-id failed. stderr: %s", exc.stderr) - raise SSHFetchError(status=SSHFetchIdStatus.IMPORT_ERROR, - reason=exc.stderr) - keys_material: str = cp.stdout.replace('\r', '').strip() + raise SSHFetchError(status=SSHFetchIdStatus.IMPORT_ERROR, reason=exc.stderr) + keys_material: str = cp.stdout.replace("\r", "").strip() return [mat for mat in keys_material.splitlines() if mat] async def gen_fingerprint_for_key(self, key: str) -> str: - """ For a given key, generate the fingerprint. """ + """For a given key, generate the fingerprint.""" # ssh-keygen supports multiple keys at once, but it is simpler to # associate each key with its resulting fingerprint if we call # ssh-keygen multiple times. - cmd = ('ssh-keygen', '-l', '-f', '-') + cmd = ("ssh-keygen", "-l", "-f", "-") try: cp = await arun_command(cmd, check=True, input=key) except subprocess.CalledProcessError as exc: log.exception("ssh-import-id failed. stderr: %s", exc.stderr) - raise SSHFetchError(status=SSHFetchIdStatus.FINGERPRINT_ERROR, - reason=exc.stderr) + raise SSHFetchError( + status=SSHFetchIdStatus.FINGERPRINT_ERROR, reason=exc.stderr + ) return cp.stdout.strip() @@ -79,20 +77,23 @@ class DryRunSSHKeyFetcher(SSHKeyFetcher): async def fetch_keys_fake_success(self, user_id: str) -> List[str]: unused, username = user_id.split(":", maxsplit=1) - return [f"""\ + return [ + f"""\ ssh-ed25519\ AAAAC3NzaC1lZDI1NTE5AAAAIMM/qhS3hS3+IjpJBYXZWCqPKPH9Zag8QYbS548iEjoZ\ - {username}@earth # ssh-import-id {user_id}"""] + {username}@earth # ssh-import-id {user_id}""" + ] async def fetch_keys_fake_failure(self, user_id: str) -> List[str]: unused, username = user_id.split(":", maxsplit=1) - raise SSHFetchError(status=SSHFetchIdStatus.IMPORT_ERROR, - reason=f"ERROR Username {username} not found.") + raise SSHFetchError( + status=SSHFetchIdStatus.IMPORT_ERROR, + reason=f"ERROR Username {username} not found.", + ) async def fetch_keys_for_id(self, user_id: str) -> List[str]: service, username = user_id.split(":", maxsplit=1) - strategy = self.SSHImportStrategy( - self.app.dr_cfg.ssh_import_default_strategy) + strategy = self.SSHImportStrategy(self.app.dr_cfg.ssh_import_default_strategy) for entry in self.app.dr_cfg.ssh_imports: if entry["username"] != username: continue diff --git a/subiquity/server/tests/test_apt.py b/subiquity/server/tests/test_apt.py index b5fc75f8..bd5490c3 100644 --- a/subiquity/server/tests/test_apt.py +++ b/subiquity/server/tests/test_apt.py @@ -15,24 +15,23 @@ import io import subprocess -from unittest.mock import Mock, patch, AsyncMock +from unittest.mock import AsyncMock, Mock, patch from curtin.commands.extract import TrivialSourceHandler -from subiquitycore.tests import SubiTestCase -from subiquitycore.tests.mocks import make_app -from subiquitycore.utils import astart_command -from subiquity.server.apt import ( - AptConfigurer, - DryRunAptConfigurer, - AptConfigCheckError, - OverlayMountpoint, -) -from subiquity.server.dryrun import DRConfig from subiquity.models.mirror import MirrorModel from subiquity.models.proxy import ProxyModel from subiquity.models.subiquity import DebconfSelectionsModel - +from subiquity.server.apt import ( + AptConfigCheckError, + AptConfigurer, + DryRunAptConfigurer, + OverlayMountpoint, +) +from subiquity.server.dryrun import DRConfig +from subiquitycore.tests import SubiTestCase +from subiquitycore.tests.mocks import make_app +from subiquitycore.utils import astart_command APT_UPDATE_SUCCESS = """\ Hit:1 http://mirror focal InRelease @@ -67,8 +66,7 @@ class TestAptConfigurer(SubiTestCase): self.model.debconf_selections = DebconfSelectionsModel() self.model.locale.selected_language = "en_US.UTF-8" self.app = make_app(self.model) - self.configurer = AptConfigurer( - self.app, AsyncMock(), TrivialSourceHandler('')) + self.configurer = AptConfigurer(self.app, AsyncMock(), TrivialSourceHandler("")) self.astart_sym = "subiquity.server.apt.astart_command" @@ -78,7 +76,7 @@ class TestAptConfigurer(SubiTestCase): self.assertNotIn("https_proxy", config["apt"]) def test_apt_config_proxy(self): - proxy = 'http://apt-cacher-ng:3142' + proxy = "http://apt-cacher-ng:3142" self.model.proxy.proxy = proxy config = self.configurer.apt_config(final=True) @@ -87,41 +85,44 @@ class TestAptConfigurer(SubiTestCase): async def test_overlay(self): self.configurer.install_tree = OverlayMountpoint( - upperdir="upperdir-install-tree", - lowers=["lowers1-install-tree"], - mountpoint="mountpoint-install-tree", - ) + upperdir="upperdir-install-tree", + lowers=["lowers1-install-tree"], + mountpoint="mountpoint-install-tree", + ) self.configurer.configured_tree = OverlayMountpoint( - upperdir="upperdir-install-tree", - lowers=["lowers1-install-tree"], - mountpoint="mountpoint-install-tree", - ) + upperdir="upperdir-install-tree", + lowers=["lowers1-install-tree"], + mountpoint="mountpoint-install-tree", + ) self.source = "source" - with patch.object(self.app, "command_runner", - create=True, new_callable=AsyncMock): + with patch.object( + self.app, "command_runner", create=True, new_callable=AsyncMock + ): async with self.configurer.overlay(): pass async def test_run_apt_config_check(self): self.configurer.configured_tree = OverlayMountpoint( - upperdir="upperdir-install-tree", - lowers=["lowers1-install-tree"], - mountpoint="mountpoint-install-tree", - ) + upperdir="upperdir-install-tree", + lowers=["lowers1-install-tree"], + mountpoint="mountpoint-install-tree", + ) async def astart_success(cmd, **kwargs): - """ Simulates apt-get update behaving normally. """ - proc = await astart_command(["sh", "-c", "cat"], - **kwargs, stdin=subprocess.PIPE) + """Simulates apt-get update behaving normally.""" + proc = await astart_command( + ["sh", "-c", "cat"], **kwargs, stdin=subprocess.PIPE + ) proc.stdin.write(APT_UPDATE_SUCCESS.encode("utf-8")) proc.stdin.write_eof() return proc async def astart_failure(cmd, **kwargs): - """ Simulates apt-get update failing. """ - proc = await astart_command(["sh", "-c", "cat; exit 1"], - **kwargs, stdin=subprocess.PIPE) + """Simulates apt-get update failing.""" + proc = await astart_command( + ["sh", "-c", "cat; exit 1"], **kwargs, stdin=subprocess.PIPE + ) proc.stdin.write(APT_UPDATE_FAILURE.encode("utf-8")) proc.stdin.write_eof() return proc @@ -152,30 +153,35 @@ class TestDRAptConfigurer(SubiTestCase): {"url": "http://run-on-host", "strategy": "run-on-host"}, {"pattern": "/random$", "strategy": "random"}, ] - self.configurer = DryRunAptConfigurer(self.app, AsyncMock(), '') + self.configurer = DryRunAptConfigurer(self.app, AsyncMock(), "") self.configurer.configured_tree = OverlayMountpoint( - upperdir="upperdir-install-tree", - lowers=["lowers1-install-tree"], - mountpoint="mountpoint-install-tree", - ) + upperdir="upperdir-install-tree", + lowers=["lowers1-install-tree"], + mountpoint="mountpoint-install-tree", + ) def test_get_mirror_check_strategy(self): Strategy = DryRunAptConfigurer.MirrorCheckStrategy self.assertEqual( - Strategy.SUCCESS, - self.configurer.get_mirror_check_strategy("http://success")) + Strategy.SUCCESS, + self.configurer.get_mirror_check_strategy("http://success"), + ) self.assertEqual( - Strategy.FAILURE, - self.configurer.get_mirror_check_strategy("http://failure")) - self.assertEqual(Strategy.RUN_ON_HOST, - self.configurer.get_mirror_check_strategy( - "http://run-on-host")) - self.assertEqual(Strategy.RANDOM, - self.configurer.get_mirror_check_strategy( - "http://mirror/random")) + Strategy.FAILURE, + self.configurer.get_mirror_check_strategy("http://failure"), + ) self.assertEqual( - Strategy.FAILURE, - self.configurer.get_mirror_check_strategy("http://default")) + Strategy.RUN_ON_HOST, + self.configurer.get_mirror_check_strategy("http://run-on-host"), + ) + self.assertEqual( + Strategy.RANDOM, + self.configurer.get_mirror_check_strategy("http://mirror/random"), + ) + self.assertEqual( + Strategy.FAILURE, + self.configurer.get_mirror_check_strategy("http://default"), + ) async def test_run_apt_config_check_success(self): output = io.StringIO() @@ -194,10 +200,14 @@ class TestDRAptConfigurer(SubiTestCase): output = io.StringIO() self.app.dr_cfg.apt_mirror_check_default_strategy = "random" self.candidate.uri = "http://default" - with patch("subiquity.server.apt.random.choice", - return_value=self.configurer.apt_config_check_success): + with patch( + "subiquity.server.apt.random.choice", + return_value=self.configurer.apt_config_check_success, + ): await self.configurer.run_apt_config_check(output) - with patch("subiquity.server.apt.random.choice", - return_value=self.configurer.apt_config_check_failure): + with patch( + "subiquity.server.apt.random.choice", + return_value=self.configurer.apt_config_check_failure, + ): with self.assertRaises(AptConfigCheckError): await self.configurer.run_apt_config_check(output) diff --git a/subiquity/server/tests/test_contract_selection.py b/subiquity/server/tests/test_contract_selection.py index a297275f..ad82826a 100644 --- a/subiquity/server/tests/test_contract_selection.py +++ b/subiquity/server/tests/test_contract_selection.py @@ -14,13 +14,13 @@ # along with this program. If not, see . import unittest -from unittest.mock import patch, AsyncMock +from unittest.mock import AsyncMock, patch from subiquity.server.contract_selection import ( ContractSelection, UnknownError, UPCSExpiredError, - ) +) class TestContractSelection(unittest.IsolatedAsyncioTestCase): @@ -28,13 +28,15 @@ class TestContractSelection(unittest.IsolatedAsyncioTestCase): self.client = AsyncMock() async def test_init(self): - with patch.object(ContractSelection, "_run_polling", - return_value=None) as run_polling: - - cs = ContractSelection(client=self.client, - magic_token="magic", - user_code="user-code", - validity_seconds=200) + with patch.object( + ContractSelection, "_run_polling", return_value=None + ) as run_polling: + cs = ContractSelection( + client=self.client, + magic_token="magic", + user_code="user-code", + validity_seconds=200, + ) self.assertIs(cs.client, self.client) self.assertEqual(cs.magic_token, "magic") @@ -49,9 +51,7 @@ class TestContractSelection(unittest.IsolatedAsyncioTestCase): "errors": ["Initiate failed"], "result": "failure", } - with patch.object(self.client.strategy, - "magic_initiate_v1", return_value=ret): - + with patch.object(self.client.strategy, "magic_initiate_v1", return_value=ret): with self.assertRaises(UnknownError) as cm: await ContractSelection.initiate(self.client) self.assertIn("Initiate failed", cm.exception.errors) @@ -68,9 +68,7 @@ class TestContractSelection(unittest.IsolatedAsyncioTestCase): }, }, } - with patch.object(self.client.strategy, - "magic_initiate_v1", return_value=ret): - + with patch.object(self.client.strategy, "magic_initiate_v1", return_value=ret): cs = await ContractSelection.initiate(self.client) self.assertEqual(cs.magic_token, "magic-12345") @@ -87,17 +85,17 @@ class TestContractSelection(unittest.IsolatedAsyncioTestCase): }, ], } - with patch.object(ContractSelection, - "_run_polling", return_value=None): - cs = ContractSelection(client=self.client, - magic_token="magic-token", - user_code="ABCDEF", - validity_seconds=200) + with patch.object(ContractSelection, "_run_polling", return_value=None): + cs = ContractSelection( + client=self.client, + magic_token="magic-token", + user_code="ABCDEF", + validity_seconds=200, + ) await cs.task self.client.strategy.scale_factor = 1 - with patch.object(self.client.strategy, - "magic_wait_v1", return_value=ret): + with patch.object(self.client.strategy, "magic_wait_v1", return_value=ret): # 100 seconds elapsed, should be considered a timeout with patch("time.monotonic", side_effect=[100, 200]): with self.assertRaises(UPCSExpiredError): @@ -118,15 +116,15 @@ class TestContractSelection(unittest.IsolatedAsyncioTestCase): }, }, } - with patch.object(ContractSelection, - "_run_polling", return_value=None): - cs = ContractSelection(client=self.client, - magic_token="magic-token", - user_code="ABCDEF", - validity_seconds=200) + with patch.object(ContractSelection, "_run_polling", return_value=None): + cs = ContractSelection( + client=self.client, + magic_token="magic-token", + user_code="ABCDEF", + validity_seconds=200, + ) await cs.task self.client.strategy.scale_factor = 1 - with patch.object(self.client.strategy, - "magic_wait_v1", return_value=ret): + with patch.object(self.client.strategy, "magic_wait_v1", return_value=ret): self.assertEqual((await cs._run_polling()), "C12345257") diff --git a/subiquity/server/tests/test_controller.py b/subiquity/server/tests/test_controller.py index 0d9684bf..b1c0826c 100644 --- a/subiquity/server/tests/test_controller.py +++ b/subiquity/server/tests/test_controller.py @@ -14,14 +14,12 @@ # along with this program. If not, see . import contextlib - from unittest.mock import patch +from subiquity.server.controller import SubiquityController from subiquitycore.tests import SubiTestCase from subiquitycore.tests.mocks import make_app -from subiquity.server.controller import SubiquityController - class TestController(SubiTestCase): def setUp(self): diff --git a/subiquity/server/tests/test_geoip.py b/subiquity/server/tests/test_geoip.py index ebfad3d8..e2a36b25 100644 --- a/subiquity/server/tests/test_geoip.py +++ b/subiquity/server/tests/test_geoip.py @@ -16,14 +16,11 @@ import aiohttp from aioresponses import aioresponses +from subiquity.server.geoip import GeoIP, HTTPGeoIPStrategy from subiquitycore.tests import SubiTestCase from subiquitycore.tests.mocks import make_app -from subiquity.server.geoip import ( - GeoIP, - HTTPGeoIPStrategy, - ) -xml = ''' +xml = """ 1.2.3.4 OK @@ -39,12 +36,12 @@ xml = ''' 707 America/Los_Angeles -''' -partial = '' -incomplete = '-121.7016' -long_cc = 'USA' -empty_tz = '' -empty_cc = '' +""" +partial = "" +incomplete = "-121.7016" +long_cc = "USA" +empty_tz = "" +empty_cc = "" class TestGeoIP(SubiTestCase): @@ -100,7 +97,9 @@ class TestGeoIPBadData(SubiTestCase): async def test_lookup_error(self): with aioresponses() as mocked: - mocked.get("https://geoip.ubuntu.com/lookup", - exception=aiohttp.ClientError('lookup failure')) + mocked.get( + "https://geoip.ubuntu.com/lookup", + exception=aiohttp.ClientError("lookup failure"), + ) self.assertFalse(await self.geoip.lookup()) self.assertIsNone(self.geoip.timezone) diff --git a/subiquity/server/tests/test_kernel.py b/subiquity/server/tests/test_kernel.py index 40a12c3a..024d90a3 100644 --- a/subiquity/server/tests/test_kernel.py +++ b/subiquity/server/tests/test_kernel.py @@ -17,57 +17,58 @@ import subprocess import unittest from unittest import mock -from subiquity.server.kernel import ( - flavor_to_pkgname, - list_installed_kernels, -) +from subiquity.server.kernel import flavor_to_pkgname, list_installed_kernels class TestFlavorToPkgname(unittest.TestCase): def test_flavor_generic(self): - self.assertEqual('linux-generic', - flavor_to_pkgname('generic', dry_run=True)) + self.assertEqual("linux-generic", flavor_to_pkgname("generic", dry_run=True)) def test_flavor_oem(self): - self.assertEqual('linux-oem-20.04', - flavor_to_pkgname('oem', dry_run=True)) + self.assertEqual("linux-oem-20.04", flavor_to_pkgname("oem", dry_run=True)) def test_flavor_hwe(self): - self.assertEqual('linux-generic-hwe-20.04', - flavor_to_pkgname('hwe', dry_run=True)) + self.assertEqual( + "linux-generic-hwe-20.04", flavor_to_pkgname("hwe", dry_run=True) + ) - self.assertEqual('linux-generic-hwe-20.04', - flavor_to_pkgname('generic-hwe', dry_run=True)) + self.assertEqual( + "linux-generic-hwe-20.04", flavor_to_pkgname("generic-hwe", dry_run=True) + ) class TestListInstalledKernels(unittest.IsolatedAsyncioTestCase): async def test_one_kernel(self): rv = subprocess.CompletedProcess([], 0) - rv.stdout = 'linux-image-6.1.0-16-generic' - with mock.patch('subiquity.server.kernel.arun_command', - return_value=rv) as mock_arun: - ret = await list_installed_kernels('/') - self.assertEqual(['linux-image-6.1.0-16-generic'], ret) + rv.stdout = "linux-image-6.1.0-16-generic" + with mock.patch( + "subiquity.server.kernel.arun_command", return_value=rv + ) as mock_arun: + ret = await list_installed_kernels("/") + self.assertEqual(["linux-image-6.1.0-16-generic"], ret) mock_arun.assert_called_once() async def test_two_kernels(self): rv = subprocess.CompletedProcess([], 0) - rv.stdout = '''\ + rv.stdout = """\ linux-image-6.1.0-16-generic linux-image-6.2.0-24-generic -''' - with mock.patch('subiquity.server.kernel.arun_command', - return_value=rv) as mock_arun: - ret = await list_installed_kernels('/') - self.assertEqual(['linux-image-6.1.0-16-generic', - 'linux-image-6.2.0-24-generic'], ret) +""" + with mock.patch( + "subiquity.server.kernel.arun_command", return_value=rv + ) as mock_arun: + ret = await list_installed_kernels("/") + self.assertEqual( + ["linux-image-6.1.0-16-generic", "linux-image-6.2.0-24-generic"], ret + ) mock_arun.assert_called_once() async def test_no_kernel(self): rv = subprocess.CompletedProcess([], 0) - rv.stdout = '\n' - with mock.patch('subiquity.server.kernel.arun_command', - return_value=rv) as mock_arun: - ret = await list_installed_kernels('/') + rv.stdout = "\n" + with mock.patch( + "subiquity.server.kernel.arun_command", return_value=rv + ) as mock_arun: + ret = await list_installed_kernels("/") self.assertEqual([], ret) mock_arun.assert_called_once() diff --git a/subiquity/server/tests/test_mounter.py b/subiquity/server/tests/test_mounter.py index 9fa482f3..093e85bf 100644 --- a/subiquity/server/tests/test_mounter.py +++ b/subiquity/server/tests/test_mounter.py @@ -15,20 +15,19 @@ import os import pathlib -from unittest.mock import AsyncMock, call, Mock, patch +from unittest.mock import AsyncMock, Mock, call, patch -from subiquitycore.tests import SubiTestCase -from subiquitycore.tests.mocks import make_app from subiquity.server.mounter import ( - lowerdir_for, Mounter, Mountpoint, OverlayMountpoint, + lowerdir_for, ) +from subiquitycore.tests import SubiTestCase +from subiquitycore.tests.mocks import make_app class TestMounter(SubiTestCase): - def setUp(self): self.model = Mock() self.app = make_app(self.model) @@ -36,8 +35,9 @@ class TestMounter(SubiTestCase): async def test_mount_unmount(self): mounter = Mounter(self.app) # Make sure we can unmount something that we mounted before. - with patch.object(self.app, "command_runner", - create=True, new_callable=AsyncMock): + with patch.object( + self.app, "command_runner", create=True, new_callable=AsyncMock + ): m = await mounter.mount("/dev/cdrom", self.tmp_dir()) await mounter.unmount(m) @@ -50,10 +50,10 @@ class TestMounter(SubiTestCase): def touch(*paths): for p in paths: - if p.endswith('/'): + if p.endswith("/"): os.mkdir(p) else: - with open(p, 'w'): + with open(p, "w"): pass # Create the following situation: @@ -75,31 +75,29 @@ class TestMounter(SubiTestCase): # * src/only-src-dir -> dst/only-src-dir # * src/only-src-file -> dst/only-src-file - touch(f'{src}/only-src-file') - touch(f'{dst}/only-dst-file') - touch(f'{src}/both-file', f'{dst}/both-file') - touch(f'{src}/only-src-dir/', f'{src}/only-src-dir/file') - touch(f'{dst}/only-dst-dir/', f'{dst}/only-dst-dir/file') - touch(f'{src}/both-dir/', f'{dst}/both-dir/') - touch(f'{src}/both-dir/only-src-file', f'{src}/both-dir/both-file') - touch(f'{dst}/both-dir/only-dst-file', f'{dst}/both-dir/both-file') + touch(f"{src}/only-src-file") + touch(f"{dst}/only-dst-file") + touch(f"{src}/both-file", f"{dst}/both-file") + touch(f"{src}/only-src-dir/", f"{src}/only-src-dir/file") + touch(f"{dst}/only-dst-dir/", f"{dst}/only-dst-dir/file") + touch(f"{src}/both-dir/", f"{dst}/both-dir/") + touch(f"{src}/both-dir/only-src-file", f"{src}/both-dir/both-file") + touch(f"{dst}/both-dir/only-dst-file", f"{dst}/both-dir/both-file") with patch.object(mounter, "mount", new_callable=AsyncMock) as mocked: await mounter.bind_mount_tree(src, dst) - mocked.assert_has_calls([ - call( - f'{src}/only-src-file', - f'{dst}/only-src-file', - options='bind'), - call( - f'{src}/only-src-dir', - f'{dst}/only-src-dir', - options='bind'), - call( - f'{src}/both-dir/only-src-file', - f'{dst}/both-dir/only-src-file', - options='bind'), - ], any_order=True) + mocked.assert_has_calls( + [ + call(f"{src}/only-src-file", f"{dst}/only-src-file", options="bind"), + call(f"{src}/only-src-dir", f"{dst}/only-src-dir", options="bind"), + call( + f"{src}/both-dir/only-src-file", + f"{dst}/both-dir/only-src-file", + options="bind", + ), + ], + any_order=True, + ) self.assertEqual(mocked.call_count, 3) async def test_bind_mount_tree_no_target(self): @@ -107,18 +105,18 @@ class TestMounter(SubiTestCase): # check bind_mount_tree behaviour when the passed dst does not # exist. src = self.tmp_dir() - dst = os.path.join(self.tmp_dir(), 'dst') + dst = os.path.join(self.tmp_dir(), "dst") with patch.object(mounter, "mount", new_callable=AsyncMock) as mocked: await mounter.bind_mount_tree(src, dst) - mocked.assert_called_once_with(src, dst, options='bind') + mocked.assert_called_once_with(src, dst, options="bind") async def test_bind_mount_creates_dest_dir(self): mounter = Mounter(self.app) # When we are bind mounting a directory, the destination should be # created as a directory. src = self.tmp_dir() - dst = pathlib.Path(self.tmp_dir()) / 'dst' + dst = pathlib.Path(self.tmp_dir()) / "dst" self.app.command_runner = AsyncMock() await mounter.bind_mount_tree(src, dst) @@ -128,9 +126,9 @@ class TestMounter(SubiTestCase): mounter = Mounter(self.app) # When we are bind mounting a file, the destination should be created # as a file. - src = pathlib.Path(self.tmp_dir()) / 'src' + src = pathlib.Path(self.tmp_dir()) / "src" src.touch() - dst = pathlib.Path(self.tmp_dir()) / 'dst' + dst = pathlib.Path(self.tmp_dir()) / "dst" self.app.command_runner = AsyncMock() await mounter.bind_mount_tree(src, dst) @@ -140,9 +138,9 @@ class TestMounter(SubiTestCase): mounter = Mounter(self.app) # When we are mounting a device, the destination should be created # as a directory. - src = pathlib.Path(self.tmp_dir()) / 'src' + src = pathlib.Path(self.tmp_dir()) / "src" src.touch() - dst = pathlib.Path(self.tmp_dir()) / 'dst' + dst = pathlib.Path(self.tmp_dir()) / "dst" self.app.command_runner = AsyncMock() await mounter.mount(src, dst) @@ -151,51 +149,46 @@ class TestMounter(SubiTestCase): class TestLowerDirFor(SubiTestCase): def test_lowerdir_for_str(self): - self.assertEqual( - lowerdir_for("/tmp/lower1"), - "/tmp/lower1") + self.assertEqual(lowerdir_for("/tmp/lower1"), "/tmp/lower1") def test_lowerdir_for_str_list(self): self.assertEqual( - lowerdir_for(["/tmp/lower1", "/tmp/lower2"]), - "/tmp/lower2:/tmp/lower1") + lowerdir_for(["/tmp/lower1", "/tmp/lower2"]), "/tmp/lower2:/tmp/lower1" + ) def test_lowerdir_for_mountpoint(self): - self.assertEqual( - lowerdir_for(Mountpoint(mountpoint="/mnt")), - "/mnt") + self.assertEqual(lowerdir_for(Mountpoint(mountpoint="/mnt")), "/mnt") def test_lowerdir_for_simple_overlay(self): overlay = OverlayMountpoint( - lowers=["/tmp/lower1"], - upperdir="/tmp/upper1", - mountpoint="/mnt", + lowers=["/tmp/lower1"], + upperdir="/tmp/upper1", + mountpoint="/mnt", ) self.assertEqual(lowerdir_for(overlay), "/tmp/upper1:/tmp/lower1") def test_lowerdir_for_overlay(self): overlay = OverlayMountpoint( - lowers=["/tmp/lower1", "/tmp/lower2"], - upperdir="/tmp/upper1", - mountpoint="/mnt", + lowers=["/tmp/lower1", "/tmp/lower2"], + upperdir="/tmp/upper1", + mountpoint="/mnt", ) - self.assertEqual( - lowerdir_for(overlay), - "/tmp/upper1:/tmp/lower2:/tmp/lower1") + self.assertEqual(lowerdir_for(overlay), "/tmp/upper1:/tmp/lower2:/tmp/lower1") def test_lowerdir_for_list(self): overlay = OverlayMountpoint( - lowers=["/tmp/overlaylower1", "/tmp/overlaylower2"], - upperdir="/tmp/overlayupper1", - mountpoint="/mnt/overlay", + lowers=["/tmp/overlaylower1", "/tmp/overlaylower2"], + upperdir="/tmp/overlayupper1", + mountpoint="/mnt/overlay", ) mountpoint = Mountpoint(mountpoint="/mnt/mountpoint") lowers = ["/tmp/lower1", "/tmp/lower2"] self.assertEqual( - lowerdir_for([overlay, mountpoint, lowers]), - "/tmp/lower2:/tmp/lower1" + - ":/mnt/mountpoint" + - ":/tmp/overlayupper1:/tmp/overlaylower2:/tmp/overlaylower1") + lowerdir_for([overlay, mountpoint, lowers]), + "/tmp/lower2:/tmp/lower1" + + ":/mnt/mountpoint" + + ":/tmp/overlayupper1:/tmp/overlaylower2:/tmp/overlaylower1", + ) def test_lowerdir_for_other(self): with self.assertRaises(NotImplementedError): diff --git a/subiquity/server/tests/test_runner.py b/subiquity/server/tests/test_runner.py index 9a6e798a..c975b082 100644 --- a/subiquity/server/tests/test_runner.py +++ b/subiquity/server/tests/test_runner.py @@ -16,11 +16,8 @@ import os from unittest.mock import patch +from subiquity.server.runner import DryRunCommandRunner, LoggedCommandRunner from subiquitycore.tests import SubiTestCase -from subiquity.server.runner import ( - LoggedCommandRunner, - DryRunCommandRunner, - ) class TestLoggedCommandRunner(SubiTestCase): @@ -34,12 +31,10 @@ class TestLoggedCommandRunner(SubiTestCase): runner = LoggedCommandRunner(ident="my-identifier") self.assertEqual(runner.use_systemd_user, True) - runner = LoggedCommandRunner(ident="my-identifier", - use_systemd_user=True) + runner = LoggedCommandRunner(ident="my-identifier", use_systemd_user=True) self.assertEqual(runner.use_systemd_user, True) - runner = LoggedCommandRunner(ident="my-identifier", - use_systemd_user=False) + runner = LoggedCommandRunner(ident="my-identifier", use_systemd_user=False) self.assertEqual(runner.use_systemd_user, False) def test_forge_systemd_cmd(self): @@ -55,21 +50,30 @@ class TestLoggedCommandRunner(SubiTestCase): with patch.dict(os.environ, environ, clear=True): cmd = runner._forge_systemd_cmd( - ["/bin/ls", "/root"], - private_mounts=True, capture=False) + ["/bin/ls", "/root"], private_mounts=True, capture=False + ) expected = [ "systemd-run", - "--wait", "--same-dir", - "--property", "SyslogIdentifier=my-id", - "--property", "PrivateMounts=yes", - "--setenv", "PATH=/snap/subiquity/x1/bin", - "--setenv", "PYTHONPATH=/usr/lib/python3.10/site-packages", - "--setenv", "PYTHON=/snap/subiquity/x1/usr/bin/python3.10", - "--setenv", "TARGET_MOUNT_POINT=/target", - "--setenv", "SNAP=/snap/subiquity/x1", + "--wait", + "--same-dir", + "--property", + "SyslogIdentifier=my-id", + "--property", + "PrivateMounts=yes", + "--setenv", + "PATH=/snap/subiquity/x1/bin", + "--setenv", + "PYTHONPATH=/usr/lib/python3.10/site-packages", + "--setenv", + "PYTHON=/snap/subiquity/x1/usr/bin/python3.10", + "--setenv", + "TARGET_MOUNT_POINT=/target", + "--setenv", + "SNAP=/snap/subiquity/x1", "--", - "/bin/ls", "/root", + "/bin/ls", + "/root", ] self.assertEqual(cmd, expected) @@ -80,26 +84,31 @@ class TestLoggedCommandRunner(SubiTestCase): } with patch.dict(os.environ, environ, clear=True): cmd = runner._forge_systemd_cmd( - ["/bin/ls", "/root"], - private_mounts=False, capture=True) + ["/bin/ls", "/root"], private_mounts=False, capture=True + ) expected = [ "systemd-run", - "--wait", "--same-dir", - "--property", "SyslogIdentifier=my-id", + "--wait", + "--same-dir", + "--property", + "SyslogIdentifier=my-id", "--user", "--pipe", - "--setenv", "PYTHONPATH=/usr/lib/python3.10/site-packages", + "--setenv", + "PYTHONPATH=/usr/lib/python3.10/site-packages", "--", - "/bin/ls", "/root", + "/bin/ls", + "/root", ] self.assertEqual(cmd, expected) class TestDryRunCommandRunner(SubiTestCase): def setUp(self): - self.runner = DryRunCommandRunner(ident="my-identifier", - delay=10, use_systemd_user=True) + self.runner = DryRunCommandRunner( + ident="my-identifier", delay=10, use_systemd_user=True + ) def test_init(self): self.assertEqual(self.runner.ident, "my-identifier") @@ -109,19 +118,22 @@ class TestDryRunCommandRunner(SubiTestCase): @patch.object(LoggedCommandRunner, "_forge_systemd_cmd") def test_forge_systemd_cmd(self, mock_super): self.runner._forge_systemd_cmd( - ["/bin/ls", "/root"], - private_mounts=True, capture=True) + ["/bin/ls", "/root"], private_mounts=True, capture=True + ) mock_super.assert_called_once_with( - ["echo", "not running:", "/bin/ls", "/root"], - private_mounts=True, capture=True) + ["echo", "not running:", "/bin/ls", "/root"], + private_mounts=True, + capture=True, + ) mock_super.reset_mock() # Make sure exceptions are handled. - self.runner._forge_systemd_cmd(["scripts/replay-curtin-log.py"], - private_mounts=True, capture=False) + self.runner._forge_systemd_cmd( + ["scripts/replay-curtin-log.py"], private_mounts=True, capture=False + ) mock_super.assert_called_once_with( - ["scripts/replay-curtin-log.py"], - private_mounts=True, capture=False) + ["scripts/replay-curtin-log.py"], private_mounts=True, capture=False + ) def test_get_delay_for_cmd(self): # Most commands use the default delay @@ -129,16 +141,21 @@ class TestDryRunCommandRunner(SubiTestCase): self.assertEqual(delay, 10) # Commands containing "unattended-upgrades" use delay * 3 - delay = self.runner._get_delay_for_cmd([ - "python3", "-m", "curtin", - "in-target", - "-t", "/target", - "--", - "unattended-upgrades", "-v", - ]) + delay = self.runner._get_delay_for_cmd( + [ + "python3", + "-m", + "curtin", + "in-target", + "-t", + "/target", + "--", + "unattended-upgrades", + "-v", + ] + ) self.assertEqual(delay, 30) # Commands having scripts/replay will actually be executed - no delay. - delay = self.runner._get_delay_for_cmd( - ["scripts/replay-curtin-log.py"]) + delay = self.runner._get_delay_for_cmd(["scripts/replay-curtin-log.py"]) self.assertEqual(delay, 0) diff --git a/subiquity/server/tests/test_server.py b/subiquity/server/tests/test_server.py index e106e45b..82d19639 100644 --- a/subiquity/server/tests/test_server.py +++ b/subiquity/server/tests/test_server.py @@ -17,9 +17,7 @@ import os import shlex from unittest.mock import Mock, patch -from subiquitycore.utils import run_command -from subiquitycore.tests import SubiTestCase -from subiquitycore.tests.mocks import make_app +from subiquity.common.types import PasswordKind from subiquity.server.server import ( MetaController, SubiquityServer, @@ -27,17 +25,19 @@ from subiquity.server.server import ( iso_autoinstall_path, root_autoinstall_path, ) -from subiquity.common.types import PasswordKind +from subiquitycore.tests import SubiTestCase +from subiquitycore.tests.mocks import make_app +from subiquitycore.utils import run_command class TestAutoinstallLoad(SubiTestCase): async def asyncSetUp(self): self.tempdir = self.tmp_dir() - os.makedirs(self.tempdir + '/cdrom', exist_ok=True) + os.makedirs(self.tempdir + "/cdrom", exist_ok=True) opts = Mock() opts.dry_run = True opts.output_base = self.tempdir - opts.machine_config = 'examples/machines/simple.json' + opts.machine_config = "examples/machines/simple.json" opts.kernel_cmdline = {} opts.autoinstall = None self.server = SubiquityServer(opts, None) @@ -49,53 +49,53 @@ class TestAutoinstallLoad(SubiTestCase): def create(self, path, contents): path = self.path(path) - with open(path, 'w') as fp: + with open(path, "w") as fp: fp.write(contents) return path def test_autoinstall_disabled(self): - self.create(root_autoinstall_path, 'root') - self.create(cloud_autoinstall_path, 'cloud') - self.create(iso_autoinstall_path, 'iso') + self.create(root_autoinstall_path, "root") + self.create(cloud_autoinstall_path, "cloud") + self.create(iso_autoinstall_path, "iso") self.server.opts.autoinstall = "" self.assertIsNone(self.server.select_autoinstall()) def test_root_wins(self): - root = self.create(root_autoinstall_path, 'root') - autoinstall = self.create(self.path('arg.autoinstall.yaml'), 'arg') + root = self.create(root_autoinstall_path, "root") + autoinstall = self.create(self.path("arg.autoinstall.yaml"), "arg") self.server.opts.autoinstall = autoinstall - self.create(cloud_autoinstall_path, 'cloud') - self.create(iso_autoinstall_path, 'iso') + self.create(cloud_autoinstall_path, "cloud") + self.create(iso_autoinstall_path, "iso") self.assertEqual(root, self.server.select_autoinstall()) - self.assert_contents(root, 'root') + self.assert_contents(root, "root") def test_arg_wins(self): root = self.path(root_autoinstall_path) - arg = self.create(self.path('arg.autoinstall.yaml'), 'arg') + arg = self.create(self.path("arg.autoinstall.yaml"), "arg") self.server.opts.autoinstall = arg - self.create(cloud_autoinstall_path, 'cloud') - self.create(iso_autoinstall_path, 'iso') + self.create(cloud_autoinstall_path, "cloud") + self.create(iso_autoinstall_path, "iso") self.assertEqual(root, self.server.select_autoinstall()) - self.assert_contents(root, 'arg') + self.assert_contents(root, "arg") def test_cloud_wins(self): root = self.path(root_autoinstall_path) - self.create(cloud_autoinstall_path, 'cloud') - self.create(iso_autoinstall_path, 'iso') + self.create(cloud_autoinstall_path, "cloud") + self.create(iso_autoinstall_path, "iso") self.assertEqual(root, self.server.select_autoinstall()) - self.assert_contents(root, 'cloud') + self.assert_contents(root, "cloud") def test_iso_wins(self): root = self.path(root_autoinstall_path) - self.create(iso_autoinstall_path, 'iso') + self.create(iso_autoinstall_path, "iso") self.assertEqual(root, self.server.select_autoinstall()) - self.assert_contents(root, 'iso') + self.assert_contents(root, "iso") def test_nobody_wins(self): self.assertIsNone(self.server.select_autoinstall()) def test_bogus_autoinstall_argument(self): - self.server.opts.autoinstall = self.path('nonexistant.yaml') + self.server.opts.autoinstall = self.path("nonexistant.yaml") with self.assertRaises(Exception): self.server.select_autoinstall() @@ -105,24 +105,21 @@ class TestAutoinstallLoad(SubiTestCase): rootpath = self.path(root_autoinstall_path) cmd = f"sed -i -e '$ a stuff: things' {rootpath}" - contents = f'''\ + contents = f"""\ version: 1 early-commands: ["{cmd}"] -''' - arg = self.create(self.path('arg.autoinstall.yaml'), contents) +""" + arg = self.create(self.path("arg.autoinstall.yaml"), contents) self.server.opts.autoinstall = arg self.server.autoinstall = self.server.select_autoinstall() self.server.load_autoinstall_config(only_early=True) - before_early = {'version': 1, - 'early-commands': [cmd]} + before_early = {"version": 1, "early-commands": [cmd]} self.assertEqual(before_early, self.server.autoinstall_config) run_command(shlex.split(cmd), check=True) self.server.load_autoinstall_config(only_early=False) - after_early = {'version': 1, - 'early-commands': [cmd], - 'stuff': 'things'} + after_early = {"version": 1, "early-commands": [cmd], "stuff": "things"} self.assertEqual(after_early, self.server.autoinstall_config) @@ -134,41 +131,46 @@ class TestMetaController(SubiTestCase): async def test_interactive_sections_empty(self): mc = MetaController(make_app()) - mc.app.autoinstall_config['interactive-sections'] = [] + mc.app.autoinstall_config["interactive-sections"] = [] self.assertEqual([], await mc.interactive_sections_GET()) async def test_interactive_sections_all(self): mc = MetaController(make_app()) - mc.app.autoinstall_config['interactive-sections'] = ['*'] + mc.app.autoinstall_config["interactive-sections"] = ["*"] mc.app.controllers.instances = [ - Mock(autoinstall_key='f', interactive=Mock(return_value=False)), + Mock(autoinstall_key="f", interactive=Mock(return_value=False)), Mock(autoinstall_key=None, interactive=Mock(return_value=True)), - Mock(autoinstall_key='t', interactive=Mock(return_value=True)), + Mock(autoinstall_key="t", interactive=Mock(return_value=True)), ] - self.assertEqual(['t'], await mc.interactive_sections_GET()) + self.assertEqual(["t"], await mc.interactive_sections_GET()) async def test_interactive_sections_one(self): mc = MetaController(make_app()) - mc.app.autoinstall_config['interactive-sections'] = ['network'] - self.assertEqual(['network'], await mc.interactive_sections_GET()) + mc.app.autoinstall_config["interactive-sections"] = ["network"] + self.assertEqual(["network"], await mc.interactive_sections_GET()) class TestDefaultUser(SubiTestCase): - @patch('subiquity.server.server.ug_util.normalize_users_groups', - Mock(return_value=(None, None))) - @patch('subiquity.server.server.ug_util.extract_default', - Mock(return_value=(None, None))) - @patch('subiquity.server.server.user_key_fingerprints', - Mock(side_effect=Exception('should not be called'))) + @patch( + "subiquity.server.server.ug_util.normalize_users_groups", + Mock(return_value=(None, None)), + ) + @patch( + "subiquity.server.server.ug_util.extract_default", + Mock(return_value=(None, None)), + ) + @patch( + "subiquity.server.server.user_key_fingerprints", + Mock(side_effect=Exception("should not be called")), + ) async def test_no_default_user(self): opts = Mock() opts.dry_run = True opts.output_base = self.tmp_dir() - opts.machine_config = 'examples/machines/simple.json' + opts.machine_config = "examples/machines/simple.json" server = SubiquityServer(opts, None) server.cloud = Mock() - server._user_has_password = Mock( - side_effect=Exception('should not be called')) + server._user_has_password = Mock(side_effect=Exception("should not be called")) opts.dry_run = False # exciting! server.set_installer_password() diff --git a/subiquity/server/tests/test_ssh.py b/subiquity/server/tests/test_ssh.py index 2f025829..538b270a 100644 --- a/subiquity/server/tests/test_ssh.py +++ b/subiquity/server/tests/test_ssh.py @@ -13,17 +13,12 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from subprocess import CompletedProcess, CalledProcessError import unittest +from subprocess import CalledProcessError, CompletedProcess from unittest import mock from subiquity.common.types import SSHFetchIdStatus -from subiquity.server.ssh import ( - DryRunSSHKeyFetcher, - SSHFetchError, - SSHKeyFetcher, - ) - +from subiquity.server.ssh import DryRunSSHKeyFetcher, SSHFetchError, SSHKeyFetcher from subiquitycore.tests.mocks import make_app @@ -40,9 +35,12 @@ class TestSSHKeyFetcher(unittest.IsolatedAsyncioTestCase): ssh-rsa AAAAC3NzaC1lZ test@gh/335797 # ssh-import-id gh:test """ keys = await self.fetcher.fetch_keys_for_id(user_id="gh:test") - self.assertEqual(keys, [ + self.assertEqual( + keys, + [ "ssh-rsa AAAAC3NzaC1lZ test@gh/335797 # ssh-import-id gh:test", - ]) + ], + ) async def test_fetch_keys_for_id_two_key_ok(self): with mock.patch(self.arun_command_sym) as mock_arun: @@ -53,10 +51,13 @@ ssh-rsa AAAAC3NzaC1lZ test@host # ssh-import-id lp:test ssh-ed25519 AAAAAC3N test@host # ssh-import-id lp:test """ keys = await self.fetcher.fetch_keys_for_id(user_id="lp:test") - self.assertEqual(keys, [ - "ssh-rsa AAAAC3NzaC1lZ test@host # ssh-import-id lp:test", - "ssh-ed25519 AAAAAC3N test@host # ssh-import-id lp:test", - ]) + self.assertEqual( + keys, + [ + "ssh-rsa AAAAC3NzaC1lZ test@host # ssh-import-id lp:test", + "ssh-ed25519 AAAAAC3N test@host # ssh-import-id lp:test", + ], + ) async def test_fetch_keys_for_id_error(self): with mock.patch(self.arun_command_sym) as mock_arun: @@ -69,8 +70,7 @@ ssh-ed25519 AAAAAC3N test@host # ssh-import-id lp:test await self.fetcher.fetch_keys_for_id(user_id="lp:test2") self.assertEqual(cm.exception.reason, stderr) - self.assertEqual(cm.exception.status, - SSHFetchIdStatus.IMPORT_ERROR) + self.assertEqual(cm.exception.status, SSHFetchIdStatus.IMPORT_ERROR) async def test_gen_fingerprint_for_key_ok(self): with mock.patch(self.arun_command_sym) as mock_arun: @@ -79,11 +79,15 @@ ssh-ed25519 AAAAAC3N test@host # ssh-import-id lp:test 256 SHA256:rIR9UVRKspl7+KF75s test@host # ssh-import-id lp:test (ED25519) """ fp = await self.fetcher.gen_fingerprint_for_key( - "ssh-ed25519 AAAAAC3N test@host # ssh-import-id lp:test") + "ssh-ed25519 AAAAAC3N test@host # ssh-import-id lp:test" + ) - self.assertEqual(fp, """\ + self.assertEqual( + fp, + """\ 256 SHA256:rIR9UVRKspl7+KF75s test@host # ssh-import-id lp:test (ED25519)\ -""") +""", + ) async def test_gen_fingerprint_for_key_error(self): stderr = "(stdin) is not a public key file.\n" @@ -91,11 +95,11 @@ ssh-ed25519 AAAAAC3N test@host # ssh-import-id lp:test mock_arun.side_effect = CalledProcessError(1, [], None, stderr) with self.assertRaises(SSHFetchError) as cm: await self.fetcher.gen_fingerprint_for_key( - "ssh-nsa AAAAAC3N test@host # ssh-import-id lp:test") + "ssh-nsa AAAAAC3N test@host # ssh-import-id lp:test" + ) self.assertEqual(cm.exception.reason, stderr) - self.assertEqual(cm.exception.status, - SSHFetchIdStatus.FINGERPRINT_ERROR) + self.assertEqual(cm.exception.status, SSHFetchIdStatus.FINGERPRINT_ERROR) class TestDryRunSSHKeyFetcher(unittest.IsolatedAsyncioTestCase): @@ -104,10 +108,12 @@ class TestDryRunSSHKeyFetcher(unittest.IsolatedAsyncioTestCase): async def test_fetch_keys_fake_success(self): result = await self.fetcher.fetch_keys_fake_success("lp:test") - expected = ["""\ + expected = [ + """\ ssh-ed25519\ AAAAC3NzaC1lZDI1NTE5AAAAIMM/qhS3hS3+IjpJBYXZWCqPKPH9Zag8QYbS548iEjoZ\ - test@earth # ssh-import-id lp:test"""] + test@earth # ssh-import-id lp:test""" + ] self.assertEqual(result, expected) async def test_fetch_keys_fake_failure(self): diff --git a/subiquity/server/tests/test_ubuntu_advantage.py b/subiquity/server/tests/test_ubuntu_advantage.py index 25764a02..366c94c2 100644 --- a/subiquity/server/tests/test_ubuntu_advantage.py +++ b/subiquity/server/tests/test_ubuntu_advantage.py @@ -13,19 +13,19 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from subprocess import CompletedProcess import unittest -from unittest.mock import patch, AsyncMock, ANY +from subprocess import CompletedProcess +from unittest.mock import ANY, AsyncMock, patch from subiquity.common.types import UbuntuProService from subiquity.server.ubuntu_advantage import ( - InvalidTokenError, - ExpiredTokenError, CheckSubscriptionError, - UAInterface, + ExpiredTokenError, + InvalidTokenError, MockedUAInterfaceStrategy, UAClientUAInterfaceStrategy, - ) + UAInterface, +) class TestMockedUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase): @@ -68,19 +68,18 @@ class TestUAClientUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase): self.assertEqual(strategy.executable, ["/usr/bin/ubuntu-advantage"]) # Initialize with a path + interpreter. - strategy = UAClientUAInterfaceStrategy( - ("python3", "/usr/bin/ubuntu-advantage") - ) - self.assertEqual(strategy.executable, - ["python3", "/usr/bin/ubuntu-advantage"]) + strategy = UAClientUAInterfaceStrategy(("python3", "/usr/bin/ubuntu-advantage")) + self.assertEqual(strategy.executable, ["python3", "/usr/bin/ubuntu-advantage"]) async def test_query_info_succeeded(self): strategy = UAClientUAInterfaceStrategy() command = ( "ubuntu-advantage", "status", - "--format", "json", - "--simulate-with-token", "123456789", + "--format", + "json", + "--simulate-with-token", + "123456789", ) with patch(self.arun_command_sym) as mock_arun: @@ -94,8 +93,10 @@ class TestUAClientUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase): command = ( "ubuntu-advantage", "status", - "--format", "json", - "--simulate-with-token", "123456789", + "--format", + "json", + "--simulate-with-token", + "123456789", ) with patch(self.arun_command_sym) as mock_arun: @@ -110,8 +111,10 @@ class TestUAClientUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase): command = ( "ubuntu-advantage", "status", - "--format", "json", - "--simulate-with-token", "123456789", + "--format", + "json", + "--simulate-with-token", + "123456789", ) with patch(self.arun_command_sym) as mock_arun: @@ -141,8 +144,10 @@ class TestUAClientUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase): command = ( "ubuntu-advantage", "status", - "--format", "json", - "--simulate-with-token", "123456789", + "--format", + "json", + "--simulate-with-token", + "123456789", ) with patch(self.arun_command_sym) as mock_arun: @@ -158,15 +163,17 @@ class TestUAClientUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase): mock_arun.return_value = CompletedProcess([], 0) mock_arun.return_value.stdout = '{"result": "success"}' result = await strategy._api_call( - endpoint="u.pro.attach.magic.wait.v1", - params=[("magic_token", "132456")]) + endpoint="u.pro.attach.magic.wait.v1", + params=[("magic_token", "132456")], + ) self.assertEqual(result, {"result": "success"}) command = ( "ubuntu-advantage", "api", "u.pro.attach.magic.wait.v1", - "--args", "magic_token=132456", + "--args", + "magic_token=132456", ) mock_arun.assert_called_once_with(command, check=False, env=ANY) @@ -176,8 +183,8 @@ class TestUAClientUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase): await strategy.magic_initiate_v1() api_call.assert_called_once_with( - endpoint="u.pro.attach.magic.initiate.v1", - params=[]) + endpoint="u.pro.attach.magic.initiate.v1", params=[] + ) async def test_magic_wait_v1(self): strategy = UAClientUAInterfaceStrategy() @@ -185,8 +192,8 @@ class TestUAClientUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase): await strategy.magic_wait_v1(magic_token="ABCDEF") api_call.assert_called_once_with( - endpoint="u.pro.attach.magic.wait.v1", - params=[("magic_token", "ABCDEF")]) + endpoint="u.pro.attach.magic.wait.v1", params=[("magic_token", "ABCDEF")] + ) async def test_magic_revoke_v1(self): strategy = UAClientUAInterfaceStrategy() @@ -194,12 +201,11 @@ class TestUAClientUAInterfaceStrategy(unittest.IsolatedAsyncioTestCase): await strategy.magic_revoke_v1(magic_token="ABCDEF") api_call.assert_called_once_with( - endpoint="u.pro.attach.magic.revoke.v1", - params=[("magic_token", "ABCDEF")]) + endpoint="u.pro.attach.magic.revoke.v1", params=[("magic_token", "ABCDEF")] + ) class TestUAInterface(unittest.IsolatedAsyncioTestCase): - async def test_mocked_get_subscription(self): strategy = MockedUAInterfaceStrategy(scale_factor=1_000_000) interface = UAInterface(strategy) @@ -237,56 +243,66 @@ class TestUAInterface(unittest.IsolatedAsyncioTestCase): "description": "Center for Internet Security Audit Tools", "entitled": "no", "auto_enabled": "no", - "available": "yes" + "available": "yes", }, { "name": "esm-apps", - "description": - "UA Apps: Extended Security Maintenance (ESM)", + "description": "UA Apps: Extended Security Maintenance (ESM)", "entitled": "yes", "auto_enabled": "yes", - "available": "no" + "available": "no", }, { "name": "esm-infra", - "description": - "UA Infra: Extended Security Maintenance (ESM)", + "description": "UA Infra: Extended Security Maintenance (ESM)", "entitled": "yes", "auto_enabled": "yes", - "available": "yes" + "available": "yes", }, { "name": "fips", "description": "NIST-certified core packages", "entitled": "yes", "auto_enabled": "no", - "available": "yes" + "available": "yes", }, - ] + ], } interface.get_subscription_status = AsyncMock(return_value=status) subscription = await interface.get_subscription(token="XXX") - self.assertIn(UbuntuProService( - name="esm-infra", - description="UA Infra: Extended Security Maintenance (ESM)", - auto_enabled=True, - ), subscription.services) - self.assertIn(UbuntuProService( - name="fips", - description="NIST-certified core packages", - auto_enabled=False, - ), subscription.services) - self.assertNotIn(UbuntuProService( - name="esm-apps", - description="UA Apps: Extended Security Maintenance (ESM)", - auto_enabled=True, - ), subscription.services) - self.assertNotIn(UbuntuProService( - name="cis", - description="Center for Internet Security Audit Tools", - auto_enabled=False, - ), subscription.services) + self.assertIn( + UbuntuProService( + name="esm-infra", + description="UA Infra: Extended Security Maintenance (ESM)", + auto_enabled=True, + ), + subscription.services, + ) + self.assertIn( + UbuntuProService( + name="fips", + description="NIST-certified core packages", + auto_enabled=False, + ), + subscription.services, + ) + self.assertNotIn( + UbuntuProService( + name="esm-apps", + description="UA Apps: Extended Security Maintenance (ESM)", + auto_enabled=True, + ), + subscription.services, + ) + self.assertNotIn( + UbuntuProService( + name="cis", + description="Center for Internet Security Audit Tools", + auto_enabled=False, + ), + subscription.services, + ) # Test with "Z" suffix for the expiration date. status["expires"] = "2035-12-31T00:00:00Z" diff --git a/subiquity/server/tests/test_ubuntu_drivers.py b/subiquity/server/tests/test_ubuntu_drivers.py index 0aa87b6d..da3ab7ba 100644 --- a/subiquity/server/tests/test_ubuntu_drivers.py +++ b/subiquity/server/tests/test_ubuntu_drivers.py @@ -13,19 +13,18 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from subprocess import CalledProcessError import unittest -from unittest.mock import patch, AsyncMock, Mock - -from subiquitycore.tests.mocks import make_app +from subprocess import CalledProcessError +from unittest.mock import AsyncMock, Mock, patch from subiquity.server.dryrun import DRConfig from subiquity.server.ubuntu_drivers import ( - UbuntuDriversInterface, - UbuntuDriversClientInterface, - UbuntuDriversRunDriversInterface, CommandNotFoundError, - ) + UbuntuDriversClientInterface, + UbuntuDriversInterface, + UbuntuDriversRunDriversInterface, +) +from subiquitycore.tests.mocks import make_app class TestUbuntuDriversInterface(unittest.IsolatedAsyncioTestCase): @@ -37,35 +36,53 @@ class TestUbuntuDriversInterface(unittest.IsolatedAsyncioTestCase): ubuntu_drivers = UbuntuDriversInterface(self.app, gpgpu=False) self.assertEqual(ubuntu_drivers.app, self.app) - self.assertEqual(ubuntu_drivers.list_drivers_cmd, [ - "ubuntu-drivers", "list", - "--recommended", - ]) - self.assertEqual(ubuntu_drivers.install_drivers_cmd, [ - "ubuntu-drivers", "install", "--no-oem"]) + self.assertEqual( + ubuntu_drivers.list_drivers_cmd, + [ + "ubuntu-drivers", + "list", + "--recommended", + ], + ) + self.assertEqual( + ubuntu_drivers.install_drivers_cmd, + ["ubuntu-drivers", "install", "--no-oem"], + ) ubuntu_drivers = UbuntuDriversInterface(self.app, gpgpu=True) - self.assertEqual(ubuntu_drivers.list_drivers_cmd, [ - "ubuntu-drivers", "list", - "--recommended", "--gpgpu", - ]) - self.assertEqual(ubuntu_drivers.install_drivers_cmd, [ - "ubuntu-drivers", "install", "--no-oem", "--gpgpu"]) + self.assertEqual( + ubuntu_drivers.list_drivers_cmd, + [ + "ubuntu-drivers", + "list", + "--recommended", + "--gpgpu", + ], + ) + self.assertEqual( + ubuntu_drivers.install_drivers_cmd, + ["ubuntu-drivers", "install", "--no-oem", "--gpgpu"], + ) @patch.multiple(UbuntuDriversInterface, __abstractmethods__=set()) @patch("subiquity.server.ubuntu_drivers.run_curtin_command") async def test_install_drivers(self, mock_run_curtin_command): ubuntu_drivers = UbuntuDriversInterface(self.app, gpgpu=False) await ubuntu_drivers.install_drivers( - root_dir="/target", - context="installing third-party drivers") + root_dir="/target", context="installing third-party drivers" + ) mock_run_curtin_command.assert_called_once_with( - self.app, "installing third-party drivers", - "in-target", "-t", "/target", - "--", - "ubuntu-drivers", "install", "--no-oem", - private_mounts=True, - ) + self.app, + "installing third-party drivers", + "in-target", + "-t", + "/target", + "--", + "ubuntu-drivers", + "install", + "--no-oem", + private_mounts=True, + ) @patch.multiple(UbuntuDriversInterface, __abstractmethods__=set()) def test_drivers_from_output(self): @@ -75,8 +92,8 @@ class TestUbuntuDriversInterface(unittest.IsolatedAsyncioTestCase): nvidia-driver-470 linux-modules-nvidia-470-generic-hwe-20.04 """ self.assertEqual( - ubuntu_drivers._drivers_from_output(output=output), - ["nvidia-driver-470"]) + ubuntu_drivers._drivers_from_output(output=output), ["nvidia-driver-470"] + ) # Make sure empty lines are discarded output = """ @@ -87,32 +104,36 @@ nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04 """ self.assertEqual( - ubuntu_drivers._drivers_from_output(output=output), - ["nvidia-driver-470", "nvidia-driver-510"]) + ubuntu_drivers._drivers_from_output(output=output), + ["nvidia-driver-470", "nvidia-driver-510"], + ) class TestUbuntuDriversClientInterface(unittest.IsolatedAsyncioTestCase): def setUp(self): self.app = make_app() - self.ubuntu_drivers = UbuntuDriversClientInterface( - self.app, gpgpu=False) + self.ubuntu_drivers = UbuntuDriversClientInterface(self.app, gpgpu=False) async def test_ensure_cmd_exists(self): with patch.object( - self.app, "command_runner", - create=True, new_callable=AsyncMock) as mock_runner: + self.app, "command_runner", create=True, new_callable=AsyncMock + ) as mock_runner: # On success await self.ubuntu_drivers.ensure_cmd_exists("/target") mock_runner.run.assert_called_once_with( - [ - "chroot", "/target", - "sh", "-c", "command -v ubuntu-drivers", - ]) + [ + "chroot", + "/target", + "sh", + "-c", + "command -v ubuntu-drivers", + ] + ) # On process failure mock_runner.run.side_effect = CalledProcessError( - returncode=1, - cmd=["sh", "-c", "command -v ubuntu-drivers"]) + returncode=1, cmd=["sh", "-c", "command -v ubuntu-drivers"] + ) with self.assertRaises(CommandNotFoundError): await self.ubuntu_drivers.ensure_cmd_exists("/target") @@ -120,38 +141,55 @@ class TestUbuntuDriversClientInterface(unittest.IsolatedAsyncioTestCase): @patch("subiquity.server.ubuntu_drivers.run_curtin_command") async def test_list_drivers(self, mock_run_curtin_command): # Make sure this gets decoded as utf-8. - mock_run_curtin_command.return_value = Mock(stdout=b"""\ + mock_run_curtin_command.return_value = Mock( + stdout=b"""\ nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04 -""") +""" + ) drivers = await self.ubuntu_drivers.list_drivers( - root_dir="/target", - context="listing third-party drivers") + root_dir="/target", context="listing third-party drivers" + ) mock_run_curtin_command.assert_called_once_with( - self.app, "listing third-party drivers", - "in-target", "-t", "/target", - "--", - "ubuntu-drivers", "list", "--recommended", - capture=True, private_mounts=True) + self.app, + "listing third-party drivers", + "in-target", + "-t", + "/target", + "--", + "ubuntu-drivers", + "list", + "--recommended", + capture=True, + private_mounts=True, + ) self.assertEqual(drivers, ["nvidia-driver-510"]) @patch("subiquity.server.ubuntu_drivers.run_curtin_command") async def test_list_oem(self, mock_run_curtin_command): # Make sure this gets decoded as utf-8. - mock_run_curtin_command.return_value = Mock(stdout=b"""\ + mock_run_curtin_command.return_value = Mock( + stdout=b"""\ oem-somerville-tentacool-meta -""") +""" + ) drivers = await self.ubuntu_drivers.list_oem( - root_dir="/target", - context="listing OEM meta-packages") + root_dir="/target", context="listing OEM meta-packages" + ) mock_run_curtin_command.assert_called_once_with( - self.app, "listing OEM meta-packages", - "in-target", "-t", "/target", - "--", - "ubuntu-drivers", "list-oem", - capture=True, private_mounts=True) + self.app, + "listing OEM meta-packages", + "in-target", + "-t", + "/target", + "--", + "ubuntu-drivers", + "list-oem", + capture=True, + private_mounts=True, + ) self.assertEqual(drivers, ["oem-somerville-tentacool-meta"]) @@ -161,40 +199,59 @@ class TestUbuntuDriversRunDriversInterface(unittest.IsolatedAsyncioTestCase): self.app = make_app() self.app.dr_cfg = DRConfig() self.app.dr_cfg.ubuntu_drivers_run_on_host_umockdev = None - self.ubuntu_drivers = UbuntuDriversRunDriversInterface( - self.app, gpgpu=False) + self.ubuntu_drivers = UbuntuDriversRunDriversInterface(self.app, gpgpu=False) def test_init_no_umockdev(self): self.app.dr_cfg.ubuntu_drivers_run_on_host_umockdev = None - ubuntu_drivers = UbuntuDriversRunDriversInterface( - self.app, gpgpu=False) - self.assertEqual(ubuntu_drivers.list_oem_cmd, - ["ubuntu-drivers", "list-oem"]) - self.assertEqual(ubuntu_drivers.list_drivers_cmd[0:2], - ["ubuntu-drivers", "list"]) - self.assertEqual(ubuntu_drivers.install_drivers_cmd[0:2], - ["ubuntu-drivers", "install"]) + ubuntu_drivers = UbuntuDriversRunDriversInterface(self.app, gpgpu=False) + self.assertEqual(ubuntu_drivers.list_oem_cmd, ["ubuntu-drivers", "list-oem"]) + self.assertEqual( + ubuntu_drivers.list_drivers_cmd[0:2], ["ubuntu-drivers", "list"] + ) + self.assertEqual( + ubuntu_drivers.install_drivers_cmd[0:2], ["ubuntu-drivers", "install"] + ) def test_init_with_umockdev(self): self.app.dr_cfg.ubuntu_drivers_run_on_host_umockdev = "/xps.yaml" - ubuntu_drivers = UbuntuDriversRunDriversInterface( - self.app, gpgpu=False) - self.assertEqual(ubuntu_drivers.list_oem_cmd, - ["scripts/umockdev-wrapper.py", - "--config", "/xps.yaml", "--", - "ubuntu-drivers", "list-oem"]) - self.assertEqual(ubuntu_drivers.list_drivers_cmd[0:6], - ["scripts/umockdev-wrapper.py", - "--config", "/xps.yaml", "--", - "ubuntu-drivers", "list"]) - self.assertEqual(ubuntu_drivers.install_drivers_cmd[0:6], - ["scripts/umockdev-wrapper.py", - "--config", "/xps.yaml", "--", - "ubuntu-drivers", "install"]) + ubuntu_drivers = UbuntuDriversRunDriversInterface(self.app, gpgpu=False) + self.assertEqual( + ubuntu_drivers.list_oem_cmd, + [ + "scripts/umockdev-wrapper.py", + "--config", + "/xps.yaml", + "--", + "ubuntu-drivers", + "list-oem", + ], + ) + self.assertEqual( + ubuntu_drivers.list_drivers_cmd[0:6], + [ + "scripts/umockdev-wrapper.py", + "--config", + "/xps.yaml", + "--", + "ubuntu-drivers", + "list", + ], + ) + self.assertEqual( + ubuntu_drivers.install_drivers_cmd[0:6], + [ + "scripts/umockdev-wrapper.py", + "--config", + "/xps.yaml", + "--", + "ubuntu-drivers", + "install", + ], + ) @patch("subiquity.server.ubuntu_drivers.arun_command") async def test_ensure_cmd_exists(self, mock_arun_command): await self.ubuntu_drivers.ensure_cmd_exists("/target") mock_arun_command.assert_called_once_with( - ["sh", "-c", "command -v ubuntu-drivers"], - check=True) + ["sh", "-c", "command -v ubuntu-drivers"], check=True + ) diff --git a/subiquity/server/types.py b/subiquity/server/types.py index 35d8fb96..28426ec0 100644 --- a/subiquity/server/types.py +++ b/subiquity/server/types.py @@ -22,28 +22,28 @@ from subiquitycore.pubsub import CoreChannels class InstallerChannels(CoreChannels): # This message is sent when a value is provided for the http proxy # setting. - NETWORK_PROXY_SET = 'network-proxy-set' + NETWORK_PROXY_SET = "network-proxy-set" # This message is sent there has been a network change that might affect # snapd's ability to talk to the network. - SNAPD_NETWORK_CHANGE = 'snapd-network-change' + SNAPD_NETWORK_CHANGE = "snapd-network-change" # This message is send when results from the geoip service are received. - GEOIP = 'geoip' + GEOIP = "geoip" # (CONFIGURED, $model_name) is sent when a model is marked # configured. Note that this can happen several times for each controller # (at least in theory) if the users goes back to a screen and makes a # different choice. - CONFIGURED = 'configured' + CONFIGURED = "configured" # This message is sent when apt has been configured in the overlay that # will be used as the source for the install step of the # installation. Currently this is only done once, but this might change in # the future. - APT_CONFIGURED = 'apt-configured' + APT_CONFIGURED = "apt-configured" # This message is sent when the user has confirmed that the install has # been confirmed. Once this has happened one can be sure that all the # models in the "install" side of the install/postinstall divide will not # be reconfigured. - INSTALL_CONFIRMED = 'install-confirmed' + INSTALL_CONFIRMED = "install-confirmed" # This message is sent as late as possible, and just before shutdown. This # step is after logfiles have been copied to the system, so should be used # sparingly and only as absolutely required. - PRE_SHUTDOWN = 'pre-shutdown' + PRE_SHUTDOWN = "pre-shutdown" diff --git a/subiquity/server/ubuntu_advantage.py b/subiquity/server/ubuntu_advantage.py index 2c524104..bd962ae6 100644 --- a/subiquity/server/ubuntu_advantage.py +++ b/subiquity/server/ubuntu_advantage.py @@ -15,30 +15,28 @@ """ This module defines utilities to interface with the ubuntu-advantage-tools helper. """ -from abc import ABC, abstractmethod -from datetime import datetime as dt +import asyncio import contextlib import json import logging import os -from subprocess import CompletedProcess import tempfile +from abc import ABC, abstractmethod +from datetime import datetime as dt +from subprocess import CompletedProcess from typing import List, Sequence, Tuple, Union + import yaml -import asyncio -from subiquity.common.types import ( - UbuntuProSubscription, - UbuntuProService, - ) +from subiquity.common.types import UbuntuProService, UbuntuProSubscription from subiquitycore import utils - log = logging.getLogger("subiquity.server.ubuntu_advantage") class InvalidTokenError(Exception): - """ Exception to be raised when the supplied token is invalid. """ + """Exception to be raised when the supplied token is invalid.""" + def __init__(self, token: str, message: str = "") -> None: self.token = token self.message = message @@ -46,7 +44,8 @@ class InvalidTokenError(Exception): class ExpiredTokenError(Exception): - """ Exception to be raised when the supplied token has expired. """ + """Exception to be raised when the supplied token has expired.""" + def __init__(self, token: str, expires: str, message: str = "") -> None: self.token = token self.expires = expires @@ -55,8 +54,9 @@ class ExpiredTokenError(Exception): class CheckSubscriptionError(Exception): - """ Exception to be raised when we are unable to fetch information about - the Ubuntu Advantage subscription. """ + """Exception to be raised when we are unable to fetch information about + the Ubuntu Advantage subscription.""" + def __init__(self, token: str, message: str = "") -> None: self.token = token self.message = message @@ -64,27 +64,29 @@ class CheckSubscriptionError(Exception): class APIUsageError(Exception): - """ Exception to raise when the API call does not return valid JSON. """ + """Exception to raise when the API call does not return valid JSON.""" class UAInterfaceStrategy(ABC): - """ Strategy to query information about a UA subscription. """ + """Strategy to query information about a UA subscription.""" + @abstractmethod async def query_info(self, token: str) -> dict: - """ Return information about the UA subscription based on the token - provided. """ + """Return information about the UA subscription based on the token + provided.""" class MockedUAInterfaceStrategy(UAInterfaceStrategy): - """ Mocked version of the Ubuntu Advantage interface strategy. The info it - returns is based on example files and appearance of the UA token. """ + """Mocked version of the Ubuntu Advantage interface strategy. The info it + returns is based on example files and appearance of the UA token.""" + def __init__(self, scale_factor: int = 1): self.scale_factor = scale_factor self._user_code_count = 0 super().__init__() async def query_info(self, token: str) -> dict: - """ Return the subscription info associated with the supplied + """Return the subscription info associated with the supplied UA token. No actual query is done to the UA servers in this implementation. Instead, we create a response based on the following rules: @@ -111,7 +113,7 @@ class MockedUAInterfaceStrategy(UAInterfaceStrategy): return json.load(stream) def _next_user_code(self) -> str: - """ Return the next user code starting with AAAAAA. """ + """Return the next user code starting with AAAAAA.""" alphabet = "ABCDEFGHJKLMNPQRSTUVWXYZ123456789" count = self._user_code_count @@ -126,8 +128,8 @@ class MockedUAInterfaceStrategy(UAInterfaceStrategy): return user_code async def magic_initiate_v1(self) -> dict: - """ Simulate a success response from u.pro.attach.magic.initiate.v1 - endpoint. """ + """Simulate a success response from u.pro.attach.magic.initiate.v1 + endpoint.""" await asyncio.sleep(1 / self.scale_factor) return { "data": { @@ -144,8 +146,8 @@ class MockedUAInterfaceStrategy(UAInterfaceStrategy): } async def magic_wait_v1(self, magic_token: str) -> dict: - """ Simulate a timeout respnose from u.pro.attach.magic.wait.v1 - endpoint. """ + """Simulate a timeout respnose from u.pro.attach.magic.wait.v1 + endpoint.""" # Simulate a normal timeout await asyncio.sleep(600 / self.scale_factor) return { @@ -153,15 +155,15 @@ class MockedUAInterfaceStrategy(UAInterfaceStrategy): "errors": [ { "title": "The magic attach token is invalid, has " - " expired or never existed", + " expired or never existed", "code": "magic-attach-token-error", }, ], } async def magic_revoke_v1(self, magic_token: str) -> dict: - """ Simulate a success respnose from u.pro.attach.magic.revoke.v1 - endpoint. """ + """Simulate a success respnose from u.pro.attach.magic.revoke.v1 + endpoint.""" await asyncio.sleep(1 / self.scale_factor) return { "data": { @@ -175,19 +177,20 @@ class MockedUAInterfaceStrategy(UAInterfaceStrategy): class UAClientUAInterfaceStrategy(UAInterfaceStrategy): - """ Strategy that relies on UA client script to retrieve the information. - """ + """Strategy that relies on UA client script to retrieve the information.""" + Executable = Union[str, Sequence[str]] scale_factor = 1 def __init__(self, executable: Executable = "ubuntu-advantage") -> None: - """ Initialize the strategy using the path to the ubuntu-advantage + """Initialize the strategy using the path to the ubuntu-advantage executable we want to use. The executable can be specified as a sequence of strings so that we can specify the interpret to use as well. """ - self.executable: List[str] = \ + self.executable: List[str] = ( [executable] if isinstance(executable, str) else list(executable) + ) self.uaclient_config = None super().__init__() @@ -197,8 +200,7 @@ class UAClientUAInterfaceStrategy(UAInterfaceStrategy): @contextlib.contextmanager def uaclient_config_file(self) -> str: - """ Yields a path to a file that contains the uaclient configuration. - """ + """Yields a path to a file that contains the uaclient configuration.""" try: if self.uaclient_config is None: yield "/etc/ubuntu-advantage/uaclient.conf" @@ -210,7 +212,7 @@ class UAClientUAInterfaceStrategy(UAInterfaceStrategy): pass async def query_info(self, token: str) -> dict: - """ Return the subscription info associated with the supplied + """Return the subscription info associated with the supplied UA token. The information will be queried using the UA client executable passed to the initializer. """ @@ -221,8 +223,10 @@ class UAClientUAInterfaceStrategy(UAInterfaceStrategy): command = tuple(self.executable) + ( "status", - "--format", "json", - "--simulate-with-token", token, + "--format", + "json", + "--simulate-with-token", + token, ) with self.uaclient_config_file() as config_file: @@ -233,7 +237,8 @@ class UAClientUAInterfaceStrategy(UAInterfaceStrategy): # inspect it to know the reason of the failure. This is how we # figure out if the contract token was invalid. proc: CompletedProcess = await utils.arun_command( - command, check=False, env=env) + command, check=False, env=env + ) if proc.returncode == 0: # TODO check if we're not returning a string or a list try: @@ -251,8 +256,11 @@ class UAClientUAInterfaceStrategy(UAInterfaceStrategy): for error in data["errors"]: if error["message_code"] == "attach-invalid-token": token_invalid = True - log.debug("error reported by u-a-c: %s: %s", - error["message_code"], error["message"]) + log.debug( + "error reported by u-a-c: %s: %s", + error["message_code"], + error["message"], + ) if token_invalid: raise InvalidTokenError(token) @@ -262,9 +270,7 @@ class UAClientUAInterfaceStrategy(UAInterfaceStrategy): message = "Unable to retrieve subscription information." raise CheckSubscriptionError(token, message=message) - async def _api_call(self, - endpoint: str, - params: Sequence[Tuple[str, str]]) -> dict: + async def _api_call(self, endpoint: str, params: Sequence[Tuple[str, str]]) -> dict: command = list(tuple(self.executable)) + ["api", endpoint, "--args"] for key, value in params: @@ -274,7 +280,8 @@ class UAClientUAInterfaceStrategy(UAInterfaceStrategy): env = os.environ.copy() env["UA_CONFIG_FILE"] = config_file proc: CompletedProcess = await utils.arun_command( - tuple(command), check=False, env=env) + tuple(command), check=False, env=env + ) try: return json.loads(proc.stdout) except json.JSONDecodeError: @@ -282,35 +289,37 @@ class UAClientUAInterfaceStrategy(UAInterfaceStrategy): raise APIUsageError async def magic_initiate_v1(self) -> dict: - """ Call the u.pro.attach.magic.initiate.v1 endpoint. """ + """Call the u.pro.attach.magic.initiate.v1 endpoint.""" return await self._api_call( - endpoint="u.pro.attach.magic.initiate.v1", - params=[]) + endpoint="u.pro.attach.magic.initiate.v1", params=[] + ) async def magic_wait_v1(self, magic_token: str) -> dict: - """ Call the u.pro.attach.magic.wait.v1 endpoint. """ + """Call the u.pro.attach.magic.wait.v1 endpoint.""" return await self._api_call( - endpoint="u.pro.attach.magic.wait.v1", - params=[("magic_token", magic_token)]) + endpoint="u.pro.attach.magic.wait.v1", params=[("magic_token", magic_token)] + ) async def magic_revoke_v1(self, magic_token: str) -> dict: - """ Call the u.pro.attach.magic.revoke.v1 endpoint. """ + """Call the u.pro.attach.magic.revoke.v1 endpoint.""" return await self._api_call( - endpoint="u.pro.attach.magic.revoke.v1", - params=[("magic_token", magic_token)]) + endpoint="u.pro.attach.magic.revoke.v1", + params=[("magic_token", magic_token)], + ) class UAInterface: - """ Interface to obtain Ubuntu Advantage subscription information. """ + """Interface to obtain Ubuntu Advantage subscription information.""" + def __init__(self, strategy: UAInterfaceStrategy): self.strategy = strategy async def get_subscription_status(self, token: str) -> dict: - """ Return a dictionary containing the subscription information. """ + """Return a dictionary containing the subscription information.""" return await self.strategy.query_info(token) async def get_subscription(self, token: str) -> UbuntuProSubscription: - """ Return the name of the contract, the name of the account and the + """Return the name of the contract, the name of the account and the list of activable services (i.e. services that are entitled to the subscription and available on the current hardware). """ @@ -328,14 +337,13 @@ class UAInterface: # the current machine (e.g. on Focal running on a amd64 CPU) ; # whereas # - the entitled field tells us if the contract covers the service. - return service["available"] == "yes" \ - and service["entitled"] == "yes" + return service["available"] == "yes" and service["entitled"] == "yes" def service_from_dict(service: dict) -> UbuntuProService: return UbuntuProService( - name=service["name"], - description=service["description"], - auto_enabled=service["auto_enabled"] == "yes", + name=service["name"], + description=service["description"], + auto_enabled=service["auto_enabled"] == "yes", ) activable_services: List[UbuntuProService] = [] @@ -346,7 +354,8 @@ class UAInterface: activable_services.append(service_from_dict(service)) return UbuntuProSubscription( - account_name=info["account"]["name"], - contract_name=info["contract"]["name"], - contract_token=token, - services=activable_services) + account_name=info["account"]["name"], + contract_name=info["contract"]["name"], + contract_token=token, + services=activable_services, + ) diff --git a/subiquity/server/ubuntu_drivers.py b/subiquity/server/ubuntu_drivers.py index 2ea73689..87742fcc 100644 --- a/subiquity/server/ubuntu_drivers.py +++ b/subiquity/server/ubuntu_drivers.py @@ -15,20 +15,19 @@ """ Module that defines helpers to use the ubuntu-drivers command. """ -from abc import ABC, abstractmethod import logging import subprocess +from abc import ABC, abstractmethod from typing import List, Type from subiquity.server.curtin import run_curtin_command from subiquitycore.utils import arun_command - log = logging.getLogger("subiquity.server.ubuntu_drivers") class CommandNotFoundError(Exception): - """ Exception to be raised when the ubuntu-drivers command is not + """Exception to be raised when the ubuntu-drivers command is not available. """ @@ -38,10 +37,12 @@ class UbuntuDriversInterface(ABC): self.app = app self.list_oem_cmd = [ - "ubuntu-drivers", "list-oem", + "ubuntu-drivers", + "list-oem", ] self.list_drivers_cmd = [ - "ubuntu-drivers", "list", + "ubuntu-drivers", + "list", "--recommended", ] # Because of LP #1966413, the following command will also install @@ -50,7 +51,9 @@ class UbuntuDriversInterface(ABC): # This is not ideal but should be acceptable because we want OEM # meta-packages installed unconditionally (except in autoinstall). self.install_drivers_cmd = [ - "ubuntu-drivers", "install", "--no-oem", + "ubuntu-drivers", + "install", + "--no-oem", ] if gpgpu: self.list_drivers_cmd.append("--gpgpu") @@ -70,13 +73,19 @@ class UbuntuDriversInterface(ABC): async def install_drivers(self, root_dir: str, context) -> None: await run_curtin_command( - self.app, context, - "in-target", "-t", root_dir, "--", *self.install_drivers_cmd, - private_mounts=True) + self.app, + context, + "in-target", + "-t", + root_dir, + "--", + *self.install_drivers_cmd, + private_mounts=True, + ) def _drivers_from_output(self, output: str) -> List[str]: - """ Parse the output of ubuntu-drivers list --recommended and return a - list of drivers. """ + """Parse the output of ubuntu-drivers list --recommended and return a + list of drivers.""" drivers: List[str] = [] # Drivers are listed one per line, but some drivers are followed by a # linux-modules-* package (which we are not interested in showing). @@ -96,8 +105,8 @@ class UbuntuDriversInterface(ABC): return drivers def _oem_metapackages_from_output(self, output: str) -> List[str]: - """ Parse the output of ubuntu-drivers list-oem and return a list of - packages. """ + """Parse the output of ubuntu-drivers list-oem and return a list of + packages.""" metapackages: List[str] = [] # Packages are listed one per line. for line in [x.strip() for x in output.split("\n")]: @@ -109,26 +118,33 @@ class UbuntuDriversInterface(ABC): class UbuntuDriversClientInterface(UbuntuDriversInterface): - """ UbuntuDrivers interface that uses the ubuntu-drivers command from the - specified root directory. """ + """UbuntuDrivers interface that uses the ubuntu-drivers command from the + specified root directory.""" async def ensure_cmd_exists(self, root_dir: str) -> None: # TODO This does not tell us if the "--recommended" option is # available. try: await self.app.command_runner.run( - ['chroot', root_dir, - 'sh', '-c', - "command -v ubuntu-drivers"]) + ["chroot", root_dir, "sh", "-c", "command -v ubuntu-drivers"] + ) except subprocess.CalledProcessError: raise CommandNotFoundError( - f"Command ubuntu-drivers is not available in {root_dir}") + f"Command ubuntu-drivers is not available in {root_dir}" + ) async def list_drivers(self, root_dir: str, context) -> List[str]: result = await run_curtin_command( - self.app, context, - "in-target", "-t", root_dir, "--", *self.list_drivers_cmd, - capture=True, private_mounts=True) + self.app, + context, + "in-target", + "-t", + root_dir, + "--", + *self.list_drivers_cmd, + capture=True, + private_mounts=True, + ) # Currently we have no way to specify universal_newlines=True or # encoding="utf-8" to run_curtin_command so we need to decode the # output. @@ -136,19 +152,26 @@ class UbuntuDriversClientInterface(UbuntuDriversInterface): async def list_oem(self, root_dir: str, context) -> List[str]: result = await run_curtin_command( - self.app, context, - "in-target", "-t", root_dir, "--", *self.list_oem_cmd, - capture=True, private_mounts=True) + self.app, + context, + "in-target", + "-t", + root_dir, + "--", + *self.list_oem_cmd, + capture=True, + private_mounts=True, + ) # Currently we have no way to specify universal_newlines=True or # encoding="utf-8" to run_curtin_command so we need to decode the # output. - return self._oem_metapackages_from_output( - result.stdout.decode("utf-8")) + return self._oem_metapackages_from_output(result.stdout.decode("utf-8")) class UbuntuDriversHasDriversInterface(UbuntuDriversInterface): - """ A dry-run implementation of ubuntu-drivers that returns a hard-coded - list of drivers. """ + """A dry-run implementation of ubuntu-drivers that returns a hard-coded + list of drivers.""" + gpgpu_drivers: List[str] = ["nvidia-driver-470-server"] not_gpgpu_drivers: List[str] = ["nvidia-driver-510"] oem_metapackages: List[str] = ["oem-somerville-tentacool-meta"] @@ -168,8 +191,8 @@ class UbuntuDriversHasDriversInterface(UbuntuDriversInterface): class UbuntuDriversNoDriversInterface(UbuntuDriversHasDriversInterface): - """ A dry-run implementation of ubuntu-drivers that returns a hard-coded - empty list of drivers. """ + """A dry-run implementation of ubuntu-drivers that returns a hard-coded + empty list of drivers.""" gpgpu_drivers: List[str] = [] not_gpgpu_drivers: List[str] = [] @@ -177,8 +200,8 @@ class UbuntuDriversNoDriversInterface(UbuntuDriversHasDriversInterface): class UbuntuDriversRunDriversInterface(UbuntuDriversInterface): - """ A dry-run implementation of ubuntu-drivers that actually runs the - ubuntu-drivers command but locally. """ + """A dry-run implementation of ubuntu-drivers that actually runs the + ubuntu-drivers command but locally.""" def __init__(self, app, gpgpu: bool) -> None: super().__init__(app, gpgpu) @@ -188,29 +211,34 @@ class UbuntuDriversRunDriversInterface(UbuntuDriversInterface): self.list_oem_cmd = [ "scripts/umockdev-wrapper.py", - "--config", app.dr_cfg.ubuntu_drivers_run_on_host_umockdev, - "--"] + self.list_oem_cmd + "--config", + app.dr_cfg.ubuntu_drivers_run_on_host_umockdev, + "--", + ] + self.list_oem_cmd self.list_drivers_cmd = [ "scripts/umockdev-wrapper.py", - "--config", app.dr_cfg.ubuntu_drivers_run_on_host_umockdev, - "--"] + self.list_drivers_cmd + "--config", + app.dr_cfg.ubuntu_drivers_run_on_host_umockdev, + "--", + ] + self.list_drivers_cmd self.install_drivers_cmd = [ "scripts/umockdev-wrapper.py", - "--config", app.dr_cfg.ubuntu_drivers_run_on_host_umockdev, - "--"] + self.install_drivers_cmd + "--config", + app.dr_cfg.ubuntu_drivers_run_on_host_umockdev, + "--", + ] + self.install_drivers_cmd async def ensure_cmd_exists(self, root_dir: str) -> None: # TODO This does not tell us if the "--recommended" option is # available. try: - await arun_command( - ["sh", "-c", "command -v ubuntu-drivers"], - check=True) + await arun_command(["sh", "-c", "command -v ubuntu-drivers"], check=True) except subprocess.CalledProcessError: raise CommandNotFoundError( - "Command ubuntu-drivers is not available in this system") + "Command ubuntu-drivers is not available in this system" + ) async def list_drivers(self, root_dir: str, context) -> List[str]: # We run the command locally - ignoring the root_dir. @@ -227,9 +255,9 @@ def get_ubuntu_drivers_interface(app) -> UbuntuDriversInterface: is_server = app.base_model.source.current.variant == "server" cls: Type[UbuntuDriversInterface] = UbuntuDriversClientInterface if app.opts.dry_run: - if 'no-drivers' in app.debug_flags: + if "no-drivers" in app.debug_flags: cls = UbuntuDriversNoDriversInterface - elif 'run-drivers' in app.debug_flags: + elif "run-drivers" in app.debug_flags: cls = UbuntuDriversRunDriversInterface else: cls = UbuntuDriversHasDriversInterface diff --git a/subiquity/tests/api/test_api.py b/subiquity/tests/api/test_api.py index 12ffc224..b3abe185 100644 --- a/subiquity/tests/api/test_api.py +++ b/subiquity/tests/api/test_api.py @@ -13,33 +13,31 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import aiohttp -from aiohttp.client_exceptions import ClientResponseError -import async_timeout import asyncio import contextlib -from functools import wraps import json import os import re import tempfile +from functools import wraps from typing import Dict, List, Optional from unittest.mock import patch from urllib.parse import unquote +import aiohttp +import async_timeout +from aiohttp.client_exceptions import ClientResponseError + from subiquitycore.tests import SubiTestCase -from subiquitycore.utils import ( - astart_command, - matching_dicts, - ) +from subiquitycore.utils import astart_command, matching_dicts default_timeout = 10 def match(items, **kw): - typename = kw.pop('_type', None) + typename = kw.pop("_type", None) if typename is not None: - kw['$type'] = typename + kw["$type"] = typename return matching_dicts(items, **kw) @@ -49,7 +47,9 @@ def timeout(multiplier=1): async def run(*args, **kwargs): async with async_timeout.timeout(default_timeout * multiplier): return await coro(*args, **kwargs) + return run + return wrapper @@ -62,7 +62,7 @@ class Client: self.session = session def loads(self, data): - if data == '' or data is None: # json.loads likes neither of these + if data == "" or data is None: # json.loads likes neither of these return None return json.loads(data) @@ -70,21 +70,21 @@ class Client: # if the data we're dumping is literally False, # we want that to be 'false' if data or isinstance(data, bool): - return json.dumps(data, separators=(',', ':')) + return json.dumps(data, separators=(",", ":")) elif data is not None: return '""' else: return data async def get(self, query, **kwargs): - return await self.request('GET', query, **kwargs) + return await self.request("GET", query, **kwargs) async def post(self, query, data=None, **kwargs): - return await self.request('POST', query, data, **kwargs) + return await self.request("POST", query, data, **kwargs) - async def request(self, method, query, - data=None, full_response=False, headers=None, - **kwargs): + async def request( + self, method, query, data=None, full_response=False, headers=None, **kwargs + ): """send a GET or POST to the test instance args: method: 'GET' or 'POST' @@ -105,9 +105,9 @@ class Client: """ params = {k: self.dumps(v) for k, v in kwargs.items()} data = self.dumps(data) - async with self.session.request(method, f'http://a{query}', - data=data, params=params, - headers=headers) as resp: + async with self.session.request( + method, f"http://a{query}", data=data, params=params, headers=headers + ) as resp: print(unquote(str(resp.url))) content = await resp.content.read() content = content.decode() @@ -121,36 +121,48 @@ class Client: async def poll_startup(self): for _ in range(default_timeout * 10): try: - resp = await self.get('/meta/status') - if resp["state"] in ('STARTING_UP', 'CLOUD_INIT_WAIT', - 'EARLY_COMMANDS'): - await asyncio.sleep(.5) + resp = await self.get("/meta/status") + if resp["state"] in ( + "STARTING_UP", + "CLOUD_INIT_WAIT", + "EARLY_COMMANDS", + ): + await asyncio.sleep(0.5) continue - if resp["state"] == 'ERROR': - raise Exception('server in error state') + if resp["state"] == "ERROR": + raise Exception("server in error state") return except aiohttp.client_exceptions.ClientConnectorError: - await asyncio.sleep(.5) - raise Exception('timeout on server startup') + await asyncio.sleep(0.5) + raise Exception("timeout on server startup") class Server(Client): async def server_shutdown(self, immediate=True): try: - await self.post('/shutdown', mode='POWEROFF', immediate=immediate) + await self.post("/shutdown", mode="POWEROFF", immediate=immediate) except aiohttp.client_exceptions.ServerDisconnectedError: return - async def spawn(self, output_base, socket, machine_config, - bootloader='uefi', extra_args=None): + async def spawn( + self, output_base, socket, machine_config, bootloader="uefi", extra_args=None + ): env = os.environ.copy() - env['SUBIQUITY_REPLAY_TIMESCALE'] = '100' - cmd = ['python3', '-m', 'subiquity.cmd.server', - '--dry-run', - '--bootloader', bootloader, - '--socket', socket, - '--output-base', output_base, - '--machine-config', machine_config] + env["SUBIQUITY_REPLAY_TIMESCALE"] = "100" + cmd = [ + "python3", + "-m", + "subiquity.cmd.server", + "--dry-run", + "--bootloader", + bootloader, + "--socket", + socket, + "--output-base", + output_base, + "--machine-config", + machine_config, + ] if extra_args is not None: cmd.extend(extra_args) self.proc = await astart_command(cmd, env=env) @@ -176,12 +188,21 @@ class Server(Client): class SystemSetupServer(Server): - async def spawn(self, output_base, socket, machine_config, - bootloader='uefi', extra_args=None): + async def spawn( + self, output_base, socket, machine_config, bootloader="uefi", extra_args=None + ): env = os.environ.copy() - env['SUBIQUITY_REPLAY_TIMESCALE'] = '100' - cmd = ['python3', '-m', 'system_setup.cmd.server', - '--dry-run', '--socket', socket, '--output-base', output_base] + env["SUBIQUITY_REPLAY_TIMESCALE"] = "100" + cmd = [ + "python3", + "-m", + "system_setup.cmd.server", + "--dry-run", + "--socket", + socket, + "--output-base", + output_base, + ] root = os.path.abspath(output_base) conffile = os.path.join(root, "etc/wsl.conf") os.makedirs(os.path.dirname(conffile), exist_ok=True) @@ -206,11 +227,11 @@ class TestAPI(SubiTestCase): @contextlib.contextmanager def edit(self): - with open(self.orig_path, 'r') as fp: + with open(self.orig_path, "r") as fp: data = json.load(fp) yield data - self.path = self.outer.tmp_path('machine-config.json') - with open(self.path, 'w') as fp: + self.path = self.outer.tmp_path("machine-config.json") + with open(self.path, "w") as fp: json.dump(data, fp) def machineConfig(self, path): @@ -230,8 +251,8 @@ async def poll_for_socket_exist(socket_path): # test level timeout will trigger first, this loop is just a fallback if os.path.exists(socket_path): return - await asyncio.sleep(.1) - raise Exception('timeout looking for socket to exist') + await asyncio.sleep(0.1) + raise Exception("timeout looking for socket to exist") @contextlib.contextmanager @@ -241,7 +262,7 @@ def tempdirs(*args, **kwargs): # that the log files can be examined later # * make it an otherwise-unnecessary contextmanager so that the indentation # of the caller can be preserved - prefix = '/tmp/testapi/' + prefix = "/tmp/testapi/" os.makedirs(prefix, exist_ok=True) tempdir = tempfile.mkdtemp(prefix=prefix) print(tempdir) @@ -251,7 +272,7 @@ def tempdirs(*args, **kwargs): @contextlib.asynccontextmanager async def start_server_factory(factory, *args, **kwargs): with tempfile.TemporaryDirectory() as tempdir: - socket_path = f'{tempdir}/socket' + socket_path = f"{tempdir}/socket" conn = aiohttp.UnixConnector(path=socket_path) async with aiohttp.ClientSession(connector=conn) as session: server = factory(session) @@ -268,15 +289,14 @@ async def start_server_factory(factory, *args, **kwargs): async def start_server(*args, set_first_source=True, source=None, **kwargs): async with start_server_factory(Server, *args, **kwargs) as instance: if set_first_source: - sources = await instance.get('/source') + sources = await instance.get("/source") if sources is None: - raise Exception('unexpected /source response') - await instance.post( - '/source', source_id=sources['sources'][0]['id']) + raise Exception("unexpected /source response") + await instance.post("/source", source_id=sources["sources"][0]["id"]) while True: - resp = await instance.get('/storage/v2') + resp = await instance.get("/storage/v2") print(resp) - if resp['status'] != 'PROBING': + if resp["status"] != "PROBING": break await asyncio.sleep(0.5) yield instance @@ -293,7 +313,7 @@ async def connect_server(*args, **kwargs): # This is not used by the tests directly, but can be convenient when # wanting to debug the server process. Change a test's start_server # to connect_server, disable the test timeout, and run just that test. - socket_path = '.subiquity/socket' + socket_path = ".subiquity/socket" conn = aiohttp.UnixConnector(path=socket_path) async with aiohttp.ClientSession(connector=conn) as session: yield Client(session) @@ -302,142 +322,134 @@ async def connect_server(*args, **kwargs): class TestBitlocker(TestAPI): @timeout() async def test_has_bitlocker(self): - async with start_server('examples/machines/win10.json') as inst: - resp = await inst.get('/storage/has_bitlocker') + async with start_server("examples/machines/win10.json") as inst: + resp = await inst.get("/storage/has_bitlocker") self.assertEqual(1, len(resp)) @timeout() async def test_not_bitlocker(self): - async with start_server('examples/machines/simple.json') as inst: - resp = await inst.get('/storage/has_bitlocker') + async with start_server("examples/machines/simple.json") as inst: + resp = await inst.get("/storage/has_bitlocker") self.assertEqual(0, len(resp)) class TestFlow(TestAPI): @timeout(2) async def test_serverish_flow(self): - async with start_server('examples/machines/simple.json') as inst: - await inst.post('/locale', 'en_US.UTF-8') - keyboard = { - 'layout': 'us', - 'variant': '', - 'toggle': None - } - await inst.post('/keyboard', keyboard) - await inst.post('/source', - source_id='ubuntu-server', search_drivers=True) - await inst.post('/network') - await inst.post('/proxy', '') - await inst.post('/mirror', - {'elected': 'http://us.archive.ubuntu.com/ubuntu'}) - - resp = await inst.get('/storage/v2/guided?wait=true') - [reformat, manual] = resp['targets'] + async with start_server("examples/machines/simple.json") as inst: + await inst.post("/locale", "en_US.UTF-8") + keyboard = {"layout": "us", "variant": "", "toggle": None} + await inst.post("/keyboard", keyboard) + await inst.post("/source", source_id="ubuntu-server", search_drivers=True) + await inst.post("/network") + await inst.post("/proxy", "") await inst.post( - '/storage/v2/guided', + "/mirror", {"elected": "http://us.archive.ubuntu.com/ubuntu"} + ) + + resp = await inst.get("/storage/v2/guided?wait=true") + [reformat, manual] = resp["targets"] + await inst.post( + "/storage/v2/guided", { - 'target': reformat, - 'capability': reformat['allowed'][0], - }) - await inst.post('/storage/v2') - await inst.get('/meta/status', cur='WAITING') - await inst.post('/meta/confirm', tty='/dev/tty1') - await inst.get('/meta/status', cur='NEEDS_CONFIRMATION') + "target": reformat, + "capability": reformat["allowed"][0], + }, + ) + await inst.post("/storage/v2") + await inst.get("/meta/status", cur="WAITING") + await inst.post("/meta/confirm", tty="/dev/tty1") + await inst.get("/meta/status", cur="NEEDS_CONFIRMATION") identity = { - 'realname': 'ubuntu', - 'username': 'ubuntu', - 'hostname': 'ubuntu-server', - 'crypted_password': '$6$exDY1mhS4KUYCE/2$zmn9ToZwTKLhCw.b4/' - + 'b.ZRTIZM30JZ4QrOQ2aOXJ8yk96xpcCof0kx' - + 'KwuX1kqLG/ygbJ1f8wxED22bTL4F46P0' + "realname": "ubuntu", + "username": "ubuntu", + "hostname": "ubuntu-server", + "crypted_password": "$6$exDY1mhS4KUYCE/2$zmn9ToZwTKLhCw.b4/" + + "b.ZRTIZM30JZ4QrOQ2aOXJ8yk96xpcCof0kx" + + "KwuX1kqLG/ygbJ1f8wxED22bTL4F46P0", } - await inst.post('/identity', identity) - ssh = { - 'install_server': False, - 'allow_pw': False, - 'authorized_keys': [] - } - await inst.post('/ssh', ssh) - await inst.post('/snaplist', []) - await inst.post('/drivers', {'install': False}) + await inst.post("/identity", identity) + ssh = {"install_server": False, "allow_pw": False, "authorized_keys": []} + await inst.post("/ssh", ssh) + await inst.post("/snaplist", []) + await inst.post("/drivers", {"install": False}) ua_params = { "token": "a1b2c3d4e6f7g8h9I0K1", } - await inst.post('/ubuntu_pro', ua_params) - for state in 'RUNNING', 'WAITING', 'RUNNING', 'UU_RUNNING': - await inst.get('/meta/status', cur=state) + await inst.post("/ubuntu_pro", ua_params) + for state in "RUNNING", "WAITING", "RUNNING", "UU_RUNNING": + await inst.get("/meta/status", cur=state) @timeout() async def test_v2_flow(self): - cfg = self.machineConfig('examples/machines/simple.json') + cfg = self.machineConfig("examples/machines/simple.json") with cfg.edit() as data: - attrs = data['storage']['blockdev']['/dev/sda']['attrs'] - attrs['size'] = str(25 << 30) - extra_args = ['--source-catalog', 'examples/sources/desktop.yaml'] + attrs = data["storage"]["blockdev"]["/dev/sda"]["attrs"] + attrs["size"] = str(25 << 30) + extra_args = ["--source-catalog", "examples/sources/desktop.yaml"] async with start_server(cfg, extra_args=extra_args) as inst: - disk_id = 'disk-sda' - orig_resp = await inst.get('/storage/v2') - [sda] = match(orig_resp['disks'], id=disk_id) - self.assertTrue(len(sda['partitions']) > 0) + disk_id = "disk-sda" + orig_resp = await inst.get("/storage/v2") + [sda] = match(orig_resp["disks"], id=disk_id) + self.assertTrue(len(sda["partitions"]) > 0) - data = {'disk_id': disk_id} - resp = await inst.post('/storage/v2/reformat_disk', data) - [sda] = match(resp['disks'], id=disk_id) - [gap] = sda['partitions'] - self.assertEqual('Gap', gap['$type']) + data = {"disk_id": disk_id} + resp = await inst.post("/storage/v2/reformat_disk", data) + [sda] = match(resp["disks"], id=disk_id) + [gap] = sda["partitions"] + self.assertEqual("Gap", gap["$type"]) data = { - 'disk_id': disk_id, - 'gap': gap, - 'partition': { - 'format': 'ext3', - 'mount': '/', - } + "disk_id": disk_id, + "gap": gap, + "partition": { + "format": "ext3", + "mount": "/", + }, } - add_resp = await inst.post('/storage/v2/add_partition', data) - [sda] = add_resp['disks'] - [root] = match(sda['partitions'], mount='/') - self.assertEqual('ext3', root['format']) + add_resp = await inst.post("/storage/v2/add_partition", data) + [sda] = add_resp["disks"] + [root] = match(sda["partitions"], mount="/") + self.assertEqual("ext3", root["format"]) data = { - 'disk_id': disk_id, - 'partition': { - 'number': root['number'], - 'format': 'ext4', - 'wipe': 'superblock', - } + "disk_id": disk_id, + "partition": { + "number": root["number"], + "format": "ext4", + "wipe": "superblock", + }, } - edit_resp = await inst.post('/storage/v2/edit_partition', data) + edit_resp = await inst.post("/storage/v2/edit_partition", data) - [add_sda] = add_resp['disks'] - [add_root] = match(add_sda['partitions'], mount='/') + [add_sda] = add_resp["disks"] + [add_root] = match(add_sda["partitions"], mount="/") - [edit_sda] = edit_resp['disks'] - [edit_root] = match(edit_sda['partitions'], mount='/') + [edit_sda] = edit_resp["disks"] + [edit_root] = match(edit_sda["partitions"], mount="/") - for key in 'size', 'number', 'mount', 'boot': + for key in "size", "number", "mount", "boot": self.assertEqual(add_root[key], edit_root[key], key) - self.assertEqual('ext4', edit_root['format']) + self.assertEqual("ext4", edit_root["format"]) - del_resp = await inst.post('/storage/v2/delete_partition', data) - [sda] = del_resp['disks'] - [p, g] = sda['partitions'] - self.assertEqual('Partition', p['$type']) - self.assertEqual('Gap', g['$type']) + del_resp = await inst.post("/storage/v2/delete_partition", data) + [sda] = del_resp["disks"] + [p, g] = sda["partitions"] + self.assertEqual("Partition", p["$type"]) + self.assertEqual("Gap", g["$type"]) - reset_resp = await inst.post('/storage/v2/reset') + reset_resp = await inst.post("/storage/v2/reset") self.assertEqual(orig_resp, reset_resp) - resp = await inst.get('/storage/v2/guided') - [reformat] = match(resp['targets'], - _type='GuidedStorageTargetReformat') + resp = await inst.get("/storage/v2/guided") + [reformat] = match(resp["targets"], _type="GuidedStorageTargetReformat") data = { - 'target': reformat, - 'capability': reformat['allowed'][0], - } - await inst.post('/storage/v2/guided', data) - after_guided_resp = await inst.get('/storage/v2') - post_resp = await inst.post('/storage/v2') + "target": reformat, + "capability": reformat["allowed"][0], + } + await inst.post("/storage/v2/guided", data) + after_guided_resp = await inst.get("/storage/v2") + post_resp = await inst.post("/storage/v2") # posting to the endpoint shouldn't change the answer self.assertEqual(after_guided_resp, post_resp) @@ -445,510 +457,493 @@ class TestFlow(TestAPI): class TestGuided(TestAPI): @timeout() async def test_guided_v2_reformat(self): - cfg = 'examples/machines/win10-along-ubuntu.json' + cfg = "examples/machines/win10-along-ubuntu.json" async with start_server(cfg) as inst: - resp = await inst.get('/storage/v2/guided') - [reformat] = match(resp['targets'], - _type='GuidedStorageTargetReformat') + resp = await inst.get("/storage/v2/guided") + [reformat] = match(resp["targets"], _type="GuidedStorageTargetReformat") resp = await inst.post( - '/storage/v2/guided', + "/storage/v2/guided", { - 'target': reformat, - 'capability': reformat['allowed'][0], - }) - self.assertEqual(reformat, resp['configured']['target']) - resp = await inst.get('/storage/v2') - [p1, p2] = resp['disks'][0]['partitions'] + "target": reformat, + "capability": reformat["allowed"][0], + }, + ) + self.assertEqual(reformat, resp["configured"]["target"]) + resp = await inst.get("/storage/v2") + [p1, p2] = resp["disks"][0]["partitions"] expected_p1 = { - '$type': 'Partition', - 'boot': True, - 'format': 'fat32', - 'grub_device': True, - 'mount': '/boot/efi', - 'number': 1, - 'preserve': False, - 'wipe': 'superblock', + "$type": "Partition", + "boot": True, + "format": "fat32", + "grub_device": True, + "mount": "/boot/efi", + "number": 1, + "preserve": False, + "wipe": "superblock", } self.assertDictSubset(expected_p1, p1) expected_p2 = { - 'number': 2, - 'mount': '/', - 'format': 'ext4', - 'preserve': False, - 'wipe': 'superblock', + "number": 2, + "mount": "/", + "format": "ext4", + "preserve": False, + "wipe": "superblock", } self.assertDictSubset(expected_p2, p2) - v1resp = await inst.get('/storage') - parts = match(v1resp['config'], type='partition') + v1resp = await inst.get("/storage") + parts = match(v1resp["config"], type="partition") for p in parts: - self.assertFalse(p['preserve'], p) - self.assertEqual('superblock', p.get('wipe'), p) + self.assertFalse(p["preserve"], p) + self.assertEqual("superblock", p.get("wipe"), p) @timeout() async def test_guided_v2_resize(self): - cfg = 'examples/machines/win10-along-ubuntu.json' - extra = ['--storage-version', '2'] + cfg = "examples/machines/win10-along-ubuntu.json" + extra = ["--storage-version", "2"] async with start_server(cfg, extra_args=extra) as inst: - orig_resp = await inst.get('/storage/v2') - [orig_p1, orig_p2, orig_p3, orig_p4, orig_p5] = \ - orig_resp['disks'][0]['partitions'] - resp = await inst.get('/storage/v2/guided') + orig_resp = await inst.get("/storage/v2") + [orig_p1, orig_p2, orig_p3, orig_p4, orig_p5] = orig_resp["disks"][0][ + "partitions" + ] + resp = await inst.get("/storage/v2/guided") [resize_ntfs, resize_ext4] = match( - resp['targets'], _type='GuidedStorageTargetResize') - resize_ntfs['new_size'] = 30 << 30 + resp["targets"], _type="GuidedStorageTargetResize" + ) + resize_ntfs["new_size"] = 30 << 30 data = { - 'target': resize_ntfs, - 'capability': resize_ntfs['allowed'][0], - } - resp = await inst.post('/storage/v2/guided', data) - self.assertEqual(resize_ntfs, resp['configured']['target']) - resp = await inst.get('/storage/v2') - [p1, p2, p3, p6, p4, p5] = resp['disks'][0]['partitions'] + "target": resize_ntfs, + "capability": resize_ntfs["allowed"][0], + } + resp = await inst.post("/storage/v2/guided", data) + self.assertEqual(resize_ntfs, resp["configured"]["target"]) + resp = await inst.get("/storage/v2") + [p1, p2, p3, p6, p4, p5] = resp["disks"][0]["partitions"] expected_p1 = { - '$type': 'Partition', - 'boot': True, - 'format': 'vfat', - 'grub_device': True, - 'mount': '/boot/efi', - 'number': 1, - 'size': orig_p1['size'], - 'resize': None, - 'wipe': None, + "$type": "Partition", + "boot": True, + "format": "vfat", + "grub_device": True, + "mount": "/boot/efi", + "number": 1, + "size": orig_p1["size"], + "resize": None, + "wipe": None, } self.assertDictSubset(expected_p1, p1) self.assertEqual(orig_p2, p2) self.assertEqual(orig_p4, p4) self.assertEqual(orig_p5, p5) expected_p6 = { - 'number': 6, - 'mount': '/', - 'format': 'ext4', + "number": 6, + "mount": "/", + "format": "ext4", } self.assertDictSubset(expected_p6, p6) @timeout() async def test_guided_v2_use_gap(self): - cfg = self.machineConfig('examples/machines/win10-along-ubuntu.json') + cfg = self.machineConfig("examples/machines/win10-along-ubuntu.json") with cfg.edit() as data: - pt = data['storage']['blockdev']['/dev/sda']['partitiontable'] - [node] = match(pt['partitions'], node='/dev/sda5') - pt['partitions'].remove(node) - del data['storage']['blockdev']['/dev/sda5'] - del data['storage']['filesystem']['/dev/sda5'] - extra = ['--storage-version', '2'] + pt = data["storage"]["blockdev"]["/dev/sda"]["partitiontable"] + [node] = match(pt["partitions"], node="/dev/sda5") + pt["partitions"].remove(node) + del data["storage"]["blockdev"]["/dev/sda5"] + del data["storage"]["filesystem"]["/dev/sda5"] + extra = ["--storage-version", "2"] async with start_server(cfg, extra_args=extra) as inst: - orig_resp = await inst.get('/storage/v2') - [orig_p1, orig_p2, orig_p3, orig_p4, gap] = \ - orig_resp['disks'][0]['partitions'] - resp = await inst.get('/storage/v2/guided') - [use_gap] = match(resp['targets'], - _type='GuidedStorageTargetUseGap') + orig_resp = await inst.get("/storage/v2") + [orig_p1, orig_p2, orig_p3, orig_p4, gap] = orig_resp["disks"][0][ + "partitions" + ] + resp = await inst.get("/storage/v2/guided") + [use_gap] = match(resp["targets"], _type="GuidedStorageTargetUseGap") data = { - 'target': use_gap, - 'capability': use_gap['allowed'][0], - } - resp = await inst.post('/storage/v2/guided', data) - self.assertEqual(use_gap, resp['configured']['target']) - resp = await inst.get('/storage/v2') - [p1, p2, p3, p4, p5] = resp['disks'][0]['partitions'] + "target": use_gap, + "capability": use_gap["allowed"][0], + } + resp = await inst.post("/storage/v2/guided", data) + self.assertEqual(use_gap, resp["configured"]["target"]) + resp = await inst.get("/storage/v2") + [p1, p2, p3, p4, p5] = resp["disks"][0]["partitions"] expected_p1 = { - '$type': 'Partition', - 'boot': True, - 'format': 'vfat', - 'grub_device': True, - 'mount': '/boot/efi', - 'number': 1, - 'size': orig_p1['size'], - 'resize': None, - 'wipe': None, + "$type": "Partition", + "boot": True, + "format": "vfat", + "grub_device": True, + "mount": "/boot/efi", + "number": 1, + "size": orig_p1["size"], + "resize": None, + "wipe": None, } self.assertDictSubset(expected_p1, p1) self.assertEqual(orig_p2, p2) self.assertEqual(orig_p3, p3) self.assertEqual(orig_p4, p4) expected_p5 = { - 'number': 5, - 'mount': '/', - 'format': 'ext4', + "number": 5, + "mount": "/", + "format": "ext4", } self.assertDictSubset(expected_p5, p5) @timeout() async def test_guided_v2_resize_logical(self): - cfg = 'examples/machines/threebuntu-on-msdos.json' - extra = ['--storage-version', '2'] + cfg = "examples/machines/threebuntu-on-msdos.json" + extra = ["--storage-version", "2"] async with start_server(cfg, extra_args=extra) as inst: - resp = await inst.get('/storage/v2/guided') + resp = await inst.get("/storage/v2/guided") [resize] = match( - resp['targets'], _type='GuidedStorageTargetResize', - partition_number=6) + resp["targets"], _type="GuidedStorageTargetResize", partition_number=6 + ) data = { - 'target': resize, - 'capability': resize['allowed'][0], - } - resp = await inst.post('/storage/v2/guided', data) - self.assertEqual(resize, resp['configured']['target']) + "target": resize, + "capability": resize["allowed"][0], + } + resp = await inst.post("/storage/v2/guided", data) + self.assertEqual(resize, resp["configured"]["target"]) # should not throw a Gap Not Found exception class TestCore(TestAPI): @timeout() async def test_basic_core_boot(self): - cfg = self.machineConfig('examples/machines/simple.json') + cfg = self.machineConfig("examples/machines/simple.json") with cfg.edit() as data: - attrs = data['storage']['blockdev']['/dev/sda']['attrs'] - attrs['size'] = str(25 << 30) + attrs = data["storage"]["blockdev"]["/dev/sda"]["attrs"] + attrs["size"] = str(25 << 30) kw = dict( - bootloader='uefi', + bootloader="uefi", extra_args=[ - '--storage-version', '2', - '--source-catalog', 'examples/sources/install-canary.yaml', - '--dry-run-config', 'examples/dry-run-configs/tpm.yaml', - ] + "--storage-version", + "2", + "--source-catalog", + "examples/sources/install-canary.yaml", + "--dry-run-config", + "examples/dry-run-configs/tpm.yaml", + ], ) async with start_server(cfg, **kw) as inst: - await inst.post('/source', source_id='ubuntu-desktop') - resp = await inst.get('/storage/v2/guided', wait=True) - [reformat, manual] = resp['targets'] - self.assertIn('CORE_BOOT_PREFER_ENCRYPTED', - reformat['allowed']) - data = dict(target=reformat, capability='CORE_BOOT_ENCRYPTED') - await inst.post('/storage/v2/guided', data) - v2resp = await inst.get('/storage/v2') - [d] = v2resp['disks'] - pgs = d['partitions'] + await inst.post("/source", source_id="ubuntu-desktop") + resp = await inst.get("/storage/v2/guided", wait=True) + [reformat, manual] = resp["targets"] + self.assertIn("CORE_BOOT_PREFER_ENCRYPTED", reformat["allowed"]) + data = dict(target=reformat, capability="CORE_BOOT_ENCRYPTED") + await inst.post("/storage/v2/guided", data) + v2resp = await inst.get("/storage/v2") + [d] = v2resp["disks"] + pgs = d["partitions"] [p4] = match(pgs, number=4) # FIXME The current model has a ~13GiB gap between p1 and p2. # Presumably this will be removed later. - [p1, g1, p2, p3, p4] = d['partitions'] - e1 = dict(offset=1 << 20, mount='/boot/efi') + [p1, g1, p2, p3, p4] = d["partitions"] + e1 = dict(offset=1 << 20, mount="/boot/efi") self.assertDictSubset(e1, p1) - self.assertDictSubset(dict(mount='/boot'), p2) + self.assertDictSubset(dict(mount="/boot"), p2) self.assertDictSubset(dict(mount=None), p3) - self.assertDictSubset(dict(mount='/'), p4) + self.assertDictSubset(dict(mount="/"), p4) class TestAdd(TestAPI): @timeout() async def test_v2_add_boot_partition(self): - async with start_server('examples/machines/simple.json') as inst: - disk_id = 'disk-sda' + async with start_server("examples/machines/simple.json") as inst: + disk_id = "disk-sda" - resp = await inst.post('/storage/v2') - [sda] = match(resp['disks'], id=disk_id) - [gap] = match(sda['partitions'], _type='Gap') + resp = await inst.post("/storage/v2") + [sda] = match(resp["disks"], id=disk_id) + [gap] = match(sda["partitions"], _type="Gap") data = { - 'disk_id': disk_id, - 'gap': gap, - 'partition': { - 'format': 'ext4', - 'mount': '/', - } + "disk_id": disk_id, + "gap": gap, + "partition": { + "format": "ext4", + "mount": "/", + }, } - single_add = await inst.post('/storage/v2/add_partition', data) - self.assertEqual(2, len(single_add['disks'][0]['partitions'])) - self.assertTrue(single_add['disks'][0]['boot_device']) + single_add = await inst.post("/storage/v2/add_partition", data) + self.assertEqual(2, len(single_add["disks"][0]["partitions"])) + self.assertTrue(single_add["disks"][0]["boot_device"]) - await inst.post('/storage/v2/reset') + await inst.post("/storage/v2/reset") # these manual steps are expected to be mostly equivalent to just # adding the single partition and getting the automatic boot # partition - resp = await inst.post( - '/storage/v2/add_boot_partition', disk_id=disk_id) - [sda] = match(resp['disks'], id=disk_id) - [gap] = match(sda['partitions'], _type='Gap') + resp = await inst.post("/storage/v2/add_boot_partition", disk_id=disk_id) + [sda] = match(resp["disks"], id=disk_id) + [gap] = match(sda["partitions"], _type="Gap") data = { - 'disk_id': disk_id, - 'gap': gap, - 'partition': { - 'format': 'ext4', - 'mount': '/', - } + "disk_id": disk_id, + "gap": gap, + "partition": { + "format": "ext4", + "mount": "/", + }, } - manual_add = await inst.post('/storage/v2/add_partition', data) + manual_add = await inst.post("/storage/v2/add_partition", data) # the only difference is the partition number assigned - when we # explicitly add_boot_partition, that is the first partition # created, versus when we add_partition and get a boot partition # implicitly for resp in single_add, manual_add: - for part in resp['disks'][0]['partitions']: - part.pop('number') - part.pop('path') + for part in resp["disks"][0]["partitions"]: + part.pop("number") + part.pop("path") self.assertEqual(single_add, manual_add) @timeout() async def test_v2_deny_multiple_add_boot_partition(self): - async with start_server('examples/machines/simple.json') as inst: - disk_id = 'disk-sda' - await inst.post('/storage/v2/add_boot_partition', disk_id=disk_id) + async with start_server("examples/machines/simple.json") as inst: + disk_id = "disk-sda" + await inst.post("/storage/v2/add_boot_partition", disk_id=disk_id) with self.assertRaises(ClientResponseError): - await inst.post('/storage/v2/add_boot_partition', - disk_id=disk_id) + await inst.post("/storage/v2/add_boot_partition", disk_id=disk_id) @timeout() async def test_v2_deny_multiple_add_boot_partition_BIOS(self): - cfg = 'examples/machines/simple.json' - async with start_server(cfg, 'bios') as inst: - disk_id = 'disk-sda' - await inst.post('/storage/v2/add_boot_partition', disk_id=disk_id) + cfg = "examples/machines/simple.json" + async with start_server(cfg, "bios") as inst: + disk_id = "disk-sda" + await inst.post("/storage/v2/add_boot_partition", disk_id=disk_id) with self.assertRaises(ClientResponseError): - await inst.post('/storage/v2/add_boot_partition', - disk_id=disk_id) + await inst.post("/storage/v2/add_boot_partition", disk_id=disk_id) @timeout() async def test_add_format_required(self): - disk_id = 'disk-sda' - async with start_server('examples/machines/simple.json') as inst: + disk_id = "disk-sda" + async with start_server("examples/machines/simple.json") as inst: bad_partitions = [ {}, - {'mount': '/'}, + {"mount": "/"}, ] for partition in bad_partitions: - data = {'disk_id': disk_id, 'partition': partition} - with self.assertRaises(ClientResponseError, - msg=f'data {data}'): - await inst.post('/storage/v2/add_partition', data) + data = {"disk_id": disk_id, "partition": partition} + with self.assertRaises(ClientResponseError, msg=f"data {data}"): + await inst.post("/storage/v2/add_partition", data) @timeout() async def test_add_default_size_handling(self): - async with start_server('examples/machines/simple.json') as inst: - disk_id = 'disk-sda' - resp = await inst.get('/storage/v2') - [sda] = match(resp['disks'], id=disk_id) - [gap] = sda['partitions'] + async with start_server("examples/machines/simple.json") as inst: + disk_id = "disk-sda" + resp = await inst.get("/storage/v2") + [sda] = match(resp["disks"], id=disk_id) + [gap] = sda["partitions"] data = { - 'disk_id': disk_id, - 'gap': gap, - 'partition': { - 'format': 'ext4', - 'mount': '/', - } + "disk_id": disk_id, + "gap": gap, + "partition": { + "format": "ext4", + "mount": "/", + }, } - resp = await inst.post('/storage/v2/add_partition', data) - [sda] = match(resp['disks'], id=disk_id) - [sda1, sda2] = sda['partitions'] - self.assertEqual(gap['size'], sda1['size'] + sda2['size']) + resp = await inst.post("/storage/v2/add_partition", data) + [sda] = match(resp["disks"], id=disk_id) + [sda1, sda2] = sda["partitions"] + self.assertEqual(gap["size"], sda1["size"] + sda2["size"]) @timeout() async def test_v2_add_boot_BIOS(self): - cfg = 'examples/machines/simple.json' - async with start_server(cfg, 'bios') as inst: - disk_id = 'disk-sda' - resp = await inst.post('/storage/v2/add_boot_partition', - disk_id=disk_id) - [sda] = match(resp['disks'], id=disk_id) - [sda1] = match(sda['partitions'], number=1) - self.assertTrue(sda['boot_device']) - self.assertTrue(sda1['boot']) + cfg = "examples/machines/simple.json" + async with start_server(cfg, "bios") as inst: + disk_id = "disk-sda" + resp = await inst.post("/storage/v2/add_boot_partition", disk_id=disk_id) + [sda] = match(resp["disks"], id=disk_id) + [sda1] = match(sda["partitions"], number=1) + self.assertTrue(sda["boot_device"]) + self.assertTrue(sda1["boot"]) @timeout() async def test_v2_blank_is_not_boot(self): - cfg = 'examples/machines/simple.json' - async with start_server(cfg, 'bios') as inst: - disk_id = 'disk-sda' - resp = await inst.get('/storage/v2') - [sda] = match(resp['disks'], id=disk_id) - self.assertFalse(sda['boot_device']) + cfg = "examples/machines/simple.json" + async with start_server(cfg, "bios") as inst: + disk_id = "disk-sda" + resp = await inst.get("/storage/v2") + [sda] = match(resp["disks"], id=disk_id) + self.assertFalse(sda["boot_device"]) @timeout() async def test_v2_multi_disk_multi_boot(self): - cfg = 'examples/machines/many-nics-and-disks.json' + cfg = "examples/machines/many-nics-and-disks.json" async with start_server(cfg) as inst: - resp = await inst.get('/storage/v2') - [d1] = match(resp['disks'], id='disk-vda') - [d2] = match(resp['disks'], id='disk-vdb') - await inst.post('/storage/v2/reformat_disk', {'disk_id': d1['id']}) - await inst.post('/storage/v2/reformat_disk', {'disk_id': d2['id']}) - await inst.post('/storage/v2/add_boot_partition', disk_id=d1['id']) - await inst.post('/storage/v2/add_boot_partition', disk_id=d2['id']) + resp = await inst.get("/storage/v2") + [d1] = match(resp["disks"], id="disk-vda") + [d2] = match(resp["disks"], id="disk-vdb") + await inst.post("/storage/v2/reformat_disk", {"disk_id": d1["id"]}) + await inst.post("/storage/v2/reformat_disk", {"disk_id": d2["id"]}) + await inst.post("/storage/v2/add_boot_partition", disk_id=d1["id"]) + await inst.post("/storage/v2/add_boot_partition", disk_id=d2["id"]) # should allow both disks to get a boot partition with no Exception class TestDelete(TestAPI): @timeout() async def test_v2_delete_without_reformat(self): - cfg = 'examples/machines/win10.json' - extra = ['--storage-version', '1'] + cfg = "examples/machines/win10.json" + extra = ["--storage-version", "1"] async with start_server(cfg, extra_args=extra) as inst: - data = { - 'disk_id': 'disk-sda', - 'partition': {'number': 1} - } + data = {"disk_id": "disk-sda", "partition": {"number": 1}} with self.assertRaises(ClientResponseError): - await inst.post('/storage/v2/delete_partition', data) + await inst.post("/storage/v2/delete_partition", data) @timeout() async def test_v2_delete_without_reformat_is_ok_with_sv2(self): - cfg = 'examples/machines/win10.json' - extra = ['--storage-version', '2'] + cfg = "examples/machines/win10.json" + extra = ["--storage-version", "2"] async with start_server(cfg, extra_args=extra) as inst: - data = { - 'disk_id': 'disk-sda', - 'partition': {'number': 1} - } - await inst.post('/storage/v2/delete_partition', data) + data = {"disk_id": "disk-sda", "partition": {"number": 1}} + await inst.post("/storage/v2/delete_partition", data) @timeout() async def test_v2_delete_with_reformat(self): - async with start_server('examples/machines/win10.json') as inst: - disk_id = 'disk-sda' - resp = await inst.post('/storage/v2/reformat_disk', - {'disk_id': disk_id}) - [sda] = resp['disks'] - [gap] = sda['partitions'] + async with start_server("examples/machines/win10.json") as inst: + disk_id = "disk-sda" + resp = await inst.post("/storage/v2/reformat_disk", {"disk_id": disk_id}) + [sda] = resp["disks"] + [gap] = sda["partitions"] data = { - 'disk_id': disk_id, - 'gap': gap, - 'partition': { - 'mount': '/', - 'format': 'ext4', - } + "disk_id": disk_id, + "gap": gap, + "partition": { + "mount": "/", + "format": "ext4", + }, } - await inst.post('/storage/v2/add_partition', data) - data = { - 'disk_id': disk_id, - 'partition': {'number': 1} - } - await inst.post('/storage/v2/delete_partition', data) + await inst.post("/storage/v2/add_partition", data) + data = {"disk_id": disk_id, "partition": {"number": 1}} + await inst.post("/storage/v2/delete_partition", data) @timeout() async def test_delete_nonexistant(self): - async with start_server('examples/machines/win10.json') as inst: - disk_id = 'disk-sda' - await inst.post('/storage/v2/reformat_disk', {'disk_id': disk_id}) - data = { - 'disk_id': disk_id, - 'partition': {'number': 1} - } + async with start_server("examples/machines/win10.json") as inst: + disk_id = "disk-sda" + await inst.post("/storage/v2/reformat_disk", {"disk_id": disk_id}) + data = {"disk_id": disk_id, "partition": {"number": 1}} with self.assertRaises(ClientResponseError): - await inst.post('/storage/v2/delete_partition', data) + await inst.post("/storage/v2/delete_partition", data) class TestEdit(TestAPI): @timeout() async def test_edit_no_change_size(self): - async with start_server('examples/machines/win10.json') as inst: - disk_id = 'disk-sda' - resp = await inst.get('/storage/v2') + async with start_server("examples/machines/win10.json") as inst: + disk_id = "disk-sda" + resp = await inst.get("/storage/v2") - [sda] = match(resp['disks'], id=disk_id) - [sda3] = match(sda['partitions'], number=3) + [sda] = match(resp["disks"], id=disk_id) + [sda3] = match(sda["partitions"], number=3) data = { - 'disk_id': disk_id, - 'partition': { - 'number': 3, - 'size': sda3['size'] - (1 << 30) - } + "disk_id": disk_id, + "partition": {"number": 3, "size": sda3["size"] - (1 << 30)}, } with self.assertRaises(ClientResponseError): - await inst.post('/storage/v2/edit_partition', data) + await inst.post("/storage/v2/edit_partition", data) @timeout() async def test_edit_no_change_grub(self): - async with start_server('examples/machines/win10.json') as inst: - disk_id = 'disk-sda' + async with start_server("examples/machines/win10.json") as inst: + disk_id = "disk-sda" data = { - 'disk_id': disk_id, - 'partition': { - 'number': 3, - 'boot': True, - } + "disk_id": disk_id, + "partition": { + "number": 3, + "boot": True, + }, } with self.assertRaises(ClientResponseError): - await inst.post('/storage/v2/edit_partition', data) + await inst.post("/storage/v2/edit_partition", data) @timeout() async def test_edit_format(self): - async with start_server('examples/machines/win10.json') as inst: - disk_id = 'disk-sda' + async with start_server("examples/machines/win10.json") as inst: + disk_id = "disk-sda" data = { - 'disk_id': disk_id, - 'partition': { - 'number': 3, - 'format': 'btrfs', - 'wipe': 'superblock', - } + "disk_id": disk_id, + "partition": { + "number": 3, + "format": "btrfs", + "wipe": "superblock", + }, } - resp = await inst.post('/storage/v2/edit_partition', data) + resp = await inst.post("/storage/v2/edit_partition", data) - [sda] = match(resp['disks'], id=disk_id) - [sda3] = match(sda['partitions'], number=3) - self.assertEqual('btrfs', sda3['format']) + [sda] = match(resp["disks"], id=disk_id) + [sda3] = match(sda["partitions"], number=3) + self.assertEqual("btrfs", sda3["format"]) @timeout() async def test_edit_mount(self): - async with start_server('examples/machines/win10.json') as inst: - disk_id = 'disk-sda' + async with start_server("examples/machines/win10.json") as inst: + disk_id = "disk-sda" data = { - 'disk_id': disk_id, - 'partition': { - 'number': 3, - 'mount': '/', - } + "disk_id": disk_id, + "partition": { + "number": 3, + "mount": "/", + }, } - resp = await inst.post('/storage/v2/edit_partition', data) + resp = await inst.post("/storage/v2/edit_partition", data) - [sda] = match(resp['disks'], id=disk_id) - [sda3] = match(sda['partitions'], number=3) - self.assertEqual('/', sda3['mount']) + [sda] = match(resp["disks"], id=disk_id) + [sda3] = match(sda["partitions"], number=3) + self.assertEqual("/", sda3["mount"]) @timeout() async def test_edit_format_and_mount(self): - async with start_server('examples/machines/win10.json') as inst: - disk_id = 'disk-sda' + async with start_server("examples/machines/win10.json") as inst: + disk_id = "disk-sda" data = { - 'disk_id': disk_id, - 'partition': { - 'number': 3, - 'format': 'btrfs', - 'mount': '/', - 'wipe': 'superblock', - } + "disk_id": disk_id, + "partition": { + "number": 3, + "format": "btrfs", + "mount": "/", + "wipe": "superblock", + }, } - resp = await inst.post('/storage/v2/edit_partition', data) + resp = await inst.post("/storage/v2/edit_partition", data) - [sda] = match(resp['disks'], id=disk_id) - [sda3] = match(sda['partitions'], number=3) - self.assertEqual('btrfs', sda3['format']) - self.assertEqual('/', sda3['mount']) + [sda] = match(resp["disks"], id=disk_id) + [sda3] = match(sda["partitions"], number=3) + self.assertEqual("btrfs", sda3["format"]) + self.assertEqual("/", sda3["mount"]) @timeout() async def test_v2_reuse(self): - async with start_server('examples/machines/win10.json') as inst: - disk_id = 'disk-sda' - resp = await inst.get('/storage/v2') - [orig_sda] = match(resp['disks'], id=disk_id) - [_, orig_sda2, _, orig_sda4] = orig_sda['partitions'] + async with start_server("examples/machines/win10.json") as inst: + disk_id = "disk-sda" + resp = await inst.get("/storage/v2") + [orig_sda] = match(resp["disks"], id=disk_id) + [_, orig_sda2, _, orig_sda4] = orig_sda["partitions"] data = { - 'disk_id': disk_id, - 'partition': { - 'number': 3, - 'format': 'ext4', - 'mount': '/', - 'wipe': 'superblock', - } + "disk_id": disk_id, + "partition": { + "number": 3, + "format": "ext4", + "mount": "/", + "wipe": "superblock", + }, } - resp = await inst.post('/storage/v2/edit_partition', data) - [sda] = match(resp['disks'], id=disk_id) - [sda1, sda2, sda3, sda4] = sda['partitions'] - self.assertIsNone(sda1['wipe']) - self.assertEqual('/boot/efi', sda1['mount']) - self.assertEqual('vfat', sda1['format']) - self.assertTrue(sda1['boot']) + resp = await inst.post("/storage/v2/edit_partition", data) + [sda] = match(resp["disks"], id=disk_id) + [sda1, sda2, sda3, sda4] = sda["partitions"] + self.assertIsNone(sda1["wipe"]) + self.assertEqual("/boot/efi", sda1["mount"]) + self.assertEqual("vfat", sda1["format"]) + self.assertTrue(sda1["boot"]) self.assertEqual(orig_sda2, sda2) - self.assertIsNotNone(sda3['wipe']) - self.assertEqual('/', sda3['mount']) - self.assertEqual('ext4', sda3['format']) - self.assertFalse(sda3['boot']) + self.assertIsNotNone(sda3["wipe"]) + self.assertEqual("/", sda3["mount"]) + self.assertEqual("ext4", sda3["format"]) + self.assertFalse(sda3["boot"]) self.assertEqual(orig_sda4, sda4) @@ -956,317 +951,312 @@ class TestEdit(TestAPI): class TestReformat(TestAPI): @timeout() async def test_reformat_msdos(self): - cfg = 'examples/machines/simple.json' + cfg = "examples/machines/simple.json" async with start_server(cfg) as inst: data = { - 'disk_id': 'disk-sda', - 'ptable': 'msdos', + "disk_id": "disk-sda", + "ptable": "msdos", } - resp = await inst.post('/storage/v2/reformat_disk', data) - [sda] = resp['disks'] - self.assertEqual('msdos', sda['ptable']) + resp = await inst.post("/storage/v2/reformat_disk", data) + [sda] = resp["disks"] + self.assertEqual("msdos", sda["ptable"]) class TestPartitionTableTypes(TestAPI): @timeout() async def test_ptable_gpt(self): - async with start_server('examples/machines/win10.json') as inst: - resp = await inst.get('/storage/v2') - [sda] = match(resp['disks'], id='disk-sda') - self.assertEqual('gpt', sda['ptable']) + async with start_server("examples/machines/win10.json") as inst: + resp = await inst.get("/storage/v2") + [sda] = match(resp["disks"], id="disk-sda") + self.assertEqual("gpt", sda["ptable"]) @timeout() async def test_ptable_msdos(self): - cfg = 'examples/machines/many-nics-and-disks.json' + cfg = "examples/machines/many-nics-and-disks.json" async with start_server(cfg) as inst: - resp = await inst.get('/storage/v2') - [sda] = match(resp['disks'], id='disk-sda') - self.assertEqual('msdos', sda['ptable']) + resp = await inst.get("/storage/v2") + [sda] = match(resp["disks"], id="disk-sda") + self.assertEqual("msdos", sda["ptable"]) @timeout() async def test_ptable_none(self): - async with start_server('examples/machines/simple.json') as inst: - resp = await inst.get('/storage/v2') - [sda] = match(resp['disks'], id='disk-sda') - self.assertEqual(None, sda['ptable']) + async with start_server("examples/machines/simple.json") as inst: + resp = await inst.get("/storage/v2") + [sda] = match(resp["disks"], id="disk-sda") + self.assertEqual(None, sda["ptable"]) class TestTodos(TestAPI): # server indicators of required client actions @timeout() async def test_todos_simple(self): - async with start_server('examples/machines/simple.json') as inst: - disk_id = 'disk-sda' - resp = await inst.post('/storage/v2/reformat_disk', - {'disk_id': disk_id}) - self.assertTrue(resp['need_root']) - self.assertTrue(resp['need_boot']) + async with start_server("examples/machines/simple.json") as inst: + disk_id = "disk-sda" + resp = await inst.post("/storage/v2/reformat_disk", {"disk_id": disk_id}) + self.assertTrue(resp["need_root"]) + self.assertTrue(resp["need_boot"]) - [sda] = resp['disks'] - [gap] = sda['partitions'] + [sda] = resp["disks"] + [gap] = sda["partitions"] data = { - 'disk_id': disk_id, - 'gap': gap, - 'partition': { - 'format': 'ext4', - 'mount': '/', - } + "disk_id": disk_id, + "gap": gap, + "partition": { + "format": "ext4", + "mount": "/", + }, } - resp = await inst.post('/storage/v2/add_partition', data) - self.assertFalse(resp['need_root']) - self.assertFalse(resp['need_boot']) + resp = await inst.post("/storage/v2/add_partition", data) + self.assertFalse(resp["need_root"]) + self.assertFalse(resp["need_boot"]) @timeout() async def test_todos_manual(self): - async with start_server('examples/machines/simple.json') as inst: - disk_id = 'disk-sda' - resp = await inst.post('/storage/v2/reformat_disk', - {'disk_id': disk_id}) - self.assertTrue(resp['need_root']) - self.assertTrue(resp['need_boot']) + async with start_server("examples/machines/simple.json") as inst: + disk_id = "disk-sda" + resp = await inst.post("/storage/v2/reformat_disk", {"disk_id": disk_id}) + self.assertTrue(resp["need_root"]) + self.assertTrue(resp["need_boot"]) - resp = await inst.post('/storage/v2/add_boot_partition', - disk_id=disk_id) - self.assertTrue(resp['need_root']) - self.assertFalse(resp['need_boot']) + resp = await inst.post("/storage/v2/add_boot_partition", disk_id=disk_id) + self.assertTrue(resp["need_root"]) + self.assertFalse(resp["need_boot"]) - [sda] = resp['disks'] - [gap] = match(sda['partitions'], _type='Gap') + [sda] = resp["disks"] + [gap] = match(sda["partitions"], _type="Gap") data = { - 'disk_id': disk_id, - 'gap': gap, - 'partition': { - 'format': 'ext4', - 'mount': '/', - } + "disk_id": disk_id, + "gap": gap, + "partition": { + "format": "ext4", + "mount": "/", + }, } - resp = await inst.post('/storage/v2/add_partition', data) - self.assertFalse(resp['need_root']) - self.assertFalse(resp['need_boot']) + resp = await inst.post("/storage/v2/add_partition", data) + self.assertFalse(resp["need_root"]) + self.assertFalse(resp["need_boot"]) @timeout() async def test_todos_guided(self): - async with start_server('examples/machines/simple.json') as inst: - resp = await inst.post('/storage/v2/reformat_disk', - {'disk_id': 'disk-sda'}) - self.assertTrue(resp['need_root']) - self.assertTrue(resp['need_boot']) + async with start_server("examples/machines/simple.json") as inst: + resp = await inst.post("/storage/v2/reformat_disk", {"disk_id": "disk-sda"}) + self.assertTrue(resp["need_root"]) + self.assertTrue(resp["need_boot"]) - resp = await inst.get('/storage/v2/guided') - [reformat, manual] = resp['targets'] + resp = await inst.get("/storage/v2/guided") + [reformat, manual] = resp["targets"] data = { - 'target': reformat, - 'capability': reformat['allowed'][0], - } - await inst.post('/storage/v2/guided', data) - resp = await inst.get('/storage/v2') - self.assertFalse(resp['need_root']) - self.assertFalse(resp['need_boot']) + "target": reformat, + "capability": reformat["allowed"][0], + } + await inst.post("/storage/v2/guided", data) + resp = await inst.get("/storage/v2") + self.assertFalse(resp["need_root"]) + self.assertFalse(resp["need_boot"]) class TestInfo(TestAPI): @timeout() async def test_path(self): - async with start_server('examples/machines/simple.json') as inst: - disk_id = 'disk-sda' - resp = await inst.get('/storage/v2') - [sda] = match(resp['disks'], id=disk_id) - self.assertEqual('/dev/sda', sda['path']) + async with start_server("examples/machines/simple.json") as inst: + disk_id = "disk-sda" + resp = await inst.get("/storage/v2") + [sda] = match(resp["disks"], id=disk_id) + self.assertEqual("/dev/sda", sda["path"]) @timeout() async def test_model_and_vendor(self): - async with start_server('examples/machines/simple.json') as inst: - disk_id = 'disk-sda' - resp = await inst.get('/storage/v2') - [sda] = match(resp['disks'], id=disk_id) - self.assertEqual('QEMU HARDDISK', sda['model']) - self.assertEqual('ATA', sda['vendor']) + async with start_server("examples/machines/simple.json") as inst: + disk_id = "disk-sda" + resp = await inst.get("/storage/v2") + [sda] = match(resp["disks"], id=disk_id) + self.assertEqual("QEMU HARDDISK", sda["model"]) + self.assertEqual("ATA", sda["vendor"]) @timeout() async def test_no_vendor(self): - cfg = 'examples/machines/many-nics-and-disks.json' + cfg = "examples/machines/many-nics-and-disks.json" async with start_server(cfg) as inst: - disk_id = 'disk-sda' - resp = await inst.get('/storage/v2') - [sda] = match(resp['disks'], id=disk_id) - self.assertEqual('QEMU HARDDISK', sda['model']) - self.assertEqual(None, sda['vendor']) + disk_id = "disk-sda" + resp = await inst.get("/storage/v2") + [sda] = match(resp["disks"], id=disk_id) + self.assertEqual("QEMU HARDDISK", sda["model"]) + self.assertEqual(None, sda["vendor"]) class TestFree(TestAPI): @timeout() async def test_free_only(self): - async with start_server('examples/machines/simple.json') as inst: - await inst.post('/meta/free_only', enable=True) - components = await inst.get('/mirror/disable_components') + async with start_server("examples/machines/simple.json") as inst: + await inst.post("/meta/free_only", enable=True) + components = await inst.get("/mirror/disable_components") components.sort() - self.assertEqual(['multiverse', 'restricted'], components) + self.assertEqual(["multiverse", "restricted"], components) @timeout() async def test_not_free_only(self): - async with start_server('examples/machines/simple.json') as inst: - comps = ['universe', 'multiverse'] - await inst.post('/mirror/disable_components', comps) - await inst.post('/meta/free_only', enable=False) - components = await inst.get('/mirror/disable_components') - self.assertEqual(['universe'], components) + async with start_server("examples/machines/simple.json") as inst: + comps = ["universe", "multiverse"] + await inst.post("/mirror/disable_components", comps) + await inst.post("/meta/free_only", enable=False) + components = await inst.get("/mirror/disable_components") + self.assertEqual(["universe"], components) class TestOSProbe(TestAPI): @timeout() async def test_win10(self): - async with start_server('examples/machines/win10.json') as inst: - resp = await inst.get('/storage/v2') - [sda] = match(resp['disks'], id='disk-sda') - [sda1] = match(sda['partitions'], number=1) + async with start_server("examples/machines/win10.json") as inst: + resp = await inst.get("/storage/v2") + [sda] = match(resp["disks"], id="disk-sda") + [sda1] = match(sda["partitions"], number=1) expected = { - 'label': 'Windows', - 'long': 'Windows Boot Manager', - 'subpath': '/efi/Microsoft/Boot/bootmgfw.efi', - 'type': 'efi', - 'version': None + "label": "Windows", + "long": "Windows Boot Manager", + "subpath": "/efi/Microsoft/Boot/bootmgfw.efi", + "type": "efi", + "version": None, } - self.assertEqual(expected, sda1['os']) + self.assertEqual(expected, sda1["os"]) class TestPartitionTableEditing(TestAPI): @timeout() async def test_use_free_space_after_existing(self): - cfg = 'examples/machines/ubuntu-and-free-space.json' - extra = ['--storage-version', '2'] + cfg = "examples/machines/ubuntu-and-free-space.json" + extra = ["--storage-version", "2"] async with start_server(cfg, extra_args=extra) as inst: # Disk has 3 existing partitions and free space. Add one to end. # sda1 is an ESP, so that should get implicitly picked up. - resp = await inst.get('/storage/v2') - [sda] = resp['disks'] - [e1, e2, e3, gap] = sda['partitions'] - self.assertEqual('Gap', gap['$type']) + resp = await inst.get("/storage/v2") + [sda] = resp["disks"] + [e1, e2, e3, gap] = sda["partitions"] + self.assertEqual("Gap", gap["$type"]) data = { - 'disk_id': 'disk-sda', - 'gap': gap, - 'partition': { - 'format': 'ext4', - 'mount': '/', - } + "disk_id": "disk-sda", + "gap": gap, + "partition": { + "format": "ext4", + "mount": "/", + }, } - resp = await inst.post('/storage/v2/add_partition', data) - [sda] = resp['disks'] - [p1, p2, p3, p4] = sda['partitions'] - e1.pop('annotations') - e1.update({'mount': '/boot/efi', 'grub_device': True}) + resp = await inst.post("/storage/v2/add_partition", data) + [sda] = resp["disks"] + [p1, p2, p3, p4] = sda["partitions"] + e1.pop("annotations") + e1.update({"mount": "/boot/efi", "grub_device": True}) self.assertDictSubset(e1, p1) self.assertEqual(e2, p2) self.assertEqual(e3, p3) e4 = { - '$type': 'Partition', - 'number': 4, - 'size': gap['size'], - 'offset': gap['offset'], - 'format': 'ext4', - 'mount': '/', + "$type": "Partition", + "number": 4, + "size": gap["size"], + "offset": gap["offset"], + "format": "ext4", + "mount": "/", } self.assertDictSubset(e4, p4) @timeout() async def test_resize(self): - cfg = self.machineConfig( - 'examples/machines/ubuntu-and-free-space.json') + cfg = self.machineConfig("examples/machines/ubuntu-and-free-space.json") with cfg.edit() as data: - blockdev = data['storage']['blockdev'] - sizes = {k: int(v['attrs']['size']) for k, v in blockdev.items()} + blockdev = data["storage"]["blockdev"] + sizes = {k: int(v["attrs"]["size"]) for k, v in blockdev.items()} # expand sda3 to use the rest of the disk - sda3_size = (sizes['/dev/sda'] - sizes['/dev/sda1'] - - sizes['/dev/sda2'] - (2 << 20)) - blockdev['/dev/sda3']['attrs']['size'] = str(sda3_size) + sda3_size = ( + sizes["/dev/sda"] - sizes["/dev/sda1"] - sizes["/dev/sda2"] - (2 << 20) + ) + blockdev["/dev/sda3"]["attrs"]["size"] = str(sda3_size) - extra = ['--storage-version', '2'] + extra = ["--storage-version", "2"] async with start_server(cfg, extra_args=extra) as inst: # Disk has 3 existing partitions and no free space. - resp = await inst.get('/storage/v2') - [sda] = resp['disks'] - [orig_p1, orig_p2, orig_p3] = sda['partitions'] + resp = await inst.get("/storage/v2") + [sda] = resp["disks"] + [orig_p1, orig_p2, orig_p3] = sda["partitions"] p3 = orig_p3.copy() - p3['size'] = 10 << 30 + p3["size"] = 10 << 30 data = { - 'disk_id': 'disk-sda', - 'partition': p3, + "disk_id": "disk-sda", + "partition": p3, } - resp = await inst.post('/storage/v2/edit_partition', data) - [sda] = resp['disks'] - [_, _, actual_p3, g1] = sda['partitions'] - self.assertEqual(10 << 30, actual_p3['size']) - self.assertEqual(True, actual_p3['resize']) - self.assertIsNone(actual_p3['wipe']) - end_size = orig_p3['size'] - (10 << 30) - self.assertEqual(end_size, g1['size']) + resp = await inst.post("/storage/v2/edit_partition", data) + [sda] = resp["disks"] + [_, _, actual_p3, g1] = sda["partitions"] + self.assertEqual(10 << 30, actual_p3["size"]) + self.assertEqual(True, actual_p3["resize"]) + self.assertIsNone(actual_p3["wipe"]) + end_size = orig_p3["size"] - (10 << 30) + self.assertEqual(end_size, g1["size"]) expected_p1 = orig_p1.copy() - expected_p1.pop('annotations') - expected_p1.update({'mount': '/boot/efi', 'grub_device': True}) + expected_p1.pop("annotations") + expected_p1.update({"mount": "/boot/efi", "grub_device": True}) expected_p3 = actual_p3 data = { - 'disk_id': 'disk-sda', - 'gap': g1, - 'partition': { - 'format': 'ext4', - 'mount': '/srv', - } + "disk_id": "disk-sda", + "gap": g1, + "partition": { + "format": "ext4", + "mount": "/srv", + }, } - resp = await inst.post('/storage/v2/add_partition', data) - [sda] = resp['disks'] - [actual_p1, actual_p2, actual_p3, actual_p4] = sda['partitions'] + resp = await inst.post("/storage/v2/add_partition", data) + [sda] = resp["disks"] + [actual_p1, actual_p2, actual_p3, actual_p4] = sda["partitions"] self.assertDictSubset(expected_p1, actual_p1) self.assertEqual(orig_p2, actual_p2) self.assertEqual(expected_p3, actual_p3) - self.assertEqual(end_size, actual_p4['size']) - self.assertEqual('Partition', actual_p4['$type']) + self.assertEqual(end_size, actual_p4["size"]) + self.assertEqual("Partition", actual_p4["$type"]) - v1resp = await inst.get('/storage') - config = v1resp['config'] - [sda3] = match(config, type='partition', number=3) - [sda3_format] = match(config, type='format', volume=sda3['id']) - self.assertTrue(sda3['preserve']) - self.assertTrue(sda3['resize']) - self.assertTrue(sda3_format['preserve']) + v1resp = await inst.get("/storage") + config = v1resp["config"] + [sda3] = match(config, type="partition", number=3) + [sda3_format] = match(config, type="format", volume=sda3["id"]) + self.assertTrue(sda3["preserve"]) + self.assertTrue(sda3["resize"]) + self.assertTrue(sda3_format["preserve"]) @timeout() async def test_est_min_size(self): - cfg = self.machineConfig('examples/machines/win10-along-ubuntu.json') + cfg = self.machineConfig("examples/machines/win10-along-ubuntu.json") with cfg.edit() as data: - fs = data['storage']['filesystem'] - fs['/dev/sda1']['ESTIMATED_MIN_SIZE'] = 0 + fs = data["storage"]["filesystem"] + fs["/dev/sda1"]["ESTIMATED_MIN_SIZE"] = 0 # data file has no sda2 in filesystem - fs['/dev/sda3']['ESTIMATED_MIN_SIZE'] = -1 - fs['/dev/sda4']['ESTIMATED_MIN_SIZE'] = (1 << 20) + 1 + fs["/dev/sda3"]["ESTIMATED_MIN_SIZE"] = -1 + fs["/dev/sda4"]["ESTIMATED_MIN_SIZE"] = (1 << 20) + 1 - extra = ['--storage-version', '2'] + extra = ["--storage-version", "2"] async with start_server(cfg, extra_args=extra) as inst: - resp = await inst.get('/storage/v2') - [sda] = resp['disks'] - [p1, _, p3, p4, _] = sda['partitions'] - self.assertEqual(1 << 20, p1['estimated_min_size']) - self.assertEqual(-1, p3['estimated_min_size']) - self.assertEqual(2 << 20, p4['estimated_min_size']) + resp = await inst.get("/storage/v2") + [sda] = resp["disks"] + [p1, _, p3, p4, _] = sda["partitions"] + self.assertEqual(1 << 20, p1["estimated_min_size"]) + self.assertEqual(-1, p3["estimated_min_size"]) + self.assertEqual(2 << 20, p4["estimated_min_size"]) @timeout() async def test_v2_orig_config(self): - cfg = 'examples/machines/win10-along-ubuntu.json' - extra = ['--storage-version', '2'] + cfg = "examples/machines/win10-along-ubuntu.json" + extra = ["--storage-version", "2"] async with start_server(cfg, extra_args=extra) as inst: - start_resp = await inst.get('/storage/v2') - resp = await inst.get('/storage/v2/guided') - resize = match(resp['targets'], - _type='GuidedStorageTargetResize')[0] - resize['new_size'] = 30 << 30 + start_resp = await inst.get("/storage/v2") + resp = await inst.get("/storage/v2/guided") + resize = match(resp["targets"], _type="GuidedStorageTargetResize")[0] + resize["new_size"] = 30 << 30 data = { - 'target': resize, - 'capability': resize['allowed'][0], - } - await inst.post('/storage/v2/guided', data) - orig_config = await inst.get('/storage/v2/orig_config') - end_resp = await inst.get('/storage/v2') + "target": resize, + "capability": resize["allowed"][0], + } + await inst.post("/storage/v2/guided", data) + orig_config = await inst.get("/storage/v2/orig_config") + end_resp = await inst.get("/storage/v2") self.assertEqual(start_resp, orig_config) self.assertNotEqual(start_resp, end_resp) @@ -1274,121 +1264,119 @@ class TestPartitionTableEditing(TestAPI): class TestGap(TestAPI): @timeout() async def test_blank_disk_is_one_big_gap(self): - async with start_server('examples/machines/simple.json') as inst: - resp = await inst.get('/storage/v2') - [sda] = match(resp['disks'], id='disk-sda') - gap = sda['partitions'][0] + async with start_server("examples/machines/simple.json") as inst: + resp = await inst.get("/storage/v2") + [sda] = match(resp["disks"], id="disk-sda") + gap = sda["partitions"][0] expected = (100 << 30) - (2 << 20) - self.assertEqual(expected, gap['size']) - self.assertEqual('YES', gap['usable']) + self.assertEqual(expected, gap["size"]) + self.assertEqual("YES", gap["usable"]) @timeout() async def test_gap_at_end(self): - async with start_server('examples/machines/simple.json') as inst: - resp = await inst.get('/storage/v2') - [sda] = resp['disks'] - [gap] = match(sda['partitions'], _type='Gap') + async with start_server("examples/machines/simple.json") as inst: + resp = await inst.get("/storage/v2") + [sda] = resp["disks"] + [gap] = match(sda["partitions"], _type="Gap") data = { - 'disk_id': 'disk-sda', - 'gap': gap, - 'partition': { - 'format': 'ext4', - 'mount': '/', - 'size': 4 << 30, - } + "disk_id": "disk-sda", + "gap": gap, + "partition": { + "format": "ext4", + "mount": "/", + "size": 4 << 30, + }, } - resp = await inst.post('/storage/v2/add_partition', data) - [sda] = resp['disks'] - [boot] = match(sda['partitions'], mount='/boot/efi') - [p1, p2, gap] = sda['partitions'] - self.assertEqual('Gap', gap['$type']) - expected = (100 << 30) - p1['size'] - p2['size'] - (2 << 20) - self.assertEqual(expected, gap['size']) + resp = await inst.post("/storage/v2/add_partition", data) + [sda] = resp["disks"] + [boot] = match(sda["partitions"], mount="/boot/efi") + [p1, p2, gap] = sda["partitions"] + self.assertEqual("Gap", gap["$type"]) + expected = (100 << 30) - p1["size"] - p2["size"] - (2 << 20) + self.assertEqual(expected, gap["size"]) @timeout() async def SKIP_test_two_gaps(self): - async with start_server('examples/machines/simple.json') as inst: - disk_id = 'disk-sda' - resp = await inst.post('/storage/v2/add_boot_partition', - disk_id=disk_id) + async with start_server("examples/machines/simple.json") as inst: + disk_id = "disk-sda" + resp = await inst.post("/storage/v2/add_boot_partition", disk_id=disk_id) json_print(resp) - boot_size = resp['disks'][0]['partitions'][0]['size'] + boot_size = resp["disks"][0]["partitions"][0]["size"] root_size = 4 << 30 data = { - 'disk_id': disk_id, - 'partition': { - 'format': 'ext4', - 'mount': '/', - 'size': root_size, - } + "disk_id": disk_id, + "partition": { + "format": "ext4", + "mount": "/", + "size": root_size, + }, } - await inst.post('/storage/v2/add_partition', data) - data = { - 'disk_id': disk_id, - 'partition': {'number': 1} - } - resp = await inst.post('/storage/v2/delete_partition', data) - [sda] = match(resp['disks'], id=disk_id) - self.assertEqual(3, len(sda['partitions'])) + await inst.post("/storage/v2/add_partition", data) + data = {"disk_id": disk_id, "partition": {"number": 1}} + resp = await inst.post("/storage/v2/delete_partition", data) + [sda] = match(resp["disks"], id=disk_id) + self.assertEqual(3, len(sda["partitions"])) - boot_gap = sda['partitions'][0] - self.assertEqual(boot_size, boot_gap['size']) - self.assertEqual('Gap', boot_gap['$type']) + boot_gap = sda["partitions"][0] + self.assertEqual(boot_size, boot_gap["size"]) + self.assertEqual("Gap", boot_gap["$type"]) - root = sda['partitions'][1] - self.assertEqual(root_size, root['size']) - self.assertEqual('Partition', root['$type']) + root = sda["partitions"][1] + self.assertEqual(root_size, root["size"]) + self.assertEqual("Partition", root["$type"]) - end_gap = sda['partitions'][2] + end_gap = sda["partitions"][2] end_size = (10 << 30) - boot_size - root_size - (2 << 20) - self.assertEqual(end_size, end_gap['size']) - self.assertEqual('Gap', end_gap['$type']) + self.assertEqual(end_size, end_gap["size"]) + self.assertEqual("Gap", end_gap["$type"]) class TestRegression(TestAPI): @timeout() async def test_edit_not_trigger_boot_device(self): - async with start_server('examples/machines/simple.json') as inst: - disk_id = 'disk-sda' - resp = await inst.get('/storage/v2') - [sda] = resp['disks'] - [gap] = sda['partitions'] + async with start_server("examples/machines/simple.json") as inst: + disk_id = "disk-sda" + resp = await inst.get("/storage/v2") + [sda] = resp["disks"] + [gap] = sda["partitions"] data = { - 'disk_id': disk_id, - 'gap': gap, - 'partition': { - 'format': 'ext4', - 'mount': '/foo', - } + "disk_id": disk_id, + "gap": gap, + "partition": { + "format": "ext4", + "mount": "/foo", + }, } - resp = await inst.post('/storage/v2/add_partition', data) - [sda] = resp['disks'] - [part] = match(sda['partitions'], mount='/foo') - part.update({ - 'format': 'ext3', - 'mount': '/bar', - 'wipe': 'superblock', - }) - data['partition'] = part - data.pop('gap') - await inst.post('/storage/v2/edit_partition', data) + resp = await inst.post("/storage/v2/add_partition", data) + [sda] = resp["disks"] + [part] = match(sda["partitions"], mount="/foo") + part.update( + { + "format": "ext3", + "mount": "/bar", + "wipe": "superblock", + } + ) + data["partition"] = part + data.pop("gap") + await inst.post("/storage/v2/edit_partition", data) # should not throw an exception complaining about boot @timeout() async def test_osprober_knames(self): - cfg = 'examples/machines/lp-1986676-missing-osprober.json' + cfg = "examples/machines/lp-1986676-missing-osprober.json" async with start_server(cfg) as inst: - resp = await inst.get('/storage/v2') - [nvme] = match(resp['disks'], id='disk-nvme0n1') - [nvme_p2] = match(nvme['partitions'], path='/dev/nvme0n1p2') + resp = await inst.get("/storage/v2") + [nvme] = match(resp["disks"], id="disk-nvme0n1") + [nvme_p2] = match(nvme["partitions"], path="/dev/nvme0n1p2") expected = { - 'long': 'Ubuntu 22.04.1 LTS', - 'label': 'Ubuntu', - 'type': 'linux', - 'subpath': None, - 'version': '22.04.1' + "long": "Ubuntu 22.04.1 LTS", + "label": "Ubuntu", + "type": "linux", + "subpath": None, + "version": "22.04.1", } - self.assertEqual(expected, nvme_p2['os']) + self.assertEqual(expected, nvme_p2["os"]) @timeout() async def test_edit_should_trigger_wipe_when_requested(self): @@ -1396,547 +1384,582 @@ class TestRegression(TestAPI): # The old way this worked was to use changes to the 'format' value to # decide if a wipe was happening or not, and now the client chooses so # explicitly. - cfg = 'examples/machines/win10-along-ubuntu.json' + cfg = "examples/machines/win10-along-ubuntu.json" async with start_server(cfg) as inst: - resp = await inst.get('/storage/v2') - [d1] = resp['disks'] - [p5] = match(d1['partitions'], number=5) - p5.update(dict(mount='/home', wipe='superblock')) - data = dict(disk_id=d1['id'], partition=p5) - resp = await inst.post('/storage/v2/edit_partition', data) + resp = await inst.get("/storage/v2") + [d1] = resp["disks"] + [p5] = match(d1["partitions"], number=5) + p5.update(dict(mount="/home", wipe="superblock")) + data = dict(disk_id=d1["id"], partition=p5) + resp = await inst.post("/storage/v2/edit_partition", data) - v1resp = await inst.get('/storage') - [c_p5] = match(v1resp['config'], number=5) - [c_p5fmt] = match(v1resp['config'], volume=c_p5['id']) - self.assertEqual('superblock', c_p5['wipe']) - self.assertFalse(c_p5fmt['preserve']) - self.assertTrue(c_p5['preserve']) + v1resp = await inst.get("/storage") + [c_p5] = match(v1resp["config"], number=5) + [c_p5fmt] = match(v1resp["config"], volume=c_p5["id"]) + self.assertEqual("superblock", c_p5["wipe"]) + self.assertFalse(c_p5fmt["preserve"]) + self.assertTrue(c_p5["preserve"]) # then let's change our minds and not wipe it - [d1] = resp['disks'] - [p5] = match(d1['partitions'], number=5) - p5['wipe'] = None - data = dict(disk_id=d1['id'], partition=p5) - resp = await inst.post('/storage/v2/edit_partition', data) + [d1] = resp["disks"] + [p5] = match(d1["partitions"], number=5) + p5["wipe"] = None + data = dict(disk_id=d1["id"], partition=p5) + resp = await inst.post("/storage/v2/edit_partition", data) - v1resp = await inst.get('/storage') - [c_p5] = match(v1resp['config'], number=5) - [c_p5fmt] = match(v1resp['config'], volume=c_p5['id']) - self.assertNotIn('wipe', c_p5) - self.assertTrue(c_p5fmt['preserve']) - self.assertTrue(c_p5['preserve']) + v1resp = await inst.get("/storage") + [c_p5] = match(v1resp["config"], number=5) + [c_p5fmt] = match(v1resp["config"], volume=c_p5["id"]) + self.assertNotIn("wipe", c_p5) + self.assertTrue(c_p5fmt["preserve"]) + self.assertTrue(c_p5["preserve"]) @timeout() async def test_edit_should_leave_other_values_alone(self): - cfg = 'examples/machines/win10-along-ubuntu.json' + cfg = "examples/machines/win10-along-ubuntu.json" async with start_server(cfg) as inst: + async def check_preserve(): - v1resp = await inst.get('/storage') - [c_p5] = match(v1resp['config'], number=5) - [c_p5fmt] = match(v1resp['config'], volume=c_p5['id']) - self.assertNotIn('wipe', c_p5) - self.assertTrue(c_p5fmt['preserve']) - self.assertTrue(c_p5['preserve']) + v1resp = await inst.get("/storage") + [c_p5] = match(v1resp["config"], number=5) + [c_p5fmt] = match(v1resp["config"], volume=c_p5["id"]) + self.assertNotIn("wipe", c_p5) + self.assertTrue(c_p5fmt["preserve"]) + self.assertTrue(c_p5["preserve"]) - resp = await inst.get('/storage/v2') - d1 = resp['disks'][0] - [p5] = match(resp['disks'][0]['partitions'], number=5) + resp = await inst.get("/storage/v2") + d1 = resp["disks"][0] + [p5] = match(resp["disks"][0]["partitions"], number=5) orig_p5 = p5.copy() - self.assertEqual('ext4', p5['format']) - self.assertIsNone(p5['mount']) + self.assertEqual("ext4", p5["format"]) + self.assertIsNone(p5["mount"]) - data = {'disk_id': d1['id'], 'partition': p5} - resp = await inst.post('/storage/v2/edit_partition', data) - [p5] = match(resp['disks'][0]['partitions'], number=5) + data = {"disk_id": d1["id"], "partition": p5} + resp = await inst.post("/storage/v2/edit_partition", data) + [p5] = match(resp["disks"][0]["partitions"], number=5) self.assertEqual(orig_p5, p5) await check_preserve() - p5.update({'mount': '/'}) - data = {'disk_id': d1['id'], 'partition': p5} - resp = await inst.post('/storage/v2/edit_partition', data) - [p5] = match(resp['disks'][0]['partitions'], number=5) + p5.update({"mount": "/"}) + data = {"disk_id": d1["id"], "partition": p5} + resp = await inst.post("/storage/v2/edit_partition", data) + [p5] = match(resp["disks"][0]["partitions"], number=5) expected = orig_p5.copy() - expected['mount'] = '/' - expected['annotations'] = [ - 'existing', 'already formatted as ext4', 'mounted at /' + expected["mount"] = "/" + expected["annotations"] = [ + "existing", + "already formatted as ext4", + "mounted at /", ] self.assertEqual(expected, p5) await check_preserve() - data = {'disk_id': d1['id'], 'partition': p5} - resp = await inst.post('/storage/v2/edit_partition', data) - [p5] = match(resp['disks'][0]['partitions'], number=5) + data = {"disk_id": d1["id"], "partition": p5} + resp = await inst.post("/storage/v2/edit_partition", data) + [p5] = match(resp["disks"][0]["partitions"], number=5) self.assertEqual(expected, p5) await check_preserve() @timeout() async def test_no_change_edit(self): - cfg = 'examples/machines/simple.json' - extra = ['--storage-version', '2'] + cfg = "examples/machines/simple.json" + extra = ["--storage-version", "2"] async with start_server(cfg, extra_args=extra) as inst: - resp = await inst.get('/storage/v2') - [d] = resp['disks'] - [g] = d['partitions'] + resp = await inst.get("/storage/v2") + [d] = resp["disks"] + [g] = d["partitions"] data = { - "disk_id": 'disk-sda', + "disk_id": "disk-sda", "gap": g, "partition": { "size": 107372085248, "format": "ext4", - } + }, } - resp = await inst.post('/storage/v2/add_partition', data) - [p] = resp['disks'][0]['partitions'] - self.assertEqual('ext4', p['format']) + resp = await inst.post("/storage/v2/add_partition", data) + [p] = resp["disks"][0]["partitions"] + self.assertEqual("ext4", p["format"]) orig_p = p.copy() - data = {"disk_id": 'disk-sda', "partition": p} - resp = await inst.post('/storage/v2/edit_partition', data) - [p] = resp['disks'][0]['partitions'] + data = {"disk_id": "disk-sda", "partition": p} + resp = await inst.post("/storage/v2/edit_partition", data) + [p] = resp["disks"][0]["partitions"] self.assertEqual(orig_p, p) @timeout() async def test_no_change_edit_swap(self): - '''LP: 2002413 - editing a swap partition would fail with + """LP: 2002413 - editing a swap partition would fail with > Exception: Filesystem(fstype='swap', ...) is already mounted Make sure editing the partition is ok now. - ''' - cfg = 'examples/machines/simple.json' - extra = ['--storage-version', '2'] + """ + cfg = "examples/machines/simple.json" + extra = ["--storage-version", "2"] async with start_server(cfg, extra_args=extra) as inst: - resp = await inst.get('/storage/v2') - [d] = resp['disks'] - [g] = d['partitions'] + resp = await inst.get("/storage/v2") + [d] = resp["disks"] + [g] = d["partitions"] data = { - "disk_id": 'disk-sda', + "disk_id": "disk-sda", "gap": g, "partition": { "size": 8589934592, # 8 GiB "format": "swap", - } + }, } - resp = await inst.post('/storage/v2/add_partition', data) - [p, gap] = resp['disks'][0]['partitions'] - self.assertEqual('swap', p['format']) + resp = await inst.post("/storage/v2/add_partition", data) + [p, gap] = resp["disks"][0]["partitions"] + self.assertEqual("swap", p["format"]) orig_p = p.copy() - data = {"disk_id": 'disk-sda', "partition": p} - resp = await inst.post('/storage/v2/edit_partition', data) - [p, gap] = resp['disks'][0]['partitions'] + data = {"disk_id": "disk-sda", "partition": p} + resp = await inst.post("/storage/v2/edit_partition", data) + [p, gap] = resp["disks"][0]["partitions"] self.assertEqual(orig_p, p) @timeout() async def test_can_create_unformatted_partition(self): - '''We want to offer the same list of fstypes for Subiquity and U-D-I, - but the list is different today. Verify that unformatted partitions - may be created.''' - cfg = 'examples/machines/simple.json' - extra = ['--storage-version', '2'] + """We want to offer the same list of fstypes for Subiquity and U-D-I, + but the list is different today. Verify that unformatted partitions + may be created.""" + cfg = "examples/machines/simple.json" + extra = ["--storage-version", "2"] async with start_server(cfg, extra_args=extra) as inst: - resp = await inst.get('/storage/v2') - [d] = resp['disks'] - [g] = d['partitions'] + resp = await inst.get("/storage/v2") + [d] = resp["disks"] + [g] = d["partitions"] data = { - 'disk_id': 'disk-sda', - 'gap': g, - 'partition': { - 'size': -1, - 'format': None, - } + "disk_id": "disk-sda", + "gap": g, + "partition": { + "size": -1, + "format": None, + }, } - resp = await inst.post('/storage/v2/add_partition', data) - [p] = resp['disks'][0]['partitions'] - self.assertIsNone(p['format']) - v1resp = await inst.get('/storage') - self.assertEqual([], match(v1resp['config'], type='format')) + resp = await inst.post("/storage/v2/add_partition", data) + [p] = resp["disks"][0]["partitions"] + self.assertIsNone(p["format"]) + v1resp = await inst.get("/storage") + self.assertEqual([], match(v1resp["config"], type="format")) @timeout() async def test_guided_v2_resize_logical_middle_partition(self): - '''LP: #2015521 - a logical partition that wasn't the physically last + """LP: #2015521 - a logical partition that wasn't the physically last logical partition was resized to allow creation of more partitions, but the 1MiB space was not left between the newly created partition and the - physically last partition.''' - cfg = 'examples/machines/threebuntu-on-msdos.json' - extra = ['--storage-version', '2'] + physically last partition.""" + cfg = "examples/machines/threebuntu-on-msdos.json" + extra = ["--storage-version", "2"] async with start_server(cfg, extra_args=extra) as inst: - resp = await inst.get('/storage/v2/guided') - [resize] = match(resp['targets'], partition_number=5, - _type='GuidedStorageTargetResize') + resp = await inst.get("/storage/v2/guided") + [resize] = match( + resp["targets"], partition_number=5, _type="GuidedStorageTargetResize" + ) data = { - 'target': resize, - 'capability': resize['allowed'][0], - } - resp = await inst.post('/storage/v2/guided', data) - self.assertEqual(resize, resp['configured']['target']) + "target": resize, + "capability": resize["allowed"][0], + } + resp = await inst.post("/storage/v2/guided", data) + self.assertEqual(resize, resp["configured"]["target"]) - resp = await inst.get('/storage') - parts = match(resp['config'], type='partition', flag='logical') + resp = await inst.get("/storage") + parts = match(resp["config"], type="partition", flag="logical") logicals = [] for part in parts: - part['end'] = part['offset'] + part['size'] + part["end"] = part["offset"] + part["size"] logicals.append(part) - logicals.sort(key=lambda p: p['offset']) + logicals.sort(key=lambda p: p["offset"]) for i in range(len(logicals) - 1): - cur, nxt = logicals[i:i+2] - self.assertLessEqual(cur['end'] + (1 << 20), nxt['offset'], - f'partition overlap {cur} {nxt}') + cur, nxt = logicals[i : i + 2] + self.assertLessEqual( + cur["end"] + (1 << 20), + nxt["offset"], + f"partition overlap {cur} {nxt}", + ) @timeout() async def test_probert_result_during_partitioning(self): - '''LP: #2016901 - when a probert run finishes during manual + """LP: #2016901 - when a probert run finishes during manual partitioning, we used to load the probing data automatically ; essentially discarding any change made by the user so far. This test creates a new partition, simulates the end of a probert run, and then tries to edit the previously created partition. The edit operation would fail in earlier versions because the new partition would have - been discarded. ''' - cfg = 'examples/machines/simple.json' - extra = ['--storage-version', '2'] + been discarded.""" + cfg = "examples/machines/simple.json" + extra = ["--storage-version", "2"] async with start_server(cfg, extra_args=extra) as inst: - names = ['locale', 'keyboard', 'source', 'network', 'proxy', - 'mirror'] - await inst.post('/meta/mark_configured', endpoint_names=names) - resp = await inst.get('/storage/v2') - [d] = resp['disks'] - [g] = d['partitions'] + names = ["locale", "keyboard", "source", "network", "proxy", "mirror"] + await inst.post("/meta/mark_configured", endpoint_names=names) + resp = await inst.get("/storage/v2") + [d] = resp["disks"] + [g] = d["partitions"] data = { - 'disk_id': 'disk-sda', - 'gap': g, - 'partition': { - 'size': -1, - 'mount': '/', - 'format': 'ext4', - } + "disk_id": "disk-sda", + "gap": g, + "partition": { + "size": -1, + "mount": "/", + "format": "ext4", + }, } - add_resp = await inst.post('/storage/v2/add_partition', data) - [sda] = add_resp['disks'] - [root] = match(sda['partitions'], mount='/') + add_resp = await inst.post("/storage/v2/add_partition", data) + [sda] = add_resp["disks"] + [root] = match(sda["partitions"], mount="/") # Now let's make sure we get the results from a probert run to kick # in. - await inst.post('/storage/dry_run_wait_probe') + await inst.post("/storage/dry_run_wait_probe") data = { - 'disk_id': 'disk-sda', - 'partition': { - 'number': root['number'], - } + "disk_id": "disk-sda", + "partition": { + "number": root["number"], + }, } # We should be able to modify the created partition. - await inst.post('/storage/v2/edit_partition', data) + await inst.post("/storage/v2/edit_partition", data) class TestCancel(TestAPI): @timeout() async def test_cancel_drivers(self): - with patch.dict(os.environ, {'SUBIQUITY_DEBUG': 'has-drivers'}): - async with start_server('examples/machines/simple.json') as inst: - await inst.post('/source', source_id="placeholder", - search_drivers=True) + with patch.dict(os.environ, {"SUBIQUITY_DEBUG": "has-drivers"}): + async with start_server("examples/machines/simple.json") as inst: + await inst.post("/source", source_id="placeholder", search_drivers=True) # /drivers?wait=true is expected to block until APT is # configured. # Let's make sure we cancel it. with self.assertRaises(asyncio.TimeoutError): - await asyncio.wait_for(inst.get('/drivers', wait=True), - 0.1) - names = ['locale', 'keyboard', 'source', 'network', 'proxy', - 'mirror', 'storage'] - await inst.post('/meta/mark_configured', endpoint_names=names) - await inst.get('/meta/status', cur='WAITING') - await inst.post('/meta/confirm', tty='/dev/tty1') - await inst.get('/meta/status', cur='NEEDS_CONFIRMATION') + await asyncio.wait_for(inst.get("/drivers", wait=True), 0.1) + names = [ + "locale", + "keyboard", + "source", + "network", + "proxy", + "mirror", + "storage", + ] + await inst.post("/meta/mark_configured", endpoint_names=names) + await inst.get("/meta/status", cur="WAITING") + await inst.post("/meta/confirm", tty="/dev/tty1") + await inst.get("/meta/status", cur="NEEDS_CONFIRMATION") # should not raise ServerDisconnectedError - resp = await inst.get('/drivers', wait=True) - self.assertEqual(['nvidia-driver-470-server'], resp['drivers']) + resp = await inst.get("/drivers", wait=True) + self.assertEqual(["nvidia-driver-470-server"], resp["drivers"]) class TestDrivers(TestAPI): async def _test_source(self, source_id, expected_driver): - with patch.dict(os.environ, {'SUBIQUITY_DEBUG': 'has-drivers'}): - cfg = 'examples/machines/simple.json' - extra = ['--source-catalog', 'examples/sources/mixed.yaml'] + with patch.dict(os.environ, {"SUBIQUITY_DEBUG": "has-drivers"}): + cfg = "examples/machines/simple.json" + extra = ["--source-catalog", "examples/sources/mixed.yaml"] async with start_server(cfg, extra_args=extra) as inst: - await inst.post('/source', source_id=source_id, - search_drivers=True) + await inst.post("/source", source_id=source_id, search_drivers=True) - names = ['locale', 'keyboard', 'source', 'network', 'proxy', - 'mirror', 'storage'] - await inst.post('/meta/mark_configured', endpoint_names=names) - await inst.get('/meta/status', cur='WAITING') - await inst.post('/meta/confirm', tty='/dev/tty1') - await inst.get('/meta/status', cur='NEEDS_CONFIRMATION') + names = [ + "locale", + "keyboard", + "source", + "network", + "proxy", + "mirror", + "storage", + ] + await inst.post("/meta/mark_configured", endpoint_names=names) + await inst.get("/meta/status", cur="WAITING") + await inst.post("/meta/confirm", tty="/dev/tty1") + await inst.get("/meta/status", cur="NEEDS_CONFIRMATION") - resp = await inst.get('/drivers', wait=True) - self.assertEqual([expected_driver], resp['drivers']) + resp = await inst.get("/drivers", wait=True) + self.assertEqual([expected_driver], resp["drivers"]) @timeout() async def test_server_source(self): - await self._test_source('ubuntu-server-minimal', - 'nvidia-driver-470-server') + await self._test_source("ubuntu-server-minimal", "nvidia-driver-470-server") @timeout() async def test_desktop_source(self): - await self._test_source('ubuntu-desktop', 'nvidia-driver-510') + await self._test_source("ubuntu-desktop", "nvidia-driver-510") @timeout() async def test_listing_ongoing(self): - ''' Ensure that the list of drivers returned by /drivers is null while - the list has not been retrieved. ''' - async with start_server('examples/machines/simple.json') as inst: - resp = await inst.get('/drivers', wait=False) - self.assertIsNone(resp['drivers']) + """Ensure that the list of drivers returned by /drivers is null while + the list has not been retrieved.""" + async with start_server("examples/machines/simple.json") as inst: + resp = await inst.get("/drivers", wait=False) + self.assertIsNone(resp["drivers"]) # POSTing to /source will restart the retrieval operation. - await inst.post('/source', source_id='ubuntu-server', - search_drivers=True) + await inst.post("/source", source_id="ubuntu-server", search_drivers=True) - resp = await inst.get('/drivers', wait=False) - self.assertIsNone(resp['drivers']) + resp = await inst.get("/drivers", wait=False) + self.assertIsNone(resp["drivers"]) class TestOEM(TestAPI): @timeout() async def test_listing_ongoing(self): - ''' Ensure that the list of OEM metapackages returned by /oem is - null while the list has not been retrieved. ''' - async with start_server('examples/machines/simple.json') as inst: - resp = await inst.get('/oem', wait=False) - self.assertIsNone(resp['metapackages']) + """Ensure that the list of OEM metapackages returned by /oem is + null while the list has not been retrieved.""" + async with start_server("examples/machines/simple.json") as inst: + resp = await inst.get("/oem", wait=False) + self.assertIsNone(resp["metapackages"]) async def test_listing_empty(self): expected_pkgs = [] - with patch.dict(os.environ, {'SUBIQUITY_DEBUG': 'no-drivers'}): - async with start_server('examples/machines/simple.json') as inst: - await inst.post('/source', source_id='ubuntu-server') - names = ['locale', 'keyboard', 'source', 'network', 'proxy', - 'mirror', 'storage'] - await inst.post('/meta/mark_configured', endpoint_names=names) - await inst.get('/meta/status', cur='WAITING') - await inst.post('/meta/confirm', tty='/dev/tty1') - await inst.get('/meta/status', cur='NEEDS_CONFIRMATION') + with patch.dict(os.environ, {"SUBIQUITY_DEBUG": "no-drivers"}): + async with start_server("examples/machines/simple.json") as inst: + await inst.post("/source", source_id="ubuntu-server") + names = [ + "locale", + "keyboard", + "source", + "network", + "proxy", + "mirror", + "storage", + ] + await inst.post("/meta/mark_configured", endpoint_names=names) + await inst.get("/meta/status", cur="WAITING") + await inst.post("/meta/confirm", tty="/dev/tty1") + await inst.get("/meta/status", cur="NEEDS_CONFIRMATION") - resp = await inst.get('/oem', wait=True) - self.assertEqual(expected_pkgs, resp['metapackages']) + resp = await inst.get("/oem", wait=True) + self.assertEqual(expected_pkgs, resp["metapackages"]) - async def _test_listing_certified(self, source_id: str, - expected: List[str]): - with patch.dict(os.environ, {'SUBIQUITY_DEBUG': 'has-drivers'}): - args = ['--source-catalog', 'examples/sources/mixed.yaml'] - config = 'examples/machines/simple.json' + async def _test_listing_certified(self, source_id: str, expected: List[str]): + with patch.dict(os.environ, {"SUBIQUITY_DEBUG": "has-drivers"}): + args = ["--source-catalog", "examples/sources/mixed.yaml"] + config = "examples/machines/simple.json" async with start_server(config, extra_args=args) as inst: - await inst.post('/source', source_id=source_id) - names = ['locale', 'keyboard', 'source', 'network', 'proxy', - 'mirror', 'storage'] - await inst.post('/meta/mark_configured', endpoint_names=names) - await inst.get('/meta/status', cur='WAITING') - await inst.post('/meta/confirm', tty='/dev/tty1') - await inst.get('/meta/status', cur='NEEDS_CONFIRMATION') + await inst.post("/source", source_id=source_id) + names = [ + "locale", + "keyboard", + "source", + "network", + "proxy", + "mirror", + "storage", + ] + await inst.post("/meta/mark_configured", endpoint_names=names) + await inst.get("/meta/status", cur="WAITING") + await inst.post("/meta/confirm", tty="/dev/tty1") + await inst.get("/meta/status", cur="NEEDS_CONFIRMATION") - resp = await inst.get('/oem', wait=True) - self.assertEqual(expected, resp['metapackages']) + resp = await inst.get("/oem", wait=True) + self.assertEqual(expected, resp["metapackages"]) async def test_listing_certified_ubuntu_server(self): # Listing of OEM meta-packages is intentionally disabled on # ubuntu-server. - await self._test_listing_certified( - source_id='ubuntu-server', - expected=[]) + await self._test_listing_certified(source_id="ubuntu-server", expected=[]) async def test_listing_certified_ubuntu_desktop(self): await self._test_listing_certified( - source_id='ubuntu-desktop', - expected=['oem-somerville-tentacool-meta']) + source_id="ubuntu-desktop", expected=["oem-somerville-tentacool-meta"] + ) class TestSource(TestAPI): @timeout() async def test_optional_search_drivers(self): - async with start_server('examples/machines/simple.json') as inst: - await inst.post('/source', source_id='ubuntu-server') - resp = await inst.get('/source') - self.assertFalse(resp['search_drivers']) + async with start_server("examples/machines/simple.json") as inst: + await inst.post("/source", source_id="ubuntu-server") + resp = await inst.get("/source") + self.assertFalse(resp["search_drivers"]) - await inst.post('/source', source_id='ubuntu-server', - search_drivers=True) - resp = await inst.get('/source') - self.assertTrue(resp['search_drivers']) + await inst.post("/source", source_id="ubuntu-server", search_drivers=True) + resp = await inst.get("/source") + self.assertTrue(resp["search_drivers"]) - await inst.post('/source', source_id='ubuntu-server', - search_drivers=False) - resp = await inst.get('/source') - self.assertFalse(resp['search_drivers']) + await inst.post("/source", source_id="ubuntu-server", search_drivers=False) + resp = await inst.get("/source") + self.assertFalse(resp["search_drivers"]) class TestIdentityValidation(TestAPI): @timeout() async def test_username_validation(self): - async with start_server('examples/machines/simple.json') as inst: - resp = await inst.get('/identity/validate_username', - username='plugdev') - self.assertEqual(resp, 'SYSTEM_RESERVED') + async with start_server("examples/machines/simple.json") as inst: + resp = await inst.get("/identity/validate_username", username="plugdev") + self.assertEqual(resp, "SYSTEM_RESERVED") - resp = await inst.get('/identity/validate_username', - username='root') - self.assertEqual(resp, 'ALREADY_IN_USE') + resp = await inst.get("/identity/validate_username", username="root") + self.assertEqual(resp, "ALREADY_IN_USE") - resp = await inst.get('/identity/validate_username', - username='r'*33) - self.assertEqual(resp, 'TOO_LONG') + resp = await inst.get("/identity/validate_username", username="r" * 33) + self.assertEqual(resp, "TOO_LONG") - resp = await inst.get('/identity/validate_username', - username='01root') - self.assertEqual(resp, 'INVALID_CHARS') + resp = await inst.get("/identity/validate_username", username="01root") + self.assertEqual(resp, "INVALID_CHARS") - resp = await inst.get('/identity/validate_username', - username='o#$%^&') - self.assertEqual(resp, 'INVALID_CHARS') + resp = await inst.get("/identity/validate_username", username="o#$%^&") + self.assertEqual(resp, "INVALID_CHARS") class TestManyPrimaries(TestAPI): @timeout() async def test_create_primaries(self): - cfg = 'examples/machines/simple.json' - extra = ['--storage-version', '2'] + cfg = "examples/machines/simple.json" + extra = ["--storage-version", "2"] async with start_server(cfg, extra_args=extra) as inst: - resp = await inst.get('/storage/v2') - d1 = resp['disks'][0] + resp = await inst.get("/storage/v2") + d1 = resp["disks"][0] - data = {'disk_id': d1['id'], 'ptable': 'msdos'} - resp = await inst.post('/storage/v2/reformat_disk', data) - [gap] = match(resp['disks'][0]['partitions'], _type='Gap') + data = {"disk_id": d1["id"], "ptable": "msdos"} + resp = await inst.post("/storage/v2/reformat_disk", data) + [gap] = match(resp["disks"][0]["partitions"], _type="Gap") for _ in range(4): - self.assertEqual('YES', gap['usable']) + self.assertEqual("YES", gap["usable"]) data = { - 'disk_id': d1['id'], - 'gap': gap, - 'partition': { - 'size': 1 << 30, - 'format': 'ext4', - } + "disk_id": d1["id"], + "gap": gap, + "partition": { + "size": 1 << 30, + "format": "ext4", + }, } - resp = await inst.post('/storage/v2/add_partition', data) - [gap] = match(resp['disks'][0]['partitions'], _type='Gap') + resp = await inst.post("/storage/v2/add_partition", data) + [gap] = match(resp["disks"][0]["partitions"], _type="Gap") - self.assertEqual('TOO_MANY_PRIMARY_PARTS', gap['usable']) + self.assertEqual("TOO_MANY_PRIMARY_PARTS", gap["usable"]) data = { - 'disk_id': d1['id'], - 'gap': gap, - 'partition': { - 'size': 1 << 30, - 'format': 'ext4', - } + "disk_id": d1["id"], + "gap": gap, + "partition": { + "size": 1 << 30, + "format": "ext4", + }, } with self.assertRaises(ClientResponseError): - await inst.post('/storage/v2/add_partition', data) + await inst.post("/storage/v2/add_partition", data) class TestKeyboard(TestAPI): @timeout() async def test_input_source(self): - async with start_server('examples/machines/simple.json') as inst: - data = {'layout': 'fr', 'variant': 'latin9'} - await inst.post('/keyboard/input_source', data, user='foo') + async with start_server("examples/machines/simple.json") as inst: + data = {"layout": "fr", "variant": "latin9"} + await inst.post("/keyboard/input_source", data, user="foo") class TestUbuntuProContractSelection(TestAPI): @timeout() async def test_upcs_flow(self): - async with start_server('examples/machines/simple.json') as inst: + async with start_server("examples/machines/simple.json") as inst: # Wait should fail if no initiate first. with self.assertRaises(Exception): - await inst.get('/ubuntu_pro/contract_selection/wait') + await inst.get("/ubuntu_pro/contract_selection/wait") # Cancel should fail if no initiate first. with self.assertRaises(Exception): - await inst.post('/ubuntu_pro/contract_selection/cancel') + await inst.post("/ubuntu_pro/contract_selection/cancel") - await inst.post('/ubuntu_pro/contract_selection/initiate') + await inst.post("/ubuntu_pro/contract_selection/initiate") # Double initiate should fail with self.assertRaises(Exception): - await inst.post('/ubuntu_pro/contract_selection/initiate') + await inst.post("/ubuntu_pro/contract_selection/initiate") # This call should block for long enough. with self.assertRaises(asyncio.TimeoutError): await asyncio.wait_for( - inst.get('/ubuntu_pro/contract_selection/wait'), - timeout=0.5) + inst.get("/ubuntu_pro/contract_selection/wait"), timeout=0.5 + ) - await inst.post('/ubuntu_pro/contract_selection/cancel') + await inst.post("/ubuntu_pro/contract_selection/cancel") with self.assertRaises(Exception): - await inst.get('/ubuntu_pro/contract_selection/wait') + await inst.get("/ubuntu_pro/contract_selection/wait") class TestAutoinstallServer(TestAPI): @timeout(2) async def test_make_view_requests(self): - cfg = 'examples/machines/simple.json' + cfg = "examples/machines/simple.json" extra = [ - '--autoinstall', 'examples/autoinstall/short.yaml', - '--source-catalog', 'examples/sources/install.yaml', + "--autoinstall", + "examples/autoinstall/short.yaml", + "--source-catalog", + "examples/sources/install.yaml", ] - async with start_server(cfg, extra_args=extra, - set_first_source=False) as inst: + async with start_server(cfg, extra_args=extra, set_first_source=False) as inst: view_request_unspecified, resp = await inst.get( - '/locale', - full_response=True) - self.assertEqual('en_US.UTF-8', view_request_unspecified) - self.assertEqual('ok', resp.headers['x-status']) + "/locale", full_response=True + ) + self.assertEqual("en_US.UTF-8", view_request_unspecified) + self.assertEqual("ok", resp.headers["x-status"]) view_request_no, resp = await inst.get( - '/locale', - headers={'x-make-view-request': 'no'}, - full_response=True) - self.assertEqual('en_US.UTF-8', view_request_no) - self.assertEqual('ok', resp.headers['x-status']) + "/locale", headers={"x-make-view-request": "no"}, full_response=True + ) + self.assertEqual("en_US.UTF-8", view_request_no) + self.assertEqual("ok", resp.headers["x-status"]) view_request_yes, resp = await inst.get( - '/locale', - headers={'x-make-view-request': 'yes'}, - full_response=True) + "/locale", headers={"x-make-view-request": "yes"}, full_response=True + ) self.assertIsNone(view_request_yes) - self.assertEqual('skip', resp.headers['x-status']) + self.assertEqual("skip", resp.headers["x-status"]) @timeout() async def test_interactive(self): - cfg = 'examples/machines/simple.json' - with tempfile.NamedTemporaryFile(mode='w') as tf: - tf.write(''' + cfg = "examples/machines/simple.json" + with tempfile.NamedTemporaryFile(mode="w") as tf: + tf.write( + """ version: 1 interactive-sections: ['*'] - ''') + """ + ) tf.flush() - extra = ['--autoinstall', tf.name] + extra = ["--autoinstall", tf.name] async with start_server(cfg, extra_args=extra) as inst: - resp = await inst.get('/meta/interactive_sections') - expected = set([ - 'locale', 'refresh-installer', 'keyboard', 'source', - 'network', 'ubuntu-pro', 'proxy', 'apt', 'storage', - 'identity', 'ssh', 'snaps', 'codecs', 'drivers', - 'timezone', 'updates', 'shutdown', - ]) + resp = await inst.get("/meta/interactive_sections") + expected = set( + [ + "locale", + "refresh-installer", + "keyboard", + "source", + "network", + "ubuntu-pro", + "proxy", + "apt", + "storage", + "identity", + "ssh", + "snaps", + "codecs", + "drivers", + "timezone", + "updates", + "shutdown", + ] + ) self.assertTrue(expected.issubset(resp)) class TestWSLSetupOptions(TestAPI): @timeout() async def test_wslsetupoptions(self): - cfg = 'examples/machines/simple.json' + cfg = "examples/machines/simple.json" async with start_system_setup_server(cfg) as inst: - await inst.post('/meta/client_variant', variant='wsl_setup') + await inst.post("/meta/client_variant", variant="wsl_setup") - payload = {'install_language_support_packages': False} - endpoint = '/wslsetupoptions' + payload = {"install_language_support_packages": False} + endpoint = "/wslsetupoptions" resp = await inst.get(endpoint) - self.assertTrue(resp['install_language_support_packages']) + self.assertTrue(resp["install_language_support_packages"]) await inst.post(endpoint, payload) resp = await inst.get(endpoint) - self.assertFalse(resp['install_language_support_packages']) + self.assertFalse(resp["install_language_support_packages"]) class TestActiveDirectory(TestAPI): @@ -1944,61 +1967,62 @@ class TestActiveDirectory(TestAPI): async def test_ad(self): # Few tests to assert that the controller is properly wired. # Exhaustive validation test cases are in the unit tests. - cfg = 'examples/machines/simple.json' + cfg = "examples/machines/simple.json" async with start_server(cfg) as instance: - endpoint = '/active_directory' + endpoint = "/active_directory" ad_dict = await instance.get(endpoint) # Starts with the detected domain. - self.assertEqual('ubuntu.com', ad_dict['domain_name']) + self.assertEqual("ubuntu.com", ad_dict["domain_name"]) # Post works by "returning None" ad_dict = { - 'admin_name': 'Ubuntu', - 'domain_name': 'ubuntu.com', - 'password': 'u', + "admin_name": "Ubuntu", + "domain_name": "ubuntu.com", + "password": "u", } result = await instance.post(endpoint, ad_dict) self.assertIsNone(result) # Rejects empty password. - result = await instance.post(endpoint+'/check_password', - data='') - self.assertEqual('EMPTY', result) + result = await instance.post(endpoint + "/check_password", data="") + self.assertEqual("EMPTY", result) # Rejects invalid domain controller names. - result = await instance.post(endpoint + '/check_domain_name', - data='..ubuntu.com') + result = await instance.post( + endpoint + "/check_domain_name", data="..ubuntu.com" + ) - self.assertIn('MULTIPLE_DOTS', result) + self.assertIn("MULTIPLE_DOTS", result) # Rejects invalid usernames. - result = await instance.post(endpoint + '/check_admin_name', - data='ubuntu;pro') - self.assertEqual('INVALID_CHARS', result) + result = await instance.post( + endpoint + "/check_admin_name", data="ubuntu;pro" + ) + self.assertEqual("INVALID_CHARS", result) # Notice that lowercase is not required. - result = await instance.post(endpoint + '/check_admin_name', - data='$Ubuntu') - self.assertEqual('OK', result) + result = await instance.post(endpoint + "/check_admin_name", data="$Ubuntu") + self.assertEqual("OK", result) # Leverages the stub ping strategy - result = await instance.post(endpoint + '/ping_domain_controller', - data='rockbuntu.com') - self.assertEqual('REALM_NOT_FOUND', result) + result = await instance.post( + endpoint + "/ping_domain_controller", data="rockbuntu.com" + ) + self.assertEqual("REALM_NOT_FOUND", result) # Attempts to join with the info supplied above. ad_dict = { - 'admin_name': 'Ubuntu', - 'domain_name': 'jubuntu.com', - 'password': 'u', + "admin_name": "Ubuntu", + "domain_name": "jubuntu.com", + "password": "u", } result = await instance.post(endpoint, ad_dict) self.assertIsNone(result) - join_result_ep = endpoint + '/join_result' + join_result_ep = endpoint + "/join_result" # Without wait this shouldn't block but the result is unknown until # the install controller runs. join_result = await instance.get(join_result_ep, wait=False) - self.assertEqual('UNKNOWN', join_result) + self.assertEqual("UNKNOWN", join_result) # And without the installer controller running, a blocking call # should timeout since joining never happens. with self.assertRaises(asyncio.exceptions.TimeoutError): @@ -2008,30 +2032,29 @@ class TestActiveDirectory(TestAPI): # Helper method @staticmethod async def target_packages() -> List[str]: - """ Returns the list of packages the AD Model wants to install in the - target system.""" + """Returns the list of packages the AD Model wants to install in the + target system.""" from subiquity.models.ad import AdModel + model = AdModel() model.do_join = True return await model.target_packages() async def packages_lookup(self, log_dir: str) -> Dict[str, bool]: - """ Returns a dictionary mapping the additional packages expected - to be installed in the target system and whether they were - referred to or not in the server log. """ + """Returns a dictionary mapping the additional packages expected + to be installed in the target system and whether they were + referred to or not in the server log.""" expected_packages = await self.target_packages() packages_lookup = {p: False for p in expected_packages} log_path = os.path.join(log_dir, "subiquity-server-debug.log") - find_start = \ - 'finish: subiquity/Install/install/postinstall/install_{}:' - log_status = ' SUCCESS: installing {}' + find_start = "finish: subiquity/Install/install/postinstall/install_{}:" + log_status = " SUCCESS: installing {}" with open(log_path, encoding="utf-8") as f: lines = f.readlines() for line in lines: for pack in packages_lookup: - find_line = find_start.format(pack) + \ - log_status.format(pack) + find_line = find_start.format(pack) + log_status.format(pack) pack_found = re.search(find_line, line) is not None if pack_found: packages_lookup[pack] = True @@ -2040,32 +2063,35 @@ class TestActiveDirectory(TestAPI): @timeout() async def test_ad_autoinstall(self): - cfg = 'examples/machines/simple.json' + cfg = "examples/machines/simple.json" extra = [ - '--autoinstall', 'examples/autoinstall/ad.yaml', - '--source-catalog', 'examples/sources/mixed.yaml', - '--kernel-cmdline', 'autoinstall', + "--autoinstall", + "examples/autoinstall/ad.yaml", + "--source-catalog", + "examples/sources/mixed.yaml", + "--kernel-cmdline", + "autoinstall", ] try: - async with start_server(cfg, extra_args=extra, - set_first_source=False) as inst: - endpoint = '/active_directory' + async with start_server( + cfg, extra_args=extra, set_first_source=False + ) as inst: + endpoint = "/active_directory" logdir = inst.output_base() self.assertIsNotNone(logdir) - await inst.post('/meta/client_variant', variant='desktop') + await inst.post("/meta/client_variant", variant="desktop") ad_info = await inst.get(endpoint) - self.assertIsNotNone(ad_info['admin_name']) - self.assertIsNotNone(ad_info['domain_name']) - self.assertEqual('', ad_info['password']) - ad_info['password'] = 'passw0rd' + self.assertIsNotNone(ad_info["admin_name"]) + self.assertIsNotNone(ad_info["domain_name"]) + self.assertEqual("", ad_info["password"]) + ad_info["password"] = "passw0rd" # This should be enough to configure AD controller and cause it # to install packages into and try joining the target system. await inst.post(endpoint, ad_info) # Now this shouldn't hang or timeout - join_result = await inst.get(endpoint + '/join_result', - wait=True) - self.assertEqual('OK', join_result) + join_result = await inst.get(endpoint + "/join_result", wait=True) + self.assertEqual("OK", join_result) packages = await self.packages_lookup(logdir) for k, v in packages.items(): print(f"Checking package {k}") @@ -2082,10 +2108,10 @@ class TestMountDetection(TestAPI): async def test_mount_detection(self): # Test that the partition the installer is running from is # correctly identified as mounted. - cfg = 'examples/machines/booted-from-rp.json' + cfg = "examples/machines/booted-from-rp.json" async with start_server(cfg) as instance: - result = await instance.get('/storage/v2') - [disk1] = result['disks'] - self.assertTrue(disk1['has_in_use_partition']) - disk1p2 = disk1['partitions'][1] - self.assertTrue(disk1p2['is_in_use']) + result = await instance.get("/storage/v2") + [disk1] = result["disks"] + self.assertTrue(disk1["has_in_use_partition"]) + disk1p2 = disk1["partitions"][1] + self.assertTrue(disk1p2["is_in_use"]) diff --git a/subiquity/tests/fakes.py b/subiquity/tests/fakes.py index ad290ce7..cc7f8ec3 100644 --- a/subiquity/tests/fakes.py +++ b/subiquity/tests/fakes.py @@ -1,8 +1,9 @@ import os + import yaml -TOP_DIR = os.path.join('/'.join(__file__.split('/')[:-3])) -TEST_DATA = os.path.join(TOP_DIR, 'subiquity', 'tests', 'data') -FAKE_MACHINE_JSON = os.path.join(TEST_DATA, 'fake_machine.json') +TOP_DIR = os.path.join("/".join(__file__.split("/")[:-3])) +TEST_DATA = os.path.join(TOP_DIR, "subiquity", "tests", "data") +FAKE_MACHINE_JSON = os.path.join(TEST_DATA, "fake_machine.json") FAKE_MACHINE_JSON_DATA = yaml.safe_load(open(FAKE_MACHINE_JSON)) -FAKE_MACHINE_STORAGE_DATA = FAKE_MACHINE_JSON_DATA.get('storage') +FAKE_MACHINE_STORAGE_DATA = FAKE_MACHINE_JSON_DATA.get("storage") diff --git a/subiquity/tests/test_timezonecontroller.py b/subiquity/tests/test_timezonecontroller.py index 50f800e5..c7cae079 100644 --- a/subiquity/tests/test_timezonecontroller.py +++ b/subiquity/tests/test_timezonecontroller.py @@ -23,15 +23,15 @@ from subiquitycore.tests.mocks import make_app class MockGeoIP: - text = '' + text = "" @property def timezone(self): return self.text -tz_denver = 'America/Denver' -tz_utc = 'Etc/UTC' +tz_denver = "America/Denver" +tz_utc = "Etc/UTC" class TestTimeZoneController(SubiTestCase): @@ -44,68 +44,71 @@ class TestTimeZoneController(SubiTestCase): self.tzc.app.geoip = MockGeoIP() self.tzc.app.geoip.text = tz_denver - @mock.patch('subiquity.server.controllers.timezone.timedatectl_settz') - @mock.patch('subiquity.server.controllers.timezone.timedatectl_gettz') - @mock.patch('subiquity.server.controllers.timezone.generate_possible_tzs') + @mock.patch("subiquity.server.controllers.timezone.timedatectl_settz") + @mock.patch("subiquity.server.controllers.timezone.timedatectl_gettz") + @mock.patch("subiquity.server.controllers.timezone.generate_possible_tzs") async def test_good_tzs(self, generate_possible_tzs, tdc_gettz, tdc_settz): - generate_possible_tzs.return_value = \ - ['', 'geoip', 'Pacific/Auckland', 'America/Denver'] + generate_possible_tzs.return_value = [ + "", + "geoip", + "Pacific/Auckland", + "America/Denver", + ] tdc_gettz.return_value = tz_utc goods = [ # val - autoinstall value # | settz - should system set timezone # | | geoip - did we get a value by geoip? # | | | valid_lookup - ('geoip', True, True, True), - ('geoip', False, True, False), + ("geoip", True, True, True), + ("geoip", False, True, False), # empty val is valid and means to set no time zone - ('', False, False, True), - ('Pacific/Auckland', True, False, False), - ('America/Denver', True, False, False), + ("", False, False, True), + ("Pacific/Auckland", True, False, False), + ("America/Denver", True, False, False), ] for val, settz, geoip, valid_lookup in goods: self.tzc_init() - if not val or val == 'geoip': + if not val or val == "geoip": if valid_lookup: tz = TimeZoneInfo(tz_denver, True) else: tz = TimeZoneInfo(tz_utc, False) - self.tzc.app.geoip.text = '' + self.tzc.app.geoip.text = "" else: tz = TimeZoneInfo(val, False) self.tzc.deserialize(val) self.assertEqual(val, self.tzc.serialize()) - self.assertEqual(settz, self.tzc.model.should_set_tz, - self.tzc.model) - self.assertEqual(geoip, self.tzc.model.detect_with_geoip, - self.tzc.model) + self.assertEqual(settz, self.tzc.model.should_set_tz, self.tzc.model) + self.assertEqual(geoip, self.tzc.model.detect_with_geoip, self.tzc.model) self.assertEqual(tz, await self.tzc.GET(), self.tzc.model) cloudconfig = {} if self.tzc.model.should_set_tz: - cloudconfig = {'timezone': tz.timezone} + cloudconfig = {"timezone": tz.timezone} tdc_settz.assert_called_with(self.tzc.app, tz.timezone) - self.assertEqual(cloudconfig, self.tzc.model.make_cloudconfig(), - self.tzc.model) + self.assertEqual( + cloudconfig, self.tzc.model.make_cloudconfig(), self.tzc.model + ) def test_bad_tzs(self): bads = [ - 'dhcp', # possible future value, not supported yet - 'notatimezone', + "dhcp", # possible future value, not supported yet + "notatimezone", ] for b in bads: with self.assertRaises(ValueError): self.tzc.deserialize(b) - @mock.patch('subprocess.run') - @mock.patch('subiquity.server.controllers.timezone.timedatectl_gettz') + @mock.patch("subprocess.run") + @mock.patch("subiquity.server.controllers.timezone.timedatectl_gettz") def test_set_tz_escape_dryrun(self, tdc_gettz, subprocess_run): tdc_gettz.return_value = tz_utc self.tzc.app.dry_run = True - self.tzc.possible = ['geoip'] - self.tzc.deserialize('geoip') - self.assertEqual('sleep', subprocess_run.call_args.args[0][0]) + self.tzc.possible = ["geoip"] + self.tzc.deserialize("geoip") + self.assertEqual("sleep", subprocess_run.call_args.args[0][0]) - @mock.patch('subiquity.server.controllers.timezone.timedatectl_settz') + @mock.patch("subiquity.server.controllers.timezone.timedatectl_settz") async def test_get_tz_should_not_set(self, tdc_settz): await self.tzc.GET() self.assertFalse(self.tzc.model.should_set_tz) diff --git a/subiquity/ui/frame.py b/subiquity/ui/frame.py index 798d917e..a13de48b 100644 --- a/subiquity/ui/frame.py +++ b/subiquity/ui/frame.py @@ -18,12 +18,10 @@ import logging from subiquitycore.ui.frame import SubiquityCoreUI from subiquitycore.view import BaseView - -log = logging.getLogger('subiquity.ui.frame') +log = logging.getLogger("subiquity.ui.frame") class SubiquityUI(SubiquityCoreUI): - block_input = False def __init__(self, app, right_icon): diff --git a/subiquity/ui/mount.py b/subiquity/ui/mount.py index 214ab968..92af83c2 100644 --- a/subiquity/ui/mount.py +++ b/subiquity/ui/mount.py @@ -1,40 +1,30 @@ - import re -from urwid import ( - connect_signal, - Text, - ) +from urwid import Text, connect_signal -from subiquitycore.ui.container import ( - Columns, - Pile, - WidgetWrap, - ) +from subiquitycore.ui.container import Columns, Pile, WidgetWrap from subiquitycore.ui.form import FormField from subiquitycore.ui.interactive import Selector, StringEditor from subiquitycore.ui.utils import Color - common_mountpoints = [ - '/', - '/boot', - '/home', - '/srv', - '/usr', - '/var', - '/var/lib', - ] + "/", + "/boot", + "/home", + "/srv", + "/usr", + "/var", + "/var/lib", +] class _MountEditor(StringEditor): - """ Mountpoint input prompt with input rules - """ + """Mountpoint input prompt with input rules""" def keypress(self, size, key): - ''' restrict what chars we allow for mountpoints ''' + """restrict what chars we allow for mountpoints""" - mountpoint = r'[a-zA-Z0-9_/\.\-]' + mountpoint = r"[a-zA-Z0-9_/\.\-]" if re.match(mountpoint, key) is None: return False @@ -45,16 +35,15 @@ OTHER = object() LEAVE_UNMOUNTED = object() suitable_mountpoints_for_existing_fs = [ - '/home', - '/srv', + "/home", + "/srv", OTHER, LEAVE_UNMOUNTED, - ] +] class MountSelector(WidgetWrap): - - signals = ['change'] + signals = ["change"] def __init__(self, mountpoints): opts = [] @@ -68,11 +57,11 @@ class MountSelector(WidgetWrap): opts.append((mnt, False)) if first_opt is None: first_opt = len(opts) - opts.append((_('Other'), True, OTHER)) - opts.append(('---', False)), - opts.append((_('Leave unmounted'), True, LEAVE_UNMOUNTED)) + opts.append((_("Other"), True, OTHER)) + opts.append(("---", False)), + opts.append((_("Leave unmounted"), True, LEAVE_UNMOUNTED)) self._selector = Selector(opts, first_opt) - connect_signal(self._selector, 'select', self._select_mount) + connect_signal(self._selector, "select", self._select_mount) self._other = _MountEditor() super().__init__(Pile([self._selector])) self._other_showing = False @@ -93,8 +82,11 @@ class MountSelector(WidgetWrap): def _showhide_other(self, show): if show and not self._other_showing: self._w.contents.append( - (Columns([(1, Text("/")), Color.string_input(self._other)]), - self._w.options('pack'))) + ( + Columns([(1, Text("/")), Color.string_input(self._other)]), + self._w.options("pack"), + ) + ) self._other_showing = True elif not show and self._other_showing: del self._w.contents[-1] @@ -105,7 +97,7 @@ class MountSelector(WidgetWrap): if value == OTHER: self._w.focus_position = 1 value = "/" + self._other.value - self._emit('change', value) + self._emit("change", value) @property def value(self): @@ -128,13 +120,12 @@ class MountSelector(WidgetWrap): else: self._selector.value = OTHER self._showhide_other(True) - if not val.startswith('/'): + if not val.startswith("/"): raise ValueError(f"{val} does not start with /") self._other.value = val[1:] class MountField(FormField): - takes_default_style = False def _make_widget(self, form): diff --git a/subiquity/ui/views/__init__.py b/subiquity/ui/views/__init__.py index f6e552af..7197a075 100644 --- a/subiquity/ui/views/__init__.py +++ b/subiquity/ui/views/__init__.py @@ -13,21 +13,19 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from .filesystem import ( - FilesystemView, - GuidedDiskSelectionView, - ) +from .filesystem import FilesystemView, GuidedDiskSelectionView from .identity import IdentityView from .installprogress import ProgressView from .keyboard import KeyboardView from .welcome import WelcomeView from .zdev import ZdevView + __all__ = [ - 'FilesystemView', - 'GuidedDiskSelectionView', - 'IdentityView', - 'KeyboardView', - 'ProgressView', - 'WelcomeView', - 'ZdevView', + "FilesystemView", + "GuidedDiskSelectionView", + "IdentityView", + "KeyboardView", + "ProgressView", + "WelcomeView", + "ZdevView", ] diff --git a/subiquity/ui/views/drivers.py b/subiquity/ui/views/drivers.py index 01e3b7d6..200f9931 100644 --- a/subiquity/ui/views/drivers.py +++ b/subiquity/ui/views/drivers.py @@ -16,44 +16,35 @@ """ Module defining the view for third-party drivers installation. """ -from enum import auto, Enum import logging +from enum import Enum, auto from typing import List, Optional -from urwid import ( - connect_signal, - Text, - ) +from urwid import Text, connect_signal from subiquitycore.async_helpers import run_bg_task from subiquitycore.ui.buttons import back_btn, ok_btn -from subiquitycore.ui.form import ( - Form, - RadioButtonField, -) +from subiquitycore.ui.form import Form, RadioButtonField from subiquitycore.ui.spinner import Spinner from subiquitycore.ui.utils import screen from subiquitycore.view import BaseView - -log = logging.getLogger('subiquity.ui.views.drivers') +log = logging.getLogger("subiquity.ui.views.drivers") class DriversForm(Form): - """ Form that shows a checkbox to configure whether we want to install the - available drivers or not. """ + """Form that shows a checkbox to configure whether we want to install the + available drivers or not.""" cancel_label = _("Back") ok_label = _("Continue") group: List[RadioButtonField] = [] - install = RadioButtonField( - group, - _("Install all third-party drivers")) + install = RadioButtonField(group, _("Install all third-party drivers")) do_not_install = RadioButtonField( - group, - _("Do not install third-party drivers now")) + group, _("Do not install third-party drivers now") + ) class DriversViewStatus(Enum): @@ -63,20 +54,24 @@ class DriversViewStatus(Enum): class DriversView(BaseView): - title = _("Third-party drivers") form = None - def __init__(self, controller, drivers: Optional[List[str]], - install: bool, local_only: bool) -> None: + def __init__( + self, controller, drivers: Optional[List[str]], install: bool, local_only: bool + ) -> None: self.controller = controller self.local_only = local_only self.search_later = [ - Text(_("Note: Once the installation has finished and you are " + - "connected to a network, you can search again for " + - "third-party drivers using the following command:")), + Text( + _( + "Note: Once the installation has finished and you are " + + "connected to a network, you can search again for " + + "third-party drivers using the following command:" + ) + ), Text(""), Text(" $ ubuntu-drivers list --recommended --gpgpu"), ] @@ -89,34 +84,36 @@ class DriversView(BaseView): self.make_main(install, drivers) def make_waiting(self, install: bool) -> None: - """ Change the view into a spinner and start waiting for drivers - asynchronously. """ - self.spinner = Spinner(style='dots') + """Change the view into a spinner and start waiting for drivers + asynchronously.""" + self.spinner = Spinner(style="dots") self.spinner.start() if self.local_only: - looking_for_drivers = _("Not connected to a network. " + - "Looking for applicable third-party " + - "drivers available locally...") + looking_for_drivers = _( + "Not connected to a network. " + + "Looking for applicable third-party " + + "drivers available locally..." + ) else: - looking_for_drivers = _("Looking for applicable third-party " + - "drivers available locally or online...") + looking_for_drivers = _( + "Looking for applicable third-party " + + "drivers available locally or online..." + ) rows = [ Text(looking_for_drivers), Text(""), self.spinner, - ] - self.back_btn = back_btn( - _("Back"), - on_press=lambda sender: self.cancel()) + ] + self.back_btn = back_btn(_("Back"), on_press=lambda sender: self.cancel()) self._w = screen(rows, [self.back_btn]) run_bg_task(self._wait(install)) self.status = DriversViewStatus.WAITING async def _wait(self, install: bool) -> None: - """ Wait until the "list" of drivers is available and change the view - accordingly. """ + """Wait until the "list" of drivers is available and change the view + accordingly.""" drivers = await self.controller._wait_drivers() self.spinner.stop() if drivers: @@ -125,48 +122,50 @@ class DriversView(BaseView): self.make_no_drivers() def make_no_drivers(self) -> None: - """ Change the view into an information page that shows that no - third-party drivers are available for installation. """ + """Change the view into an information page that shows that no + third-party drivers are available for installation.""" if self.local_only: - no_drivers_found = _("No applicable third-party drivers are " + - "available locally.") + no_drivers_found = _( + "No applicable third-party drivers are " + "available locally." + ) else: - no_drivers_found = _("No applicable third-party drivers are " + - "available locally or online.") + no_drivers_found = _( + "No applicable third-party drivers are " + + "available locally or online." + ) rows = [Text(no_drivers_found)] if self.local_only: rows.append(Text("")) rows.extend(self.search_later) - self.cont_btn = ok_btn( - _("Continue"), - on_press=lambda sender: self.done(False)) - self.back_btn = back_btn( - _("Back"), - on_press=lambda sender: self.cancel()) + self.cont_btn = ok_btn(_("Continue"), on_press=lambda sender: self.done(False)) + self.back_btn = back_btn(_("Back"), on_press=lambda sender: self.cancel()) self._w = screen(rows, [self.cont_btn, self.back_btn]) self.status = DriversViewStatus.NO_DRIVERS def make_main(self, install: bool, drivers: List[str]) -> None: - """ Change the view to display the drivers form. """ - self.form = DriversForm(initial={ - "install": bool(install), - "do_not_install": (not install), - }) + """Change the view to display the drivers form.""" + self.form = DriversForm( + initial={ + "install": bool(install), + "do_not_install": (not install), + } + ) excerpt = _( "The following third-party drivers were found. " - "Do you want to install them?") + "Do you want to install them?" + ) def on_cancel(_: DriversForm) -> None: self.cancel() connect_signal( - self.form, 'submit', - lambda result: self.done(result.install.value)) - connect_signal(self.form, 'cancel', on_cancel) + self.form, "submit", lambda result: self.done(result.install.value) + ) + connect_signal(self.form, "cancel", on_cancel) rows = [Text(f"* {driver}") for driver in drivers] rows.append(Text("")) diff --git a/subiquity/ui/views/error.py b/subiquity/ui/views/error.py index 1f8a9278..13453185 100644 --- a/subiquity/ui/views/error.py +++ b/subiquity/ui/views/error.py @@ -16,137 +16,175 @@ import asyncio import logging -from urwid import ( - connect_signal, - disconnect_signal, - Padding, - ProgressBar, - Text, - ) +from urwid import Padding, ProgressBar, Text, connect_signal, disconnect_signal +from subiquity.common.errorreport import ErrorReportKind, ErrorReportState +from subiquity.common.types import CasperMd5Results from subiquitycore.async_helpers import run_bg_task from subiquitycore.ui.buttons import other_btn -from subiquitycore.ui.container import ( - Pile, - ) +from subiquitycore.ui.container import Pile from subiquitycore.ui.spinner import Spinner from subiquitycore.ui.stretchy import Stretchy -from subiquitycore.ui.table import ( - ColSpec, - TablePile, - TableRow, - ) -from subiquitycore.ui.utils import ( - button_pile, - ClickableIcon, - Color, - disabled, - rewrap, - ) -from subiquitycore.ui.width import ( - widget_width, - ) +from subiquitycore.ui.table import ColSpec, TablePile, TableRow +from subiquitycore.ui.utils import ClickableIcon, Color, button_pile, disabled, rewrap +from subiquitycore.ui.width import widget_width -from subiquity.common.errorreport import ( - ErrorReportKind, - ErrorReportState, - ) - -from subiquity.common.types import CasperMd5Results - - -log = logging.getLogger('subiquity.ui.views.error') +log = logging.getLogger("subiquity.ui.views.error") def close_btn(stretchy, label=None): if label is None: label = _("Close") return other_btn( - label, - on_press=lambda sender: stretchy.app.remove_global_overlay(stretchy)) + label, on_press=lambda sender: stretchy.app.remove_global_overlay(stretchy) + ) error_report_intros = { - ErrorReportKind.BLOCK_PROBE_FAIL: _(""" + ErrorReportKind.BLOCK_PROBE_FAIL: _( + """ Sorry, there was a problem examining the storage devices on this system. -"""), - ErrorReportKind.DISK_PROBE_FAIL: _(""" +""" + ), + ErrorReportKind.DISK_PROBE_FAIL: _( + """ Sorry, there was a problem examining the storage devices on this system. -"""), - ErrorReportKind.INSTALL_FAIL: _(""" +""" + ), + ErrorReportKind.INSTALL_FAIL: _( + """ Sorry, there was a problem completing the installation. -"""), - ErrorReportKind.NETWORK_FAIL: _(""" +""" + ), + ErrorReportKind.NETWORK_FAIL: _( + """ Sorry, there was a problem applying the network configuration. -"""), - ErrorReportKind.SERVER_REQUEST_FAIL: _(""" +""" + ), + ErrorReportKind.SERVER_REQUEST_FAIL: _( + """ Sorry, the installer has encountered an internal error. -"""), - ErrorReportKind.UI: _(""" +""" + ), + ErrorReportKind.UI: _( + """ Sorry, the installer has restarted because of an error. -"""), - ErrorReportKind.UNKNOWN: _(""" +""" + ), + ErrorReportKind.UNKNOWN: _( + """ Sorry, an unknown error occurred. -"""), +""" + ), } error_report_state_descriptions = { - ErrorReportState.INCOMPLETE: (_(""" + ErrorReportState.INCOMPLETE: ( + _( + """ Information is being collected from the system that will help the developers diagnose the report. -"""), True), - ErrorReportState.LOADING: (_(""" +""" + ), + True, + ), + ErrorReportState.LOADING: ( + _( + """ Loading report... -"""), True), - ErrorReportState.ERROR_GENERATING: (_(""" +""" + ), + True, + ), + ErrorReportState.ERROR_GENERATING: ( + _( + """ Collecting information from the system failed. See the files in /var/log/installer for more. -"""), False), - ErrorReportState.ERROR_LOADING: (_(""" +""" + ), + False, + ), + ErrorReportState.ERROR_LOADING: ( + _( + """ Loading the report failed. See the files in /var/log/installer for more. -"""), False), +""" + ), + False, + ), } error_report_options = { - ErrorReportKind.BLOCK_PROBE_FAIL: (_(""" + ErrorReportKind.BLOCK_PROBE_FAIL: ( + _( + """ You can continue and the installer will just present the disks present in the system and not other block devices, or you may be able to fix the issue by switching to a shell and reconfiguring the system's block devices manually. -"""), ['debug_shell', 'continue']), - ErrorReportKind.DISK_PROBE_FAIL: (_(""" +""" + ), + ["debug_shell", "continue"], + ), + ErrorReportKind.DISK_PROBE_FAIL: ( + _( + """ You may be able to fix the issue by switching to a shell and reconfiguring the system's block devices manually. -"""), ['debug_shell', 'continue']), - ErrorReportKind.NETWORK_FAIL: (_(""" +""" + ), + ["debug_shell", "continue"], + ), + ErrorReportKind.NETWORK_FAIL: ( + _( + """ You can continue with the installation but it will be assumed the network is not functional. -"""), ['continue']), - ErrorReportKind.SERVER_REQUEST_FAIL: (_(""" +""" + ), + ["continue"], + ), + ErrorReportKind.SERVER_REQUEST_FAIL: ( + _( + """ You can continue or restart the installer. -"""), ['continue', 'restart']), - ErrorReportKind.INSTALL_FAIL: (_(""" +""" + ), + ["continue", "restart"], + ), + ErrorReportKind.INSTALL_FAIL: ( + _( + """ Do you want to try starting the installation again? -"""), ['restart', 'close']), +""" + ), + ["restart", "close"], + ), ErrorReportKind.UI: ( - _("Select continue to try the installation again."), ['continue']), - ErrorReportKind.UNKNOWN: ("", ['close']), + _("Select continue to try the installation again."), + ["continue"], + ), + ErrorReportKind.UNKNOWN: ("", ["close"]), } -submit_text = _(""" +submit_text = _( + """ If you want to help improve the installer, you can send an error report. -""") +""" +) -integrity_check_fail_text = _(""" +integrity_check_fail_text = _( + """ The install media checksum verification failed. It's possible that this crash is related to that checksum failure. Consider verifying the install media and retrying the install. -""") +""" +) class ErrorReportStretchy(Stretchy): - def __init__(self, app, ref, interrupting=True): self.app = app self.error_ref = ref @@ -156,58 +194,55 @@ class ErrorReportStretchy(Stretchy): if self.report is None: run_bg_task(self._wait()) else: - connect_signal(self.report, 'changed', self._report_changed) + connect_signal(self.report, "changed", self._report_changed) self.report.mark_seen() self.interrupting = interrupting self.min_wait = asyncio.create_task(asyncio.sleep(0.1)) self.btns = { - 'cancel': other_btn( - _("Cancel upload"), on_press=self.cancel_upload), - 'close': close_btn(self, _("Close report")), - 'continue': close_btn(self, _("Continue")), - 'debug_shell': other_btn( - _("Switch to a shell"), on_press=self.debug_shell), - 'restart': other_btn( - _("Restart the installer"), on_press=self.restart), - 'submit': other_btn( - _("Send to Canonical"), on_press=self.submit), - 'submitted': disabled(other_btn(_("Sent to Canonical"))), - 'view': other_btn( - _("View full report"), on_press=self.view_report), - } + "cancel": other_btn(_("Cancel upload"), on_press=self.cancel_upload), + "close": close_btn(self, _("Close report")), + "continue": close_btn(self, _("Continue")), + "debug_shell": other_btn(_("Switch to a shell"), on_press=self.debug_shell), + "restart": other_btn(_("Restart the installer"), on_press=self.restart), + "submit": other_btn(_("Send to Canonical"), on_press=self.submit), + "submitted": disabled(other_btn(_("Sent to Canonical"))), + "view": other_btn(_("View full report"), on_press=self.view_report), + } w = 0 for n, b in self.btns.items(): w = max(w, widget_width(b)) for n, b in self.btns.items(): - self.btns[n] = Padding(b, width=w, align='center') + self.btns[n] = Padding(b, width=w, align="center") - self.spinner = Spinner(style='dots') + self.spinner = Spinner(style="dots") self.pile = Pile([]) self.pile.contents[:] = [ - (w, self.pile.options('pack')) for w in self._pile_elements()] + (w, self.pile.options("pack")) for w in self._pile_elements() + ] super().__init__("", [self.pile], 0, 0) - connect_signal(self, 'closed', self.spinner.stop) + connect_signal(self, "closed", self.spinner.stop) async def _wait(self): - self.report = await self.app.error_reporter.get_wait( - self.error_ref) + self.report = await self.app.error_reporter.get_wait(self.error_ref) self.error_ref = self.report.ref() - connect_signal(self.report, 'changed', self._report_changed) + connect_signal(self.report, "changed", self._report_changed) self.report.mark_seen() await self._report_changed_() def pb(self, upload): pb = ProgressBar( - normal='progress_incomplete', - complete='progress_complete', + normal="progress_incomplete", + complete="progress_complete", current=upload.bytes_sent, - done=upload.bytes_to_send) + done=upload.bytes_to_send, + ) def _progress(): pb.done = upload.bytes_to_send pb.current = upload.bytes_sent - connect_signal(upload, 'progress', _progress) + + connect_signal(upload, "progress", _progress) return pb @@ -217,13 +252,13 @@ class ErrorReportStretchy(Stretchy): widgets = [ Text(rewrap(_(error_report_intros[self.error_ref.kind]))), Text(""), - ] + ] self.spinner.stop() if self.error_ref.state == ErrorReportState.DONE: assert self.report - widgets.append(btns['view']) + widgets.append(btns["view"]) widgets.append(Text("")) widgets.append(Text(rewrap(_(submit_text)))) widgets.append(Text("")) @@ -234,36 +269,36 @@ class ErrorReportStretchy(Stretchy): widgets.append(self.upload_pb) else: if self.report.oops_id: - widgets.append(btns['submitted']) + widgets.append(btns["submitted"]) else: - widgets.append(btns['submit']) + widgets.append(btns["submit"]) self.upload_pb = None fs_label, fs_loc = self.report.persistent_details if fs_label is not None: location_text = _( "The error report has been saved to\n\n {loc}\n\non the " - "filesystem with label {label!r}.").format( - loc=fs_loc, label=fs_label) - widgets.extend([ - Text(""), - Text(location_text), - ]) + "filesystem with label {label!r}." + ).format(loc=fs_loc, label=fs_label) + widgets.extend( + [ + Text(""), + Text(location_text), + ] + ) else: text, spin = error_report_state_descriptions[self.error_ref.state] widgets.append(Text(rewrap(_(text)))) if spin: self.spinner.start() - widgets.extend([ - Text(""), - self.spinner]) + widgets.extend([Text(""), self.spinner]) if self.integrity_check_result == CasperMd5Results.FAIL: widgets.append(Text("")) widgets.append(Text(rewrap(_(integrity_check_fail_text)))) if self.report and self.report.uploader: - widgets.extend([Text(""), btns['cancel']]) + widgets.extend([Text(""), btns["cancel"]]) elif self.interrupting: if self.error_ref.state != ErrorReportState.INCOMPLETE: text, btn_names = error_report_options[self.error_ref.kind] @@ -272,10 +307,12 @@ class ErrorReportStretchy(Stretchy): for b in btn_names: widgets.extend([Text(""), btns[b]]) else: - widgets.extend([ - Text(""), - btns['close'], - ]) + widgets.extend( + [ + Text(""), + btns["close"], + ] + ) return widgets @@ -283,8 +320,7 @@ class ErrorReportStretchy(Stretchy): if self.pending: self.pending.cancel() self.pending = asyncio.create_task(asyncio.sleep(0.1)) - self.change_task = asyncio.create_task( - self._report_changed_()) + self.change_task = asyncio.create_task(self._report_changed_()) async def _report_changed_(self): await self.pending @@ -295,7 +331,8 @@ class ErrorReportStretchy(Stretchy): self.error_ref = self.report.ref() self.integrity_check_result = await self.app.client.integrity.GET() self.pile.contents[:] = [ - (w, self.pile.options('pack')) for w in self._pile_elements()] + (w, self.pile.options("pack")) for w in self._pile_elements() + ] if self.pile.selectable(): while not self.pile.focus.selectable(): self.pile.focus_position += 1 @@ -319,21 +356,23 @@ class ErrorReportStretchy(Stretchy): def closed(self): if self.report: - disconnect_signal(self.report, 'changed', self._report_changed) + disconnect_signal(self.report, "changed", self._report_changed) class ErrorReportListStretchy(Stretchy): - def __init__(self, app): self.app = app rows = [ - TableRow([ - Text(""), - Text(_("DATE")), - Text(_("KIND")), - Text(_("STATUS")), - Text(""), - ])] + TableRow( + [ + Text(""), + Text(_("DATE")), + Text(_("KIND")), + Text(_("STATUS")), + Text(""), + ] + ) + ] self.report_to_row = {} self.app.error_reporter.load_reports() for report in self.app.error_reporter.reports: @@ -347,12 +386,11 @@ class ErrorReportListStretchy(Stretchy): self.table, Text(""), button_pile([close_btn(self)]), - ] + ] super().__init__("", widgets, 2, 2) def open_report(self, sender, report): - self.app.add_global_overlay( - ErrorReportStretchy(self.app, report, False)) + self.app.add_global_overlay(ErrorReportStretchy(self.app, report, False)) def state_for_report(self, report): if report.seen: @@ -362,18 +400,17 @@ class ErrorReportListStretchy(Stretchy): def cells_for_report(self, report): date = report.pr.get("Date", "???") icon = ClickableIcon(date) - connect_signal(icon, 'click', self.open_report, report) + connect_signal(icon, "click", self.open_report, report) return [ Text("["), icon, Text(_(report.kind.value)), Text(_(self.state_for_report(report))), Text("]"), - ] + ] def row_for_report(self, report): - return Color.menu_button( - TableRow(self.cells_for_report(report))) + return Color.menu_button(TableRow(self.cells_for_report(report))) def _report_changed(self, report): old_r = self.report_to_row.get(report) diff --git a/subiquity/ui/views/filesystem/__init__.py b/subiquity/ui/views/filesystem/__init__.py index cf66c905..c766c160 100644 --- a/subiquity/ui/views/filesystem/__init__.py +++ b/subiquity/ui/views/filesystem/__init__.py @@ -22,7 +22,8 @@ configuration. from .filesystem import FilesystemView from .guided import GuidedDiskSelectionView + __all__ = [ - 'FilesystemView', - 'GuidedDiskSelectionView', + "FilesystemView", + "GuidedDiskSelectionView", ] diff --git a/subiquity/ui/views/filesystem/compound.py b/subiquity/ui/views/filesystem/compound.py index 1d874307..f34a128d 100644 --- a/subiquity/ui/views/filesystem/compound.py +++ b/subiquity/ui/views/filesystem/compound.py @@ -15,46 +15,26 @@ import logging -from urwid import ( - CheckBox, - connect_signal, - Padding as UrwidPadding, - Text, - ) - -from subiquitycore.ui.container import ( - WidgetWrap, - ) -from subiquitycore.ui.form import ( - Form, - simple_field, - Toggleable, - WantsToKnowFormField, - ) -from subiquitycore.ui.selector import ( - Selector, - ) -from subiquitycore.ui.table import ( - TablePile, - TableRow, - ) -from subiquitycore.ui.utils import ( - Color, - ) +from urwid import CheckBox +from urwid import Padding as UrwidPadding +from urwid import Text, connect_signal from subiquity.common.filesystem import labels -from subiquity.models.filesystem import ( - humanize_size, - ) +from subiquity.models.filesystem import humanize_size +from subiquitycore.ui.container import WidgetWrap +from subiquitycore.ui.form import Form, Toggleable, WantsToKnowFormField, simple_field +from subiquitycore.ui.selector import Selector +from subiquitycore.ui.table import TablePile, TableRow +from subiquitycore.ui.utils import Color -log = logging.getLogger('subiquity.ui.views.filesystem.compound') +log = logging.getLogger("subiquity.ui.views.filesystem.compound") LABEL, DEVICE, PART = range(3) class MultiDeviceChooser(WidgetWrap, WantsToKnowFormField): - signals = ['change'] + signals = ["change"] def __init__(self): self.table = TablePile([], spacing=1) @@ -85,13 +65,11 @@ class MultiDeviceChooser(WidgetWrap, WantsToKnowFormField): @property def active_devices(self): - return {device for device, status in self.devices.items() - if status == 'active'} + return {device for device, status in self.devices.items() if status == "active"} @property def spare_devices(self): - return {device for device, status in self.devices.items() - if status == 'spare'} + return {device for device, status in self.devices.items() if status == "spare"} def set_supports_spares(self, val): if val == self.supports_spares: @@ -106,7 +84,7 @@ class MultiDeviceChooser(WidgetWrap, WantsToKnowFormField): else: for device in list(self.devices): self.device_to_selector[device].enabled = False - self.devices[device] = 'active' + self.devices[device] = "active" self.table.set_contents(self.no_selector_rows) def _state_change_device(self, sender, state, device): @@ -118,11 +96,11 @@ class MultiDeviceChooser(WidgetWrap, WantsToKnowFormField): else: self.device_to_selector[device].enabled = False del self.devices[device] - self._emit('change', self.devices) + self._emit("change", self.devices) def _select_active_spare(self, sender, value, device): self.devices[device] = value - self._emit('change', self.devices) + self._emit("change", self.devices) def _summarize(self, prefix, device): if device.fs() is not None: @@ -133,8 +111,7 @@ class MultiDeviceChooser(WidgetWrap, WantsToKnowFormField): else: text += _(", not mounted") else: - text = prefix + _("unused {device}").format( - device=labels.desc(device)) + text = prefix + _("unused {device}").format(device=labels.desc(device)) return TableRow([(2, Color.info_minor(Text(text)))]) def set_bound_form_field(self, bff): @@ -142,14 +119,20 @@ class MultiDeviceChooser(WidgetWrap, WantsToKnowFormField): self.all_rows = [] for kind, device in bff.form.possible_components: if kind == LABEL: - self.all_rows.append(TableRow([ - Text(" " + labels.label(device)), - Text(humanize_size(device.size), align='right') - ])) + self.all_rows.append( + TableRow( + [ + Text(" " + labels.label(device)), + Text(humanize_size(device.size), align="right"), + ] + ) + ) self.no_selector_rows.append(self.all_rows[-1]) - self.all_rows.append(TableRow([ - (2, Color.info_minor(Text(" " + labels.desc(device)))) - ])) + self.all_rows.append( + TableRow( + [(2, Color.info_minor(Text(" " + labels.desc(device))))] + ) + ) self.no_selector_rows.append(self.all_rows[-1]) else: label = labels.label(device, short=True) @@ -161,20 +144,17 @@ class MultiDeviceChooser(WidgetWrap, WantsToKnowFormField): else: raise Exception("unexpected kind {}".format(kind)) box = CheckBox( - label, - on_state_change=self._state_change_device, - user_data=device) + label, on_state_change=self._state_change_device, user_data=device + ) self.device_to_checkbox[device] = box - size = Text(humanize_size(device.size), align='right') + size = Text(humanize_size(device.size), align="right") self.all_rows.append(Color.menu_button(TableRow([box, size]))) self.no_selector_rows.append(self.all_rows[-1]) - selector = Selector(['active', 'spare']) - connect_signal( - selector, 'select', self._select_active_spare, device) + selector = Selector(["active", "spare"]) + connect_signal(selector, "select", self._select_active_spare, device) selector = Toggleable( - UrwidPadding( - Color.menu_button(selector), - left=len(prefix))) + UrwidPadding(Color.menu_button(selector), left=len(prefix)) + ) selector.enabled = False self.device_to_selector[device] = selector self.all_rows.append(TableRow([(2, selector)])) @@ -189,7 +169,6 @@ MultiDeviceField.takes_default_style = False class CompoundDiskForm(Form): - def __init__(self, model, possible_components, initial): self.model = model self.possible_components = possible_components @@ -214,9 +193,11 @@ class CompoundDiskForm(Form): # reuse) potential_boot_disks.add(d) if not potential_boot_disks - set(mdc.value): - return _("\ + return _( + "\ If you put all disks into RAIDs or LVM VGs, there will be nowhere \ -to put the boot partition.") +to put the boot partition." + ) def get_possible_components(model, existing, cur_devices, device_ok): diff --git a/subiquity/ui/views/filesystem/delete.py b/subiquity/ui/views/filesystem/delete.py index b8a00159..d2c5ed4e 100644 --- a/subiquity/ui/views/filesystem/delete.py +++ b/subiquity/ui/views/filesystem/delete.py @@ -13,36 +13,33 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from gettext import ngettext import logging +from gettext import ngettext from urwid import Text -from subiquitycore.ui.buttons import danger_btn, other_btn -from subiquitycore.ui.table import ( - TablePile, - TableRow, - ) -from subiquitycore.ui.utils import button_pile -from subiquitycore.ui.stretchy import Stretchy - from subiquity.common.filesystem import labels +from subiquitycore.ui.buttons import danger_btn, other_btn +from subiquitycore.ui.stretchy import Stretchy +from subiquitycore.ui.table import TablePile, TableRow +from subiquitycore.ui.utils import button_pile from .helpers import summarize_device - -log = logging.getLogger('subiquity.ui.views.filesystem.delete') +log = logging.getLogger("subiquity.ui.views.filesystem.delete") class ConfirmDeleteStretchy(Stretchy): - def __init__(self, parent, obj): self.parent = parent self.obj = obj lines = [ - Text(_("Do you really want to delete the {desc} {label}?").format( - desc=labels.desc(obj), label=labels.label(obj))), + Text( + _("Do you really want to delete the {desc} {label}?").format( + desc=labels.desc(obj), label=labels.label(obj) + ) + ), Text(""), ] stretchy_index = 0 @@ -50,27 +47,31 @@ class ConfirmDeleteStretchy(Stretchy): if fs is not None: m = fs.mount() if m is not None: - lines.append(Text(_( - "It is formatted as {fstype} and mounted at " - "{path}").format( - fstype=fs.fstype, - path=m.path))) + lines.append( + Text( + _( + "It is formatted as {fstype} and mounted at " "{path}" + ).format(fstype=fs.fstype, path=m.path) + ) + ) else: - lines.append(Text(_( - "It is formatted as {fstype} and not mounted.").format( - fstype=fs.fstype))) - elif hasattr(obj, 'partitions') and len(obj.partitions()) > 0: + lines.append( + Text( + _("It is formatted as {fstype} and not mounted.").format( + fstype=fs.fstype + ) + ) + ) + elif hasattr(obj, "partitions") and len(obj.partitions()) > 0: n = len(obj.partitions()) if obj.type == "lvm_volgroup": line = ngettext( - "It contains 1 logical volume", - "It contains {n} logical volumes", - n) + "It contains 1 logical volume", "It contains {n} logical volumes", n + ) else: line = ngettext( - "It contains 1 partition", - "It contains {n} partitions", - n) + "It contains 1 partition", "It contains {n} partitions", n + ) lines.append(Text(line.format(n=n))) lines.append(Text("")) stretchy_index = len(lines) @@ -81,10 +82,7 @@ class ConfirmDeleteStretchy(Stretchy): lines.append(TablePile(rows)) elif obj.type == "raid" and obj._subvolumes: n = len(obj._subvolumes) - line = ngettext( - "It contains 1 volume.", - "It contains {n} volumes.", - n) + line = ngettext("It contains 1 volume.", "It contains {n} volumes.", n) lines.append(Text(line.format(n=n))) lines.append(Text("")) else: @@ -93,12 +91,14 @@ class ConfirmDeleteStretchy(Stretchy): delete_btn = danger_btn(label=_("Delete"), on_press=self.confirm) widgets = lines + [ Text(""), - button_pile([ - delete_btn, - other_btn(label=_("Cancel"), on_press=self.cancel), - ]), + button_pile( + [ + delete_btn, + other_btn(label=_("Cancel"), on_press=self.cancel), + ] + ), ] - super().__init__("", widgets, stretchy_index, len(lines)+1) + super().__init__("", widgets, stretchy_index, len(lines) + 1) def confirm(self, sender=None): self.parent.controller.delete(self.obj) @@ -110,34 +110,33 @@ class ConfirmDeleteStretchy(Stretchy): class ConfirmReformatStretchy(Stretchy): - def __init__(self, parent, obj): self.parent = parent self.obj = obj fs = obj.fs() if fs is not None: - title = _( - "Remove filesystem from {device}" - ).format(device=labels.desc(obj)) + title = _("Remove filesystem from {device}").format(device=labels.desc(obj)) lines = [ _( "Do you really want to remove the existing filesystem " "from {device}?" - ).format(device=labels.label(obj)), + ).format(device=labels.label(obj)), "", ] m = fs.mount() if m is not None: - lines.append(_( - "It is formatted as {fstype} and mounted at " - "{path}").format( - fstype=fs.fstype, - path=m.path)) + lines.append( + _("It is formatted as {fstype} and mounted at " "{path}").format( + fstype=fs.fstype, path=m.path + ) + ) else: - lines.append(_( - "It is formatted as {fstype} and not mounted.").format( - fstype=fs.fstype)) + lines.append( + _("It is formatted as {fstype} and not mounted.").format( + fstype=fs.fstype + ) + ) else: if obj.type == "lvm_volgroup": things = _("logical volumes") @@ -145,12 +144,12 @@ class ConfirmReformatStretchy(Stretchy): things = _("partitions") # things is either "logical volumes" or "partitions" title = _("Remove all {things} from {obj}").format( - things=things, obj=labels.desc(obj)) + things=things, obj=labels.desc(obj) + ) lines = [ - _( - "Do you really want to remove all {things} from " - "{obj}?").format( - things=things, obj=labels.label(obj)), + _("Do you really want to remove all {things} from " "{obj}?").format( + things=things, obj=labels.label(obj) + ), "", ] # XXX summarize partitions here? @@ -159,10 +158,12 @@ class ConfirmReformatStretchy(Stretchy): widgets = [ Text("\n".join(lines)), Text(""), - button_pile([ - delete_btn, - other_btn(label=_("Cancel"), on_press=self.cancel), - ]), + button_pile( + [ + delete_btn, + other_btn(label=_("Cancel"), on_press=self.cancel), + ] + ), ] super().__init__(title, widgets, 0, 2) diff --git a/subiquity/ui/views/filesystem/disk_info.py b/subiquity/ui/views/filesystem/disk_info.py index 3140b1ec..bf7af250 100644 --- a/subiquity/ui/views/filesystem/disk_info.py +++ b/subiquity/ui/views/filesystem/disk_info.py @@ -14,30 +14,29 @@ # along with this program. If not, see . import logging + from urwid import Text +from subiquity.common.filesystem import labels from subiquitycore.ui.buttons import done_btn -from subiquitycore.ui.utils import button_pile from subiquitycore.ui.stretchy import Stretchy from subiquitycore.ui.table import ColSpec, TablePile, TableRow +from subiquitycore.ui.utils import button_pile -from subiquity.common.filesystem import labels - - -log = logging.getLogger('subiquity.ui.views.filesystem.disk_info') +log = logging.getLogger("subiquity.ui.views.filesystem.disk_info") labels_keys = [ - ('Path:', 'devname'), - ('Multipath:', 'multipath'), - ('Vendor:', 'vendor'), - ('Model:', 'model'), - ('SerialNo:', 'serial'), - ('WWN:', 'wwn'), - ('Size:', 'size'), - ('Bus:', 'bus'), - ('Rotational:', 'rotational'), - ('Path:', 'devpath'), + ("Path:", "devname"), + ("Multipath:", "multipath"), + ("Vendor:", "vendor"), + ("Model:", "model"), + ("SerialNo:", "serial"), + ("WWN:", "wwn"), + ("Size:", "size"), + ("Bus:", "bus"), + ("Rotational:", "rotational"), + ("Path:", "devpath"), ] @@ -48,12 +47,12 @@ class DiskInfoStretchy(Stretchy): rows = [] for label, key in labels_keys: v = str(dinfo[key]) - rows.append(TableRow([Text(label, align='right'), Text(v)])) + rows.append(TableRow([Text(label, align="right"), Text(v)])) widgets = [ TablePile(rows, colspecs={1: ColSpec(can_shrink=True)}), Text(""), button_pile([done_btn(_("Close"), on_press=self.close)]), - ] + ] title = _("Info for {device}").format(device=labels.label(disk)) super().__init__(title, widgets, 0, 2) diff --git a/subiquity/ui/views/filesystem/filesystem.py b/subiquity/ui/views/filesystem/filesystem.py index 0c9ab2a4..9ba17d2a 100644 --- a/subiquity/ui/views/filesystem/filesystem.py +++ b/subiquity/ui/views/filesystem/filesystem.py @@ -22,60 +22,34 @@ configuration. import logging import attr +from urwid import Text, connect_signal -from urwid import ( - connect_signal, - Text, - ) - -from subiquitycore.ui.actionmenu import ( - Action, - ActionMenu, - ActionMenuOpenButton, - ) -from subiquitycore.ui.buttons import ( - back_btn, - done_btn, - menu_btn, - other_btn, - reset_btn, - ) -from subiquitycore.ui.container import ( - ListBox, - WidgetWrap, - ) +from subiquity.common.filesystem import boot, gaps, labels +from subiquity.common.filesystem.actions import DeviceAction +from subiquity.models.filesystem import humanize_size +from subiquitycore.ui.actionmenu import Action, ActionMenu, ActionMenuOpenButton +from subiquitycore.ui.buttons import back_btn, done_btn, menu_btn, other_btn, reset_btn +from subiquitycore.ui.container import ListBox, WidgetWrap from subiquitycore.ui.form import Toggleable from subiquitycore.ui.stretchy import Stretchy -from subiquitycore.ui.table import ( - ColSpec, - TablePile, - TableRow, - ) +from subiquitycore.ui.table import ColSpec, TablePile, TableRow from subiquitycore.ui.utils import ( - button_pile, Color, - make_action_menu_row, Padding, + button_pile, + make_action_menu_row, screen, - ) +) from subiquitycore.view import BaseView -from subiquity.common.filesystem.actions import ( - DeviceAction, - ) -from subiquity.common.filesystem import boot, gaps, labels -from subiquity.models.filesystem import ( - humanize_size, - ) - from .delete import ConfirmDeleteStretchy, ConfirmReformatStretchy from .disk_info import DiskInfoStretchy from .helpers import summarize_device from .lvm import VolGroupStretchy -from .partition import PartitionStretchy, FormatEntireStretchy +from .partition import FormatEntireStretchy, PartitionStretchy from .raid import RaidStretchy -log = logging.getLogger('subiquity.ui.views.filesystem.filesystem') +log = logging.getLogger("subiquity.ui.views.filesystem.filesystem") @attr.s @@ -88,7 +62,7 @@ class MountInfo: @property def split_path(self): - return self.mount.path.split('/') + return self.mount.path.split("/") @property def size(self): @@ -121,23 +95,27 @@ class MountInfo: class MountList(WidgetWrap): - def __init__(self, parent): self.parent = parent - self.table = TablePile([], spacing=2, colspecs={ - 0: ColSpec(rpad=1), - 1: ColSpec(can_shrink=True), - 2: ColSpec(min_width=9), - 4: ColSpec(rpad=1), - 5: ColSpec(rpad=1), - }) + self.table = TablePile( + [], + spacing=2, + colspecs={ + 0: ColSpec(rpad=1), + 1: ColSpec(can_shrink=True), + 2: ColSpec(min_width=9), + 4: ColSpec(rpad=1), + 5: ColSpec(rpad=1), + }, + ) self._no_mounts_content = Color.info_minor( - Text(_("No disks or partitions mounted."))) + Text(_("No disks or partitions mounted.")) + ) super().__init__(self.table) def _mount_action(self, sender, action, mount): - log.debug('_mount_action %s %s', action, mount) - if action == 'unmount': + log.debug("_mount_action %s %s", action, mount) + if action == "unmount": self.parent.controller.delete_mount(mount) self.parent.refresh_model_inputs() @@ -145,8 +123,8 @@ class MountList(WidgetWrap): mountinfos = [ MountInfo(mount=m) for m in sorted( - self.parent.model.all_mounts(), - key=lambda m: (m.path == "", m.path)) + self.parent.model.all_mounts(), key=lambda m: (m.path == "", m.path) + ) ] if len(mountinfos) == 0: self.table.set_contents([]) @@ -154,42 +132,47 @@ class MountList(WidgetWrap): return self._w = self.table - rows = [TableRow([ - Color.info_minor(heading) for heading in [ - Text(" "), - Text(_("MOUNT POINT")), - Text(_("SIZE"), align='center'), - Text(_("TYPE")), - Text(_("DEVICE TYPE")), - Text(" "), - Text(" "), - ]])] + rows = [ + TableRow( + [ + Color.info_minor(heading) + for heading in [ + Text(" "), + Text(_("MOUNT POINT")), + Text(_("SIZE"), align="center"), + Text(_("TYPE")), + Text(_("DEVICE TYPE")), + Text(" "), + Text(" "), + ] + ] + ) + ] for i, mi in enumerate(mountinfos): path_markup = mi.path if path_markup == "": path_markup = "SWAP" else: - for j in range(i-1, -1, -1): + for j in range(i - 1, -1, -1): mi2 = mountinfos[j] if mi.startswith(mi2): - part1 = "/".join(mi.split_path[:len(mi2.split_path)]) - part2 = "/".join( - [''] + mi.split_path[len(mi2.split_path):]) - path_markup = [('info_minor', part1), part2] + part1 = "/".join(mi.split_path[: len(mi2.split_path)]) + part2 = "/".join([""] + mi.split_path[len(mi2.split_path) :]) + path_markup = [("info_minor", part1), part2] break - if j == 0 and mi2.split_path == ['', '']: + if j == 0 and mi2.split_path == ["", ""]: path_markup = [ - ('info_minor', "/"), + ("info_minor", "/"), "/".join(mi.split_path[1:]), - ] - actions = [(_("Unmount"), mi.mount.can_delete(), 'unmount')] + ] + actions = [(_("Unmount"), mi.mount.can_delete(), "unmount")] menu = ActionMenu(actions) - connect_signal(menu, 'action', self._mount_action, mi.mount) + connect_signal(menu, "action", self._mount_action, mi.mount) cells = [ Text("["), Text(path_markup), - Text(mi.size, align='right'), + Text(mi.size, align="right"), Text(mi.fstype), Text(mi.desc), menu, @@ -198,11 +181,12 @@ class MountList(WidgetWrap): row = make_action_menu_row( cells, menu, - attr_map='menu_button', + attr_map="menu_button", focus_map={ - None: 'menu_button focus', - 'info_minor': 'menu_button focus', - }) + None: "menu_button focus", + "info_minor": "menu_button focus", + }, + ) rows.append(row) self.table.set_contents(rows) if self.table._w.focus_position >= len(rows): @@ -212,25 +196,27 @@ class MountList(WidgetWrap): def _stretchy_shower(cls): def impl(self, device): self.parent.show_stretchy_overlay(cls(self.parent, device)) + impl.opens_dialog = True return impl class WhyNotStretchy(Stretchy): - def __init__(self, parent, obj, action, whynot): self.parent = parent self.obj = obj title = "Cannot {action} {type}".format( - action=_(action.value).lower(), - type=labels.desc(obj)) + action=_(action.value).lower(), type=labels.desc(obj) + ) widgets = [ Text(whynot), Text(""), - button_pile([ - other_btn(label=_("Close"), on_press=self.close), - ]), + button_pile( + [ + other_btn(label=_("Close"), on_press=self.close), + ] + ), ] super().__init__(title, widgets, 0, 2) @@ -241,21 +227,25 @@ class WhyNotStretchy(Stretchy): def _whynot_shower(view, action, whynot): def impl(obj): view.show_stretchy_overlay(WhyNotStretchy(view, obj, action, whynot)) + impl.opens_dialog = True return impl class DeviceList(WidgetWrap): - def __init__(self, parent, show_available): self.parent = parent self.show_available = show_available - self.table = TablePile([], spacing=2, colspecs={ - 0: ColSpec(rpad=1), - 2: ColSpec(can_shrink=True), - 4: ColSpec(min_width=9), - 5: ColSpec(rpad=1), - }) + self.table = TablePile( + [], + spacing=2, + colspecs={ + 0: ColSpec(rpad=1), + 2: ColSpec(can_shrink=True), + 4: ColSpec(min_width=9), + 5: ColSpec(rpad=1), + }, + ) if show_available: text = _("No available devices") else: @@ -280,7 +270,7 @@ class DeviceList(WidgetWrap): elif cd.type == "lvm_volgroup": cd.devices.remove(disk) else: - 1/0 + 1 / 0 disk._constructed_device = None self.parent.refresh_model_inputs() @@ -292,8 +282,8 @@ class DeviceList(WidgetWrap): self.parent.refresh_model_inputs() _partition_EDIT = _stretchy_shower( - lambda parent, part: PartitionStretchy(parent, part.device, - partition=part)) + lambda parent, part: PartitionStretchy(parent, part.device, partition=part) + ) _partition_REMOVE = _disk_REMOVE _partition_DELETE = _stretchy_shower(ConfirmDeleteStretchy) @@ -308,16 +298,17 @@ class DeviceList(WidgetWrap): _lvm_volgroup_DELETE = _partition_DELETE _lvm_partition_EDIT = _stretchy_shower( - lambda parent, part: PartitionStretchy(parent, part.volgroup, - partition=part)) + lambda parent, part: PartitionStretchy(parent, part.volgroup, partition=part) + ) _lvm_partition_DELETE = _partition_DELETE _gap_PARTITION = _stretchy_shower( - lambda parent, gap: PartitionStretchy(parent, gap.device, gap=gap)) + lambda parent, gap: PartitionStretchy(parent, gap.device, gap=gap) + ) def _action(self, sender, value, device): action, meth = value - log.debug('_action %s %s', action, device.id) + log.debug("_action %s %s", action, device.id) meth(device) def _label_REMOVE(self, action, device): @@ -329,11 +320,12 @@ class DeviceList(WidgetWrap): def _label_PARTITION(self, action, gap): device = gap.device - if device.type == 'lvm_volgroup': + if device.type == "lvm_volgroup": return _("Create Logical Volume") else: return _("Add {ptype} Partition").format( - ptype=device.ptable_for_new_partition().upper()) + ptype=device.ptable_for_new_partition().upper() + ) def _label_TOGGLE_BOOT(self, action, device): if boot.is_boot_device(device): @@ -348,7 +340,8 @@ class DeviceList(WidgetWrap): device_actions = [] for action in DeviceAction.supported(device): label_meth = getattr( - self, '_label_{}'.format(action.name), lambda a, d: a.str()) + self, "_label_{}".format(action.name), lambda a, d: a.str() + ) label = label_meth(action, device) enabled, whynot = action.can(device) if whynot: @@ -357,27 +350,32 @@ class DeviceList(WidgetWrap): label += " *" meth = _whynot_shower(self.parent, action, whynot) else: - meth_name = '_{}_{}'.format(device.type, action.name) + meth_name = "_{}_{}".format(device.type, action.name) meth = getattr(self, meth_name) - if not whynot and action in [DeviceAction.DELETE, - DeviceAction.REFORMAT]: + if not whynot and action in [DeviceAction.DELETE, DeviceAction.REFORMAT]: label = Color.danger_button(ActionMenuOpenButton(label)) - device_actions.append(Action( - label=label, - enabled=enabled, - value=(action, meth), - opens_dialog=getattr(meth, 'opens_dialog', False))) + device_actions.append( + Action( + label=label, + enabled=enabled, + value=(action, meth), + opens_dialog=getattr(meth, "opens_dialog", False), + ) + ) if not device_actions: return Text("") menu = ActionMenu(device_actions) - connect_signal(menu, 'action', self._action, device) + connect_signal(menu, "action", self._action, device) return menu def refresh_model_inputs(self): devices = [ - d for d in self.parent.model.all_devices() - if (d.available() == self.show_available - or (not self.show_available and d.has_unavailable_partition())) + d + for d in self.parent.model.all_devices() + if ( + d.available() == self.show_available + or (not self.show_available and d.has_unavailable_partition()) + ) ] if len(devices) == 0: self._w = Padding.push_2(self._no_devices_content) @@ -386,21 +384,30 @@ class DeviceList(WidgetWrap): self._w = self.table rows = [] - rows.append(Color.info_minor(TableRow([ - Text(""), - (2, Text(_("DEVICE"))), - Text(_("TYPE")), - Text(_("SIZE"), align="center"), - Text(""), - Text(""), - ]))) + rows.append( + Color.info_minor( + TableRow( + [ + Text(""), + (2, Text(_("DEVICE"))), + Text(_("TYPE")), + Text(_("SIZE"), align="center"), + Text(""), + Text(""), + ] + ) + ) + ) if self.show_available: + def filter(part): if isinstance(part, gaps.Gap): return True return part.available() + else: + def filter(part): if isinstance(part, gaps.Gap): return False @@ -410,9 +417,9 @@ class DeviceList(WidgetWrap): for obj, cells in summarize_device(device, filter): menu = self._action_menu_for_device(obj) if obj is device: - start, end = '[', ']' + start, end = "[", "]" else: - start, end = '', '' + start, end = "", "" cells = [Text(start)] + cells + [menu, Text(end)] if isinstance(menu, ActionMenu): rows.append(make_action_menu_row(cells, menu)) @@ -437,15 +444,15 @@ class FilesystemView(BaseView): self.avail_list = DeviceList(self, True) self.used_list = DeviceList(self, False) self.avail_list.table.bind(self.used_list.table) - self._create_raid_btn = Toggleable(menu_btn( - label=_("Create software RAID (md)"), - on_press=self.create_raid)) - self._create_vg_btn = Toggleable(menu_btn( - label=_("Create volume group (LVM)"), - on_press=self.create_vg)) + self._create_raid_btn = Toggleable( + menu_btn(label=_("Create software RAID (md)"), on_press=self.create_raid) + ) + self._create_vg_btn = Toggleable( + menu_btn(label=_("Create volume group (LVM)"), on_press=self.create_vg) + ) bp = button_pile([self._create_raid_btn, self._create_vg_btn]) - bp.align = 'left' + bp.align = "left" body = [ Text(_("FILE SYSTEM SUMMARY")), @@ -464,13 +471,15 @@ class FilesystemView(BaseView): Text(""), self.used_list, Text(""), - ] + ] self.lb = ListBox(body) self.lb.base_widget.offset_rows = 2 - super().__init__(screen( - self.lb, self._build_buttons(), - focus_buttons=self.model.can_install())) + super().__init__( + screen( + self.lb, self._build_buttons(), focus_buttons=self.model.can_install() + ) + ) self.frame = self._w.base_widget self.showing_guidance = False self.refresh_model_inputs() @@ -484,11 +493,13 @@ class FilesystemView(BaseView): if not todos: return None rows = [ - TableRow([ - Text(_("To continue you need to:")), - Text(todos[0]), - ]), - ] + TableRow( + [ + Text(_("To continue you need to:")), + Text(todos[0]), + ] + ), + ] for todo in todos[1:]: rows.append(TableRow([Text(""), Text(todo)])) rows.append(TableRow([Text("")])) @@ -501,7 +512,7 @@ class FilesystemView(BaseView): self.done_btn, reset_btn(_("Reset"), on_press=self.reset), back_btn(_("Back"), on_press=self.cancel), - ] + ] def refresh_model_inputs(self): lvm_devices = set() @@ -529,8 +540,7 @@ class FilesystemView(BaseView): del self.frame.contents[0] guidance = self._guidance() if guidance: - self.frame.contents.insert( - 0, (guidance, self.frame.options('pack'))) + self.frame.contents.insert(0, (guidance, self.frame.options("pack"))) self.showing_guidance = True else: self.showing_guidance = False diff --git a/subiquity/ui/views/filesystem/guided.py b/subiquity/ui/views/filesystem/guided.py index 8fd93212..dab44f7c 100644 --- a/subiquity/ui/views/filesystem/guided.py +++ b/subiquity/ui/views/filesystem/guided.py @@ -16,35 +16,7 @@ import logging import attr - -from urwid import ( - connect_signal, - Text, - ) - -from subiquitycore.ui.form import ( - BooleanField, - ChoiceField, - Form, - NO_CAPTION, - NO_HELP, - PasswordField, - RadioButtonField, - SubForm, - SubFormField, - ) -from subiquitycore.ui.buttons import other_btn -from subiquitycore.ui.selector import Option -from subiquitycore.ui.table import ( - TablePile, - TableRow, - ) -from subiquitycore.ui.utils import ( - Color, - rewrap, - screen, - ) -from subiquitycore.view import BaseView +from urwid import Text, connect_signal from subiquity.common.types import ( Gap, @@ -56,7 +28,22 @@ from subiquity.common.types import ( Partition, ) from subiquity.models.filesystem import humanize_size - +from subiquitycore.ui.buttons import other_btn +from subiquitycore.ui.form import ( + NO_CAPTION, + NO_HELP, + BooleanField, + ChoiceField, + Form, + PasswordField, + RadioButtonField, + SubForm, + SubFormField, +) +from subiquitycore.ui.selector import Option +from subiquitycore.ui.table import TablePile, TableRow +from subiquitycore.ui.utils import Color, rewrap, screen +from subiquitycore.view import BaseView log = logging.getLogger("subiquity.ui.views.filesystem.guided") @@ -64,7 +51,6 @@ subtitle = _("Configure a guided storage layout, or create a custom one:") class LUKSOptionsForm(SubForm): - passphrase = PasswordField(_("Passphrase:")) confirm_passphrase = PasswordField(_("Confirm passphrase:")) @@ -78,10 +64,9 @@ class LUKSOptionsForm(SubForm): class LVMOptionsForm(SubForm): - def __init__(self, parent): super().__init__(parent) - connect_signal(self.encrypt.widget, 'change', self._toggle) + connect_signal(self.encrypt.widget, "change", self._toggle) self.luks_options.enabled = self.encrypt.value def _toggle(self, sender, val): @@ -94,31 +79,39 @@ class LVMOptionsForm(SubForm): def summarize_device(disk): label = disk.label - rows = [(disk, [ - (2, Text(label)), - Text(disk.type), - Text(humanize_size(disk.size), align="right"), - ])] + rows = [ + ( + disk, + [ + (2, Text(label)), + Text(disk.type), + Text(humanize_size(disk.size), align="right"), + ], + ) + ] if disk.partitions: for part in disk.partitions: if isinstance(part, Partition): details = ", ".join(part.annotations) - rows.append((part, [ - Text(_("partition {number}").format(number=part.number)), - (2, Text(details)), - Text(humanize_size(part.size), align="right"), - ])) + rows.append( + ( + part, + [ + Text(_("partition {number}").format(number=part.number)), + (2, Text(details)), + Text(humanize_size(part.size), align="right"), + ], + ) + ) elif isinstance(part, Gap): # If desired, we could show gaps here. It is less critical, # given that the context is reformatting full disks and the # partition display is showing what is about to be lost. pass else: - raise Exception(f'unhandled partition type {part}') + raise Exception(f"unhandled partition type {part}") else: - rows.append((None, [ - (4, Color.info_minor(Text(", ".join(disk.usage_labels)))) - ])) + rows.append((None, [(4, Color.info_minor(Text(", ".join(disk.usage_labels))))])) return rows @@ -130,43 +123,48 @@ class TPMChoice: tpm_help_texts = { - "AVAILABLE_CAN_BE_DESELECTED": - _("The entire disk will be encrypted and protected by the " - "TPM. If this option is deselected, the disk will be " - "unencrypted and without any protection."), - "AVAILABLE_CANNOT_BE_DESELECTED": - _("The entire disk will be encrypted and protected by the TPM."), + "AVAILABLE_CAN_BE_DESELECTED": _( + "The entire disk will be encrypted and protected by the " + "TPM. If this option is deselected, the disk will be " + "unencrypted and without any protection." + ), + "AVAILABLE_CANNOT_BE_DESELECTED": _( + "The entire disk will be encrypted and protected by the TPM." + ), "UNAVAILABLE": - # for translators: 'reason' is the reason FDE is unavailable. - _("TPM backed full-disk encryption is not available " - "on this device (the reason given was \"{reason}\")."), + # for translators: 'reason' is the reason FDE is unavailable. + _( + "TPM backed full-disk encryption is not available " + 'on this device (the reason given was "{reason}").' + ), } choices = { GuidedCapability.CORE_BOOT_ENCRYPTED: TPMChoice( - enabled=False, default=True, - help=tpm_help_texts['AVAILABLE_CANNOT_BE_DESELECTED']), + enabled=False, + default=True, + help=tpm_help_texts["AVAILABLE_CANNOT_BE_DESELECTED"], + ), GuidedCapability.CORE_BOOT_UNENCRYPTED: TPMChoice( - enabled=False, default=False, - help=tpm_help_texts['UNAVAILABLE']), + enabled=False, default=False, help=tpm_help_texts["UNAVAILABLE"] + ), GuidedCapability.CORE_BOOT_PREFER_ENCRYPTED: TPMChoice( - enabled=True, default=True, - help=tpm_help_texts['AVAILABLE_CAN_BE_DESELECTED']), + enabled=True, default=True, help=tpm_help_texts["AVAILABLE_CAN_BE_DESELECTED"] + ), GuidedCapability.CORE_BOOT_PREFER_UNENCRYPTED: TPMChoice( - enabled=True, default=False, - help=tpm_help_texts['AVAILABLE_CAN_BE_DESELECTED']), + enabled=True, default=False, help=tpm_help_texts["AVAILABLE_CAN_BE_DESELECTED"] + ), } class GuidedChoiceForm(SubForm): - disk = ChoiceField(caption=NO_CAPTION, help=NO_HELP, choices=["x"]) use_lvm = BooleanField(_("Set up this disk as an LVM group"), help=NO_HELP) lvm_options = SubFormField(LVMOptionsForm, "", help=NO_HELP) use_tpm = BooleanField(_("Full disk encryption with TPM")) def __init__(self, parent): - super().__init__(parent, initial={'use_lvm': True}) + super().__init__(parent, initial={"use_lvm": True}) self.tpm_choice = None options = [] @@ -196,18 +194,18 @@ class GuidedChoiceForm(SubForm): self.disk.widget.options = options self.disk.widget.index = initial - connect_signal(self.disk.widget, 'select', self._select_disk) + connect_signal(self.disk.widget, "select", self._select_disk) self._select_disk(None, self.disk.value) - connect_signal(self.use_lvm.widget, 'change', self._toggle_lvm) + connect_signal(self.use_lvm.widget, "change", self._toggle_lvm) self._toggle_lvm(None, self.use_lvm.value) if GuidedCapability.LVM_LUKS not in all_caps: - self.remove_field('lvm_options') + self.remove_field("lvm_options") if GuidedCapability.LVM not in all_caps: - self.remove_field('use_lvm') + self.remove_field("use_lvm") core_boot_caps = [c for c in all_caps if c.is_core_boot()] if not core_boot_caps: - self.remove_field('use_tpm') + self.remove_field("use_tpm") def _select_disk(self, sender, val): self.use_lvm.enabled = GuidedCapability.LVM in val.allowed @@ -215,10 +213,9 @@ class GuidedChoiceForm(SubForm): if core_boot_caps: assert len(val.allowed) == 1 cap = core_boot_caps[0] - reason = '' + reason = "" for disallowed in val.disallowed: - if disallowed.capability == \ - GuidedCapability.CORE_BOOT_ENCRYPTED: + if disallowed.capability == GuidedCapability.CORE_BOOT_ENCRYPTED: reason = disallowed.message self.tpm_choice = choices[cap] self.use_tpm.enabled = self.tpm_choice.enabled @@ -234,7 +231,6 @@ class GuidedChoiceForm(SubForm): class GuidedForm(Form): - group = [] guided = RadioButtonField(group, _("Use an entire disk"), help=NO_HELP) @@ -247,14 +243,15 @@ class GuidedForm(Form): self.targets = targets self.disk_by_id = disk_by_id super().__init__() - connect_signal(self.guided.widget, 'change', self._toggle_guided) + connect_signal(self.guided.widget, "change", self._toggle_guided) def _toggle_guided(self, sender, new_value): self.guided_choice.enabled = new_value self.validated() -HELP = _(""" +HELP = _( + """ The "Use an entire disk" option installs Ubuntu onto the selected disk, replacing any partitions and data already there. @@ -281,23 +278,27 @@ If you choose to use a custom storage layout, no changes are made to the disks and you will have to, at a minimum, select a boot disk and mount a filesystem at /. -""") +""" +) -no_big_disks = _(""" +no_big_disks = _( + """ Block probing did not discover any disks big enough to support guided storage configuration. Manual configuration may still be possible. -""") +""" +) -no_disks = _(""" +no_disks = _( + """ Block probing did not discover any disks. Unfortunately this means that installation will not be possible. -""") +""" +) class GuidedDiskSelectionView(BaseView): - title = _("Guided storage configuration") def __init__(self, controller, targets, disk_by_id): @@ -306,7 +307,7 @@ class GuidedDiskSelectionView(BaseView): reformats = [] any_ok = False offer_manual = False - encryption_unavail_reason = '' + encryption_unavail_reason = "" GCDR = GuidedDisallowedCapabilityReason for target in targets: @@ -322,9 +323,7 @@ class GuidedDiskSelectionView(BaseView): encryption_unavail_reason = disallowed.message if any_ok: - show_form = self.form = GuidedForm( - targets=reformats, - disk_by_id=disk_by_id) + show_form = self.form = GuidedForm(targets=reformats, disk_by_id=disk_by_id) if not offer_manual: show_form = self.form.guided_choice.widget.form @@ -332,18 +331,20 @@ class GuidedDiskSelectionView(BaseView): else: excerpt = _(subtitle) - connect_signal(show_form, 'submit', self.done) - connect_signal(show_form, 'cancel', self.cancel) + connect_signal(show_form, "submit", self.done) + connect_signal(show_form, "cancel", self.cancel) super().__init__( - show_form.as_screen( - focus_buttons=False, excerpt=_(excerpt))) + show_form.as_screen(focus_buttons=False, excerpt=_(excerpt)) + ) elif encryption_unavail_reason: super().__init__( screen( [Text(rewrap(_(encryption_unavail_reason)))], [other_btn(_("Back"), on_press=self.cancel)], - excerpt=_("Cannot install core boot classic system"))) + excerpt=_("Cannot install core boot classic system"), + ) + ) elif disk_by_id and offer_manual: super().__init__( screen( @@ -351,33 +352,32 @@ class GuidedDiskSelectionView(BaseView): [ other_btn(_("OK"), on_press=self.manual), other_btn(_("Back"), on_press=self.cancel), - ])) + ], + ) + ) else: - super().__init__( - screen( - [Text(rewrap(_(no_disks)))], - [])) + super().__init__(screen([Text(rewrap(_(no_disks)))], [])) def local_help(self): return (_("Help on guided storage configuration"), rewrap(_(HELP))) def done(self, sender): results = self.form.as_data() - if results['guided']: - guided_choice = results['guided_choice'] - target = guided_choice['disk'] + if results["guided"]: + guided_choice = results["guided_choice"] + target = guided_choice["disk"] tpm_choice = self.form.guided_choice.widget.form.tpm_choice password = None if tpm_choice is not None: - if guided_choice.get('use_tpm', tpm_choice.default): + if guided_choice.get("use_tpm", tpm_choice.default): capability = GuidedCapability.CORE_BOOT_ENCRYPTED else: capability = GuidedCapability.CORE_BOOT_UNENCRYPTED - elif guided_choice.get('use_lvm', False): - opts = guided_choice.get('lvm_options', {}) - if opts.get('encrypt', False): + elif guided_choice.get("use_lvm", False): + opts = guided_choice.get("lvm_options", {}) + if opts.get("encrypt", False): capability = GuidedCapability.LVM_LUKS - password = opts['luks_options']['passphrase'] + password = opts["luks_options"]["passphrase"] else: capability = GuidedCapability.LVM else: @@ -386,19 +386,21 @@ class GuidedDiskSelectionView(BaseView): target=target, capability=capability, password=password, - ) + ) else: choice = GuidedChoiceV2( target=GuidedStorageTargetManual(), capability=GuidedCapability.MANUAL, - ) + ) self.controller.guided_choice(choice) def manual(self, sender): - self.controller.guided_choice(GuidedChoiceV2( - target=GuidedStorageTargetManual(), - capability=GuidedCapability.MANUAL, - )) + self.controller.guided_choice( + GuidedChoiceV2( + target=GuidedStorageTargetManual(), + capability=GuidedCapability.MANUAL, + ) + ) def cancel(self, btn=None): self.controller.cancel() diff --git a/subiquity/ui/views/filesystem/helpers.py b/subiquity/ui/views/filesystem/helpers.py index 603e9517..35900427 100644 --- a/subiquity/ui/views/filesystem/helpers.py +++ b/subiquity/ui/views/filesystem/helpers.py @@ -16,9 +16,7 @@ from urwid import Text from subiquity.common.filesystem import gaps, labels -from subiquity.models.filesystem import ( - humanize_size, - ) +from subiquity.models.filesystem import humanize_size def summarize_device(device, part_filter=lambda p: True): @@ -34,20 +32,29 @@ def summarize_device(device, part_filter=lambda p: True): anns = labels.annotations(device) + labels.usage_labels(device) if anns: label = "{} ({})".format(label, ", ".join(anns)) - rows = [(device, [ - (2, Text(label)), - Text(labels.desc(device)), - Text(humanize_size(device.size), align="right"), - ])] + rows = [ + ( + device, + [ + (2, Text(label)), + Text(labels.desc(device)), + Text(humanize_size(device.size), align="right"), + ], + ) + ] partitions = gaps.parts_and_gaps(device) for part in partitions: if not part_filter(part): continue - details = ", ".join( - labels.annotations(part) + labels.usage_labels(part)) - rows.append((part, [ - Text(labels.label(part, short=True)), - (2, Text(details)), - Text(humanize_size(part.size), align="right"), - ])) + details = ", ".join(labels.annotations(part) + labels.usage_labels(part)) + rows.append( + ( + part, + [ + Text(labels.label(part, short=True)), + (2, Text(details)), + Text(humanize_size(part.size), align="right"), + ], + ) + ) return rows diff --git a/subiquity/ui/views/filesystem/lvm.py b/subiquity/ui/views/filesystem/lvm.py index c67bdd2f..a1e4fcb2 100644 --- a/subiquity/ui/views/filesystem/lvm.py +++ b/subiquity/ui/views/filesystem/lvm.py @@ -17,47 +17,36 @@ import logging import os import re -from urwid import ( - connect_signal, - Text, - ) +from urwid import Text, connect_signal -from subiquitycore.ui.container import ( - Pile, - ) +from subiquity.models.filesystem import get_lvm_size, humanize_size +from subiquity.ui.views.filesystem.compound import ( + CompoundDiskForm, + MultiDeviceField, + get_possible_components, +) +from subiquitycore.ui.container import Pile from subiquitycore.ui.form import ( BooleanField, PasswordField, ReadOnlyField, - simple_field, WantsToKnowFormField, - ) -from subiquitycore.ui.interactive import ( - StringEditor, - ) -from subiquitycore.ui.stretchy import ( - Stretchy, - ) + simple_field, +) +from subiquitycore.ui.interactive import StringEditor +from subiquitycore.ui.stretchy import Stretchy -from subiquity.models.filesystem import ( - get_lvm_size, - humanize_size, - ) -from subiquity.ui.views.filesystem.compound import ( - CompoundDiskForm, - get_possible_components, - MultiDeviceField, - ) - -log = logging.getLogger('subiquity.ui.views.filesystem.lvm') +log = logging.getLogger("subiquity.ui.views.filesystem.lvm") class VGNameEditor(StringEditor, WantsToKnowFormField): def __init__(self): - self.valid_char_pat = r'[-a-zA-Z0-9_+.]' - self.error_invalid_char = _("The only characters permitted in the " - "name of a volume group are a-z, A-Z, " - "0-9, +, _, . and -") + self.valid_char_pat = r"[-a-zA-Z0-9_+.]" + self.error_invalid_char = _( + "The only characters permitted in the " + "name of a volume group are a-z, A-Z, " + "0-9, +, _, . and -" + ) super().__init__() def valid_char(self, ch): @@ -73,16 +62,14 @@ VGNameField = simple_field(VGNameEditor) class VolGroupForm(CompoundDiskForm): - - def __init__(self, model, possible_components, initial, - vg_names, deleted_vg_names): + def __init__(self, model, possible_components, initial, vg_names, deleted_vg_names): self.vg_names = vg_names self.deleted_vg_names = deleted_vg_names super().__init__(model, possible_components, initial) - connect_signal(self.encrypt.widget, 'change', self._change_encrypt) + connect_signal(self.encrypt.widget, "change", self._change_encrypt) self.confirm_passphrase.use_as_confirmation( - for_field=self.passphrase, - desc=_("Passphrases")) + for_field=self.passphrase, desc=_("Passphrases") + ) self._change_encrypt(None, self.encrypt.value) name = VGNameField(_("Name:")) @@ -101,30 +88,33 @@ class VolGroupForm(CompoundDiskForm): def validate_devices(self): if len(self.devices.value) < 1: - return _("Select at least one device to be part of the volume " - "group.") + return _("Select at least one device to be part of the volume " "group.") def validate_name(self): v = self.name.value if not v: return _("The name of a volume group cannot be empty") - if v.startswith('-'): + if v.startswith("-"): return _("The name of a volume group cannot start with a hyphen") if v in self.vg_names: return _("There is already a volume group named '{name}'").format( - name=self.name.value) - if v in ('.', '..', 'md') or os.path.exists('/dev/' + v): + name=self.name.value + ) + if v in (".", "..", "md") or os.path.exists("/dev/" + v): if v not in self.deleted_vg_names: - return _("{name} is not a valid name for a volume " - "group").format(name=v) + return _("{name} is not a valid name for a volume " "group").format( + name=v + ) def validate_passphrase(self): if self.encrypt.value and len(self.passphrase.value) < 1: return _("Passphrase must be set") def validate_confirm_passphrase(self): - if self.encrypt.value and \ - self.passphrase.value != self.confirm_passphrase.value: + if ( + self.encrypt.value + and self.passphrase.value != self.confirm_passphrase.value + ): return _("Passphrases do not match") @@ -134,27 +124,28 @@ class VolGroupStretchy(Stretchy): self.existing = existing vg_names = {vg.name for vg in parent.model.all_volgroups()} orig_vg_names = { - action['name'] for action in parent.model._orig_config - if action['type'] == 'lvm_volgroup' - } + action["name"] + for action in parent.model._orig_config + if action["type"] == "lvm_volgroup" + } if existing is None: - title = _('Create LVM volume group') - label = _('Create') + title = _("Create LVM volume group") + label = _("Create") x = 0 while True: - name = 'vg{}'.format(x) + name = "vg{}".format(x) if name not in vg_names: break x += 1 initial = { - 'devices': {}, - 'name': name, - 'size': '-', - } + "devices": {}, + "name": name, + "size": "-", + } else: vg_names.remove(existing.name) title = _('Edit volume group "{name}"').format(name=existing.name) - label = _('Save') + label = _("Save") devices = {} key = "" encrypt = False @@ -173,54 +164,54 @@ class VolGroupStretchy(Stretchy): if d.key is not None: key = d.key d = d.volume - devices[d] = 'active' + devices[d] = "active" initial = { - 'devices': devices, - 'name': existing.name, - 'encrypt': encrypt, - 'passphrase': key, - 'confirm_passphrase': key, - } + "devices": devices, + "name": existing.name, + "encrypt": encrypt, + "passphrase": key, + "confirm_passphrase": key, + } possible_components = get_possible_components( - self.parent.model, existing, initial['devices'], - lambda dev: dev.ok_for_lvm_vg) + self.parent.model, + existing, + initial["devices"], + lambda dev: dev.ok_for_lvm_vg, + ) deleted_vg_names = orig_vg_names - vg_names form = self.form = VolGroupForm( - self.parent.model, possible_components, initial, - vg_names, deleted_vg_names) + self.parent.model, possible_components, initial, vg_names, deleted_vg_names + ) self.form.buttons.base_widget[0].set_label(label) self.form.devices.widget.set_supports_spares(False) - connect_signal(form.devices.widget, 'change', self._change_devices) - connect_signal(form, 'submit', self.done) - connect_signal(form, 'cancel', self.cancel) + connect_signal(form.devices.widget, "change", self._change_devices) + connect_signal(form, "submit", self.done) + connect_signal(form, "cancel", self.cancel) rows = form.as_rows() - super().__init__( - title, - [Pile(rows), Text(""), self.form.buttons], - 0, 0) + super().__init__(title, [Pile(rows), Text(""), self.form.buttons], 0, 0) def _change_devices(self, sender, new_devices): if len(sender.active_devices) >= 1: self.form.size.value = humanize_size(get_lvm_size(new_devices)) else: - self.form.size.value = '-' + self.form.size.value = "-" def done(self, sender): result = self.form.as_data() - del result['size'] + del result["size"] mdc = self.form.devices.widget - result['devices'] = mdc.active_devices - if 'confirm_passphrase' in result: - del result['confirm_passphrase'] + result["devices"] = mdc.active_devices + if "confirm_passphrase" in result: + del result["confirm_passphrase"] safe_result = result.copy() - if 'passphrase' in safe_result: - safe_result['passphrase'] = '' + if "passphrase" in safe_result: + safe_result["passphrase"] = "" log.debug("vg_done: {}".format(safe_result)) self.parent.controller.volgroup_handler(self.existing, result) self.parent.refresh_model_inputs() diff --git a/subiquity/ui/views/filesystem/partition.py b/subiquity/ui/views/filesystem/partition.py index cc43ef51..7a71a8cf 100644 --- a/subiquity/ui/views/filesystem/partition.py +++ b/subiquity/ui/views/filesystem/partition.py @@ -22,67 +22,65 @@ configuration. import logging import re -from urwid import connect_signal, Text +from urwid import Text, connect_signal +from subiquity.common.filesystem import boot, gaps, labels +from subiquity.models.filesystem import ( + HUMAN_UNITS, + LVM_CHUNK_SIZE, + Disk, + LVM_VolGroup, + align_up, + dehumanize_size, + humanize_size, +) +from subiquity.ui.mount import ( + MountField, + common_mountpoints, + suitable_mountpoints_for_existing_fs, +) +from subiquitycore.ui.container import Pile from subiquitycore.ui.form import ( BooleanField, Form, FormField, - simple_field, WantsToKnowFormField, + simple_field, ) from subiquitycore.ui.interactive import StringEditor from subiquitycore.ui.selector import Option, Selector -from subiquitycore.ui.container import Pile from subiquitycore.ui.stretchy import Stretchy from subiquitycore.ui.utils import rewrap -from subiquity.common.filesystem import boot, gaps, labels -from subiquity.models.filesystem import ( - align_up, - Disk, - HUMAN_UNITS, - dehumanize_size, - humanize_size, - LVM_CHUNK_SIZE, - LVM_VolGroup, -) -from subiquity.ui.mount import ( - common_mountpoints, - MountField, - suitable_mountpoints_for_existing_fs, - ) - - -log = logging.getLogger('subiquity.ui.views.filesystem.partition') +log = logging.getLogger("subiquity.ui.views.filesystem.partition") class FSTypeField(FormField): - takes_default_style = False def _make_widget(self, form): # This will need to do something different for editing an # existing partition that is already formatted. options = [ - ('ext4', True), - ('xfs', True), - ('btrfs', True), - ('---', False), - ('swap', True), + ("ext4", True), + ("xfs", True), + ("btrfs", True), + ("---", False), + ("swap", True), ] if form.existing_fs_type is None: options = options + [ - ('---', False), - (_('Leave unformatted'), True, None), - ] + ("---", False), + (_("Leave unformatted"), True, None), + ] else: - label = _('Leave formatted as {fstype}').format( - fstype=form.existing_fs_type) + label = _("Leave formatted as {fstype}").format( + fstype=form.existing_fs_type + ) options = [ (label, True, None), - ('---', False), - ] + options + ("---", False), + ] + options sel = Selector(opts=options) sel.value = None return sel @@ -97,7 +95,7 @@ class SizeWidget(StringEditor): val = self.value if not val: return - suffixes = ''.join(HUMAN_UNITS) + ''.join(HUMAN_UNITS).lower() + suffixes = "".join(HUMAN_UNITS) + "".join(HUMAN_UNITS).lower() if val[-1] not in suffixes: unit = self.form.size_str[-1] val += unit @@ -109,17 +107,24 @@ class SizeWidget(StringEditor): if sz > self.form.max_size: self.value = self.form.size_str self.form.size.show_extra( - ('info_minor', - _("Capped partition size at {size}").format( - size=self.form.size_str))) + ( + "info_minor", + _("Capped partition size at {size}").format( + size=self.form.size_str + ), + ) + ) else: aligned_sz = align_up(sz, self.form.alignment) aligned_sz_str = humanize_size(aligned_sz) if aligned_sz != sz and aligned_sz_str != self.form.size.value: self.value = aligned_sz_str self.form.size.show_extra( - ('info_minor', _("Rounded size up to {size}").format( - size=aligned_sz_str))) + ( + "info_minor", + _("Rounded size up to {size}").format(size=aligned_sz_str), + ) + ) class SizeField(FormField): @@ -129,10 +134,12 @@ class SizeField(FormField): class LVNameEditor(StringEditor, WantsToKnowFormField): def __init__(self): - self.valid_char_pat = r'[-a-zA-Z0-9_+.]' - self.error_invalid_char = _("The only characters permitted in the " - "name of a logical volume are a-z, A-Z, " - "0-9, +, _, . and -") + self.valid_char_pat = r"[-a-zA-Z0-9_+.]" + self.error_invalid_char = _( + "The only characters permitted in the " + "name of a logical volume are a-z, A-Z, " + "0-9, +, _, . and -" + ) super().__init__() def valid_char(self, ch): @@ -148,7 +155,6 @@ LVNameField = simple_field(LVNameEditor) class PartitionForm(Form): - def __init__(self, model, max_size, initial, lvm_names, device, alignment): self.model = model self.device = device @@ -157,21 +163,22 @@ class PartitionForm(Form): ofstype = device.original_fstype() if ofstype: self.existing_fs_type = ofstype - initial_path = initial.get('mount') + initial_path = initial.get("mount") self.mountpoints = { - m.path: m.device.volume for m in self.model.all_mounts() - if m.path != initial_path} + m.path: m.device.volume + for m in self.model.all_mounts() + if m.path != initial_path + } self.max_size = max_size if max_size is not None: self.size_str = humanize_size(max_size) - self.size.caption = _("Size (max {size}):").format( - size=self.size_str) + self.size.caption = _("Size (max {size}):").format(size=self.size_str) self.lvm_names = lvm_names self.alignment = alignment super().__init__(initial) if max_size is None: - self.remove_field('size') - connect_signal(self.fstype.widget, 'select', self.select_fstype) + self.remove_field("size") + connect_signal(self.fstype.widget, "select", self.select_fstype) self.form_pile = None self.select_fstype(None, self.fstype.widget.value) @@ -196,8 +203,7 @@ class PartitionForm(Form): fstype_for_check = fstype if fstype_for_check is None: fstype_for_check = self.existing_fs_type - self.mount.enabled = self.model.is_mounted_filesystem( - fstype_for_check) + self.mount.enabled = self.model.is_mounted_filesystem(fstype_for_check) self.fstype.value = fstype self.mount.showing_extra = False self.mount.validate() @@ -207,13 +213,13 @@ class PartitionForm(Form): fstype = FSTypeField(_("Format:")) mount = MountField(_("Mount:")) use_swap = BooleanField( - _("Use as swap"), - help=_("Use this swap partition in the installed system.")) + _("Use as swap"), help=_("Use this swap partition in the installed system.") + ) def clean_size(self, val): if not val: return self.max_size - suffixes = ''.join(HUMAN_UNITS) + ''.join(HUMAN_UNITS).lower() + suffixes = "".join(HUMAN_UNITS) + "".join(HUMAN_UNITS).lower() if val[-1] not in suffixes: val += self.size_str[-1] if val == self.size_str: @@ -233,24 +239,31 @@ class PartitionForm(Form): v = self.name.value if not v: return _("The name of a logical volume cannot be empty") - if v.startswith('-'): + if v.startswith("-"): return _("The name of a logical volume cannot start with a hyphen") - if v in ('.', '..', 'snapshot', 'pvmove'): - return _( - "A logical volume may not be called {name}" - ).format(name=v) - for substring in ['_cdata', '_cmeta', '_corig', '_mlog', '_mimage', - '_pmspare', '_rimage', '_rmeta', '_tdata', - '_tmeta', '_vorigin']: + if v in (".", "..", "snapshot", "pvmove"): + return _("A logical volume may not be called {name}").format(name=v) + for substring in [ + "_cdata", + "_cmeta", + "_corig", + "_mlog", + "_mimage", + "_pmspare", + "_rimage", + "_rmeta", + "_tdata", + "_tmeta", + "_vorigin", + ]: if substring in v: return _( - 'The name of a logical volume may not contain ' - '"{substring}"' - ).format(substring=substring) + "The name of a logical volume may not contain " '"{substring}"' + ).format(substring=substring) if v in self.lvm_names: - return _( - "There is already a logical volume named {name}." - ).format(name=self.name.value) + return _("There is already a logical volume named {name}.").format( + name=self.name.value + ) def validate_mount(self): mount = self.mount.value @@ -258,21 +271,26 @@ class PartitionForm(Form): return # /usr/include/linux/limits.h:PATH_MAX if len(mount) > 4095: - return _('Path exceeds PATH_MAX') + return _("Path exceeds PATH_MAX") dev = self.mountpoints.get(mount) if dev is not None: return _("{device} is already mounted at {path}.").format( - device=labels.label(dev).title(), path=mount) + device=labels.label(dev).title(), path=mount + ) if self.existing_fs_type is not None: if self.fstype.value is None: if mount in common_mountpoints: if mount not in suitable_mountpoints_for_existing_fs: self.mount.show_extra( - ('info_error', - _("Mounting an existing filesystem at " - "{mountpoint} is usually a bad idea, " - "proceed only with caution.").format( - mountpoint=mount))) + ( + "info_error", + _( + "Mounting an existing filesystem at " + "{mountpoint} is usually a bad idea, " + "proceed only with caution." + ).format(mountpoint=mount), + ) + ) def as_rows(self): r = super().as_rows() @@ -281,11 +299,12 @@ class PartitionForm(Form): else: exclude = self.use_swap._table i = r.index(exclude) - del r[i-1:i+1] + del r[i - 1 : i + 1] return r -bios_grub_partition_description = _("""\ +bios_grub_partition_description = _( + """\ Bootloader partition {middle} @@ -295,77 +314,93 @@ space after the MBR for GRUB to store its second-stage core.img, so a small unformatted partition is needed at the start of the disk. It will not contain a filesystem and will not be mounted, and cannot be edited here. -""") +""" +) -unconfigured_bios_grub_partition_middle = _("""\ +unconfigured_bios_grub_partition_middle = _( + """\ If this disk is selected as a boot device, GRUB will be installed onto -the target disk's MBR.""") +the target disk's MBR.""" +) -configured_bios_grub_partition_middle = _("""\ +configured_bios_grub_partition_middle = _( + """\ As this disk has been selected as a boot device, GRUB will be -installed onto the target disk's MBR.""") +installed onto the target disk's MBR.""" +) -unconfigured_boot_partition_description = _("""\ +unconfigured_boot_partition_description = _( + """\ Bootloader partition This is an ESP / "EFI system partition" as required by UEFI. If this disk is selected as a boot device, Grub will be installed onto this partition, which must be formatted as fat32. -""") +""" +) -configured_boot_partition_description = _("""\ +configured_boot_partition_description = _( + """\ Bootloader partition This is an ESP / "EFI system partition" as required by UEFI. As this disk has been selected as a boot device, Grub will be installed onto this partition, which must be formatted as fat32. -""") +""" +) -boot_partition_description_size = _("""\ +boot_partition_description_size = _( + """\ The only aspect of this partition that can be edited is the size. -""") +""" +) -boot_partition_description_reformat = _("""\ +boot_partition_description_reformat = _( + """\ You can choose whether to use the existing filesystem on this partition or reformat it. -""") +""" +) -unconfigured_prep_partition_description = _("""\ +unconfigured_prep_partition_description = _( + """\ Required bootloader partition This is the PReP partion which is required on POWER. If this disk is selected as a boot device, Grub will be installed onto this partition. -""") +""" +) -configured_prep_partition_description = _("""\ +configured_prep_partition_description = _( + """\ Required bootloader partition This is the PReP partion which is required on POWER. As this disk has been selected as a boot device, Grub will be installed onto this partition. -""") +""" +) def initial_data_for_fs(fs): r = {} if fs is not None: if fs.preserve: - r['fstype'] = None + r["fstype"] = None if fs.fstype == "swap": - r['use_swap'] = fs.mount() is not None + r["use_swap"] = fs.mount() is not None else: - r['fstype'] = fs.fstype + r["fstype"] = fs.fstype if fs._m.is_mounted_filesystem(fs.fstype): mount = fs.mount() if mount is not None: - r['mount'] = mount.path + r["mount"] = mount.path else: - r['mount'] = None + r["mount"] = None return r class PartitionStretchy(Stretchy): - def __init__(self, parent, disk, *, partition=None, gap=None): self.disk = disk self.partition = partition @@ -375,7 +410,7 @@ class PartitionStretchy(Stretchy): self.parent = parent if partition is None and gap is None: - raise Exception('bad PartitionStretchy - needs partition or gap') + raise Exception("bad PartitionStretchy - needs partition or gap") initial = {} label = _("Create") @@ -388,42 +423,45 @@ class PartitionStretchy(Stretchy): if partition: if partition.flag in ["bios_grub", "prep"]: label = None - initial['mount'] = None + initial["mount"] = None elif boot.is_esp(partition) and not partition.grub_device: label = None else: label = _("Save") - initial['size'] = humanize_size(self.partition.size) - max_size = partition.size + \ - gaps.movable_trailing_partitions_and_gap_size(partition)[1] + initial["size"] = humanize_size(self.partition.size) + max_size = ( + partition.size + + gaps.movable_trailing_partitions_and_gap_size(partition)[1] + ) if not boot.is_esp(partition): initial.update(initial_data_for_fs(self.partition.fs())) else: if partition.fs() and partition.fs().mount(): - initial['mount'] = '/boot/efi' + initial["mount"] = "/boot/efi" else: - initial['mount'] = None + initial["mount"] = None if isinstance(disk, LVM_VolGroup): - initial['name'] = partition.name + initial["name"] = partition.name lvm_names.remove(partition.name) else: - initial['fstype'] = 'ext4' + initial["fstype"] = "ext4" max_size = self.gap.size if isinstance(disk, LVM_VolGroup): x = 0 while True: - name = 'lv-{}'.format(x) + name = "lv-{}".format(x) if name not in lvm_names: break x += 1 - initial['name'] = name + initial["name"] = name self.form = PartitionForm( - self.model, max_size, initial, lvm_names, partition, alignment) + self.model, max_size, initial, lvm_names, partition, alignment + ) if not isinstance(disk, LVM_VolGroup): - self.form.remove_field('name') + self.form.remove_field("name") if label is not None: self.form.buttons.base_widget[0].set_label(label) @@ -435,18 +473,12 @@ class PartitionStretchy(Stretchy): if boot.is_esp(partition): if partition.original_fstype(): opts = [ - Option(( - _("Use existing fat32 filesystem"), - True, - None - )), + Option((_("Use existing fat32 filesystem"), True, None)), Option(("---", False)), - Option(( - _("Reformat as fresh fat32 filesystem"), - True, - "fat32" - )), - ] + Option( + (_("Reformat as fresh fat32 filesystem"), True, "fat32") + ), + ] self.form.fstype.widget.options = opts if partition.fs().preserve: self.form.fstype.widget.index = 0 @@ -469,8 +501,8 @@ class PartitionStretchy(Stretchy): self.form.name.enabled = False self.form.size.enabled = False - connect_signal(self.form, 'submit', self.done) - connect_signal(self.form, 'cancel', self.cancel) + connect_signal(self.form, "submit", self.done) + connect_signal(self.form, "cancel", self.cancel) rows = [] focus_index = 0 @@ -485,30 +517,36 @@ class PartitionStretchy(Stretchy): else: focus_index = 2 desc = _(unconfigured_boot_partition_description) - rows.extend([ - Text(rewrap(desc)), - Text(""), - ]) + rows.extend( + [ + Text(rewrap(desc)), + Text(""), + ] + ) elif self.partition.flag == "bios_grub": if self.partition.device.grub_device: middle = _(configured_bios_grub_partition_middle) else: middle = _(unconfigured_bios_grub_partition_middle) desc = _(bios_grub_partition_description).format(middle=middle) - rows.extend([ - Text(rewrap(desc)), - Text(""), - ]) + rows.extend( + [ + Text(rewrap(desc)), + Text(""), + ] + ) focus_index = 2 elif self.partition.flag == "prep": if self.partition.grub_device: desc = _(configured_prep_partition_description) else: desc = _(unconfigured_prep_partition_description) - rows.extend([ - Text(rewrap(desc)), - Text(""), - ]) + rows.extend( + [ + Text(rewrap(desc)), + Text(""), + ] + ) focus_index = 2 rows.extend(self.form.as_rows()) self.form.form_pile = Pile(rows) @@ -521,22 +559,22 @@ class PartitionStretchy(Stretchy): if partition is None: if isinstance(disk, LVM_VolGroup): title = _("Adding logical volume to {vgname}").format( - vgname=labels.label(disk)) + vgname=labels.label(disk) + ) else: title = _("Adding {ptype} partition to {device}").format( ptype=disk.ptable_for_new_partition().upper(), - device=labels.label(disk)) + device=labels.label(disk), + ) else: if isinstance(disk, LVM_VolGroup): - title = _( - "Editing logical volume {lvname} of {vgname}" - ).format( - lvname=partition.name, - vgname=labels.label(disk)) + title = _("Editing logical volume {lvname} of {vgname}").format( + lvname=partition.name, vgname=labels.label(disk) + ) else: title = _("Editing partition {number} of {device}").format( - number=partition.number, - device=labels.label(disk)) + number=partition.number, device=labels.label(disk) + ) super().__init__(title, widgets, 0, focus_index) @@ -548,11 +586,11 @@ class PartitionStretchy(Stretchy): spec = form.as_data() if self.partition is not None and boot.is_esp(self.partition): if self.partition.original_fstype() is None: - spec['fstype'] = self.partition.fs().fstype + spec["fstype"] = self.partition.fs().fstype if self.partition.fs().mount() is not None: - spec['mount'] = self.partition.fs().mount().path + spec["mount"] = self.partition.fs().mount().path else: - spec['mount'] = None + spec["mount"] = None if isinstance(self.disk, LVM_VolGroup): handler = self.controller.logical_volume_handler else: @@ -563,9 +601,7 @@ class PartitionStretchy(Stretchy): class FormatEntireStretchy(Stretchy): - def __init__(self, parent, device): - self.device = device self.model = parent.model self.controller = parent.controller @@ -576,23 +612,32 @@ class FormatEntireStretchy(Stretchy): if fs is not None: initial.update(initial_data_for_fs(fs)) elif not isinstance(device, Disk): - initial['fstype'] = 'ext4' + initial["fstype"] = "ext4" self.form = PartitionForm( - self.model, 0, initial, None, device, - alignment=device.alignment_data().part_align) - self.form.remove_field('size') - self.form.remove_field('name') + self.model, + 0, + initial, + None, + device, + alignment=device.alignment_data().part_align, + ) + self.form.remove_field("size") + self.form.remove_field("name") - connect_signal(self.form, 'submit', self.done) - connect_signal(self.form, 'cancel', self.cancel) + connect_signal(self.form, "submit", self.done) + connect_signal(self.form, "cancel", self.cancel) rows = [] if isinstance(device, Disk): rows = [ - Text(_("Formatting and mounting a disk directly is unusual. " - "You probably want to add a partition instead.")), + Text( + _( + "Formatting and mounting a disk directly is unusual. " + "You probably want to add a partition instead." + ) + ), Text(""), - ] + ] rows.extend(self.form.as_rows()) self.form.form_pile = Pile(rows) widgets = [ @@ -601,8 +646,7 @@ class FormatEntireStretchy(Stretchy): self.form.buttons, ] - title = _("Format and/or mount {device}").format( - device=labels.label(device)) + title = _("Format and/or mount {device}").format(device=labels.label(device)) super().__init__(title, widgets, 0, 0) diff --git a/subiquity/ui/views/filesystem/probing.py b/subiquity/ui/views/filesystem/probing.py index d2a43b85..3a6203cc 100644 --- a/subiquity/ui/views/filesystem/probing.py +++ b/subiquity/ui/views/filesystem/probing.py @@ -15,43 +15,38 @@ import logging -from urwid import ( - Text, - ) +from urwid import Text -from subiquitycore.ui.buttons import ( - other_btn, - ) -from subiquitycore.ui.spinner import ( - Spinner, - ) -from subiquitycore.ui.utils import ( - button_pile, - screen, - ) +from subiquitycore.ui.buttons import other_btn +from subiquitycore.ui.spinner import Spinner +from subiquitycore.ui.utils import button_pile, screen from subiquitycore.view import BaseView - log = logging.getLogger("subiquity.ui.views.filesystem.probing") class SlowProbing(BaseView): - title = _("Waiting for storage probing to complete") def __init__(self, controller): self.controller = controller self.spinner = Spinner(style="dots") self.spinner.start() - super().__init__(screen( - [ - Text(_("The installer is probing for block devices to install " - "to. Please wait until it completes.")), - Text(""), - self.spinner, - ], - [other_btn(_("Back"), on_press=self.cancel)] - )) + super().__init__( + screen( + [ + Text( + _( + "The installer is probing for block devices to install " + "to. Please wait until it completes." + ) + ), + Text(""), + self.spinner, + ], + [other_btn(_("Back"), on_press=self.cancel)], + ) + ) def cancel(self, result=None): self.controller.cancel() @@ -60,23 +55,28 @@ class SlowProbing(BaseView): fail_text = _( "Unfortunately probing for devices to install to failed. Please report a " "bug on Launchpad, and if possible include the contents of the " - "/var/log/installer directory.") + "/var/log/installer directory." +) class ProbingFailed(BaseView): - title = _("Probing for devices to install to failed") def __init__(self, controller, error_ref): self.controller = controller self.error_ref = error_ref - super().__init__(screen([ - Text(_(fail_text)), - Text(""), - button_pile( - [other_btn(_("Show Error Report"), on_press=self.show_error)]), - ], - [other_btn(_("Back"), on_press=self.cancel)])) + super().__init__( + screen( + [ + Text(_(fail_text)), + Text(""), + button_pile( + [other_btn(_("Show Error Report"), on_press=self.show_error)] + ), + ], + [other_btn(_("Back"), on_press=self.cancel)], + ) + ) def cancel(self, result=None): self.controller.cancel() diff --git a/subiquity/ui/views/filesystem/raid.py b/subiquity/ui/views/filesystem/raid.py index 594684c9..8ee28053 100644 --- a/subiquity/ui/views/filesystem/raid.py +++ b/subiquity/ui/views/filesystem/raid.py @@ -16,43 +16,31 @@ import logging import re -from urwid import ( - connect_signal, - Text, - ) - -from subiquitycore.ui.container import ( - Pile, - ) -from subiquitycore.ui.form import ( - ChoiceField, - ReadOnlyField, - simple_field, - WantsToKnowFormField, - ) -from subiquitycore.ui.interactive import ( - StringEditor, - ) -from subiquitycore.ui.selector import ( - Option, - ) -from subiquitycore.ui.stretchy import ( - Stretchy, - ) +from urwid import Text, connect_signal from subiquity.models.filesystem import ( get_raid_size, humanize_size, raidlevels, raidlevels_by_value, - ) +) from subiquity.ui.views.filesystem.compound import ( CompoundDiskForm, - get_possible_components, MultiDeviceField, - ) + get_possible_components, +) +from subiquitycore.ui.container import Pile +from subiquitycore.ui.form import ( + ChoiceField, + ReadOnlyField, + WantsToKnowFormField, + simple_field, +) +from subiquitycore.ui.interactive import StringEditor +from subiquitycore.ui.selector import Option +from subiquitycore.ui.stretchy import Stretchy -log = logging.getLogger('subiquity.ui.views.filesystem.raid') +log = logging.getLogger("subiquity.ui.views.filesystem.raid") raidlevel_choices = [ @@ -64,17 +52,20 @@ raidlevel_choices = [ class RaidnameEditor(StringEditor, WantsToKnowFormField): def valid_char(self, ch): - if len(ch) == 1 and ch == '/': + if len(ch) == 1 and ch == "/": self.bff.in_error = True - self.bff.show_extra(("info_error", - _("/ is not permitted " - "in the name of a RAID device"))) + self.bff.show_extra( + ("info_error", _("/ is not permitted " "in the name of a RAID device")) + ) return False elif len(ch) == 1 and ch.isspace(): self.bff.in_error = True - self.bff.show_extra(("info_error", - _("Whitespace is not permitted in the " - "name of a RAID device"))) + self.bff.show_extra( + ( + "info_error", + _("Whitespace is not permitted in the " "name of a RAID device"), + ) + ) return False else: return super().valid_char(ch) @@ -84,7 +75,6 @@ RaidnameField = simple_field(RaidnameEditor) class RaidForm(CompoundDiskForm): - def __init__(self, model, possible_components, initial, raid_names): self.raid_names = raid_names super().__init__(model, possible_components, initial) @@ -98,25 +88,26 @@ class RaidForm(CompoundDiskForm): def clean_name(self, val): if not val: raise ValueError("The name cannot be empty") - if not re.match('md[0-9]+', val): - val = 'md/' + val + if not re.match("md[0-9]+", val): + val = "md/" + val return val def validate_name(self): if self.name.value in self.raid_names: return _("There is already a RAID named '{name}'").format( - name=self.name.value) - if self.name.value in ('/dev/md/.', '/dev/md/..'): + name=self.name.value + ) + if self.name.value in ("/dev/md/.", "/dev/md/.."): return _(". and .. are not valid names for RAID devices") def validate_devices(self): active_device_count = len(self.devices.widget.active_devices) if active_device_count < self.level.value.min_devices: return _( - 'RAID Level "{level}" requires at least {min_active} active' - ' devices').format( - level=self.level.value.name, - min_active=self.level.value.min_devices) + 'RAID Level "{level}" requires at least {min_active} active' " devices" + ).format( + level=self.level.value.name, min_active=self.level.value.min_devices + ) return super().validate_devices() @@ -127,68 +118,65 @@ class RaidStretchy(Stretchy): raid_names = {raid.name for raid in parent.model.all_raids()} if existing is None: title = _('Create software RAID ("MD") disk') - label = _('Create') + label = _("Create") x = 0 while True: - name = 'md{}'.format(x) + name = "md{}".format(x) if name not in raid_names: break x += 1 initial = { - 'devices': {}, - 'name': name, - 'level': raidlevels_by_value["raid1"], - 'size': '-', - } + "devices": {}, + "name": name, + "level": raidlevels_by_value["raid1"], + "size": "-", + } else: raid_names.remove(existing.name) - title = _('Edit software RAID disk "{name}"').format( - name=existing.name) - label = _('Save') + title = _('Edit software RAID disk "{name}"').format(name=existing.name) + label = _("Save") name = existing.name - if name.startswith('md/'): + if name.startswith("md/"): name = name[3:] devices = {} for d in existing.devices: - devices[d] = 'active' + devices[d] = "active" for d in existing.spare_devices: - devices[d] = 'spare' + devices[d] = "spare" initial = { - 'devices': devices, - 'name': name, - 'level': raidlevels_by_value[existing.raidlevel] - } + "devices": devices, + "name": name, + "level": raidlevels_by_value[existing.raidlevel], + } possible_components = get_possible_components( - self.parent.model, existing, initial['devices'], - lambda dev: dev.ok_for_raid) + self.parent.model, existing, initial["devices"], lambda dev: dev.ok_for_raid + ) form = self.form = RaidForm( - self.parent.model, possible_components, initial, raid_names) + self.parent.model, possible_components, initial, raid_names + ) - form.devices.widget.set_supports_spares( - initial['level'].supports_spares) + form.devices.widget.set_supports_spares(initial["level"].supports_spares) form.buttons.base_widget[0].set_label(label) - connect_signal(form.level.widget, 'select', self._select_level) - connect_signal(form.devices.widget, 'change', self._change_devices) - connect_signal(form, 'submit', self.done) - connect_signal(form, 'cancel', self.cancel) + connect_signal(form.level.widget, "select", self._select_level) + connect_signal(form.devices.widget, "change", self._change_devices) + connect_signal(form, "submit", self.done) + connect_signal(form, "cancel", self.cancel) rows = form.as_rows() - super().__init__( - title, - [Pile(rows), Text(""), self.form.buttons], - 0, 0) + super().__init__(title, [Pile(rows), Text(""), self.form.buttons], 0, 0) def _select_level(self, sender, new_level): active_device_count = len(self.form.devices.widget.active_devices) if active_device_count >= new_level.min_devices: self.form.size.value = humanize_size( - get_raid_size(new_level.value, self.form.devices.value)) + get_raid_size(new_level.value, self.form.devices.value) + ) else: - self.form.size.value = '-' + self.form.size.value = "-" self.form.devices.widget.set_supports_spares(new_level.supports_spares) self.form.level.value = new_level self.form.devices.showing_extra = False @@ -197,16 +185,17 @@ class RaidStretchy(Stretchy): def _change_devices(self, sender, new_devices): if len(sender.active_devices) >= self.form.level.value.min_devices: self.form.size.value = humanize_size( - get_raid_size(self.form.level.value.value, new_devices)) + get_raid_size(self.form.level.value.value, new_devices) + ) else: - self.form.size.value = '-' + self.form.size.value = "-" def done(self, sender): result = self.form.as_data() mdc = self.form.devices.widget - result['devices'] = mdc.active_devices - result['spare_devices'] = mdc.spare_devices - log.debug('raid_done: result = {}'.format(result)) + result["devices"] = mdc.active_devices + result["spare_devices"] = mdc.spare_devices + log.debug("raid_done: result = {}".format(result)) self.parent.controller.raid_handler(self.existing, result) self.parent.refresh_model_inputs() self.parent.remove_overlay() diff --git a/subiquity/ui/views/filesystem/tests/test_filesystem.py b/subiquity/ui/views/filesystem/tests/test_filesystem.py index 50f511d3..6e2e60e4 100644 --- a/subiquity/ui/views/filesystem/tests/test_filesystem.py +++ b/subiquity/ui/views/filesystem/tests/test_filesystem.py @@ -3,23 +3,14 @@ from unittest import mock import urwid -from subiquitycore.testing import view_helpers - from subiquity.client.controllers.filesystem import FilesystemController -from subiquity.models.filesystem import ( - Bootloader, - Disk, - FilesystemModel, - ) -from subiquity.models.tests.test_filesystem import ( - FakeStorageInfo, - make_model, - ) +from subiquity.models.filesystem import Bootloader, Disk, FilesystemModel +from subiquity.models.tests.test_filesystem import FakeStorageInfo, make_model from subiquity.ui.views.filesystem.filesystem import FilesystemView +from subiquitycore.testing import view_helpers class FilesystemViewTests(unittest.TestCase): - def make_view(self, model, devices=[]): controller = mock.create_autospec(spec=FilesystemController) controller.ui = mock.Mock() @@ -36,10 +27,13 @@ class FilesystemViewTests(unittest.TestCase): model._actions = [] model._all_ids = set() disk = Disk( - m=model, serial="DISK-SERIAL", path='/dev/thing', - info=FakeStorageInfo(size=100*(2**20), free=50*(2**20))) + m=model, + serial="DISK-SERIAL", + path="/dev/thing", + info=FakeStorageInfo(size=100 * (2**20), free=50 * (2**20)), + ) view = self.make_view(model, [disk]) w = view_helpers.find_with_pred( - view, - lambda w: isinstance(w, urwid.Text) and "DISK-SERIAL" in w.text) + view, lambda w: isinstance(w, urwid.Text) and "DISK-SERIAL" in w.text + ) self.assertIsNotNone(w, "could not find DISK-SERIAL in view") diff --git a/subiquity/ui/views/filesystem/tests/test_lvm.py b/subiquity/ui/views/filesystem/tests/test_lvm.py index e3191d29..ad665dbe 100644 --- a/subiquity/ui/views/filesystem/tests/test_lvm.py +++ b/subiquity/ui/views/filesystem/tests/test_lvm.py @@ -18,14 +18,11 @@ from unittest import mock import urwid -from subiquitycore.testing import view_helpers -from subiquitycore.view import BaseView - from subiquity.client.controllers.filesystem import FilesystemController from subiquity.ui.views.filesystem.lvm import VolGroupStretchy -from subiquity.ui.views.filesystem.tests.test_partition import ( - make_model_and_disk, - ) +from subiquity.ui.views.filesystem.tests.test_partition import make_model_and_disk +from subiquitycore.testing import view_helpers +from subiquitycore.view import BaseView def make_view(model, existing=None): @@ -40,45 +37,42 @@ def make_view(model, existing=None): class LVMViewTests(unittest.TestCase): - def test_create_vg(self): model, disk = make_model_and_disk() - part1 = model.add_partition(disk, size=10*(2**30), offset=0) - part2 = model.add_partition(disk, size=10*(2**30), offset=10*(2**30)) + part1 = model.add_partition(disk, size=10 * (2**30), offset=0) + part2 = model.add_partition(disk, size=10 * (2**30), offset=10 * (2**30)) view, stretchy = make_view(model) form_data = { - 'name': 'vg1', - 'devices': {part1: 'active', part2: 'active'}, - } + "name": "vg1", + "devices": {part1: "active", part2: "active"}, + } expected_data = { - 'name': 'vg1', - 'devices': {part1, part2}, - 'encrypt': False, - } + "name": "vg1", + "devices": {part1, part2}, + "encrypt": False, + } view_helpers.enter_data(stretchy.form, form_data) view_helpers.click(stretchy.form.done_btn.base_widget) - view.controller.volgroup_handler.assert_called_once_with( - None, expected_data) + view.controller.volgroup_handler.assert_called_once_with(None, expected_data) def test_create_vg_encrypted(self): model, disk = make_model_and_disk() - part1 = model.add_partition(disk, size=10*(2**30), offset=0) - part2 = model.add_partition(disk, size=10*(2**30), offset=10*(2**30)) + part1 = model.add_partition(disk, size=10 * (2**30), offset=0) + part2 = model.add_partition(disk, size=10 * (2**30), offset=10 * (2**30)) view, stretchy = make_view(model) form_data = { - 'name': 'vg1', - 'devices': {part1: 'active', part2: 'active'}, - 'encrypt': True, - 'passphrase': 'passw0rd', - 'confirm_passphrase': 'passw0rd', - } + "name": "vg1", + "devices": {part1: "active", part2: "active"}, + "encrypt": True, + "passphrase": "passw0rd", + "confirm_passphrase": "passw0rd", + } expected_data = { - 'name': 'vg1', - 'devices': {part1, part2}, - 'encrypt': True, - 'passphrase': 'passw0rd', - } + "name": "vg1", + "devices": {part1, part2}, + "encrypt": True, + "passphrase": "passw0rd", + } view_helpers.enter_data(stretchy.form, form_data) view_helpers.click(stretchy.form.done_btn.base_widget) - view.controller.volgroup_handler.assert_called_once_with( - None, expected_data) + view.controller.volgroup_handler.assert_called_once_with(None, expected_data) diff --git a/subiquity/ui/views/filesystem/tests/test_partition.py b/subiquity/ui/views/filesystem/tests/test_partition.py index cde9fbaa..ab5363aa 100644 --- a/subiquity/ui/views/filesystem/tests/test_partition.py +++ b/subiquity/ui/views/filesystem/tests/test_partition.py @@ -3,21 +3,16 @@ from unittest import mock import urwid -from subiquitycore.testing import view_helpers -from subiquitycore.view import BaseView - from subiquity.client.controllers.filesystem import FilesystemController from subiquity.common.filesystem import gaps -from subiquity.models.filesystem import ( - dehumanize_size, - ) -from subiquity.models.tests.test_filesystem import ( - make_model_and_disk, - ) +from subiquity.models.filesystem import dehumanize_size +from subiquity.models.tests.test_filesystem import make_model_and_disk from subiquity.ui.views.filesystem.partition import ( FormatEntireStretchy, PartitionStretchy, - ) +) +from subiquitycore.testing import view_helpers +from subiquitycore.view import BaseView def make_partition_view(model, disk, **kw): @@ -43,7 +38,6 @@ def make_format_entire_view(model, disk): class PartitionViewTests(unittest.TestCase): - def test_initial_focus(self): model, disk = make_model_and_disk() gap = gaps.Gap(device=disk, offset=1 << 20, size=99 << 30) @@ -57,60 +51,60 @@ class PartitionViewTests(unittest.TestCase): def test_create_partition(self): valid_data = { - 'size': "1M", - 'fstype': "ext4", - } + "size": "1M", + "fstype": "ext4", + } model, disk = make_model_and_disk() gap = gaps.Gap(device=disk, offset=1 << 20, size=99 << 30) view, stretchy = make_partition_view(model, disk, gap=gap) view_helpers.enter_data(stretchy.form, valid_data) view_helpers.click(stretchy.form.done_btn.base_widget) - valid_data['mount'] = '/' - valid_data['size'] = dehumanize_size(valid_data['size']) - valid_data['use_swap'] = False + valid_data["mount"] = "/" + valid_data["size"] = dehumanize_size(valid_data["size"]) + valid_data["use_swap"] = False view.controller.partition_disk_handler.assert_called_once_with( - stretchy.disk, valid_data, partition=None, gap=gap) + stretchy.disk, valid_data, partition=None, gap=gap + ) def test_edit_partition(self): form_data = { - 'size': "256M", - 'fstype': "xfs", - } + "size": "256M", + "fstype": "xfs", + } model, disk = make_model_and_disk() - partition = model.add_partition(disk, size=512*(2**20), offset=0) + partition = model.add_partition(disk, size=512 * (2**20), offset=0) model.add_filesystem(partition, "ext4") view, stretchy = make_partition_view(model, disk, partition=partition) self.assertTrue(stretchy.form.done_btn.enabled) view_helpers.enter_data(stretchy.form, form_data) view_helpers.click(stretchy.form.done_btn.base_widget) expected_data = { - 'size': dehumanize_size(form_data['size']), - 'fstype': 'xfs', - 'mount': None, - 'use_swap': False, - } + "size": dehumanize_size(form_data["size"]), + "fstype": "xfs", + "mount": None, + "use_swap": False, + } view.controller.partition_disk_handler.assert_called_once_with( - stretchy.disk, expected_data, - partition=stretchy.partition, gap=None) + stretchy.disk, expected_data, partition=stretchy.partition, gap=None + ) def test_size_clamping(self): model, disk = make_model_and_disk() - partition = model.add_partition(disk, size=512*(2**20), offset=0) + partition = model.add_partition(disk, size=512 * (2**20), offset=0) model.add_filesystem(partition, "ext4") view, stretchy = make_partition_view(model, disk, partition=partition) self.assertTrue(stretchy.form.done_btn.enabled) stretchy.form.size.value = "1000T" stretchy.form.size.widget.lost_focus() self.assertTrue(stretchy.form.size.showing_extra) - self.assertIn( - "Capped partition size", stretchy.form.size.under_text.text) + self.assertIn("Capped partition size", stretchy.form.size.under_text.text) def test_edit_existing_partition(self): form_data = { - 'fstype': "xfs", - } + "fstype": "xfs", + } model, disk = make_model_and_disk() - partition = model.add_partition(disk, size=512*(2**20), offset=0) + partition = model.add_partition(disk, size=512 * (2**20), offset=0) partition.preserve = True model.add_filesystem(partition, "ext4") view, stretchy = make_partition_view(model, disk, partition=partition) @@ -119,19 +113,19 @@ class PartitionViewTests(unittest.TestCase): view_helpers.enter_data(stretchy.form, form_data) view_helpers.click(stretchy.form.done_btn.base_widget) expected_data = { - 'fstype': 'xfs', - 'mount': None, - 'use_swap': False, - } + "fstype": "xfs", + "mount": None, + "use_swap": False, + } view.controller.partition_disk_handler.assert_called_once_with( - stretchy.disk, expected_data, - partition=stretchy.partition, gap=None) + stretchy.disk, expected_data, partition=stretchy.partition, gap=None + ) def test_edit_existing_partition_mountpoints(self): # Set up a PartitionStretchy for editing a partition with an # existing filesystem. model, disk = make_model_and_disk() - partition = model.add_partition(disk, size=512*(2**20), offset=0) + partition = model.add_partition(disk, size=512 * (2**20), offset=0) partition.preserve = True partition.number = 1 fs = model.add_filesystem(partition, "ext4") @@ -149,33 +143,32 @@ class PartitionViewTests(unittest.TestCase): # The option for mounting at / is disabled. But /srv is still # enabled. selector = stretchy.form.mount.widget._selector - self.assertFalse(selector.option_by_value('/').enabled) - self.assertTrue(selector.option_by_value('/srv').enabled) + self.assertFalse(selector.option_by_value("/").enabled) + self.assertTrue(selector.option_by_value("/srv").enabled) # Typing in an unsuitable mountpoint triggers a message. stretchy.form.mount.value = "/boot" stretchy.form.mount.validate() self.assertTrue(stretchy.form.mount.showing_extra) - self.assertIn( - "bad idea", stretchy.form.mount.under_text.text) - self.assertIn( - "/boot", stretchy.form.mount.under_text.text) + self.assertIn("bad idea", stretchy.form.mount.under_text.text) + self.assertIn("/boot", stretchy.form.mount.under_text.text) # Selecting to reformat the partition clears the message and # reenables the / option. stretchy.form.select_fstype(None, "ext4") self.assertFalse(stretchy.form.mount.showing_extra) - self.assertTrue(selector.option_by_value('/').enabled) + self.assertTrue(selector.option_by_value("/").enabled) def test_edit_boot_partition(self): form_data = { - 'size': "256M", - } + "size": "256M", + } model, disk = make_model_and_disk() - partition = model.add_partition(disk, size=512*(2**20), - offset=0, flag="boot") + partition = model.add_partition( + disk, size=512 * (2**20), offset=0, flag="boot" + ) fs = model.add_filesystem(partition, "fat32") - model.add_mount(fs, '/boot/efi') + model.add_mount(fs, "/boot/efi") view, stretchy = make_partition_view(model, disk, partition=partition) self.assertFalse(stretchy.form.fstype.enabled) @@ -186,19 +179,20 @@ class PartitionViewTests(unittest.TestCase): view_helpers.enter_data(stretchy.form, form_data) view_helpers.click(stretchy.form.done_btn.base_widget) expected_data = { - 'size': dehumanize_size(form_data['size']), - 'fstype': "fat32", - 'mount': '/boot/efi', - 'use_swap': False, - } + "size": dehumanize_size(form_data["size"]), + "fstype": "fat32", + "mount": "/boot/efi", + "use_swap": False, + } view.controller.partition_disk_handler.assert_called_once_with( - stretchy.disk, expected_data, - partition=stretchy.partition, gap=None) + stretchy.disk, expected_data, partition=stretchy.partition, gap=None + ) def test_edit_existing_unused_boot_partition(self): model, disk = make_model_and_disk() - partition = model.add_partition(disk, size=512*(2**20), - offset=0, flag="boot") + partition = model.add_partition( + disk, size=512 * (2**20), offset=0, flag="boot" + ) fs = model.add_filesystem(partition, "fat32") model._orig_config = model._render_actions() disk.preserve = partition.preserve = fs.preserve = True @@ -211,38 +205,39 @@ class PartitionViewTests(unittest.TestCase): view_helpers.click(stretchy.form.done_btn.base_widget) expected_data = { - 'mount': None, - 'use_swap': False, - } + "mount": None, + "use_swap": False, + } view.controller.partition_disk_handler.assert_called_once_with( - stretchy.disk, expected_data, - partition=stretchy.partition, gap=None) + stretchy.disk, expected_data, partition=stretchy.partition, gap=None + ) def test_edit_existing_used_boot_partition(self): model, disk = make_model_and_disk() - partition = model.add_partition(disk, size=512*(2**20), - offset=0, flag="boot") + partition = model.add_partition( + disk, size=512 * (2**20), offset=0, flag="boot" + ) fs = model.add_filesystem(partition, "fat32") model._orig_config = model._render_actions() partition.grub_device = True disk.preserve = partition.preserve = fs.preserve = True - model.add_mount(fs, '/boot/efi') + model.add_mount(fs, "/boot/efi") view, stretchy = make_partition_view(model, disk, partition=partition) self.assertTrue(stretchy.form.fstype.enabled) self.assertEqual(stretchy.form.fstype.value, None) self.assertFalse(stretchy.form.mount.enabled) - self.assertEqual(stretchy.form.mount.value, '/boot/efi') + self.assertEqual(stretchy.form.mount.value, "/boot/efi") view_helpers.click(stretchy.form.done_btn.base_widget) expected_data = { - 'fstype': None, - 'mount': '/boot/efi', - 'use_swap': False, - } + "fstype": None, + "mount": "/boot/efi", + "use_swap": False, + } view.controller.partition_disk_handler.assert_called_once_with( - stretchy.disk, expected_data, - partition=stretchy.partition, gap=None) + stretchy.disk, expected_data, partition=stretchy.partition, gap=None + ) def test_format_entire_unusual_filesystem(self): model, disk = make_model_and_disk() @@ -250,5 +245,4 @@ class PartitionViewTests(unittest.TestCase): fs.preserve = True model._orig_config = model._render_actions() view, stretchy = make_format_entire_view(model, disk) - self.assertEqual( - stretchy.form.fstype.value, None) + self.assertEqual(stretchy.form.fstype.value, None) diff --git a/subiquity/ui/views/filesystem/tests/test_raid.py b/subiquity/ui/views/filesystem/tests/test_raid.py index db061389..dc1e31fb 100644 --- a/subiquity/ui/views/filesystem/tests/test_raid.py +++ b/subiquity/ui/views/filesystem/tests/test_raid.py @@ -3,18 +3,13 @@ from unittest import mock import urwid +from subiquity.client.controllers.filesystem import FilesystemController +from subiquity.models.filesystem import raidlevels_by_value +from subiquity.ui.views.filesystem.raid import RaidStretchy +from subiquity.ui.views.filesystem.tests.test_partition import make_model_and_disk from subiquitycore.testing import view_helpers from subiquitycore.view import BaseView -from subiquity.client.controllers.filesystem import FilesystemController -from subiquity.models.filesystem import ( - raidlevels_by_value, - ) -from subiquity.ui.views.filesystem.raid import RaidStretchy -from subiquity.ui.views.filesystem.tests.test_partition import ( - make_model_and_disk, - ) - def make_view(model, existing=None): controller = mock.create_autospec(spec=FilesystemController) @@ -28,45 +23,42 @@ def make_view(model, existing=None): class RaidViewTests(unittest.TestCase): - def test_create_raid(self): model, disk = make_model_and_disk() - part1 = model.add_partition(disk, size=10*(2**30), offset=0) - part2 = model.add_partition(disk, size=10*(2**30), offset=10*(2**30)) - part3 = model.add_partition(disk, size=10*(2**30), offset=20*(2**30)) + part1 = model.add_partition(disk, size=10 * (2**30), offset=0) + part2 = model.add_partition(disk, size=10 * (2**30), offset=10 * (2**30)) + part3 = model.add_partition(disk, size=10 * (2**30), offset=20 * (2**30)) view, stretchy = make_view(model) form_data = { - 'name': 'md0', - 'devices': {part1: 'active', part2: 'active', part3: 'spare'}, - } + "name": "md0", + "devices": {part1: "active", part2: "active", part3: "spare"}, + } expected_data = { - 'name': 'md0', - 'devices': {part1, part2}, - 'spare_devices': {part3}, - 'level': raidlevels_by_value["raid1"], - } + "name": "md0", + "devices": {part1, part2}, + "spare_devices": {part3}, + "level": raidlevels_by_value["raid1"], + } view_helpers.enter_data(stretchy.form, form_data) view_helpers.click(stretchy.form.done_btn.base_widget) - view.controller.raid_handler.assert_called_once_with( - None, expected_data) + view.controller.raid_handler.assert_called_once_with(None, expected_data) def test_edit_raid(self): model, disk = make_model_and_disk() - part1 = model.add_partition(disk, size=10*(2**30), offset=0) - part2 = model.add_partition(disk, size=10*(2**30), offset=10*(2**30)) + part1 = model.add_partition(disk, size=10 * (2**30), offset=0) + part2 = model.add_partition(disk, size=10 * (2**30), offset=10 * (2**30)) raid = model.add_raid("md0", "raid1", {part1, part2}, set()) view, stretchy = make_view(model, raid) form_data = { - 'name': 'md1', - 'level': raidlevels_by_value["raid0"], - } + "name": "md1", + "level": raidlevels_by_value["raid0"], + } expected_data = { - 'name': 'md1', - 'devices': {part1, part2}, - 'spare_devices': set(), - 'level': raidlevels_by_value["raid0"], - } + "name": "md1", + "devices": {part1, part2}, + "spare_devices": set(), + "level": raidlevels_by_value["raid0"], + } view_helpers.enter_data(stretchy.form, form_data) view_helpers.click(stretchy.form.done_btn.base_widget) - view.controller.raid_handler.assert_called_once_with( - raid, expected_data) + view.controller.raid_handler.assert_called_once_with(raid, expected_data) diff --git a/subiquity/ui/views/help.py b/subiquity/ui/views/help.py index f217a9a5..f3c5f74f 100644 --- a/subiquity/ui/views/help.py +++ b/subiquity/ui/views/help.py @@ -16,59 +16,31 @@ import logging import os -from urwid import ( - connect_signal, - Divider, - Filler, - PopUpLauncher, - Text, - ) - -from subiquitycore.async_helpers import run_bg_task -from subiquitycore.lsb_release import lsb_release -from subiquitycore.ssh import summarize_host_keys -from subiquitycore.ui.buttons import ( - header_btn, - other_btn, - ) -from subiquitycore.ui.container import ( - Columns, - Pile, - WidgetWrap, - ) -from subiquitycore.ui.utils import ( - button_pile, - ClickableIcon, - Color, - ) -from subiquitycore.ui.stretchy import ( - Stretchy, - ) -from subiquitycore.ui.table import ( - ColSpec, - TablePile, - TableRow, - ) -from subiquitycore.ui.utils import ( - rewrap, - ) -from subiquitycore.ui.width import ( - widget_width, - ) +from urwid import Divider, Filler, PopUpLauncher, Text, connect_signal from subiquity.common.types import PasswordKind from subiquity.ui.views.error import ErrorReportListStretchy +from subiquitycore.async_helpers import run_bg_task +from subiquitycore.lsb_release import lsb_release +from subiquitycore.ssh import summarize_host_keys +from subiquitycore.ui.buttons import header_btn, other_btn +from subiquitycore.ui.container import Columns, Pile, WidgetWrap +from subiquitycore.ui.stretchy import Stretchy +from subiquitycore.ui.table import ColSpec, TablePile, TableRow +from subiquitycore.ui.utils import ClickableIcon, Color, button_pile, rewrap +from subiquitycore.ui.width import widget_width -log = logging.getLogger('subiquity.ui.views.help') +log = logging.getLogger("subiquity.ui.views.help") def close_btn(app, stretchy): return other_btn( - _("Close"), - on_press=lambda sender: app.remove_global_overlay(stretchy)) + _("Close"), on_press=lambda sender: app.remove_global_overlay(stretchy) + ) -ABOUT_INSTALLER = _(""" +ABOUT_INSTALLER = _( + """ Welcome to the Ubuntu Server Installer! The most popular server Linux in the cloud and data centre, this @@ -81,10 +53,12 @@ The installer only requires the up and down arrow keys, space (or return) and the occasional bit of typing. This is version {snap_version} of the installer. -""") +""" +) -ABOUT_INSTALLER_LTS = _(""" +ABOUT_INSTALLER_LTS = _( + """ Welcome to the Ubuntu Server Installer! The most popular server Linux in the cloud and data centre, you can @@ -97,56 +71,78 @@ The installer only requires the up and down arrow keys, space (or return) and the occasional bit of typing. This is version {snap_version} of the installer. -""") +""" +) -SSH_HELP_PROLOGUE = _(""" +SSH_HELP_PROLOGUE = _( + """ It is possible to connect to the installer over the network, which might allow the use of a more capable terminal and can offer more languages -than can be rendered in the Linux console.""") +than can be rendered in the Linux console.""" +) -SSH_HELP_MULTIPLE_ADDRESSES = _(""" +SSH_HELP_MULTIPLE_ADDRESSES = _( + """ To connect, SSH to any of these addresses: -""") +""" +) -SSH_HELP_ONE_ADDRESSES = _(""" -To connect, SSH to {username}@{ip}.""") +SSH_HELP_ONE_ADDRESSES = _( + """ +To connect, SSH to {username}@{ip}.""" +) -SSH_HELP_EPILOGUE_KNOWN_PASS_NO_KEYS = _("""\ -The password you should use is "{password}".""") +SSH_HELP_EPILOGUE_KNOWN_PASS_NO_KEYS = _( + """\ +The password you should use is "{password}".""" +) -SSH_HELP_EPILOGUE_UNKNOWN_PASS_NO_KEYS = _("""\ -You should use the preconfigured password passed to cloud-init.""") +SSH_HELP_EPILOGUE_UNKNOWN_PASS_NO_KEYS = _( + """\ +You should use the preconfigured password passed to cloud-init.""" +) -SSH_HELP_EPILOGUE_ONE_KEY = _("""\ +SSH_HELP_EPILOGUE_ONE_KEY = _( + """\ You can log in with the {keytype} key with fingerprint: {fingerprint} -""") +""" +) -SSH_HELP_EPILOGUE_MULTIPLE_KEYS = _("""\ +SSH_HELP_EPILOGUE_MULTIPLE_KEYS = _( + """\ You can log in with one of the following keys: -""") +""" +) -SSH_HELP_EPILOGUE_KNOWN_PASS_KEYS = _(""" -Or you can use the password "{password}".""") +SSH_HELP_EPILOGUE_KNOWN_PASS_KEYS = _( + """ +Or you can use the password "{password}".""" +) -SSH_HELP_EPILOGUE_UNKNOWN_PASS_KEYS = _(""" -Or you can use the preconfigured password passed to cloud-init.""") +SSH_HELP_EPILOGUE_UNKNOWN_PASS_KEYS = _( + """ +Or you can use the preconfigured password passed to cloud-init.""" +) -SSH_HELP_NO_ADDRESSES = _(""" +SSH_HELP_NO_ADDRESSES = _( + """ Unfortunately this system seems to have no global IP addresses at this time. -""") +""" +) -SSH_HELP_NO_PASSWORD = _(""" +SSH_HELP_NO_PASSWORD = _( + """ Unfortunately the installer was unable to detect the password that has been set. -""") +""" +) def ssh_help_texts(ssh_info): - texts = [_(SSH_HELP_PROLOGUE), ""] if ssh_info is not None: @@ -154,43 +150,61 @@ def ssh_help_texts(ssh_info): texts.append(_(SSH_HELP_MULTIPLE_ADDRESSES)) texts.append("") for ip in ssh_info.ips: - texts.append(Text( - "{}@{}".format(ssh_info.username, ip), align='center')) + texts.append( + Text("{}@{}".format(ssh_info.username, ip), align="center") + ) else: - texts.append(_(SSH_HELP_ONE_ADDRESSES).format( - username=ssh_info.username, ip=str(ssh_info.ips[0]))) + texts.append( + _(SSH_HELP_ONE_ADDRESSES).format( + username=ssh_info.username, ip=str(ssh_info.ips[0]) + ) + ) texts.append("") if ssh_info.authorized_key_fingerprints: if len(ssh_info.authorized_key_fingerprints) == 1: key = ssh_info.authorized_key_fingerprints[0] - texts.append(Text(_(SSH_HELP_EPILOGUE_ONE_KEY).format( - keytype=key.keytype, fingerprint=key.fingerprint))) + texts.append( + Text( + _(SSH_HELP_EPILOGUE_ONE_KEY).format( + keytype=key.keytype, fingerprint=key.fingerprint + ) + ) + ) else: texts.append(_(SSH_HELP_EPILOGUE_MULTIPLE_KEYS)) texts.append("") rows = [] for key in ssh_info.authorized_key_fingerprints: - rows.append( - TableRow([Text(key.keytype), Text(key.fingerprint)])) + rows.append(TableRow([Text(key.keytype), Text(key.fingerprint)])) texts.append(TablePile(rows)) if ssh_info.password_kind == PasswordKind.KNOWN: texts.append("") - texts.append(SSH_HELP_EPILOGUE_KNOWN_PASS_KEYS.format( - password=ssh_info.password)) + texts.append( + SSH_HELP_EPILOGUE_KNOWN_PASS_KEYS.format(password=ssh_info.password) + ) elif ssh_info.password_kind == PasswordKind.UNKNOWN: texts.append("") texts.append(SSH_HELP_EPILOGUE_UNKNOWN_PASS_KEYS) else: if ssh_info.password_kind == PasswordKind.KNOWN: - texts.append(SSH_HELP_EPILOGUE_KNOWN_PASS_NO_KEYS.format( - password=ssh_info.password)) + texts.append( + SSH_HELP_EPILOGUE_KNOWN_PASS_NO_KEYS.format( + password=ssh_info.password + ) + ) elif ssh_info.password_kind == PasswordKind.UNKNOWN: texts.append(SSH_HELP_EPILOGUE_UNKNOWN_PASS_NO_KEYS) texts.append("") - texts.append(Text(summarize_host_keys([ - (key.keytype, key.fingerprint) - for key in ssh_info.host_key_fingerprints - ]))) + texts.append( + Text( + summarize_host_keys( + [ + (key.keytype, key.fingerprint) + for key in ssh_info.host_key_fingerprints + ] + ) + ) + ) else: texts.append("") texts.append(_(SSH_HELP_NO_ADDRESSES)) @@ -199,7 +213,6 @@ def ssh_help_texts(ssh_info): class SimpleTextStretchy(Stretchy): - def __init__(self, app, title, *texts): widgets = [] @@ -208,38 +221,41 @@ class SimpleTextStretchy(Stretchy): text = Text(rewrap(text)) widgets.append(text) - widgets.extend([ - Text(""), - button_pile([close_btn(app, self)]), - ]) - super().__init__(title, widgets, 0, len(widgets)-1) + widgets.extend( + [ + Text(""), + button_pile([close_btn(app, self)]), + ] + ) + super().__init__(title, widgets, 0, len(widgets) - 1) -GLOBAL_KEY_HELP = _("""\ -The following keys can be used at any time:""") +GLOBAL_KEY_HELP = _( + """\ +The following keys can be used at any time:""" +) GLOBAL_KEYS = ( - (_("ESC"), _('go back')), - (_('F1'), _('open help menu')), - (_('Control-Z, F2'), _('switch to shell')), - (_('Control-L, F3'), _('redraw screen')), - ) + (_("ESC"), _("go back")), + (_("F1"), _("open help menu")), + (_("Control-Z, F2"), _("switch to shell")), + (_("Control-L, F3"), _("redraw screen")), +) SERIAL_GLOBAL_HELP_KEYS = ( - (_('Control-T, F4'), _('toggle rich mode (colour, unicode) on and off')), - ) + (_("Control-T, F4"), _("toggle rich mode (colour, unicode) on and off")), +) DRY_RUN_KEYS = ( - (_('Control-X'), _('quit')), - (_('Control-E'), _('generate noisy error report')), - (_('Control-R'), _('generate quiet error report')), - (_('Control-G'), _('pretend to run an install')), - (_('Control-U'), _('crash the ui')), - ) + (_("Control-X"), _("quit")), + (_("Control-E"), _("generate noisy error report")), + (_("Control-R"), _("generate quiet error report")), + (_("Control-G"), _("pretend to run an install")), + (_("Control-U"), _("crash the ui")), +) class GlobalKeyStretchy(Stretchy): - def __init__(self, app): rows = [] keys = GLOBAL_KEYS @@ -248,91 +264,85 @@ class GlobalKeyStretchy(Stretchy): for key, text in keys: rows.append(TableRow([Text(_(key)), Text(_(text))])) if app.opts.dry_run: - dro = _('(dry-run only)') + dro = _("(dry-run only)") for key, text in DRY_RUN_KEYS: - rows.append(TableRow([ - Text(_(key)), - Text(_(text) + ' ' + dro)])) - table = TablePile( - rows, spacing=2, colspecs={1: ColSpec(can_shrink=True)}) + rows.append(TableRow([Text(_(key)), Text(_(text) + " " + dro)])) + table = TablePile(rows, spacing=2, colspecs={1: ColSpec(can_shrink=True)}) widgets = [ - Pile([ - ('pack', Text(rewrap(GLOBAL_KEY_HELP))), - ('pack', Text("")), - ('pack', table), - ]), + Pile( + [ + ("pack", Text(rewrap(GLOBAL_KEY_HELP))), + ("pack", Text("")), + ("pack", table), + ] + ), Text(""), button_pile([close_btn(app, self)]), - ] + ] super().__init__(_("Shortcut Keys"), widgets, 0, 2) -hline = Divider('─') -vline = Text('│') -tlcorner = Text('┌') -trcorner = Text('┐') -blcorner = Text('└') -brcorner = Text('┘') -rtee = Text('┤') -ltee = Text('├') +hline = Divider("─") +vline = Text("│") +tlcorner = Text("┌") +trcorner = Text("┐") +blcorner = Text("└") +brcorner = Text("┘") +rtee = Text("┤") +ltee = Text("├") def menu_item(text, on_press=None): icon = ClickableIcon(" " + text + " ") if on_press is not None: - connect_signal(icon, 'click', on_press) + connect_signal(icon, "click", on_press) return Color.frame_button(icon) class OpenHelpMenu(WidgetWrap): - def __init__(self, parent): self.parent = parent close = header_btn(parent.base_widget.label) - about = menu_item( - _("About this installer"), on_press=self.parent.about) - keys = menu_item( - _("Keyboard shortcuts"), on_press=self.parent.shortcuts) - drop_to_shell = menu_item( - _("Enter shell"), on_press=self.parent.debug_shell) + about = menu_item(_("About this installer"), on_press=self.parent.about) + keys = menu_item(_("Keyboard shortcuts"), on_press=self.parent.shortcuts) + drop_to_shell = menu_item(_("Enter shell"), on_press=self.parent.debug_shell) buttons = { about, close, drop_to_shell, keys, - } + } if self.parent.ssh_info is not None: - ssh_help = menu_item( - _("Help on SSH access"), on_press=self.parent.ssh_help) + ssh_help = menu_item(_("Help on SSH access"), on_press=self.parent.ssh_help) buttons.add(ssh_help) if self.parent.app.opts.run_on_serial: - rich = menu_item( - _("Toggle rich mode"), on_press=self.parent.toggle_rich) + rich = menu_item(_("Toggle rich mode"), on_press=self.parent.toggle_rich) buttons.add(rich) local_title = None - if hasattr(parent.app.ui.body, 'local_help'): + if hasattr(parent.app.ui.body, "local_help"): local_title, local_doc = parent.app.ui.body.local_help() if local_title is not None: local = menu_item( - local_title, - on_press=self.parent.show_local(local_title, local_doc)) + local_title, on_press=self.parent.show_local(local_title, local_doc) + ) buttons.add(local) else: - local = Text( - ('info_minor header', " " + _("Help on this screen") + " ")) + local = Text(("info_minor header", " " + _("Help on this screen") + " ")) self.parent.app.error_reporter.load_reports() if self.parent.app.error_reporter.reports: view_errors = menu_item( _("View error reports").format(local_title), - on_press=self.parent.show_errors) + on_press=self.parent.show_errors, + ) buttons.add(view_errors) else: view_errors = Text( - ('info_minor header', " " + _("View error reports") + " ")) + ("info_minor header", " " + _("View error reports") + " ") + ) for button in buttons: - connect_signal(button.base_widget, 'click', self._close) + connect_signal(button.base_widget, "click", self._close) entries = [ local, @@ -341,50 +351,60 @@ class OpenHelpMenu(WidgetWrap): view_errors, hline, about, - ] + ] if self.parent.ssh_info is not None: entries.append(ssh_help) if self.parent.app.opts.run_on_serial: - entries.extend([ - hline, - rich, - ]) + entries.extend( + [ + hline, + rich, + ] + ) rows = [ - Columns([ - ('fixed', 1, tlcorner), - hline, - (widget_width(close), close), - ('fixed', 1, trcorner), - ]), - ] + Columns( + [ + ("fixed", 1, tlcorner), + hline, + (widget_width(close), close), + ("fixed", 1, trcorner), + ] + ), + ] for entry in entries: if isinstance(entry, Divider): left, right = ltee, rtee else: left = right = vline - rows.append(Columns([ - ('fixed', 1, left), - entry, - ('fixed', 1, right), - ])) + rows.append( + Columns( + [ + ("fixed", 1, left), + entry, + ("fixed", 1, right), + ] + ) + ) rows.append( - Columns([ - (1, blcorner), - hline, - (1, brcorner), - ])) - self.width = max([ - widget_width(b) for b in entries - if not isinstance(b, Divider) - ]) + 2 + Columns( + [ + (1, blcorner), + hline, + (1, brcorner), + ] + ) + ) + self.width = ( + max([widget_width(b) for b in entries if not isinstance(b, Divider)]) + 2 + ) self.height = len(entries) + 2 super().__init__(Color.frame_header(Filler(Pile(rows)))) def keypress(self, size, key): - if key == 'esc': + if key == "esc": self.parent.close_pop_up() else: return super().keypress(size, key) @@ -394,7 +414,6 @@ class OpenHelpMenu(WidgetWrap): class HelpMenu(PopUpLauncher): - def __init__(self, app, about_msg=None): self.app = app self._about_message = about_msg @@ -405,7 +424,8 @@ class HelpMenu(PopUpLauncher): async def _get_ssh_info(self): self.ssh_info = await self.app.wait_with_text_dialog( - self.app.client.meta.ssh_info.GET(), "Getting SSH info") + self.app.client.meta.ssh_info.GET(), "Getting SSH info" + ) self.open_pop_up() def _open(self, sender): @@ -418,11 +438,11 @@ class HelpMenu(PopUpLauncher): def get_pop_up_parameters(self): return { - 'left': widget_width(self.btn) - self._menu.width + 1, - 'top': 0, - 'overlay_width': self._menu.width, - 'overlay_height': self._menu.height, - } + "left": widget_width(self.btn) - self._menu.width + 1, + "top": 0, + "overlay_width": self._menu.width, + "overlay_height": self._menu.height, + } def _show_overlay(self, stretchy): ui = self.app.ui @@ -438,20 +458,22 @@ class HelpMenu(PopUpLauncher): self.current_help = None ui.pile.focus_position = fp - connect_signal(stretchy, 'closed', on_close) + connect_signal(stretchy, "closed", on_close) self.app.add_global_overlay(stretchy) def _default_about_msg(self): info = lsb_release(dry_run=self.app.opts.dry_run) - if 'LTS' in info['description']: + if "LTS" in info["description"]: template = _(ABOUT_INSTALLER_LTS) else: template = _(ABOUT_INSTALLER) - info.update({ - 'snap_version': os.environ.get("SNAP_VERSION", "SNAP_VERSION"), - 'snap_revision': os.environ.get("SNAP_REVISION", "SNAP_REVISION"), - }) + info.update( + { + "snap_version": os.environ.get("SNAP_VERSION", "SNAP_VERSION"), + "snap_revision": os.environ.get("SNAP_REVISION", "SNAP_REVISION"), + } + ) return template.format(**info) def about(self, sender=None): @@ -459,10 +481,8 @@ class HelpMenu(PopUpLauncher): self._about_message = self._default_about_msg() self._show_overlay( - SimpleTextStretchy( - self.app, - _("About the installer"), - self._about_message)) + SimpleTextStretchy(self.app, _("About the installer"), self._about_message) + ) def ssh_help(self, sender=None): texts = ssh_help_texts(self.ssh_info) @@ -472,16 +492,13 @@ class HelpMenu(PopUpLauncher): self.app, _("Help on SSH access"), *texts, - )) + ) + ) def show_local(self, local_title, local_doc): - def cb(sender=None): - self._show_overlay( - SimpleTextStretchy( - self.app, - local_title, - local_doc)) + self._show_overlay(SimpleTextStretchy(self.app, local_title, local_doc)) + return cb def shortcuts(self, sender): diff --git a/subiquity/ui/views/identity.py b/subiquity/ui/views/identity.py index 54e7f859..a8a218b1 100644 --- a/subiquity/ui/views/identity.py +++ b/subiquity/ui/views/identity.py @@ -17,58 +17,51 @@ import logging import os import re -from urwid import ( - connect_signal, - ) -from subiquitycore.async_helpers import schedule_task +from urwid import connect_signal -from subiquitycore.ui.interactive import ( - PasswordEditor, - StringEditor, - ) -from subiquitycore.ui.form import ( - Form, - simple_field, - WantsToKnowFormField, - ) +from subiquity.common.resources import resource_path +from subiquity.common.types import IdentityData, UsernameValidation +from subiquitycore.async_helpers import schedule_task +from subiquitycore.ui.form import Form, WantsToKnowFormField, simple_field +from subiquitycore.ui.interactive import PasswordEditor, StringEditor from subiquitycore.ui.utils import screen from subiquitycore.utils import crypt_password from subiquitycore.view import BaseView -from subiquity.common.resources import resource_path -from subiquity.common.types import IdentityData, UsernameValidation - - log = logging.getLogger("subiquity.ui.views.identity") HOSTNAME_MAXLEN = 64 -HOSTNAME_REGEX = r'[a-z0-9_][a-z0-9_-]*' +HOSTNAME_REGEX = r"[a-z0-9_][a-z0-9_-]*" REALNAME_MAXLEN = 160 SSH_IMPORT_MAXLEN = 256 + 3 # account for lp: or gh: USERNAME_MAXLEN = 32 -USERNAME_REGEX = r'[a-z_][a-z0-9_-]*' +USERNAME_REGEX = r"[a-z_][a-z0-9_-]*" class RealnameEditor(StringEditor, WantsToKnowFormField): def valid_char(self, ch): - if len(ch) == 1 and ch in ':,=': + if len(ch) == 1 and ch in ":,=": self.bff.in_error = True - self.bff.show_extra(("info_error", - _("The characters : , and = are not permitted" - " in this field"))) + self.bff.show_extra( + ( + "info_error", + _("The characters : , and = are not permitted" " in this field"), + ) + ) return False else: return super().valid_char(ch) class _AsyncValidatedMixin: - """ Provides Editor widgets with async validation capabilities """ + """Provides Editor widgets with async validation capabilities""" + def __init__(self): self.validation_task = None self.initial = None self.validation_result = None self._validate_async_inner = None - connect_signal(self, 'change', self._reset_validation) + connect_signal(self, "change", self._reset_validation) def set_initial_state(self, initial): self.initial = initial @@ -84,8 +77,7 @@ class _AsyncValidatedMixin: if self.validation_task is not None: self.validation_task.cancel() - self.validation_task = \ - schedule_task(self._validate_async(self.value)) + self.validation_task = schedule_task(self._validate_async(self.value)) async def _validate_async(self, value): # Retrigger field validation because it's not guaranteed that the async @@ -97,9 +89,10 @@ class _AsyncValidatedMixin: class UsernameEditor(StringEditor, _AsyncValidatedMixin, WantsToKnowFormField): def __init__(self): - self.valid_char_pat = r'[-a-z0-9_]' - self.error_invalid_char = _("The only characters permitted in this " - "field are a-z, 0-9, _ and -") + self.valid_char_pat = r"[-a-z0-9_]" + self.error_invalid_char = _( + "The only characters permitted in this " "field are a-z, 0-9, _ and -" + ) super().__init__() def valid_char(self, ch): @@ -117,7 +110,6 @@ PasswordField = simple_field(PasswordEditor) class IdentityForm(Form): - def __init__(self, controller, initial): self.controller = controller super().__init__(initial=initial) @@ -128,29 +120,29 @@ class IdentityForm(Form): realname = RealnameField(_("Your name:")) hostname = UsernameField( _("Your server's name:"), - help=_("The name it uses when it talks to other computers.")) + help=_("The name it uses when it talks to other computers."), + ) username = UsernameField(_("Pick a username:")) password = PasswordField(_("Choose a password:")) confirm_password = PasswordField(_("Confirm your password:")) def validate_realname(self): if len(self.realname.value) > REALNAME_MAXLEN: - return _( - "Name too long, must be less than {limit}" - ).format(limit=REALNAME_MAXLEN) + return _("Name too long, must be less than {limit}").format( + limit=REALNAME_MAXLEN + ) def validate_hostname(self): if len(self.hostname.value) < 1: return _("Server name must not be empty") if len(self.hostname.value) > HOSTNAME_MAXLEN: - return _( - "Server name too long, must be less than {limit}" - ).format(limit=HOSTNAME_MAXLEN) + return _("Server name too long, must be less than {limit}").format( + limit=HOSTNAME_MAXLEN + ) if not re.match(HOSTNAME_REGEX, self.hostname.value): - return _( - "Hostname must match HOSTNAME_REGEX: " + HOSTNAME_REGEX) + return _("Hostname must match HOSTNAME_REGEX: " + HOSTNAME_REGEX) def validate_username(self): username = self.username.value @@ -158,24 +150,23 @@ class IdentityForm(Form): return _("Username missing") if len(username) > USERNAME_MAXLEN: - return _( - "Username too long, must be less than {limit}" - ).format(limit=USERNAME_MAXLEN) + return _("Username too long, must be less than {limit}").format( + limit=USERNAME_MAXLEN + ) - if not re.match(r'[a-z_][a-z0-9_-]*', username): - return _( - "Username must match USERNAME_REGEX: " + USERNAME_REGEX) + if not re.match(r"[a-z_][a-z0-9_-]*", username): + return _("Username must match USERNAME_REGEX: " + USERNAME_REGEX) state = self.username.widget.validation_result if state == UsernameValidation.SYSTEM_RESERVED: return _( 'The username "{username}" is reserved for use by the system.' - ).format(username=username) + ).format(username=username) if state == UsernameValidation.ALREADY_IN_USE: - return _( - 'The username "{username}" is already in use.' - ).format(username=username) + return _('The username "{username}" is already in use.').format( + username=username + ) def validate_password(self): if len(self.password.value) < 1: @@ -188,50 +179,56 @@ class IdentityForm(Form): class IdentityView(BaseView): title = _("Profile setup") - excerpt = _("Enter the username and password you will use to log in to " - "the system. You can configure SSH access on the next screen " - "but a password is still needed for sudo.") + excerpt = _( + "Enter the username and password you will use to log in to " + "the system. You can configure SSH access on the next screen " + "but a password is still needed for sudo." + ) def __init__(self, controller, identity_data): self.controller = controller - reserved_usernames_path = resource_path('reserved-usernames') + reserved_usernames_path = resource_path("reserved-usernames") reserved_usernames = set() if os.path.exists(reserved_usernames_path): with open(reserved_usernames_path) as fp: for line in fp: line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue reserved_usernames.add(line) else: - reserved_usernames.add('root') + reserved_usernames.add("root") initial = { - 'realname': identity_data.realname, - 'username': identity_data.username, - 'hostname': identity_data.hostname, - } + "realname": identity_data.realname, + "username": identity_data.username, + "hostname": identity_data.hostname, + } self.form = IdentityForm(controller, initial) self.form.confirm_password.use_as_confirmation( - for_field=self.form.password, - desc=_("Passwords")) + for_field=self.form.password, desc=_("Passwords") + ) - connect_signal(self.form, 'submit', self.done) + connect_signal(self.form, "submit", self.done) super().__init__( screen( self.form.as_rows(), [self.form.done_btn], excerpt=_(self.excerpt), - focus_buttons=False)) + focus_buttons=False, + ) + ) def done(self, result): - self.controller.done(IdentityData( - hostname=self.form.hostname.value, - realname=self.form.realname.value, - username=self.form.username.value, - crypted_password=crypt_password(self.form.password.value), - )) + self.controller.done( + IdentityData( + hostname=self.form.hostname.value, + realname=self.form.realname.value, + username=self.form.username.value, + crypted_password=crypt_password(self.form.password.value), + ) + ) diff --git a/subiquity/ui/views/installprogress.py b/subiquity/ui/views/installprogress.py index 6c7902ff..ecb144f1 100644 --- a/subiquity/ui/views/installprogress.py +++ b/subiquity/ui/views/installprogress.py @@ -16,28 +16,18 @@ import asyncio import logging -from urwid import ( - LineBox, - Text, - ) +from urwid import LineBox, Text +from subiquity.common.types import ApplicationState from subiquitycore.async_helpers import run_bg_task -from subiquitycore.view import BaseView -from subiquitycore.ui.buttons import ( - cancel_btn, - danger_btn, - ok_btn, - other_btn, - ) +from subiquitycore.ui.buttons import cancel_btn, danger_btn, ok_btn, other_btn from subiquitycore.ui.container import Columns, ListBox, Pile from subiquitycore.ui.form import Toggleable from subiquitycore.ui.spinner import Spinner -from subiquitycore.ui.utils import button_pile, Padding, rewrap from subiquitycore.ui.stretchy import Stretchy +from subiquitycore.ui.utils import Padding, button_pile, rewrap from subiquitycore.ui.width import widget_width - -from subiquity.common.types import ApplicationState - +from subiquitycore.view import BaseView log = logging.getLogger("subiquity.ui.views.installprogress") @@ -51,40 +41,36 @@ class MyLineBox(LineBox): class ProgressView(BaseView): - title = _("Install progress") def __init__(self, controller): self.controller = controller self.ongoing = {} # context_id -> line containing a spinner - self.reboot_btn = Toggleable(ok_btn( - _("Reboot Now"), on_press=self.reboot)) + self.reboot_btn = Toggleable(ok_btn(_("Reboot Now"), on_press=self.reboot)) self.view_error_btn = cancel_btn( - _("View error report"), on_press=self.view_error) - self.view_log_btn = other_btn( - _("View full log"), on_press=self.view_log) - self.continue_btn = other_btn( - _("Continue"), on_press=self.continue_) + _("View error report"), on_press=self.view_error + ) + self.view_log_btn = other_btn(_("View full log"), on_press=self.view_log) + self.continue_btn = other_btn(_("Continue"), on_press=self.continue_) self.event_listbox = ListBox() self.event_linebox = MyLineBox(self.event_listbox) self.event_buttons = button_pile([self.view_log_btn]) event_body = [ - ('weight', 1, Padding.center_79(self.event_linebox, min_width=76)), - ('pack', Text("")), - ('pack', self.event_buttons), - ('pack', Text("")), + ("weight", 1, Padding.center_79(self.event_linebox, min_width=76)), + ("pack", Text("")), + ("pack", self.event_buttons), + ("pack", Text("")), ] self.event_pile = Pile(event_body) self.log_listbox = ListBox() log_linebox = MyLineBox(self.log_listbox, _("Full installer output")) log_body = [ - ('weight', 1, log_linebox), - ('pack', button_pile([other_btn(_("Close"), - on_press=self.close_log)])), - ] + ("weight", 1, log_linebox), + ("pack", button_pile([other_btn(_("Close"), on_press=self.close_log)])), + ] self.log_pile = Pile(log_body) super().__init__(self.event_pile) @@ -96,17 +82,20 @@ class ProgressView(BaseView): walker.append(line) if at_end: lb.set_focus(len(walker) - 1) - lb.set_focus_valign('bottom') + lb.set_focus_valign("bottom") def event_start(self, context_id, context_parent_id, message): self.event_finish(context_parent_id) walker = self.event_listbox.base_widget.body spinner = Spinner() spinner.start() - new_line = Columns([ - ('pack', Text(message)), - ('pack', spinner), - ], dividechars=1) + new_line = Columns( + [ + ("pack", Text(message)), + ("pack", spinner), + ], + dividechars=1, + ) self.ongoing[context_id] = len(walker) self._add_line(self.event_listbox, new_line) @@ -137,7 +126,7 @@ class ProgressView(BaseView): def _set_buttons(self, buttons): p = self.event_buttons.original_widget - p.contents[:] = [(b, p.options('pack')) for b in buttons] + p.contents[:] = [(b, p.options("pack")) for b in buttons] self._set_button_width() def update_for_state(self, state): @@ -161,12 +150,11 @@ class ProgressView(BaseView): btns = [self.view_log_btn] elif state == ApplicationState.UU_RUNNING: self.title = _("Install complete!") - self.reboot_btn.base_widget.set_label( - _("Cancel update and reboot")) + self.reboot_btn.base_widget.set_label(_("Cancel update and reboot")) btns = [ self.view_log_btn, self.reboot_btn, - ] + ] elif state == ApplicationState.UU_CANCELLING: self.title = _("Install complete!") self.reboot_btn.base_widget.set_label(_("Rebooting...")) @@ -174,25 +162,25 @@ class ProgressView(BaseView): btns = [ self.view_log_btn, self.reboot_btn, - ] + ] elif state == ApplicationState.DONE: self.title = _("Install complete!") self.reboot_btn.base_widget.set_label(_("Reboot Now")) btns = [ self.view_log_btn, self.reboot_btn, - ] + ] elif state == ApplicationState.ERROR: - self.title = _('An error occurred during installation') + self.title = _("An error occurred during installation") self.reboot_btn.base_widget.set_label(_("Reboot Now")) self.reboot_btn.enabled = True btns = [ self.view_log_btn, self.view_error_btn, self.reboot_btn, - ] + ] elif state == ApplicationState.EXITED: - self.title = _('Subiquity server process has exited') + self.title = _("Subiquity server process has exited") btns = [self.view_log_btn] else: raise Exception(state) @@ -219,7 +207,7 @@ class ProgressView(BaseView): self.event_pile.base_widget.focus_position = 2 def reboot(self, btn): - log.debug('reboot clicked') + log.debug("reboot clicked") self.reboot_btn.base_widget.set_label(_("Rebooting...")) self.reboot_btn.enabled = False self.event_buttons.original_widget._select_first_selectable() @@ -236,14 +224,16 @@ class ProgressView(BaseView): self._w = self.event_pile -confirmation_text = _("""\ +confirmation_text = _( + """\ Selecting Continue below will begin the installation process and result in the loss of data on the disks selected to be formatted. You will not be able to return to this or a previous screen once the installation has started. -Are you sure you want to continue?""") +Are you sure you want to continue?""" +) class InstallConfirmation(Stretchy): @@ -252,15 +242,16 @@ class InstallConfirmation(Stretchy): widgets = [ Text(rewrap(_(confirmation_text))), Text(""), - button_pile([ - cancel_btn(_("No"), on_press=self.cancel), - danger_btn(_("Continue"), on_press=self.ok)]), - ] + button_pile( + [ + cancel_btn(_("No"), on_press=self.cancel), + danger_btn(_("Continue"), on_press=self.ok), + ] + ), + ] super().__init__( - _("Confirm destructive action"), - widgets, - stretchy_index=0, - focus_index=2) + _("Confirm destructive action"), widgets, stretchy_index=0, focus_index=2 + ) def ok(self, sender): if isinstance(self.app.ui.body, ProgressView): @@ -277,30 +268,29 @@ class InstallConfirmation(Stretchy): self.app.ui.body.show_continue() -running_text = _("""\ +running_text = _( + """\ The installer running on {tty} is currently installing the system. You can wait for this to complete or switch to a shell. -""") +""" +) class InstallRunning(Stretchy): def __init__(self, app, tty): self.app = app - self.btn = Toggleable(other_btn( - _("Switch to a shell"), on_press=self._debug_shell)) + self.btn = Toggleable( + other_btn(_("Switch to a shell"), on_press=self._debug_shell) + ) self.btn.enabled = False self._enable_task = asyncio.create_task(self._enable()) widgets = [ Text(rewrap(_(running_text).format(tty=tty))), - Text(''), + Text(""), button_pile([self.btn]), - ] - super().__init__( - "", - widgets, - stretchy_index=0, - focus_index=2) + ] + super().__init__("", widgets, stretchy_index=0, focus_index=2) async def _enable(self): await asyncio.sleep(0.5) diff --git a/subiquity/ui/views/keyboard.py b/subiquity/ui/views/keyboard.py index 66f8679c..f16934b8 100644 --- a/subiquity/ui/views/keyboard.py +++ b/subiquity/ui/views/keyboard.py @@ -16,41 +16,24 @@ import locale import logging -from urwid import ( - connect_signal, - LineBox, - Text, - Padding as UrwidPadding - ) - -from subiquitycore.async_helpers import run_bg_task -from subiquitycore.ui.buttons import ( - cancel_btn, - ok_btn, - other_btn, - ) -from subiquitycore.ui.container import ( - Columns, - Pile, - WidgetWrap, - ) -from subiquitycore.ui.form import ( - ChoiceField, - Form, - ) -from subiquitycore.ui.selector import Selector, Option -from subiquitycore.ui.stretchy import ( - Stretchy, - ) -from subiquitycore.ui.utils import button_pile, Color, Padding, screen -from subiquitycore.view import BaseView +from urwid import LineBox +from urwid import Padding as UrwidPadding +from urwid import Text, connect_signal from subiquity.common.types import ( KeyboardSetting, StepKeyPresent, StepPressKey, StepResult, - ) +) +from subiquitycore.async_helpers import run_bg_task +from subiquitycore.ui.buttons import cancel_btn, ok_btn, other_btn +from subiquitycore.ui.container import Columns, Pile, WidgetWrap +from subiquitycore.ui.form import ChoiceField, Form +from subiquitycore.ui.selector import Option, Selector +from subiquitycore.ui.stretchy import Stretchy +from subiquitycore.ui.utils import Color, Padding, button_pile, screen +from subiquitycore.view import BaseView log = logging.getLogger("subiquity.ui.views.keyboard") @@ -60,12 +43,15 @@ class AutoDetectBase(WidgetWrap): self.keyboard_detector = keyboard_detector self.step = step lb = LineBox( - Pile([ - ('pack', Text("")), - ('pack', UrwidPadding(self.make_body(), left=2, right=2)), - ('pack', Text("")) - ]), - _("Keyboard auto-detection")) + Pile( + [ + ("pack", Text("")), + ("pack", UrwidPadding(self.make_body(), left=2, right=2)), + ("pack", Text("")), + ] + ), + _("Keyboard auto-detection"), + ) super().__init__(lb) def start(self): @@ -75,14 +61,13 @@ class AutoDetectBase(WidgetWrap): pass def keypress(self, size, key): - if key == 'esc': + if key == "esc": self.keyboard_detector.backup() else: return super().keypress(size, key) class AutoDetectIntro(AutoDetectBase): - def ok(self, sender): self.keyboard_detector.do_step(None) @@ -90,31 +75,42 @@ class AutoDetectIntro(AutoDetectBase): self.keyboard_detector.abort() def make_body(self): - return Pile([ - Text(_("Keyboard detection starting. You will be asked a " - "series of questions about your keyboard. Press escape " - "at any time to go back to the previous screen.")), + return Pile( + [ + Text( + _( + "Keyboard detection starting. You will be asked a " + "series of questions about your keyboard. Press escape " + "at any time to go back to the previous screen." + ) + ), Text(""), - button_pile([ - ok_btn(label=_("OK"), on_press=self.ok), - ok_btn(label=_("Cancel"), on_press=self.cancel), - ]), - ]) + button_pile( + [ + ok_btn(label=_("OK"), on_press=self.ok), + ok_btn(label=_("Cancel"), on_press=self.cancel), + ] + ), + ] + ) class AutoDetectResult(AutoDetectBase): - - preamble = _("""\ + preamble = _( + """\ Keyboard auto detection completed. Your keyboard was detected as: -""") - postamble = _("""\ +""" + ) + postamble = _( + """\ If this is correct, select Done on the next screen. If not you can select \ another layout or run the automated detection again. -""") +""" + ) @property def _kview(self): @@ -125,17 +121,20 @@ another layout or run the automated detection again. def make_body(self): self.layout, self.variant = self._kview.lookup( - self.step.layout, self.step.variant) + self.step.layout, self.step.variant + ) layout_text = _("Layout") var_text = _("Variant") width = max(len(layout_text), len(var_text), 12) - return Pile([ + return Pile( + [ Text(_(self.preamble)), Text("%*s: %s" % (width, layout_text, self.layout.name)), Text("%*s: %s" % (width, var_text, self.variant.name)), Text(_(self.postamble)), button_pile([ok_btn(label=_("OK"), on_press=self.ok)]), - ]) + ] + ) class AutoDetectPressKey(AutoDetectBase): @@ -149,12 +148,19 @@ class AutoDetectPressKey(AutoDetectBase): def make_body(self): self.error_text = Text("", align="center") - self.pile = Pile([ - ('pack', Text(_("Please press one of the following keys:"))), - ('pack', Text("")), - ('pack', Columns([Text(s, align="center") - for s in self.step.symbols], dividechars=1)), - ]) + self.pile = Pile( + [ + ("pack", Text(_("Please press one of the following keys:"))), + ("pack", Text("")), + ( + "pack", + Columns( + [Text(s, align="center") for s in self.step.symbols], + dividechars=1, + ), + ), + ] + ) return self.pile @property @@ -168,23 +174,25 @@ class AutoDetectPressKey(AutoDetectBase): self.input_filter.exit_keycodes_mode() def error(self, message): - t = Color.info_error(Text(message, align='center')) - self.pile.contents.extend([ - (Text(""), self.pile.options('pack')), - (t, self.pile.options('pack')), - ]) + t = Color.info_error(Text(message, align="center")) + self.pile.contents.extend( + [ + (Text(""), self.pile.options("pack")), + (t, self.pile.options("pack")), + ] + ) def keypress(self, size, key): - log.debug('keypress %r', key) - if key.startswith('release '): + log.debug("keypress %r", key) + if key.startswith("release "): # Escape is key 1 on all keyboards and all layouts except # amigas and very old Macs so this seems safe enough. - if key == 'release 1': - return super().keypress(size, 'esc') + if key == "release 1": + return super().keypress(size, "esc") else: return - elif key.startswith('press '): - code = int(key[len('press '):]) + elif key.startswith("press "): + code = int(key[len("press ") :]) if code not in self.step.keycodes: self.error(_("Input was not recognized, try again")) return @@ -193,16 +201,16 @@ class AutoDetectPressKey(AutoDetectBase): # If we're not on a linux tty, the filtering won't have # happened and so there's no way to get the keycodes. Do # something literally random instead. - if key == 'e': + if key == "e": self.error(_("Input was not recognized, try again")) return import random + v = random.choice(list(self.step.keycodes.values())) self.keyboard_detector.do_step(v) class AutoDetectKeyPresent(AutoDetectBase): - def yes(self, sender): self.keyboard_detector.do_step(self.step.yes) @@ -210,16 +218,20 @@ class AutoDetectKeyPresent(AutoDetectBase): self.keyboard_detector.do_step(self.step.no) def make_body(self): - return Pile([ - Text(_("Is the following key present on your keyboard?")), - Text(""), - Text(self.step.symbol, align="center"), - Text(""), - button_pile([ - ok_btn(label=_("Yes"), on_press=self.yes), - other_btn(label=_("No"), on_press=self.no), - ]), - ]) + return Pile( + [ + Text(_("Is the following key present on your keyboard?")), + Text(""), + Text(self.step.symbol, align="center"), + Text(""), + button_pile( + [ + ok_btn(label=_("Yes"), on_press=self.yes), + other_btn(label=_("No"), on_press=self.no), + ] + ), + ] + ) class Detector: @@ -241,7 +253,7 @@ class Detector: StepResult: AutoDetectResult, StepPressKey: AutoDetectPressKey, StepKeyPresent: AutoDetectKeyPresent, - } + } def backup(self): if len(self.seen_steps) == 0: @@ -264,8 +276,8 @@ class Detector: async def _do_step(self, step_index): log.debug("moving to step %s", step_index) step = await self.keyboard_view.controller.app.wait_with_text_dialog( - self.keyboard_view.controller.get_step(step_index), - "...") + self.keyboard_view.controller.get_step(step_index), "..." + ) self.seen_steps.append(step_index) log.debug("step: %s", step) self.overlay = self.step_cls_to_view_cls[type(step)](self, step) @@ -274,7 +286,8 @@ class Detector: self.keyboard_view.show_overlay(self.overlay) -toggle_text = _("""\ +toggle_text = _( + """\ You will need a way to toggle the keyboard between the national layout and \ the standard Latin layout. @@ -283,37 +296,37 @@ latter case, use the combination Shift+Caps Lock for normal Caps toggle). \ Alt+Shift is also a popular combination; it will however lose its usual \ behavior in Emacs and other programs that use it for specific needs. -Not all listed keys are present on all keyboards. """) +Not all listed keys are present on all keyboards. """ +) toggle_options = [ - (_('Caps Lock'), True, 'caps_toggle'), - (_('Right Alt (AltGr)'), True, 'toggle'), - (_('Right Control'), True, 'rctrl_toggle'), - (_('Right Shift'), True, 'rshift_toggle'), - (_('Right Logo key'), True, 'rwin_toggle'), - (_('Menu key'), True, 'menu_toggle'), - (_('Alt+Shift'), True, 'alt_shift_toggle'), - (_('Control+Shift'), True, 'ctrl_shift_toggle'), - (_('Control+Alt'), True, 'ctrl_alt_toggle'), - (_('Alt+Caps Lock'), True, 'alt_caps_toggle'), - (_('Left Control+Left Shift'), True, 'lctrl_lshift_toggle'), - (_('Left Alt'), True, 'lalt_toggle'), - (_('Left Control'), True, 'lctrl_toggle'), - (_('Left Shift'), True, 'lshift_toggle'), - (_('Left Logo key'), True, 'lwin_toggle'), - (_('Scroll Lock key'), True, 'sclk_toggle'), - (_('No toggling'), True, None), - ] + (_("Caps Lock"), True, "caps_toggle"), + (_("Right Alt (AltGr)"), True, "toggle"), + (_("Right Control"), True, "rctrl_toggle"), + (_("Right Shift"), True, "rshift_toggle"), + (_("Right Logo key"), True, "rwin_toggle"), + (_("Menu key"), True, "menu_toggle"), + (_("Alt+Shift"), True, "alt_shift_toggle"), + (_("Control+Shift"), True, "ctrl_shift_toggle"), + (_("Control+Alt"), True, "ctrl_alt_toggle"), + (_("Alt+Caps Lock"), True, "alt_caps_toggle"), + (_("Left Control+Left Shift"), True, "lctrl_lshift_toggle"), + (_("Left Alt"), True, "lalt_toggle"), + (_("Left Control"), True, "lctrl_toggle"), + (_("Left Shift"), True, "lshift_toggle"), + (_("Left Logo key"), True, "lwin_toggle"), + (_("Scroll Lock key"), True, "sclk_toggle"), + (_("No toggling"), True, None), +] class ToggleQuestion(Stretchy): - def __init__(self, parent, setting): self.parent = parent self.setting = setting self.selector = Selector(toggle_options) - self.selector.value = 'alt_shift_toggle' + self.selector.value = "alt_shift_toggle" if self.parent.initial_setting.toggle: try: self.selector.value = self.parent.initial_setting.toggle @@ -323,21 +336,25 @@ class ToggleQuestion(Stretchy): widgets = [ Text(_(toggle_text)), Text(""), - Padding.center_79(Columns([ - ('pack', Text(_("Shortcut: "))), - self.selector, - ])), + Padding.center_79( + Columns( + [ + ("pack", Text(_("Shortcut: "))), + self.selector, + ] + ) + ), Text(""), - button_pile([ - ok_btn(label=_("OK"), on_press=self.ok), - cancel_btn(label=_("Cancel"), on_press=self.cancel), - ]), - ] + button_pile( + [ + ok_btn(label=_("OK"), on_press=self.ok), + cancel_btn(label=_("Cancel"), on_press=self.cancel), + ] + ), + ] super().__init__( - _("Select layout toggle"), - widgets, - stretchy_index=0, - focus_index=4) + _("Select layout toggle"), widgets, stretchy_index=0, focus_index=4 + ) def ok(self, sender): self.parent.remove_overlay() @@ -349,7 +366,6 @@ class ToggleQuestion(Stretchy): class KeyboardForm(Form): - cancel_label = _("Back") layout = ChoiceField(_("Layout:"), choices=["placeholder"]) @@ -357,7 +373,6 @@ class KeyboardForm(Form): class KeyboardView(BaseView): - title = _("Keyboard configuration") def __init__(self, controller, setup): @@ -370,35 +385,38 @@ class KeyboardView(BaseView): for layout in self.layouts: opts.append(Option((layout.name, True, layout))) opts.sort(key=lambda o: locale.strxfrm(o.label.text)) - connect_signal(self.form, 'submit', self.done) - connect_signal(self.form, 'cancel', self.cancel) + connect_signal(self.form, "submit", self.done) + connect_signal(self.form, "cancel", self.cancel) connect_signal(self.form.layout.widget, "select", self.select_layout) self.form.layout.widget.options = opts - layout, variant = self.lookup( - setup.setting.layout, setup.setting.variant) + layout, variant = self.lookup(setup.setting.layout, setup.setting.variant) self.set_values(layout, variant) if self.controller.opts.run_on_serial: - excerpt = _('Please select the layout of the keyboard directly ' - 'attached to the system, if any.') + excerpt = _( + "Please select the layout of the keyboard directly " + "attached to the system, if any." + ) else: - excerpt = _('Please select your keyboard layout below, or select ' - '"Identify keyboard" to detect your layout ' - 'automatically.') + excerpt = _( + "Please select your keyboard layout below, or select " + '"Identify keyboard" to detect your layout ' + "automatically." + ) lb_contents = self.form.as_rows() if not self.controller.opts.run_on_serial: - lb_contents.extend([ - Text(""), - button_pile([ - other_btn(label=_("Identify keyboard"), - on_press=self.detect)]), - ]) - super().__init__(screen( - lb_contents, - self.form.buttons, - excerpt=excerpt, - narrow_rows=True)) + lb_contents.extend( + [ + Text(""), + button_pile( + [other_btn(label=_("Identify keyboard"), on_press=self.detect)] + ), + ] + ) + super().__init__( + screen(lb_contents, self.form.buttons, excerpt=excerpt, narrow_rows=True) + ) def detect(self, sender): Detector(self).start() @@ -411,7 +429,8 @@ class KeyboardView(BaseView): async def _check_toggle(self, setting): needs_toggle = await self.controller.app.wait_with_text_dialog( - self.controller.needs_toggle(setting), "...") + self.controller.needs_toggle(setting), "..." + ) if needs_toggle: self.show_stretchy_overlay(ToggleQuestion(self, setting)) else: @@ -419,14 +438,15 @@ class KeyboardView(BaseView): def done(self, result): data = result.as_data() - layout = data['layout'] - variant = data.get('variant', layout.variants[0]) + layout = data["layout"] + variant = data.get("variant", layout.variants[0]) setting = KeyboardSetting(layout=layout.code, variant=variant.code) run_bg_task(self._check_toggle(setting)) async def _apply(self, setting): await self.controller.app.wait_with_text_dialog( - self.controller.apply(setting), _("Applying config")) + self.controller.apply(setting), _("Applying config") + ) self.controller.done() def really_done(self, setting): diff --git a/subiquity/ui/views/mirror.py b/subiquity/ui/views/mirror.py index 2ed503d6..82076e45 100644 --- a/subiquity/ui/views/mirror.py +++ b/subiquity/ui/views/mirror.py @@ -20,64 +20,50 @@ import asyncio import logging from typing import Callable, Optional -from urwid import ( - connect_signal, - LineBox, - Padding, - Text, - ) +from urwid import LineBox, Padding, Text, connect_signal import subiquitycore.async_helpers as async_helpers -from subiquitycore.ui.buttons import ( - other_btn, - ) -from subiquitycore.ui.container import ( - ListBox, - Pile, - WidgetWrap, - ) -from subiquitycore.ui.form import ( - Form, - URLField, -) +from subiquity.common.types import MirrorCheckResponse, MirrorCheckStatus, MirrorPost +from subiquitycore.ui.buttons import other_btn +from subiquitycore.ui.container import ListBox, Pile, WidgetWrap +from subiquitycore.ui.form import Form, URLField from subiquitycore.ui.spinner import Spinner -from subiquitycore.ui.table import TableRow, TablePile +from subiquitycore.ui.table import TablePile, TableRow from subiquitycore.ui.utils import button_pile, rewrap from subiquitycore.view import BaseView -from subiquity.common.types import ( - MirrorPost, - MirrorCheckResponse, - MirrorCheckStatus, - ) - -log = logging.getLogger('subiquity.ui.views.mirror') +log = logging.getLogger("subiquity.ui.views.mirror") mirror_help = _( - "You may provide an archive mirror that will be used instead " - "of the default.") + "You may provide an archive mirror that will be used instead " "of the default." +) MIRROR_CHECK_CONFIRMATION_TEXTS = { MirrorCheckStatus.RUNNING: ( _("Mirror check still running"), - _("The check of the mirror URL is still running. You can continue but" - " there is a chance that the installation will fail."), + _( + "The check of the mirror URL is still running. You can continue but" + " there is a chance that the installation will fail." ), + ), MirrorCheckStatus.FAILED: ( _("Mirror check failed"), - _("The check of the mirror URL failed. You can continue but it is very" - " likely that the installation will fail."), + _( + "The check of the mirror URL failed. You can continue but it is very" + " likely that the installation will fail." ), + ), None: ( _("Mirror check has not run"), - _("The check of the mirror has not yet started. You can continue but" - " there is a chance that the installation will fail."), + _( + "The check of the mirror has not yet started. You can continue but" + " there is a chance that the installation will fail." ), - } + ), +} class MirrorForm(Form): - on_validate: Optional[Callable[[str], None]] = None cancel_label = _("Back") @@ -93,33 +79,42 @@ class MirrorForm(Form): # * a boolean stating if network is available # * the status of the mirror check (or None if it hasn't started yet) MIRROR_CHECK_STATUS_TEXTS = { - (False, None): _("The mirror location cannot be checked because no network" - " has been configured."), + (False, None): _( + "The mirror location cannot be checked because no network" + " has been configured." + ), (True, None): _("The mirror location has not yet started."), - (True, MirrorCheckStatus.RUNNING): _("The mirror location is being" - " tested."), + (True, MirrorCheckStatus.RUNNING): _("The mirror location is being" " tested."), (True, MirrorCheckStatus.OK): _("This mirror location passed tests."), - (True, MirrorCheckStatus.FAILED): _("""\ + (True, MirrorCheckStatus.FAILED): _( + """\ This mirror location does not seem to work. The output below may help explain the problem. You can try again once the issue has been fixed (common problems are network issues or the system clock being wrong). -"""), } +""" + ), +} class MirrorView(BaseView): - title = _("Configure Ubuntu archive mirror") - excerpt = _("If you use an alternative mirror for Ubuntu, enter its " - "details here.") + excerpt = _( + "If you use an alternative mirror for Ubuntu, enter its " "details here." + ) - def __init__(self, controller, mirror, - check: Optional[MirrorCheckResponse], has_network: bool): + def __init__( + self, + controller, + mirror, + check: Optional[MirrorCheckResponse], + has_network: bool, + ): self.controller = controller - self.form = MirrorForm(initial={'url': mirror}) + self.form = MirrorForm(initial={"url": mirror}) - connect_signal(self.form, 'submit', self.done) - connect_signal(self.form, 'cancel', self.cancel) + connect_signal(self.form, "submit", self.done) + connect_signal(self.form, "cancel", self.cancel) self.status_text = Text("") self.status_spinner = Spinner() @@ -127,37 +122,48 @@ class MirrorView(BaseView): self.output_text = Text("") self.output_box = LineBox(ListBox([self.output_text])) self.output_wrap = WidgetWrap(self.output_box) - self.retry_btns = button_pile([other_btn( - _("Try again now"), - on_press=lambda sender: self.check_url(self.form.url.value))]) + self.retry_btns = button_pile( + [ + other_btn( + _("Try again now"), + on_press=lambda sender: self.check_url(self.form.url.value), + ) + ] + ) self.has_network = has_network if check is not None: self.update_status(check) else: - self.status_text.set_text(rewrap(_( - MIRROR_CHECK_STATUS_TEXTS[self.has_network, None]))) + self.status_text.set_text( + rewrap(_(MIRROR_CHECK_STATUS_TEXTS[self.has_network, None])) + ) self.last_status: Optional[MirrorCheckStatus] = None - rows = [ - ('pack', Text(_(self.excerpt))), - ('pack', Text("")), - ] + [('pack', r) for r in self.form.as_rows()] + [ - ('pack', Text("")), - ('pack', self.status_wrap), - ('pack', Text("")), - self.output_wrap, - ('pack', Text("")), - ('pack', self.form.buttons), - ('pack', Text("")), + rows = ( + [ + ("pack", Text(_(self.excerpt))), + ("pack", Text("")), ] + + [("pack", r) for r in self.form.as_rows()] + + [ + ("pack", Text("")), + ("pack", self.status_wrap), + ("pack", Text("")), + self.output_wrap, + ("pack", Text("")), + ("pack", self.form.buttons), + ("pack", Text("")), + ] + ) self.form.on_validate = self.on_url_changed pile = Pile(rows) pile.focus_position = len(rows) - 2 - super().__init__(Padding( - pile, align='center', width=("relative", 79), min_width=76)) + super().__init__( + Padding(pile, align="center", width=("relative", 79), min_width=76) + ) def on_url_changed(self, url: str) -> None: async def inner(): @@ -174,15 +180,15 @@ class MirrorView(BaseView): async_helpers.run_bg_task(self._check_url(url)) async def _check_url(self, url, cancel_ongoing=False): - await self.controller.endpoint.POST( - MirrorPost(staged=url)) + await self.controller.endpoint.POST(MirrorPost(staged=url)) await self.controller.endpoint.check_mirror.start.POST(True) state = await self.controller.endpoint.check_mirror.progress.GET() self.update_status(state) def update_status(self, check_state: MirrorCheckResponse): - self.status_text.set_text(rewrap(_( - MIRROR_CHECK_STATUS_TEXTS[self.has_network, check_state.status]))) + self.status_text.set_text( + rewrap(_(MIRROR_CHECK_STATUS_TEXTS[self.has_network, check_state.status])) + ) self.output_text.set_text(check_state.output) async def cb(): @@ -191,20 +197,24 @@ class MirrorView(BaseView): self.update_status(status) if check_state.status == MirrorCheckStatus.FAILED: - self.output_wrap._w = Pile([ - ('pack', self.retry_btns), - ('pack', Text("")), - self.output_box, - ]) + self.output_wrap._w = Pile( + [ + ("pack", self.retry_btns), + ("pack", Text("")), + self.output_box, + ] + ) elif check_state.status == MirrorCheckStatus.RUNNING: self.output_wrap._w = self.output_box if check_state.status == MirrorCheckStatus.RUNNING: async_helpers.run_bg_task(cb()) self.status_spinner.start() - self.status_wrap._w = TablePile([ - TableRow([self.status_text, self.status_spinner]), - ]) + self.status_wrap._w = TablePile( + [ + TableRow([self.status_text, self.status_spinner]), + ] + ) else: self.status_spinner.stop() self.status_wrap._w = self.status_text @@ -215,14 +225,21 @@ class MirrorView(BaseView): async def confirm_continue_anyway() -> None: title, question = MIRROR_CHECK_CONFIRMATION_TEXTS[self.last_status] confirmed = await self.ask_confirmation( - title=title, question=question, - confirm_label=_("Continue"), cancel_label=_("Cancel")) + title=title, + question=question, + confirm_label=_("Continue"), + cancel_label=_("Cancel"), + ) if confirmed: self.controller.done(result.url.value) + log.debug("User input: {}".format(result.as_data())) if self.has_network and self.last_status in [ - MirrorCheckStatus.RUNNING, MirrorCheckStatus.FAILED, None]: + MirrorCheckStatus.RUNNING, + MirrorCheckStatus.FAILED, + None, + ]: async_helpers.run_bg_task(confirm_continue_anyway()) else: self.controller.done(result.url.value) diff --git a/subiquity/ui/views/proxy.py b/subiquity/ui/views/proxy.py index aa2afa42..1c689083 100644 --- a/subiquity/ui/views/proxy.py +++ b/subiquity/ui/views/proxy.py @@ -19,43 +19,42 @@ Provides high level options for Ubuntu install """ import logging + from urwid import connect_signal +from subiquitycore.ui.form import Form, URLField from subiquitycore.view import BaseView -from subiquitycore.ui.form import ( - Form, - URLField, + +log = logging.getLogger("subiquity.ui.views.proxy") + +proxy_help = _( + "If you need to use a HTTP proxy to access the outside world, " + "enter the proxy information here. Otherwise, leave this blank." + "\n\nThe proxy information should be given in the standard " + 'form of "http://[[user][:pass]@]host[:port]/".' ) -log = logging.getLogger('subiquity.ui.views.proxy') - -proxy_help = _("If you need to use a HTTP proxy to access the outside world, " - "enter the proxy information here. Otherwise, leave this blank." - "\n\nThe proxy information should be given in the standard " - "form of \"http://[[user][:pass]@]host[:port]/\".") - - class ProxyForm(Form): - cancel_label = _("Back") url = URLField(_("Proxy address:"), help=proxy_help) class ProxyView(BaseView): - title = _("Configure proxy") - excerpt = _("If this system requires a proxy to connect to the internet, " - "enter its details here.") + excerpt = _( + "If this system requires a proxy to connect to the internet, " + "enter its details here." + ) def __init__(self, controller, proxy): self.controller = controller - self.form = ProxyForm(initial={'url': proxy}) + self.form = ProxyForm(initial={"url": proxy}) - connect_signal(self.form, 'submit', self.done) - connect_signal(self.form, 'cancel', self.cancel) + connect_signal(self.form, "submit", self.done) + connect_signal(self.form, "cancel", self.cancel) super().__init__(self.form.as_screen(excerpt=_(self.excerpt))) diff --git a/subiquity/ui/views/refresh.py b/subiquity/ui/views/refresh.py index 13637dfd..bf3723ec 100644 --- a/subiquity/ui/views/refresh.py +++ b/subiquity/ui/views/refresh.py @@ -18,30 +18,22 @@ import json import logging import aiohttp +from urwid import ProgressBar, Text, WidgetWrap -from urwid import ( - ProgressBar, - Text, - WidgetWrap, - ) - +from subiquity.common.types import RefreshCheckState, TaskStatus from subiquitycore.async_helpers import schedule_task -from subiquitycore.view import BaseView from subiquitycore.ui.buttons import done_btn, other_btn from subiquitycore.ui.container import Columns, ListBox from subiquitycore.ui.spinner import Spinner -from subiquitycore.ui.utils import button_pile, Color, screen +from subiquitycore.ui.utils import Color, button_pile, screen +from subiquitycore.view import BaseView -from subiquity.common.types import RefreshCheckState, TaskStatus - -log = logging.getLogger('subiquity.ui.views.refresh') +log = logging.getLogger("subiquity.ui.views.refresh") class TaskProgressBar(ProgressBar): def __init__(self): - super().__init__( - normal='progress_incomplete', - complete='progress_complete') + super().__init__(normal="progress_incomplete", complete="progress_complete") self._width = 80 self.label = "" @@ -55,25 +47,28 @@ class TaskProgressBar(ProgressBar): suffix = " {:.2f} / {:.2f} MiB".format(current_MiB, done_MiB) remaining = self._width - len(suffix) - 2 if len(self.label) > remaining: - label = self.label[:remaining-3] + '...' + label = self.label[: remaining - 3] + "..." else: - label = self.label + ' ' * (remaining - len(self.label)) + label = self.label + " " * (remaining - len(self.label)) return label + suffix class TaskProgress(WidgetWrap): - def __init__(self): self.mode = "spinning" self.spinner = Spinner() - self.label = Text("", wrap='clip') - cols = Color.progress_incomplete(Columns([ - (1, Text("")), - self.label, - (1, Text("")), - (1, self.spinner), - (1, Text("")), - ])) + self.label = Text("", wrap="clip") + cols = Color.progress_incomplete( + Columns( + [ + (1, Text("")), + self.label, + (1, Text("")), + (1, self.spinner), + (1, Text("")), + ] + ) + ) super().__init__(cols) def update(self, task): @@ -107,34 +102,29 @@ def exc_message(exc): class RefreshView(BaseView): - checking_title = _("Checking for installer update...") checking_excerpt = _( "Contacting the snap store to check if a new version of the " "installer is available." - ) + ) check_failed_title = _("Contacting the snap store failed") - check_failed_excerpt = _( - "Contacting the snap store failed:" - ) + check_failed_excerpt = _("Contacting the snap store failed:") available_title = _("Installer update available") available_excerpt = _( "Version {new} of the installer is now available ({current} is " "currently running)." - ) + ) progress_title = _("Downloading update...") progress_excerpt = _( "Please wait while the updated installer is being downloaded. The " "installer will restart automatically when the download is complete." - ) + ) update_failed_title = _("Update failed") - update_failed_excerpt = _( - "Downloading and applying the update:" - ) + update_failed_excerpt = _("Downloading and applying the update:") def __init__(self, controller): self.controller = controller @@ -156,7 +146,7 @@ class RefreshView(BaseView): buttons = [ done_btn(_("Continue without updating"), on_press=self.done), other_btn(_("Back"), on_press=self.cancel), - ] + ] self.title = self.checking_title self.controller.ui.set_header(self.title) @@ -179,11 +169,13 @@ class RefreshView(BaseView): rows = [Text(exc_message(exc))] - buttons = button_pile([ - done_btn(_("Try again"), on_press=self.try_check_again), - done_btn(_("Continue without updating"), on_press=self.done), - other_btn(_("Back"), on_press=self.cancel), - ]) + buttons = button_pile( + [ + done_btn(_("Try again"), on_press=self.try_check_again), + done_btn(_("Continue without updating"), on_press=self.done), + other_btn(_("Back"), on_press=self.cancel), + ] + ) buttons.base_widget.focus_position = 1 self.title = self.check_failed_title @@ -199,37 +191,41 @@ class RefreshView(BaseView): rows = [ Text(_("You can read the release notes for each version at:")), Text(""), - Text( - "https://github.com/canonical/subiquity/releases", - align='center'), + Text("https://github.com/canonical/subiquity/releases", align="center"), Text(""), Text( - _("If you choose to update, the update will be downloaded " - "and the installation will continue from here."), + _( + "If you choose to update, the update will be downloaded " + "and the installation will continue from here." ), - ] + ), + ] - buttons = button_pile([ - done_btn(_("Update to the new installer"), on_press=self.update), - done_btn(_("Continue without updating"), - on_press=self.skip_update), - other_btn(_("Back"), on_press=self.cancel), - ]) + buttons = button_pile( + [ + done_btn(_("Update to the new installer"), on_press=self.update), + done_btn(_("Continue without updating"), on_press=self.skip_update), + other_btn(_("Back"), on_press=self.cancel), + ] + ) buttons.base_widget.focus_position = 1 excerpt = _(self.available_excerpt).format( current=self.controller.status.current_snap_version, - new=self.controller.status.new_snap_version) + new=self.controller.status.new_snap_version, + ) self.title = self.available_title self.controller.ui.set_header(self.available_title) self._w = screen(rows, buttons, excerpt=excerpt) - if 'update' in self.controller.answers: - if self.controller.answers['update']: + if "update" in self.controller.answers: + if self.controller.answers["update"]: self.update() else: + async def skip(): self.skip_update() + self._skip_task = asyncio.create_task(skip()) def update(self, sender=None): @@ -276,12 +272,13 @@ class RefreshView(BaseView): rows = [Text(msg)] - buttons = button_pile([ - done_btn(_("Try again"), on_press=self.try_update_again), - done_btn(_("Continue without updating"), - on_press=self.skip_update), - other_btn(_("Back"), on_press=self.cancel), - ]) + buttons = button_pile( + [ + done_btn(_("Try again"), on_press=self.try_update_again), + done_btn(_("Continue without updating"), on_press=self.skip_update), + other_btn(_("Back"), on_press=self.cancel), + ] + ) buttons.base_widget.focus_position = 1 self.title = self.update_failed_title diff --git a/subiquity/ui/views/serial.py b/subiquity/ui/views/serial.py index 0b8b2b4c..f40a8123 100644 --- a/subiquity/ui/views/serial.py +++ b/subiquity/ui/views/serial.py @@ -57,11 +57,11 @@ class SerialView(BaseView): def make_serial(self): self.rich_btn = forward_btn( - label="Continue in rich mode", - on_press=self.rich_mode) + label="Continue in rich mode", on_press=self.rich_mode + ) self.basic_btn = forward_btn( - label="Continue in basic mode", - on_press=self.basic_mode) + label="Continue in basic mode", on_press=self.basic_mode + ) btns = [self.rich_btn, self.basic_btn] widgets = [ Text(""), @@ -71,9 +71,9 @@ class SerialView(BaseView): if self.ssh_info: widgets.append(Text(rewrap(SSH_TEXT))) widgets.append(Text("")) - btns.append(other_btn( - label="View SSH instructions", - on_press=self.ssh_help)) + btns.append( + other_btn(label="View SSH instructions", on_press=self.ssh_help) + ) return screen(widgets, btns) def rich_mode(self, sender): diff --git a/subiquity/ui/views/snaplist.py b/subiquity/ui/views/snaplist.py index 18206116..94947d4f 100644 --- a/subiquity/ui/views/snaplist.py +++ b/subiquity/ui/views/snaplist.py @@ -19,48 +19,26 @@ import logging from typing import Tuple import yaml +from urwid import AttrMap, CheckBox, LineBox +from urwid import ListBox as UrwidListBox +from urwid import RadioButton, SelectableIcon, SimpleFocusListWalker, Text -from urwid import ( - AttrMap, - CheckBox, - LineBox, - ListBox as UrwidListBox, - RadioButton, - SelectableIcon, - SimpleFocusListWalker, - Text, - ) - +from subiquity.common.types import SnapCheckState, SnapSelection +from subiquity.models.filesystem import humanize_size from subiquitycore.async_helpers import schedule_task -from subiquitycore.ui.buttons import ok_btn, cancel_btn, other_btn +from subiquitycore.ui.buttons import cancel_btn, ok_btn, other_btn from subiquitycore.ui.container import ( Columns, ListBox, Pile, ScrollBarListBox, WidgetWrap, - ) +) from subiquitycore.ui.spinner import Spinner -from subiquitycore.ui.table import ( - AbstractTable, - ColSpec, - TablePile, - TableRow, - ) -from subiquitycore.ui.utils import ( - button_pile, - Color, - Padding, - screen, - ) +from subiquitycore.ui.table import AbstractTable, ColSpec, TablePile, TableRow +from subiquitycore.ui.utils import Color, Padding, button_pile, screen from subiquitycore.view import BaseView -from subiquity.common.types import ( - SnapCheckState, - SnapSelection, - ) -from subiquity.models.filesystem import humanize_size - log = logging.getLogger("subiquity.ui.views.snaplist") @@ -68,11 +46,10 @@ class StarRadioButton(RadioButton): states = { True: SelectableIcon("(*)"), False: SelectableIcon("( )"), - } + } class NoTabCyclingTableListBox(AbstractTable): - def _make(self, rows): body = SimpleFocusListWalker(rows) return ScrollBarListBox(UrwidListBox(body)) @@ -82,14 +59,14 @@ def format_datetime(d): delta = datetime.datetime.now() - d if delta.total_seconds() < 60: return _("just now") - elif delta.total_seconds() < 60*60: - amount = int(delta.total_seconds()/60) + elif delta.total_seconds() < 60 * 60: + amount = int(delta.total_seconds() / 60) if amount == 1: unit = _("minute") else: unit = _("minutes") - elif delta.total_seconds() < 60*60*24: - amount = int(delta.total_seconds()/(60*60)) + elif delta.total_seconds() < 60 * 60 * 24: + amount = int(delta.total_seconds() / (60 * 60)) if amount == 1: unit = _("hour") else: @@ -106,7 +83,7 @@ def format_datetime(d): def check_mark() -> Tuple[Tuple[str, str], ...]: - """ Return a tuple that can be passed to a urwid.Text() """ + """Return a tuple that can be passed to a urwid.Text()""" # We want a single check-mark "✓" in rich-mode but a double star "**" in # basic mode. Because we have a 1:1 mapping between a given unicode @@ -114,13 +91,12 @@ def check_mark() -> Tuple[Tuple[str, str], ...]: # invisible in rich-mode. In basic mode, they both get substituted with a # star. return ( - ('verified', '\N{check mark}'), - ('verified invisible', '\N{check mark}'), + ("verified", "\N{check mark}"), + ("verified invisible", "\N{check mark}"), ) class SnapInfoView(WidgetWrap): - # This is mostly like a Pile but it tries to be a bit smart about # how to distribute space between the description and channel list # (which can both be arbitrarily long or short). If both are long, @@ -133,7 +109,7 @@ class SnapInfoView(WidgetWrap): self.needs_focus = True self.selections_by_name = {} - self.description = Text(snap.description.replace('\r', '').strip()) + self.description = Text(snap.description.replace("\r", "").strip()) self.lb_description = ListBox([self.description]) latest_update = datetime.datetime.min @@ -144,88 +120,115 @@ class SnapInfoView(WidgetWrap): selection = SnapSelection( name=self.snap.name, channel=csi.channel_name, - classic=csi.confinement == "classic") + classic=csi.confinement == "classic", + ) btn = StarRadioButton( radio_group, csi.channel_name, state=csi.channel_name == cur_channel, on_state_change=self.state_change, - user_data=selection) - channel_rows.append(Color.menu_button(TableRow([ - btn, - Text(csi.version), - Text("(" + csi.revision + ")"), - Text(humanize_size(csi.size)), - Text(format_datetime(csi.released_at)), - Text(csi.confinement), - ]))) + user_data=selection, + ) + channel_rows.append( + Color.menu_button( + TableRow( + [ + btn, + Text(csi.version), + Text("(" + csi.revision + ")"), + Text(humanize_size(csi.size)), + Text(format_datetime(csi.released_at)), + Text(csi.confinement), + ] + ) + ) + ) - first_info_row = TableRow([ - (3, Text( + first_info_row = TableRow( + [ + ( + 3, + Text( + [ + ("info_minor", _("LICENSE: ")), + snap.license, + ], + wrap="clip", + ), + ), + ( + 3, + Text( + [ + ("info_minor", _("LAST UPDATED: ")), + format_datetime(latest_update), + ] + ), + ), + ] + ) + heading_row = Color.info_minor( + TableRow( [ - ('info_minor', _("LICENSE: ")), - snap.license, - ], wrap='clip')), - (3, Text( - [ - ('info_minor', _("LAST UPDATED: ")), - format_datetime(latest_update), - ])), - ]) - heading_row = Color.info_minor(TableRow([ - Text(_("CHANNEL")), - (2, Text(_("VERSION"))), - Text(_("SIZE")), - Text(_("PUBLISHED")), - Text(_("CONFINEMENT")), - ])) + Text(_("CHANNEL")), + (2, Text(_("VERSION"))), + Text(_("SIZE")), + Text(_("PUBLISHED")), + Text(_("CONFINEMENT")), + ] + ) + ) colspecs = { 1: ColSpec(can_shrink=True), - } + } info_table = TablePile( [ first_info_row, TableRow([Text("")]), heading_row, ], - spacing=2, colspecs=colspecs) + spacing=2, + colspecs=colspecs, + ) self.lb_channels = NoTabCyclingTableListBox( - channel_rows, - spacing=2, colspecs=colspecs) + channel_rows, spacing=2, colspecs=colspecs + ) info_table.bind(self.lb_channels) self.info_padding = Padding.pull_1(info_table) - publisher = [('info_minor', _("by: ")), snap.publisher] + publisher = [("info_minor", _("by: ")), snap.publisher] if snap.verified: publisher.append(" ") publisher.extend(check_mark()) elif snap.starred: - publisher.append(('starred', ' \N{circled white star}')) + publisher.append(("starred", " \N{circled white star}")) - title = Columns([ - Text(snap.name), - ('pack', Text(publisher, align='right')), - ], dividechars=1) + title = Columns( + [ + Text(snap.name), + ("pack", Text(publisher, align="right")), + ], + dividechars=1, + ) contents = [ - ('pack', title), - ('pack', Text("")), - ('pack', Text(snap.summary)), - ('pack', Text("")), + ("pack", title), + ("pack", Text("")), + ("pack", Text(snap.summary)), + ("pack", Text("")), self.lb_description, # overwritten in render() - ('pack', Text("")), - ('pack', self.info_padding), - ('pack', Text("")), - ('weight', 1, self.lb_channels), - ] + ("pack", Text("")), + ("pack", self.info_padding), + ("pack", Text("")), + ("weight", 1, self.lb_channels), + ] self.description_index = contents.index(self.lb_description) self.pile = Pile(contents) super().__init__(self.pile) def state_change(self, sender, state, selection): if state: - log.debug( - "selecting %s from %s", self.snap.name, selection.channel) + log.debug("selecting %s from %s", self.snap.name, selection.channel) self.parent.snap_boxes[self.snap.name].set_state(True) self.parent.selections_by_name[self.snap.name] = selection @@ -233,12 +236,12 @@ class SnapInfoView(WidgetWrap): maxcol, maxrow = size rows_available = maxrow - pack_option = self.pile.options('pack') + pack_option = self.pile.options("pack") for w, o in self.pile.contents: if o == pack_option: rows_available -= w.rows((maxcol,), focus) - rows_wanted_description = self.description.rows((maxcol-1,), False) + rows_wanted_description = self.description.rows((maxcol - 1,), False) rows_wanted_channels = 0 for row in self.lb_channels._w.original_widget.body: rows_wanted_channels += row.rows((maxcol,), False) @@ -247,16 +250,19 @@ class SnapInfoView(WidgetWrap): description_rows = rows_wanted_description channel_rows = rows_wanted_channels else: - if rows_wanted_description < 2*rows_available/3: + if rows_wanted_description < 2 * rows_available / 3: description_rows = rows_wanted_description channel_rows = rows_available - description_rows else: channel_rows = max( - min(rows_wanted_channels, int(rows_available/3)), 3) + min(rows_wanted_channels, int(rows_available / 3)), 3 + ) description_rows = rows_available - channel_rows self.pile.contents[self.description_index] = ( - self.lb_description, self.pile.options('given', description_rows)) + self.lb_description, + self.pile.options("given", description_rows), + ) if description_rows >= rows_wanted_description: self.lb_description.base_widget._selectable = False else: @@ -272,7 +278,6 @@ class SnapInfoView(WidgetWrap): class FetchingFailed(WidgetWrap): - def __init__(self, row, snap): self.row = row self.closed = False @@ -284,10 +289,14 @@ class FetchingFailed(WidgetWrap): cancel = cancel_btn(label=_("Cancel"), on_press=self.close) super().__init__( LineBox( - Pile([ - ('pack', Text(' ' + text)), - ('pack', button_pile([retry, cancel])), - ]))) + Pile( + [ + ("pack", Text(" " + text)), + ("pack", button_pile([retry, cancel])), + ] + ) + ) + ) def load(self, sender=None): self.close() @@ -304,21 +313,20 @@ class SnapCheckBox(CheckBox): states = { True: SelectableIcon("[*]"), False: SelectableIcon("[ ]"), - } + } def __init__(self, parent, snap, state): self.parent = parent self.snap = snap - super().__init__( - snap.name, state=state, on_state_change=self.state_change) + super().__init__(snap.name, state=state, on_state_change=self.state_change) async def load_info(self): app = self.parent.controller.app await app.wait_with_text_dialog( - asyncio.shield( - self.parent.controller.get_snap_info(self.snap)), + asyncio.shield(self.parent.controller.get_snap_info(self.snap)), _("Fetching info for {snap}").format(snap=self.snap.name), - can_cancel=True) + can_cancel=True, + ) if len(self.snap.channels) == 0: # or other indication of failure ff = FetchingFailed(self, self.snap) self.parent.show_overlay(ff, width=ff.width) @@ -328,12 +336,17 @@ class SnapCheckBox(CheckBox): if selection is not None: cur_chan = selection.channel siv = SnapInfoView(self.parent, self.snap, cur_chan) - self.parent.show_screen(screen( - siv, - [other_btn( - label=_("Close"), - on_press=self.parent.show_main_screen)], - focus_buttons=False)) + self.parent.show_screen( + screen( + siv, + [ + other_btn( + label=_("Close"), on_press=self.parent.show_main_screen + ) + ], + focus_buttons=False, + ) + ) def keypress(self, size, key): if key.startswith("enter"): @@ -346,15 +359,15 @@ class SnapCheckBox(CheckBox): log.debug("selecting %s", self.snap.name) self.parent.selections_by_name[self.snap.name] = SnapSelection( name=self.snap.name, - channel='stable', - classic=self.snap.confinement == "classic") + channel="stable", + classic=self.snap.confinement == "classic", + ) else: log.debug("unselecting %s", self.snap.name) self.parent.selections_by_name.pop(self.snap.name, None) class SnapListView(BaseView): - title = _("Featured Server Snaps") def __init__(self, controller, data): @@ -366,11 +379,13 @@ class SnapListView(BaseView): self.loaded(data) def wait_load(self): - spinner = Spinner(style='dots') + spinner = Spinner(style="dots") spinner.start() self._w = screen( - [spinner], [ok_btn(label=_("Continue"), on_press=self.done)], - excerpt=_("Loading server snaps from store, please wait...")) + [spinner], + [ok_btn(label=_("Continue"), on_press=self.done)], + excerpt=_("Loading server snaps from store, please wait..."), + ) schedule_task(self._wait_load(spinner)) async def _wait_load(self, spinner): @@ -398,7 +413,8 @@ class SnapListView(BaseView): [ other_btn(label=_("Try again"), on_press=try_again_pressed), ok_btn(label=_("Continue"), on_press=self.done), - ]) + ], + ) def show_main_screen(self, sender=None): self._w = self._main_screen @@ -408,7 +424,7 @@ class SnapListView(BaseView): def get_seed_yaml(self): if self.controller.opts.dry_run: - return '''snaps: + return """snaps: - name: core channel: stable @@ -417,11 +433,11 @@ class SnapListView(BaseView): name: lxd channel: stable/ubuntu-18.04 file: lxd_59.snap -''' - seed_location = '/media/filesystem/var/lib/snapd/seed/seed.yaml' +""" + seed_location = "/media/filesystem/var/lib/snapd/seed/seed.yaml" content = "{}" try: - with open(seed_location, encoding='utf-8', errors='replace') as fp: + with open(seed_location, encoding="utf-8", errors="replace") as fp: content = fp.read() except FileNotFoundError: log.exception("could not find source at %r", seed_location) @@ -434,8 +450,8 @@ class SnapListView(BaseView): log.exception("failed to parse seed.yaml") return set() names = set() - for snap in seed.get('snaps', []): - name = snap.get('name') + for snap in seed.get("snaps", []): + name = snap.get("name") if name: names.add(name) log.debug("pre-seeded snaps %s", names) @@ -444,7 +460,7 @@ class SnapListView(BaseView): def make_main_screen(self, data): self.selections_by_name = { selection.name: selection for selection in data.selections - } + } self.snap_boxes = {} body = [] preinstalled = self.get_preinstalled_snaps() @@ -453,49 +469,56 @@ class SnapListView(BaseView): log.debug("not offering preseeded snap %r", snap.name) continue box = self.snap_boxes[snap.name] = SnapCheckBox( - self, snap, snap.name in self.selections_by_name) + self, snap, snap.name in self.selections_by_name + ) publisher = [snap.publisher] if snap.verified: publisher.extend(check_mark()) elif snap.starred: - publisher.append(('starred', '\N{circled white star}')) + publisher.append(("starred", "\N{circled white star}")) row = [ box, Text(publisher), - Text(snap.summary, wrap='clip'), - Text("\N{BLACK RIGHT-POINTING SMALL TRIANGLE}") - ] - body.append(AttrMap( - TableRow(row), - 'menu_button', - { - None: 'menu_button focus', - 'verified': 'verified focus', - 'verified invisible': 'verified inv focus', - 'starred': 'starred focus', - }, - )) + Text(snap.summary, wrap="clip"), + Text("\N{BLACK RIGHT-POINTING SMALL TRIANGLE}"), + ] + body.append( + AttrMap( + TableRow(row), + "menu_button", + { + None: "menu_button focus", + "verified": "verified focus", + "verified invisible": "verified inv focus", + "starred": "starred focus", + }, + ) + ) table = NoTabCyclingTableListBox( body, colspecs={ 1: ColSpec(omittable=True), 2: ColSpec(pack=False, min_width=40), - }) + }, + ) ok = ok_btn(label=_("Done"), on_press=self.done) cancel = cancel_btn(label=_("Back"), on_press=self.cancel) self._main_screen = screen( - table, [ok, cancel], + table, + [ok, cancel], focus_buttons=False, excerpt=_( "These are popular snaps in server environments. Select or " "deselect with SPACE, press ENTER to see more details of the " - "package, publisher and versions available.")) + "package, publisher and versions available." + ), + ) def done(self, sender=None): log.debug("snaps to install %s", self.selections_by_name) - self.controller.done(sorted( - self.selections_by_name.values(), - key=lambda s: s.name)) + self.controller.done( + sorted(self.selections_by_name.values(), key=lambda s: s.name) + ) def cancel(self, sender=None): if self._w is self._main_screen: diff --git a/subiquity/ui/views/source.py b/subiquity/ui/views/source.py index 651d882b..b5ed8308 100644 --- a/subiquity/ui/views/source.py +++ b/subiquity/ui/views/source.py @@ -16,25 +16,17 @@ import logging from typing import List -from urwid import ( - connect_signal, - Text, -) +from urwid import Text, connect_signal -from subiquitycore.view import BaseView from subiquitycore.ui.container import ListBox -from subiquitycore.ui.form import ( - BooleanField, - Form, - RadioButtonField, -) +from subiquitycore.ui.form import BooleanField, Form, RadioButtonField from subiquitycore.ui.utils import screen +from subiquitycore.view import BaseView -log = logging.getLogger('subiquity.ui.views.source') +log = logging.getLogger("subiquity.ui.views.source") class SourceView(BaseView): - title = _("Choose type of install") def __init__(self, controller, sources, current_id, search_drivers: bool): @@ -43,8 +35,8 @@ class SourceView(BaseView): group: List[RadioButtonField] = [] ns = { - 'cancel_label': _("Back"), - } + "cancel_label": _("Back"), + } initial = {} for default in True, False: @@ -52,24 +44,32 @@ class SourceView(BaseView): if source.default != default: continue ns[source.id] = RadioButtonField( - group, source.name, '\n' + source.description) + group, source.name, "\n" + source.description + ) initial[source.id] = source.id == current_id ns["search_drivers"] = BooleanField( - _("Search for third-party drivers"), "\n" + - _("This software is subject to license terms included with its " - "documentation. Some is proprietary.") + " " + - _("Third-party drivers should not be installed on systems that " - "will be used for FIPS or the real-time kernel.")) + _("Search for third-party drivers"), + "\n" + + _( + "This software is subject to license terms included with its " + "documentation. Some is proprietary." + ) + + " " + + _( + "Third-party drivers should not be installed on systems that " + "will be used for FIPS or the real-time kernel." + ), + ) initial["search_drivers"] = search_drivers - SourceForm = type(Form)('SourceForm', (Form,), ns) - log.debug('%r %r', ns, current_id) + SourceForm = type(Form)("SourceForm", (Form,), ns) + log.debug("%r %r", ns, current_id) self.form = SourceForm(initial=initial) - connect_signal(self.form, 'submit', self.done) - connect_signal(self.form, 'cancel', self.cancel) + connect_signal(self.form, "submit", self.done) + connect_signal(self.form, "cancel", self.cancel) excerpt = _("Choose the base for the installation.") @@ -81,10 +81,9 @@ class SourceView(BaseView): super().__init__( screen( - ListBox(rows), - self.form.buttons, - excerpt=excerpt, - focus_buttons=True)) + ListBox(rows), self.form.buttons, excerpt=excerpt, focus_buttons=True + ) + ) def done(self, result): log.debug("User input: %s", result.as_data()) diff --git a/subiquity/ui/views/ssh.py b/subiquity/ui/views/ssh.py index 5d63a7ac..dd339a2a 100644 --- a/subiquity/ui/views/ssh.py +++ b/subiquity/ui/views/ssh.py @@ -17,80 +17,53 @@ import logging import re from typing import List -from urwid import ( - connect_signal, - LineBox, - Text, - ) - -from subiquitycore.view import ( - BaseView, - ) -from subiquitycore.ui.buttons import ( - cancel_btn, - ok_btn, - ) -from subiquitycore.ui.container import ( - ListBox, - Pile, - WidgetWrap, - ) -from subiquitycore.ui.form import ( - BooleanField, - ChoiceField, - Form, -) -from subiquitycore.ui.spinner import ( - Spinner, - ) -from subiquitycore.ui.stretchy import ( - Stretchy, - ) -from subiquitycore.ui.utils import ( - button_pile, - screen, - SomethingFailed, - ) +from urwid import LineBox, Text, connect_signal from subiquity.common.types import SSHData, SSHIdentity -from subiquity.ui.views.identity import ( - UsernameField, - ) +from subiquity.ui.views.identity import UsernameField +from subiquitycore.ui.buttons import cancel_btn, ok_btn +from subiquitycore.ui.container import ListBox, Pile, WidgetWrap +from subiquitycore.ui.form import BooleanField, ChoiceField, Form +from subiquitycore.ui.spinner import Spinner +from subiquitycore.ui.stretchy import Stretchy +from subiquitycore.ui.utils import SomethingFailed, button_pile, screen +from subiquitycore.view import BaseView - -log = logging.getLogger('subiquity.ui.views.ssh') +log = logging.getLogger("subiquity.ui.views.ssh") SSH_IMPORT_MAXLEN = 256 + 3 # account for lp: or gh: _ssh_import_data = { None: { - 'caption': _("Import Username:"), - 'help': "", - 'valid_char': '.', - 'error_invalid_char': '', - 'regex': '.*', - }, - 'gh': { - 'caption': _("GitHub Username:"), - 'help': _("Enter your GitHub username."), - 'valid_char': r'[a-zA-Z0-9\-]', - 'error_invalid_char': _('A GitHub username may only contain ' - 'alphanumeric characters or hyphens.'), - }, - 'lp': { - 'caption': _("Launchpad Username:"), - 'help': "Enter your Launchpad username.", - 'valid_char': r'[a-z0-9\+\.\-]', - 'error_invalid_char': _('A Launchpad username may only contain ' - 'lower-case alphanumeric characters, hyphens, ' - 'plus, or periods.'), - }, - } + "caption": _("Import Username:"), + "help": "", + "valid_char": ".", + "error_invalid_char": "", + "regex": ".*", + }, + "gh": { + "caption": _("GitHub Username:"), + "help": _("Enter your GitHub username."), + "valid_char": r"[a-zA-Z0-9\-]", + "error_invalid_char": _( + "A GitHub username may only contain " "alphanumeric characters or hyphens." + ), + }, + "lp": { + "caption": _("Launchpad Username:"), + "help": "Enter your Launchpad username.", + "valid_char": r"[a-z0-9\+\.\-]", + "error_invalid_char": _( + "A Launchpad username may only contain " + "lower-case alphanumeric characters, hyphens, " + "plus, or periods." + ), + }, +} class SSHForm(Form): - install_server = BooleanField(_("Install OpenSSH server")) ssh_import_id = ChoiceField( @@ -99,10 +72,11 @@ class SSHForm(Form): (_("No"), True, None), (_("from GitHub"), True, "gh"), (_("from Launchpad"), True, "lp"), - ], - help=_("You can import your SSH keys from GitHub or Launchpad.")) + ], + help=_("You can import your SSH keys from GitHub or Launchpad."), + ) - import_username = UsernameField(_ssh_import_data[None]['caption']) + import_username = UsernameField(_ssh_import_data[None]["caption"]) pwauth = BooleanField(_("Allow password authentication over SSH")) @@ -110,8 +84,7 @@ class SSHForm(Form): def __init__(self, initial): super().__init__(initial=initial) - connect_signal( - self.install_server.widget, 'change', self._toggle_server) + connect_signal(self.install_server.widget, "change", self._toggle_server) self._toggle_server(None, self.install_server.value) def _toggle_server(self, sender, new_value): @@ -140,24 +113,29 @@ class SSHForm(Form): return _("This field must not be blank.") if len(username) > SSH_IMPORT_MAXLEN: return _("SSH id too long, must be < ") + str(SSH_IMPORT_MAXLEN) - if self.ssh_import_id_value == 'lp': + if self.ssh_import_id_value == "lp": lp_regex = r"^[a-z0-9][a-z0-9\+\.\-]*$" if not re.match(lp_regex, self.import_username.value): - return _("A Launchpad username must start with a letter or " - "number. All letters must be lower-case. The " - "characters +, - and . are also allowed after " - "the first character.""") - elif self.ssh_import_id_value == 'gh': - if not re.match(r'^[a-zA-Z0-9\-]+$', username): - return _("A GitHub username may only contain alphanumeric " - "characters or single hyphens, and cannot begin or " - "end with a hyphen.") + return _( + "A Launchpad username must start with a letter or " + "number. All letters must be lower-case. The " + "characters +, - and . are also allowed after " + "the first character." + "" + ) + elif self.ssh_import_id_value == "gh": + if not re.match(r"^[a-zA-Z0-9\-]+$", username): + return _( + "A GitHub username may only contain alphanumeric " + "characters or single hyphens, and cannot begin or " + "end with a hyphen." + ) class FetchingSSHKeys(WidgetWrap): def __init__(self, parent): self.parent = parent - spinner = Spinner(style='dots') + spinner = Spinner(style="dots") spinner.start() text = _("Fetching SSH keys...") button = cancel_btn(label=_("Cancel"), on_press=self.cancel) @@ -166,11 +144,15 @@ class FetchingSSHKeys(WidgetWrap): self.width = len(text) + 4 super().__init__( LineBox( - Pile([ - ('pack', Text(' ' + text)), - ('pack', spinner), - ('pack', button_pile([button])), - ]))) + Pile( + [ + ("pack", Text(" " + text)), + ("pack", spinner), + ("pack", button_pile([button])), + ] + ) + ) + ) def cancel(self, sender): self.parent.controller._fetch_cancel() @@ -188,15 +170,18 @@ class ConfirmSSHKeys(Stretchy): if len(identities) > 1: title = _("Confirm SSH keys") - header = _("Keys with the following fingerprints were fetched. " - "Do you want to use them?") + header = _( + "Keys with the following fingerprints were fetched. " + "Do you want to use them?" + ) else: title = _("Confirm SSH key") - header = _("A key with the following fingerprint was fetched. " - "Do you want to use it?") + header = _( + "A key with the following fingerprint was fetched. " + "Do you want to use it?" + ) - fingerprints = Pile([Text(identity.key_fingerprint) - for identity in identities]) + fingerprints = Pile([Text(identity.key_fingerprint) for identity in identities]) super().__init__( title, @@ -206,22 +191,27 @@ class ConfirmSSHKeys(Stretchy): fingerprints, Text(""), button_pile([ok, cancel]), - ], 2, 4) + ], + 2, + 4, + ) def cancel(self, sender): self.parent.remove_overlay() def ok(self, sender): - self.ssh_data.authorized_keys = \ - [id_.to_authorized_key() for id_ in self.identities] + self.ssh_data.authorized_keys = [ + id_.to_authorized_key() for id_ in self.identities + ] self.parent.controller.done(self.ssh_data) class SSHView(BaseView): - title = _("SSH Setup") - excerpt = _("You can choose to install the OpenSSH server package to " - "enable secure remote access to your server.") + excerpt = _( + "You can choose to install the OpenSSH server package to " + "enable secure remote access to your server." + ) def __init__(self, controller, ssh_data): self.controller = controller @@ -229,15 +219,16 @@ class SSHView(BaseView): initial = { "install_server": ssh_data.install_server, "pwauth": ssh_data.allow_pw, - } + } self.form = SSHForm(initial=initial) - connect_signal(self.form.ssh_import_id.widget, 'select', - self._select_ssh_import_id) + connect_signal( + self.form.ssh_import_id.widget, "select", self._select_ssh_import_id + ) - connect_signal(self.form, 'submit', self.done) - connect_signal(self.form, 'cancel', self.cancel) + connect_signal(self.form, "submit", self.done) + connect_signal(self.form, "cancel", self.cancel) self.form_rows = ListBox(self.form.as_rows()) super().__init__( @@ -245,15 +236,17 @@ class SSHView(BaseView): self.form_rows, self.form.buttons, excerpt=_(self.excerpt), - focus_buttons=False)) + focus_buttons=False, + ) + ) def _select_ssh_import_id(self, sender, val): iu = self.form.import_username data = _ssh_import_data[val] - iu.help = _(data['help']) - iu.caption = _(data['caption']) - iu.widget.valid_char_pat = data['valid_char'] - iu.widget.error_invalid_char = _(data['error_invalid_char']) + iu.help = _(data["help"]) + iu.caption = _(data["caption"]) + iu.widget.valid_char_pat = data["valid_char"] + iu.widget.error_invalid_char = _(data["error_invalid_char"]) iu.enabled = val is not None self.form.pwauth.enabled = val is not None # The logic here is a little tortured but the idea is that if @@ -274,27 +267,28 @@ class SSHView(BaseView): log.debug("User input: {}".format(self.form.as_data())) ssh_data = SSHData( install_server=self.form.install_server.value, - allow_pw=self.form.pwauth.value) + allow_pw=self.form.pwauth.value, + ) # if user specifed a value, allow user to validate fingerprint if self.form.ssh_import_id.value: - ssh_import_id = self.form.ssh_import_id.value + ":" + \ - self.form.import_username.value + ssh_import_id = ( + self.form.ssh_import_id.value + ":" + self.form.import_username.value + ) fsk = FetchingSSHKeys(self) self.show_overlay(fsk, width=fsk.width, min_width=None) self.controller.fetch_ssh_keys( - ssh_import_id=ssh_import_id, ssh_data=ssh_data) + ssh_import_id=ssh_import_id, ssh_data=ssh_data + ) else: self.controller.done(ssh_data) def cancel(self, result=None): self.controller.cancel() - def confirm_ssh_keys(self, ssh_data, ssh_import_id, - identities: List[SSHIdentity]): + def confirm_ssh_keys(self, ssh_data, ssh_import_id, identities: List[SSHIdentity]): self.remove_overlay() - self.show_stretchy_overlay( - ConfirmSSHKeys(self, ssh_data, identities)) + self.show_stretchy_overlay(ConfirmSSHKeys(self, ssh_data, identities)) def fetching_ssh_keys_failed(self, msg, stderr): # FIXME in answers-based runs, the overlay does not exist so we pass diff --git a/subiquity/ui/views/tests/test_identity.py b/subiquity/ui/views/tests/test_identity.py index 2e8fbcdb..4c5dd12c 100644 --- a/subiquity/ui/views/tests/test_identity.py +++ b/subiquity/ui/views/tests/test_identity.py @@ -16,48 +16,45 @@ import unittest from unittest import mock -from subiquitycore.testing import view_helpers - from subiquity.client.controllers.identity import IdentityController from subiquity.common.types import IdentityData, UsernameValidation from subiquity.ui.views.identity import IdentityView - +from subiquitycore.testing import view_helpers valid_data = { - 'realname': 'Real Name', - 'hostname': 'host-name', - 'username': 'username', - 'password': 'password', - 'confirm_password': 'password' - } + "realname": "Real Name", + "hostname": "host-name", + "username": "username", + "password": "password", + "confirm_password": "password", +} too_long = { - 'realname': 'Real Name', - 'hostname': 'host-name', - 'username': 'u'*33, - 'password': 'password', - 'confirm_password': 'password' - } + "realname": "Real Name", + "hostname": "host-name", + "username": "u" * 33, + "password": "password", + "confirm_password": "password", +} already_taken = { - 'realname': 'Real Name', - 'hostname': 'host-name', - 'username': 'root', - 'password': 'password', - 'confirm_password': 'password' - } + "realname": "Real Name", + "hostname": "host-name", + "username": "root", + "password": "password", + "confirm_password": "password", +} system_reserved = { - 'realname': 'Real Name', - 'hostname': 'host-name', - 'username': 'plugdev', - 'password': 'password', - 'confirm_password': 'password' - } + "realname": "Real Name", + "hostname": "host-name", + "username": "plugdev", + "password": "password", + "confirm_password": "password", +} class IdentityViewTests(unittest.IsolatedAsyncioTestCase): - def make_view(self): controller = mock.create_autospec(spec=IdentityController) return IdentityView(controller, IdentityData()) @@ -83,8 +80,9 @@ class IdentityViewTests(unittest.IsolatedAsyncioTestCase): async def test_username_validation_system_reserved(self): view = self.make_view() widget = view.form.username.widget - view.controller.validate_username.return_value = \ + view.controller.validate_username.return_value = ( UsernameValidation.SYSTEM_RESERVED + ) view_helpers.enter_data(view.form, system_reserved) widget.lost_focus() await widget.validation_task @@ -93,8 +91,9 @@ class IdentityViewTests(unittest.IsolatedAsyncioTestCase): async def test_username_validation_in_use(self): view = self.make_view() widget = view.form.username.widget - view.controller.validate_username.return_value = \ + view.controller.validate_username.return_value = ( UsernameValidation.ALREADY_IN_USE + ) view_helpers.enter_data(view.form, already_taken) widget.lost_focus() await widget.validation_task @@ -107,17 +106,18 @@ class IdentityViewTests(unittest.IsolatedAsyncioTestCase): def test_click_done(self): view = self.make_view() - CRYPTED = '' - with mock.patch('subiquity.ui.views.identity.crypt_password') as cp: + CRYPTED = "" + with mock.patch("subiquity.ui.views.identity.crypt_password") as cp: cp.side_effect = lambda p: CRYPTED view_helpers.enter_data(view.form, valid_data) done_btn = view_helpers.find_button_matching(view, "^Done$") view_helpers.click(done_btn) expected = IdentityData( - realname=valid_data['realname'], - username=valid_data['username'], - hostname=valid_data['hostname'], - crypted_password=CRYPTED) + realname=valid_data["realname"], + username=valid_data["username"], + hostname=valid_data["hostname"], + crypted_password=CRYPTED, + ) view.controller.done.assert_called_once_with(expected) async def test_can_tab_to_done_when_valid(self): @@ -135,7 +135,7 @@ class IdentityViewTests(unittest.IsolatedAsyncioTestCase): view_helpers.enter_data(view.form, valid_data) for i in range(100): - view_helpers.keypress(view, 'tab', size=(80, 24)) + view_helpers.keypress(view, "tab", size=(80, 24)) focus_path = view_helpers.get_focus_path(view) for w in reversed(focus_path): if w is view.form.done_btn: diff --git a/subiquity/ui/views/tests/test_installprogress.py b/subiquity/ui/views/tests/test_installprogress.py index 64b374e0..e8e2b502 100644 --- a/subiquity/ui/views/tests/test_installprogress.py +++ b/subiquity/ui/views/tests/test_installprogress.py @@ -1,15 +1,13 @@ import unittest from unittest import mock -from subiquitycore.testing import view_helpers - from subiquity.client.controllers.progress import ProgressController from subiquity.common.types import ApplicationState from subiquity.ui.views.installprogress import ProgressView +from subiquitycore.testing import view_helpers class IdentityViewTests(unittest.TestCase): - def make_view(self): controller = mock.create_autospec(spec=ProgressController) controller.app = mock.Mock() diff --git a/subiquity/ui/views/tests/test_welcome.py b/subiquity/ui/views/tests/test_welcome.py index 17630ebc..014718b1 100644 --- a/subiquity/ui/views/tests/test_welcome.py +++ b/subiquity/ui/views/tests/test_welcome.py @@ -1,15 +1,14 @@ import unittest from unittest import mock -import urwid -from subiquitycore.testing import view_helpers +import urwid from subiquity.client.controllers.welcome import WelcomeController from subiquity.ui.views.welcome import WelcomeView +from subiquitycore.testing import view_helpers class WelcomeViewTests(unittest.TestCase): - def make_view_with_languages(self, languages): controller = mock.create_autospec(spec=WelcomeController) with mock.patch("subiquity.ui.views.welcome.get_languages") as p: @@ -19,18 +18,20 @@ class WelcomeViewTests(unittest.TestCase): def test_basic(self): # Clicking the button for a language calls "switch_language" # on the model and "done" on the controller. - view = self.make_view_with_languages([('code', 'native')]) + view = self.make_view_with_languages([("code", "native")]) but = view_helpers.find_button_matching(view, "^native$") view_helpers.click(but) - view.controller.done.assert_called_once_with(('code', 'native')) + view.controller.done.assert_called_once_with(("code", "native")) def test_initial_focus(self): # The initial focus for the view is the button for the first # language. - view = self.make_view_with_languages([ - ('code1', 'native1'), - ('code2', 'native2'), - ]) + view = self.make_view_with_languages( + [ + ("code1", "native1"), + ("code2", "native2"), + ] + ) for w in reversed(view_helpers.get_focus_path(view)): if isinstance(w, urwid.Button): self.assertEqual(w.label, "native1") diff --git a/subiquity/ui/views/ubuntu_pro.py b/subiquity/ui/views/ubuntu_pro.py index 8b1de997..a3733554 100644 --- a/subiquity/ui/views/ubuntu_pro.py +++ b/subiquity/ui/views/ubuntu_pro.py @@ -19,114 +19,87 @@ import logging import re from typing import Callable, List -from urwid import ( - Columns, - connect_signal, - LineBox, - Text, - Widget, - ) +from urwid import Columns, LineBox, Text, Widget, connect_signal -from subiquity.common.types import ( - UbuntuProCheckTokenStatus, - UbuntuProSubscription, - ) -from subiquitycore.view import BaseView -from subiquitycore.ui.buttons import ( - back_btn, - cancel_btn, - done_btn, - menu_btn, - ok_btn, - ) -from subiquitycore.ui.container import ( - ListBox, - Pile, - WidgetWrap, - ) +from subiquity.common.types import UbuntuProCheckTokenStatus, UbuntuProSubscription +from subiquitycore.ui.buttons import back_btn, cancel_btn, done_btn, menu_btn, ok_btn +from subiquitycore.ui.container import ListBox, Pile, WidgetWrap from subiquitycore.ui.form import ( - Form, - SubForm, - SubFormField, NO_HELP, - simple_field, + Form, RadioButtonField, ReadOnlyField, + SubForm, + SubFormField, WantsToKnowFormField, - ) -from subiquitycore.ui.spinner import ( - Spinner, - ) -from subiquitycore.ui.stretchy import ( - Stretchy, - ) -from subiquitycore.ui.utils import ( - button_pile, - screen, - SomethingFailed, - ) - + simple_field, +) from subiquitycore.ui.interactive import StringEditor +from subiquitycore.ui.spinner import Spinner +from subiquitycore.ui.stretchy import Stretchy +from subiquitycore.ui.utils import SomethingFailed, button_pile, screen +from subiquitycore.view import BaseView - -log = logging.getLogger('subiquity.ui.views.ubuntu_pro') +log = logging.getLogger("subiquity.ui.views.ubuntu_pro") class ContractTokenEditor(StringEditor, WantsToKnowFormField): - """ Represent a text-box editor for the Ubuntu Pro Token. """ + """Represent a text-box editor for the Ubuntu Pro Token.""" + def __init__(self): - """ Initialize the text-field editor for UA token. """ + """Initialize the text-field editor for UA token.""" self.valid_char_pat = r"[1-9A-HJ-NP-Za-km-z]" self.error_invalid_char = _("'{}' is not a valid character.") super().__init__() def valid_char(self, ch: str) -> bool: - """ Tells whether the character passed is within the range of allowed + """Tells whether the character passed is within the range of allowed characters """ if len(ch) == 1 and not re.match(self.valid_char_pat, ch): self.bff.in_error = True - self.bff.show_extra( - ("info_error", self.error_invalid_char.format(ch))) + self.bff.show_extra(("info_error", self.error_invalid_char.format(ch))) return False return super().valid_char(ch) class ContractTokenForm(SubForm): - """ Represents a sub-form requesting Ubuntu Pro token. + """Represents a sub-form requesting Ubuntu Pro token. +---------------------------------------------------------+ | Token: C123456789ABCDEF | | From your admin, or from ubuntu.com/pro | +---------------------------------------------------------+ """ + ContractTokenField = simple_field(ContractTokenEditor) token = ContractTokenField( - _("Token:"), - help=_("From your admin, or from ubuntu.com/pro")) + _("Token:"), help=_("From your admin, or from ubuntu.com/pro") + ) def validate_token(self): - """ Return an error if the token input does not match the expected - format. """ + """Return an error if the token input does not match the expected + format.""" if not 24 <= len(self.token.value) <= 30: return _("Invalid token: should be between 24 and 30 characters") class UbuntuOneForm(SubForm): - """ Represents a sub-form showing a user-code and requesting the user to + """Represents a sub-form showing a user-code and requesting the user to browse ubuntu.com/pro/attach. +---------------------------------------------------------+ | Code: 1A3-4V6 | | Attach machine via your Ubuntu One account | +---------------------------------------------------------+ """ + user_code = ReadOnlyField( - _("Code:"), - help=_("Attach machine via your Ubuntu One account")) + _("Code:"), help=_("Attach machine via your Ubuntu One account") + ) class UpgradeModeForm(Form): - """ Represents a form requesting information to enable Ubuntu Pro. + """Represents a form requesting information to enable Ubuntu Pro. +---------------------------------------------------------+ | (X) Enter code on ubuntu.com/pro/attach | | Code: 1A3-4V6 | @@ -140,31 +113,32 @@ class UpgradeModeForm(Form): | [ Back ] | +---------------------------------------------------------+ """ + cancel_label = _("Back") ok_label = _("Continue") ok_label_waiting = _("Waiting") group: List[RadioButtonField] = [] with_ubuntu_one = RadioButtonField( - group, _("Enter code on ubuntu.com/pro/attach"), - help=NO_HELP) - with_ubuntu_one_subform = SubFormField( - UbuntuOneForm, "", help=NO_HELP) + group, _("Enter code on ubuntu.com/pro/attach"), help=NO_HELP + ) + with_ubuntu_one_subform = SubFormField(UbuntuOneForm, "", help=NO_HELP) with_contract_token = RadioButtonField( - group, _("Or add token manually"), - help=NO_HELP) - with_contract_token_subform = SubFormField( - ContractTokenForm, "", help=NO_HELP) + group, _("Or add token manually"), help=NO_HELP + ) + with_contract_token_subform = SubFormField(ContractTokenForm, "", help=NO_HELP) def __init__(self, initial) -> None: - """ Initializer that configures the callback to run when the radios are - checked/unchecked. """ + """Initializer that configures the callback to run when the radios are + checked/unchecked.""" super().__init__(initial) - connect_signal(self.with_contract_token.widget, - 'change', self._toggle_contract_token_input) - connect_signal(self.with_ubuntu_one.widget, - 'change', self._toggle_ubuntu_one_input) + connect_signal( + self.with_contract_token.widget, "change", self._toggle_contract_token_input + ) + connect_signal( + self.with_ubuntu_one.widget, "change", self._toggle_ubuntu_one_input + ) if initial["with_contract_token_subform"]["token"]: self._toggle_contract_token_input(None, True) @@ -177,26 +151,26 @@ class UpgradeModeForm(Form): self.with_ubuntu_one_subform.widget.form.user_code.in_error = True def _toggle_contract_token_input(self, sender, new_value): - """ Enable/disable the sub-form that requests the contract token. """ + """Enable/disable the sub-form that requests the contract token.""" self.with_contract_token_subform.enabled = new_value if new_value: self.buttons.base_widget[0].set_label(self.ok_label) def _toggle_ubuntu_one_input(self, sender, new_value): - """ Enable/disable the sub-form that shows the user-code. """ + """Enable/disable the sub-form that shows the user-code.""" self.with_ubuntu_one_subform.enabled = new_value if new_value: self.buttons.base_widget[0].set_label(self.ok_label_waiting) def set_user_code(self, user_code: str) -> None: - """ Show the new value of user code in a green backgroud. """ + """Show the new value of user code in a green backgroud.""" value = ("user_code", user_code) self.with_ubuntu_one_subform.widget.form.user_code.value = value class UpgradeYesNoForm(Form): - """ Represents a form asking if we want to upgrade to Ubuntu Pro. + """Represents a form asking if we want to upgrade to Ubuntu Pro. +---------------------------------------------------------+ | (X) Enable Ubuntu Pro | | | @@ -214,17 +188,17 @@ class UpgradeYesNoForm(Form): ok_label = _("Continue") group: List[RadioButtonField] = [] - upgrade = RadioButtonField( - group, _("Enable Ubuntu Pro"), - help=NO_HELP) + upgrade = RadioButtonField(group, _("Enable Ubuntu Pro"), help=NO_HELP) skip = RadioButtonField( - group, _("Skip for now"), - help="\n" + _("You can always enable Ubuntu Pro later via the" - " 'pro attach' command.")) + group, + _("Skip for now"), + help="\n" + + _("You can always enable Ubuntu Pro later via the" " 'pro attach' command."), + ) class UpgradeYesNoFormNoNetwork(UpgradeYesNoForm): - """ Represents a "read-only" form that does not let the user enable Ubuntu + """Represents a "read-only" form that does not let the user enable Ubuntu Pro because the network is unavailable. +---------------------------------------------------------+ | ( ) Enable Ubuntu Pro | @@ -238,18 +212,22 @@ class UpgradeYesNoFormNoNetwork(UpgradeYesNoForm): | [ Back ] | +---------------------------------------------------------+ """ + group: List[RadioButtonField] = [] - upgrade = RadioButtonField( - group, _("Enable Ubuntu Pro"), - help=NO_HELP) + upgrade = RadioButtonField(group, _("Enable Ubuntu Pro"), help=NO_HELP) skip = RadioButtonField( - group, _("Skip Ubuntu Pro setup for now"), - help="\n" + _("Once you are connected to the Internet, you can" - " enable Ubuntu Po via the 'pro attach' command.")) + group, + _("Skip Ubuntu Pro setup for now"), + help="\n" + + _( + "Once you are connected to the Internet, you can" + " enable Ubuntu Po via the 'pro attach' command." + ), + ) def __init__(self): - """ Initializer that disables the relevant fields. """ + """Initializer that disables the relevant fields.""" super().__init__(initial={}) self.upgrade.value = False self.upgrade.enabled = False @@ -257,10 +235,11 @@ class UpgradeYesNoFormNoNetwork(UpgradeYesNoForm): class CheckingContractToken(WidgetWrap): - """ Widget displaying a loading animation while checking ubuntu pro - subscription. """ + """Widget displaying a loading animation while checking ubuntu pro + subscription.""" + def __init__(self, parent: BaseView): - """ Initializes the loading animation widget. """ + """Initializes the loading animation widget.""" self.parent = parent spinner = Spinner(style="dots") spinner.start() @@ -269,62 +248,66 @@ class CheckingContractToken(WidgetWrap): self.width = len(text) + 4 super().__init__( LineBox( - Pile([ - ('pack', Text(' ' + text)), - ('pack', spinner), - ('pack', button_pile([button])), - ]))) + Pile( + [ + ("pack", Text(" " + text)), + ("pack", spinner), + ("pack", button_pile([button])), + ] + ) + ) + ) def cancel(self, sender) -> None: - """ Close the loading animation and cancel the check operation. """ + """Close the loading animation and cancel the check operation.""" self.parent.controller.cancel_check_token() self.parent.remove_overlay() class UbuntuProView(BaseView): - """ Represent the view of the Ubuntu Pro configuration. """ + """Represent the view of the Ubuntu Pro configuration.""" title = _("Upgrade to Ubuntu Pro") subscription_done_label = _("Continue") def __init__(self, controller, token: str, has_network: bool): - """ Initialize the view with the default value for the token. """ + """Initialize the view with the default value for the token.""" self.controller = controller self.has_network = has_network if self.has_network: - self.upgrade_yes_no_form = UpgradeYesNoForm(initial={ - "skip": not token, - "upgrade": bool(token), - }) + self.upgrade_yes_no_form = UpgradeYesNoForm( + initial={ + "skip": not token, + "upgrade": bool(token), + } + ) else: self.upgrade_yes_no_form = UpgradeYesNoFormNoNetwork() - self.upgrade_mode_form = UpgradeModeForm(initial={ - "with_contract_token_subform": {"token": token}, - "with_contract_token": bool(token), - "with_ubuntu_one": not token, - }) + self.upgrade_mode_form = UpgradeModeForm( + initial={ + "with_contract_token_subform": {"token": token}, + "with_contract_token": bool(token), + "with_ubuntu_one": not token, + } + ) def on_upgrade_yes_no_cancel(unused: UpgradeYesNoForm): - """ Function to call when hitting Done from the upgrade/skip - screen. """ + """Function to call when hitting Done from the upgrade/skip + screen.""" self.cancel() def on_upgrade_mode_cancel(unused: UpgradeModeForm): - """ Function to call when hitting Back from the contract token - form. """ + """Function to call when hitting Back from the contract token + form.""" self._w = self.upgrade_yes_no_screen() - connect_signal(self.upgrade_yes_no_form, - 'submit', self.upgrade_yes_no_done) - connect_signal(self.upgrade_yes_no_form, - 'cancel', on_upgrade_yes_no_cancel) - connect_signal(self.upgrade_mode_form, - 'submit', self.upgrade_mode_done) - connect_signal(self.upgrade_mode_form, - 'cancel', on_upgrade_mode_cancel) + connect_signal(self.upgrade_yes_no_form, "submit", self.upgrade_yes_no_done) + connect_signal(self.upgrade_yes_no_form, "cancel", on_upgrade_yes_no_cancel) + connect_signal(self.upgrade_mode_form, "submit", self.upgrade_mode_done) + connect_signal(self.upgrade_mode_form, "cancel", on_upgrade_mode_cancel) # Throwaway tasks self.tasks: List[asyncio.Task] = [] @@ -332,7 +315,7 @@ class UbuntuProView(BaseView): super().__init__(self.upgrade_yes_no_screen()) def upgrade_mode_screen(self) -> Widget: - """ Return a screen that asks the user to provide a contract token or + """Return a screen that asks the user to provide a contract token or input a user-code on the portal. +---------------------------------------------------------+ | To upgrade to Ubuntu Pro, use your existing free | @@ -354,14 +337,15 @@ class UbuntuProView(BaseView): +---------------------------------------------------------+ """ - excerpt = _("To upgrade to Ubuntu Pro, use your existing free" - " personal, or company Ubuntu One account, or provide a" - " token.") + excerpt = _( + "To upgrade to Ubuntu Pro, use your existing free" + " personal, or company Ubuntu One account, or provide a" + " token." + ) how_to_register_btn = menu_btn( - _("How to register"), - on_press=lambda unused: self.show_how_to_register() - ) + _("How to register"), on_press=lambda unused: self.show_how_to_register() + ) bp = button_pile([how_to_register_btn]) bp.align = "left" rows = [ @@ -369,13 +353,14 @@ class UbuntuProView(BaseView): Text(""), ] + self.upgrade_mode_form.as_rows() return screen( - ListBox(rows), - self.upgrade_mode_form.buttons, - excerpt=excerpt, - focus_buttons=True) + ListBox(rows), + self.upgrade_mode_form.buttons, + excerpt=excerpt, + focus_buttons=True, + ) def upgrade_yes_no_screen(self) -> Widget: - """ Return a screen that asks the user to skip or upgrade. + """Return a screen that asks the user to skip or upgrade. +---------------------------------------------------------+ | Upgrade this machine to Ubuntu Pro for security updates | | on a much wider range of packages, until 2032. Assists | @@ -396,17 +381,20 @@ class UbuntuProView(BaseView): """ security_updates_until = 2032 - excerpt = _("Upgrade this machine to Ubuntu Pro for security updates" - " on a much wider range of packages, until" - f" {security_updates_until}. Assists with FedRAMP, FIPS," - " STIG, HIPAA and other compliance or hardening" - " requirements.") - excerpt_no_net = _("An Internet connection is required to enable" - " Ubuntu Pro.") + excerpt = _( + "Upgrade this machine to Ubuntu Pro for security updates" + " on a much wider range of packages, until" + f" {security_updates_until}. Assists with FedRAMP, FIPS," + " STIG, HIPAA and other compliance or hardening" + " requirements." + ) + excerpt_no_net = _( + "An Internet connection is required to enable" " Ubuntu Pro." + ) about_pro_btn = menu_btn( - _("About Ubuntu Pro"), - on_press=lambda unused: self.show_about_ubuntu_pro()) + _("About Ubuntu Pro"), on_press=lambda unused: self.show_about_ubuntu_pro() + ) bp = button_pile([about_pro_btn]) bp.align = "left" @@ -415,13 +403,13 @@ class UbuntuProView(BaseView): Text(""), ] + self.upgrade_yes_no_form.as_rows() return screen( - ListBox(rows), - self.upgrade_yes_no_form.buttons, - excerpt=excerpt if self.has_network else excerpt_no_net, - focus_buttons=True) + ListBox(rows), + self.upgrade_yes_no_form.buttons, + excerpt=excerpt if self.has_network else excerpt_no_net, + focus_buttons=True, + ) - def subscription_screen(self, subscription: UbuntuProSubscription) \ - -> Widget: + def subscription_screen(self, subscription: UbuntuProSubscription) -> Widget: """ +---------------------------------------------------------+ | Subscription: Ubuntu Pro - Physical 24/5 | @@ -452,16 +440,17 @@ class UbuntuProView(BaseView): rows: List[Widget] = [] - rows.extend([ - Text(_("Subscription") + ": " + subscription.contract_name), - ]) + rows.extend( + [ + Text(_("Subscription") + ": " + subscription.contract_name), + ] + ) rows.append(Text("")) if auto_enabled: rows.append(Text(_("List of your enabled services:"))) rows.append(Text("")) - rows.extend( - [Text(f" * {svc.description}") for svc in auto_enabled]) + rows.extend([Text(f" * {svc.description}") for svc in auto_enabled]) if can_be_enabled: if auto_enabled: @@ -471,8 +460,7 @@ class UbuntuProView(BaseView): else: rows.append(Text(_("Available services:"))) rows.append(Text("")) - rows.extend( - [Text(f" * {svc.description}") for svc in can_be_enabled]) + rows.extend([Text(f" * {svc.description}") for svc in can_be_enabled]) def on_continue() -> None: self.controller.next_screen() @@ -480,51 +468,55 @@ class UbuntuProView(BaseView): def on_back() -> None: self._w = self.upgrade_yes_no_screen() - back_button = back_btn( - label=_("Back"), - on_press=lambda unused: on_back()) + back_button = back_btn(label=_("Back"), on_press=lambda unused: on_back()) continue_button = done_btn( - label=self.__class__.subscription_done_label, - on_press=lambda unused: on_continue()) + label=self.__class__.subscription_done_label, + on_press=lambda unused: on_continue(), + ) widgets: List[Widget] = [ Text(""), Pile(rows), Text(""), - Text(_("If you want to change the default enablements for your" - " token, you can do so via the ubuntu.com/pro web" - " interface. Alternatively you can change enabled services" - " using the `pro` command-line tool once the installation" - " is finished.")), + Text( + _( + "If you want to change the default enablements for your" + " token, you can do so via the ubuntu.com/pro web" + " interface. Alternatively you can change enabled services" + " using the `pro` command-line tool once the installation" + " is finished." + ) + ), Text(""), ] return screen( - ListBox(widgets), - buttons=[continue_button, back_button], - excerpt=None, - focus_buttons=True) + ListBox(widgets), + buttons=[continue_button, back_button], + excerpt=None, + focus_buttons=True, + ) + + def contract_token_check_ok(self, subscription: UbuntuProSubscription) -> None: + """Close the "checking-token" overlay and open the token added overlay + instead.""" - def contract_token_check_ok(self, subscription: UbuntuProSubscription) \ - -> None: - """ Close the "checking-token" overlay and open the token added overlay - instead. """ def show_subscription() -> None: self.remove_overlay() self.show_subscription(subscription=subscription) self.remove_overlay() - widget = TokenAddedWidget( - parent=self, - on_continue=show_subscription) + widget = TokenAddedWidget(parent=self, on_continue=show_subscription) self.show_stretchy_overlay(widget) def upgrade_mode_done(self, form: UpgradeModeForm) -> None: - """ Open the loading dialog and asynchronously check if the token is - valid. """ + """Open the loading dialog and asynchronously check if the token is + valid.""" + def on_success(subscription: UbuntuProSubscription) -> None: def noop() -> None: pass + if self.controller.cs_initiated: self.controller.contract_selection_cancel(on_cancelled=noop) self.contract_token_check_ok(subscription) @@ -547,48 +539,51 @@ class UbuntuProView(BaseView): token: str = form.with_contract_token_subform.value["token"] checking_token_overlay = CheckingContractToken(self) - self.show_overlay(checking_token_overlay, - width=checking_token_overlay.width, - min_width=None) + self.show_overlay( + checking_token_overlay, width=checking_token_overlay.width, min_width=None + ) - self.controller.check_token(token, - on_success=on_success, - on_failure=on_failure) + self.controller.check_token(token, on_success=on_success, on_failure=on_failure) def cs_initiated(self, user_code: str) -> None: - """ Function to call when the contract selection has successfully - initiated. It will start the polling asynchronously. """ + """Function to call when the contract selection has successfully + initiated. It will start the polling asynchronously.""" + def reinitiate() -> None: - self.controller.contract_selection_initiate( - on_initiated=self.cs_initiated) + self.controller.contract_selection_initiate(on_initiated=self.cs_initiated) self.upgrade_mode_form.set_user_code(user_code) self.controller.contract_selection_wait( - on_contract_selected=self.on_contract_selected, - on_timeout=reinitiate, - ) + on_contract_selected=self.on_contract_selected, + on_timeout=reinitiate, + ) def on_contract_selected(self, contract_token: str) -> None: - """ Function to call when the contract selection has finished - succesfully. """ + """Function to call when the contract selection has finished + succesfully.""" checking_token_overlay = CheckingContractToken(self) - self.show_overlay(checking_token_overlay, - width=checking_token_overlay.width, - min_width=None) + self.show_overlay( + checking_token_overlay, width=checking_token_overlay.width, min_width=None + ) def on_failure(status: UbuntuProCheckTokenStatus) -> None: - """ Open a message box stating that the contract-token + """Open a message box stating that the contract-token obtained via contract-selection is not valid ; and then go - back to the previous screen. """ + back to the previous screen.""" - log.error("contract-token obtained via contract-selection" - " counld not be validated: %r", status) + log.error( + "contract-token obtained via contract-selection" + " counld not be validated: %r", + status, + ) self._w = self.upgrade_mode_screen() - self.show_stretchy_overlay(SomethingFailed( - self, - msg=_("Internal error"), - stderr=_("Could not add the contract token selected.") - )) + self.show_stretchy_overlay( + SomethingFailed( + self, + msg=_("Internal error"), + stderr=_("Could not add the contract token selected."), + ) + ) # It would be uncommon to have this call fail in production # because the contract token obtained via contract-selection is @@ -596,65 +591,72 @@ class UbuntuProView(BaseView): # environment used (e.g., staging in subiquity and production # in u-a-c) can lead to this error though. self.controller.check_token( - contract_token, - on_success=self.contract_token_check_ok, - on_failure=on_failure) + contract_token, + on_success=self.contract_token_check_ok, + on_failure=on_failure, + ) def upgrade_yes_no_done(self, form: UpgradeYesNoForm) -> None: - """ If skip is selected, move on to the next screen. - Otherwise, show the form requesting a contract token. """ + """If skip is selected, move on to the next screen. + Otherwise, show the form requesting a contract token.""" if form.skip.value: self.controller.done("") else: + def initiate() -> None: self.controller.contract_selection_initiate( - on_initiated=self.cs_initiated) + on_initiated=self.cs_initiated + ) self._w = self.upgrade_mode_screen() if self.controller.cs_initiated: # Cancel the existing contract selection before initiating a # new one. - self.controller.contract_selection_cancel( - on_cancelled=initiate) + self.controller.contract_selection_cancel(on_cancelled=initiate) else: initiate() def cancel(self) -> None: - """ Called when the user presses the Back button. """ + """Called when the user presses the Back button.""" self.controller.cancel() def show_about_ubuntu_pro(self) -> None: - """ Display an overlay that shows information about Ubuntu Pro. """ + """Display an overlay that shows information about Ubuntu Pro.""" self.show_stretchy_overlay(AboutProWidget(self)) def show_how_to_register(self) -> None: - """ Display an overlay that shows instructions to register to - Ubuntu Pro. """ + """Display an overlay that shows instructions to register to + Ubuntu Pro.""" self.show_stretchy_overlay(HowToRegisterWidget(self)) def show_invalid_token(self) -> None: - """ Display an overlay that indicates that the user-supplied token is - invalid. """ + """Display an overlay that indicates that the user-supplied token is + invalid.""" self.show_stretchy_overlay(InvalidTokenWidget(self)) def show_expired_token(self) -> None: - """ Display an overlay that indicates that the user-supplied token has - expired. """ + """Display an overlay that indicates that the user-supplied token has + expired.""" self.show_stretchy_overlay(ExpiredTokenWidget(self)) def show_unknown_error(self) -> None: - """ Display an overlay that indicates that we were unable to retrieve + """Display an overlay that indicates that we were unable to retrieve the subscription information. Reasons can be multiple include lack of network connection, temporary service unavailability, API issue ... The user is prompted to continue anyway or go back. """ - question = _("Unable to check your subscription information." - " Do you want to go back or continue anyway?") + question = _( + "Unable to check your subscription information." + " Do you want to go back or continue anyway?" + ) async def confirm_continue_anyway() -> None: confirmed = await self.ask_confirmation( - title=_("Unknown error"), question=question, - cancel_label=_("Back"), confirm_label=_("Continue anyway")) + title=_("Unknown error"), + question=question, + cancel_label=_("Back"), + confirm_label=_("Continue anyway"), + ) if confirmed: subform = self.upgrade_mode_form.with_contract_token_subform @@ -663,14 +665,14 @@ class UbuntuProView(BaseView): self.tasks.append(asyncio.create_task(confirm_continue_anyway())) def show_subscription(self, subscription: UbuntuProSubscription) -> None: - """ Display a screen with information about the subscription, including + """Display a screen with information about the subscription, including the list of services that can be enabled. After the user confirms, we - will quit the current view and move on. """ + will quit the current view and move on.""" self._w = self.subscription_screen(subscription=subscription) class ExpiredTokenWidget(Stretchy): - """ Widget that shows that the supplied token is expired. + """Widget that shows that the supplied token is expired. +--------------------- Expired token ---------------------+ | | @@ -680,26 +682,25 @@ class ExpiredTokenWidget(Stretchy): | [ Okay ] | +---------------------------------------------------------+ """ + def __init__(self, parent: BaseView) -> None: - """ Initializes the widget. """ + """Initializes the widget.""" self.parent = parent cont = done_btn(label=_("Okay"), on_press=lambda unused: self.close()) widgets = [ - Text(_("Your token has expired. Please use another token" - " to continue.")), + Text(_("Your token has expired. Please use another token" " to continue.")), Text(""), button_pile([cont]), - ] - super().__init__("Expired token", widgets, - stretchy_index=0, focus_index=2) + ] + super().__init__("Expired token", widgets, stretchy_index=0, focus_index=2) def close(self) -> None: - """ Close the overlay. """ + """Close the overlay.""" self.parent.remove_overlay() class InvalidTokenWidget(Stretchy): - """ Widget that shows that the supplied token is invalid. + """Widget that shows that the supplied token is invalid. +--------------------- Invalid token ---------------------+ | | @@ -709,26 +710,30 @@ class InvalidTokenWidget(Stretchy): | [ Okay ] | +---------------------------------------------------------+ """ + def __init__(self, parent: BaseView) -> None: - """ Initializes the widget. """ + """Initializes the widget.""" self.parent = parent cont = done_btn(label=_("Okay"), on_press=lambda unused: self.close()) widgets = [ - Text(_("Your token could not be verified. Please ensure it is" - " correct and try again.")), + Text( + _( + "Your token could not be verified. Please ensure it is" + " correct and try again." + ) + ), Text(""), button_pile([cont]), - ] - super().__init__("Invalid token", widgets, - stretchy_index=0, focus_index=2) + ] + super().__init__("Invalid token", widgets, stretchy_index=0, focus_index=2) def close(self) -> None: - """ Close the overlay. """ + """Close the overlay.""" self.parent.remove_overlay() class TokenAddedWidget(Stretchy): - """ Widget that shows that the supplied token is valid and was "added". + """Widget that shows that the supplied token is valid and was "added". +---------------- Token added successfully ---------------+ | | | Your token has been added successfully and your | @@ -738,29 +743,32 @@ class TokenAddedWidget(Stretchy): | [ Continue ] | +---------------------------------------------------------+ """ + title = _("Token added successfully") done_label = _("Continue") - def __init__(self, parent: UbuntuProView, - on_continue: Callable[[], None]) -> None: - """ Initializes the widget. """ + def __init__(self, parent: UbuntuProView, on_continue: Callable[[], None]) -> None: + """Initializes the widget.""" self.parent = parent cont = done_btn( - label=self.__class__.done_label, - on_press=lambda unused: on_continue()) + label=self.__class__.done_label, on_press=lambda unused: on_continue() + ) widgets = [ - Text(_("Your token has been added successfully and your" - " subscription configuration will be applied at the first" - " boot.")), + Text( + _( + "Your token has been added successfully and your" + " subscription configuration will be applied at the first" + " boot." + ) + ), Text(""), button_pile([cont]), - ] - super().__init__(self.__class__.title, widgets, - stretchy_index=0, focus_index=2) + ] + super().__init__(self.__class__.title, widgets, stretchy_index=0, focus_index=2) class AboutProWidget(Stretchy): - """ Widget showing some information about what Ubuntu Pro offers. + """Widget showing some information about what Ubuntu Pro offers. +------------------- About Ubuntu Pro --------------------+ | | | Ubuntu Pro is the same base Ubuntu, with an additional | @@ -779,34 +787,38 @@ class AboutProWidget(Stretchy): | [ Continue ] | +---------------------------------------------------------+ """ + def __init__(self, parent: UbuntuProView) -> None: - """ Initializes the widget.""" + """Initializes the widget.""" self.parent = parent ok = ok_btn(label=_("Continue"), on_press=lambda unused: self.close()) title = _("About Ubuntu Pro") - header = _("Ubuntu Pro is the same base Ubuntu, with an additional" - " layer of security and compliance services and security" - " patches covering a wider range of packages.") + header = _( + "Ubuntu Pro is the same base Ubuntu, with an additional" + " layer of security and compliance services and security" + " patches covering a wider range of packages." + ) universe_packages = 23000 main_packages = 2300 services = [ - _("Security patch coverage for CVSS critical, high and selected" - f" medium vulnerabilities in all {universe_packages:,} packages" - f' in "universe" (extended from the normal {main_packages:,}' - ' "main" packages).'), + _( + "Security patch coverage for CVSS critical, high and selected" + f" medium vulnerabilities in all {universe_packages:,} packages" + f' in "universe" (extended from the normal {main_packages:,}' + ' "main" packages).' + ), _("10 years of security patch coverage (extended from 5 years)."), _("Kernel Livepatch to reduce required reboots."), _("Ubuntu Security Guide for CIS and DISA-STIG hardening."), - _("FIPS 140-2 NIST-certified crypto-modules for FedRAMP" - " compliance"), + _("FIPS 140-2 NIST-certified crypto-modules for FedRAMP" " compliance"), ] def itemize(item: str, marker: str = "•") -> Columns: - """ Return the text specified in a Text widget prepended with a + """Return the text specified in a Text widget prepended with a bullet point / marker. If the text is too long to fit in a single line, the continuation lines are indented as shown below: +---------------------------+ @@ -815,16 +827,14 @@ class AboutProWidget(Stretchy): | look like. | +---------------------------+ """ - return Columns( - [(len(marker), Text(marker)), Text(item)], dividechars=1) + return Columns([(len(marker), Text(marker)), Text(item)], dividechars=1) widgets: List[Widget] = [ Text(header), Text(""), Pile([itemize(svc, marker=" •") for svc in services]), Text(""), - Text(_("Ubuntu Pro is free for personal use on up to 3" - " machines.")), + Text(_("Ubuntu Pro is free for personal use on up to 3" " machines.")), Text(_("More information is at ubuntu.com/pro")), Text(""), button_pile([ok]), @@ -833,12 +843,12 @@ class AboutProWidget(Stretchy): super().__init__(title, widgets, stretchy_index=2, focus_index=7) def close(self) -> None: - """ Close the overlay. """ + """Close the overlay.""" self.parent.remove_overlay() class HowToRegisterWidget(Stretchy): - """ Widget showing some instructions to register to Ubuntu Pro. + """Widget showing some instructions to register to Ubuntu Pro. +-------------------- How to register --------------------+ | | |_Create your Ubuntu One account with your email. Each | @@ -851,17 +861,20 @@ class HowToRegisterWidget(Stretchy): | [ Continue ] | +---------------------------------------------------------+ """ + def __init__(self, parent: UbuntuProView) -> None: - """ Initializes the widget.""" + """Initializes the widget.""" self.parent = parent ok = ok_btn(label=_("Continue"), on_press=lambda unused: self.close()) title = _("How to register") - header = _("Create your Ubuntu One account with your email. Each" - " Ubuntu One account gets a free personal Ubuntu Pro" - " subscription for up to three machines, including" - " laptops, servers or cloud virtual machines.") + header = _( + "Create your Ubuntu One account with your email. Each" + " Ubuntu One account gets a free personal Ubuntu Pro" + " subscription for up to three machines, including" + " laptops, servers or cloud virtual machines." + ) widgets: List[Widget] = [ Text(header), @@ -874,5 +887,5 @@ class HowToRegisterWidget(Stretchy): super().__init__(title, widgets, stretchy_index=2, focus_index=4) def close(self) -> None: - """ Close the overlay. """ + """Close the overlay.""" self.parent.remove_overlay() diff --git a/subiquity/ui/views/welcome.py b/subiquity/ui/views/welcome.py index e55f3646..4fffb6e9 100644 --- a/subiquity/ui/views/welcome.py +++ b/subiquity/ui/views/welcome.py @@ -22,22 +22,23 @@ import logging from urwid import Text +from subiquity.common.resources import resource_path +from subiquitycore.screen import is_linux_tty from subiquitycore.ui.buttons import forward_btn, other_btn from subiquitycore.ui.container import ListBox from subiquitycore.ui.stretchy import Stretchy from subiquitycore.ui.utils import button_pile, rewrap, screen -from subiquitycore.screen import is_linux_tty from subiquitycore.view import BaseView -from subiquity.common.resources import resource_path - log = logging.getLogger("subiquity.ui.views.welcome") -HELP = _(""" +HELP = _( + """ Select the language to use for the installer and to be configured in the installed system. -""") +""" +) CLOUD_INIT_FAIL_TEXT = """ cloud-init failed to complete after 10 minutes of waiting. This @@ -48,12 +49,12 @@ attach the contents of /var/log, it would be most appreciated. def get_languages(): - lang_path = resource_path('languagelist') + lang_path = resource_path("languagelist") languages = [] with open(lang_path) as lang_file: for line in lang_file: - level, code, name = line.strip().split(':') + level, code, name = line.strip().split(":") if is_linux_tty() and level != "console": continue languages.append((code, name)) @@ -83,9 +84,9 @@ class WelcomeView(BaseView): current_index = i btns.append( forward_btn( - label=native, - on_press=self.choose_language, - user_arg=(code, native))) + label=native, on_press=self.choose_language, user_arg=(code, native) + ) + ) lb = ListBox(btns) back = None @@ -94,12 +95,15 @@ class WelcomeView(BaseView): if current_index is not None: lb.base_widget.focus_position = current_index return screen( - lb, focus_buttons=False, narrow_rows=True, + lb, + focus_buttons=False, + narrow_rows=True, buttons=[back] if back else None, - excerpt=_("Use UP, DOWN and ENTER keys to select your language.")) + excerpt=_("Use UP, DOWN and ENTER keys to select your language."), + ) def choose_language(self, sender, lang): - log.debug('WelcomeView %s', lang) + log.debug("WelcomeView %s", lang) self.controller.done(lang) def local_help(self): @@ -109,20 +113,14 @@ class WelcomeView(BaseView): class CloudInitFail(Stretchy): def __init__(self, app): self.app = app - self.shell_btn = other_btn( - _("Switch to a shell"), on_press=self._debug_shell) - self.close_btn = other_btn( - _("Close"), on_press=self._close) + self.shell_btn = other_btn(_("Switch to a shell"), on_press=self._debug_shell) + self.close_btn = other_btn(_("Close"), on_press=self._close) widgets = [ Text(rewrap(_(CLOUD_INIT_FAIL_TEXT))), - Text(''), + Text(""), button_pile([self.shell_btn, self.close_btn]), - ] - super().__init__( - "", - widgets, - stretchy_index=0, - focus_index=2) + ] + super().__init__("", widgets, stretchy_index=0, focus_index=2) def _debug_shell(self, sender): self.app.debug_shell() diff --git a/subiquity/ui/views/zdev.py b/subiquity/ui/views/zdev.py index 28907d42..6a4a193a 100644 --- a/subiquity/ui/views/zdev.py +++ b/subiquity/ui/views/zdev.py @@ -20,35 +20,17 @@ Provides device activation and configuration on s390x """ import logging -from urwid import ( - connect_signal, - Text, - ) +from urwid import Text, connect_signal from subiquitycore.async_helpers import run_bg_task -from subiquitycore.ui.actionmenu import ( - ActionMenu, - ) -from subiquitycore.ui.buttons import ( - back_btn, - done_btn, - ) -from subiquitycore.ui.container import ( - WidgetWrap, - ) -from subiquitycore.ui.table import ( - ColSpec, - TableListBox, - TableRow, - ) -from subiquitycore.ui.utils import ( - Color, - make_action_menu_row, - screen, - ) +from subiquitycore.ui.actionmenu import ActionMenu +from subiquitycore.ui.buttons import back_btn, done_btn +from subiquitycore.ui.container import WidgetWrap +from subiquitycore.ui.table import ColSpec, TableListBox, TableRow +from subiquitycore.ui.utils import Color, make_action_menu_row, screen from subiquitycore.view import BaseView -log = logging.getLogger('subiquity.ui.views.zdev') +log = logging.getLogger("subiquity.ui.views.zdev") def status(zdevinfo): @@ -65,60 +47,79 @@ def status(zdevinfo): class ZdevList(WidgetWrap): - def __init__(self, parent): self.parent = parent - self.table = TableListBox([], spacing=2, colspecs={ - 0: ColSpec(rpad=2), - 1: ColSpec(rpad=2), - 2: ColSpec(rpad=2), - 3: ColSpec(rpad=2), - }) - self._no_zdev_content = Color.info_minor( - Text(_("No zdev devices found."))) + self.table = TableListBox( + [], + spacing=2, + colspecs={ + 0: ColSpec(rpad=2), + 1: ColSpec(rpad=2), + 2: ColSpec(rpad=2), + 3: ColSpec(rpad=2), + }, + ) + self._no_zdev_content = Color.info_minor(Text(_("No zdev devices found."))) super().__init__(self.table) async def _zdev_action(self, action, zdevinfo): new_zdevinfos = await self.parent.controller.app.wait_with_text_dialog( - self.parent.controller.chzdev(action, zdevinfo), "Updating...") + self.parent.controller.chzdev(action, zdevinfo), "Updating..." + ) self.update(new_zdevinfos) def zdev_action(self, sender, action, zdevinfo): run_bg_task(self._zdev_action(action, zdevinfo)) def update(self, zdevinfos): - rows = [TableRow([ - Color.info_minor(heading) for heading in [ - Text(_("ID")), - Text(_("ONLINE")), - Text(_("NAMES")), - ]])] + rows = [ + TableRow( + [ + Color.info_minor(heading) + for heading in [ + Text(_("ID")), + Text(_("ONLINE")), + Text(_("NAMES")), + ] + ] + ) + ] - typeclass = '' + typeclass = "" for i, zdevinfo in enumerate(zdevinfos): if zdevinfo.typeclass != typeclass: - rows.append(TableRow([ - Text(""), - ])) - rows.append(TableRow([ - Color.info_minor(Text(zdevinfo.type)), - Text(""), - Text("") - ])) + rows.append( + TableRow( + [ + Text(""), + ] + ) + ) + rows.append( + TableRow( + [Color.info_minor(Text(zdevinfo.type)), Text(""), Text("")] + ) + ) typeclass = zdevinfo.typeclass - if zdevinfo.type == 'zfcp-lun': - rows.append(TableRow([ - Color.info_minor(Text(zdevinfo.id[9:])), - status(zdevinfo), - Text(zdevinfo.names), - ])) + if zdevinfo.type == "zfcp-lun": + rows.append( + TableRow( + [ + Color.info_minor(Text(zdevinfo.id[9:])), + status(zdevinfo), + Text(zdevinfo.names), + ] + ) + ) continue - actions = [(_("Enable"), not zdevinfo.on, 'enable'), - (_("Disable"), zdevinfo.on, 'disable')] + actions = [ + (_("Enable"), not zdevinfo.on, "enable"), + (_("Disable"), zdevinfo.on, "disable"), + ] menu = ActionMenu(actions) - connect_signal(menu, 'action', self.zdev_action, zdevinfo) + connect_signal(menu, "action", self.zdev_action, zdevinfo) cells = [ Text(zdevinfo.id), status(zdevinfo), @@ -128,12 +129,13 @@ class ZdevList(WidgetWrap): row = make_action_menu_row( cells, menu, - attr_map='menu_button', + attr_map="menu_button", focus_map={ - None: 'menu_button focus', - 'info_minor': 'menu_button focus', + None: "menu_button focus", + "info_minor": "menu_button focus", }, - cursor_x=0) + cursor_x=0, + ) rows.append(row) self.table.set_contents(rows) if self.table._w.base_widget.focus_position >= len(rows): @@ -144,14 +146,12 @@ class ZdevView(BaseView): title = _("Zdev setup") def __init__(self, controller, zdevinfos): - log.debug('FileSystemView init start()') + log.debug("FileSystemView init start()") self.controller = controller self.zdev_list = ZdevList(self) self.zdev_list.update(zdevinfos) - frame = screen( - self.zdev_list, self._build_buttons(), - focus_buttons=False) + frame = screen(self.zdev_list, self._build_buttons(), focus_buttons=False) super().__init__(frame) # Prevent urwid from putting the first focused widget at the # very top of the display (obscuring the headings) @@ -161,7 +161,7 @@ class ZdevView(BaseView): return [ done_btn(_("Continue"), on_press=self.done), back_btn(_("Back"), on_press=self.cancel), - ] + ] def cancel(self, button=None): self.controller.cancel() diff --git a/subiquitycore/__init__.py b/subiquitycore/__init__.py index 0cccf2c0..56de480e 100644 --- a/subiquitycore/__init__.py +++ b/subiquitycore/__init__.py @@ -16,8 +16,9 @@ """ SubiquityCore """ from subiquitycore import i18n + __all__ = [ - 'i18n', + "i18n", ] __version__ = "0.0.5" diff --git a/subiquitycore/async_helpers.py b/subiquitycore/async_helpers.py index fdcc6ad8..4a0acfc1 100644 --- a/subiquitycore/async_helpers.py +++ b/subiquitycore/async_helpers.py @@ -17,7 +17,6 @@ import asyncio import concurrent.futures import logging - log = logging.getLogger("subiquitycore.async_helpers") @@ -48,7 +47,7 @@ background_tasks = set() def run_bg_task(coro, *args, **kwargs) -> None: - """ Run a background task in a fire-and-forget style. """ + """Run a background task in a fire-and-forget style.""" task = asyncio.create_task(coro, *args, **kwargs) background_tasks.add(task) task.add_done_callback(background_tasks.discard) @@ -65,11 +64,11 @@ async def run_in_thread(func, *args): class TaskAlreadyRunningError(Exception): """Used to let callers know that a task hasn't been started due to cancel_restart == False and the task already running.""" + pass class SingleInstanceTask: - def __init__(self, func, propagate_errors=True, cancel_restart=True): self.func = func self.propagate_errors = propagate_errors @@ -96,7 +95,8 @@ class SingleInstanceTask: if not self.cancel_restart: if self.task is not None and not self.task.done(): raise TaskAlreadyRunningError( - 'Skipping invocation of task - already running') + "Skipping invocation of task - already running" + ) old = self.task coro = self.func(*args, **kw) if asyncio.iscoroutine(coro): diff --git a/subiquitycore/context.py b/subiquitycore/context.py index 954ec041..07c94d71 100644 --- a/subiquitycore/context.py +++ b/subiquitycore/context.py @@ -79,7 +79,7 @@ class Context: while c is not None: names.append(c.name) c = c.parent - return '/'.join(reversed(names)) + return "/".join(reversed(names)) def enter(self, description=None): if description is None: @@ -126,29 +126,31 @@ def with_context(name=None, description="", **context_kw): name = meth.__name__ def convargs(self, kw): - context = kw.get('context') + context = kw.get("context") if context is None: context = self.context - kw['context'] = context.child( + kw["context"] = context.child( name=name.format(**kw), description=description.format(self=self, **kw), - **context_kw) + **context_kw, + ) return kw @functools.wraps(meth) def decorated_sync(self, **kw): kw = convargs(self, kw) - with kw['context']: + with kw["context"]: return meth(self, **kw) @functools.wraps(meth) async def decorated_async(self, **kw): kw = convargs(self, kw) - with kw['context']: + with kw["context"]: return await meth(self, **kw) if inspect.iscoroutinefunction(meth): return decorated_async else: return decorated_sync + return decorate diff --git a/subiquitycore/controller.py b/subiquitycore/controller.py index ab6e9530..72a2768f 100644 --- a/subiquitycore/controller.py +++ b/subiquitycore/controller.py @@ -13,8 +13,8 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from abc import ABC import logging +from abc import ABC from typing import Optional log = logging.getLogger("subiquitycore.controller") @@ -26,7 +26,7 @@ class BaseController(ABC): model_name: Optional[str] = None def __init__(self, app): - self.name = type(self).__name__[:-len("Controller")] + self.name = type(self).__name__[: -len("Controller")] self.opts = app.opts self.app = app self.context = self.app.context.child(self.name, childlevel="DEBUG") diff --git a/subiquitycore/controllers/network.py b/subiquitycore/controllers/network.py index 2f19b824..92d203ac 100644 --- a/subiquitycore/controllers/network.py +++ b/subiquitycore/controllers/network.py @@ -23,11 +23,12 @@ from typing import Optional import pyroute2 import yaml - from probert.network import IFF_UP, NetworkEventReceiver +from subiquitycore import netplan from subiquitycore.async_helpers import SingleInstanceTask from subiquitycore.context import with_context +from subiquitycore.controller import BaseController from subiquitycore.file_util import write_file from subiquitycore.models.network import ( BondConfig, @@ -35,21 +36,12 @@ from subiquitycore.models.network import ( NetDevAction, StaticConfig, WLANConfig, - ) -from subiquitycore import netplan -from subiquitycore.controller import BaseController +) from subiquitycore.pubsub import CoreChannels from subiquitycore.tuicontroller import TuiController from subiquitycore.ui.stretchy import StretchyOverlay -from subiquitycore.ui.views.network import ( - NetworkView, - ) -from subiquitycore.utils import ( - arun_command, - orig_environ, - run_command, - ) - +from subiquitycore.ui.views.network import NetworkView +from subiquitycore.utils import arun_command, orig_environ, run_command log = logging.getLogger("subiquitycore.controllers.network") @@ -84,20 +76,20 @@ class SubiquityNetworkEventReceiver(NetworkEventReceiver): def route_change(self, action, data): super().route_change(action, data) - if data['dst'] != 'default': + if data["dst"] != "default": return - if data['table'] != 254: + if data["table"] != 254: return self.probe_default_routes() self.controller.update_has_default_route(self.has_default_route) def _default_route_exists(self, routes): for route in routes: - if int(route['table']) != 254: + if int(route["table"]) != 254: continue - if route['dst']: + if route["dst"]: continue - if int(route['priority']) >= 20000: + if int(route["priority"]) >= 20000: # network manager probes routes by creating one at 20000 + # the real metric, but those aren't necessarily valid. continue @@ -107,11 +99,12 @@ class SubiquityNetworkEventReceiver(NetworkEventReceiver): def probe_default_routes(self): with pyroute2.NDB() as ndb: self.has_default_route = self._default_route_exists(ndb.routes) - log.debug('default routes %s', self.has_default_route) + log.debug("default routes %s", self.has_default_route) @staticmethod - def create(controller: BaseController, dry_run: bool) \ - -> "SubiquityNetworkEventReceiver": + def create( + controller: BaseController, dry_run: bool + ) -> "SubiquityNetworkEventReceiver": if dry_run: return DryRunSubiquityNetworkEventReceiver(controller) else: @@ -121,10 +114,10 @@ class SubiquityNetworkEventReceiver(NetworkEventReceiver): class DryRunSubiquityNetworkEventReceiver(SubiquityNetworkEventReceiver): def probe_default_routes(self): self.has_default_route = True - log.debug('dryrun default routes %s', self.has_default_route) + log.debug("dryrun default routes %s", self.has_default_route) -default_netplan = ''' +default_netplan = """ network: version: 2 ethernets: @@ -153,11 +146,10 @@ network: access-points: "some-ap": password: password -''' +""" class BaseNetworkController(BaseController): - model_name = "network" root = "/" @@ -170,23 +162,26 @@ class BaseNetworkController(BaseController): netplan_dir = os.path.dirname(netplan_path) if os.path.exists(netplan_dir): import shutil + shutil.rmtree(netplan_dir) os.makedirs(netplan_dir) - with open(netplan_path, 'w') as fp: + with open(netplan_path, "w") as fp: fp.write(default_netplan) self.parse_netplan_configs() self._watching = False - self.network_event_receiver = \ - SubiquityNetworkEventReceiver.create(self, self.opts.dry_run) + self.network_event_receiver = SubiquityNetworkEventReceiver.create( + self, self.opts.dry_run + ) def parse_netplan_configs(self): self.model.parse_netplan_configs(self.root) def start(self): self._observer_handles = [] - self.observer, self._observer_fds = ( - self.app.prober.probe_network(self.network_event_receiver)) + self.observer, self._observer_fds = self.app.prober.probe_network( + self.network_event_receiver + ) self.start_watching() def stop_watching(self): @@ -206,7 +201,7 @@ class BaseNetworkController(BaseController): self._watching = True def _data_ready(self, fd): - cp = run_command(['udevadm', 'settle', '-t', '0']) + cp = run_command(["udevadm", "settle", "-t", "0"]) if cp.returncode != 0: log.debug("waiting 0.1 to let udev event queue settle") self.stop_watching() @@ -237,10 +232,10 @@ class BaseNetworkController(BaseController): @property def netplan_path(self): if self.opts.project == "subiquity": - netplan_config_file_name = '00-installer-config.yaml' + netplan_config_file_name = "00-installer-config.yaml" else: - netplan_config_file_name = '00-snapd-config.yaml' - return os.path.join(self.root, 'etc/netplan', netplan_config_file_name) + netplan_config_file_name = "00-snapd-config.yaml" + return os.path.join(self.root, "etc/netplan", netplan_config_file_name) def apply_config(self, context=None, silent=False): self.apply_config_task.start_sync(context=context, silent=silent) @@ -248,17 +243,17 @@ class BaseNetworkController(BaseController): async def _down_devs(self, devs): for dev in devs: try: - log.debug('downing %s', dev.name) + log.debug("downing %s", dev.name) self.observer.rtlistener.unset_link_flags(dev.ifindex, IFF_UP) except RuntimeError: # We don't actually care very much about this - log.exception('unset_link_flags failed for %s', dev.name) + log.exception("unset_link_flags failed for %s", dev.name) async def _delete_devs(self, devs): for dev in devs: # XXX would be nicer to do this via rtlistener eventually. - log.debug('deleting %s', dev.name) - cmd = ['ip', 'link', 'delete', 'dev', dev.name] + log.debug("deleting %s", dev.name) + cmd = ["ip", "link", "delete", "dev", dev.name] try: await arun_command(cmd, check=True) except subprocess.CalledProcessError as cp: @@ -267,10 +262,10 @@ class BaseNetworkController(BaseController): def _write_config(self): config = self.model.render_config() - log.debug("network config: \n%s", - yaml.dump( - netplan.sanitize_config(config), - default_flow_style=False)) + log.debug( + "network config: \n%s", + yaml.dump(netplan.sanitize_config(config), default_flow_style=False), + ) for p in netplan.configs_in_root(self.root, masked=True): if p == self.netplan_path: @@ -281,8 +276,7 @@ class BaseNetworkController(BaseController): self.parse_netplan_configs() - @with_context( - name="apply_config", description="silent={silent}", level="INFO") + @with_context(name="apply_config", description="silent={silent}", level="INFO") async def _apply_config(self, *, context, silent): devs_to_delete = [] devs_to_down = [] @@ -294,8 +288,7 @@ class BaseNetworkController(BaseController): if dev.dhcp_enabled(v): if not silent: dev.set_dhcp_state(v, DHCPState.PENDING) - self.network_event_receiver.update_link( - dev.ifindex) + self.network_event_receiver.update_link(dev.ifindex) else: dev.set_dhcp_state(v, DHCPState.RECONFIGURE) dev.dhcp_events[v] = e = asyncio.Event() @@ -314,14 +307,15 @@ class BaseNetworkController(BaseController): self.apply_starting() try: + def error(stage): if not silent: self.apply_error(stage) if self.opts.dry_run: - delay = 1/self.app.scale_factor - await arun_command(['sleep', str(delay)]) - if os.path.exists('/lib/netplan/generate'): + delay = 1 / self.app.scale_factor + await arun_command(["sleep", str(delay)]) + if os.path.exists("/lib/netplan/generate"): # If netplan appears to be installed, run generate to # at least test that what we wrote is acceptable to # netplan but clear the SNAP environment variable to @@ -329,23 +323,34 @@ class BaseNetworkController(BaseController): # tries to call netplan over the system bus. env = os.environ.copy() with contextlib.suppress(KeyError): - del env['SNAP'] + del env["SNAP"] await arun_command( - ['netplan', 'generate', '--root', self.root], - check=True, env=env) + ["netplan", "generate", "--root", self.root], + check=True, + env=env, + ) else: if devs_to_down or devs_to_delete: try: await arun_command( - ['systemctl', 'mask', '--runtime', - 'systemd-networkd.service', - 'systemd-networkd.socket'], - check=True) + [ + "systemctl", + "mask", + "--runtime", + "systemd-networkd.service", + "systemd-networkd.socket", + ], + check=True, + ) await arun_command( - ['systemctl', 'stop', - 'systemd-networkd.service', - 'systemd-networkd.socket'], - check=True) + [ + "systemctl", + "stop", + "systemd-networkd.service", + "systemd-networkd.socket", + ], + check=True, + ) except subprocess.CalledProcessError: error("stop-networkd") raise @@ -355,24 +360,30 @@ class BaseNetworkController(BaseController): await self._delete_devs(devs_to_delete) if devs_to_down or devs_to_delete: await arun_command( - ['systemctl', 'unmask', '--runtime', - 'systemd-networkd.service', - 'systemd-networkd.socket'], - check=True) + [ + "systemctl", + "unmask", + "--runtime", + "systemd-networkd.service", + "systemd-networkd.socket", + ], + check=True, + ) env = orig_environ(None) try: - await arun_command(['netplan', 'apply'], - env=env, check=True) + await arun_command(["netplan", "apply"], env=env, check=True) except subprocess.CalledProcessError as cpe: - log.debug('CalledProcessError: ' - f'stdout[{cpe.stdout}] stderr[{cpe.stderr}]') + log.debug( + "CalledProcessError: " + f"stdout[{cpe.stdout}] stderr[{cpe.stderr}]" + ) error("apply") raise if devs_to_down or devs_to_delete: # It's probably running already, but just in case. await arun_command( - ['systemctl', 'start', 'systemd-networkd.socket'], - check=False) + ["systemctl", "start", "systemd-networkd.socket"], check=False + ) finally: if not silent: self.apply_stopping() @@ -381,9 +392,7 @@ class BaseNetworkController(BaseController): return try: - await asyncio.wait_for( - asyncio.wait({e.wait() for e in dhcp_events}), - 10) + await asyncio.wait_for(asyncio.wait({e.wait() for e in dhcp_events}), 10) except asyncio.TimeoutError: pass @@ -393,28 +402,26 @@ class BaseNetworkController(BaseController): dev.set_dhcp_state(v, DHCPState.TIMED_OUT) self.network_event_receiver.update_link(dev.ifindex) - def set_static_config(self, dev_name: str, ip_version: int, - static_config: StaticConfig) -> None: + def set_static_config( + self, dev_name: str, ip_version: int, static_config: StaticConfig + ) -> None: dev = self.model.get_netdev_by_name(dev_name) dev.remove_ip_networks_for_version(ip_version) - dev.config.setdefault('addresses', []).extend(static_config.addresses) + dev.config.setdefault("addresses", []).extend(static_config.addresses) if static_config.gateway: - dev.config['routes'] = [{ - 'to': 'default', - 'via': static_config.gateway - }] + dev.config["routes"] = [{"to": "default", "via": static_config.gateway}] else: dev.remove_routes(ip_version) - ns = dev.config.setdefault('nameservers', {}) - ns.setdefault('addresses', []).extend(static_config.nameservers) - ns.setdefault('search', []).extend(static_config.searchdomains) + ns = dev.config.setdefault("nameservers", {}) + ns.setdefault("addresses", []).extend(static_config.nameservers) + ns.setdefault("search", []).extend(static_config.searchdomains) self.update_link(dev) self.apply_config() def enable_dhcp(self, dev_name: str, ip_version: int) -> None: dev = self.model.get_netdev_by_name(dev_name) dev.remove_ip_networks_for_version(ip_version) - dhcpkey = 'dhcp{v}'.format(v=ip_version) + dhcpkey = "dhcp{v}".format(v=ip_version) dev.config[dhcpkey] = True self.update_link(dev) self.apply_config() @@ -436,11 +443,11 @@ class BaseNetworkController(BaseController): dev = self.model.get_netdev_by_name(dev_name) touched_devices = set() if dev.type == "bond": - for device_name in dev.config['interfaces']: + for device_name in dev.config["interfaces"]: interface = self.model.get_netdev_by_name(device_name) touched_devices.add(interface) elif dev.type == "vlan": - link = self.model.get_netdev_by_name(dev.config['link']) + link = self.model.get_netdev_by_name(dev.config["link"]) touched_devices.add(link) dev.config = None self.del_link(dev) @@ -448,8 +455,9 @@ class BaseNetworkController(BaseController): self.update_link(dev) self.apply_config() - def add_or_update_bond(self, existing_name: Optional[str], - new_name: str, new_info: BondConfig) -> None: + def add_or_update_bond( + self, existing_name: Optional[str], new_name: str, new_info: BondConfig + ) -> None: get_netdev_by_name = self.model.get_netdev_by_name touched_devices = set() for device_name in new_info.interfaces: @@ -461,7 +469,7 @@ class BaseNetworkController(BaseController): self.new_link(new_dev) else: existing = get_netdev_by_name(existing_name) - for interface in existing.config['interfaces']: + for interface in existing.config["interfaces"]: touched_devices.add(get_netdev_by_name(interface)) existing.config.update(new_info.to_config()) if existing.name != new_name: @@ -480,11 +488,11 @@ class BaseNetworkController(BaseController): async def get_info_for_netdev(self, dev_name: str) -> str: device = self.model.get_netdev_by_name(dev_name) if device.info is not None: - return yaml.dump( - device.info.serialize(), default_flow_style=False) + return yaml.dump(device.info.serialize(), default_flow_style=False) else: return "Configured but not yet created {type} interface.".format( - type=device.type) + type=device.type + ) def set_wlan(self, dev_name: str, wlan: WLANConfig) -> None: device = self.model.get_netdev_by_name(dev_name) @@ -492,7 +500,7 @@ class BaseNetworkController(BaseController): if wlan.ssid and not cur_ssid: # Turn DHCP4 on by default when specifying an SSID for # the first time... - device.config['dhcp4'] = True + device.config["dhcp4"] = True device.set_ssid_psk(wlan.ssid, wlan.psk) self.update_link(device) self.apply_config() @@ -502,7 +510,7 @@ class BaseNetworkController(BaseController): try: self.observer.trigger_scan(device.ifindex) except RuntimeError as r: - device.info.wlan['scan_state'] = 'error %s' % (r,) + device.info.wlan["scan_state"] = "error %s" % (r,) self.update_link(device) @abc.abstractmethod @@ -540,12 +548,11 @@ class BaseNetworkController(BaseController): class NetworkAnswersMixin: - async def run_answers(self): - if self.answers.get('accept-default', False): + if self.answers.get("accept-default", False): self.done() - elif self.answers.get('actions', False): - actions = self.answers['actions'] + elif self.answers.get("actions", False): + actions = self.answers["actions"] self.answers.clear() await self._run_actions(actions) @@ -566,54 +573,49 @@ class NetworkAnswersMixin: async def _answers_action(self, action): log.debug("_answers_action %r", action) - if 'obj' in action: - table = self._action_get(action['obj']) - meth = getattr( - self.ui.body, - "_action_{}".format(action['action'])) - action_obj = getattr(NetDevAction, action['action']) + if "obj" in action: + table = self._action_get(action["obj"]) + meth = getattr(self.ui.body, "_action_{}".format(action["action"])) + action_obj = getattr(NetDevAction, action["action"]) self.ui.body._action(None, (action_obj, meth), table) yield body = self.ui.body._w - if action['action'] == "DELETE": + if action["action"] == "DELETE": t = 0.0 while table.dev_info.name in self.view.cur_netdev_names: await asyncio.sleep(0.1) t += 0.1 if t > 5.0: - raise Exception( - "interface did not disappear in 5 secs") + raise Exception("interface did not disappear in 5 secs") log.debug("waited %s for interface to disappear", t) if not isinstance(body, StretchyOverlay): return for k, v in action.items(): - if not k.endswith('data'): + if not k.endswith("data"): continue form_name = "form" submit_key = "submit" - if '-' in k: - prefix = k.split('-')[0] + if "-" in k: + prefix = k.split("-")[0] form_name = prefix + "_form" submit_key = prefix + "-submit" async for _ in self._enter_form_data( - getattr(body.stretchy, form_name), - v, - action.get(submit_key, True)): + getattr(body.stretchy, form_name), v, action.get(submit_key, True) + ): pass - elif action['action'] == 'create-bond': + elif action["action"] == "create-bond": self.ui.body._create_bond() yield body = self.ui.body._w - data = action['data'].copy() - if 'devices' in data: - data['interfaces'] = data.pop('devices') + data = action["data"].copy() + if "devices" in data: + data["interfaces"] = data.pop("devices") async for _ in self._enter_form_data( - body.stretchy.form, - data, - action.get("submit", True)): + body.stretchy.form, data, action.get("submit", True) + ): pass t = 0.0 - while data['name'] not in self.view.cur_netdev_names: + while data["name"] not in self.view.cur_netdev_names: await asyncio.sleep(0.1) t += 0.1 if t > 5.0: @@ -621,15 +623,13 @@ class NetworkAnswersMixin: if t > 0: log.debug("waited %s for bond to appear", t) yield - elif action['action'] == 'done': + elif action["action"] == "done": self.ui.body.done() else: raise Exception("could not process action {}".format(action)) -class NetworkController(BaseNetworkController, TuiController, - NetworkAnswersMixin): - +class NetworkController(BaseNetworkController, TuiController, NetworkAnswersMixin): def __init__(self, app): super().__init__(app) self.view = None @@ -638,15 +638,14 @@ class NetworkController(BaseNetworkController, TuiController, def make_ui(self): if not self.view_shown: self.update_initial_configs() - netdev_infos = [ - dev.netdev_info() for dev in self.model.get_all_netdevs() - ] + netdev_infos = [dev.netdev_info() for dev in self.model.get_all_netdevs()] self.view = NetworkView(self, netdev_infos) if not self.view_shown: self.apply_config(silent=True) self.view_shown = True self.view.update_has_default_route( - self.network_event_receiver.has_default_route) + self.network_event_receiver.has_default_route + ) return self.view def end_ui(self): diff --git a/subiquitycore/controllers/tests/test_network.py b/subiquitycore/controllers/tests/test_network.py index aec6d17e..ac9f7b92 100644 --- a/subiquitycore/controllers/tests/test_network.py +++ b/subiquitycore/controllers/tests/test_network.py @@ -27,98 +27,111 @@ class TestRoutes(unittest.IsolatedAsyncioTestCase): self.assertFalse(self.er._default_route_exists([])) def test_one_good(self): - routes = [{ - "target": "localhost", - "tflags": 0, - "table": 254, - "ifname": "ens3", - "dst": "", - "dst_len": 0, - "priority": 100, - "gateway": "10.0.2.2" - }] + routes = [ + { + "target": "localhost", + "tflags": 0, + "table": 254, + "ifname": "ens3", + "dst": "", + "dst_len": 0, + "priority": 100, + "gateway": "10.0.2.2", + } + ] self.assertTrue(self.er._default_route_exists(routes)) def test_mix(self): - routes = [{ - "target": "localhost", - "tflags": 0, - "table": 254, - "ifname": "ens3", - "dst": "", - "dst_len": 0, - "priority": 100, - "gateway": "10.0.2.2" - }, { - "target": "localhost", - "tflags": 0, - "table": 254, - "ifname": "ens3", - "dst": "10.0.2.0", - "dst_len": 24, - "priority": 100, - "gateway": None - }, { - "target": "localhost", - "tflags": 0, - "table": 255, - "ifname": "ens3", - "dst": "10.0.2.0", - "dst_len": 24, - "priority": 100, - "gateway": None - }, { - "target": "localhost", - "tflags": 0, - "table": 254, - "ifname": "ens3", - "dst": "10.0.2.0", - "dst_len": 24, - "priority": 20100, - "gateway": None - }] + routes = [ + { + "target": "localhost", + "tflags": 0, + "table": 254, + "ifname": "ens3", + "dst": "", + "dst_len": 0, + "priority": 100, + "gateway": "10.0.2.2", + }, + { + "target": "localhost", + "tflags": 0, + "table": 254, + "ifname": "ens3", + "dst": "10.0.2.0", + "dst_len": 24, + "priority": 100, + "gateway": None, + }, + { + "target": "localhost", + "tflags": 0, + "table": 255, + "ifname": "ens3", + "dst": "10.0.2.0", + "dst_len": 24, + "priority": 100, + "gateway": None, + }, + { + "target": "localhost", + "tflags": 0, + "table": 254, + "ifname": "ens3", + "dst": "10.0.2.0", + "dst_len": 24, + "priority": 20100, + "gateway": None, + }, + ] self.assertTrue(self.er._default_route_exists(routes)) def test_one_other(self): - routes = [{ - "target": "localhost", - "tflags": 0, - "table": 254, - "ifname": "ens3", - "dst": "10.0.2.0", - "dst_len": 24, - "priority": 100, - "gateway": None - }] + routes = [ + { + "target": "localhost", + "tflags": 0, + "table": 254, + "ifname": "ens3", + "dst": "10.0.2.0", + "dst_len": 24, + "priority": 100, + "gateway": None, + } + ] self.assertFalse(self.er._default_route_exists(routes)) def test_wrong_table(self): - routes = [{ - "target": "localhost", - "tflags": 0, - "table": 255, - "ifname": "ens3", - "dst": "", - "dst_len": 0, - "priority": 100, - "gateway": "10.0.2.2" - }] + routes = [ + { + "target": "localhost", + "tflags": 0, + "table": 255, + "ifname": "ens3", + "dst": "", + "dst_len": 0, + "priority": 100, + "gateway": "10.0.2.2", + } + ] self.assertFalse(self.er._default_route_exists(routes)) def test_wrong_priority(self): - routes = [{ - "target": "localhost", - "tflags": 0, - "table": 254, - "ifname": "ens3", - "dst": "", - "dst_len": 0, - "priority": 20100, - "gateway": "10.0.2.2" - }] + routes = [ + { + "target": "localhost", + "tflags": 0, + "table": 254, + "ifname": "ens3", + "dst": "", + "dst_len": 0, + "priority": 20100, + "gateway": "10.0.2.2", + } + ] self.assertFalse(self.er._default_route_exists(routes)) diff --git a/subiquitycore/controllerset.py b/subiquitycore/controllerset.py index 140dcd4d..f35f00a1 100644 --- a/subiquitycore/controllerset.py +++ b/subiquitycore/controllerset.py @@ -15,7 +15,6 @@ class ControllerSet: - def __init__(self, controllers_mod, names, init_args=()): self.controllers_mod = controllers_mod self.controller_names = names[:] @@ -24,7 +23,7 @@ class ControllerSet: self.instances = [] def _get_controller_class(self, name): - cls_name = name+"Controller" + cls_name = name + "Controller" return getattr(self.controllers_mod, cls_name) def load(self, name): diff --git a/subiquitycore/core.py b/subiquitycore/core.py index 3f03ec5d..c3de3f8f 100644 --- a/subiquitycore/core.py +++ b/subiquitycore/core.py @@ -19,17 +19,14 @@ import logging import os from subiquitycore.async_helpers import run_bg_task -from subiquitycore.context import ( - Context, - ) +from subiquitycore.context import Context from subiquitycore.controllerset import ControllerSet from subiquitycore.pubsub import MessageHub -log = logging.getLogger('subiquitycore.core') +log = logging.getLogger("subiquitycore.core") class Application: - # A concrete subclass must set project and controllers attributes, e.g.: # # project = "subiquity" @@ -56,38 +53,36 @@ class Application: # subiquitycore/prober.py # - copy-logs-fail: makes post-install copying of logs fail, see # subiquity/controllers/installprogress.py - self.debug_flags = os.environ.get('SUBIQUITY_DEBUG', '').split(',') + self.debug_flags = os.environ.get("SUBIQUITY_DEBUG", "").split(",") self.opts = opts opts.project = self.project - self.root = '/' + self.root = "/" if opts.dry_run: self.root = opts.output_base - self.state_dir = os.path.join(self.root, 'run', self.project) - os.makedirs(self.state_path('states'), exist_ok=True) + self.state_dir = os.path.join(self.root, "run", self.project) + os.makedirs(self.state_path("states"), exist_ok=True) - self.scale_factor = float( - os.environ.get('SUBIQUITY_REPLAY_TIMESCALE', "1")) - self.updated = os.path.exists(self.state_path('updating')) + self.scale_factor = float(os.environ.get("SUBIQUITY_REPLAY_TIMESCALE", "1")) + self.updated = os.path.exists(self.state_path("updating")) self.hub = MessageHub() - asyncio.get_running_loop().set_exception_handler( - self._exception_handler) + asyncio.get_running_loop().set_exception_handler(self._exception_handler) self.load_controllers(self.controllers) self.context = Context.new(self) self.exit_event = asyncio.Event() self.controllers_have_started = asyncio.Event() def load_controllers(self, controllers): - """ Load the corresponding list of controllers + """Load the corresponding list of controllers - Those will need to be restarted if already started """ + Those will need to be restarted if already started""" self.controllers = ControllerSet( - self.controllers_mod, controllers, - init_args=(self,)) + self.controllers_mod, controllers, init_args=(self,) + ) def _exception_handler(self, loop, context): - exc = context.get('exception') + exc = context.get("exception") if exc: self.exit_event.set() self._exc = exc @@ -101,7 +96,7 @@ class Application: cur = self.controllers.cur if cur is None: return - with open(self.state_path('states', cur.name), 'w') as fp: + with open(self.state_path("states", cur.name), "w") as fp: json.dump(cur.serialize(), fp) def report_start_event(self, context, description): @@ -114,7 +109,7 @@ class Application: level = getattr(logging, context.level) log.log(level, "finish: %s %s", description, status.name) -# EventLoop ------------------------------------------------------------------- + # EventLoop ------------------------------------------------------------------- def exit(self): self.exit_event.set() diff --git a/subiquitycore/file_util.py b/subiquitycore/file_util.py index 23524d93..50d57349 100644 --- a/subiquitycore/file_util.py +++ b/subiquitycore/file_util.py @@ -24,15 +24,18 @@ import tempfile import yaml _DEF_PERMS_FILE = 0o640 -_DEF_GROUP = 'adm' +_DEF_GROUP = "adm" -log = logging.getLogger('subiquitycore.file_util') +log = logging.getLogger("subiquitycore.file_util") def set_log_perms(target, *, isdir=True, group_write=False, mode=None): if os.getuid() != 0: - log.warning('set_log_perms: running as non-root - not adjusting' + - ' group owner or permissions for ' + target) + log.warning( + "set_log_perms: running as non-root - not adjusting" + + " group owner or permissions for " + + target + ) return if mode is None: mode = _DEF_PERMS_FILE @@ -53,7 +56,7 @@ def open_perms(filename, *, cmode=None): try: dirname = os.path.dirname(filename) os.makedirs(dirname, exist_ok=True) - tf = tempfile.NamedTemporaryFile(dir=dirname, delete=False, mode='w') + tf = tempfile.NamedTemporaryFile(dir=dirname, delete=False, mode="w") yield tf tf.close() set_log_perms(tf.name, mode=cmode) @@ -71,7 +74,7 @@ def write_file(filename, content, **kwargs): def generate_timestamped_header() -> str: now = datetime.datetime.utcnow() - return f'# Autogenerated by Subiquity: {now} UTC\n' + return f"# Autogenerated by Subiquity: {now} UTC\n" def generate_config_yaml(filename, content, **kwargs): diff --git a/subiquitycore/i18n.py b/subiquitycore/i18n.py index 5af5b2a2..177243dd 100644 --- a/subiquitycore/i18n.py +++ b/subiquitycore/i18n.py @@ -18,37 +18,44 @@ import gettext import os import syslog -syslog.syslog('i18n file is ' + __file__) -localedir = '/usr/share/locale' -if __file__.startswith('/snap/'): - localedir = os.path.realpath(__file__ + '/../../../../../share/locale') -build_mo = os.path.realpath(__file__ + '/../../build/mo/') +syslog.syslog("i18n file is " + __file__) +localedir = "/usr/share/locale" +if __file__.startswith("/snap/"): + localedir = os.path.realpath(__file__ + "/../../../../../share/locale") +build_mo = os.path.realpath(__file__ + "/../../build/mo/") if os.path.isdir(build_mo): localedir = build_mo -syslog.syslog('Final localedir is ' + localedir) +syslog.syslog("Final localedir is " + localedir) -def switch_language(code='en_US'): - syslog.syslog('switch_language ' + code) +def switch_language(code="en_US"): + syslog.syslog("switch_language " + code) fake_trans = os.environ.get("FAKE_TRANSLATE", "0") - if code != 'en_US' and fake_trans == "mangle": + if code != "en_US" and fake_trans == "mangle": + def my_gettext(message): return "_(%s)" % message + elif fake_trans not in ("0", ""): + def my_gettext(message): return message + elif code: translation = gettext.translation( - 'subiquity', localedir=localedir, languages=[code], fallback=True) + "subiquity", localedir=localedir, languages=[code], fallback=True + ) def my_gettext(message): if not message: return message return translation.gettext(message) + import builtins - builtins.__dict__['_'] = my_gettext + + builtins.__dict__["_"] = my_gettext switch_language() -__all__ = ['switch_language'] +__all__ = ["switch_language"] diff --git a/subiquitycore/log.py b/subiquitycore/log.py index 1dcace9e..f04f2d1e 100644 --- a/subiquitycore/log.py +++ b/subiquitycore/log.py @@ -19,7 +19,7 @@ import os from subiquitycore.file_util import set_log_perms -def setup_logger(dir, base='subiquity'): +def setup_logger(dir, base="subiquity"): os.makedirs(dir, exist_ok=True) # Create the log directory in such a way that users in the group may # write to this directory in the installation environment. @@ -30,7 +30,7 @@ def setup_logger(dir, base='subiquity'): r = {} - for level in 'info', 'debug': + for level in "info", "debug": nopid_file = os.path.join(dir, "{}-{}.log".format(base, level)) logfile = "{}.{}".format(nopid_file, os.getpid()) handler = logging.FileHandler(logfile) @@ -44,7 +44,9 @@ def setup_logger(dir, base='subiquity'): handler.setLevel(getattr(logging, level.upper())) handler.setFormatter( logging.Formatter( - "%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s")) + "%(asctime)s %(levelname)s %(name)s:%(lineno)d %(message)s" + ) + ) logger.addHandler(handler) r[level] = logfile diff --git a/subiquitycore/lsb_release.py b/subiquitycore/lsb_release.py index 0714092b..9e3ac136 100644 --- a/subiquitycore/lsb_release.py +++ b/subiquitycore/lsb_release.py @@ -30,5 +30,5 @@ def lsb_release(path=None, dry_run: bool = False) -> Dict[str, str]: return ret -if __name__ == '__main__': +if __name__ == "__main__": print(lsb_release()) diff --git a/subiquitycore/models/__init__.py b/subiquitycore/models/__init__.py index c6bf3020..42af1eb9 100644 --- a/subiquitycore/models/__init__.py +++ b/subiquitycore/models/__init__.py @@ -14,4 +14,5 @@ # along with this program. If not, see . from .network import NetworkModel -__all__ = ['NetworkModel'] + +__all__ = ["NetworkModel"] diff --git a/subiquitycore/models/network.py b/subiquitycore/models/network.py index 8002dbdf..587ba056 100644 --- a/subiquitycore/models/network.py +++ b/subiquitycore/models/network.py @@ -14,24 +14,31 @@ # along with this program. If not, see . import enum -from gettext import pgettext import ipaddress import logging -import yaml +from gettext import pgettext from socket import AF_INET, AF_INET6 -import attr from typing import Dict, List, Optional +import attr +import yaml + from subiquitycore import netplan - NETDEV_IGNORED_IFACE_TYPES = [ - 'lo', 'bridge', 'tun', 'tap', 'dummy', 'sit', 'can', '???' + "lo", + "bridge", + "tun", + "tap", + "dummy", + "sit", + "can", + "???", ] -NETDEV_ALLOWED_VIRTUAL_IFACE_TYPES = ['vlan', 'bond'] +NETDEV_ALLOWED_VIRTUAL_IFACE_TYPES = ["vlan", "bond"] -log = logging.getLogger('subiquitycore.models.network') +log = logging.getLogger("subiquitycore.models.network") def addr_version(ip): @@ -103,21 +110,22 @@ class BondConfig: def to_config(self): mode = self.mode params = { - 'mode': self.mode, - } + "mode": self.mode, + } if mode in BondParameters.supports_xmit_hash_policy: - params['transmit-hash-policy'] = self.xmit_hash_policy + params["transmit-hash-policy"] = self.xmit_hash_policy if mode in BondParameters.supports_lacp_rate: - params['lacp-rate'] = self.lacp_rate + params["lacp-rate"] = self.lacp_rate return { - 'interfaces': self.interfaces, - 'parameters': params, - } + "interfaces": self.interfaces, + "parameters": params, + } @attr.s(auto_attribs=True) class NetDevInfo: """All the information about a NetworkDev that the view code needs.""" + name: str type: str @@ -148,41 +156,40 @@ class BondParameters: # configured. modes = [ - 'balance-rr', - 'active-backup', - 'balance-xor', - 'broadcast', - '802.3ad', - 'balance-tlb', - 'balance-alb', + "balance-rr", + "active-backup", + "balance-xor", + "broadcast", + "802.3ad", + "balance-tlb", + "balance-alb", ] supports_xmit_hash_policy = { - 'balance-xor', - '802.3ad', - 'balance-tlb', + "balance-xor", + "802.3ad", + "balance-tlb", } xmit_hash_policies = [ - 'layer2', - 'layer2+3', - 'layer3+4', - 'encap2+3', - 'encap3+4', + "layer2", + "layer2+3", + "layer3+4", + "encap2+3", + "encap3+4", ] supports_lacp_rate = { - '802.3ad', + "802.3ad", } lacp_rates = [ - 'slow', - 'fast', + "slow", + "fast", ] class NetworkDev(object): - def __init__(self, model, name, typ): self._model = model self._name = name @@ -194,10 +201,10 @@ class NetworkDev(object): self._dhcp_state = { 4: None, 6: None, - } + } def netdev_info(self) -> NetDevInfo: - if self.type == 'eth': + if self.type == "eth": is_connected = bool(self.info.is_connected) else: is_connected = True @@ -205,34 +212,36 @@ class NetworkDev(object): for dev2 in self._model.get_all_netdevs(): if dev2.type != "bond": continue - if self.name in dev2.config.get('interfaces', []): + if self.name in dev2.config.get("interfaces", []): bond_master = dev2.name break bond: Optional[BondConfig] = None - if self.type == 'bond' and self.config is not None: - params = self.config['parameters'] + if self.type == "bond" and self.config is not None: + params = self.config["parameters"] bond = BondConfig( - interfaces=self.config['interfaces'], - mode=params['mode'], - xmit_hash_policy=params.get('xmit-hash-policy'), - lacp_rate=params.get('lacp-rate')) + interfaces=self.config["interfaces"], + mode=params["mode"], + xmit_hash_policy=params.get("xmit-hash-policy"), + lacp_rate=params.get("lacp-rate"), + ) vlan: Optional[VLANConfig] = None - if self.type == 'vlan' and self.config is not None: - vlan = VLANConfig(id=self.config['id'], link=self.config['link']) + if self.type == "vlan" and self.config is not None: + vlan = VLANConfig(id=self.config["id"], link=self.config["link"]) wlan: Optional[WLANStatus] = None - if self.type == 'wlan': + if self.type == "wlan": ssid, psk = self.configured_ssid wlan = WLANStatus( config=WLANConfig(ssid=ssid, psk=psk), - scan_state=self.info.wlan['scan_state'], - visible_ssids=self.info.wlan['visible_ssids']) + scan_state=self.info.wlan["scan_state"], + visible_ssids=self.info.wlan["visible_ssids"], + ) dhcp_addresses = self.dhcp_addresses() configured_addresses: Dict[int, List[str]] = {4: [], 6: []} if self.config is not None: - for addr in self.config.get('addresses', []): + for addr in self.config.get("addresses", []): configured_addresses[addr_version(addr)].append(addr) - ns = self.config.get('nameservers', {}) + ns = self.config.get("nameservers", {}) else: ns = {} dhcp_statuses = {} @@ -241,16 +250,18 @@ class NetworkDev(object): dhcp_statuses[v] = DHCPStatus( enabled=self.dhcp_enabled(v), state=self._dhcp_state[v], - addresses=dhcp_addresses[v]) + addresses=dhcp_addresses[v], + ) if self.config is not None: - gateway = self.config.get('gateway' + str(v)) + gateway = self.config.get("gateway" + str(v)) else: gateway = None static_configs[v] = StaticConfig( addresses=configured_addresses[v], gateway=gateway, - nameservers=ns.get('nameservers', []), - searchdomains=ns.get('search', [])) + nameservers=ns.get("nameservers", []), + searchdomains=ns.get("search", []), + ) return NetDevInfo( name=self.name, type=self.type, @@ -266,14 +277,14 @@ class NetworkDev(object): is_used=self.is_used, disabled_reason=self.disabled_reason, enabled_actions=[ - action for action in NetDevAction - if self.supports_action(action) - ], - hwaddr=getattr(self.info, 'hwaddr', None), - vendor=getattr(self.info, 'vendor', None), - model=getattr(self.info, 'model', None), + action for action in NetDevAction if self.supports_action(action) + ], + hwaddr=getattr(self.info, "hwaddr", None), + vendor=getattr(self.info, "vendor", None), + model=getattr(self.info, "model", None), is_virtual=self.is_virtual, - has_config=self.config is not None) + has_config=self.config is not None, + ) def dhcp_addresses(self): r = {4: [], 6: []} @@ -285,7 +296,7 @@ class NetworkDev(object): v = 6 else: continue - if a.source == 'dhcp': + if a.source == "dhcp": r[v].append(str(a.address)) return r @@ -293,10 +304,10 @@ class NetworkDev(object): if self.config is None: return False else: - return self.config.get('dhcp{v}'.format(v=version), False) + return self.config.get("dhcp{v}".format(v=version), False) def dhcp_state(self, version): - if not self.config.get('dhcp{v}'.format(v=version), False): + if not self.config.get("dhcp{v}".format(v=version), False): return None return self._dhcp_state[version] @@ -316,7 +327,9 @@ class NetworkDev(object): if new_name in self._model.devices_by_name: raise RuntimeError( "renaming {old_name} over {new_name}".format( - old_name=self.name, new_name=new_name)) + old_name=self.name, new_name=new_name + ) + ) self._model.devices_by_name[new_name] = self if self.info is not None: dead_device = NetworkDev(self._model, self.name, self.type) @@ -331,18 +344,18 @@ class NetworkDev(object): @property def configured_ssid(self): - for ssid, settings in self.config.get('access-points', {}).items(): - psk = settings.get('password') + for ssid, settings in self.config.get("access-points", {}).items(): + psk = settings.get("password") return ssid, psk return None, None def set_ssid_psk(self, ssid, psk): - aps = self.config.setdefault('access-points', {}) + aps = self.config.setdefault("access-points", {}) aps.clear() if ssid is not None: aps[ssid] = {} if psk is not None: - aps[ssid]['password'] = psk + aps[ssid]["password"] = psk @property def ifindex(self): @@ -359,7 +372,7 @@ class NetworkDev(object): def is_bond_slave(self): for dev in self._model.get_all_netdevs(): if dev.type == "bond": - if self.name in dev.config.get('interfaces', []): + if self.name in dev.config.get("interfaces", []): return True return False @@ -367,17 +380,20 @@ class NetworkDev(object): def is_used(self): for dev in self._model.get_all_netdevs(): if dev.type == "bond": - if self.name in dev.config.get('interfaces', []): + if self.name in dev.config.get("interfaces", []): return True if dev.type == "vlan": - if self.name == dev.config.get('link'): + if self.name == dev.config.get("link"): return True return False @property def actual_global_ip_addresses(self): - return [addr.ip for _, addr in sorted(self.info.addresses.items()) - if addr.scope == "global"] + return [ + addr.ip + for _, addr in sorted(self.info.addresses.items()) + if addr.scope == "global" + ] _supports_INFO = True _supports_EDIT_WLAN = property(lambda self: self.type == "wlan") @@ -385,29 +401,32 @@ class NetworkDev(object): _supports_EDIT_IPV6 = True _supports_EDIT_BOND = property(lambda self: self.type == "bond") _supports_ADD_VLAN = property( - lambda self: self.type != "vlan" and not self.is_bond_slave) - _supports_DELETE = property( - lambda self: self.is_virtual and not self.is_used) + lambda self: self.type != "vlan" and not self.is_bond_slave + ) + _supports_DELETE = property(lambda self: self.is_virtual and not self.is_used) def remove_ip_networks_for_version(self, version): - self.config.pop('dhcp{v}'.format(v=version), None) + self.config.pop("dhcp{v}".format(v=version), None) self.remove_routes(version) addrs = [] - for ip in self.config.get('addresses', []): + for ip in self.config.get("addresses", []): if addr_version(ip) != version: addrs.append(ip) if addrs: - self.config['addresses'] = addrs + self.config["addresses"] = addrs else: - self.config.pop('addresses', None) + self.config.pop("addresses", None) def remove_routes(self, version): - routes = [route for route in self.config.get('routes', []) - if addr_version(route['via']) != version] + routes = [ + route + for route in self.config.get("routes", []) + if addr_version(route["via"]) != version + ] if routes: - self.config['routes'] = routes + self.config["routes"] = routes else: - self.config.pop('routes', None) + self.config.pop("routes", None) class NetworkModel(object): @@ -425,7 +444,7 @@ class NetworkModel(object): @has_network.setter def has_network(self, val): - log.debug('has_network %s', val) + log.debug("has_network %s", val) self._has_network = val def parse_netplan_configs(self, netplan_root): @@ -435,7 +454,7 @@ class NetworkModel(object): def new_link(self, ifindex, link): log.debug("new_link %s %s %s", ifindex, link.name, link.type) if link.type in NETDEV_IGNORED_IFACE_TYPES: - log.debug('ignoring based on type') + log.debug("ignoring based on type") return is_virtual = link.is_virtual if link.type == "wlan": @@ -443,7 +462,7 @@ class NetworkModel(object): # they are real for testing purposes. is_virtual = False if is_virtual and link.type not in NETDEV_ALLOWED_VIRTUAL_IFACE_TYPES: - log.debug('ignoring based on is_virtual') + log.debug("ignoring based on is_virtual") return dev = self.devices_by_name.get(link.name) if dev is not None: @@ -459,14 +478,17 @@ class NetworkModel(object): if is_virtual and not config: # If we see a virtual device without there already # being a config for it, we just ignore it. - log.debug('ignoring virtual device with no config') + log.debug("ignoring virtual device with no config") return dev = NetworkDev(self, link.name, link.type) dev.info = link dev.config = config - log.debug("new_link %s %s with config %s", - ifindex, link.name, - netplan.sanitize_interface_config(dev.config)) + log.debug( + "new_link %s %s with config %s", + ifindex, + link.name, + netplan.sanitize_interface_config(dev.config), + ) self.devices_by_name[link.name] = dev return dev @@ -493,15 +515,15 @@ class NetworkModel(object): def new_vlan(self, device_name, tag): name = "{name}.{tag}".format(name=device_name, tag=tag) - dev = self.devices_by_name[name] = NetworkDev(self, name, 'vlan') + dev = self.devices_by_name[name] = NetworkDev(self, name, "vlan") dev.config = { - 'link': device_name, - 'id': tag, - } + "link": device_name, + "id": tag, + } return dev def new_bond(self, name, bond_config): - dev = self.devices_by_name[name] = NetworkDev(self, name, 'bond') + dev = self.devices_by_name[name] = NetworkDev(self, name, "bond") dev.config = bond_config.to_config() return dev @@ -515,27 +537,28 @@ class NetworkModel(object): return self.devices_by_name[name] def stringify_config(self, config): - return '\n'.join([ - "# This is the network config written by '{}'".format( - self.project), - yaml.dump(config, default_flow_style=False), - ]) + return "\n".join( + [ + "# This is the network config written by '{}'".format(self.project), + yaml.dump(config, default_flow_style=False), + ] + ) def render_config(self): config = { - 'network': { - 'version': 2, + "network": { + "version": 2, }, } type_to_key = { - 'eth': 'ethernets', - 'bond': 'bonds', - 'wlan': 'wifis', - 'vlan': 'vlans', - } + "eth": "ethernets", + "bond": "bonds", + "wlan": "wifis", + "vlan": "vlans", + } for dev in self.get_all_netdevs(): key = type_to_key[dev.type] - configs = config['network'].setdefault(key, {}) + configs = config["network"].setdefault(key, {}) if dev.config or dev.is_used: configs[dev.name] = dev.config @@ -544,21 +567,23 @@ class NetworkModel(object): def rendered_config_paths(self): """Return a list of file paths rendered by this model.""" return [ - '/' + write_file['path'] - for write_file in self.render().get('write_files').values() + "/" + write_file["path"] + for write_file in self.render().get("write_files").values() ] def render(self): return { - 'write_files': { - 'etc_netplan_installer': { - 'path': 'etc/netplan/00-installer-config.yaml', - 'content': self.stringify_config(self.render_config()), - }, - 'nonet': { - 'path': ('etc/cloud/cloud.cfg.d/' - 'subiquity-disable-cloudinit-networking.cfg'), - 'content': 'network: {config: disabled}\n', - }, + "write_files": { + "etc_netplan_installer": { + "path": "etc/netplan/00-installer-config.yaml", + "content": self.stringify_config(self.render_config()), }, - } + "nonet": { + "path": ( + "etc/cloud/cloud.cfg.d/" + "subiquity-disable-cloudinit-networking.cfg" + ), + "content": "network: {config: disabled}\n", + }, + }, + } diff --git a/subiquitycore/models/tests/test_network.py b/subiquitycore/models/tests/test_network.py index ea1fde39..ce99f64b 100644 --- a/subiquitycore/models/tests/test_network.py +++ b/subiquitycore/models/tests/test_network.py @@ -21,21 +21,21 @@ class TestRouteManagement(SubiTestCase): def setUp(self): self.nd = NetworkDev(None, None, None) self.ipv4s = [ - {'to': 'default', 'via': '10.0.2.2'}, - {'to': '1.2.3.0/24', 'via': '1.2.3.4'}, + {"to": "default", "via": "10.0.2.2"}, + {"to": "1.2.3.0/24", "via": "1.2.3.4"}, ] self.ipv6s = [ - {'to': 'default', 'via': '1111::2222'}, - {'to': '3333::0/64', 'via': '3333::4444'}, + {"to": "default", "via": "1111::2222"}, + {"to": "3333::0/64", "via": "3333::4444"}, ] - self.nd.config = {'routes': self.ipv4s + self.ipv6s} + self.nd.config = {"routes": self.ipv4s + self.ipv6s} def test_remove_v4(self): self.nd.remove_routes(4) expected = self.ipv6s - self.assertEqual(expected, self.nd.config['routes']) + self.assertEqual(expected, self.nd.config["routes"]) def test_remove_v6(self): self.nd.remove_routes(6) expected = self.ipv4s - self.assertEqual(expected, self.nd.config['routes']) + self.assertEqual(expected, self.nd.config["routes"]) diff --git a/subiquitycore/netplan.py b/subiquitycore/netplan.py index ed42a07b..236ace22 100644 --- a/subiquitycore/netplan.py +++ b/subiquitycore/netplan.py @@ -1,17 +1,18 @@ import copy -import glob import fnmatch -import os +import glob import logging +import os + import yaml log = logging.getLogger("subiquitycore.netplan") def _sanitize_inteface_config(iface_config): - for ap, ap_config in iface_config.get('access-points', {}).items(): - if 'password' in ap_config: - ap_config['password'] = '' + for ap, ap_config in iface_config.get("access-points", {}).items(): + if "password" in ap_config: + ap_config["password"] = "" def sanitize_interface_config(iface_config): @@ -23,7 +24,7 @@ def sanitize_interface_config(iface_config): def sanitize_config(config): """Return a copy of config with passwords redacted.""" config = copy.deepcopy(config) - interfaces = config.get('network', {}).get('wifis', {}).items() + interfaces = config.get("network", {}).get("wifis", {}).items() for iface, iface_config in interfaces: _sanitize_inteface_config(iface_config) return config @@ -48,7 +49,7 @@ class Config: except yaml.ReaderError as e: log.info("could not parse config: %s", e) return - network = config.get('network') + network = config.get("network") if network is None: log.info("no 'network' key in config") return @@ -56,10 +57,10 @@ class Config: if version != 2: log.info("network has no/unexpected version %s", version) return - for phys_key in 'ethernets', 'wifis': + for phys_key in "ethernets", "wifis": for dev, dev_config in network.get(phys_key, {}).items(): self.physical_devices.append(_PhysicalDevice(dev, dev_config)) - for virt_key in 'bonds', 'vlans': + for virt_key in "bonds", "vlans": for dev, dev_config in network.get(virt_key, {}).items(): self.virtual_devices.append(_VirtualDevice(dev, dev_config)) @@ -69,14 +70,17 @@ class Config: if dev.name == link.name: return copy.deepcopy(dev.config) else: - allowed_matches = ('macaddress',) - match_key = 'match' + allowed_matches = ("macaddress",) + match_key = "match" for dev in self.physical_devices: if dev.matches_link(link): config = copy.deepcopy(dev.config) if match_key in config: - match = {k: v for k, v in config[match_key].items() - if k in allowed_matches} + match = { + k: v + for k, v in config[match_key].items() + if k in allowed_matches + } if match: config[match_key] = match else: @@ -96,19 +100,17 @@ class Config: class _PhysicalDevice: def __init__(self, name, config): - match = config.get('match') + match = config.get("match") if match is None: self.match_name = name self.match_mac = None self.match_driver = None else: - self.match_name = match.get('name') - self.match_mac = match.get('macaddress') - self.match_driver = match.get('driver') + self.match_name = match.get("name") + self.match_mac = match.get("macaddress") + self.match_driver = match.get("driver") self.config = config - log.debug( - "config for %s = %s" % ( - name, sanitize_interface_config(self.config))) + log.debug("config for %s = %s" % (name, sanitize_interface_config(self.config))) def matches_link(self, link): if self.match_name is not None: @@ -130,9 +132,7 @@ class _VirtualDevice: def __init__(self, name, config): self.name = name self.config = config - log.debug( - "config for %s = %s" % ( - name, sanitize_interface_config(self.config))) + log.debug("config for %s = %s" % (name, sanitize_interface_config(self.config))) def configs_in_root(root, masked=False): @@ -155,7 +155,7 @@ def configs_in_root(root, masked=False): """returned key is basename + string-precidence based on dir.""" bname = os.path.basename(path) bdir = path[rootlen + 1] - bdir = bdir[:bdir.find(os.path.sep)] + bdir = bdir[: bdir.find(os.path.sep)] return "%s/%s" % (bname, bdir) if not masked: diff --git a/subiquitycore/palette.py b/subiquitycore/palette.py index f2ea5b69..490c79c9 100644 --- a/subiquitycore/palette.py +++ b/subiquitycore/palette.py @@ -16,115 +16,98 @@ COLORS = [ # black - ("bg", (0x11, 0x11, 0x11)), + ("bg", (0x11, 0x11, 0x11)), # dark red - ("danger", (0xff, 0x00, 0x00)), + ("danger", (0xFF, 0x00, 0x00)), # dark green - ("good", (0x0e, 0x84, 0x20)), + ("good", (0x0E, 0x84, 0x20)), # brown - ("orange", (0xe9, 0x54, 0x20)), + ("orange", (0xE9, 0x54, 0x20)), # dark blue - ("neutral", (0x00, 0x7a, 0xa6)), + ("neutral", (0x00, 0x7A, 0xA6)), # dark magenta - ("brand", (0x33, 0x33, 0x33)), + ("brand", (0x33, 0x33, 0x33)), # dark cyan - ("gray", (0x80, 0x80, 0x80)), + ("gray", (0x80, 0x80, 0x80)), # light gray - ("fg", (0xff, 0xff, 0xff)), + ("fg", (0xFF, 0xFF, 0xFF)), ] PALETTE_COLOR = [ - ('frame_header_fringe', 'orange', 'bg'), - ('frame_header', 'fg', 'orange'), - ('body', 'fg', 'bg'), - - ('done_button', 'fg', 'bg'), - ('danger_button', 'fg', 'bg'), - ('other_button', 'fg', 'bg'), - ('done_button focus', 'fg', 'good'), - ('danger_button focus', 'fg', 'danger'), - ('other_button focus', 'fg', 'gray'), - - ('menu_button', 'fg', 'bg'), - ('menu_button focus', 'fg', 'gray'), - - ('frame_button', 'fg', 'orange'), - ('frame_button focus', 'orange', 'fg'), - - ('info_primary', 'fg', 'bg'), - ('info_minor', 'gray', 'bg'), - ('info_minor header', 'gray', 'orange'), - ('info_error', 'danger', 'bg'), - - ('string_input', 'bg', 'fg'), - ('string_input focus', 'fg', 'gray'), - - ('progress_incomplete', 'fg', 'gray'), - ('progress_complete', 'fg', 'neutral'), - ('scrollbar', 'brand', 'bg'), - ('scrollbar focus', 'gray', 'bg'), - - ('verified', 'good', 'bg'), - ('verified focus', 'good', 'gray'), - ('verified invisible', 'bg', 'bg'), - ('verified inv focus', 'gray', 'gray'), - - ('starred', 'orange', 'bg'), - ('starred focus', 'orange', 'gray'), - - ('user_code', 'fg', 'good'), + ("frame_header_fringe", "orange", "bg"), + ("frame_header", "fg", "orange"), + ("body", "fg", "bg"), + ("done_button", "fg", "bg"), + ("danger_button", "fg", "bg"), + ("other_button", "fg", "bg"), + ("done_button focus", "fg", "good"), + ("danger_button focus", "fg", "danger"), + ("other_button focus", "fg", "gray"), + ("menu_button", "fg", "bg"), + ("menu_button focus", "fg", "gray"), + ("frame_button", "fg", "orange"), + ("frame_button focus", "orange", "fg"), + ("info_primary", "fg", "bg"), + ("info_minor", "gray", "bg"), + ("info_minor header", "gray", "orange"), + ("info_error", "danger", "bg"), + ("string_input", "bg", "fg"), + ("string_input focus", "fg", "gray"), + ("progress_incomplete", "fg", "gray"), + ("progress_complete", "fg", "neutral"), + ("scrollbar", "brand", "bg"), + ("scrollbar focus", "gray", "bg"), + ("verified", "good", "bg"), + ("verified focus", "good", "gray"), + ("verified invisible", "bg", "bg"), + ("verified inv focus", "gray", "gray"), + ("starred", "orange", "bg"), + ("starred focus", "orange", "gray"), + ("user_code", "fg", "good"), ] PALETTE_MONO = [ - ('frame_header_fringe', 'white', 'black'), - ('frame_header', 'black', 'white'), - ('body', 'white', 'black'), - - ('done_button', 'white', 'black'), - ('danger_button', 'white', 'black'), - ('other_button', 'white', 'black'), - ('done_button focus', 'black', 'white'), - ('danger_button focus', 'black', 'white'), - ('other_button focus', 'black', 'white'), - - ('menu_button', 'white', 'black'), - ('menu_button focus', 'black', 'white'), - ('frame_button', 'black', 'white'), - ('frame_button focus', 'white', 'black'), - - ('info_primary', 'white', 'black'), - ('info_minor', 'white', 'black'), - ('info_minor header', 'black', 'white'), - ('info_error', 'white', 'black'), - - ('string_input', 'white', 'black'), - ('string_input focus', 'black', 'white'), - - ('progress_incomplete', 'white', 'black'), - ('progress_complete', 'black', 'white'), - ('scrollbar_fg', 'white', 'black'), - ('scrollbar_bg', 'white', 'black'), - - ('verified', 'white', 'black'), - ('verified focus', 'black', 'white'), - ('verified invisible', 'white', 'black'), - ('verified inv focus', 'black', 'white'), - - ('starred', 'white', 'black'), - ('starred focus', 'black', 'white'), - - ('user_code', 'white', 'black'), + ("frame_header_fringe", "white", "black"), + ("frame_header", "black", "white"), + ("body", "white", "black"), + ("done_button", "white", "black"), + ("danger_button", "white", "black"), + ("other_button", "white", "black"), + ("done_button focus", "black", "white"), + ("danger_button focus", "black", "white"), + ("other_button focus", "black", "white"), + ("menu_button", "white", "black"), + ("menu_button focus", "black", "white"), + ("frame_button", "black", "white"), + ("frame_button focus", "white", "black"), + ("info_primary", "white", "black"), + ("info_minor", "white", "black"), + ("info_minor header", "black", "white"), + ("info_error", "white", "black"), + ("string_input", "white", "black"), + ("string_input focus", "black", "white"), + ("progress_incomplete", "white", "black"), + ("progress_complete", "black", "white"), + ("scrollbar_fg", "white", "black"), + ("scrollbar_bg", "white", "black"), + ("verified", "white", "black"), + ("verified focus", "black", "white"), + ("verified invisible", "white", "black"), + ("verified inv focus", "black", "white"), + ("starred", "white", "black"), + ("starred focus", "black", "white"), + ("user_code", "white", "black"), ] urwid_8_names = ( - 'black', - 'dark red', - 'dark green', - 'brown', - 'dark blue', - 'dark magenta', - 'dark cyan', - 'light gray', + "black", + "dark red", + "dark green", + "brown", + "dark blue", + "dark magenta", + "dark cyan", + "light gray", ) @@ -142,8 +125,7 @@ def _urwidize_palette(colors, styles): # an urwid palette by mapping the names in colors to the standard # name. if len(colors) != 8: - raise Exception( - "make_palette must be passed a list of exactly 8 colors") + raise Exception("make_palette must be passed a list of exactly 8 colors") urwid_name = dict(zip([c[0] for c in colors], urwid_8_names)) urwid_palette = [] diff --git a/subiquitycore/prober.py b/subiquitycore/prober.py index 31e118bb..61fe7f89 100644 --- a/subiquitycore/prober.py +++ b/subiquitycore/prober.py @@ -15,45 +15,41 @@ import asyncio import logging + import yaml +from probert.network import StoredDataObserver, UdevObserver from subiquitycore.async_helpers import run_in_thread -from probert.network import ( - StoredDataObserver, - UdevObserver, - ) - -log = logging.getLogger('subiquitycore.prober') +log = logging.getLogger("subiquitycore.prober") -class Prober(): +class Prober: def __init__(self, machine_config, debug_flags): self.saved_config = None if machine_config: self.saved_config = yaml.safe_load(machine_config) self.debug_flags = debug_flags - log.debug('Prober() init finished, data:{}'.format(self.saved_config)) + log.debug("Prober() init finished, data:{}".format(self.saved_config)) def probe_network(self, receiver): if self.saved_config is not None: - observer = StoredDataObserver( - self.saved_config['network'], receiver) + observer = StoredDataObserver(self.saved_config["network"], receiver) else: observer = UdevObserver(receiver) return observer, observer.start() async def get_storage(self, probe_types=None): if self.saved_config is not None: - flag = 'bpfail-full' + flag = "bpfail-full" if probe_types is not None: - flag = 'bpfail-restricted' + flag = "bpfail-restricted" if flag in self.debug_flags: await asyncio.sleep(2) - 1/0 - r = self.saved_config['storage'].copy() - if probe_types is not None and 'defaults' not in probe_types: - for k in self.saved_config['storage']: + 1 / 0 + r = self.saved_config["storage"].copy() + if probe_types is not None and "defaults" not in probe_types: + for k in self.saved_config["storage"]: if k not in probe_types: r[k] = {} return r @@ -63,7 +59,8 @@ class Prober(): # Until probert is completely free of blocking IO, we should continue # running it in a separate thread. def run_probert(probe_types): - return asyncio.run(Storage().probe(probe_types=probe_types, - parallelize=True)) + return asyncio.run( + Storage().probe(probe_types=probe_types, parallelize=True) + ) return await run_in_thread(run_probert, probe_types) diff --git a/subiquitycore/pubsub.py b/subiquitycore/pubsub.py index df40cb76..5c861c55 100644 --- a/subiquitycore/pubsub.py +++ b/subiquitycore/pubsub.py @@ -18,11 +18,10 @@ import inspect class CoreChannels: - NETWORK_UP = 'network-up' + NETWORK_UP = "network-up" class MessageHub: - def __init__(self): self.subscriptions = {} diff --git a/subiquitycore/screen.py b/subiquitycore/screen.py index 43672cf2..a5c3a15a 100644 --- a/subiquitycore/screen.py +++ b/subiquitycore/screen.py @@ -23,7 +23,7 @@ import urwid from subiquitycore.palette import COLORS, urwid_8_names -log = logging.getLogger('subiquitycore.screen') +log = logging.getLogger("subiquitycore.screen") # From uapi/linux/kd.h: @@ -31,11 +31,10 @@ KDGKBTYPE = 0x4B33 # get keyboard type GIO_CMAP = 0x4B70 # gets colour palette on VGA+ PIO_CMAP = 0x4B71 # sets colour palette on VGA+ -UO_R, UO_G, UO_B = 0xe9, 0x54, 0x20 +UO_R, UO_G, UO_B = 0xE9, 0x54, 0x20 class SubiquityScreen(urwid.raw_display.Screen): - # This class fixes a bug in urwid's screen: # # Calling screen.stop() sends the INPUT_DESCRIPTORS_CHANGED signal. This @@ -68,18 +67,17 @@ class SubiquityScreen(urwid.raw_display.Screen): class LinuxScreen(SubiquityScreen): - def __init__(self, colors, **kwargs): self._colors = colors super().__init__(**kwargs) def start(self): - self.curpal = bytearray(16*3) + self.curpal = bytearray(16 * 3) fcntl.ioctl(sys.stdout.fileno(), GIO_CMAP, self.curpal) newpal = self.curpal.copy() for i in range(8): for j in range(3): - newpal[i*3+j] = self._colors[i][1][j] + newpal[i * 3 + j] = self._colors[i][1][j] fcntl.ioctl(self._term_input_file.fileno(), PIO_CMAP, newpal) super().start() @@ -89,10 +87,8 @@ class LinuxScreen(SubiquityScreen): class TwentyFourBitScreen(SubiquityScreen): - def __init__(self, colors, **kwargs): - self._urwid_name_to_rgb = { - n: colors[i][1] for i, n in enumerate(urwid_8_names)} + self._urwid_name_to_rgb = {n: colors[i][1] for i, n in enumerate(urwid_8_names)} super().__init__(**kwargs) def _cc(self, color): @@ -103,22 +99,20 @@ class TwentyFourBitScreen(SubiquityScreen): maximum compatibility; they are the only colors used when the mono palette is selected. """ - if color == 'white': - return '7' - elif color == 'black': - return '0' - elif color == 'default': - return '9' + if color == "white": + return "7" + elif color == "black": + return "0" + elif color == "default": + return "9" else: # This is almost but not quite a ISO 8613-3 code -- that # would use colons to separate the rgb values instead. But # it's what xterm, and hence everything else, supports. - return '8;2;{};{};{}'.format(*self._urwid_name_to_rgb[color]) + return "8;2;{};{};{}".format(*self._urwid_name_to_rgb[color]) def _attrspec_to_escape(self, a): - return '\x1b[0;3{};4{}m'.format( - self._cc(a.foreground), - self._cc(a.background)) + return "\x1b[0;3{};4{}m".format(self._cc(a.foreground), self._cc(a.background)) _is_linux_tty = None @@ -128,12 +122,12 @@ def is_linux_tty(): global _is_linux_tty if _is_linux_tty is None: try: - r = fcntl.ioctl(sys.stdout.fileno(), KDGKBTYPE, ' ') + r = fcntl.ioctl(sys.stdout.fileno(), KDGKBTYPE, " ") except IOError as e: log.debug("KDGKBTYPE failed %r", e) return False - log.debug("KDGKBTYPE returned %r, is_linux_tty %s", r, r == b'\x02') - _is_linux_tty = r == b'\x02' + log.debug("KDGKBTYPE returned %r, is_linux_tty %s", r, r == b"\x02") + _is_linux_tty = r == b"\x02" return _is_linux_tty diff --git a/subiquitycore/snapd.py b/subiquitycore/snapd.py index 05cbb156..afcd89f4 100644 --- a/subiquitycore/snapd.py +++ b/subiquitycore/snapd.py @@ -14,24 +14,20 @@ # along with this program. If not, see . import asyncio -from functools import partial import glob import json import logging import os import time -from urllib.parse import ( - quote_plus, - urlencode, - ) +from functools import partial +from urllib.parse import quote_plus, urlencode + +import requests_unixsocket from subiquitycore.async_helpers import run_in_thread from subiquitycore.utils import run_command -import requests_unixsocket - - -log = logging.getLogger('subiquitycore.snapd') +log = logging.getLogger("subiquitycore.snapd") # Every method in this module blocks. Do not call them from the main thread! @@ -43,38 +39,34 @@ class SnapdConnection: def get(self, path, **args): if args: - path += '?' + urlencode(args) + path += "?" + urlencode(args) with requests_unixsocket.Session() as session: return session.get(self.url_base + path, timeout=60) def post(self, path, body, **args): if args: - path += '?' + urlencode(args) + path += "?" + urlencode(args) with requests_unixsocket.Session() as session: - return session.post( - self.url_base + path, data=json.dumps(body), - timeout=60) + return session.post(self.url_base + path, data=json.dumps(body), timeout=60) def configure_proxy(self, proxy): log.debug("restarting snapd to pick up proxy config") - dropin_dir = os.path.join( - self.root, 'etc/systemd/system/snapd.service.d') + dropin_dir = os.path.join(self.root, "etc/systemd/system/snapd.service.d") os.makedirs(dropin_dir, exist_ok=True) - with open(os.path.join(dropin_dir, 'snap_proxy.conf'), 'w') as fp: + with open(os.path.join(dropin_dir, "snap_proxy.conf"), "w") as fp: fp.write(proxy.proxy_systemd_dropin()) - if self.root == '/': + if self.root == "/": cmds = [ - ['systemctl', 'daemon-reload'], - ['systemctl', 'restart', 'snapd.service'], - ] + ["systemctl", "daemon-reload"], + ["systemctl", "restart", "snapd.service"], + ] else: - cmds = [['sleep', '2']] + cmds = [["sleep", "2"]] for cmd in cmds: run_command(cmd) class _FakeFileResponse: - def __init__(self, path): self.path = path @@ -87,7 +79,6 @@ class _FakeFileResponse: class _FakeMemoryResponse: - def __init__(self, data): self.data = data @@ -128,59 +119,65 @@ class FakeSnapdConnection: def configure_proxy(self, proxy): log.debug("pretending to restart snapd to pick up proxy config") - time.sleep(2/self.scale_factor) + time.sleep(2 / self.scale_factor) def post(self, path, body, **args): - if path == "v2/snaps/subiquity" and body['action'] == 'refresh': + if path == "v2/snaps/subiquity" and body["action"] == "refresh": # The post-refresh hook does this in the real world. - update_marker_file = self.output_base + '/run/subiquity/updating' - open(update_marker_file, 'w').close() - return _FakeMemoryResponse({ - "type": "async", - "change": "7", - "status-code": 200, - "status": "OK", - }) + update_marker_file = self.output_base + "/run/subiquity/updating" + open(update_marker_file, "w").close() + return _FakeMemoryResponse( + { + "type": "async", + "change": "7", + "status-code": 200, + "status": "OK", + } + ) change = None - if path == "v2/snaps/subiquity" and body['action'] == 'switch': + if path == "v2/snaps/subiquity" and body["action"] == "switch": change = "8" - if path.startswith('v2/systems/') and body['action'] == 'install': - system = path.split('/')[2] - step = body['step'] - if step == 'finish': - if system == 'finish-fail': + if path.startswith("v2/systems/") and body["action"] == "install": + system = path.split("/")[2] + step = body["step"] + if step == "finish": + if system == "finish-fail": change = "15" else: change = "5" - elif step == 'setup-storage-encryption': + elif step == "setup-storage-encryption": change = "6" if change is not None: - return _FakeMemoryResponse({ - "type": "async", - "change": change, - "status-code": 200, - "status": "Accepted", - }) + return _FakeMemoryResponse( + { + "type": "async", + "change": change, + "status-code": 200, + "status": "Accepted", + } + ) raise Exception( - "Don't know how to fake POST response to {}".format((path, args))) + "Don't know how to fake POST response to {}".format((path, args)) + ) def get(self, path, **args): - if 'change' not in path: - time.sleep(1/self.scale_factor) - filename = path.replace('/', '-') + if "change" not in path: + time.sleep(1 / self.scale_factor) + filename = path.replace("/", "-") if args: - filename += '-' + urlencode(sorted(args.items())) + filename += "-" + urlencode(sorted(args.items())) if filename in self.response_sets: return self.response_sets[filename].next() filepath = os.path.join(self.snap_data_dir, filename) - if os.path.exists(filepath + '.json'): - return _FakeFileResponse(filepath + '.json') + if os.path.exists(filepath + ".json"): + return _FakeFileResponse(filepath + ".json") if os.path.isdir(filepath): - files = sorted(glob.glob(os.path.join(filepath, '*.json'))) + files = sorted(glob.glob(os.path.join(filepath, "*.json"))) rs = self.response_sets[filename] = ResponseSet(files) return rs.next() raise Exception( - "Don't know how to fake GET response to {}".format((path, args))) + "Don't know how to fake GET response to {}".format((path, args)) + ) def get_fake_connection(scale_factor=1000, output_base=None): @@ -188,30 +185,29 @@ def get_fake_connection(scale_factor=1000, output_base=None): if output_base is None: output_base = os.path.join(proj_dir, ".subiquity") return FakeSnapdConnection( - os.path.join(proj_dir, "examples", "snaps"), - scale_factor, output_base) + os.path.join(proj_dir, "examples", "snaps"), scale_factor, output_base + ) class AsyncSnapd: - def __init__(self, connection): self.connection = connection async def get(self, path, **args): - response = await run_in_thread( - partial(self.connection.get, path, **args)) + response = await run_in_thread(partial(self.connection.get, path, **args)) response.raise_for_status() return response.json() async def post(self, path, body, **args): response = await run_in_thread( - partial(self.connection.post, path, body, **args)) + partial(self.connection.post, path, body, **args) + ) response.raise_for_status() return response.json() async def post_and_wait(self, path, body, **args): - change = (await self.post(path, body, **args))['change'] - change_path = 'v2/changes/{}'.format(change) + change = (await self.post(path, body, **args))["change"] + change_path = "v2/changes/{}".format(change) while True: result = await self.get(change_path) if result["result"]["status"] == "Done": diff --git a/subiquitycore/ssh.py b/subiquitycore/ssh.py index 560e3cb5..75e7197f 100644 --- a/subiquitycore/ssh.py +++ b/subiquitycore/ssh.py @@ -19,7 +19,6 @@ import pwd from subiquitycore.utils import run_command - log = logging.getLogger("subiquitycore.ssh") @@ -29,7 +28,7 @@ def host_key_fingerprints(): Returns a sequence of (key-type, fingerprint) pairs. """ try: - config = run_command(['sshd', '-T']) + config = run_command(["sshd", "-T"]) except FileNotFoundError: log.debug("sshd not found") return [] @@ -38,7 +37,7 @@ def host_key_fingerprints(): return [] keyfiles = [] for line in config.stdout.splitlines(): - if line.startswith('hostkey '): + if line.startswith("hostkey "): keyfiles.append(line.split(None, 1)[1]) info = [] for keyfile in keyfiles: @@ -48,29 +47,33 @@ def host_key_fingerprints(): def fingerprints(keyfile): info = [] - cp = run_command(['ssh-keygen', '-lf', keyfile]) + cp = run_command(["ssh-keygen", "-lf", keyfile]) if cp.returncode != 0: log.debug("ssh-keygen -lf %s failed %r", keyfile, cp.stderr) return info for line in cp.stdout.splitlines(): - parts = line.strip().replace('\r', '').split() + parts = line.strip().replace("\r", "").split() fingerprint = parts[1] - keytype = parts[-1].strip('()') + keytype = parts[-1].strip("()") info.append((keytype, fingerprint)) return info -host_keys_intro = _("""\ +host_keys_intro = _( + """\ The host key fingerprints are: -""") +""" +) host_key_tmpl = """ {keytype:{width}} {fingerprint}""" -single_host_key_tmpl = _("""\ +single_host_key_tmpl = _( + """\ The {keytype} host key fingerprint is: {fingerprint} -""") +""" +) def host_key_info(): @@ -79,17 +82,18 @@ def host_key_info(): def summarize_host_keys(fingerprints): if len(fingerprints) == 0: - return '' + return "" if len(fingerprints) == 1: [(keytype, fingerprint)] = fingerprints - return _(single_host_key_tmpl).format(keytype=keytype, - fingerprint=fingerprint) + return _(single_host_key_tmpl).format(keytype=keytype, fingerprint=fingerprint) lines = [_(host_keys_intro)] longest_type = max([len(keytype) for keytype, _ in fingerprints]) for keytype, fingerprint in fingerprints: - lines.append(host_key_tmpl.format(keytype=keytype, - fingerprint=fingerprint, - width=longest_type)) + lines.append( + host_key_tmpl.format( + keytype=keytype, fingerprint=fingerprint, width=longest_type + ) + ) return "".join(lines) @@ -99,7 +103,7 @@ def user_key_fingerprints(username): except KeyError: log.exception("getpwnam(%s) failed", username) return [] - user_key_file = '{}/.ssh/authorized_keys'.format(user_info.pw_dir) + user_key_file = "{}/.ssh/authorized_keys".format(user_info.pw_dir) if os.path.exists(user_key_file): return fingerprints(user_key_file) else: @@ -108,15 +112,17 @@ def user_key_fingerprints(username): def get_ips_standalone(): from probert.prober import Prober + from subiquitycore.models.network import NETDEV_IGNORED_IFACE_TYPES + prober = Prober() prober.probe_network() - links = prober.get_results()['network']['links'] + links = prober.get_results()["network"]["links"] ips = [] - for link in sorted(links, key=lambda link: link['netlink_data']['name']): - if link['type'] in NETDEV_IGNORED_IFACE_TYPES: + for link in sorted(links, key=lambda link: link["netlink_data"]["name"]): + if link["type"] in NETDEV_IGNORED_IFACE_TYPES: continue - for addr in link['addresses']: - if addr['scope'] == "global": - ips.append(addr['address'].split('/')[0]) + for addr in link["addresses"]: + if addr["scope"] == "global": + ips.append(addr["address"].split("/")[0]) return ips diff --git a/subiquitycore/testing/view_helpers.py b/subiquitycore/testing/view_helpers.py index ae66aab4..1e87c0ff 100644 --- a/subiquitycore/testing/view_helpers.py +++ b/subiquitycore/testing/view_helpers.py @@ -1,4 +1,5 @@ import re + import urwid from subiquitycore.ui.stretchy import StretchyOverlay @@ -7,13 +8,12 @@ from subiquitycore.ui.stretchy import StretchyOverlay def find_with_pred(w, pred, return_path=False): def _walk(w, path): if not isinstance(w, urwid.Widget): - raise RuntimeError( - "_walk walked to non-widget %r via %r" % (w, path)) + raise RuntimeError("_walk walked to non-widget %r via %r" % (w, path)) if pred(w): return w, path - if hasattr(w, '_wrapped_widget'): + if hasattr(w, "_wrapped_widget"): return _walk(w._wrapped_widget, (w,) + path) - if hasattr(w, 'original_widget'): + if hasattr(w, "original_widget"): return _walk(w.original_widget, (w,) + path) if isinstance(w, urwid.ListBox): for w in w.body: @@ -25,7 +25,7 @@ def find_with_pred(w, pred, return_path=False): r, p = _walk(w, (w,) + path) if r: return r, p - elif hasattr(w, 'contents'): + elif hasattr(w, "contents"): contents = w.contents for w, _ in contents: r, p = _walk(w, (w,) + path) @@ -36,6 +36,7 @@ def find_with_pred(w, pred, return_path=False): if r: return r, p return None, None + r, p = _walk(w, ()) if return_path: return r, p @@ -46,11 +47,12 @@ def find_with_pred(w, pred, return_path=False): def find_button_matching(w, pat, return_path=False): def pred(w): return isinstance(w, urwid.Button) and re.match(pat, w.label) + return find_with_pred(w, pred, return_path) def click(but): - but._emit('click') + but._emit("click") def keypress(w, key, size=(30, 1)): @@ -61,8 +63,7 @@ def get_focus_path(w): path = [] while True: path.append(w) - if isinstance(w, urwid.ListBox) and (w.set_focus_pending == - "first selectable"): + if isinstance(w, urwid.ListBox) and (w.set_focus_pending == "first selectable"): for w2 in w.body: if w2.selectable(): w = w2 @@ -71,9 +72,9 @@ def get_focus_path(w): break if w.focus is not None: w = w.focus - elif hasattr(w, '_wrapped_widget'): + elif hasattr(w, "_wrapped_widget"): w = w._wrapped_widget - elif hasattr(w, 'original_widget'): + elif hasattr(w, "original_widget"): w = w.original_widget else: break diff --git a/subiquitycore/tests/__init__.py b/subiquitycore/tests/__init__.py index 4d6a5870..44eaca6e 100644 --- a/subiquitycore/tests/__init__.py +++ b/subiquitycore/tests/__init__.py @@ -2,7 +2,6 @@ import functools import os import shutil import tempfile - from unittest import IsolatedAsyncioTestCase @@ -10,8 +9,7 @@ class SubiTestCase(IsolatedAsyncioTestCase): def tmp_dir(self, dir=None, cleanup=True): # return a full path to a temporary directory that will be cleaned up. if dir is None: - tmpd = tempfile.mkdtemp( - prefix="subiquity-%s." % self.__class__.__name__) + tmpd = tempfile.mkdtemp(prefix="subiquity-%s." % self.__class__.__name__) else: tmpd = tempfile.mkdtemp(dir=dir) self.addCleanup(functools.partial(shutil.rmtree, tmpd)) @@ -26,7 +24,7 @@ class SubiTestCase(IsolatedAsyncioTestCase): return os.path.normpath(os.path.abspath(os.path.join(dir, path))) def assert_contents(self, path, expected_contents): - with open(path, 'r') as fp: + with open(path, "r") as fp: self.assertEqual(expected_contents, fp.read()) @@ -34,7 +32,7 @@ def populate_dir(path, files): if not os.path.exists(path): os.makedirs(path) ret = [] - for (name, content) in files.items(): + for name, content in files.items(): p = os.path.sep.join([path, name]) if not os.path.isdir(os.path.dirname(p)): os.makedirs(os.path.dirname(p)) @@ -42,7 +40,7 @@ def populate_dir(path, files): if isinstance(content, bytes): fp.write(content) else: - fp.write(content.encode('utf-8')) + fp.write(content.encode("utf-8")) fp.close() ret.append(p) return ret diff --git a/subiquitycore/tests/parameterized.py b/subiquitycore/tests/parameterized.py index d7258b27..8a0d9b3e 100644 --- a/subiquitycore/tests/parameterized.py +++ b/subiquitycore/tests/parameterized.py @@ -56,10 +56,13 @@ class parameterized(_parameterized): @classmethod def param_as_standalone_func(cls, p, func, name): if inspect.iscoroutinefunction(func): + @wraps(func) async def standalone_func(*a): return await func(*(a + p.args), **p.kwargs) + else: + @wraps(func) def standalone_func(*a): return func(*(a + p.args), **p.kwargs) diff --git a/subiquitycore/tests/test_async_helpers.py b/subiquitycore/tests/test_async_helpers.py index 392d6954..1a1c701e 100644 --- a/subiquitycore/tests/test_async_helpers.py +++ b/subiquitycore/tests/test_async_helpers.py @@ -17,11 +17,7 @@ import asyncio import unittest from unittest.mock import AsyncMock -from subiquitycore.async_helpers import ( - SingleInstanceTask, - TaskAlreadyRunningError, -) - +from subiquitycore.async_helpers import SingleInstanceTask, TaskAlreadyRunningError from subiquitycore.tests.parameterized import parameterized @@ -30,12 +26,12 @@ class TestSingleInstanceTask(unittest.IsolatedAsyncioTestCase): async def test_cancellable(self, cancel_restart, expected_call_count): async def fn(): await asyncio.sleep(3) - raise Exception('timeout') + raise Exception("timeout") mock_fn = AsyncMock(side_effect=fn) sit = SingleInstanceTask(mock_fn, cancel_restart=cancel_restart) await sit.start() - await asyncio.sleep(.01) + await asyncio.sleep(0.01) try: await sit.start() except TaskAlreadyRunningError: @@ -53,6 +49,7 @@ class TestSITWait(unittest.IsolatedAsyncioTestCase): async def test_wait_started(self): async def fn(): pass + sit = SingleInstanceTask(fn) await sit.start() await asyncio.wait_for(sit.wait(), timeout=1.0) @@ -60,7 +57,8 @@ class TestSITWait(unittest.IsolatedAsyncioTestCase): async def test_wait_not_started(self): async def fn(): - self.fail('not supposed to be called') + self.fail("not supposed to be called") + sit = SingleInstanceTask(fn) self.assertFalse(sit.done()) with self.assertRaises(asyncio.TimeoutError): diff --git a/subiquitycore/tests/test_file_util.py b/subiquitycore/tests/test_file_util.py index ae0a5b43..ea0b4849 100644 --- a/subiquitycore/tests/test_file_util.py +++ b/subiquitycore/tests/test_file_util.py @@ -19,13 +19,13 @@ from subiquitycore.tests import SubiTestCase class TestCopy(SubiTestCase): def test_copied_to_non_exist_dir(self): - data = 'stuff things' - src = self.tmp_path('src') - tgt = self.tmp_path('create-me/target') - with open(src, 'w') as fp: + data = "stuff things" + src = self.tmp_path("src") + tgt = self.tmp_path("create-me/target") + with open(src, "w") as fp: fp.write(data) copy_file_if_exists(src, tgt) self.assert_contents(tgt, data) def test_copied_non_exist_src(self): - copy_file_if_exists('/does/not/exist', '/ditto') + copy_file_if_exists("/does/not/exist", "/ditto") diff --git a/subiquitycore/tests/test_lsb_release.py b/subiquitycore/tests/test_lsb_release.py index 933c37e1..a8d19802 100644 --- a/subiquitycore/tests/test_lsb_release.py +++ b/subiquitycore/tests/test_lsb_release.py @@ -14,7 +14,7 @@ # along with this program. If not, see . import unittest -from unittest.mock import patch, mock_open +from unittest.mock import mock_open, patch from subiquitycore.lsb_release import lsb_release diff --git a/subiquitycore/tests/test_netplan.py b/subiquitycore/tests/test_netplan.py index 614db140..cefe9484 100644 --- a/subiquitycore/tests/test_netplan.py +++ b/subiquitycore/tests/test_netplan.py @@ -1,40 +1,45 @@ import os -from subiquitycore.tests import SubiTestCase, populate_dir from subiquitycore.netplan import configs_in_root +from subiquitycore.tests import SubiTestCase, populate_dir class TestConfigsInRoot(SubiTestCase): def test_masked_true(self): """configs_in_root masked=False should not return masked files.""" my_dir = self.tmp_dir() - unmasked = ['run/netplan/00base.yaml', - 'lib/netplan/01system.yaml', 'etc/netplan/99end.yaml'] + unmasked = [ + "run/netplan/00base.yaml", + "lib/netplan/01system.yaml", + "etc/netplan/99end.yaml", + ] masked = ["etc/netplan/00base.yaml"] populate_dir(my_dir, {f: "key: here\n" for f in unmasked + masked}) self.assertEqual( - [os.path.join(my_dir, p) for p in unmasked], - configs_in_root(my_dir)) + [os.path.join(my_dir, p) for p in unmasked], configs_in_root(my_dir) + ) def test_masked_false(self): """configs_in_root mask=True should return all configs.""" my_dir = self.tmp_dir() yamls = [ - 'etc/netplan/00base.yaml', - 'run/netplan/00base.yaml', - 'lib/netplan/01system.yaml', - 'etc/netplan/99end.yaml'] + "etc/netplan/00base.yaml", + "run/netplan/00base.yaml", + "lib/netplan/01system.yaml", + "etc/netplan/99end.yaml", + ] populate_dir(my_dir, {f: "someyaml: here\n" for f in yamls}) self.assertEqual( [os.path.join(my_dir, p) for p in yamls], - configs_in_root(my_dir, masked=True)) + configs_in_root(my_dir, masked=True), + ) def test_only_includes_yaml(self): """configs_in_root should only return *.yaml files.""" my_dir = self.tmp_dir() - yamls = ['etc/netplan/00base.yaml', 'etc/netplan/99end.yaml'] - nonyamls = ['etc/netplan/ignored.yaml.dist', 'run/netplan/my.cfg'] + yamls = ["etc/netplan/00base.yaml", "etc/netplan/99end.yaml"] + nonyamls = ["etc/netplan/ignored.yaml.dist", "run/netplan/my.cfg"] populate_dir(my_dir, {f: "someyaml: here\n" for f in yamls + nonyamls}) self.assertEqual( - [os.path.join(my_dir, p) for p in yamls], - configs_in_root(my_dir)) + [os.path.join(my_dir, p) for p in yamls], configs_in_root(my_dir) + ) diff --git a/subiquitycore/tests/test_prober.py b/subiquitycore/tests/test_prober.py index fbaf89cc..fcf0f57e 100644 --- a/subiquitycore/tests/test_prober.py +++ b/subiquitycore/tests/test_prober.py @@ -13,15 +13,14 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from subiquitycore.tests import SubiTestCase - from subiquitycore.prober import Prober +from subiquitycore.tests import SubiTestCase class TestProber(SubiTestCase): async def test_none_and_defaults_equal(self): - with open('examples/machines/simple.json', 'r') as fp: + with open("examples/machines/simple.json", "r") as fp: prober = Prober(machine_config=fp, debug_flags=()) none_storage = await prober.get_storage(probe_types=None) - defaults_storage = await prober.get_storage(probe_types={'defaults'}) + defaults_storage = await prober.get_storage(probe_types={"defaults"}) self.assertEqual(defaults_storage, none_storage) diff --git a/subiquitycore/tests/test_pubsub.py b/subiquitycore/tests/test_pubsub.py index 3160a36a..6675bd59 100644 --- a/subiquitycore/tests/test_pubsub.py +++ b/subiquitycore/tests/test_pubsub.py @@ -43,7 +43,7 @@ class TestMessageHub(unittest.IsolatedAsyncioTestCase): async def test_message_arg(self): cb = MagicMock() - channel_id = 'test-message-arg' + channel_id = "test-message-arg" self.hub.subscribe(channel_id, cb) - await self.hub.abroadcast(channel_id, '0', 1, 'two', [3], four=4) - cb.assert_called_once_with('0', 1, 'two', [3], four=4) + await self.hub.abroadcast(channel_id, "0", 1, "two", [3], four=4) + cb.assert_called_once_with("0", 1, "two", [3], four=4) diff --git a/subiquitycore/tests/test_utils.py b/subiquitycore/tests/test_utils.py index 108ce6fb..d25c6718 100644 --- a/subiquitycore/tests/test_utils.py +++ b/subiquitycore/tests/test_utils.py @@ -26,45 +26,45 @@ class TestOrigEnviron(SubiTestCase): self.assertEqual(expected, orig_environ(env)) def test_orig_path(self): - env = {'PATH': 'a', 'PATH_ORIG': 'b'} - expected = {'PATH': 'b'} + env = {"PATH": "a", "PATH_ORIG": "b"} + expected = {"PATH": "b"} self.assertEqual(expected, orig_environ(env)) def test_not_this_key(self): - env = {'PATH': 'a', 'PATH_ORIG_AAAAA': 'b'} + env = {"PATH": "a", "PATH_ORIG_AAAAA": "b"} expected = env self.assertEqual(expected, orig_environ(env)) def test_remove_empty_key(self): - env = {'STUFF': 'a', 'STUFF_ORIG': ''} + env = {"STUFF": "a", "STUFF_ORIG": ""} expected = {} self.assertEqual(expected, orig_environ(env)) def test_no_ld_library_path(self): - env = {'LD_LIBRARY_PATH': 'a'} + env = {"LD_LIBRARY_PATH": "a"} expected = {} self.assertEqual(expected, orig_environ(env)) def test_practical(self): - snap = '/snap/subiquity/1234' + snap = "/snap/subiquity/1234" env = { - 'TERM': 'linux', - 'LD_LIBRARY_PATH': '/var/lib/snapd/lib/gl', - 'PYTHONIOENCODING_ORIG': '', - 'PYTHONIOENCODING': 'utf-8', - 'SUBIQUITY_ROOT_ORIG': '', - 'SUBIQUITY_ROOT': snap, - 'PYTHON_ORIG': '', - 'PYTHON': f'{snap}/usr/bin/python3.10', - 'PYTHONPATH_ORIG': '', - 'PYTHONPATH': f'{snap}/stuff/things', - 'PY3OR2_PYTHON_ORIG': '', - 'PY3OR2_PYTHON': f'{snap}/usr/bin/python3.10', - 'PATH_ORIG': '/usr/bin:/bin', - 'PATH': '/usr/bin:/bin:/snap/bin' + "TERM": "linux", + "LD_LIBRARY_PATH": "/var/lib/snapd/lib/gl", + "PYTHONIOENCODING_ORIG": "", + "PYTHONIOENCODING": "utf-8", + "SUBIQUITY_ROOT_ORIG": "", + "SUBIQUITY_ROOT": snap, + "PYTHON_ORIG": "", + "PYTHON": f"{snap}/usr/bin/python3.10", + "PYTHONPATH_ORIG": "", + "PYTHONPATH": f"{snap}/stuff/things", + "PY3OR2_PYTHON_ORIG": "", + "PY3OR2_PYTHON": f"{snap}/usr/bin/python3.10", + "PATH_ORIG": "/usr/bin:/bin", + "PATH": "/usr/bin:/bin:/snap/bin", } expected = { - 'TERM': 'linux', - 'PATH': '/usr/bin:/bin', + "TERM": "linux", + "PATH": "/usr/bin:/bin", } self.assertEqual(expected, orig_environ(env)) diff --git a/subiquitycore/tests/test_view.py b/subiquitycore/tests/test_view.py index cb0e740a..be28dc0d 100644 --- a/subiquitycore/tests/test_view.py +++ b/subiquitycore/tests/test_view.py @@ -16,9 +16,9 @@ import urwid from subiquitycore.tests import SubiTestCase -from subiquitycore.view import BaseView, OverlayNotFoundError from subiquitycore.ui.stretchy import Stretchy, StretchyOverlay from subiquitycore.ui.utils import undisabled +from subiquitycore.view import BaseView, OverlayNotFoundError class InstrumentedStretchy(Stretchy): @@ -35,7 +35,6 @@ class InstrumentedStretchy(Stretchy): class TestBaseView(SubiTestCase): - def get_stretchy_chain(self, view): view = view._w r = [] diff --git a/subiquitycore/tests/util.py b/subiquitycore/tests/util.py index 63cef422..ab2b1a83 100644 --- a/subiquitycore/tests/util.py +++ b/subiquitycore/tests/util.py @@ -18,4 +18,4 @@ import string def random_string(): - return ''.join(random.choice(string.ascii_letters) for _ in range(8)) + return "".join(random.choice(string.ascii_letters) for _ in range(8)) diff --git a/subiquitycore/tui.py b/subiquitycore/tui.py index 3554ba3f..1671377e 100644 --- a/subiquitycore/tui.py +++ b/subiquitycore/tui.py @@ -21,42 +21,38 @@ import signal from typing import Callable, Optional, Union import urwid - import yaml -from subiquitycore.async_helpers import ( - run_bg_task, - schedule_task, - ) +from subiquitycore.async_helpers import run_bg_task, schedule_task from subiquitycore.core import Application -from subiquitycore.palette import ( - PALETTE_COLOR, - PALETTE_MONO, - ) +from subiquitycore.palette import PALETTE_COLOR, PALETTE_MONO from subiquitycore.screen import make_screen from subiquitycore.tuicontroller import Skip -from subiquitycore.ui.utils import LoadingDialog from subiquitycore.ui.frame import SubiquityCoreUI +from subiquitycore.ui.utils import LoadingDialog from subiquitycore.utils import astart_command from subiquitycore.view import BaseView -log = logging.getLogger('subiquitycore.tui') +log = logging.getLogger("subiquitycore.tui") def extend_dec_special_charmap(): - urwid.escape.DEC_SPECIAL_CHARMAP.update({ - ord('\N{BLACK RIGHT-POINTING SMALL TRIANGLE}'): '>', - ord('\N{BLACK LEFT-POINTING SMALL TRIANGLE}'): '<', - ord('\N{BLACK DOWN-POINTING SMALL TRIANGLE}'): 'v', - ord('\N{BLACK UP-POINTING SMALL TRIANGLE}'): '^', - ord('\N{check mark}'): '*', - ord('\N{circled white star}'): '*', - ord('\N{bullet}'): '*', - ord('\N{lower half block}'): '=', - ord('\N{upper half block}'): '=', - ord('\N{FULL BLOCK}'): urwid.escape.DEC_SPECIAL_CHARMAP[ - ord('\N{BOX DRAWINGS LIGHT VERTICAL}')], - }) + urwid.escape.DEC_SPECIAL_CHARMAP.update( + { + ord("\N{BLACK RIGHT-POINTING SMALL TRIANGLE}"): ">", + ord("\N{BLACK LEFT-POINTING SMALL TRIANGLE}"): "<", + ord("\N{BLACK DOWN-POINTING SMALL TRIANGLE}"): "v", + ord("\N{BLACK UP-POINTING SMALL TRIANGLE}"): "^", + ord("\N{check mark}"): "*", + ord("\N{circled white star}"): "*", + ord("\N{bullet}"): "*", + ord("\N{lower half block}"): "=", + ord("\N{upper half block}"): "=", + ord("\N{FULL BLOCK}"): urwid.escape.DEC_SPECIAL_CHARMAP[ + ord("\N{BOX DRAWINGS LIGHT VERTICAL}") + ], + } + ) # When waiting for something of unknown duration, block the UI for at @@ -68,7 +64,6 @@ MIN_SHOW_PROGRESS_TIME = 1.0 class TuiApplication(Application): - make_ui = SubiquityCoreUI def __init__(self, opts): @@ -80,7 +75,7 @@ class TuiApplication(Application): self.answers = yaml.safe_load(opts.answers.read()) log.debug("Loaded answers %s", self.answers) if not opts.dry_run: - open('/run/casper-no-prompt', 'w').close() + open("/run/casper-no-prompt", "w").close() # Set rich_mode to the opposite of what we want, so we can # call toggle_rich to get the right things set up. @@ -89,15 +84,15 @@ class TuiApplication(Application): self.cur_screen = None self.fg_proc = None - def run_command_in_foreground(self, cmd, before_hook=None, after_hook=None, - **kw): + def run_command_in_foreground(self, cmd, before_hook=None, after_hook=None, **kw): if self.fg_proc is not None: raise Exception("cannot run two fg processes at once") screen = self.urwid_loop.screen async def _run(): self.fg_proc = proc = await astart_command( - cmd, stdin=None, stdout=None, stderr=None, **kw) + cmd, stdin=None, stdout=None, stderr=None, **kw + ) await proc.communicate() self.fg_proc = None # One of the main use cases for this function is to run interactive @@ -122,14 +117,14 @@ class TuiApplication(Application): after_hook() screen.stop() - urwid.emit_signal( - screen, urwid.display_common.INPUT_DESCRIPTORS_CHANGED) + urwid.emit_signal(screen, urwid.display_common.INPUT_DESCRIPTORS_CHANGED) if before_hook is not None: before_hook() schedule_task(_run()) - async def make_view_for_controller(self, new) \ - -> Union[BaseView, Callable[[], BaseView]]: + async def make_view_for_controller( + self, new + ) -> Union[BaseView, Callable[[], BaseView]]: new.context.enter("starting UI") if self.opts.screens and new.name not in self.opts.screens: raise Skip @@ -161,8 +156,7 @@ class TuiApplication(Application): await asyncio.sleep(MAX_BLOCK_TIME) self.ui.block_input = False nonlocal min_show_task - min_show_task = asyncio.create_task( - asyncio.sleep(MIN_SHOW_PROGRESS_TIME)) + min_show_task = asyncio.create_task(asyncio.sleep(MIN_SHOW_PROGRESS_TIME)) show() self.ui.block_input = True @@ -183,8 +177,7 @@ class TuiApplication(Application): def show_progress(self): raise NotImplementedError - async def wait_with_text_dialog(self, awaitable, message, - *, can_cancel=False): + async def wait_with_text_dialog(self, awaitable, message, *, can_cancel=False): ld = None task_to_cancel = None @@ -194,6 +187,7 @@ class TuiApplication(Application): async def w(): return await orig + awaitable = task_to_cancel = asyncio.create_task(w()) else: task_to_cancel = None @@ -206,14 +200,14 @@ class TuiApplication(Application): def hide_load(): ld.close() - return await self._wait_with_indication( - awaitable, show_load, hide_load) + return await self._wait_with_indication(awaitable, show_load, hide_load) async def wait_with_progress(self, awaitable): return await self._wait_with_indication(awaitable, self.show_progress) - async def _move_screen(self, increment, coro) \ - -> Optional[Union[BaseView, Callable[[], BaseView]]]: + async def _move_screen( + self, increment, coro + ) -> Optional[Union[BaseView, Callable[[], BaseView]]]: if coro is not None: await coro old, self.cur_screen = self.cur_screen, None @@ -241,7 +235,8 @@ class TuiApplication(Application): async def move_screen(self, increment, coro): view_or_callable = await self.wait_with_progress( - self._move_screen(increment, coro)) + self._move_screen(increment, coro) + ) if view_or_callable is not None: if callable(view_or_callable): view = view_or_callable() @@ -265,11 +260,11 @@ class TuiApplication(Application): def toggle_rich(self): if self.rich_mode: - urwid.util.set_encoding('ascii') + urwid.util.set_encoding("ascii") new_palette = PALETTE_MONO self.rich_mode = False else: - urwid.util.set_encoding('utf-8') + urwid.util.set_encoding("utf-8") new_palette = PALETTE_COLOR self.rich_mode = True urwid.CanvasCache.clear() @@ -277,11 +272,11 @@ class TuiApplication(Application): self.urwid_loop.screen.clear() def unhandled_input(self, key): - if self.opts.dry_run and key == 'ctrl x': + if self.opts.dry_run and key == "ctrl x": self.exit() - elif key == 'f3': + elif key == "f3": self.urwid_loop.screen.clear() - elif self.opts.run_on_serial and key in ['ctrl t', 'f4']: + elif self.opts.run_on_serial and key in ["ctrl t", "f4"]: self.toggle_rich() def extra_urwid_loop_args(self): @@ -297,12 +292,14 @@ class TuiApplication(Application): screen = self.make_screen(input, output) screen.register_palette(PALETTE_COLOR) self.urwid_loop = urwid.MainLoop( - self.ui, screen=screen, - handle_mouse=False, pop_ups=True, + self.ui, + screen=screen, + handle_mouse=False, + pop_ups=True, unhandled_input=self.unhandled_input, event_loop=urwid.AsyncioEventLoop(loop=asyncio.get_running_loop()), - **self.extra_urwid_loop_args() - ) + **self.extra_urwid_loop_args(), + ) extend_dec_special_charmap() self.toggle_rich() self.urwid_loop.start() diff --git a/subiquitycore/tuicontroller.py b/subiquitycore/tuicontroller.py index e3f72a46..777739ed 100644 --- a/subiquitycore/tuicontroller.py +++ b/subiquitycore/tuicontroller.py @@ -13,9 +13,9 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from abc import abstractmethod import asyncio import logging +from abc import abstractmethod from subiquitycore.controller import BaseController @@ -61,17 +61,17 @@ class TuiController(BaseController): # Stuff for fine grained actions, used by filesystem and network # controller at time of writing this comment. - async def _enter_form_data(self, form, data, submit, clean_suffix=''): + async def _enter_form_data(self, form, data, submit, clean_suffix=""): for k, v in data.items(): - c = getattr( - self, '_action_clean_{}_{}'.format(k, clean_suffix), None) + c = getattr(self, "_action_clean_{}_{}".format(k, clean_suffix), None) if c is None: - c = getattr(self, '_action_clean_{}'.format(k), lambda x: x) + c = getattr(self, "_action_clean_{}".format(k), lambda x: x) field = getattr(form, k) from subiquitycore.ui.selector import Selector + v = c(v) if isinstance(field.widget, Selector): - field.widget._emit('select', v) + field.widget._emit("select", v) field.value = v yield for bf in form._fields: @@ -83,7 +83,7 @@ class TuiController(BaseController): form._click_done(None) async def _run_actions(self, actions): - delay = 0.2/self.app.scale_factor + delay = 0.2 / self.app.scale_factor for action in actions: async for _ in self._answers_action(action): await asyncio.sleep(delay) @@ -91,7 +91,6 @@ class TuiController(BaseController): class RepeatedController(BaseController): - def __init__(self, orig, index): self.name = "{}-{}".format(orig.name, index) self.orig = orig diff --git a/subiquitycore/ui/actionmenu.py b/subiquitycore/ui/actionmenu.py index 8e854876..861ff69f 100644 --- a/subiquitycore/ui/actionmenu.py +++ b/subiquitycore/ui/actionmenu.py @@ -14,24 +14,19 @@ # along with this program. If not, see . import attr - from urwid import ( ACTIVATE, AttrWrap, Button, - connect_signal, LineBox, PopUpLauncher, SelectableIcon, Text, Widget, - ) - -from subiquitycore.ui.container import ( - Columns, - ListBox, - WidgetWrap, + connect_signal, ) + +from subiquitycore.ui.container import Columns, ListBox, WidgetWrap from subiquitycore.ui.utils import Color @@ -69,8 +64,7 @@ class _ActionMenuDialog(WidgetWrap): else: btn = Color.menu_button(ActionMenuButton(action.label)) width = max(width, len(btn.base_widget.label)) - connect_signal( - btn.base_widget, 'click', self.click, action.value) + connect_signal(btn.base_widget, "click", self.click, action.value) else: label = action.label if isinstance(label, Widget): @@ -80,12 +74,15 @@ class _ActionMenuDialog(WidgetWrap): rhs = "\N{BLACK RIGHT-POINTING SMALL TRIANGLE}" else: rhs = "" - btn = Columns([ - ('fixed', 1, Text("")), - Text(label), - ('fixed', 1, Text(rhs)), - ], dividechars=1) - btn = AttrWrap(btn, 'info_minor') + btn = Columns( + [ + ("fixed", 1, Text("")), + Text(label), + ("fixed", 1, Text(rhs)), + ], + dividechars=1, + ) + btn = AttrWrap(btn, "info_minor") group.append(btn) self.width = width super().__init__(Color.body(LineBox(ListBox(group)))) @@ -98,7 +95,7 @@ class _ActionMenuDialog(WidgetWrap): self.parent.close_pop_up() def keypress(self, size, key): - if key == 'esc': + if key == "esc": self.parent.close_pop_up() else: return super().keypress(size, key) @@ -116,11 +113,9 @@ class Action: class ActionMenu(PopUpLauncher): + signals = ["action", "open", "close"] - signals = ['action', 'open', 'close'] - - def __init__(self, opts, - icon="\N{BLACK RIGHT-POINTING SMALL TRIANGLE}"): + def __init__(self, opts, icon="\N{BLACK RIGHT-POINTING SMALL TRIANGLE}"): self._actions = [] for opt in opts: if not isinstance(opt, Action): @@ -157,8 +152,8 @@ class ActionMenu(PopUpLauncher): def get_pop_up_parameters(self): width = self._dialog.width + 7 return { - 'left': 1, - 'top': -1, - 'overlay_width': width, - 'overlay_height': len(self._actions) + 3, - } + "left": 1, + "top": -1, + "overlay_width": width, + "overlay_height": len(self._actions) + 3, + } diff --git a/subiquitycore/ui/anchors.py b/subiquitycore/ui/anchors.py index 6f7dbf2e..edc91fb2 100644 --- a/subiquitycore/ui/anchors.py +++ b/subiquitycore/ui/anchors.py @@ -15,21 +15,13 @@ import logging -from urwid import ( - int_scale, - SolidFill, - Text, - ) +from urwid import SolidFill, Text, int_scale -from subiquitycore.ui.container import ( - Columns, - Pile, - WidgetWrap, - ) +from subiquitycore.ui.container import Columns, Pile, WidgetWrap from subiquitycore.ui.utils import Color from subiquitycore.ui.width import widget_width -log = logging.getLogger('subiquitycore.ui.anchors') +log = logging.getLogger("subiquitycore.ui.anchors") class HeaderColumns(Columns): @@ -53,7 +45,7 @@ class HeaderColumns(Columns): center = max(int_scale(79, 101, maxcol + 1), 76) message = center - btn - pad = (maxcol - center)//2 + pad = (maxcol - center) // 2 if pad <= 0: pad = 1 message = maxcol - 2 - btn @@ -61,22 +53,25 @@ class HeaderColumns(Columns): class Header(WidgetWrap): - """ Header Widget """ + """Header Widget""" def __init__(self, title, right_icon): if isinstance(title, str): title = Text(title) - title = HeaderColumns([ - Text(""), - title, - right_icon, - Text(""), - ]) + title = HeaderColumns( + [ + Text(""), + title, + right_icon, + Text(""), + ] + ) super().__init__( - Pile([ - (1, Color.frame_header_fringe( - SolidFill("\N{lower half block}"))), + Pile( + [ + (1, Color.frame_header_fringe(SolidFill("\N{lower half block}"))), Color.frame_header(title), - (1, Color.frame_header_fringe( - SolidFill("\N{upper half block}"))), - ])) + (1, Color.frame_header_fringe(SolidFill("\N{upper half block}"))), + ] + ) + ) diff --git a/subiquitycore/ui/buttons.py b/subiquitycore/ui/buttons.py index a3178c26..42917602 100644 --- a/subiquitycore/ui/buttons.py +++ b/subiquitycore/ui/buttons.py @@ -26,13 +26,15 @@ def _stylized_button(left, right, style): btn = Btn(label, on_press=on_press, user_data=user_arg) btn._w.contents[2] = ( btn._w.contents[2][0], - btn._w.options('given', len(right))) - super().__init__(btn, style + '_button', style + '_button focus') + btn._w.options("given", len(right)), + ) + super().__init__(btn, style + "_button", style + "_button focus") + return StyleAttrMap def action_button(style): - return _stylized_button('[', ']', style) + return _stylized_button("[", "]", style) _forward_rhs = "\N{BLACK RIGHT-POINTING SMALL TRIANGLE} ]" diff --git a/subiquitycore/ui/confirmation.py b/subiquitycore/ui/confirmation.py index 34e53608..34b2a97f 100644 --- a/subiquitycore/ui/confirmation.py +++ b/subiquitycore/ui/confirmation.py @@ -17,22 +17,23 @@ from typing import Callable import urwid -from subiquitycore.ui.buttons import ( - back_btn, - done_btn, - ) -from subiquitycore.ui.utils import button_pile +from subiquitycore.ui.buttons import back_btn, done_btn from subiquitycore.ui.stretchy import Stretchy +from subiquitycore.ui.utils import button_pile class ConfirmationOverlay(Stretchy): - """ An overlay widget that asks the user to confirm or cancel an action. - """ - def __init__(self, title: str, question: str, - confirm_label: str, cancel_label: str, - on_confirm: Callable[[], None], - on_cancel: Callable[[], None]) -> None: + """An overlay widget that asks the user to confirm or cancel an action.""" + def __init__( + self, + title: str, + question: str, + confirm_label: str, + cancel_label: str, + on_confirm: Callable[[], None], + on_cancel: Callable[[], None], + ) -> None: self.on_cancel_cb = on_cancel self.on_confirm_cb = on_confirm self.choice_made = False @@ -40,12 +41,12 @@ class ConfirmationOverlay(Stretchy): widgets = [ urwid.Text(question), urwid.Text(""), - button_pile([ - back_btn( - label=cancel_label, on_press=lambda u: self.on_cancel()), - done_btn( - label=confirm_label, on_press=lambda u: self.on_confirm()), - ]), + button_pile( + [ + back_btn(label=cancel_label, on_press=lambda u: self.on_cancel()), + done_btn(label=confirm_label, on_press=lambda u: self.on_confirm()), + ] + ), ] super().__init__(title, widgets, 0, 2) diff --git a/subiquitycore/ui/container.py b/subiquitycore/ui/container.py index a5de7395..ed915cb1 100644 --- a/subiquitycore/ui/container.py +++ b/subiquitycore/ui/container.py @@ -66,7 +66,7 @@ import operator import urwid -log = logging.getLogger('subiquitycore.ui.container') +log = logging.getLogger("subiquitycore.ui.container") # The theory of tab-cycling: # @@ -111,12 +111,12 @@ def _has_other_selectable(widgets, cur_focus): for key, command in list(urwid.command_map._command.items()): - if command in ('next selectable', 'prev selectable', urwid.ACTIVATE): - urwid.command_map[key + ' no wrap'] = command + if command in ("next selectable", "prev selectable", urwid.ACTIVATE): + urwid.command_map[key + " no wrap"] = command enter_advancing_command_map = urwid.command_map.copy() -enter_advancing_command_map['enter'] = 'next selectable' -enter_advancing_command_map['enter no wrap'] = 'next selectable' +enter_advancing_command_map["enter"] = "next selectable" +enter_advancing_command_map["enter no wrap"] = "next selectable" class TabCyclingPile(urwid.Pile): @@ -159,11 +159,11 @@ class TabCyclingPile(urwid.Pile): # start subiquity change: new code downkey = key - if not key.endswith(' no wrap') and (self._command_map[key] in - ('next selectable', - 'prev selectable')): + if not key.endswith(" no wrap") and ( + self._command_map[key] in ("next selectable", "prev selectable") + ): if _has_other_selectable(self._widgets(), self.focus_position): - downkey += ' no wrap' + downkey += " no wrap" # end subiquity change item_rows = None @@ -176,40 +176,43 @@ class TabCyclingPile(urwid.Pile): # start subiquity change: pass downkey to focus, not key, # do not return if command_map[upkey] is next/prev selectable upkey = self.focus.keypress(tsize, downkey) - if self._command_map[upkey] not in ('cursor up', 'cursor down', - 'next selectable', - 'prev selectable'): + if self._command_map[upkey] not in ( + "cursor up", + "cursor down", + "next selectable", + "prev selectable", + ): return upkey # end subiquity change # begin subiquity change: new code - if self._command_map[key] == 'next selectable': + if self._command_map[key] == "next selectable": next_fp = self.focus_position + 1 for i, (w, o) in enumerate(self._contents[next_fp:], next_fp): if w.selectable(): self.set_focus(i) _maybe_call(w, "_select_first_selectable") return - if not key.endswith(' no wrap'): + if not key.endswith(" no wrap"): self._select_first_selectable() return key - elif self._command_map[key] == 'prev selectable': - positions = self._contents[:self.focus_position] + elif self._command_map[key] == "prev selectable": + positions = self._contents[: self.focus_position] for i, (w, o) in reversed(list(enumerate(positions))): if w.selectable(): self.set_focus(i) _maybe_call(w, "_select_last_selectable") return - if not key.endswith(' no wrap'): + if not key.endswith(" no wrap"): self._select_last_selectable() return key # continued subiquity change: set 'meth' appropriately - elif self._command_map[key] == 'cursor up': - candidates = list(range(i-1, -1, -1)) # count backwards to 0 - meth = '_select_last_selectable' + elif self._command_map[key] == "cursor up": + candidates = list(range(i - 1, -1, -1)) # count backwards to 0 + meth = "_select_last_selectable" else: # self._command_map[key] == 'cursor down' - candidates = list(range(i+1, len(self.contents))) - meth = '_select_first_selectable' + candidates = list(range(i + 1, len(self.contents))) + meth = "_select_first_selectable" # end subiquity change if not item_rows: @@ -224,18 +227,17 @@ class TabCyclingPile(urwid.Pile): # begin subiquity change: new code _maybe_call(self.focus, meth) # end subiquity change - if not hasattr(self.focus, 'move_cursor_to_coords'): + if not hasattr(self.focus, "move_cursor_to_coords"): return rows = item_rows[j] - if self._command_map[key] == 'cursor up': - rowlist = list(range(rows-1, -1, -1)) + if self._command_map[key] == "cursor up": + rowlist = list(range(rows - 1, -1, -1)) else: # self._command_map[key] == 'cursor down' rowlist = list(range(rows)) for row in rowlist: tsize = self.get_item_size(size, j, True, item_rows) - if self.focus_item.move_cursor_to_coords( - tsize, self.pref_col, row): + if self.focus_item.move_cursor_to_coords(tsize, self.pref_col, row): break return @@ -244,17 +246,21 @@ class TabCyclingPile(urwid.Pile): class OneSelectableColumns(urwid.Columns): - def __init__(self, widget_list, dividechars=0, focus_column=None, - min_width=1, box_columns=None): - super().__init__(widget_list, dividechars, focus_column, - min_width, box_columns) + def __init__( + self, + widget_list, + dividechars=0, + focus_column=None, + min_width=1, + box_columns=None, + ): + super().__init__(widget_list, dividechars, focus_column, min_width, box_columns) selectables = 0 for w, o in self._contents: if w.selectable(): selectables += 1 if selectables > 1: - raise Exception( - "subiquity only supports one selectable in a Columns") + raise Exception("subiquity only supports one selectable in a Columns") def _select_first_selectable(self): """Select first selectable child (possibily recursively).""" @@ -279,7 +285,7 @@ class TabCyclingListBox(urwid.ListBox): def __init__(self, body): # urwid.ListBox converts an arbitrary sequence argument to a # PollingListWalker, which doesn't work with the below code. - if getattr(body, 'get_focus', None) is None: + if getattr(body, "get_focus", None) is None: body = urwid.SimpleFocusListWalker(body) super().__init__(body) @@ -313,11 +319,11 @@ class TabCyclingListBox(urwid.ListBox): def keypress(self, size, key): downkey = key - if not key.endswith(' no wrap') and (self._command_map[key] in - ('next selectable', - 'prev selectable')): + if not key.endswith(" no wrap") and ( + self._command_map[key] in ("next selectable", "prev selectable") + ): if _has_other_selectable(self.body, self.focus_position): - downkey += ' no wrap' + downkey += " no wrap" upkey = super().keypress(size, downkey) if upkey != downkey: key = upkey @@ -325,24 +331,24 @@ class TabCyclingListBox(urwid.ListBox): if len(self.body) == 0: return key - if self._command_map[key] == 'next selectable': + if self._command_map[key] == "next selectable": next_fp = self.focus_position + 1 for i, w in enumerate(self.body[next_fp:], next_fp): if w.selectable(): self.set_focus(i) _maybe_call(w, "_select_first_selectable") return None - if not key.endswith(' no wrap'): + if not key.endswith(" no wrap"): self._select_first_selectable() return key - elif self._command_map[key] == 'prev selectable': - positions = self.body[:self.focus_position] + elif self._command_map[key] == "prev selectable": + positions = self.body[: self.focus_position] for i, w in reversed(list(enumerate(positions))): if w.selectable(): self.set_focus(i) _maybe_call(w, "_select_last_selectable") return None - if not key.endswith(' no wrap'): + if not key.endswith(" no wrap"): self._select_last_selectable() return key else: @@ -350,7 +356,6 @@ class TabCyclingListBox(urwid.ListBox): class FocusTrackingMixin: - def __init__(self, *args, **kw): super().__init__(*args, **kw) self._contents.set_focus_changed_callback(self._focus_changed) @@ -361,10 +366,10 @@ class FocusTrackingMixin: except IndexError: pass else: - _maybe_call(cur_focus, 'lost_focus') + _maybe_call(cur_focus, "lost_focus") def gained_focus(self): - _maybe_call(self.focus, 'gained_focus') + _maybe_call(self.focus, "gained_focus") def _focus_changed(self, new_focus): try: @@ -372,8 +377,8 @@ class FocusTrackingMixin: except IndexError: pass else: - _maybe_call(cur_focus, 'lost_focus') - _maybe_call(self[new_focus], 'gained_focus') + _maybe_call(cur_focus, "lost_focus") + _maybe_call(self[new_focus], "gained_focus") self._invalidate() @@ -386,7 +391,6 @@ class FocusTrackingColumns(FocusTrackingMixin, OneSelectableColumns): class FocusTrackingListBox(TabCyclingListBox): - def __init__(self, *args, **kw): super().__init__(*args, **kw) self.body.set_focus_changed_callback(self._focus_changed) @@ -394,15 +398,15 @@ class FocusTrackingListBox(TabCyclingListBox): def _focus_changed(self, new_focus): if not self._in_set_focus_complete: - _maybe_call(self.focus, 'lost_focus') - _maybe_call(self.body[new_focus], 'gained_focus') + _maybe_call(self.focus, "lost_focus") + _maybe_call(self.body[new_focus], "gained_focus") self._invalidate() def lost_focus(self): - _maybe_call(self.focus, 'lost_focus') + _maybe_call(self.focus, "lost_focus") def gained_focus(self): - _maybe_call(self.focus, 'gained_focus') + _maybe_call(self.focus, "gained_focus") def _set_focus_complete(self, size, focus): self._in_set_focus_complete = True @@ -417,21 +421,22 @@ Pile = FocusTrackingPile class ScrollBarListBox(urwid.WidgetDecoration): - top = urwid.SolidFill("\N{BLACK UP-POINTING SMALL TRIANGLE}") box = urwid.SolidFill("\N{FULL BLOCK}") bg = urwid.SolidFill(" ") bot = urwid.SolidFill("\N{BLACK DOWN-POINTING SMALL TRIANGLE}") def __init__(self, lb, *, always_scroll=False): - pile = Pile([ - ('fixed', 1, self.top), - ('weight', 1, self.bg), - ('weight', 1, self.box), - ('weight', 1, self.bg), - ('fixed', 1, self.bot), - ]) - self.bar = urwid.AttrMap(pile, 'scrollbar', 'scrollbar focus') + pile = Pile( + [ + ("fixed", 1, self.top), + ("weight", 1, self.bg), + ("weight", 1, self.box), + ("weight", 1, self.bg), + ("fixed", 1, self.bot), + ] + ) + self.bar = urwid.AttrMap(pile, "scrollbar", "scrollbar focus") self._always_scroll = always_scroll super().__init__(lb) @@ -443,7 +448,7 @@ class ScrollBarListBox(urwid.WidgetDecoration): def keypress(self, size, key): lb = self.original_widget if not self._scroll(size, True): - size = (size[0]-1, size[1]) + size = (size[0] - 1, size[1]) return lb.keypress(size, key) def render(self, size, focus=False): @@ -474,11 +479,11 @@ class ScrollBarListBox(urwid.WidgetDecoration): # Calculate the number of rows off the top and bottom of # the listbox. - if 'top' in visible: + if "top" in visible: top = 0 else: top = height_before_focus + inset - offset - if 'bottom' in visible: + if "bottom" in visible: bottom = 0 else: bottom = height - top - maxrow @@ -494,30 +499,32 @@ class ScrollBarListBox(urwid.WidgetDecoration): # rows and that being < 1 simplifies to the below. o = self.bar.base_widget.options if maxrow < top + bottom: - boxopt = o('given', 1) + boxopt = o("given", 1) else: - boxopt = o('weight', maxrow) + boxopt = o("weight", maxrow) top_arrow = self.top if top == 0: - top_arrow = urwid.AttrMap(top_arrow, 'scrollbar') + top_arrow = urwid.AttrMap(top_arrow, "scrollbar") bot_arrow = self.bot if bottom == 0: - bot_arrow = urwid.AttrMap(bot_arrow, 'scrollbar') + bot_arrow = urwid.AttrMap(bot_arrow, "scrollbar") self.bar.base_widget.contents[:] = [ - (top_arrow, o('given', 1)), - (self.bg, o('weight', top)), + (top_arrow, o("given", 1)), + (self.bg, o("weight", top)), (self.box, boxopt), - (self.bg, o('weight', bottom)), - (bot_arrow, o('given', 1)), - ] + (self.bg, o("weight", bottom)), + (bot_arrow, o("given", 1)), + ] lb_canvas = lb.render((maxcol - 1, maxrow), focus) bar_canvas = self.bar.render((1, maxrow), focus) - return urwid.CanvasJoin([ - (lb_canvas, lb.focus_position, True, maxcol - 1), - (bar_canvas, None, False, 1), - ]) + return urwid.CanvasJoin( + [ + (lb_canvas, lb.focus_position, True, maxcol - 1), + (bar_canvas, None, False, 1), + ] + ) def ListBox(body=None, *, always_scroll=False): @@ -525,11 +532,9 @@ def ListBox(body=None, *, always_scroll=False): # PollingListWalker, which doesn't work with our code. if body is None: body = [] - if body is getattr(body, 'get_focus', None) is None: + if body is getattr(body, "get_focus", None) is None: body = urwid.SimpleFocusListWalker(body) - return ScrollBarListBox( - FocusTrackingListBox(body), - always_scroll=always_scroll) + return ScrollBarListBox(FocusTrackingListBox(body), always_scroll=always_scroll) get_delegate = operator.attrgetter("_wrapped_widget.base_widget") @@ -543,9 +548,11 @@ class OurWidgetWrap(urwid.WidgetWrap): # does. _select_first_selectable = property( - lambda self: get_delegate(self)._select_first_selectable) + lambda self: get_delegate(self)._select_first_selectable + ) _select_last_selectable = property( - lambda self: get_delegate(self)._select_last_selectable) + lambda self: get_delegate(self)._select_last_selectable + ) lost_focus = property(lambda self: get_delegate(self).lost_focus) gained_focus = property(lambda self: get_delegate(self).gained_focus) diff --git a/subiquitycore/ui/form.py b/subiquitycore/ui/form.py index 3adaf258..ea73c588 100644 --- a/subiquitycore/ui/form.py +++ b/subiquitycore/ui/form.py @@ -17,45 +17,29 @@ import abc import logging from urllib.parse import urlparse +from urwid import CheckBox, MetaSignals +from urwid import Padding as UrwidPadding from urwid import ( - CheckBox, - connect_signal, - delegate_to_widget_mixin, - emit_signal, - MetaSignals, - Padding as UrwidPadding, RadioButton, Text, WidgetDecoration, - ) + connect_signal, + delegate_to_widget_mixin, + emit_signal, +) from subiquitycore.ui.buttons import cancel_btn, done_btn -from subiquitycore.ui.container import ( - Pile, - WidgetWrap, -) +from subiquitycore.ui.container import Pile, WidgetWrap from subiquitycore.ui.interactive import ( - PasswordEditor, - IntegerEditor, - StringEditor, EmailEditor, - ) + IntegerEditor, + PasswordEditor, + StringEditor, +) from subiquitycore.ui.selector import Selector -from subiquitycore.ui.table import ( - ColSpec, - TablePile, - TableRow, - ) -from subiquitycore.ui.utils import ( - button_pile, - Color, - disabled, - screen, - ) -from subiquitycore.ui.width import ( - widget_width, - ) - +from subiquitycore.ui.table import ColSpec, TablePile, TableRow +from subiquitycore.ui.utils import Color, button_pile, disabled, screen +from subiquitycore.ui.width import widget_width log = logging.getLogger("subiquitycore.ui.form") @@ -71,9 +55,7 @@ NO_CAPTION = object() NO_HELP = object() -class Toggleable(delegate_to_widget_mixin('_original_widget'), - WidgetDecoration): - +class Toggleable(delegate_to_widget_mixin("_original_widget"), WidgetDecoration): has_original_width = True def __init__(self, original): @@ -95,7 +77,6 @@ class Toggleable(delegate_to_widget_mixin('_original_widget'), class _Validator(WidgetWrap): - def __init__(self, field, w): self.field = field super().__init__(w) @@ -105,7 +86,7 @@ class _Validator(WidgetWrap): def lost_focus(self): self.field.showing_extra = False - lf = getattr(self._w.base_widget, 'lost_focus', None) + lf = getattr(self._w.base_widget, "lost_focus", None) if lf is not None: lf() self.field.validate() @@ -113,6 +94,7 @@ class _Validator(WidgetWrap): class WantsToKnowFormField(object): """A marker class.""" + def set_bound_form_field(self, bff): self.bff = bff @@ -121,7 +103,6 @@ form_colspecs = {1: ColSpec(pack=False)} class BoundFormField(object): - def __init__(self, field, form, widget): self.field = field self.form = form @@ -134,13 +115,13 @@ class BoundFormField(object): self._build_table() - if 'change' in getattr(widget, 'signals', []): - connect_signal(widget, 'change', self._change) + if "change" in getattr(widget, "signals", []): + connect_signal(widget, "change", self._change) if isinstance(widget, WantsToKnowFormField): widget.set_bound_form_field(self) def is_in_error(self) -> bool: - """ Tells whether this field is in error. """ + """Tells whether this field is in error.""" return self.in_error def _build_table(self): @@ -159,17 +140,16 @@ class BoundFormField(object): self.caption_text = Text(_(self.field.caption)) if self.field.caption_first: - self.caption_text.align = 'right' + self.caption_text.align = "right" first_row = [self.caption_text, _Validator(self, widget)] else: first_row = [ _Validator( self, - UrwidPadding( - widget, align='right', - width=widget_width(widget))), + UrwidPadding(widget, align="right", width=widget_width(widget)), + ), self.caption_text, - ] + ] second_row = [Text(""), self.under_text] rows = [first_row] @@ -228,7 +208,7 @@ class BoundFormField(object): else: self.in_error = True if show_error: - self.show_extra(('info_error', r)) + self.show_extra(("info_error", r)) self.form.validated() def show_extra(self, extra_markup): @@ -237,7 +217,7 @@ class BoundFormField(object): @property def value(self): - return self.clean(getattr(self, 'tmpval', self.widget.value)) + return self.clean(getattr(self, "tmpval", self.widget.value)) @value.setter def value(self, val): @@ -249,7 +229,7 @@ class BoundFormField(object): return self._help elif self.field.help is not None: if isinstance(self.field.help, str): - return ('info_minor', _(self.field.help)) + return ("info_minor", _(self.field.help)) else: return self.field.help else: @@ -280,22 +260,23 @@ class BoundFormField(object): for row in self._rows: row.enabled = val - def use_as_confirmation(self, for_field: "BoundFormField", desc: str) \ - -> None: - """ Mark this field as a confirmation field for another field. + def use_as_confirmation(self, for_field: "BoundFormField", desc: str) -> None: + """Mark this field as a confirmation field for another field. This will automatically compare the value of both fields when this - field (a.k.a., the confirmation field) is changed. """ + field (a.k.a., the confirmation field) is changed.""" + def _check_confirmation(sender, new_text): if not for_field.value.startswith(new_text): self.show_extra(("info_error", _(f"{desc} do not match"))) else: self.show_extra("") + connect_signal(self.widget, "change", _check_confirmation) class BoundSubFormField(BoundFormField): def is_in_error(self): - """ Tells whether this field is in error. We will also check if the + """Tells whether this field is in error. We will also check if the subform (if enabled) reports an error. """ if super().is_in_error(): @@ -308,7 +289,6 @@ class BoundSubFormField(BoundFormField): class FormField(abc.ABC): - next_index = 0 takes_default_style = True caption_first = True @@ -318,7 +298,7 @@ class FormField(abc.ABC): self.caption = caption # Allows styling at instantiation if isinstance(help, str): - self.help = ('info_minor', help) + self.help = ("info_minor", help) else: self.help = help self.index = FormField.next_index @@ -337,6 +317,7 @@ def simple_field(widget_maker): class Field(FormField): def _make_widget(self, form): return widget_maker() + return Field @@ -347,7 +328,6 @@ EmailField = simple_field(EmailEditor) class RadioButtonEditor(RadioButton): - reserve_columns = 3 @property @@ -360,7 +340,6 @@ class RadioButtonEditor(RadioButton): class RadioButtonField(FormField): - caption_first = False takes_default_style = False @@ -382,7 +361,7 @@ class RadioButtonField(FormField): class URLEditor(StringEditor, WantsToKnowFormField): - def __init__(self, allowed_schemes=frozenset(['http', 'https'])): + def __init__(self, allowed_schemes=frozenset(["http", "https"])): self.allowed_schemes = allowed_schemes super().__init__() @@ -403,8 +382,8 @@ class URLEditor(StringEditor, WantsToKnowFormField): else: schemes = schemes[0] raise ValueError( - _("This field must be a {schemes} URL.").format( - schemes=schemes)) + _("This field must be a {schemes} URL.").format(schemes=schemes) + ) return v @@ -412,7 +391,6 @@ URLField = simple_field(URLEditor) class ChoiceField(FormField): - takes_default_style = False def __init__(self, caption=None, help=None, choices=[]): @@ -424,7 +402,6 @@ class ChoiceField(FormField): class ReadOnlyWidget(Text): - @property def value(self): return self.text @@ -435,7 +412,6 @@ class ReadOnlyWidget(Text): class ReadOnlyField(FormField): - takes_default_style = False def _make_widget(self, form): @@ -443,7 +419,6 @@ class ReadOnlyField(FormField): class CheckBoxEditor(CheckBox): - reserve_columns = 3 @property @@ -456,16 +431,14 @@ class CheckBoxEditor(CheckBox): class BooleanField(FormField): - caption_first = False takes_default_style = False def _make_widget(self, form): - return CheckBoxEditor('') + return CheckBoxEditor("") class MetaForm(MetaSignals): - def __init__(self, name, bases, attrs): super().__init__(name, bases, attrs) _unbound_fields = [] @@ -480,17 +453,18 @@ class MetaForm(MetaSignals): class Form(object, metaclass=MetaForm): - - signals = ['submit', 'cancel'] + signals = ["submit", "cancel"] ok_label = _("Done") cancel_label = _("Cancel") def __init__(self, initial={}): - self.done_btn = Toggleable(done_btn(_(self.ok_label), - on_press=self._click_done)) - self.cancel_btn = Toggleable(cancel_btn(_(self.cancel_label), - on_press=self._click_cancel)) + self.done_btn = Toggleable( + done_btn(_(self.ok_label), on_press=self._click_done) + ) + self.cancel_btn = Toggleable( + cancel_btn(_(self.cancel_label), on_press=self._click_cancel) + ) self.buttons = button_pile([self.done_btn, self.cancel_btn]) self._fields = [] for field in self._unbound_fields: @@ -509,10 +483,10 @@ class Form(object, metaclass=MetaForm): bf.field.value = data[bf.field.name] def _click_done(self, sender): - emit_signal(self, 'submit', self) + emit_signal(self, "submit", self) def _click_cancel(self, sender): - emit_signal(self, 'cancel', self) + emit_signal(self, "cancel", self) def remove_field(self, field_name): new_fields = [] @@ -535,12 +509,15 @@ class Form(object, metaclass=MetaForm): def as_screen(self, focus_buttons=True, excerpt=None, narrow_rows=False): return screen( - self.as_rows(), self.buttons, - focus_buttons=focus_buttons, excerpt=excerpt, - narrow_rows=narrow_rows) + self.as_rows(), + self.buttons, + focus_buttons=focus_buttons, + excerpt=excerpt, + narrow_rows=narrow_rows, + ) def has_validation_error(self) -> bool: - """ Tells if any field is in error. """ + """Tells if any field is in error.""" return any(map(lambda f: f.is_in_error(), self._fields)) def validated(self): @@ -559,7 +536,6 @@ class Form(object, metaclass=MetaForm): class SubFormWidget(WidgetWrap): - def __init__(self, form): self.form = form super().__init__(Pile(form.as_rows())) @@ -575,7 +551,6 @@ class SubFormWidget(WidgetWrap): class SubFormField(FormField): - takes_default_style = False bound_field_class = BoundSubFormField @@ -589,12 +564,11 @@ class SubFormField(FormField): class SubForm(Form): - def __init__(self, parent, **kw): self.parent = parent super().__init__(**kw) def validated(self): - """ Propagate the validation to the parent. """ + """Propagate the validation to the parent.""" self.parent.validated() super().validated() diff --git a/subiquitycore/ui/frame.py b/subiquitycore/ui/frame.py index 37a9346d..0e0da083 100644 --- a/subiquitycore/ui/frame.py +++ b/subiquitycore/ui/frame.py @@ -17,31 +17,26 @@ import logging -from urwid import ( - Text, - ) +from urwid import Text + from subiquitycore.ui.anchors import Header -from subiquitycore.ui.container import ( - ListBox, - Pile, - WidgetWrap, - ) +from subiquitycore.ui.container import ListBox, Pile, WidgetWrap from subiquitycore.ui.utils import Color - -log = logging.getLogger('subiquitycore.ui.frame') +log = logging.getLogger("subiquitycore.ui.frame") class SubiquityCoreUI(WidgetWrap): - right_icon = Text("") def __init__(self): self.header = Header("", self.right_icon) - self.pile = Pile([ - ('pack', self.header), - ListBox([Text("")]), - ]) + self.pile = Pile( + [ + ("pack", self.header), + ListBox([Text("")]), + ] + ) self.pile.focus_position = 1 super().__init__(Color.body(self.pile)) diff --git a/subiquitycore/ui/interactive.py b/subiquitycore/ui/interactive.py index dacda29c..5a0c21a9 100644 --- a/subiquitycore/ui/interactive.py +++ b/subiquitycore/ui/interactive.py @@ -16,25 +16,20 @@ """ Re-usable input widgets """ -from functools import partial import logging import re +from functools import partial -from urwid import ( - Edit, - IntEdit, - ) +from urwid import Edit, IntEdit -from subiquitycore.ui.container import ( - WidgetWrap, - ) +from subiquitycore.ui.container import WidgetWrap from subiquitycore.ui.selector import Selector log = logging.getLogger("subiquitycore.ui.interactive") class StringEditor(Edit): - """ Edit input class + """Edit input class Attaches its result to the `value` accessor. """ @@ -49,8 +44,8 @@ class StringEditor(Edit): class PasswordEditor(StringEditor): - """ Password input prompt with masking - """ + """Password input prompt with masking""" + def __init__(self, mask="*"): super().__init__(mask=mask) @@ -66,20 +61,19 @@ class RestrictedEditor(StringEditor): return len(ch) == 1 and self.matcher.match(ch) is not None -RealnameEditor = partial(RestrictedEditor, r'[^:,=]') -EmailEditor = partial(RestrictedEditor, r'[-a-zA-Z0-9_.@+=]') +RealnameEditor = partial(RestrictedEditor, r"[^:,=]") +EmailEditor = partial(RestrictedEditor, r"[-a-zA-Z0-9_.@+=]") class UsernameEditor(StringEditor): - """ Username input prompt with input rules - """ + """Username input prompt with input rules""" def keypress(self, size, key): - ''' restrict what chars we allow for username ''' + """restrict what chars we allow for username""" if self._command_map[key] is not None: return super().keypress(size, key) new_text = self.insert_text_result(key)[0] - username = r'[a-z_][a-z0-9_-]*' + username = r"[a-z_][a-z0-9_-]*" # don't allow non username chars if new_text != "" and re.match(username, new_text) is None: return False @@ -87,8 +81,8 @@ class UsernameEditor(StringEditor): class IntegerEditor(WidgetWrap): - """ IntEdit input class - """ + """IntEdit input class""" + def __init__(self, default=0): self._edit = IntEdit(default=default) super().__init__(self._edit) @@ -103,8 +97,8 @@ class IntegerEditor(WidgetWrap): class YesNo(Selector): - """ Yes/No selector - """ + """Yes/No selector""" + def __init__(self): - opts = [_('Yes'), _('No')] + opts = [_("Yes"), _("No")] super().__init__(opts) diff --git a/subiquitycore/ui/selector.py b/subiquitycore/ui/selector.py index 8ec5965f..aabe73d8 100644 --- a/subiquitycore/ui/selector.py +++ b/subiquitycore/ui/selector.py @@ -13,30 +13,17 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from urwid import ( - ACTIVATE, - AttrWrap, - CompositeCanvas, - connect_signal, - LineBox, - Padding as UrwidPadding, - PopUpLauncher, - Text, - ) +from urwid import ACTIVATE, AttrWrap, CompositeCanvas, LineBox +from urwid import Padding as UrwidPadding +from urwid import PopUpLauncher, Text, connect_signal -from subiquitycore.ui.container import ( - Columns, - ListBox, - WidgetWrap, - ) -from subiquitycore.ui.utils import ( - Color, - ) +from subiquitycore.ui.container import Columns, ListBox, WidgetWrap +from subiquitycore.ui.utils import Color from subiquitycore.ui.width import widget_width class ClickableThing(WidgetWrap): - signals = ['click'] + signals = ["click"] def selectable(self): return True @@ -62,7 +49,7 @@ class ClickableThing(WidgetWrap): def keypress(self, size, key): if self._command_map[key] != ACTIVATE: return key - self._emit('click') + self._emit("click") class _PopUpSelectDialog(WidgetWrap): @@ -74,23 +61,25 @@ class _PopUpSelectDialog(WidgetWrap): for i, option in enumerate(self.parent._options): if option.enabled: btn = ClickableThing(option.label) - connect_signal(btn, 'click', self.click, i) + connect_signal(btn, "click", self.click, i) if i == cur_index: - rhs = '\N{BLACK LEFT-POINTING SMALL TRIANGLE} ' + rhs = "\N{BLACK LEFT-POINTING SMALL TRIANGLE} " else: - rhs = '' + rhs = "" else: btn = option.label - rhs = '' - row = Columns([ - (1, Text("")), - btn, - (2, Text(rhs)), - ]) + rhs = "" + row = Columns( + [ + (1, Text("")), + btn, + (2, Text(rhs)), + ] + ) if option.enabled: - row = AttrWrap(row, 'menu_button', 'menu_button focus') + row = AttrWrap(row, "menu_button", "menu_button focus") else: - row = AttrWrap(row, 'info_minor') + row = AttrWrap(row, "info_minor") btn = UrwidPadding(row, width=self.parent._padding.width) group.append(btn) list_box = ListBox(group) @@ -102,7 +91,7 @@ class _PopUpSelectDialog(WidgetWrap): self.parent.close_pop_up() def keypress(self, size, key): - if key == 'esc': + if key == "esc": self.parent.close_pop_up() else: return super().keypress(size, key) @@ -113,7 +102,6 @@ class SelectorError(Exception): class Option: - def __init__(self, val): if not isinstance(val, tuple): if isinstance(val, Option): @@ -170,17 +158,24 @@ class Selector(WidgetWrap): (A bit like