format with black + isort

This commit is contained in:
Dan Bungert 2023-07-25 15:26:25 -06:00
parent 93a34a64fa
commit 34d40643ad
298 changed files with 14900 additions and 14044 deletions

View File

@ -16,6 +16,7 @@
""" Console-Conf """ """ Console-Conf """
from subiquitycore import i18n from subiquitycore import i18n
__all__ = [ __all__ = [
'i18n', "i18n",
] ]

View File

@ -16,42 +16,63 @@
import argparse import argparse
import asyncio import asyncio
import sys
import os
import logging import logging
from subiquitycore.log import setup_logger import os
from subiquitycore import __version__ as VERSION import sys
from console_conf.core import ConsoleConf, RecoveryChooser from console_conf.core import ConsoleConf, RecoveryChooser
from subiquitycore import __version__ as VERSION
from subiquitycore.log import setup_logger
def parse_options(argv): def parse_options(argv):
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description=( description=("console-conf - Pre-Ownership Configuration for Ubuntu Core"),
'console-conf - Pre-Ownership Configuration for Ubuntu Core'), prog="console-conf",
prog='console-conf') )
parser.add_argument('--dry-run', action='store_true', parser.add_argument(
dest='dry_run', "--dry-run",
help='menu-only, do not call installer function') action="store_true",
parser.add_argument('--serial', action='store_true', dest="dry_run",
dest='run_on_serial', help="menu-only, do not call installer function",
help='Run the installer over serial console.') )
parser.add_argument('--ascii', action='store_true', parser.add_argument(
dest='ascii', "--serial",
help='Run the installer in ascii mode.') action="store_true",
parser.add_argument('--machine-config', metavar='CONFIG', dest="run_on_serial",
dest='machine_config', help="Run the installer over serial console.",
)
parser.add_argument(
"--ascii",
action="store_true",
dest="ascii",
help="Run the installer in ascii mode.",
)
parser.add_argument(
"--machine-config",
metavar="CONFIG",
dest="machine_config",
type=argparse.FileType(), type=argparse.FileType(),
help="Don't Probe. Use probe data file") help="Don't Probe. Use probe data file",
parser.add_argument('--screens', action='append', dest='screens', )
default=[]) parser.add_argument("--screens", action="append", dest="screens", default=[])
parser.add_argument('--answers') parser.add_argument("--answers")
parser.add_argument('--recovery-chooser-mode', action='store_true', parser.add_argument(
dest='chooser_systems', "--recovery-chooser-mode",
help=('Run as a recovery chooser interacting with the ' action="store_true",
'calling process over stdin/stdout streams')) dest="chooser_systems",
parser.add_argument('--output-base', action='store', dest='output_base', help=(
default='.subiquity', "Run as a recovery chooser interacting with the "
help='in dryrun, control basedir of files') "calling process over stdin/stdout streams"
),
)
parser.add_argument(
"--output-base",
action="store",
dest="output_base",
default=".subiquity",
help="in dryrun, control basedir of files",
)
return parser.parse_args(argv) return parser.parse_args(argv)
@ -64,7 +85,7 @@ def main():
if opts.dry_run: if opts.dry_run:
LOGDIR = opts.output_base LOGDIR = opts.output_base
setup_logger(dir=LOGDIR) setup_logger(dir=LOGDIR)
logger = logging.getLogger('console_conf') logger = logging.getLogger("console_conf")
logger.info("Starting console-conf v{}".format(VERSION)) logger.info("Starting console-conf v{}".format(VERSION))
logger.info("Arguments passed: {}".format(sys.argv)) logger.info("Arguments passed: {}".format(sys.argv))
@ -73,8 +94,7 @@ def main():
# when running as a chooser, the stdin/stdout streams are set up by # when running as a chooser, the stdin/stdout streams are set up by
# the process that runs us, attempt to restore the tty in/out by # the process that runs us, attempt to restore the tty in/out by
# looking at stderr # looking at stderr
chooser_input, chooser_output = restore_std_streams_from( chooser_input, chooser_output = restore_std_streams_from(sys.stderr)
sys.stderr)
interface = RecoveryChooser(opts, chooser_input, chooser_output) interface = RecoveryChooser(opts, chooser_input, chooser_output)
else: else:
interface = ConsoleConf(opts) interface = ConsoleConf(opts)
@ -91,10 +111,10 @@ def restore_std_streams_from(from_file):
tty = os.ttyname(from_file.fileno()) tty = os.ttyname(from_file.fileno())
# we have tty now # we have tty now
chooser_input, chooser_output = sys.stdin, sys.stdout chooser_input, chooser_output = sys.stdin, sys.stdout
sys.stdin = open(tty, 'r') sys.stdin = open(tty, "r")
sys.stdout = open(tty, 'w') sys.stdout = open(tty, "w")
return chooser_input, chooser_output return chooser_input, chooser_output
if __name__ == '__main__': if __name__ == "__main__":
sys.exit(main()) sys.exit(main())

View File

@ -17,20 +17,18 @@
import logging import logging
import sys import sys
from subiquitycore.log import setup_logger
from subiquitycore import __version__ as VERSION
from console_conf.controllers.identity import write_login_details_standalone from console_conf.controllers.identity import write_login_details_standalone
from subiquitycore import __version__ as VERSION
from subiquitycore.log import setup_logger
def main(): def main():
logger = setup_logger(dir='/var/log/console-conf') logger = setup_logger(dir="/var/log/console-conf")
logger = logging.getLogger('console_conf') logger = logging.getLogger("console_conf")
logger.info( logger.info("Starting console-conf-write-login-details v{}".format(VERSION))
"Starting console-conf-write-login-details v{}".format(VERSION))
logger.info("Arguments passed: {}".format(sys.argv)) logger.info("Arguments passed: {}".format(sys.argv))
return write_login_details_standalone() return write_login_details_standalone()
if __name__ == '__main__': if __name__ == "__main__":
sys.exit(main()) sys.exit(main())

View File

@ -15,19 +15,17 @@
""" console-conf controllers """ """ console-conf controllers """
from .identity import IdentityController
from subiquitycore.controllers.network import NetworkController from subiquitycore.controllers.network import NetworkController
from .welcome import WelcomeController, RecoveryChooserWelcomeController
from .chooser import ( from .chooser import RecoveryChooserConfirmController, RecoveryChooserController
RecoveryChooserController, from .identity import IdentityController
RecoveryChooserConfirmController from .welcome import RecoveryChooserWelcomeController, WelcomeController
)
__all__ = [ __all__ = [
'IdentityController', "IdentityController",
'NetworkController', "NetworkController",
'WelcomeController', "WelcomeController",
'RecoveryChooserWelcomeController', "RecoveryChooserWelcomeController",
'RecoveryChooserController', "RecoveryChooserController",
'RecoveryChooserConfirmController', "RecoveryChooserConfirmController",
] ]

View File

@ -15,18 +15,16 @@
import logging import logging
from console_conf.ui.views import ( from console_conf.ui.views import (
ChooserView,
ChooserCurrentSystemView,
ChooserConfirmView, ChooserConfirmView,
) ChooserCurrentSystemView,
ChooserView,
)
from subiquitycore.tuicontroller import TuiController from subiquitycore.tuicontroller import TuiController
log = logging.getLogger("console_conf.controllers.chooser") log = logging.getLogger("console_conf.controllers.chooser")
class RecoveryChooserBaseController(TuiController): class RecoveryChooserBaseController(TuiController):
def __init__(self, app): def __init__(self, app):
super().__init__(app) super().__init__(app)
self.model = app.base_model self.model = app.base_model
@ -40,7 +38,6 @@ class RecoveryChooserBaseController(TuiController):
class RecoveryChooserController(RecoveryChooserBaseController): class RecoveryChooserController(RecoveryChooserBaseController):
def __init__(self, app): def __init__(self, app):
super().__init__(app) super().__init__(app)
self._model_view, self._all_view = self._make_views() self._model_view, self._all_view = self._make_views()
@ -64,9 +61,9 @@ class RecoveryChooserController(RecoveryChooserBaseController):
if self.model.current and self.model.current.actions: if self.model.current and self.model.current.actions:
# only when we have a current system and it has actions available # only when we have a current system and it has actions available
more = len(self.model.systems) > 1 more = len(self.model.systems) > 1
current_view = ChooserCurrentSystemView(self, current_view = ChooserCurrentSystemView(
self.model.current, self, self.model.current, has_more=more
has_more=more) )
all_view = ChooserView(self, self.model.systems) all_view = ChooserView(self, self.model.systems)
return current_view, all_view return current_view, all_view
@ -80,8 +77,7 @@ class RecoveryChooserController(RecoveryChooserBaseController):
self.ui.set_body(self._all_view) self.ui.set_body(self._all_view)
def back(self): def back(self):
if self._current_view == self._all_view and \ if self._current_view == self._all_view and self._model_view is not None:
self._model_view is not None:
# back in the all-systems screen goes back to the current model # back in the all-systems screen goes back to the current model
# screen, provided we have one # screen, provided we have one
self._current_view = self._model_view self._current_view = self._model_view

View File

@ -20,18 +20,17 @@ import pwd
import shlex import shlex
import sys import sys
from subiquitycore.ssh import host_key_info, get_ips_standalone from console_conf.ui.views import IdentityView, LoginView
from subiquitycore.snapd import SnapdConnection from subiquitycore.snapd import SnapdConnection
from subiquitycore.ssh import get_ips_standalone, host_key_info
from subiquitycore.tuicontroller import TuiController from subiquitycore.tuicontroller import TuiController
from subiquitycore.utils import disable_console_conf, run_command from subiquitycore.utils import disable_console_conf, run_command
from console_conf.ui.views import IdentityView, LoginView log = logging.getLogger("console_conf.controllers.identity")
log = logging.getLogger('console_conf.controllers.identity')
def get_core_version(): def get_core_version():
""" For a ubuntu-core system, return its version or None """ """For a ubuntu-core system, return its version or None"""
path = "/usr/lib/os-release" path = "/usr/lib/os-release"
try: try:
@ -53,32 +52,32 @@ def get_core_version():
def get_managed(): def get_managed():
""" Check if device is managed """ """Check if device is managed"""
con = SnapdConnection('', '/run/snapd.socket') con = SnapdConnection("", "/run/snapd.socket")
return con.get('v2/system-info').json()['result']['managed'] return con.get("v2/system-info").json()["result"]["managed"]
def get_realname(username): def get_realname(username):
try: try:
info = pwd.getpwnam(username) info = pwd.getpwnam(username)
except KeyError: except KeyError:
return '' return ""
return info.pw_gecos.split(',', 1)[0] return info.pw_gecos.split(",", 1)[0]
def get_device_owner(): def get_device_owner():
""" Get device owner, if any """ """Get device owner, if any"""
con = SnapdConnection('', '/run/snapd.socket') con = SnapdConnection("", "/run/snapd.socket")
for user in con.get('v2/users').json()['result']: for user in con.get("v2/users").json()["result"]:
if 'username' not in user: if "username" not in user:
continue continue
username = user['username'] username = user["username"]
homedir = '/home/' + username homedir = "/home/" + username
if os.path.isdir(homedir): if os.path.isdir(homedir):
return { return {
'username': username, "username": username,
'realname': get_realname(username), "realname": get_realname(username),
'homedir': homedir, "homedir": homedir,
} }
return None return None
@ -110,15 +109,22 @@ def write_login_details(fp, username, ips):
tty_name = os.ttyname(0)[5:] # strip off the /dev/ tty_name = os.ttyname(0)[5:] # strip off the /dev/
version = get_core_version() or "16" version = get_core_version() or "16"
if len(ips) == 0: if len(ips) == 0:
fp.write(login_details_tmpl_no_ip.format( fp.write(
sshcommands=sshcommands, tty_name=tty_name, version=version)) login_details_tmpl_no_ip.format(
sshcommands=sshcommands, tty_name=tty_name, version=version
)
)
else: else:
first_ip = ips[0] first_ip = ips[0]
fp.write(login_details_tmpl.format(sshcommands=sshcommands, fp.write(
login_details_tmpl.format(
sshcommands=sshcommands,
host_key_info=host_key_info(), host_key_info=host_key_info(),
tty_name=tty_name, tty_name=tty_name,
first_ip=first_ip, first_ip=first_ip,
version=version)) version=version,
)
)
def write_login_details_standalone(): def write_login_details_standalone():
@ -131,18 +137,16 @@ def write_login_details_standalone():
else: else:
tty_name = os.ttyname(0)[5:] tty_name = os.ttyname(0)[5:]
version = get_core_version() or "16" version = get_core_version() or "16"
print(login_details_tmpl_no_ip.format( print(login_details_tmpl_no_ip.format(tty_name=tty_name, version=version))
tty_name=tty_name, version=version))
return 2 return 2
if owner is None: if owner is None:
print("device managed without user @ {}".format(', '.join(ips))) print("device managed without user @ {}".format(", ".join(ips)))
else: else:
write_login_details(sys.stdout, owner['username'], ips) write_login_details(sys.stdout, owner["username"], ips)
return 0 return 0
class IdentityController(TuiController): class IdentityController(TuiController):
def __init__(self, app): def __init__(self, app):
super().__init__(app) super().__init__(app)
self.model = app.base_model.identity self.model = app.base_model.identity
@ -152,11 +156,9 @@ class IdentityController(TuiController):
device_owner = get_device_owner() device_owner = get_device_owner()
if device_owner: if device_owner:
self.model.add_user(device_owner) self.model.add_user(device_owner)
key_file = os.path.join(device_owner['homedir'], key_file = os.path.join(device_owner["homedir"], ".ssh/authorized_keys")
".ssh/authorized_keys") cp = run_command(["ssh-keygen", "-lf", key_file])
cp = run_command(['ssh-keygen', '-lf', key_file]) self.model.user.fingerprints = cp.stdout.replace("\r", "").splitlines()
self.model.user.fingerprints = (
cp.stdout.replace('\r', '').splitlines())
return self.make_login_view() return self.make_login_view()
else: else:
return IdentityView(self.model, self) return IdentityView(self.model, self)
@ -164,35 +166,35 @@ class IdentityController(TuiController):
def identity_done(self, email): def identity_done(self, email):
if self.opts.dry_run: if self.opts.dry_run:
result = { result = {
'realname': email, "realname": email,
'username': email, "username": email,
} }
self.model.add_user(result) self.model.add_user(result)
login_details_path = self.opts.output_base + '/login-details.txt' login_details_path = self.opts.output_base + "/login-details.txt"
else: else:
self.app.urwid_loop.draw_screen() self.app.urwid_loop.draw_screen()
cp = run_command( cp = run_command(["snap", "create-user", "--sudoer", "--json", email])
["snap", "create-user", "--sudoer", "--json", email])
if cp.returncode != 0: if cp.returncode != 0:
if isinstance(self.ui.body, IdentityView): if isinstance(self.ui.body, IdentityView):
self.ui.body.snap_create_user_failed( self.ui.body.snap_create_user_failed(
"Creating user failed:", cp.stderr) "Creating user failed:", cp.stderr
)
return return
else: else:
data = json.loads(cp.stdout) data = json.loads(cp.stdout)
result = { result = {
'realname': email, "realname": email,
'username': data['username'], "username": data["username"],
} }
os.makedirs('/run/console-conf', exist_ok=True) os.makedirs("/run/console-conf", exist_ok=True)
login_details_path = '/run/console-conf/login-details.txt' login_details_path = "/run/console-conf/login-details.txt"
self.model.add_user(result) self.model.add_user(result)
ips = [] ips = []
net_model = self.app.base_model.network net_model = self.app.base_model.network
for dev in net_model.get_all_netdevs(): for dev in net_model.get_all_netdevs():
ips.extend(dev.actual_global_ip_addresses) ips.extend(dev.actual_global_ip_addresses)
with open(login_details_path, 'w') as fp: with open(login_details_path, "w") as fp:
write_login_details(fp, result['username'], ips) write_login_details(fp, result["username"], ips)
self.login() self.login()
def cancel(self): def cancel(self):

View File

@ -15,16 +15,12 @@
import unittest import unittest
from unittest import mock from unittest import mock
from subiquitycore.tests.mocks import make_app
from console_conf.controllers.chooser import ( from console_conf.controllers.chooser import (
RecoveryChooserController,
RecoveryChooserConfirmController, RecoveryChooserConfirmController,
) RecoveryChooserController,
from console_conf.models.systems import ( )
RecoverySystemsModel, from console_conf.models.systems import RecoverySystemsModel, SelectedSystemAction
SelectedSystemAction, from subiquitycore.tests.mocks import make_app
)
model1_non_current = { model1_non_current = {
"current": False, "current": False,
@ -43,7 +39,7 @@ model1_non_current = {
"actions": [ "actions": [
{"title": "action 1", "mode": "action1"}, {"title": "action 1", "mode": "action1"},
{"title": "action 2", "mode": "action2"}, {"title": "action 2", "mode": "action2"},
] ],
} }
model2_current = { model2_current = {
@ -62,21 +58,19 @@ model2_current = {
}, },
"actions": [ "actions": [
{"title": "action 1", "mode": "action1"}, {"title": "action 1", "mode": "action1"},
] ],
} }
def make_model(): def make_model():
return RecoverySystemsModel.from_systems([model1_non_current, return RecoverySystemsModel.from_systems([model1_non_current, model2_current])
model2_current])
class TestChooserConfirmController(unittest.TestCase): class TestChooserConfirmController(unittest.TestCase):
def test_back(self): def test_back(self):
app = make_app() app = make_app()
c = RecoveryChooserConfirmController(app) c = RecoveryChooserConfirmController(app)
c.model = mock.Mock(selection='selection') c.model = mock.Mock(selection="selection")
c.back() c.back()
app.respond.assert_not_called() app.respond.assert_not_called()
app.exit.assert_not_called() app.exit.assert_not_called()
@ -86,12 +80,12 @@ class TestChooserConfirmController(unittest.TestCase):
def test_confirm(self): def test_confirm(self):
app = make_app() app = make_app()
c = RecoveryChooserConfirmController(app) c = RecoveryChooserConfirmController(app)
c.model = mock.Mock(selection='selection') c.model = mock.Mock(selection="selection")
c.confirm() c.confirm()
app.respond.assert_called_with('selection') app.respond.assert_called_with("selection")
app.exit.assert_called() app.exit.assert_called()
@mock.patch('console_conf.controllers.chooser.ChooserConfirmView') @mock.patch("console_conf.controllers.chooser.ChooserConfirmView")
def test_confirm_view(self, ccv): def test_confirm_view(self, ccv):
app = make_app() app = make_app()
c = RecoveryChooserConfirmController(app) c = RecoveryChooserConfirmController(app)
@ -102,24 +96,25 @@ class TestChooserConfirmController(unittest.TestCase):
class TestChooserController(unittest.TestCase): class TestChooserController(unittest.TestCase):
def test_select(self): def test_select(self):
app = make_app(model=make_model()) app = make_app(model=make_model())
c = RecoveryChooserController(app) c = RecoveryChooserController(app)
c.select(c.model.systems[0], c.model.systems[0].actions[0]) c.select(c.model.systems[0], c.model.systems[0].actions[0])
exp = SelectedSystemAction(system=c.model.systems[0], exp = SelectedSystemAction(
action=c.model.systems[0].actions[0]) system=c.model.systems[0], action=c.model.systems[0].actions[0]
)
self.assertEqual(c.model.selection, exp) self.assertEqual(c.model.selection, exp)
app.next_screen.assert_called() app.next_screen.assert_called()
app.respond.assert_not_called() app.respond.assert_not_called()
app.exit.assert_not_called() app.exit.assert_not_called()
@mock.patch('console_conf.controllers.chooser.ChooserCurrentSystemView', @mock.patch(
return_value='current') "console_conf.controllers.chooser.ChooserCurrentSystemView",
@mock.patch('console_conf.controllers.chooser.ChooserView', return_value="current",
return_value='all') )
@mock.patch("console_conf.controllers.chooser.ChooserView", return_value="all")
def test_current_ui_first(self, cv, ccsv): def test_current_ui_first(self, cv, ccsv):
app = make_app(model=make_model()) app = make_app(model=make_model())
c = RecoveryChooserController(app) c = RecoveryChooserController(app)
@ -128,15 +123,16 @@ class TestChooserController(unittest.TestCase):
# as well as all systems view # as well as all systems view
cv.assert_called_with(c, c.model.systems) cv.assert_called_with(c, c.model.systems)
v = c.make_ui() v = c.make_ui()
self.assertEqual(v, 'current') self.assertEqual(v, "current")
# user selects more options and the view is replaced # user selects more options and the view is replaced
c.more_options() c.more_options()
c.ui.set_body.assert_called_with('all') c.ui.set_body.assert_called_with("all")
@mock.patch('console_conf.controllers.chooser.ChooserCurrentSystemView', @mock.patch(
return_value='current') "console_conf.controllers.chooser.ChooserCurrentSystemView",
@mock.patch('console_conf.controllers.chooser.ChooserView', return_value="current",
return_value='all') )
@mock.patch("console_conf.controllers.chooser.ChooserView", return_value="all")
def test_current_current_all_there_and_back(self, cv, ccsv): def test_current_current_all_there_and_back(self, cv, ccsv):
app = make_app(model=make_model()) app = make_app(model=make_model())
c = RecoveryChooserController(app) c = RecoveryChooserController(app)
@ -145,21 +141,22 @@ class TestChooserController(unittest.TestCase):
cv.assert_called_with(c, c.model.systems) cv.assert_called_with(c, c.model.systems)
v = c.make_ui() v = c.make_ui()
self.assertEqual(v, 'current') self.assertEqual(v, "current")
# user selects more options and the view is replaced # user selects more options and the view is replaced
c.more_options() c.more_options()
c.ui.set_body.assert_called_with('all') c.ui.set_body.assert_called_with("all")
# go back now # go back now
c.back() c.back()
c.ui.set_body.assert_called_with('current') c.ui.set_body.assert_called_with("current")
# nothing # nothing
c.back() c.back()
c.ui.set_body.not_called() c.ui.set_body.not_called()
@mock.patch('console_conf.controllers.chooser.ChooserCurrentSystemView', @mock.patch(
return_value='current') "console_conf.controllers.chooser.ChooserCurrentSystemView",
@mock.patch('console_conf.controllers.chooser.ChooserView', return_value="current",
return_value='all') )
@mock.patch("console_conf.controllers.chooser.ChooserView", return_value="all")
def test_only_one_and_current(self, cv, ccsv): def test_only_one_and_current(self, cv, ccsv):
model = RecoverySystemsModel.from_systems([model2_current]) model = RecoverySystemsModel.from_systems([model2_current])
app = make_app(model=model) app = make_app(model=model)
@ -169,15 +166,16 @@ class TestChooserController(unittest.TestCase):
cv.assert_called_with(c, c.model.systems) cv.assert_called_with(c, c.model.systems)
v = c.make_ui() v = c.make_ui()
self.assertEqual(v, 'current') self.assertEqual(v, "current")
# going back does nothing # going back does nothing
c.back() c.back()
c.ui.set_body.not_called() c.ui.set_body.not_called()
@mock.patch('console_conf.controllers.chooser.ChooserCurrentSystemView', @mock.patch(
return_value='current') "console_conf.controllers.chooser.ChooserCurrentSystemView",
@mock.patch('console_conf.controllers.chooser.ChooserView', return_value="current",
return_value='all') )
@mock.patch("console_conf.controllers.chooser.ChooserView", return_value="all")
def test_all_systems_first_no_current(self, cv, ccsv): def test_all_systems_first_no_current(self, cv, ccsv):
model = RecoverySystemsModel.from_systems([model1_non_current]) model = RecoverySystemsModel.from_systems([model1_non_current])
app = make_app(model=model) app = make_app(model=model)
@ -192,4 +190,4 @@ class TestChooserController(unittest.TestCase):
ccsv.assert_not_called() ccsv.assert_not_called()
v = c.make_ui() v = c.make_ui()
self.assertEqual(v, 'all') self.assertEqual(v, "all")

View File

@ -13,13 +13,11 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from console_conf.ui.views import WelcomeView, ChooserWelcomeView from console_conf.ui.views import ChooserWelcomeView, WelcomeView
from subiquitycore.tuicontroller import TuiController from subiquitycore.tuicontroller import TuiController
class WelcomeController(TuiController): class WelcomeController(TuiController):
welcome_view = WelcomeView welcome_view = WelcomeView
def make_ui(self): def make_ui(self):
@ -34,7 +32,6 @@ class WelcomeController(TuiController):
class RecoveryChooserWelcomeController(WelcomeController): class RecoveryChooserWelcomeController(WelcomeController):
welcome_view = ChooserWelcomeView welcome_view = ChooserWelcomeView
def __init__(self, app): def __init__(self, app):

View File

@ -15,18 +15,17 @@
import logging import logging
from subiquitycore.prober import Prober
from subiquitycore.tui import TuiApplication
from console_conf.models.console_conf import ConsoleConfModel from console_conf.models.console_conf import ConsoleConfModel
from console_conf.models.systems import RecoverySystemsModel from console_conf.models.systems import RecoverySystemsModel
from subiquitycore.prober import Prober
from subiquitycore.tui import TuiApplication
log = logging.getLogger("console_conf.core") log = logging.getLogger("console_conf.core")
class ConsoleConf(TuiApplication): class ConsoleConf(TuiApplication):
from console_conf import controllers as controllers_mod from console_conf import controllers as controllers_mod
project = "console_conf" project = "console_conf"
make_model = ConsoleConfModel make_model = ConsoleConfModel
@ -43,8 +42,8 @@ class ConsoleConf(TuiApplication):
class RecoveryChooser(TuiApplication): class RecoveryChooser(TuiApplication):
from console_conf import controllers as controllers_mod from console_conf import controllers as controllers_mod
project = "console_conf" project = "console_conf"
controllers = [ controllers = [

View File

@ -1,5 +1,5 @@
from subiquitycore.models.network import NetworkModel from subiquitycore.models.network import NetworkModel
from .identity import IdentityModel from .identity import IdentityModel

View File

@ -17,8 +17,7 @@ import logging
import attr import attr
log = logging.getLogger("console_conf.models.identity")
log = logging.getLogger('console_conf.models.identity')
@attr.s @attr.s
@ -30,8 +29,7 @@ class User(object):
class IdentityModel(object): class IdentityModel(object):
""" Model representing user identity """Model representing user identity"""
"""
def __init__(self): def __init__(self):
self._user = None self._user = None

View File

@ -13,8 +13,8 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging
import json import json
import logging
import attr import attr
import jsonschema import jsonschema
@ -113,9 +113,7 @@ class RecoverySystemsModel:
m = syst["model"] m = syst["model"]
b = syst["brand"] b = syst["brand"]
model = SystemModel( model = SystemModel(
model=m["model"], model=m["model"], brand_id=m["brand-id"], display_name=m["display-name"]
brand_id=m["brand-id"],
display_name=m["display-name"]
) )
brand = Brand( brand = Brand(
ID=b["id"], ID=b["id"],
@ -180,11 +178,13 @@ class RecoverySystem:
actions = attr.ib() actions = attr.ib()
def __eq__(self, other): def __eq__(self, other):
return self.current == other.current and \ return (
self.label == other.label and \ self.current == other.current
self.model == other.model and \ and self.label == other.label
self.brand == other.brand and \ and self.model == other.model
self.actions == other.actions and self.brand == other.brand
and self.actions == other.actions
)
@attr.s @attr.s
@ -195,10 +195,12 @@ class Brand:
validation = attr.ib() validation = attr.ib()
def __eq__(self, other): def __eq__(self, other):
return self.ID == other.ID and \ return (
self.username == other.username and \ self.ID == other.ID
self.display_name == other.display_name and \ and self.username == other.username
self.validation == other.validation and self.display_name == other.display_name
and self.validation == other.validation
)
@attr.s @attr.s
@ -208,9 +210,11 @@ class SystemModel:
display_name = attr.ib() display_name = attr.ib()
def __eq__(self, other): def __eq__(self, other):
return self.model == other.model and \ return (
self.brand_id == other.brand_id and \ self.model == other.model
self.display_name == other.display_name and self.brand_id == other.brand_id
and self.display_name == other.display_name
)
@attr.s @attr.s
@ -219,8 +223,7 @@ class SystemAction:
mode = attr.ib() mode = attr.ib()
def __eq__(self, other): def __eq__(self, other):
return self.title == other.title and \ return self.title == other.title and self.mode == other.mode
self.mode == other.mode
@attr.s @attr.s
@ -229,5 +232,4 @@ class SelectedSystemAction:
action = attr.ib() action = attr.ib()
def __eq__(self, other): def __eq__(self, other):
return self.system == other.system and \ return self.system == other.system and self.action == other.action
self.action == other.action

View File

@ -13,23 +13,23 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
import json import json
import jsonschema import unittest
from io import StringIO from io import StringIO
import jsonschema
from console_conf.models.systems import ( from console_conf.models.systems import (
RecoverySystemsModel,
RecoverySystem,
Brand, Brand,
SystemModel, RecoverySystem,
SystemAction, RecoverySystemsModel,
SelectedSystemAction, SelectedSystemAction,
) SystemAction,
SystemModel,
)
class RecoverySystemsModelTests(unittest.TestCase): class RecoverySystemsModelTests(unittest.TestCase):
reference = { reference = {
"systems": [ "systems": [
{ {
@ -49,7 +49,7 @@ class RecoverySystemsModelTests(unittest.TestCase):
"actions": [ "actions": [
{"title": "reinstall", "mode": "install"}, {"title": "reinstall", "mode": "install"},
{"title": "recover", "mode": "recover"}, {"title": "recover", "mode": "recover"},
] ],
}, },
{ {
"label": "other", "label": "other",
@ -66,30 +66,34 @@ class RecoverySystemsModelTests(unittest.TestCase):
}, },
"actions": [ "actions": [
{"title": "reinstall", "mode": "install"}, {"title": "reinstall", "mode": "install"},
] ],
} },
] ]
} }
def test_from_systems_stream_happy(self): def test_from_systems_stream_happy(self):
raw = json.dumps(self.reference) raw = json.dumps(self.reference)
systems = RecoverySystemsModel.from_systems_stream(StringIO(raw)) systems = RecoverySystemsModel.from_systems_stream(StringIO(raw))
exp = RecoverySystemsModel([ exp = RecoverySystemsModel(
[
RecoverySystem( RecoverySystem(
current=True, current=True,
label="1234", label="1234",
model=SystemModel( model=SystemModel(
model="core20-amd64", model="core20-amd64",
brand_id="brand-id", brand_id="brand-id",
display_name="Core 20 AMD64 system"), display_name="Core 20 AMD64 system",
),
brand=Brand( brand=Brand(
ID="brand-id", ID="brand-id",
username="brand-username", username="brand-username",
display_name="this is my brand", display_name="this is my brand",
validation="verified"), validation="verified",
actions=[SystemAction(title="reinstall", mode="install"), ),
SystemAction(title="recover", mode="recover")] actions=[
SystemAction(title="reinstall", mode="install"),
SystemAction(title="recover", mode="recover"),
],
), ),
RecoverySystem( RecoverySystem(
current=False, current=False,
@ -97,15 +101,18 @@ class RecoverySystemsModelTests(unittest.TestCase):
model=SystemModel( model=SystemModel(
model="my-brand-box", model="my-brand-box",
brand_id="other-brand-id", brand_id="other-brand-id",
display_name="Funky box"), display_name="Funky box",
),
brand=Brand( brand=Brand(
ID="other-brand-id", ID="other-brand-id",
username="other-brand", username="other-brand",
display_name="my brand", display_name="my brand",
validation="unproven"), validation="unproven",
actions=[SystemAction(title="reinstall", mode="install")]
), ),
]) actions=[SystemAction(title="reinstall", mode="install")],
),
]
)
self.assertEqual(systems.systems, exp.systems) self.assertEqual(systems.systems, exp.systems)
self.assertEqual(systems.current, exp.systems[0]) self.assertEqual(systems.current, exp.systems[0])
@ -114,7 +121,8 @@ class RecoverySystemsModelTests(unittest.TestCase):
RecoverySystemsModel.from_systems_stream(StringIO("{}")) RecoverySystemsModel.from_systems_stream(StringIO("{}"))
def test_from_systems_stream_invalid_missing_system_label(self): def test_from_systems_stream_invalid_missing_system_label(self):
raw = json.dumps({ raw = json.dumps(
{
"systems": [ "systems": [
{ {
"brand": { "brand": {
@ -131,15 +139,17 @@ class RecoverySystemsModelTests(unittest.TestCase):
"actions": [ "actions": [
{"title": "reinstall", "mode": "install"}, {"title": "reinstall", "mode": "install"},
{"title": "recover", "mode": "recover"}, {"title": "recover", "mode": "recover"},
] ],
}, },
] ]
}) }
)
with self.assertRaises(jsonschema.ValidationError): with self.assertRaises(jsonschema.ValidationError):
RecoverySystemsModel.from_systems_stream(StringIO(raw)) RecoverySystemsModel.from_systems_stream(StringIO(raw))
def test_from_systems_stream_invalid_missing_brand(self): def test_from_systems_stream_invalid_missing_brand(self):
raw = json.dumps({ raw = json.dumps(
{
"systems": [ "systems": [
{ {
"label": "1234", "label": "1234",
@ -151,15 +161,17 @@ class RecoverySystemsModelTests(unittest.TestCase):
"actions": [ "actions": [
{"title": "reinstall", "mode": "install"}, {"title": "reinstall", "mode": "install"},
{"title": "recover", "mode": "recover"}, {"title": "recover", "mode": "recover"},
] ],
}, },
] ]
}) }
)
with self.assertRaises(jsonschema.ValidationError): with self.assertRaises(jsonschema.ValidationError):
RecoverySystemsModel.from_systems_stream(StringIO(raw)) RecoverySystemsModel.from_systems_stream(StringIO(raw))
def test_from_systems_stream_invalid_missing_model(self): def test_from_systems_stream_invalid_missing_model(self):
raw = json.dumps({ raw = json.dumps(
{
"systems": [ "systems": [
{ {
"label": "1234", "label": "1234",
@ -172,15 +184,17 @@ class RecoverySystemsModelTests(unittest.TestCase):
"actions": [ "actions": [
{"title": "reinstall", "mode": "install"}, {"title": "reinstall", "mode": "install"},
{"title": "recover", "mode": "recover"}, {"title": "recover", "mode": "recover"},
] ],
}, },
] ]
}) }
)
with self.assertRaises(jsonschema.ValidationError): with self.assertRaises(jsonschema.ValidationError):
RecoverySystemsModel.from_systems_stream(StringIO(raw)) RecoverySystemsModel.from_systems_stream(StringIO(raw))
def test_from_systems_stream_valid_no_actions(self): def test_from_systems_stream_valid_no_actions(self):
raw = json.dumps({ raw = json.dumps(
{
"systems": [ "systems": [
{ {
"label": "1234", "label": "1234",
@ -198,17 +212,20 @@ class RecoverySystemsModelTests(unittest.TestCase):
"actions": [], "actions": [],
}, },
] ]
}) }
)
RecoverySystemsModel.from_systems_stream(StringIO(raw)) RecoverySystemsModel.from_systems_stream(StringIO(raw))
def test_selection(self): def test_selection(self):
raw = json.dumps(self.reference) raw = json.dumps(self.reference)
model = RecoverySystemsModel.from_systems_stream(StringIO(raw)) model = RecoverySystemsModel.from_systems_stream(StringIO(raw))
model.select(model.systems[1], model.systems[1].actions[0]) model.select(model.systems[1], model.systems[1].actions[0])
self.assertEqual(model.selection, self.assertEqual(
model.selection,
SelectedSystemAction( SelectedSystemAction(
system=model.systems[1], system=model.systems[1], action=model.systems[1].actions[0]
action=model.systems[1].actions[0])) ),
)
model.unselect() model.unselect()
self.assertIsNone(model.selection) self.assertIsNone(model.selection)
@ -221,13 +238,16 @@ class RecoverySystemsModelTests(unittest.TestCase):
stream = StringIO() stream = StringIO()
RecoverySystemsModel.to_response_stream(model.selection, stream) RecoverySystemsModel.to_response_stream(model.selection, stream)
fromjson = json.loads(stream.getvalue()) fromjson = json.loads(stream.getvalue())
self.assertEqual(fromjson, { self.assertEqual(
fromjson,
{
"label": "other", "label": "other",
"action": { "action": {
"mode": "install", "mode": "install",
"title": "reinstall", "title": "reinstall",
}, },
}) },
)
def test_no_current(self): def test_no_current(self):
reference = { reference = {
@ -249,7 +269,7 @@ class RecoverySystemsModelTests(unittest.TestCase):
"actions": [ "actions": [
{"title": "reinstall", "mode": "install"}, {"title": "reinstall", "mode": "install"},
{"title": "recover", "mode": "recover"}, {"title": "recover", "mode": "recover"},
] ],
}, },
], ],
} }

View File

@ -15,14 +15,10 @@
""" ConsoleConf UI Views """ """ ConsoleConf UI Views """
from .chooser import ChooserConfirmView, ChooserCurrentSystemView, ChooserView
from .identity import IdentityView from .identity import IdentityView
from .login import LoginView from .login import LoginView
from .welcome import WelcomeView, ChooserWelcomeView from .welcome import ChooserWelcomeView, WelcomeView
from .chooser import (
ChooserView,
ChooserCurrentSystemView,
ChooserConfirmView
)
__all__ = [ __all__ = [
"IdentityView", "IdentityView",

View File

@ -20,29 +20,14 @@ Chooser provides a view with recovery chooser actions.
""" """
import logging import logging
from urwid import ( from urwid import Text, connect_signal
connect_signal,
Text,
)
from subiquitycore.ui.buttons import (
danger_btn,
forward_btn,
back_btn,
)
from subiquitycore.ui.actionmenu import (
Action,
ActionMenu,
)
from subiquitycore.ui.container import Pile, ListBox
from subiquitycore.ui.utils import (
button_pile,
screen,
make_action_menu_row,
Color,
)
from subiquitycore.ui.table import TableRow, TablePile
from subiquitycore.view import BaseView
from subiquitycore.ui.actionmenu import Action, ActionMenu
from subiquitycore.ui.buttons import back_btn, danger_btn, forward_btn
from subiquitycore.ui.container import ListBox, Pile
from subiquitycore.ui.table import TablePile, TableRow
from subiquitycore.ui.utils import Color, button_pile, make_action_menu_row, screen
from subiquitycore.view import BaseView
log = logging.getLogger("console_conf.ui.views.chooser") log = logging.getLogger("console_conf.ui.views.chooser")
@ -69,28 +54,31 @@ class ChooserCurrentSystemView(ChooserBaseView):
def __init__(self, controller, current, has_more=False): def __init__(self, controller, current, has_more=False):
self.controller = controller self.controller = controller
log.debug('more systems available: %s', has_more) log.debug("more systems available: %s", has_more)
log.debug('current system: %s', current) log.debug("current system: %s", current)
actions = [] actions = []
for action in sorted(current.actions, key=by_preferred_action_type): for action in sorted(current.actions, key=by_preferred_action_type):
actions.append(forward_btn(label=action.title.capitalize(), actions.append(
forward_btn(
label=action.title.capitalize(),
on_press=self._current_system_action, on_press=self._current_system_action,
user_arg=(current, action))) user_arg=(current, action),
)
)
if has_more: if has_more:
# add a button to show the other systems # add a button to show the other systems
actions.append(Text("")) actions.append(Text(""))
actions.append(forward_btn(label="Show all available systems", actions.append(
on_press=self._more_options)) forward_btn(
label="Show all available systems", on_press=self._more_options
)
)
lb = ListBox(actions) lb = ListBox(actions)
super().__init__(current, super().__init__(current, screen(lb, narrow_rows=True, excerpt=self.excerpt))
screen(
lb,
narrow_rows=True,
excerpt=self.excerpt))
def _current_system_action(self, sender, arg): def _current_system_action(self, sender, arg):
current, action = arg current, action = arg
@ -104,43 +92,54 @@ class ChooserCurrentSystemView(ChooserBaseView):
class ChooserView(ChooserBaseView): class ChooserView(ChooserBaseView):
excerpt = ("Select one of available recovery systems and a desired " excerpt = (
"action to execute.") "Select one of available recovery systems and a desired " "action to execute."
)
def __init__(self, controller, systems): def __init__(self, controller, systems):
self.controller = controller self.controller = controller
heading_table = TablePile([ heading_table = TablePile(
TableRow([ [
Color.info_minor(Text(header)) for header in [ TableRow(
"LABEL", "MODEL", "PUBLISHER", "" [
Color.info_minor(Text(header))
for header in ["LABEL", "MODEL", "PUBLISHER", ""]
] ]
]) )
], ],
spacing=2) spacing=2,
)
trows = [] trows = []
systems = sorted(systems, systems = sorted(
key=lambda s: (s.brand.display_name, systems,
key=lambda s: (
s.brand.display_name,
s.model.display_name, s.model.display_name,
s.current, s.current,
s.label)) s.label,
),
)
for s in systems: for s in systems:
actions = [] actions = []
log.debug('actions: %s', s.actions) log.debug("actions: %s", s.actions)
for act in sorted(s.actions, key=by_preferred_action_type): for act in sorted(s.actions, key=by_preferred_action_type):
actions.append(Action(label=act.title.capitalize(), actions.append(
value=act, Action(label=act.title.capitalize(), value=act, enabled=True)
enabled=True)) )
menu = ActionMenu(actions) menu = ActionMenu(actions)
connect_signal(menu, 'action', self._system_action, s) connect_signal(menu, "action", self._system_action, s)
srow = make_action_menu_row([ srow = make_action_menu_row(
[
Text(s.label), Text(s.label),
Text(s.model.display_name), Text(s.model.display_name),
Text(s.brand.display_name), Text(s.brand.display_name),
Text("(installed)" if s.current else ""), Text("(installed)" if s.current else ""),
menu, menu,
], menu) ],
menu,
)
trows.append(srow) trows.append(srow)
systems_table = TablePile(trows, spacing=2) systems_table = TablePile(trows, spacing=2)
@ -154,12 +153,15 @@ class ChooserView(ChooserBaseView):
# back to options of current system # back to options of current system
buttons.append(back_btn("BACK", on_press=self.back)) buttons.append(back_btn("BACK", on_press=self.back))
super().__init__(controller.model.current, super().__init__(
controller.model.current,
screen( screen(
rows=rows, rows=rows,
buttons=button_pile(buttons), buttons=button_pile(buttons),
focus_buttons=False, focus_buttons=False,
excerpt=self.excerpt)) excerpt=self.excerpt,
),
)
def _system_action(self, sender, action, system): def _system_action(self, sender, action, system):
self.controller.select(system, action) self.controller.select(system, action)
@ -169,22 +171,24 @@ class ChooserView(ChooserBaseView):
class ChooserConfirmView(ChooserBaseView): class ChooserConfirmView(ChooserBaseView):
canned_summary = { canned_summary = {
"run": "Continue running the system without any changes.", "run": "Continue running the system without any changes.",
"recover": ("You have requested to reboot the system into recovery " "recover": ("You have requested to reboot the system into recovery " "mode."),
"mode."), "install": (
"install": ("You are about to {action_lower} the system version " "You are about to {action_lower} the system version "
"{version} for {model} from {publisher}.\n\n" "{version} for {model} from {publisher}.\n\n"
"This will remove all existing user data on the " "This will remove all existing user data on the "
"device.\n\n" "device.\n\n"
"The system will reboot in the process."), "The system will reboot in the process."
),
} }
default_summary = ("You are about to execute action \"{action}\" using " default_summary = (
'You are about to execute action "{action}" using '
"system version {version} for device {model} from " "system version {version} for device {model} from "
"{publisher}.\n\n" "{publisher}.\n\n"
"Make sure you understand the consequences of " "Make sure you understand the consequences of "
"performing this action.") "performing this action."
)
def __init__(self, controller, selection): def __init__(self, controller, selection):
self.controller = controller self.controller = controller
@ -193,21 +197,21 @@ class ChooserConfirmView(ChooserBaseView):
danger_btn("CONFIRM", on_press=self.confirm), danger_btn("CONFIRM", on_press=self.confirm),
back_btn("BACK", on_press=self.back), back_btn("BACK", on_press=self.back),
] ]
fmt = self.canned_summary.get(selection.action.mode, fmt = self.canned_summary.get(selection.action.mode, self.default_summary)
self.default_summary) summary = fmt.format(
summary = fmt.format(action=selection.action.title, action=selection.action.title,
action_lower=selection.action.title.lower(), action_lower=selection.action.title.lower(),
model=selection.system.model.display_name, model=selection.system.model.display_name,
publisher=selection.system.brand.display_name, publisher=selection.system.brand.display_name,
version=selection.system.label) version=selection.system.label,
)
rows = [ rows = [
Text(summary), Text(summary),
] ]
super().__init__(controller.model.current, super().__init__(
screen( controller.model.current,
rows=rows, screen(rows=rows, buttons=button_pile(buttons), focus_buttons=False),
buttons=button_pile(buttons), )
focus_buttons=False))
def confirm(self, result): def confirm(self, result):
self.controller.confirm() self.controller.confirm()

View File

@ -14,20 +14,18 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging import logging
from urwid import connect_signal from urwid import connect_signal
from subiquitycore.view import BaseView from subiquitycore.ui.form import EmailField, Form
from subiquitycore.ui.utils import SomethingFailed from subiquitycore.ui.utils import SomethingFailed
from subiquitycore.ui.form import ( from subiquitycore.view import BaseView
Form,
EmailField,
)
log = logging.getLogger("console_conf.ui.views.identity") log = logging.getLogger("console_conf.ui.views.identity")
sso_help = ("If you do not have an account, visit " sso_help = (
"https://login.ubuntu.com to create one.") "If you do not have an account, visit " "https://login.ubuntu.com to create one."
)
class IdentityForm(Form): class IdentityForm(Form):
@ -45,11 +43,12 @@ class IdentityView(BaseView):
self.controller = controller self.controller = controller
self.form = IdentityForm() self.form = IdentityForm()
connect_signal(self.form, 'submit', self.done) connect_signal(self.form, "submit", self.done)
connect_signal(self.form, 'cancel', self.cancel) connect_signal(self.form, "cancel", self.cancel)
super().__init__(self.form.as_screen(focus_buttons=False, super().__init__(
excerpt=_(self.excerpt))) self.form.as_screen(focus_buttons=False, excerpt=_(self.excerpt))
)
def cancel(self, button=None): def cancel(self, button=None):
self.controller.cancel() self.controller.cancel()

View File

@ -24,7 +24,7 @@ from urwid import Text
from subiquitycore.ui.buttons import done_btn from subiquitycore.ui.buttons import done_btn
from subiquitycore.ui.container import ListBox, Pile from subiquitycore.ui.container import ListBox, Pile
from subiquitycore.ui.utils import button_pile, Padding from subiquitycore.ui.utils import Padding, button_pile
from subiquitycore.view import BaseView from subiquitycore.view import BaseView
log = logging.getLogger("console_conf.ui.views.login") log = logging.getLogger("console_conf.ui.views.login")
@ -41,15 +41,23 @@ class LoginView(BaseView):
self.items = [] self.items = []
super().__init__( super().__init__(
Pile([ Pile(
('pack', Text("")), [
("pack", Text("")),
Padding.center_79(ListBox(self._build_model_inputs())), Padding.center_79(ListBox(self._build_model_inputs())),
('pack', Pile([ (
('pack', Text("")), "pack",
Pile(
[
("pack", Text("")),
button_pile(self._build_buttons()), button_pile(self._build_buttons()),
('pack', Text("")), ("pack", Text("")),
])), ]
])) ),
),
]
)
)
def _build_buttons(self): def _build_buttons(self):
return [ return [
@ -68,19 +76,19 @@ class LoginView(BaseView):
sl.append(Text("no owner")) sl.append(Text("no owner"))
return sl return sl
local_tpl = ( local_tpl = "This device is registered to {realname}."
"This device is registered to {realname}.")
remote_tpl = ( remote_tpl = (
"\n\nRemote access was enabled via authentication with SSO user" "\n\nRemote access was enabled via authentication with SSO user"
" <{username}>.\nPublic SSH keys were added to the device " " <{username}>.\nPublic SSH keys were added to the device "
"for remote access.\n\n{realname} can connect remotely to this " "for remote access.\n\n{realname} can connect remotely to this "
"device via SSH:") "device via SSH:"
)
sl = [] sl = []
login_info = { login_info = {
'realname': user.realname, "realname": user.realname,
'username': user.username, "username": user.username,
} }
login_text = local_tpl.format(**login_info) login_text = local_tpl.format(**login_info)
login_text += remote_tpl.format(**login_info) login_text += remote_tpl.format(**login_info)

View File

@ -14,16 +14,11 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from console_conf.models.systems import ( from console_conf.models.systems import SystemAction
SystemAction, from console_conf.ui.views.chooser import by_preferred_action_type
)
from console_conf.ui.views.chooser import (
by_preferred_action_type,
)
class TestSorting(unittest.TestCase): class TestSorting(unittest.TestCase):
def test_simple(self): def test_simple(self):
data = [ data = [
SystemAction(title="b-other", mode="unknown"), SystemAction(title="b-other", mode="unknown"),
@ -39,5 +34,4 @@ class TestSorting(unittest.TestCase):
SystemAction(title="a-other", mode="some-other"), SystemAction(title="a-other", mode="some-other"),
SystemAction(title="b-other", mode="unknown"), SystemAction(title="b-other", mode="unknown"),
] ]
self.assertSequenceEqual(sorted(data, key=by_preferred_action_type), self.assertSequenceEqual(sorted(data, key=by_preferred_action_type), exp)
exp)

View File

@ -16,8 +16,9 @@
""" Subiquity """ """ Subiquity """
from subiquitycore import i18n from subiquitycore import i18n
__all__ = [ __all__ = [
'i18n', "i18n",
] ]
__version__ = "0.0.5" __version__ = "0.0.5"

View File

@ -1,5 +1,6 @@
import sys import sys
if __name__ == '__main__': if __name__ == "__main__":
from subiquity.cmd.tui import main from subiquity.cmd.tui import main
sys.exit(main()) sys.exit(main())

View File

@ -25,46 +25,28 @@ from typing import Callable, Dict, List, Optional, Union
import aiohttp import aiohttp
from subiquitycore.async_helpers import (
run_bg_task,
run_in_thread,
)
from subiquitycore.screen import is_linux_tty
from subiquitycore.tuicontroller import Skip
from subiquitycore.tui import TuiApplication
from subiquitycore.utils import orig_environ
from subiquitycore.view import BaseView
from subiquity.client.controller import Confirm from subiquity.client.controller import Confirm
from subiquity.client.keycodes import ( from subiquity.client.keycodes import KeyCodesFilter, NoOpKeycodesFilter
NoOpKeycodesFilter,
KeyCodesFilter,
)
from subiquity.common.api.client import make_client_for_conn from subiquity.common.api.client import make_client_for_conn
from subiquity.common.apidef import API from subiquity.common.apidef import API
from subiquity.common.errorreport import ( from subiquity.common.errorreport import ErrorReporter
ErrorReporter,
)
from subiquity.common.serialize import from_json from subiquity.common.serialize import from_json
from subiquity.common.types import ( from subiquity.common.types import ApplicationState, ErrorReportKind, ErrorReportRef
ApplicationState,
ErrorReportKind,
ErrorReportRef,
)
from subiquity.journald import journald_listen from subiquity.journald import journald_listen
from subiquity.server.server import POSTINSTALL_MODEL_NAMES from subiquity.server.server import POSTINSTALL_MODEL_NAMES
from subiquity.ui.frame import SubiquityUI from subiquity.ui.frame import SubiquityUI
from subiquity.ui.views.error import ErrorReportStretchy from subiquity.ui.views.error import ErrorReportStretchy
from subiquity.ui.views.help import HelpMenu, ssh_help_texts from subiquity.ui.views.help import HelpMenu, ssh_help_texts
from subiquity.ui.views.installprogress import ( from subiquity.ui.views.installprogress import InstallConfirmation
InstallConfirmation, from subiquity.ui.views.welcome import CloudInitFail
) from subiquitycore.async_helpers import run_bg_task, run_in_thread
from subiquity.ui.views.welcome import ( from subiquitycore.screen import is_linux_tty
CloudInitFail, from subiquitycore.tui import TuiApplication
) from subiquitycore.tuicontroller import Skip
from subiquitycore.utils import orig_environ
from subiquitycore.view import BaseView
log = logging.getLogger("subiquity.client.client")
log = logging.getLogger('subiquity.client.client')
class Abort(Exception): class Abort(Exception):
@ -72,7 +54,8 @@ class Abort(Exception):
self.error_report_ref = error_report_ref self.error_report_ref = error_report_ref
DEBUG_SHELL_INTRO = _("""\ DEBUG_SHELL_INTRO = _(
"""\
Installer shell session activated. Installer shell session activated.
This shell session is running inside the installer environment. You This shell session is running inside the installer environment. You
@ -81,18 +64,19 @@ example by typing Control-D or 'exit'.
Be aware that this is an ephemeral environment. Changes to this Be aware that this is an ephemeral environment. Changes to this
environment will not survive a reboot. If the install has started, the environment will not survive a reboot. If the install has started, the
installed system will be mounted at /target.""") installed system will be mounted at /target."""
)
class SubiquityClient(TuiApplication): class SubiquityClient(TuiApplication):
snapd_socket_path: Optional[str] = "/run/snapd.socket"
snapd_socket_path: Optional[str] = '/run/snapd.socket'
variant: Optional[str] = None variant: Optional[str] = None
cmdline = ['snap', 'run', 'subiquity'] cmdline = ["snap", "run", "subiquity"]
dryrun_cmdline_module = 'subiquity.cmd.tui' dryrun_cmdline_module = "subiquity.cmd.tui"
from subiquity.client import controllers as controllers_mod from subiquity.client import controllers as controllers_mod
project = "subiquity" project = "subiquity"
def make_model(self): def make_model(self):
@ -142,11 +126,11 @@ class SubiquityClient(TuiApplication):
except OSError: except OSError:
self.our_tty = "not a tty" self.our_tty = "not a tty"
self.in_make_view_cvar = contextvars.ContextVar( self.in_make_view_cvar = contextvars.ContextVar("in_make_view", default=False)
'in_make_view', default=False)
self.error_reporter = ErrorReporter( self.error_reporter = ErrorReporter(
self.context.child("ErrorReporter"), self.opts.dry_run, self.root) self.context.child("ErrorReporter"), self.opts.dry_run, self.root
)
self.note_data_for_apport("SnapUpdated", str(self.updated)) self.note_data_for_apport("SnapUpdated", str(self.updated))
self.note_data_for_apport("UsingAnswers", str(bool(self.answers))) self.note_data_for_apport("UsingAnswers", str(bool(self.answers)))
@ -169,12 +153,9 @@ class SubiquityClient(TuiApplication):
def restart(self, remove_last_screen=True, restart_server=False): def restart(self, remove_last_screen=True, restart_server=False):
log.debug(f"restart {remove_last_screen} {restart_server}") log.debug(f"restart {remove_last_screen} {restart_server}")
if self.fg_proc is not None: if self.fg_proc is not None:
log.debug( log.debug("killing foreground process %s before restarting", self.fg_proc)
"killing foreground process %s before restarting",
self.fg_proc)
self.restarting = True self.restarting = True
run_bg_task( run_bg_task(self._kill_fg_proc(remove_last_screen, restart_server))
self._kill_fg_proc(remove_last_screen, restart_server))
return return
if remove_last_screen: if remove_last_screen:
self._remove_last_screen() self._remove_last_screen()
@ -187,50 +168,56 @@ class SubiquityClient(TuiApplication):
self.urwid_loop.stop() self.urwid_loop.stop()
cmdline = self.cmdline cmdline = self.cmdline
if self.opts.dry_run: if self.opts.dry_run:
cmdline = [ cmdline = (
sys.executable, '-m', self.dryrun_cmdline_module, [
] + sys.argv[1:] + ['--socket', self.opts.socket] sys.executable,
"-m",
self.dryrun_cmdline_module,
]
+ sys.argv[1:]
+ ["--socket", self.opts.socket]
)
if self.opts.server_pid is not None: if self.opts.server_pid is not None:
cmdline.extend(['--server-pid', self.opts.server_pid]) cmdline.extend(["--server-pid", self.opts.server_pid])
log.debug("restarting %r", cmdline) log.debug("restarting %r", cmdline)
os.execvpe(cmdline[0], cmdline, orig_environ(os.environ)) os.execvpe(cmdline[0], cmdline, orig_environ(os.environ))
def resp_hook(self, response): def resp_hook(self, response):
headers = response.headers headers = response.headers
if 'x-updated' in headers: if "x-updated" in headers:
if self.server_updated is None: if self.server_updated is None:
self.server_updated = headers['x-updated'] self.server_updated = headers["x-updated"]
elif self.server_updated != headers['x-updated']: elif self.server_updated != headers["x-updated"]:
self.restart(remove_last_screen=False) self.restart(remove_last_screen=False)
raise Abort raise Abort
status = headers.get('x-status') status = headers.get("x-status")
if status == 'skip': if status == "skip":
raise Skip raise Skip
elif status == 'confirm': elif status == "confirm":
raise Confirm raise Confirm
if headers.get('x-error-report') is not None: if headers.get("x-error-report") is not None:
ref = from_json(ErrorReportRef, headers['x-error-report']) ref = from_json(ErrorReportRef, headers["x-error-report"])
raise Abort(ref) raise Abort(ref)
try: try:
response.raise_for_status() response.raise_for_status()
except aiohttp.ClientError: except aiohttp.ClientError:
report = self.error_reporter.make_apport_report( report = self.error_reporter.make_apport_report(
ErrorReportKind.SERVER_REQUEST_FAIL, ErrorReportKind.SERVER_REQUEST_FAIL,
"request to {}".format(response.url.path)) "request to {}".format(response.url.path),
)
raise Abort(report.ref()) raise Abort(report.ref())
return response return response
async def noninteractive_confirmation(self): async def noninteractive_confirmation(self):
await asyncio.sleep(1) await asyncio.sleep(1)
yes = _('yes') yes = _("yes")
no = _('no') no = _("no")
answer = no answer = no
print(_("Confirmation is required to continue.")) print(_("Confirmation is required to continue."))
print(_("Add 'autoinstall' to your kernel command line to avoid this")) print(_("Add 'autoinstall' to your kernel command line to avoid this"))
print() print()
prompt = "\n\n{} ({}|{})".format( prompt = "\n\n{} ({}|{})".format(_("Continue with autoinstall?"), yes, no)
_("Continue with autoinstall?"), yes, no)
while answer != yes: while answer != yes:
print(prompt) print(prompt)
answer = await run_in_thread(input) answer = await run_in_thread(input)
@ -260,7 +247,8 @@ class SubiquityClient(TuiApplication):
if app_state == ApplicationState.NEEDS_CONFIRMATION: if app_state == ApplicationState.NEEDS_CONFIRMATION:
if confirm_task is None: if confirm_task is None:
confirm_task = asyncio.create_task( confirm_task = asyncio.create_task(
self.noninteractive_confirmation()) self.noninteractive_confirmation()
)
elif confirm_task is not None: elif confirm_task is not None:
confirm_task.cancel() confirm_task.cancel()
confirm_task = None confirm_task = None
@ -272,21 +260,20 @@ class SubiquityClient(TuiApplication):
app_status = await self._status_get(app_state) app_status = await self._status_get(app_state)
def subiquity_event_noninteractive(self, event): def subiquity_event_noninteractive(self, event):
if event['SUBIQUITY_EVENT_TYPE'] == 'start': if event["SUBIQUITY_EVENT_TYPE"] == "start":
print('start: ' + event["MESSAGE"]) print("start: " + event["MESSAGE"])
elif event['SUBIQUITY_EVENT_TYPE'] == 'finish': elif event["SUBIQUITY_EVENT_TYPE"] == "finish":
print('finish: ' + event["MESSAGE"]) print("finish: " + event["MESSAGE"])
async def connect(self): async def connect(self):
def p(s): def p(s):
print(s, end='', flush=True) print(s, end="", flush=True)
async def spin(message): async def spin(message):
p(message + '... ') p(message + "... ")
while True: while True:
for t in ['-', '\\', '|', '/']: for t in ["-", "\\", "|", "/"]:
p('\x08' + t) p("\x08" + t)
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
async def spinning_wait(message, task): async def spinning_wait(message, task):
@ -295,16 +282,18 @@ class SubiquityClient(TuiApplication):
return await task return await task
finally: finally:
spinner.cancel() spinner.cancel()
p('\x08 \n') p("\x08 \n")
status = await spinning_wait("connecting", self._status_get()) status = await spinning_wait("connecting", self._status_get())
journald_listen([status.echo_syslog_id], lambda e: print(e['MESSAGE'])) journald_listen([status.echo_syslog_id], lambda e: print(e["MESSAGE"]))
if status.state == ApplicationState.STARTING_UP: if status.state == ApplicationState.STARTING_UP:
status = await spinning_wait( status = await spinning_wait(
"starting up", self._status_get(cur=status.state)) "starting up", self._status_get(cur=status.state)
)
if status.state == ApplicationState.CLOUD_INIT_WAIT: if status.state == ApplicationState.CLOUD_INIT_WAIT:
status = await spinning_wait( status = await spinning_wait(
"waiting for cloud-init", self._status_get(cur=status.state)) "waiting for cloud-init", self._status_get(cur=status.state)
)
if status.state == ApplicationState.EARLY_COMMANDS: if status.state == ApplicationState.EARLY_COMMANDS:
print("running early commands") print("running early commands")
status = await self._status_get(cur=status.state) status = await self._status_get(cur=status.state)
@ -316,12 +305,13 @@ class SubiquityClient(TuiApplication):
def header_func(): def header_func():
if self.in_make_view_cvar.get(): if self.in_make_view_cvar.get():
return {'x-make-view-request': 'yes'} return {"x-make-view-request": "yes"}
else: else:
return None return None
self.client = make_client_for_conn(API, conn, self.resp_hook, self.client = make_client_for_conn(
header_func=header_func) API, conn, self.resp_hook, header_func=header_func
)
self.error_reporter.client = self.client self.error_reporter.client = self.client
status = await self.connect() status = await self.connect()
@ -332,11 +322,14 @@ class SubiquityClient(TuiApplication):
texts = ssh_help_texts(ssh_info) texts = ssh_help_texts(ssh_info)
for line in texts: for line in texts:
import urwid import urwid
if isinstance(line, urwid.Widget): if isinstance(line, urwid.Widget):
line = '\n'.join([ line = "\n".join(
line.decode('utf-8').rstrip() [
line.decode("utf-8").rstrip()
for line in line.render((1000,)).text for line in line.render((1000,)).text
]) ]
)
print(line) print(line)
return return
@ -355,11 +348,11 @@ class SubiquityClient(TuiApplication):
# the progress page # the progress page
if hasattr(self.controllers, "Progress"): if hasattr(self.controllers, "Progress"):
journald_listen( journald_listen(
[status.event_syslog_id], [status.event_syslog_id], self.controllers.Progress.event
self.controllers.Progress.event) )
journald_listen( journald_listen(
[status.log_syslog_id], [status.log_syslog_id], self.controllers.Progress.log_line
self.controllers.Progress.log_line) )
if not status.cloud_init_ok: if not status.cloud_init_ok:
self.add_global_overlay(CloudInitFail(self)) self.add_global_overlay(CloudInitFail(self))
self.error_reporter.load_reports() self.error_reporter.load_reports()
@ -376,18 +369,16 @@ class SubiquityClient(TuiApplication):
# does not matter as the settings will soon be clobbered but # does not matter as the settings will soon be clobbered but
# for a non-interactive one we need to clear things up or the # for a non-interactive one we need to clear things up or the
# prompting for confirmation will be confusing. # prompting for confirmation will be confusing.
os.system('stty sane') os.system("stty sane")
journald_listen( journald_listen(
[status.event_syslog_id], [status.event_syslog_id], self.subiquity_event_noninteractive, seek=True
self.subiquity_event_noninteractive, )
seek=True) run_bg_task(self.noninteractive_watch_app_state(status))
run_bg_task(
self.noninteractive_watch_app_state(status))
def _exception_handler(self, loop, context): def _exception_handler(self, loop, context):
exc = context.get('exception') exc = context.get("exception")
if self.restarting: if self.restarting:
log.debug('ignoring %s %s during restart', exc, type(exc)) log.debug("ignoring %s %s during restart", exc, type(exc))
return return
if isinstance(exc, Abort): if isinstance(exc, Abort):
if self.interactive: if self.interactive:
@ -405,8 +396,8 @@ class SubiquityClient(TuiApplication):
print("generating crash report") print("generating crash report")
try: try:
report = self.make_apport_report( report = self.make_apport_report(
ErrorReportKind.UI, "Installer UI", interrupt=False, ErrorReportKind.UI, "Installer UI", interrupt=False, wait=True
wait=True) )
if report is not None: if report is not None:
print("report saved to {path}".format(path=report.path)) print("report saved to {path}".format(path=report.path))
except Exception: except Exception:
@ -426,7 +417,7 @@ class SubiquityClient(TuiApplication):
# the server up to a second to exit, and then we signal it. # the server up to a second to exit, and then we signal it.
pid = int(self.opts.server_pid) pid = int(self.opts.server_pid)
print(f'giving the server [{pid}] up to a second to exit') print(f"giving the server [{pid}] up to a second to exit")
for unused in range(10): for unused in range(10):
try: try:
if os.waitpid(pid, os.WNOHANG) != (0, 0): if os.waitpid(pid, os.WNOHANG) != (0, 0):
@ -435,9 +426,9 @@ class SubiquityClient(TuiApplication):
# If we attached to an existing server process, # If we attached to an existing server process,
# waitpid will fail. # waitpid will fail.
pass pass
await asyncio.sleep(.1) await asyncio.sleep(0.1)
else: else:
print('killing server {}'.format(pid)) print("killing server {}".format(pid))
os.kill(pid, 2) os.kill(pid, 2)
os.waitpid(pid, 0) os.waitpid(pid, 0)
@ -448,21 +439,19 @@ class SubiquityClient(TuiApplication):
if source.id == source_selection.current_id: if source.id == source_selection.current_id:
current = source current = source
break break
if current is not None and current.variant != 'server': if current is not None and current.variant != "server":
# If using server to install desktop, mark the controllers # If using server to install desktop, mark the controllers
# the TUI client does not currently have interfaces for as # the TUI client does not currently have interfaces for as
# configured. # configured.
needed = POSTINSTALL_MODEL_NAMES.for_variant(current.variant) needed = POSTINSTALL_MODEL_NAMES.for_variant(current.variant)
for c in self.controllers.instances: for c in self.controllers.instances:
if getattr(c, 'endpoint_name', None) is not None: if getattr(c, "endpoint_name", None) is not None:
needed.discard(c.endpoint_name) needed.discard(c.endpoint_name)
if needed: if needed:
log.info( log.info("marking additional endpoints as configured: %s", needed)
"marking additional endpoints as configured: %s",
needed)
await self.client.meta.mark_configured.POST(list(needed)) await self.client.meta.mark_configured.POST(list(needed))
# TODO: remove this when TUI gets an Active Directory screen: # TODO: remove this when TUI gets an Active Directory screen:
await self.client.meta.mark_configured.POST(['active_directory']) await self.client.meta.mark_configured.POST(["active_directory"])
await self.client.meta.confirm.POST(self.our_tty) await self.client.meta.confirm.POST(self.our_tty)
def add_global_overlay(self, overlay): def add_global_overlay(self, overlay):
@ -477,7 +466,7 @@ class SubiquityClient(TuiApplication):
self.ui.body.remove_overlay(overlay) self.ui.body.remove_overlay(overlay)
def _remove_last_screen(self): def _remove_last_screen(self):
last_screen = self.state_path('last-screen') last_screen = self.state_path("last-screen")
if os.path.exists(last_screen): if os.path.exists(last_screen):
os.unlink(last_screen) os.unlink(last_screen)
@ -488,7 +477,7 @@ class SubiquityClient(TuiApplication):
def select_initial_screen(self): def select_initial_screen(self):
last_screen = None last_screen = None
if self.updated: if self.updated:
state_path = self.state_path('last-screen') state_path = self.state_path("last-screen")
if os.path.exists(state_path): if os.path.exists(state_path):
with open(state_path) as fp: with open(state_path) as fp:
last_screen = fp.read().strip() last_screen = fp.read().strip()
@ -502,7 +491,7 @@ class SubiquityClient(TuiApplication):
async def _select_initial_screen(self, index): async def _select_initial_screen(self, index):
endpoint_names = [] endpoint_names = []
for c in self.controllers.instances[:index]: for c in self.controllers.instances[:index]:
if getattr(c, 'endpoint_name', None) is not None: if getattr(c, "endpoint_name", None) is not None:
endpoint_names.append(c.endpoint_name) endpoint_names.append(c.endpoint_name)
if endpoint_names: if endpoint_names:
await self.client.meta.mark_configured.POST(endpoint_names) await self.client.meta.mark_configured.POST(endpoint_names)
@ -521,11 +510,12 @@ class SubiquityClient(TuiApplication):
log.debug("showing InstallConfirmation over %s", self.ui.body) log.debug("showing InstallConfirmation over %s", self.ui.body)
overlay = InstallConfirmation(self) overlay = InstallConfirmation(self)
self.add_global_overlay(overlay) self.add_global_overlay(overlay)
if self.answers.get('filesystem-confirmed', False): if self.answers.get("filesystem-confirmed", False):
overlay.ok(None) overlay.ok(None)
async def _start_answers_for_view( async def _start_answers_for_view(
self, controller, view: Union[BaseView, Callable[[], BaseView]]): self, controller, view: Union[BaseView, Callable[[], BaseView]]
):
def noop(): def noop():
return view return view
@ -551,19 +541,19 @@ class SubiquityClient(TuiApplication):
self.in_make_view_cvar.reset(tok) self.in_make_view_cvar.reset(tok)
if new.answers: if new.answers:
run_bg_task(self._start_answers_for_view(new, view)) run_bg_task(self._start_answers_for_view(new, view))
with open(self.state_path('last-screen'), 'w') as fp: with open(self.state_path("last-screen"), "w") as fp:
fp.write(new.name) fp.write(new.name)
return view return view
def show_progress(self): def show_progress(self):
if hasattr(self.controllers, 'Progress'): if hasattr(self.controllers, "Progress"):
self.ui.set_body(self.controllers.Progress.progress_view) self.ui.set_body(self.controllers.Progress.progress_view)
def unhandled_input(self, key): def unhandled_input(self, key):
if key == 'f1': if key == "f1":
if not self.ui.right_icon.current_help: if not self.ui.right_icon.current_help:
self.ui.right_icon.open_pop_up() self.ui.right_icon.open_pop_up()
elif key in ['ctrl z', 'f2']: elif key in ["ctrl z", "f2"]:
self.debug_shell() self.debug_shell()
elif self.opts.dry_run: elif self.opts.dry_run:
self.unhandled_input_dry_run(key) self.unhandled_input_dry_run(key)
@ -571,22 +561,22 @@ class SubiquityClient(TuiApplication):
super().unhandled_input(key) super().unhandled_input(key)
def unhandled_input_dry_run(self, key): def unhandled_input_dry_run(self, key):
if key in ['ctrl e', 'ctrl r']: if key in ["ctrl e", "ctrl r"]:
interrupt = key == 'ctrl e' interrupt = key == "ctrl e"
try: try:
1/0 1 / 0
except ZeroDivisionError: except ZeroDivisionError:
self.make_apport_report( self.make_apport_report(
ErrorReportKind.UNKNOWN, "example", interrupt=interrupt) ErrorReportKind.UNKNOWN, "example", interrupt=interrupt
elif key == 'ctrl u': )
1/0 elif key == "ctrl u":
elif key == 'ctrl b': 1 / 0
elif key == "ctrl b":
run_bg_task(self.client.dry_run.crash.GET()) run_bg_task(self.client.dry_run.crash.GET())
else: else:
super().unhandled_input(key) super().unhandled_input(key)
def debug_shell(self, after_hook=None): def debug_shell(self, after_hook=None):
def _before(): def _before():
os.system("clear") os.system("clear")
print(DEBUG_SHELL_INTRO) print(DEBUG_SHELL_INTRO)
@ -594,7 +584,8 @@ class SubiquityClient(TuiApplication):
env = orig_environ(os.environ) env = orig_environ(os.environ)
cmd = ["bash"] cmd = ["bash"]
self.run_command_in_foreground( self.run_command_in_foreground(
cmd, env=env, before_hook=_before, after_hook=after_hook, cwd='/') cmd, env=env, before_hook=_before, after_hook=after_hook, cwd="/"
)
def note_file_for_apport(self, key, path): def note_file_for_apport(self, key, path):
self.error_reporter.note_file_for_apport(key, path) self.error_reporter.note_file_for_apport(key, path)
@ -603,8 +594,7 @@ class SubiquityClient(TuiApplication):
self.error_reporter.note_data_for_apport(key, value) self.error_reporter.note_data_for_apport(key, value)
def make_apport_report(self, kind, thing, *, interrupt, wait=False, **kw): def make_apport_report(self, kind, thing, *, interrupt, wait=False, **kw):
report = self.error_reporter.make_apport_report( report = self.error_reporter.make_apport_report(kind, thing, wait=wait, **kw)
kind, thing, wait=wait, **kw)
if report is not None and interrupt: if report is not None and interrupt:
self.show_error_report(report.ref()) self.show_error_report(report.ref())
@ -614,7 +604,7 @@ class SubiquityClient(TuiApplication):
def show_error_report(self, error_ref): def show_error_report(self, error_ref):
log.debug("show_error_report %r", error_ref.base) log.debug("show_error_report %r", error_ref.base)
if isinstance(self.ui.body, BaseView): if isinstance(self.ui.body, BaseView):
w = getattr(self.ui.body._w, 'stretchy', None) w = getattr(self.ui.body._w, "stretchy", None)
if isinstance(w, ErrorReportStretchy): if isinstance(w, ErrorReportStretchy):
# Don't show an error if already looking at one. # Don't show an error if already looking at one.
return return

View File

@ -16,9 +16,7 @@
import logging import logging
from typing import Optional from typing import Optional
from subiquitycore.tuicontroller import ( from subiquitycore.tuicontroller import TuiController
TuiController,
)
log = logging.getLogger("subiquity.client.controller") log = logging.getLogger("subiquity.client.controller")
@ -28,7 +26,6 @@ class Confirm(Exception):
class SubiquityTuiController(TuiController): class SubiquityTuiController(TuiController):
endpoint_name: Optional[str] = None endpoint_name: Optional[str] = None
def __init__(self, app): def __init__(self, app):

View File

@ -14,6 +14,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from subiquitycore.tuicontroller import RepeatedController from subiquitycore.tuicontroller import RepeatedController
from .drivers import DriversController from .drivers import DriversController
from .filesystem import FilesystemController from .filesystem import FilesystemController
from .identity import IdentityController from .identity import IdentityController
@ -33,21 +34,21 @@ from .zdev import ZdevController
# see SubiquityClient.controllers for another list # see SubiquityClient.controllers for another list
__all__ = [ __all__ = [
'DriversController', "DriversController",
'FilesystemController', "FilesystemController",
'IdentityController', "IdentityController",
'KeyboardController', "KeyboardController",
'MirrorController', "MirrorController",
'NetworkController', "NetworkController",
'ProgressController', "ProgressController",
'ProxyController', "ProxyController",
'RefreshController', "RefreshController",
'RepeatedController', "RepeatedController",
'SerialController', "SerialController",
'SnapListController', "SnapListController",
'SourceController', "SourceController",
'SSHController', "SSHController",
'UbuntuProController', "UbuntuProController",
'WelcomeController', "WelcomeController",
'ZdevController', "ZdevController",
] ]

View File

@ -17,18 +17,16 @@ import asyncio
import logging import logging
from typing import List from typing import List
from subiquity.client.controller import SubiquityTuiController
from subiquity.common.types import DriversPayload, DriversResponse
from subiquity.ui.views.drivers import DriversView, DriversViewStatus
from subiquitycore.tuicontroller import Skip from subiquitycore.tuicontroller import Skip
from subiquity.common.types import DriversPayload, DriversResponse log = logging.getLogger("subiquity.client.controllers.drivers")
from subiquity.client.controller import SubiquityTuiController
from subiquity.ui.views.drivers import DriversView, DriversViewStatus
log = logging.getLogger('subiquity.client.controllers.drivers')
class DriversController(SubiquityTuiController): class DriversController(SubiquityTuiController):
endpoint_name = "drivers"
endpoint_name = 'drivers'
async def make_ui(self) -> DriversView: async def make_ui(self) -> DriversView:
response: DriversResponse = await self.endpoint.GET() response: DriversResponse = await self.endpoint.GET()
@ -37,8 +35,9 @@ class DriversController(SubiquityTuiController):
await self.endpoint.POST(DriversPayload(install=False)) await self.endpoint.POST(DriversPayload(install=False))
raise Skip raise Skip
return DriversView(self, response.drivers, return DriversView(
response.install, response.local_only) self, response.drivers, response.install, response.local_only
)
async def _wait_drivers(self) -> List[str]: async def _wait_drivers(self) -> List[str]:
response: DriversResponse = await self.endpoint.GET(wait=True) response: DriversResponse = await self.endpoint.GET(wait=True)
@ -46,12 +45,10 @@ class DriversController(SubiquityTuiController):
return response.drivers return response.drivers
async def run_answers(self): async def run_answers(self):
if 'install' not in self.answers: if "install" not in self.answers:
return return
from subiquitycore.testing.view_helpers import ( from subiquitycore.testing.view_helpers import click
click,
)
view = self.app.ui.body view = self.app.ui.body
while view.status == DriversViewStatus.WAITING: while view.status == DriversViewStatus.WAITING:
@ -60,7 +57,7 @@ class DriversController(SubiquityTuiController):
click(view.cont_btn.base_widget) click(view.cont_btn.base_widget)
return return
view.form.install.value = self.answers['install'] view.form.install.value = self.answers["install"]
click(view.form.done_btn.base_widget) click(view.form.done_btn.base_widget)
@ -69,5 +66,4 @@ class DriversController(SubiquityTuiController):
def done(self, install: bool) -> None: def done(self, install: bool) -> None:
log.debug("DriversController.done next_screen install=%s", install) log.debug("DriversController.done next_screen install=%s", install)
self.app.next_screen( self.app.next_screen(self.endpoint.POST(DriversPayload(install=install)))
self.endpoint.POST(DriversPayload(install=install)))

View File

@ -17,50 +17,37 @@ import asyncio
import logging import logging
from typing import Callable, Optional from typing import Callable, Optional
from subiquitycore.async_helpers import run_bg_task
from subiquitycore.lsb_release import lsb_release
from subiquitycore.view import BaseView
from subiquity.client.controller import SubiquityTuiController from subiquity.client.controller import SubiquityTuiController
from subiquity.common.filesystem import gaps from subiquity.common.filesystem import gaps
from subiquity.common.filesystem.manipulator import FilesystemManipulator from subiquity.common.filesystem.manipulator import FilesystemManipulator
from subiquity.common.types import ( from subiquity.common.types import (
ProbeStatus,
GuidedCapability, GuidedCapability,
GuidedChoiceV2, GuidedChoiceV2,
GuidedStorageResponseV2, GuidedStorageResponseV2,
GuidedStorageTargetManual, GuidedStorageTargetManual,
ProbeStatus,
StorageResponseV2, StorageResponseV2,
) )
from subiquity.models.filesystem import ( from subiquity.models.filesystem import Bootloader, FilesystemModel, raidlevels_by_value
Bootloader, from subiquity.ui.views import FilesystemView, GuidedDiskSelectionView
FilesystemModel, from subiquity.ui.views.filesystem.probing import ProbingFailed, SlowProbing
raidlevels_by_value, from subiquitycore.async_helpers import run_bg_task
) from subiquitycore.lsb_release import lsb_release
from subiquity.ui.views import ( from subiquitycore.view import BaseView
FilesystemView,
GuidedDiskSelectionView,
)
from subiquity.ui.views.filesystem.probing import (
SlowProbing,
ProbingFailed,
)
log = logging.getLogger("subiquity.client.controllers.filesystem") log = logging.getLogger("subiquity.client.controllers.filesystem")
class FilesystemController(SubiquityTuiController, FilesystemManipulator): class FilesystemController(SubiquityTuiController, FilesystemManipulator):
endpoint_name = "storage"
endpoint_name = 'storage'
def __init__(self, app): def __init__(self, app):
super().__init__(app) super().__init__(app)
self.model = None self.model = None
self.answers.setdefault('guided', False) self.answers.setdefault("guided", False)
self.answers.setdefault('tpm-default', False) self.answers.setdefault("tpm-default", False)
self.answers.setdefault('guided-index', 0) self.answers.setdefault("guided-index", 0)
self.answers.setdefault('manual', []) self.answers.setdefault("manual", [])
self.current_view: Optional[BaseView] = None self.current_view: Optional[BaseView] = None
async def make_ui(self) -> Callable[[], BaseView]: async def make_ui(self) -> Callable[[], BaseView]:
@ -90,8 +77,7 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator):
if isinstance(self.ui.body, SlowProbing): if isinstance(self.ui.body, SlowProbing):
self.ui.set_body(self.current_view) self.ui.set_body(self.current_view)
else: else:
log.debug("not refreshing the display. Current display is %r", log.debug("not refreshing the display. Current display is %r", self.ui.body)
self.ui.body)
async def make_guided_ui( async def make_guided_ui(
self, self,
@ -101,12 +87,9 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator):
self.app.show_error_report(status.error_report) self.app.show_error_report(status.error_report)
return ProbingFailed(self, status.error_report) return ProbingFailed(self, status.error_report)
response: StorageResponseV2 = await self.endpoint.v2.GET( response: StorageResponseV2 = await self.endpoint.v2.GET(include_raid=True)
include_raid=True)
disk_by_id = { disk_by_id = {disk.id: disk for disk in response.disks}
disk.id: disk for disk in response.disks
}
if status.error_report: if status.error_report:
self.app.show_error_report(status.error_report) self.app.show_error_report(status.error_report)
@ -118,48 +101,46 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator):
while not isinstance(self.ui.body, GuidedDiskSelectionView): while not isinstance(self.ui.body, GuidedDiskSelectionView):
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
if self.answers['tpm-default']: if self.answers["tpm-default"]:
self.app.answers['filesystem-confirmed'] = True self.app.answers["filesystem-confirmed"] = True
self.ui.body.done(self.ui.body.form) self.ui.body.done(self.ui.body.form)
if self.answers['guided']: if self.answers["guided"]:
targets = self.ui.body.form.targets targets = self.ui.body.form.targets
if 'guided-index' in self.answers: if "guided-index" in self.answers:
target = targets[self.answers['guided-index']] target = targets[self.answers["guided-index"]]
elif 'guided-label' in self.answers: elif "guided-label" in self.answers:
label = self.answers['guided-label'] label = self.answers["guided-label"]
disk_by_id = self.ui.body.form.disk_by_id disk_by_id = self.ui.body.form.disk_by_id
[target] = [ [target] = [t for t in targets if disk_by_id[t.disk_id].label == label]
t method = self.answers.get("guided-method")
for t in targets
if disk_by_id[t.disk_id].label == label
]
method = self.answers.get('guided-method')
value = { value = {
'disk': target, "disk": target,
'use_lvm': method == "lvm", "use_lvm": method == "lvm",
} }
passphrase = self.answers.get('guided-passphrase') passphrase = self.answers.get("guided-passphrase")
if passphrase is not None: if passphrase is not None:
value['lvm_options'] = { value["lvm_options"] = {
'encrypt': True, "encrypt": True,
'luks_options': { "luks_options": {
'passphrase': passphrase, "passphrase": passphrase,
'confirm_passphrase': passphrase, "confirm_passphrase": passphrase,
} },
} }
self.ui.body.form.guided_choice.value = value self.ui.body.form.guided_choice.value = value
self.ui.body.done(None) self.ui.body.done(None)
self.app.answers['filesystem-confirmed'] = True self.app.answers["filesystem-confirmed"] = True
while not isinstance(self.ui.body, FilesystemView): while not isinstance(self.ui.body, FilesystemView):
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
self.finish() self.finish()
elif self.answers['manual']: elif self.answers["manual"]:
await self._guided_choice( await self._guided_choice(
GuidedChoiceV2( GuidedChoiceV2(
target=GuidedStorageTargetManual(), target=GuidedStorageTargetManual(),
capability=GuidedCapability.MANUAL)) capability=GuidedCapability.MANUAL,
await self._run_actions(self.answers['manual']) )
self.answers['manual'] = [] )
await self._run_actions(self.answers["manual"])
self.answers["manual"] = []
def _action_get(self, id): def _action_get(self, id):
dev_spec = id[0].split() dev_spec = id[0].split()
@ -168,7 +149,7 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator):
if dev_spec[1] == "index": if dev_spec[1] == "index":
dev = self.model.all_disks()[int(dev_spec[2])] dev = self.model.all_disks()[int(dev_spec[2])]
elif dev_spec[1] == "serial": elif dev_spec[1] == "serial":
dev = self.model._one(type='disk', serial=dev_spec[2]) dev = self.model._one(type="disk", serial=dev_spec[2])
elif dev_spec[0] == "raid": elif dev_spec[0] == "raid":
if dev_spec[1] == "name": if dev_spec[1] == "name":
for r in self.model.all_raids(): for r in self.model.all_raids():
@ -192,16 +173,13 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator):
raise Exception("could not resolve {}".format(id)) raise Exception("could not resolve {}".format(id))
def _action_clean_devices_raid(self, devices): def _action_clean_devices_raid(self, devices):
r = { r = {self._action_get(d): v for d, v in zip(devices[::2], devices[1::2])}
self._action_get(d): v
for d, v in zip(devices[::2], devices[1::2])
}
for d in r: for d in r:
assert d.ok_for_raid assert d.ok_for_raid
return r return r
def _action_clean_devices_vg(self, devices): def _action_clean_devices_vg(self, devices):
r = {self._action_get(d): 'active' for d in devices} r = {self._action_get(d): "active" for d in devices}
for d in r: for d in r:
assert d.ok_for_lvm_vg assert d.ok_for_lvm_vg
return r return r
@ -210,12 +188,13 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator):
return raidlevels_by_value[level] return raidlevels_by_value[level]
async def _answers_action(self, action): async def _answers_action(self, action):
from subiquitycore.ui.stretchy import StretchyOverlay
from subiquity.ui.views.filesystem.delete import ConfirmDeleteStretchy from subiquity.ui.views.filesystem.delete import ConfirmDeleteStretchy
from subiquitycore.ui.stretchy import StretchyOverlay
log.debug("_answers_action %r", action) log.debug("_answers_action %r", action)
if 'obj' in action: if "obj" in action:
obj = self._action_get(action['obj']) obj = self._action_get(action["obj"])
action_name = action['action'] action_name = action["action"]
if action_name == "MAKE_BOOT": if action_name == "MAKE_BOOT":
action_name = "TOGGLE_BOOT" action_name = "TOGGLE_BOOT"
if action_name == "CREATE_LV": if action_name == "CREATE_LV":
@ -223,8 +202,8 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator):
if action_name == "PARTITION": if action_name == "PARTITION":
obj = gaps.largest_gap(obj) obj = gaps.largest_gap(obj)
meth = getattr( meth = getattr(
self.ui.body.avail_list, self.ui.body.avail_list, "_{}_{}".format(obj.type, action_name)
"_{}_{}".format(obj.type, action_name)) )
meth(obj) meth(obj)
yield yield
body = self.ui.body._w body = self.ui.body._w
@ -235,34 +214,35 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator):
body.stretchy.done() body.stretchy.done()
else: else:
async for _ in self._enter_form_data( async for _ in self._enter_form_data(
body.stretchy.form, body.stretchy.form, action["data"], action.get("submit", True)
action['data'], ):
action.get("submit", True)):
pass pass
elif action['action'] == 'create-raid': elif action["action"] == "create-raid":
self.ui.body.create_raid() self.ui.body.create_raid()
yield yield
body = self.ui.body._w body = self.ui.body._w
async for _ in self._enter_form_data( async for _ in self._enter_form_data(
body.stretchy.form, body.stretchy.form,
action['data'], action["data"],
action.get("submit", True), action.get("submit", True),
clean_suffix='raid'): clean_suffix="raid",
):
pass pass
elif action['action'] == 'create-vg': elif action["action"] == "create-vg":
self.ui.body.create_vg() self.ui.body.create_vg()
yield yield
body = self.ui.body._w body = self.ui.body._w
async for _ in self._enter_form_data( async for _ in self._enter_form_data(
body.stretchy.form, body.stretchy.form,
action['data'], action["data"],
action.get("submit", True), action.get("submit", True),
clean_suffix='vg'): clean_suffix="vg",
):
pass pass
elif action['action'] == 'done': elif action["action"] == "done":
if not self.ui.body.done_btn.enabled: if not self.ui.body.done_btn.enabled:
raise Exception("answers did not provide complete fs config") raise Exception("answers did not provide complete fs config")
self.app.answers['filesystem-confirmed'] = True self.app.answers["filesystem-confirmed"] = True
self.finish() self.finish()
else: else:
raise Exception("could not process action {}".format(action)) raise Exception("could not process action {}".format(action))
@ -278,8 +258,8 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator):
if self.model.bootloader == Bootloader.PREP: if self.model.bootloader == Bootloader.PREP:
self.supports_resilient_boot = False self.supports_resilient_boot = False
else: else:
release = lsb_release(dry_run=self.app.opts.dry_run)['release'] release = lsb_release(dry_run=self.app.opts.dry_run)["release"]
self.supports_resilient_boot = release >= '20.04' self.supports_resilient_boot = release >= "20.04"
self.ui.set_body(FilesystemView(self.model, self)) self.ui.set_body(FilesystemView(self.model, self))
def guided_choice(self, choice): def guided_choice(self, choice):

View File

@ -19,34 +19,34 @@ from subiquity.client.controller import SubiquityTuiController
from subiquity.common.types import IdentityData from subiquity.common.types import IdentityData
from subiquity.ui.views import IdentityView from subiquity.ui.views import IdentityView
log = logging.getLogger('subiquity.client.controllers.identity') log = logging.getLogger("subiquity.client.controllers.identity")
class IdentityController(SubiquityTuiController): class IdentityController(SubiquityTuiController):
endpoint_name = "identity"
endpoint_name = 'identity'
async def make_ui(self): async def make_ui(self):
data = await self.endpoint.GET() data = await self.endpoint.GET()
return IdentityView(self, data) return IdentityView(self, data)
def run_answers(self): def run_answers(self):
if all(elem in self.answers for elem in if all(
['realname', 'username', 'password', 'hostname']): elem in self.answers
for elem in ["realname", "username", "password", "hostname"]
):
identity = IdentityData( identity = IdentityData(
realname=self.answers['realname'], realname=self.answers["realname"],
username=self.answers['username'], username=self.answers["username"],
hostname=self.answers['hostname'], hostname=self.answers["hostname"],
crypted_password=self.answers['password']) crypted_password=self.answers["password"],
)
self.done(identity) self.done(identity)
def cancel(self): def cancel(self):
self.app.prev_screen() self.app.prev_screen()
def done(self, identity_data): def done(self, identity_data):
log.debug( log.debug("IdentityController.done next_screen user_spec=%s", identity_data)
"IdentityController.done next_screen user_spec=%s",
identity_data)
self.app.next_screen(self.endpoint.POST(identity_data)) self.app.next_screen(self.endpoint.POST(identity_data))
async def validate_username(self, username): async def validate_username(self, username):

View File

@ -20,21 +20,20 @@ from subiquity.client.controller import SubiquityTuiController
from subiquity.common.types import KeyboardSetting from subiquity.common.types import KeyboardSetting
from subiquity.ui.views import KeyboardView from subiquity.ui.views import KeyboardView
log = logging.getLogger('subiquity.client.controllers.keyboard') log = logging.getLogger("subiquity.client.controllers.keyboard")
class KeyboardController(SubiquityTuiController): class KeyboardController(SubiquityTuiController):
endpoint_name = "keyboard"
endpoint_name = 'keyboard'
async def make_ui(self): async def make_ui(self):
setup = await self.endpoint.GET() setup = await self.endpoint.GET()
return KeyboardView(self, setup) return KeyboardView(self, setup)
async def run_answers(self): async def run_answers(self):
if 'layout' in self.answers: if "layout" in self.answers:
layout = self.answers['layout'] layout = self.answers["layout"]
variant = self.answers.get('variant', '') variant = self.answers.get("variant", "")
await self.apply(KeyboardSetting(layout=layout, variant=variant)) await self.apply(KeyboardSetting(layout=layout, variant=variant))
self.done() self.done()
@ -43,7 +42,8 @@ class KeyboardController(SubiquityTuiController):
async def needs_toggle(self, setting): async def needs_toggle(self, setting):
return await self.endpoint.needs_toggle.GET( return await self.endpoint.needs_toggle.GET(
layout_code=setting.layout, variant_code=setting.variant) layout_code=setting.layout, variant_code=setting.variant
)
async def apply(self, setting): async def apply(self, setting):
await self.endpoint.POST(setting) await self.endpoint.POST(setting)

View File

@ -16,22 +16,16 @@
import asyncio import asyncio
import logging import logging
from subiquity.client.controller import SubiquityTuiController
from subiquity.common.types import MirrorCheckStatus, MirrorGet, MirrorPost
from subiquity.ui.views.mirror import MirrorView
from subiquitycore.tuicontroller import Skip from subiquitycore.tuicontroller import Skip
from subiquity.common.types import ( log = logging.getLogger("subiquity.client.controllers.mirror")
MirrorCheckStatus,
MirrorGet,
MirrorPost,
)
from subiquity.client.controller import SubiquityTuiController
from subiquity.ui.views.mirror import MirrorView
log = logging.getLogger('subiquity.client.controllers.mirror')
class MirrorController(SubiquityTuiController): class MirrorController(SubiquityTuiController):
endpoint_name = "mirror"
endpoint_name = 'mirror'
async def make_ui(self): async def make_ui(self):
mirror_response: MirrorGet = await self.endpoint.GET() mirror_response: MirrorGet = await self.endpoint.GET()
@ -57,19 +51,18 @@ class MirrorController(SubiquityTuiController):
async def run_answers(self): async def run_answers(self):
async def wait_mirror_check() -> None: async def wait_mirror_check() -> None:
""" Wait until the mirror check has finished running. """ """Wait until the mirror check has finished running."""
while True: while True:
last_status = self.app.ui.body.last_status last_status = self.app.ui.body.last_status
if last_status not in [None, MirrorCheckStatus.RUNNING]: if last_status not in [None, MirrorCheckStatus.RUNNING]:
return return
await asyncio.sleep(.1) await asyncio.sleep(0.1)
if 'mirror' in self.answers: if "mirror" in self.answers:
self.app.ui.body.form.url.value = self.answers['mirror'] self.app.ui.body.form.url.value = self.answers["mirror"]
await wait_mirror_check() await wait_mirror_check()
self.app.ui.body.form._click_done(None) self.app.ui.body.form._click_done(None)
elif 'country-code' in self.answers \ elif "country-code" in self.answers or "accept-default" in self.answers:
or 'accept-default' in self.answers:
await wait_mirror_check() await wait_mirror_check()
self.app.ui.body.form._click_done(None) self.app.ui.body.form._click_done(None)
@ -78,5 +71,4 @@ class MirrorController(SubiquityTuiController):
def done(self, mirror): def done(self, mirror):
log.debug("MirrorController.done next_screen mirror=%s", mirror) log.debug("MirrorController.done next_screen mirror=%s", mirror)
self.app.next_screen(self.endpoint.POST( self.app.next_screen(self.endpoint.POST(MirrorPost(elected=mirror)))
MirrorPost(elected=mirror)))

View File

@ -21,6 +21,10 @@ from typing import Optional
from aiohttp import web from aiohttp import web
from subiquity.client.controller import SubiquityTuiController
from subiquity.common.api.server import make_server_at_path
from subiquity.common.apidef import LinkAction, NetEventAPI
from subiquity.common.types import ErrorReportKind, PackageInstallState
from subiquitycore.async_helpers import run_bg_task from subiquitycore.async_helpers import run_bg_task
from subiquitycore.controllers.network import NetworkAnswersMixin from subiquitycore.controllers.network import NetworkAnswersMixin
from subiquitycore.models.network import ( from subiquitycore.models.network import (
@ -28,23 +32,14 @@ from subiquitycore.models.network import (
NetDevInfo, NetDevInfo,
StaticConfig, StaticConfig,
WLANConfig, WLANConfig,
) )
from subiquitycore.ui.views.network import NetworkView from subiquitycore.ui.views.network import NetworkView
from subiquity.client.controller import SubiquityTuiController log = logging.getLogger("subiquity.client.controllers.network")
from subiquity.common.api.server import make_server_at_path
from subiquity.common.apidef import LinkAction, NetEventAPI
from subiquity.common.types import (
ErrorReportKind,
PackageInstallState,
)
log = logging.getLogger('subiquity.client.controllers.network')
class NetworkController(SubiquityTuiController, NetworkAnswersMixin): class NetworkController(SubiquityTuiController, NetworkAnswersMixin):
endpoint_name = "network"
endpoint_name = 'network'
def __init__(self, app): def __init__(self, app):
super().__init__(app) super().__init__(app)
@ -53,18 +48,18 @@ class NetworkController(SubiquityTuiController, NetworkAnswersMixin):
@web.middleware @web.middleware
async def middleware(self, request, handler): async def middleware(self, request, handler):
resp = await handler(request) resp = await handler(request)
if resp.get('exception'): if resp.get("exception"):
exc = resp['exception'] exc = resp["exception"]
log.debug( log.debug("request to {} crashed".format(request.raw_path), exc_info=exc)
'request to {} crashed'.format(request.raw_path), exc_info=exc)
self.app.make_apport_report( self.app.make_apport_report(
ErrorReportKind.NETWORK_CLIENT_FAIL, ErrorReportKind.NETWORK_CLIENT_FAIL,
"request to {}".format(request.raw_path), "request to {}".format(request.raw_path),
exc=exc, interrupt=True) exc=exc,
interrupt=True,
)
return resp return resp
async def update_link_POST(self, act: LinkAction, async def update_link_POST(self, act: LinkAction, info: NetDevInfo) -> None:
info: NetDevInfo) -> None:
if self.view is None: if self.view is None:
return return
if act == LinkAction.NEW: if act == LinkAction.NEW:
@ -91,15 +86,17 @@ class NetworkController(SubiquityTuiController, NetworkAnswersMixin):
self.view.show_network_error(stage) self.view.show_network_error(stage)
async def wlan_support_install_finished_POST( async def wlan_support_install_finished_POST(
self, state: PackageInstallState) -> None: self, state: PackageInstallState
) -> None:
if self.view: if self.view:
self.view.update_for_wlan_support_install_state(state.name) self.view.update_for_wlan_support_install_state(state.name)
async def subscribe(self): async def subscribe(self):
self.tdir = tempfile.mkdtemp() self.tdir = tempfile.mkdtemp()
self.sock_path = os.path.join(self.tdir, 'socket') self.sock_path = os.path.join(self.tdir, "socket")
self.site = await make_server_at_path( self.site = await make_server_at_path(
self.sock_path, NetEventAPI, self, middlewares=[self.middleware]) self.sock_path, NetEventAPI, self, middlewares=[self.middleware]
)
await self.endpoint.subscription.PUT(self.sock_path) await self.endpoint.subscription.PUT(self.sock_path)
async def unsubscribe(self): async def unsubscribe(self):
@ -110,8 +107,8 @@ class NetworkController(SubiquityTuiController, NetworkAnswersMixin):
async def make_ui(self): async def make_ui(self):
network_status = await self.endpoint.GET() network_status = await self.endpoint.GET()
self.view = NetworkView( self.view = NetworkView(
self, network_status.devices, self, network_status.devices, network_status.wlan_support_install_state.name
network_status.wlan_support_install_state.name) )
await self.subscribe() await self.subscribe()
return self.view return self.view
@ -126,40 +123,37 @@ class NetworkController(SubiquityTuiController, NetworkAnswersMixin):
def done(self): def done(self):
self.app.next_screen(self.endpoint.POST()) self.app.next_screen(self.endpoint.POST())
def set_static_config(self, dev_name: str, ip_version: int, def set_static_config(
static_config: StaticConfig) -> None: self, dev_name: str, ip_version: int, static_config: StaticConfig
) -> None:
run_bg_task( run_bg_task(
self.endpoint.set_static_config.POST( self.endpoint.set_static_config.POST(dev_name, ip_version, static_config)
dev_name, ip_version, static_config)) )
def enable_dhcp(self, dev_name, ip_version: int) -> None: def enable_dhcp(self, dev_name, ip_version: int) -> None:
run_bg_task( run_bg_task(self.endpoint.enable_dhcp.POST(dev_name, ip_version))
self.endpoint.enable_dhcp.POST(dev_name, ip_version))
def disable_network(self, dev_name: str, ip_version: int) -> None: def disable_network(self, dev_name: str, ip_version: int) -> None:
run_bg_task( run_bg_task(self.endpoint.disable.POST(dev_name, ip_version))
self.endpoint.disable.POST(dev_name, ip_version))
def add_vlan(self, dev_name: str, vlan_id: int): def add_vlan(self, dev_name: str, vlan_id: int):
run_bg_task( run_bg_task(self.endpoint.vlan.PUT(dev_name, vlan_id))
self.endpoint.vlan.PUT(dev_name, vlan_id))
def set_wlan(self, dev_name: str, wlan: WLANConfig) -> None: def set_wlan(self, dev_name: str, wlan: WLANConfig) -> None:
run_bg_task( run_bg_task(self.endpoint.set_wlan.POST(dev_name, wlan))
self.endpoint.set_wlan.POST(dev_name, wlan))
def start_scan(self, dev_name: str) -> None: def start_scan(self, dev_name: str) -> None:
run_bg_task( run_bg_task(self.endpoint.start_scan.POST(dev_name))
self.endpoint.start_scan.POST(dev_name))
def delete_link(self, dev_name: str): def delete_link(self, dev_name: str):
run_bg_task(self.endpoint.delete.POST(dev_name)) run_bg_task(self.endpoint.delete.POST(dev_name))
def add_or_update_bond(self, existing_name: Optional[str], def add_or_update_bond(
new_name: str, new_info: BondConfig) -> None: self, existing_name: Optional[str], new_name: str, new_info: BondConfig
) -> None:
run_bg_task( run_bg_task(
self.endpoint.add_or_edit_bond.POST( self.endpoint.add_or_edit_bond.POST(existing_name, new_name, new_info)
existing_name, new_name, new_info)) )
async def get_info_for_netdev(self, dev_name: str) -> str: async def get_info_for_netdev(self, dev_name: str) -> str:
return await self.endpoint.info.GET(dev_name) return await self.endpoint.info.GET(dev_name)

View File

@ -18,25 +18,16 @@ import logging
import aiohttp import aiohttp
from subiquity.client.controller import SubiquityTuiController
from subiquity.common.types import ApplicationState, ShutdownMode
from subiquity.ui.views.installprogress import InstallRunning, ProgressView
from subiquitycore.async_helpers import run_bg_task from subiquitycore.async_helpers import run_bg_task
from subiquitycore.context import with_context from subiquitycore.context import with_context
from subiquity.client.controller import SubiquityTuiController
from subiquity.common.types import (
ApplicationState,
ShutdownMode,
)
from subiquity.ui.views.installprogress import (
InstallRunning,
ProgressView,
)
log = logging.getLogger("subiquity.client.controllers.progress") log = logging.getLogger("subiquity.client.controllers.progress")
class ProgressController(SubiquityTuiController): class ProgressController(SubiquityTuiController):
def __init__(self, app): def __init__(self, app):
super().__init__(app) super().__init__(app)
self.progress_view = ProgressView(self) self.progress_view = ProgressView(self)
@ -49,13 +40,13 @@ class ProgressController(SubiquityTuiController):
self.progress_view.event_start( self.progress_view.event_start(
event["SUBIQUITY_CONTEXT_ID"], event["SUBIQUITY_CONTEXT_ID"],
event.get("SUBIQUITY_CONTEXT_PARENT_ID"), event.get("SUBIQUITY_CONTEXT_PARENT_ID"),
event["MESSAGE"]) event["MESSAGE"],
)
elif event["SUBIQUITY_EVENT_TYPE"] == "finish": elif event["SUBIQUITY_EVENT_TYPE"] == "finish":
self.progress_view.event_finish( self.progress_view.event_finish(event["SUBIQUITY_CONTEXT_ID"])
event["SUBIQUITY_CONTEXT_ID"])
def log_line(self, event): def log_line(self, event):
log_line = event['MESSAGE'] log_line = event["MESSAGE"]
self.progress_view.add_log_line(log_line) self.progress_view.add_log_line(log_line)
def cancel(self): def cancel(self):
@ -79,8 +70,7 @@ class ProgressController(SubiquityTuiController):
install_running = None install_running = None
while True: while True:
try: try:
app_status = await self.app.client.meta.status.GET( app_status = await self.app.client.meta.status.GET(cur=self.app_state)
cur=self.app_state)
except aiohttp.ClientError: except aiohttp.ClientError:
await asyncio.sleep(1) await asyncio.sleep(1)
continue continue
@ -103,7 +93,8 @@ class ProgressController(SubiquityTuiController):
if self.app_state == ApplicationState.RUNNING: if self.app_state == ApplicationState.RUNNING:
if app_status.confirming_tty != self.app.our_tty: if app_status.confirming_tty != self.app.our_tty:
install_running = InstallRunning( install_running = InstallRunning(
self.app, app_status.confirming_tty) self.app, app_status.confirming_tty
)
self.app.add_global_overlay(install_running) self.app.add_global_overlay(install_running)
else: else:
if install_running is not None: if install_running is not None:
@ -111,7 +102,7 @@ class ProgressController(SubiquityTuiController):
install_running = None install_running = None
if self.app_state == ApplicationState.DONE: if self.app_state == ApplicationState.DONE:
if self.answers.get('reboot', False): if self.answers.get("reboot", False):
self.click_reboot() self.click_reboot()
def make_ui(self): def make_ui(self):

View File

@ -18,20 +18,19 @@ import logging
from subiquity.client.controller import SubiquityTuiController from subiquity.client.controller import SubiquityTuiController
from subiquity.ui.views.proxy import ProxyView from subiquity.ui.views.proxy import ProxyView
log = logging.getLogger('subiquity.client.controllers.proxy') log = logging.getLogger("subiquity.client.controllers.proxy")
class ProxyController(SubiquityTuiController): class ProxyController(SubiquityTuiController):
endpoint_name = "proxy"
endpoint_name = 'proxy'
async def make_ui(self): async def make_ui(self):
proxy = await self.endpoint.GET() proxy = await self.endpoint.GET()
return ProxyView(self, proxy) return ProxyView(self, proxy)
def run_answers(self): def run_answers(self):
if 'proxy' in self.answers: if "proxy" in self.answers:
self.done(self.answers['proxy']) self.done(self.answers["proxy"])
def cancel(self): def cancel(self):
self.app.prev_screen() self.app.prev_screen()

View File

@ -18,25 +18,16 @@ import logging
import aiohttp import aiohttp
from subiquitycore.tuicontroller import ( from subiquity.client.controller import SubiquityTuiController
Skip, from subiquity.common.types import RefreshCheckState
)
from subiquity.common.types import (
RefreshCheckState,
)
from subiquity.client.controller import (
SubiquityTuiController,
)
from subiquity.ui.views.refresh import RefreshView from subiquity.ui.views.refresh import RefreshView
from subiquitycore.tuicontroller import Skip
log = logging.getLogger("subiquity.client.controllers.refresh")
log = logging.getLogger('subiquity.client.controllers.refresh')
class RefreshController(SubiquityTuiController): class RefreshController(SubiquityTuiController):
endpoint_name = "refresh"
endpoint_name = 'refresh'
def __init__(self, app): def __init__(self, app):
super().__init__(app) super().__init__(app)
@ -61,8 +52,10 @@ class RefreshController(SubiquityTuiController):
self.offered_first_time = True self.offered_first_time = True
elif index == 2: elif index == 2:
if not self.offered_first_time: if not self.offered_first_time:
if self.status.availability in [RefreshCheckState.UNKNOWN, if self.status.availability in [
RefreshCheckState.AVAILABLE]: RefreshCheckState.UNKNOWN,
RefreshCheckState.AVAILABLE,
]:
show = True show = True
else: else:
raise AssertionError("unexpected index {}".format(index)) raise AssertionError("unexpected index {}".format(index))

View File

@ -19,7 +19,7 @@ from subiquity.client.controller import SubiquityTuiController
from subiquity.ui.views.serial import SerialView from subiquity.ui.views.serial import SerialView
from subiquitycore.tuicontroller import Skip from subiquitycore.tuicontroller import Skip
log = logging.getLogger('subiquity.client.controllers.serial') log = logging.getLogger("subiquity.client.controllers.serial")
class SerialController(SubiquityTuiController): class SerialController(SubiquityTuiController):
@ -30,11 +30,11 @@ class SerialController(SubiquityTuiController):
return SerialView(self, ssh_info) return SerialView(self, ssh_info)
def run_answers(self): def run_answers(self):
if 'rich' in self.answers: if "rich" in self.answers:
if self.answers['rich']: if self.answers["rich"]:
self.ui.body.rich_btn.base_widget._emit('click') self.ui.body.rich_btn.base_widget._emit("click")
else: else:
self.ui.body.basic_btn.base_widget._emit('click') self.ui.body.basic_btn.base_widget._emit("click")
def done(self, rich): def done(self, rich):
log.debug("SerialController.done rich %s next_screen", rich) log.debug("SerialController.done rich %s next_screen", rich)

View File

@ -16,48 +16,37 @@
import logging import logging
from typing import List from typing import List
from subiquitycore.tuicontroller import ( from subiquity.client.controller import SubiquityTuiController
Skip, from subiquity.common.types import SnapCheckState, SnapSelection
)
from subiquity.client.controller import (
SubiquityTuiController,
)
from subiquity.common.types import (
SnapCheckState,
SnapSelection,
)
from subiquity.ui.views.snaplist import SnapListView from subiquity.ui.views.snaplist import SnapListView
from subiquitycore.tuicontroller import Skip
log = logging.getLogger('subiquity.client.controllers.snaplist') log = logging.getLogger("subiquity.client.controllers.snaplist")
class SnapListController(SubiquityTuiController): class SnapListController(SubiquityTuiController):
endpoint_name = "snaplist"
endpoint_name = 'snaplist'
async def make_ui(self): async def make_ui(self):
data = await self.endpoint.GET() data = await self.endpoint.GET()
if data.status == SnapCheckState.FAILED: if data.status == SnapCheckState.FAILED:
# If loading snaps failed or network is disabled, skip the screen. # If loading snaps failed or network is disabled, skip the screen.
log.debug('snaplist GET failed, mark done') log.debug("snaplist GET failed, mark done")
await self.endpoint.POST([]) await self.endpoint.POST([])
raise Skip raise Skip
return SnapListView(self, data) return SnapListView(self, data)
def run_answers(self): def run_answers(self):
if 'snaps' in self.answers: if "snaps" in self.answers:
selections = [] selections = []
for snap_name, selection in self.answers['snaps'].items(): for snap_name, selection in self.answers["snaps"].items():
if 'is_classic' in selection: if "is_classic" in selection:
selection['classic'] = selection.pop('is_classic') selection["classic"] = selection.pop("is_classic")
selections.append(SnapSelection(name=snap_name, **selection)) selections.append(SnapSelection(name=snap_name, **selection))
self.done(selections) self.done(selections)
def done(self, selections: List[SnapSelection]): def done(self, selections: List[SnapSelection]):
log.debug( log.debug("SnapListController.done next_screen snaps_to_install=%s", selections)
"SnapListController.done next_screen snaps_to_install=%s",
selections)
self.app.next_screen(self.endpoint.POST(selections)) self.app.next_screen(self.endpoint.POST(selections))
def cancel(self, sender=None): def cancel(self, sender=None):

View File

@ -18,26 +18,24 @@ import logging
from subiquity.client.controller import SubiquityTuiController from subiquity.client.controller import SubiquityTuiController
from subiquity.ui.views.source import SourceView from subiquity.ui.views.source import SourceView
log = logging.getLogger('subiquity.client.controllers.source') log = logging.getLogger("subiquity.client.controllers.source")
class SourceController(SubiquityTuiController): class SourceController(SubiquityTuiController):
endpoint_name = "source"
endpoint_name = 'source'
async def make_ui(self): async def make_ui(self):
sources = await self.endpoint.GET() sources = await self.endpoint.GET()
return SourceView(self, return SourceView(
sources.sources, self, sources.sources, sources.current_id, sources.search_drivers
sources.current_id, )
sources.search_drivers)
def run_answers(self): def run_answers(self):
form = self.app.ui.body.form form = self.app.ui.body.form
if "search_drivers" in self.answers: if "search_drivers" in self.answers:
form.search_drivers.value = self.answers["search_drivers"] form.search_drivers.value = self.answers["search_drivers"]
if 'source' in self.answers: if "source" in self.answers:
wanted_id = self.answers['source'] wanted_id = self.answers["source"]
for bf in form._fields: for bf in form._fields:
if bf is form.search_drivers: if bf is form.search_drivers:
continue continue
@ -48,6 +46,9 @@ class SourceController(SubiquityTuiController):
self.app.prev_screen() self.app.prev_screen()
def done(self, source_id, search_drivers: bool): def done(self, source_id, search_drivers: bool):
log.debug("SourceController.done source_id=%s, search_drivers=%s", log.debug(
source_id, search_drivers) "SourceController.done source_id=%s, search_drivers=%s",
source_id,
search_drivers,
)
self.app.next_screen(self.endpoint.POST(source_id, search_drivers)) self.app.next_screen(self.endpoint.POST(source_id, search_drivers))

View File

@ -15,18 +15,13 @@
import logging import logging
from subiquity.client.controller import SubiquityTuiController
from subiquity.common.types import SSHData, SSHFetchIdResponse, SSHFetchIdStatus
from subiquity.ui.views.ssh import SSHView
from subiquitycore.async_helpers import schedule_task from subiquitycore.async_helpers import schedule_task
from subiquitycore.context import with_context from subiquitycore.context import with_context
from subiquity.client.controller import SubiquityTuiController log = logging.getLogger("subiquity.client.controllers.ssh")
from subiquity.common.types import (
SSHData,
SSHFetchIdResponse,
SSHFetchIdStatus,
)
from subiquity.ui.views.ssh import SSHView
log = logging.getLogger('subiquity.client.controllers.ssh')
class FetchSSHKeysFailure(Exception): class FetchSSHKeysFailure(Exception):
@ -36,35 +31,31 @@ class FetchSSHKeysFailure(Exception):
class SSHController(SubiquityTuiController): class SSHController(SubiquityTuiController):
endpoint_name = "ssh"
endpoint_name = 'ssh'
def __init__(self, app): def __init__(self, app):
super().__init__(app) super().__init__(app)
self._fetch_task = None self._fetch_task = None
if not self.answers: if not self.answers:
identity_answers = self.app.answers.get('Identity', {}) identity_answers = self.app.answers.get("Identity", {})
if 'ssh-import-id' in identity_answers: if "ssh-import-id" in identity_answers:
self.answers['ssh-import-id'] = identity_answers[ self.answers["ssh-import-id"] = identity_answers["ssh-import-id"]
'ssh-import-id']
async def make_ui(self): async def make_ui(self):
ssh_data = await self.endpoint.GET() ssh_data = await self.endpoint.GET()
return SSHView(self, ssh_data) return SSHView(self, ssh_data)
def run_answers(self): def run_answers(self):
if 'ssh-import-id' in self.answers: if "ssh-import-id" in self.answers:
import_id = self.answers['ssh-import-id'] import_id = self.answers["ssh-import-id"]
ssh = SSHData( ssh = SSHData(install_server=True, authorized_keys=[], allow_pw=True)
install_server=True,
authorized_keys=[],
allow_pw=True)
self.fetch_ssh_keys(ssh_import_id=import_id, ssh_data=ssh) self.fetch_ssh_keys(ssh_import_id=import_id, ssh_data=ssh)
else: else:
ssh = SSHData( ssh = SSHData(
install_server=self.answers.get("install_server", False), install_server=self.answers.get("install_server", False),
authorized_keys=self.answers.get("authorized_keys", []), authorized_keys=self.answers.get("authorized_keys", []),
allow_pw=self.answers.get("pwauth", True)) allow_pw=self.answers.get("pwauth", True),
)
self.done(ssh) self.done(ssh)
def cancel(self): def cancel(self):
@ -75,44 +66,50 @@ class SSHController(SubiquityTuiController):
return return
self._fetch_task.cancel() self._fetch_task.cancel()
@with_context( @with_context(name="ssh_import_id", description="{ssh_import_id}")
name="ssh_import_id", description="{ssh_import_id}")
async def _fetch_ssh_keys(self, *, context, ssh_import_id, ssh_data): async def _fetch_ssh_keys(self, *, context, ssh_import_id, ssh_data):
with self.context.child("ssh_import_id", ssh_import_id): with self.context.child("ssh_import_id", ssh_import_id):
response: SSHFetchIdResponse = await \ response: SSHFetchIdResponse = await self.endpoint.fetch_id.GET(
self.endpoint.fetch_id.GET(ssh_import_id) ssh_import_id
)
if response.status == SSHFetchIdStatus.IMPORT_ERROR: if response.status == SSHFetchIdStatus.IMPORT_ERROR:
if isinstance(self.ui.body, SSHView): if isinstance(self.ui.body, SSHView):
self.ui.body.fetching_ssh_keys_failed( self.ui.body.fetching_ssh_keys_failed(
_("Importing keys failed:"), response.error) _("Importing keys failed:"), response.error
)
return return
elif response.status == SSHFetchIdStatus.FINGERPRINT_ERROR: elif response.status == SSHFetchIdStatus.FINGERPRINT_ERROR:
if isinstance(self.ui.body, SSHView): if isinstance(self.ui.body, SSHView):
self.ui.body.fetching_ssh_keys_failed( self.ui.body.fetching_ssh_keys_failed(
_("ssh-keygen failed to show fingerprint of" _(
" downloaded keys:"), "ssh-keygen failed to show fingerprint of"
response.error) " downloaded keys:"
),
response.error,
)
return return
identities = response.identities identities = response.identities
if 'ssh-import-id' in self.app.answers.get("Identity", {}): if "ssh-import-id" in self.app.answers.get("Identity", {}):
ssh_data.authorized_keys = \ ssh_data.authorized_keys = [
[id_.to_authorized_key() for id_ in identities] id_.to_authorized_key() for id_ in identities
]
self.done(ssh_data) self.done(ssh_data)
else: else:
if isinstance(self.ui.body, SSHView): if isinstance(self.ui.body, SSHView):
self.ui.body.confirm_ssh_keys( self.ui.body.confirm_ssh_keys(ssh_data, ssh_import_id, identities)
ssh_data, ssh_import_id, identities)
else: else:
log.debug("ui.body of unexpected instance: %s", log.debug(
type(self.ui.body).__name__) "ui.body of unexpected instance: %s",
type(self.ui.body).__name__,
)
def fetch_ssh_keys(self, ssh_import_id, ssh_data): def fetch_ssh_keys(self, ssh_import_id, ssh_data):
self._fetch_task = schedule_task( self._fetch_task = schedule_task(
self._fetch_ssh_keys( self._fetch_ssh_keys(ssh_import_id=ssh_import_id, ssh_data=ssh_data)
ssh_import_id=ssh_import_id, ssh_data=ssh_data)) )
def done(self, result): def done(self, result):
log.debug("SSHController.done next_screen result=%s", result) log.debug("SSHController.done next_screen result=%s", result)

View File

@ -22,23 +22,21 @@ from typing import Callable, Optional
import aiohttp import aiohttp
from urwid import Widget from urwid import Widget
from subiquitycore.async_helpers import schedule_task
from subiquity.client.controller import SubiquityTuiController from subiquity.client.controller import SubiquityTuiController
from subiquity.common.types import UbuntuProCheckTokenStatus as TokenStatus
from subiquity.common.types import ( from subiquity.common.types import (
UbuntuProInfo, UbuntuProInfo,
UbuntuProResponse, UbuntuProResponse,
UbuntuProCheckTokenStatus as TokenStatus,
UbuntuProSubscription, UbuntuProSubscription,
UPCSWaitStatus, UPCSWaitStatus,
) )
from subiquity.ui.views.ubuntu_pro import ( from subiquity.ui.views.ubuntu_pro import (
UbuntuProView,
UpgradeYesNoForm,
UpgradeModeForm,
TokenAddedWidget, TokenAddedWidget,
) UbuntuProView,
UpgradeModeForm,
UpgradeYesNoForm,
)
from subiquitycore.async_helpers import schedule_task
from subiquitycore.lsb_release import lsb_release from subiquitycore.lsb_release import lsb_release
from subiquitycore.tuicontroller import Skip from subiquitycore.tuicontroller import Skip
@ -46,12 +44,12 @@ log = logging.getLogger("subiquity.client.controllers.ubuntu_pro")
class UbuntuProController(SubiquityTuiController): class UbuntuProController(SubiquityTuiController):
""" Client-side controller for Ubuntu Pro configuration. """ """Client-side controller for Ubuntu Pro configuration."""
endpoint_name = "ubuntu_pro" endpoint_name = "ubuntu_pro"
def __init__(self, app) -> None: def __init__(self, app) -> None:
""" Initializer for the client-side UA controller. """ """Initializer for the client-side UA controller."""
self._check_task: Optional[asyncio.Future] = None self._check_task: Optional[asyncio.Future] = None
self._magic_attach_task: Optional[asyncio.Future] = None self._magic_attach_task: Optional[asyncio.Future] = None
@ -60,7 +58,7 @@ class UbuntuProController(SubiquityTuiController):
super().__init__(app) super().__init__(app)
async def make_ui(self) -> UbuntuProView: async def make_ui(self) -> UbuntuProView:
""" Generate the UI, based on the data provided by the model. """ """Generate the UI, based on the data provided by the model."""
dry_run: bool = self.app.opts.dry_run dry_run: bool = self.app.opts.dry_run
@ -70,13 +68,13 @@ class UbuntuProController(SubiquityTuiController):
raise Skip("Not running LTS version") raise Skip("Not running LTS version")
ubuntu_pro_info: UbuntuProResponse = await self.endpoint.GET() ubuntu_pro_info: UbuntuProResponse = await self.endpoint.GET()
return UbuntuProView(self, return UbuntuProView(
token=ubuntu_pro_info.token, self, token=ubuntu_pro_info.token, has_network=ubuntu_pro_info.has_network
has_network=ubuntu_pro_info.has_network) )
async def run_answers(self) -> None: async def run_answers(self) -> None:
""" Interact with the UI to go through the pre-attach process if """Interact with the UI to go through the pre-attach process if
requested. """ requested."""
if "token" not in self.answers: if "token" not in self.answers:
return return
@ -100,8 +98,7 @@ class UbuntuProController(SubiquityTuiController):
click(find_button_matching(view, UpgradeYesNoForm.ok_label)) click(find_button_matching(view, UpgradeYesNoForm.ok_label))
def run_token_screen(token: str) -> None: def run_token_screen(token: str) -> None:
keypress(view.upgrade_mode_form.with_contract_token.widget, keypress(view.upgrade_mode_form.with_contract_token.widget, key="enter")
key="enter")
data = {"with_contract_token_subform": {"token": token}} data = {"with_contract_token_subform": {"token": token}}
# TODO: add this point, it would be good to trigger the validation # TODO: add this point, it would be good to trigger the validation
# code for the token field. # code for the token field.
@ -117,13 +114,12 @@ class UbuntuProController(SubiquityTuiController):
# Wait until the "Token added successfully" overlay is shown. # Wait until the "Token added successfully" overlay is shown.
while not find_with_pred(view, is_token_added_overlay): while not find_with_pred(view, is_token_added_overlay):
await asyncio.sleep(.2) await asyncio.sleep(0.2)
click(find_button_matching(view, TokenAddedWidget.done_label)) click(find_button_matching(view, TokenAddedWidget.done_label))
def run_subscription_screen() -> None: def run_subscription_screen() -> None:
click(find_button_matching(view._w, click(find_button_matching(view._w, UbuntuProView.subscription_done_label))
UbuntuProView.subscription_done_label))
if not self.answers["token"]: if not self.answers["token"]:
run_yes_no_screen(skip=True) run_yes_no_screen(skip=True)
@ -134,11 +130,14 @@ class UbuntuProController(SubiquityTuiController):
await run_token_added_overlay() await run_token_added_overlay()
run_subscription_screen() run_subscription_screen()
def check_token(self, token: str, def check_token(
self,
token: str,
on_success: Callable[[UbuntuProSubscription], None], on_success: Callable[[UbuntuProSubscription], None],
on_failure: Callable[[TokenStatus], None], on_failure: Callable[[TokenStatus], None],
) -> None: ) -> None:
""" Asynchronously check the token passed as an argument. """ """Asynchronously check the token passed as an argument."""
async def inner() -> None: async def inner() -> None:
answer = await self.endpoint.check_token.GET(token) answer = await self.endpoint.check_token.GET(token)
if answer.status == TokenStatus.VALID_TOKEN: if answer.status == TokenStatus.VALID_TOKEN:
@ -150,29 +149,32 @@ class UbuntuProController(SubiquityTuiController):
self._check_task = schedule_task(inner()) self._check_task = schedule_task(inner())
def cancel_check_token(self) -> None: def cancel_check_token(self) -> None:
""" Cancel the asynchronous token check (if started). """ """Cancel the asynchronous token check (if started)."""
if self._check_task is not None: if self._check_task is not None:
self._check_task.cancel() self._check_task.cancel()
def contract_selection_initiate( def contract_selection_initiate(self, on_initiated: Callable[[str], None]) -> None:
self, on_initiated: Callable[[str], None]) -> None: """Initiate the contract selection asynchronously. Calls on_initiated
""" Initiate the contract selection asynchronously. Calls on_initiated when the contract-selection has initiated."""
when the contract-selection has initiated. """
async def inner() -> None: async def inner() -> None:
answer = await self.endpoint.contract_selection.initiate.POST() answer = await self.endpoint.contract_selection.initiate.POST()
self.cs_initiated = True self.cs_initiated = True
on_initiated(answer.user_code) on_initiated(answer.user_code)
self._magic_attach_task = schedule_task(inner()) self._magic_attach_task = schedule_task(inner())
def contract_selection_wait( def contract_selection_wait(
self, self,
on_contract_selected: Callable[[str], None], on_contract_selected: Callable[[str], None],
on_timeout: Callable[[], None]) -> None: on_timeout: Callable[[], None],
""" Asynchronously wait for the contract selection to finish. ) -> None:
"""Asynchronously wait for the contract selection to finish.
Calls on_contract_selected with the contract-token once the contract Calls on_contract_selected with the contract-token once the contract
gets selected gets selected
Otherwise, calls on_timeout if no contract got selected within the Otherwise, calls on_timeout if no contract got selected within the
allowed time frame. """ allowed time frame."""
async def inner() -> None: async def inner() -> None:
try: try:
answer = await self.endpoint.contract_selection.wait.GET() answer = await self.endpoint.contract_selection.wait.GET()
@ -187,9 +189,9 @@ class UbuntuProController(SubiquityTuiController):
self._monitor_task = schedule_task(inner()) self._monitor_task = schedule_task(inner())
def contract_selection_cancel( def contract_selection_cancel(self, on_cancelled: Callable[[], None]) -> None:
self, on_cancelled: Callable[[], None]) -> None: """Cancel the asynchronous contract selection (if started)."""
""" Cancel the asynchronous contract selection (if started). """
async def inner() -> None: async def inner() -> None:
await self.endpoint.contract_selection.cancel.POST() await self.endpoint.contract_selection.cancel.POST()
self.cs_initiated = False self.cs_initiated = False
@ -203,12 +205,10 @@ class UbuntuProController(SubiquityTuiController):
self.app.prev_screen() self.app.prev_screen()
def done(self, token: str) -> None: def done(self, token: str) -> None:
""" Submit the token and move on to the next screen. """ """Submit the token and move on to the next screen."""
self.app.next_screen( self.app.next_screen(self.endpoint.POST(UbuntuProInfo(token=token)))
self.endpoint.POST(UbuntuProInfo(token=token))
)
def next_screen(self) -> None: def next_screen(self) -> None:
""" Move on to the next screen. Assume the token should not be """Move on to the next screen. Assume the token should not be
submitted (or has already been submitted). """ submitted (or has already been submitted)."""
self.app.next_screen() self.app.next_screen()

View File

@ -15,17 +15,16 @@
import logging import logging
from subiquitycore import i18n
from subiquitycore.tuicontroller import Skip
from subiquity.client.controller import SubiquityTuiController from subiquity.client.controller import SubiquityTuiController
from subiquity.ui.views.welcome import WelcomeView from subiquity.ui.views.welcome import WelcomeView
from subiquitycore import i18n
from subiquitycore.tuicontroller import Skip
log = logging.getLogger('subiquity.client.controllers.welcome') log = logging.getLogger("subiquity.client.controllers.welcome")
class WelcomeController(SubiquityTuiController): class WelcomeController(SubiquityTuiController):
endpoint_name = "locale"
endpoint_name = 'locale'
async def make_ui(self): async def make_ui(self):
if not self.app.rich_mode: if not self.app.rich_mode:
@ -36,11 +35,11 @@ class WelcomeController(SubiquityTuiController):
return WelcomeView(self, language, self.serial) return WelcomeView(self, language, self.serial)
def run_answers(self): def run_answers(self):
if 'lang' in self.answers: if "lang" in self.answers:
self.done((self.answers['lang'], "")) self.done((self.answers["lang"], ""))
def done(self, lang): def done(self, lang):
""" Completes this controller. lang must be a tuple of strings """Completes this controller. lang must be a tuple of strings
containing the language code and its native representation containing the language code and its native representation
respectively. respectively.
""" """

View File

@ -18,20 +18,18 @@ import logging
from subiquity.client.controller import SubiquityTuiController from subiquity.client.controller import SubiquityTuiController
from subiquity.ui.views import ZdevView from subiquity.ui.views import ZdevView
log = logging.getLogger("subiquity.client.controllers.zdev") log = logging.getLogger("subiquity.client.controllers.zdev")
class ZdevController(SubiquityTuiController): class ZdevController(SubiquityTuiController):
endpoint_name = "zdev"
endpoint_name = 'zdev'
async def make_ui(self): async def make_ui(self):
infos = await self.endpoint.GET() infos = await self.endpoint.GET()
return ZdevView(self, infos) return ZdevView(self, infos)
def run_answers(self): def run_answers(self):
if 'accept-default' in self.answers: if "accept-default" in self.answers:
self.done() self.done()
def cancel(self): def cancel(self):

View File

@ -19,7 +19,7 @@ import os
import struct import struct
import sys import sys
log = logging.getLogger('subiquity.client.keycodes') log = logging.getLogger("subiquity.client.keycodes")
# /usr/include/linux/kd.h # /usr/include/linux/kd.h
K_RAW = 0x00 K_RAW = 0x00
@ -46,7 +46,7 @@ class KeyCodesFilter:
""" """
def __init__(self): def __init__(self):
self._fd = os.open("/proc/self/fd/"+str(sys.stdin.fileno()), os.O_RDWR) self._fd = os.open("/proc/self/fd/" + str(sys.stdin.fileno()), os.O_RDWR)
self.filtering = False self.filtering = False
def enter_keycodes_mode(self): def enter_keycodes_mode(self):
@ -56,7 +56,7 @@ class KeyCodesFilter:
# well). # well).
o = bytearray(4) o = bytearray(4)
fcntl.ioctl(self._fd, KDGKBMODE, o) fcntl.ioctl(self._fd, KDGKBMODE, o)
self._old_mode = struct.unpack('i', o)[0] self._old_mode = struct.unpack("i", o)[0]
# Set the keyboard mode to K_MEDIUMRAW, which causes the keyboard # Set the keyboard mode to K_MEDIUMRAW, which causes the keyboard
# driver in the kernel to pass us keycodes. # driver in the kernel to pass us keycodes.
fcntl.ioctl(self._fd, KDSKBMODE, K_MEDIUMRAW) fcntl.ioctl(self._fd, KDSKBMODE, K_MEDIUMRAW)
@ -76,17 +76,16 @@ class KeyCodesFilter:
while i < len(codes): while i < len(codes):
# This is straight from showkeys.c. # This is straight from showkeys.c.
if codes[i] & 0x80: if codes[i] & 0x80:
p = 'release ' p = "release "
else: else:
p = 'press ' p = "press "
if i + 2 < n and (codes[i] & 0x7f) == 0: if i + 2 < n and (codes[i] & 0x7F) == 0:
if (codes[i + 1] & 0x80) != 0: if (codes[i + 1] & 0x80) != 0:
if (codes[i + 2] & 0x80) != 0: if (codes[i + 2] & 0x80) != 0:
kc = (((codes[i + 1] & 0x7f) << 7) | kc = ((codes[i + 1] & 0x7F) << 7) | (codes[i + 2] & 0x7F)
(codes[i + 2] & 0x7f))
i += 3 i += 3
else: else:
kc = codes[i] & 0x7f kc = codes[i] & 0x7F
i += 1 i += 1
r.append(p + str(kc)) r.append(p + str(kc))
return r return r

View File

@ -28,11 +28,13 @@ def setup_environment():
locale.setlocale(locale.LC_CTYPE, "C.UTF-8") locale.setlocale(locale.LC_CTYPE, "C.UTF-8")
# Prefer utils from $SUBIQUITY_ROOT, over system-wide # Prefer utils from $SUBIQUITY_ROOT, over system-wide
root = os.environ.get('SUBIQUITY_ROOT') root = os.environ.get("SUBIQUITY_ROOT")
if root: if root:
os.environ['PATH'] = os.pathsep.join([ os.environ["PATH"] = os.pathsep.join(
os.path.join(root, 'bin'), [
os.path.join(root, 'usr', 'bin'), os.path.join(root, "bin"),
os.environ['PATH'], os.path.join(root, "usr", "bin"),
]) os.environ["PATH"],
os.environ["APPORT_DATA_DIR"] = os.path.join(root, 'share/apport') ]
)
os.environ["APPORT_DATA_DIR"] = os.path.join(root, "share/apport")

View File

@ -29,16 +29,16 @@ from subiquity.server.server import SubiquityServer
def make_schema(app): def make_schema(app):
schema = copy.deepcopy(app.base_schema) schema = copy.deepcopy(app.base_schema)
for controller in app.controllers.instances: for controller in app.controllers.instances:
ckey = getattr(controller, 'autoinstall_key', None) ckey = getattr(controller, "autoinstall_key", None)
if ckey is None: if ckey is None:
continue continue
cschema = getattr(controller, "autoinstall_schema", None) cschema = getattr(controller, "autoinstall_schema", None)
if cschema is None: if cschema is None:
continue continue
schema['properties'][ckey] = cschema schema["properties"][ckey] = cschema
ckey_alias = getattr(controller, 'autoinstall_key_alias', None) ckey_alias = getattr(controller, "autoinstall_key_alias", None)
if ckey_alias is None: if ckey_alias is None:
continue continue
@ -46,15 +46,15 @@ def make_schema(app):
cschema["deprecated"] = True cschema["deprecated"] = True
cschema["description"] = f"Compatibility only - use {ckey} instead" cschema["description"] = f"Compatibility only - use {ckey} instead"
schema['properties'][ckey_alias] = cschema schema["properties"][ckey_alias] = cschema
return schema return schema
def make_app(): def make_app():
parser = make_server_args_parser() parser = make_server_args_parser()
opts, unknown = parser.parse_known_args(['--dry-run']) opts, unknown = parser.parse_known_args(["--dry-run"])
app = SubiquityServer(opts, '') app = SubiquityServer(opts, "")
# This is needed because the ubuntu-pro server controller accesses dr_cfg # This is needed because the ubuntu-pro server controller accesses dr_cfg
# in the initializer. # in the initializer.
app.dr_cfg = DRConfig() app.dr_cfg = DRConfig()
@ -72,5 +72,5 @@ def main():
asyncio.run(run_with_loop()) asyncio.run(run_with_loop())
if __name__ == '__main__': if __name__ == "__main__":
sys.exit(main()) sys.exit(main())

View File

@ -25,10 +25,7 @@ import attr
from subiquitycore.log import setup_logger from subiquitycore.log import setup_logger
from .common import ( from .common import LOGDIR, setup_environment
LOGDIR,
setup_environment,
)
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
@ -41,8 +38,8 @@ class CommandLineParams:
def from_cmdline(cls, cmdline): def from_cmdline(cls, cmdline):
r = cls(cmdline) r = cls(cmdline)
for tok in shlex.split(cmdline): for tok in shlex.split(cmdline):
if '=' in tok: if "=" in tok:
k, v = tok.split('=', 1) k, v = tok.split("=", 1)
r._values[k] = v r._values[k] = v
else: else:
r._tokens.add(tok) r._tokens.add(tok)
@ -57,74 +54,103 @@ class CommandLineParams:
def make_server_args_parser(): def make_server_args_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='SUbiquity - Ubiquity for Servers', description="SUbiquity - Ubiquity for Servers", prog="subiquity"
prog='subiquity') )
parser.add_argument('--dry-run', action='store_true',
dest='dry_run',
help='menu-only, do not call installer function')
parser.add_argument('--dry-run-config', type=argparse.FileType())
parser.add_argument('--socket')
parser.add_argument('--machine-config', metavar='CONFIG',
dest='machine_config',
type=argparse.FileType(),
help="Don't Probe. Use probe data file")
parser.add_argument('--bootloader',
choices=['none', 'bios', 'prep', 'uefi'],
help='Override style of bootloader to use')
parser.add_argument( parser.add_argument(
'--autoinstall', action='store', "--dry-run",
help=('Path to autoinstall file. Empty value disables autoinstall. ' action="store_true",
'By default tries to load /autoinstall.yaml, ' dest="dry_run",
'or autoinstall data from cloud-init.')) help="menu-only, do not call installer function",
with open('/proc/cmdline') as fp: )
parser.add_argument("--dry-run-config", type=argparse.FileType())
parser.add_argument("--socket")
parser.add_argument(
"--machine-config",
metavar="CONFIG",
dest="machine_config",
type=argparse.FileType(),
help="Don't Probe. Use probe data file",
)
parser.add_argument(
"--bootloader",
choices=["none", "bios", "prep", "uefi"],
help="Override style of bootloader to use",
)
parser.add_argument(
"--autoinstall",
action="store",
help=(
"Path to autoinstall file. Empty value disables autoinstall. "
"By default tries to load /autoinstall.yaml, "
"or autoinstall data from cloud-init."
),
)
with open("/proc/cmdline") as fp:
cmdline = fp.read() cmdline = fp.read()
parser.add_argument( parser.add_argument(
'--kernel-cmdline', action='store', default=cmdline, "--kernel-cmdline",
type=CommandLineParams.from_cmdline) action="store",
default=cmdline,
type=CommandLineParams.from_cmdline,
)
parser.add_argument( parser.add_argument(
'--snaps-from-examples', action='store_const', const=True, "--snaps-from-examples",
action="store_const",
const=True,
dest="snaps_from_examples", dest="snaps_from_examples",
help=("Load snap details from examples/snaps instead of store. " help=(
"Load snap details from examples/snaps instead of store. "
"Default in dry-run mode. " "Default in dry-run mode. "
"See examples/snaps/README.md for more.")) "See examples/snaps/README.md for more."
),
)
parser.add_argument( parser.add_argument(
'--no-snaps-from-examples', action='store_const', const=False, "--no-snaps-from-examples",
action="store_const",
const=False,
dest="snaps_from_examples", dest="snaps_from_examples",
help=("Load snap details from store instead of examples. " help=(
"Load snap details from store instead of examples. "
"Default in when not in dry-run mode. " "Default in when not in dry-run mode. "
"See examples/snaps/README.md for more.")) "See examples/snaps/README.md for more."
),
)
parser.add_argument( parser.add_argument(
'--snap-section', action='store', default='server', "--snap-section",
help=("Show snaps from this section of the store in the snap " action="store",
"list screen.")) default="server",
help=("Show snaps from this section of the store in the snap " "list screen."),
)
parser.add_argument("--source-catalog", dest="source_catalog", action="store")
parser.add_argument( parser.add_argument(
'--source-catalog', dest='source_catalog', action='store') "--output-base",
action="store",
dest="output_base",
default=".subiquity",
help="in dryrun, control basedir of files",
)
parser.add_argument("--storage-version", action="store", type=int)
parser.add_argument("--use-os-prober", action="store_true", default=False)
parser.add_argument( parser.add_argument(
'--output-base', action='store', dest='output_base', "--postinst-hooks-dir", default="/etc/subiquity/postinst.d", type=pathlib.Path
default='.subiquity', )
help='in dryrun, control basedir of files')
parser.add_argument(
'--storage-version', action='store', type=int)
parser.add_argument(
'--use-os-prober', action='store_true', default=False)
parser.add_argument(
'--postinst-hooks-dir', default='/etc/subiquity/postinst.d',
type=pathlib.Path)
return parser return parser
def main(): def main():
print('starting server') print("starting server")
setup_environment() setup_environment()
# setup_environment sets $APPORT_DATA_DIR which must be set before # setup_environment sets $APPORT_DATA_DIR which must be set before
# apport is imported, which is done by this import: # apport is imported, which is done by this import:
from subiquity.server.server import SubiquityServer
from subiquity.server.dryrun import DRConfig from subiquity.server.dryrun import DRConfig
from subiquity.server.server import SubiquityServer
parser = make_server_args_parser() parser = make_server_args_parser()
opts = parser.parse_args(sys.argv[1:]) opts = parser.parse_args(sys.argv[1:])
if opts.storage_version is None: if opts.storage_version is None:
opts.storage_version = int(opts.kernel_cmdline.get( opts.storage_version = int(
'subiquity-storage-version', 1)) opts.kernel_cmdline.get("subiquity-storage-version", 1)
)
logdir = LOGDIR logdir = LOGDIR
if opts.dry_run: if opts.dry_run:
if opts.dry_run_config: if opts.dry_run_config:
@ -139,25 +165,26 @@ def main():
dr_cfg = None dr_cfg = None
if opts.socket is None: if opts.socket is None:
if opts.dry_run: if opts.dry_run:
opts.socket = opts.output_base + '/socket' opts.socket = opts.output_base + "/socket"
else: else:
opts.socket = '/run/subiquity/socket' opts.socket = "/run/subiquity/socket"
os.makedirs(os.path.dirname(opts.socket), exist_ok=True) os.makedirs(os.path.dirname(opts.socket), exist_ok=True)
block_log_dir = os.path.join(logdir, "block") block_log_dir = os.path.join(logdir, "block")
os.makedirs(block_log_dir, exist_ok=True) os.makedirs(block_log_dir, exist_ok=True)
handler = logging.FileHandler(os.path.join(block_log_dir, 'discover.log')) handler = logging.FileHandler(os.path.join(block_log_dir, "discover.log"))
handler.setLevel('DEBUG') handler.setLevel("DEBUG")
handler.setFormatter( handler.setFormatter(
logging.Formatter("%(asctime)s %(name)s:%(lineno)d %(message)s")) logging.Formatter("%(asctime)s %(name)s:%(lineno)d %(message)s")
logging.getLogger('probert').addHandler(handler) )
handler.addFilter(lambda rec: rec.name != 'probert.network') logging.getLogger("probert").addHandler(handler)
logging.getLogger('curtin').addHandler(handler) handler.addFilter(lambda rec: rec.name != "probert.network")
logging.getLogger('block-discover').addHandler(handler) logging.getLogger("curtin").addHandler(handler)
logging.getLogger("block-discover").addHandler(handler)
logfiles = setup_logger(dir=logdir, base='subiquity-server') logfiles = setup_logger(dir=logdir, base="subiquity-server")
logger = logging.getLogger('subiquity') logger = logging.getLogger("subiquity")
version = os.environ.get("SNAP_REVISION", "unknown") version = os.environ.get("SNAP_REVISION", "unknown")
snap = os.environ.get("SNAP", "unknown") snap = os.environ.get("SNAP", "unknown")
logger.info(f"Starting Subiquity server revision {version} of snap {snap}") logger.info(f"Starting Subiquity server revision {version} of snap {snap}")
@ -168,18 +195,16 @@ def main():
async def run_with_loop(): async def run_with_loop():
server = SubiquityServer(opts, block_log_dir) server = SubiquityServer(opts, block_log_dir)
server.dr_cfg = dr_cfg server.dr_cfg = dr_cfg
server.note_file_for_apport( server.note_file_for_apport("InstallerServerLog", logfiles["debug"])
"InstallerServerLog", logfiles['debug']) server.note_file_for_apport("InstallerServerLogInfo", logfiles["info"])
server.note_file_for_apport(
"InstallerServerLogInfo", logfiles['info'])
server.note_file_for_apport( server.note_file_for_apport(
"UdiLog", "UdiLog",
os.path.realpath( os.path.realpath("/var/log/installer/ubuntu_desktop_installer.log"),
"/var/log/installer/ubuntu_desktop_installer.log")) )
await server.run() await server.run()
asyncio.run(run_with_loop()) asyncio.run(run_with_loop())
if __name__ == '__main__': if __name__ == "__main__":
sys.exit(main()) sys.exit(main())

View File

@ -16,53 +16,65 @@
import argparse import argparse
import asyncio import asyncio
import fcntl
import logging import logging
import os import os
import fcntl
import subprocess import subprocess
import sys import sys
from subiquitycore.log import setup_logger from subiquitycore.log import setup_logger
from .common import ( from .common import LOGDIR, setup_environment
LOGDIR,
setup_environment,
)
from .server import make_server_args_parser from .server import make_server_args_parser
def make_client_args_parser(): def make_client_args_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Subiquity - Ubiquity for Servers', description="Subiquity - Ubiquity for Servers", prog="subiquity"
prog='subiquity') )
try: try:
ascii_default = os.ttyname(0) == "/dev/ttysclp0" ascii_default = os.ttyname(0) == "/dev/ttysclp0"
except OSError: except OSError:
ascii_default = False ascii_default = False
parser.set_defaults(ascii=ascii_default) parser.set_defaults(ascii=ascii_default)
parser.add_argument('--dry-run', action='store_true', parser.add_argument(
dest='dry_run', "--dry-run",
help='menu-only, do not call installer function') action="store_true",
parser.add_argument('--socket') dest="dry_run",
parser.add_argument('--serial', action='store_true', help="menu-only, do not call installer function",
dest='run_on_serial', )
help='Run the installer over serial console.') parser.add_argument("--socket")
parser.add_argument('--ssh', action='store_true', parser.add_argument(
dest='ssh', "--serial",
help='Print ssh login details') action="store_true",
parser.add_argument('--ascii', action='store_true', dest="run_on_serial",
dest='ascii', help="Run the installer over serial console.",
help='Run the installer in ascii mode.') )
parser.add_argument('--unicode', action='store_false', parser.add_argument(
dest='ascii', "--ssh", action="store_true", dest="ssh", help="Print ssh login details"
help='Run the installer in unicode mode.') )
parser.add_argument('--screens', action='append', dest='screens', parser.add_argument(
default=[]) "--ascii",
parser.add_argument('--answers') action="store_true",
parser.add_argument('--server-pid') dest="ascii",
parser.add_argument('--output-base', action='store', dest='output_base', help="Run the installer in ascii mode.",
default='.subiquity', )
help='in dryrun, control basedir of files') parser.add_argument(
"--unicode",
action="store_false",
dest="ascii",
help="Run the installer in unicode mode.",
)
parser.add_argument("--screens", action="append", dest="screens", default=[])
parser.add_argument("--answers")
parser.add_argument("--server-pid")
parser.add_argument(
"--output-base",
action="store",
dest="output_base",
default=".subiquity",
help="in dryrun, control basedir of files",
)
return parser return parser
@ -74,27 +86,28 @@ def main():
# setup_environment sets $APPORT_DATA_DIR which must be set before # setup_environment sets $APPORT_DATA_DIR which must be set before
# apport is imported, which is done by this import: # apport is imported, which is done by this import:
from subiquity.client.client import SubiquityClient from subiquity.client.client import SubiquityClient
parser = make_client_args_parser() parser = make_client_args_parser()
args = sys.argv[1:] args = sys.argv[1:]
if '--dry-run' in args: if "--dry-run" in args:
opts, unknown = parser.parse_known_args(args) opts, unknown = parser.parse_known_args(args)
if opts.socket is None: if opts.socket is None:
base = opts.output_base base = opts.output_base
os.makedirs(base, exist_ok=True) os.makedirs(base, exist_ok=True)
if os.path.exists(f'{base}/run/subiquity/server-state'): if os.path.exists(f"{base}/run/subiquity/server-state"):
os.unlink(f'{base}/run/subiquity/server-state') os.unlink(f"{base}/run/subiquity/server-state")
sock_path = base + '/socket' sock_path = base + "/socket"
opts.socket = sock_path opts.socket = sock_path
server_args = ['--dry-run', '--socket=' + sock_path] + unknown server_args = ["--dry-run", "--socket=" + sock_path] + unknown
server_args.extend(('--output-base', base)) server_args.extend(("--output-base", base))
server_parser = make_server_args_parser() server_parser = make_server_args_parser()
server_parser.parse_args(server_args) # just to check server_parser.parse_args(server_args) # just to check
server_stdout = open(os.path.join(base, 'server-stdout'), 'w') server_stdout = open(os.path.join(base, "server-stdout"), "w")
server_stderr = open(os.path.join(base, 'server-stderr'), 'w') server_stderr = open(os.path.join(base, "server-stderr"), "w")
server_cmd = [sys.executable, '-m', 'subiquity.cmd.server'] + \ server_cmd = [sys.executable, "-m", "subiquity.cmd.server"] + server_args
server_args
server_proc = subprocess.Popen( server_proc = subprocess.Popen(
server_cmd, stdout=server_stdout, stderr=server_stderr) server_cmd, stdout=server_stdout, stderr=server_stderr
)
opts.server_pid = str(server_proc.pid) opts.server_pid = str(server_proc.pid)
print("running server pid {}".format(server_proc.pid)) print("running server pid {}".format(server_proc.pid))
elif opts.server_pid is not None: elif opts.server_pid is not None:
@ -104,14 +117,14 @@ def main():
else: else:
opts = parser.parse_args(args) opts = parser.parse_args(args)
if opts.socket is None: if opts.socket is None:
opts.socket = '/run/subiquity/socket' opts.socket = "/run/subiquity/socket"
os.makedirs(os.path.basename(opts.socket), exist_ok=True) os.makedirs(os.path.basename(opts.socket), exist_ok=True)
logdir = LOGDIR logdir = LOGDIR
if opts.dry_run: if opts.dry_run:
logdir = opts.output_base logdir = opts.output_base
logfiles = setup_logger(dir=logdir, base='subiquity-client') logfiles = setup_logger(dir=logdir, base="subiquity-client")
logger = logging.getLogger('subiquity') logger = logging.getLogger("subiquity")
version = os.environ.get("SNAP_REVISION", "unknown") version = os.environ.get("SNAP_REVISION", "unknown")
snap = os.environ.get("SNAP", "unknown") snap = os.environ.get("SNAP", "unknown")
logger.info(f"Starting Subiquity TUI revision {version} of snap {snap}") logger.info(f"Starting Subiquity TUI revision {version} of snap {snap}")
@ -127,21 +140,18 @@ def main():
try: try:
fcntl.flock(opts.answers, fcntl.LOCK_EX | fcntl.LOCK_NB) fcntl.flock(opts.answers, fcntl.LOCK_EX | fcntl.LOCK_NB)
except OSError: except OSError:
logger.exception( logger.exception("Failed to lock auto answers file, proceding without it.")
'Failed to lock auto answers file, proceding without it.')
opts.answers.close() opts.answers.close()
opts.answers = None opts.answers = None
async def run_with_loop(): async def run_with_loop():
subiquity_interface = SubiquityClient(opts) subiquity_interface = SubiquityClient(opts)
subiquity_interface.note_file_for_apport( subiquity_interface.note_file_for_apport("InstallerLog", logfiles["debug"])
"InstallerLog", logfiles['debug']) subiquity_interface.note_file_for_apport("InstallerLogInfo", logfiles["info"])
subiquity_interface.note_file_for_apport(
"InstallerLogInfo", logfiles['info'])
await subiquity_interface.run() await subiquity_interface.run()
asyncio.run(run_with_loop()) asyncio.run(run_with_loop())
if __name__ == '__main__': if __name__ == "__main__":
sys.exit(main()) sys.exit(main())

View File

@ -19,6 +19,7 @@ import inspect
import aiohttp import aiohttp
from subiquity.common.serialize import Serializer from subiquity.common.serialize import Serializer
from .defs import Payload from .defs import Payload
@ -27,7 +28,7 @@ def _wrap(make_request, path, meth, serializer, serialize_query_args):
meth_params = sig.parameters meth_params = sig.parameters
payload_arg = None payload_arg = None
for name, param in meth_params.items(): for name, param in meth_params.items():
if getattr(param.annotation, '__origin__', None) is Payload: if getattr(param.annotation, "__origin__", None) is Payload:
payload_arg = name payload_arg = name
payload_ann = param.annotation.__args__[0] payload_ann = param.annotation.__args__[0]
r_ann = sig.return_annotation r_ann = sig.return_annotation
@ -41,14 +42,14 @@ def _wrap(make_request, path, meth, serializer, serialize_query_args):
data = serializer.serialize(payload_ann, value) data = serializer.serialize(payload_ann, value)
else: else:
if serialize_query_args: if serialize_query_args:
value = serializer.to_json( value = serializer.to_json(meth_params[arg_name].annotation, value)
meth_params[arg_name].annotation, value)
query_args[arg_name] = value query_args[arg_name] = value
async with make_request( async with make_request(
meth.__name__, path.format(**self.path_args), meth.__name__, path.format(**self.path_args), json=data, params=query_args
json=data, params=query_args) as resp: ) as resp:
resp.raise_for_status() resp.raise_for_status()
return serializer.deserialize(r_ann, await resp.json()) return serializer.deserialize(r_ann, await resp.json())
return impl return impl
@ -73,19 +74,24 @@ def make_client_cls(endpoint_cls, make_request, serializer=None):
if serializer is None: if serializer is None:
serializer = Serializer() serializer = Serializer()
ns = {'__init__': client_init} ns = {"__init__": client_init}
for k, v in endpoint_cls.__dict__.items(): for k, v in endpoint_cls.__dict__.items():
if isinstance(v, type): if isinstance(v, type):
if getattr(v, '__parameter__', False): if getattr(v, "__parameter__", False):
ns['__getitem__'] = make_getitem(v, make_request, serializer) ns["__getitem__"] = make_getitem(v, make_request, serializer)
else: else:
ns[k] = make_client(v, make_request, serializer) ns[k] = make_client(v, make_request, serializer)
elif callable(v): elif callable(v):
ns[k] = _wrap(make_request, endpoint_cls.fullpath, v, serializer, ns[k] = _wrap(
endpoint_cls.serialize_query_args) make_request,
endpoint_cls.fullpath,
v,
serializer,
endpoint_cls.serialize_query_args,
)
return type('ClientFor({})'.format(endpoint_cls.__name__), (object,), ns) return type("ClientFor({})".format(endpoint_cls.__name__), (object,), ns)
def make_client(endpoint_cls, make_request, serializer=None, path_args=None): def make_client(endpoint_cls, make_request, serializer=None, path_args=None):
@ -93,10 +99,9 @@ def make_client(endpoint_cls, make_request, serializer=None, path_args=None):
def make_client_for_conn( def make_client_for_conn(
endpoint_cls, conn, resp_hook=lambda r: r, serializer=None, endpoint_cls, conn, resp_hook=lambda r: r, serializer=None, header_func=None
header_func=None): ):
session = aiohttp.ClientSession( session = aiohttp.ClientSession(connector=conn, connector_owner=False)
connector=conn, connector_owner=False)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def make_request(method, path, *, params, json): async def make_request(method, path, *, params, json):
@ -105,14 +110,14 @@ def make_client_for_conn(
# hardcode something here (I guess the "a" gets sent along to the # hardcode something here (I guess the "a" gets sent along to the
# server in the Host: header and the server could in principle do # server in the Host: header and the server could in principle do
# something like virtual host based selection but well....) # something like virtual host based selection but well....)
url = 'http://a' + path url = "http://a" + path
if header_func is not None: if header_func is not None:
headers = header_func() headers = header_func()
else: else:
headers = None headers = None
async with session.request( async with session.request(
method, url, json=json, params=params, method, url, json=json, params=params, headers=headers, timeout=0
headers=headers, timeout=0) as response: ) as response:
yield resp_hook(response) yield resp_hook(response)
return make_client(endpoint_cls, make_request, serializer) return make_client(endpoint_cls, make_request, serializer)

View File

@ -27,8 +27,10 @@ class InvalidQueryArgs(InvalidAPIDefinition):
self.param = param self.param = param
def __str__(self): def __str__(self):
return (f"{self.callable.__qualname__} does not serialize query " return (
f"arguments but has non-str parameter '{self.param}'") f"{self.callable.__qualname__} does not serialize query "
f"arguments but has non-str parameter '{self.param}'"
)
class MultiplePathParameters(InvalidAPIDefinition): class MultiplePathParameters(InvalidAPIDefinition):
@ -38,45 +40,50 @@ class MultiplePathParameters(InvalidAPIDefinition):
self.param2 = param2 self.param2 = param2
def __str__(self): def __str__(self):
return (f"{self.cls.__name__} has multiple path parameters " return (
f"{self.param1!r} and {self.param2!r}") f"{self.cls.__name__} has multiple path parameters "
f"{self.param1!r} and {self.param2!r}"
)
def api(cls, prefix_names=(), prefix_path=(), path_params=(), def api(
serialize_query_args=True): cls, prefix_names=(), prefix_path=(), path_params=(), serialize_query_args=True
if hasattr(cls, 'serialize_query_args'): ):
if hasattr(cls, "serialize_query_args"):
serialize_query_args = cls.serialize_query_args serialize_query_args = cls.serialize_query_args
else: else:
cls.serialize_query_args = serialize_query_args cls.serialize_query_args = serialize_query_args
cls.fullpath = '/' + '/'.join(prefix_path) cls.fullpath = "/" + "/".join(prefix_path)
cls.fullname = prefix_names cls.fullname = prefix_names
seen_path_param = None seen_path_param = None
for k, v in cls.__dict__.items(): for k, v in cls.__dict__.items():
if isinstance(v, type): if isinstance(v, type):
v.__shortname__ = k v.__shortname__ = k
v.__name__ = cls.__name__ + '.' + k v.__name__ = cls.__name__ + "." + k
path_part = k path_part = k
path_param = () path_param = ()
if getattr(v, '__parameter__', False): if getattr(v, "__parameter__", False):
if seen_path_param: if seen_path_param:
raise MultiplePathParameters(cls, seen_path_param, k) raise MultiplePathParameters(cls, seen_path_param, k)
seen_path_param = k seen_path_param = k
path_part = '{' + path_part + '}' path_part = "{" + path_part + "}"
path_param = (k,) path_param = (k,)
api( api(
v, v,
prefix_names + (k,), prefix_names + (k,),
prefix_path + (path_part,), prefix_path + (path_part,),
path_params + path_param, path_params + path_param,
serialize_query_args) serialize_query_args,
)
if callable(v): if callable(v):
v.__qualname__ = cls.__name__ + '.' + k v.__qualname__ = cls.__name__ + "." + k
if not cls.serialize_query_args: if not cls.serialize_query_args:
params = inspect.signature(v).parameters params = inspect.signature(v).parameters
for param_name, param in params.items(): for param_name, param in params.items():
if param.annotation is not str and \ if (
getattr(param.annotation, '__origin__', None) is not \ param.annotation is not str
Payload: and getattr(param.annotation, "__origin__", None) is not Payload
):
raise InvalidQueryArgs(v, param) raise InvalidQueryArgs(v, param)
v.__path_params__ = path_params v.__path_params__ = path_params
return cls return cls
@ -96,8 +103,12 @@ def path_parameter(cls):
def simple_endpoint(typ): def simple_endpoint(typ):
class endpoint: class endpoint:
def GET() -> typ: ... def GET() -> typ:
def POST(data: Payload[typ]): ... ...
def POST(data: Payload[typ]):
...
return endpoint return endpoint

View File

@ -24,7 +24,7 @@ from subiquity.common.serialize import Serializer
from .defs import Payload from .defs import Payload
log = logging.getLogger('subiquity.common.api.server') log = logging.getLogger("subiquity.common.api.server")
class BindError(Exception): class BindError(Exception):
@ -47,24 +47,26 @@ class SignatureMisatchError(BindError):
self.actual = actual self.actual = actual
def __str__(self): def __str__(self):
return (f"implementation of {self.methname} has wrong signature, " return (
f"should be {self.expected} but is {self.actual}") f"implementation of {self.methname} has wrong signature, "
f"should be {self.expected} but is {self.actual}"
)
def trim(text): def trim(text):
if text is None: if text is None:
return '' return ""
elif len(text) > 80: elif len(text) > 80:
return text[:77] + '...' return text[:77] + "..."
else: else:
return text return text
async def check_controllers_started(definition, controller, request): async def check_controllers_started(definition, controller, request):
if not hasattr(controller, 'app'): if not hasattr(controller, "app"):
return return
if getattr(definition, 'allowed_before_start', False): if getattr(definition, "allowed_before_start", False):
return return
# Most endpoints should not be responding to requests # Most endpoints should not be responding to requests
@ -76,13 +78,14 @@ async def check_controllers_started(definition, controller, request):
if controller.app.controllers_have_started.is_set(): if controller.app.controllers_have_started.is_set():
return return
log.debug(f'{request.path} waiting on start') log.debug(f"{request.path} waiting on start")
await controller.app.controllers_have_started.wait() await controller.app.controllers_have_started.wait()
log.debug(f'{request.path} resuming') log.debug(f"{request.path} resuming")
def _make_handler(controller, definition, implementation, serializer, def _make_handler(
serialize_query_args): controller, definition, implementation, serializer, serialize_query_args
):
def_sig = inspect.signature(definition) def_sig = inspect.signature(definition)
def_ret_ann = def_sig.return_annotation def_ret_ann = def_sig.return_annotation
def_params = def_sig.parameters def_params = def_sig.parameters
@ -99,27 +102,26 @@ def _make_handler(controller, definition, implementation, serializer,
for param_name in definition.__path_params__: for param_name in definition.__path_params__:
check_def_params.append( check_def_params.append(
inspect.Parameter( inspect.Parameter(
param_name, param_name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=str
inspect.Parameter.POSITIONAL_OR_KEYWORD, )
annotation=str)) )
for param_name, param in def_params.items(): for param_name, param in def_params.items():
if param_name in ('request', 'context'): if param_name in ("request", "context"):
raise Exception( raise Exception(
"api method {} cannot have parameter called request or " "api method {} cannot have parameter called request or "
"context".format(definition)) "context".format(definition)
if getattr(param.annotation, '__origin__', None) is Payload: )
if getattr(param.annotation, "__origin__", None) is Payload:
data_arg = param_name data_arg = param_name
data_annotation = param.annotation.__args__[0] data_annotation = param.annotation.__args__[0]
check_def_params.append(param.replace(annotation=data_annotation)) check_def_params.append(param.replace(annotation=data_annotation))
else: else:
query_args_anns.append( query_args_anns.append((param_name, param.annotation, param.default))
(param_name, param.annotation, param.default))
check_def_params.append(param) check_def_params.append(param)
check_impl_params = [ check_impl_params = [
p for p in impl_params.values() p for p in impl_params.values() if p.name not in ("context", "request")
if p.name not in ('context', 'request')
] ]
check_impl_sig = impl_sig.replace(parameters=check_impl_params) check_impl_sig = impl_sig.replace(parameters=check_impl_params)
@ -127,17 +129,19 @@ def _make_handler(controller, definition, implementation, serializer,
if check_impl_sig != check_def_sig: if check_impl_sig != check_def_sig:
raise SignatureMisatchError( raise SignatureMisatchError(
definition.__qualname__, check_def_sig, check_impl_sig) definition.__qualname__, check_def_sig, check_impl_sig
)
async def handler(request): async def handler(request):
context = controller.context.child(implementation.__name__) context = controller.context.child(implementation.__name__)
with context: with context:
context.set('request', request) context.set("request", request)
args = {} args = {}
try: try:
if data_annotation is not None: if data_annotation is not None:
args[data_arg] = serializer.from_json( args[data_arg] = serializer.from_json(
data_annotation, await request.text()) data_annotation, await request.text()
)
for arg, ann, default in query_args_anns: for arg, ann, default in query_args_anns:
if arg in request.query: if arg in request.query:
v = request.query[arg] v = request.query[arg]
@ -146,34 +150,35 @@ def _make_handler(controller, definition, implementation, serializer,
elif default != inspect._empty: elif default != inspect._empty:
v = default v = default
else: else:
raise TypeError( raise TypeError('missing required argument "{}"'.format(arg))
'missing required argument "{}"'.format(arg))
args[arg] = v args[arg] = v
for param_name in definition.__path_params__: for param_name in definition.__path_params__:
args[param_name] = request.match_info[param_name] args[param_name] = request.match_info[param_name]
if 'context' in impl_params: if "context" in impl_params:
args['context'] = context args["context"] = context
if 'request' in impl_params: if "request" in impl_params:
args['request'] = request args["request"] = request
await check_controllers_started( await check_controllers_started(definition, controller, request)
definition, controller, request)
result = await implementation(**args) result = await implementation(**args)
resp = web.json_response( resp = web.json_response(
serializer.serialize(def_ret_ann, result), serializer.serialize(def_ret_ann, result),
headers={'x-status': 'ok'}) headers={"x-status": "ok"},
)
except Exception as exc: except Exception as exc:
tb = traceback.TracebackException.from_exception(exc) tb = traceback.TracebackException.from_exception(exc)
resp = web.Response( resp = web.Response(
text="".join(tb.format()), text="".join(tb.format()),
status=500, status=500,
headers={ headers={
'x-status': 'error', "x-status": "error",
'x-error-type': type(exc).__name__, "x-error-type": type(exc).__name__,
'x-error-msg': str(exc), "x-error-msg": str(exc),
}) },
resp['exception'] = exc )
context.description = '{} {}'.format(resp.status, trim(resp.text)) resp["exception"] = exc
context.description = "{} {}".format(resp.status, trim(resp.text))
return resp return resp
handler.controller = controller handler.controller = controller
return handler return handler
@ -181,7 +186,7 @@ def _make_handler(controller, definition, implementation, serializer,
async def controller_for_request(request): async def controller_for_request(request):
match_info = await request.app.router.resolve(request) match_info = await request.app.router.resolve(request)
return getattr(match_info.handler, 'controller', None) return getattr(match_info.handler, "controller", None)
def bind(router, endpoint, controller, serializer=None, _depth=None): def bind(router, endpoint, controller, serializer=None, _depth=None):
@ -203,8 +208,9 @@ def bind(router, endpoint, controller, serializer=None, _depth=None):
method=method, method=method,
path=endpoint.fullpath, path=endpoint.fullpath,
handler=_make_handler( handler=_make_handler(
controller, v, impl, serializer, controller, v, impl, serializer, endpoint.serialize_query_args
endpoint.serialize_query_args)) ),
)
async def make_server_at_path(socket_path, endpoint, controller, **kw): async def make_server_at_path(socket_path, endpoint, controller, **kw):

View File

@ -17,12 +17,7 @@ import contextlib
import unittest import unittest
from subiquity.common.api.client import make_client from subiquity.common.api.client import make_client
from subiquity.common.api.defs import ( from subiquity.common.api.defs import InvalidQueryArgs, Payload, api, path_parameter
api,
InvalidQueryArgs,
path_parameter,
Payload,
)
def extract(c): def extract(c):
@ -46,20 +41,21 @@ class FakeResponse:
class TestClient(unittest.TestCase): class TestClient(unittest.TestCase):
def test_simple(self): def test_simple(self):
@api @api
class API: class API:
class endpoint: class endpoint:
def GET() -> str: ... def GET() -> str:
def POST(data: Payload[str]) -> None: ... ...
def POST(data: Payload[str]) -> None:
...
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def make_request(method, path, *, params, json): async def make_request(method, path, *, params, json):
requests.append((method, path, params, json)) requests.append((method, path, params, json))
if method == "GET": if method == "GET":
v = 'value' v = "value"
else: else:
v = None v = None
yield FakeResponse(v) yield FakeResponse(v)
@ -68,85 +64,95 @@ class TestClient(unittest.TestCase):
requests = [] requests = []
r = extract(client.endpoint.GET()) r = extract(client.endpoint.GET())
self.assertEqual(r, 'value') self.assertEqual(r, "value")
self.assertEqual(requests, [("GET", '/endpoint', {}, None)]) self.assertEqual(requests, [("GET", "/endpoint", {}, None)])
requests = [] requests = []
r = extract(client.endpoint.POST('value')) r = extract(client.endpoint.POST("value"))
self.assertEqual(r, None) self.assertEqual(r, None)
self.assertEqual( self.assertEqual(requests, [("POST", "/endpoint", {}, "value")])
requests, [("POST", '/endpoint', {}, 'value')])
def test_args(self): def test_args(self):
@api @api
class API: class API:
def GET(arg: str) -> str: ... def GET(arg: str) -> str:
...
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def make_request(method, path, *, params, json): async def make_request(method, path, *, params, json):
requests.append((method, path, params, json)) requests.append((method, path, params, json))
yield FakeResponse(params['arg']) yield FakeResponse(params["arg"])
client = make_client(API, make_request) client = make_client(API, make_request)
requests = [] requests = []
r = extract(client.GET(arg='v')) r = extract(client.GET(arg="v"))
self.assertEqual(r, '"v"') self.assertEqual(r, '"v"')
self.assertEqual(requests, [("GET", '/', {'arg': '"v"'}, None)]) self.assertEqual(requests, [("GET", "/", {"arg": '"v"'}, None)])
def test_path_params(self): def test_path_params(self):
@api @api
class API: class API:
@path_parameter @path_parameter
class param: class param:
def GET(arg: str) -> str: ... def GET(arg: str) -> str:
...
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def make_request(method, path, *, params, json): async def make_request(method, path, *, params, json):
requests.append((method, path, params, json)) requests.append((method, path, params, json))
yield FakeResponse(params['arg']) yield FakeResponse(params["arg"])
client = make_client(API, make_request) client = make_client(API, make_request)
requests = [] requests = []
r = extract(client['foo'].GET(arg='v')) r = extract(client["foo"].GET(arg="v"))
self.assertEqual(r, '"v"') self.assertEqual(r, '"v"')
self.assertEqual(requests, [("GET", '/foo', {'arg': '"v"'}, None)]) self.assertEqual(requests, [("GET", "/foo", {"arg": '"v"'}, None)])
def test_serialize_query_args(self): def test_serialize_query_args(self):
@api @api
class API: class API:
serialize_query_args = False serialize_query_args = False
def GET(arg: str) -> str: ...
def GET(arg: str) -> str:
...
class meth: class meth:
def GET(arg: str, payload: Payload[int]) -> str: ... def GET(arg: str, payload: Payload[int]) -> str:
...
class more: class more:
serialize_query_args = True serialize_query_args = True
def GET(arg: str) -> str: ...
def GET(arg: str) -> str:
...
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def make_request(method, path, *, params, json): async def make_request(method, path, *, params, json):
requests.append((method, path, params, json)) requests.append((method, path, params, json))
yield FakeResponse(params['arg']) yield FakeResponse(params["arg"])
client = make_client(API, make_request) client = make_client(API, make_request)
requests = [] requests = []
extract(client.GET(arg='v')) extract(client.GET(arg="v"))
extract(client.meth.GET(arg='v', payload=1)) extract(client.meth.GET(arg="v", payload=1))
extract(client.meth.more.GET(arg='v')) extract(client.meth.more.GET(arg="v"))
self.assertEqual(requests, [ self.assertEqual(
("GET", '/', {'arg': 'v'}, None), requests,
("GET", '/meth', {'arg': 'v'}, 1), [
("GET", '/meth/more', {'arg': '"v"'}, None)]) ("GET", "/", {"arg": "v"}, None),
("GET", "/meth", {"arg": "v"}, 1),
("GET", "/meth/more", {"arg": '"v"'}, None),
],
)
class API2: class API2:
serialize_query_args = False serialize_query_args = False
def GET(arg: int) -> str: ...
def GET(arg: int) -> str:
...
with self.assertRaises(InvalidQueryArgs): with self.assertRaises(InvalidQueryArgs):
api(API2) api(API2)

View File

@ -13,82 +13,77 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import attr
import contextlib import contextlib
import functools import functools
import unittest import unittest
import aiohttp import aiohttp
import attr
from aiohttp import web from aiohttp import web
from subiquity.common.api.client import make_client from subiquity.common.api.client import make_client
from subiquity.common.api.defs import ( from subiquity.common.api.defs import (
api,
MultiplePathParameters, MultiplePathParameters,
path_parameter,
Payload, Payload,
) api,
path_parameter,
)
from .test_server import ( from .test_server import ControllerBase, makeTestClient
makeTestClient,
ControllerBase,
)
def make_request(client, method, path, *, params, json): def make_request(client, method, path, *, params, json):
return client.request( return client.request(method, path, params=params, json=json)
method, path, params=params, json=json)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def makeE2EClient(api, impl, async def makeE2EClient(api, impl, *, middlewares=(), make_request=make_request):
*, middlewares=(), make_request=make_request): async with makeTestClient(api, impl, middlewares=middlewares) as client:
async with makeTestClient(
api, impl, middlewares=middlewares) as client:
mr = functools.partial(make_request, client) mr = functools.partial(make_request, client)
yield make_client(api, mr) yield make_client(api, mr)
class TestEndToEnd(unittest.IsolatedAsyncioTestCase): class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
async def test_simple(self): async def test_simple(self):
@api @api
class API: class API:
def GET() -> str: ... def GET() -> str:
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def GET(self) -> str: async def GET(self) -> str:
return 'value' return "value"
async with makeE2EClient(API, Impl()) as client: async with makeE2EClient(API, Impl()) as client:
self.assertEqual(await client.GET(), 'value') self.assertEqual(await client.GET(), "value")
async def test_nested(self): async def test_nested(self):
@api @api
class API: class API:
class endpoint: class endpoint:
class nested: class nested:
def GET() -> str: ... def GET() -> str:
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def endpoint_nested_GET(self) -> str: async def endpoint_nested_GET(self) -> str:
return 'value' return "value"
async with makeE2EClient(API, Impl()) as client: async with makeE2EClient(API, Impl()) as client:
self.assertEqual(await client.endpoint.nested.GET(), 'value') self.assertEqual(await client.endpoint.nested.GET(), "value")
async def test_args(self): async def test_args(self):
@api @api
class API: class API:
def GET(arg1: str, arg2: str) -> str: ... def GET(arg1: str, arg2: str) -> str:
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def GET(self, arg1: str, arg2: str) -> str: async def GET(self, arg1: str, arg2: str) -> str:
return '{}+{}'.format(arg1, arg2) return "{}+{}".format(arg1, arg2)
async with makeE2EClient(API, Impl()) as client: async with makeE2EClient(API, Impl()) as client:
self.assertEqual( self.assertEqual(await client.GET(arg1="A", arg2="B"), "A+B")
await client.GET(arg1="A", arg2="B"), 'A+B')
async def test_path_args(self): async def test_path_args(self):
@api @api
@ -97,26 +92,29 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
class param1: class param1:
@path_parameter @path_parameter
class param2: class param2:
def GET(arg1: str, arg2: str) -> str: ... def GET(arg1: str, arg2: str) -> str:
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def param1_param2_GET(self, param1: str, param2: str, async def param1_param2_GET(
arg1: str, arg2: str) -> str: self, param1: str, param2: str, arg1: str, arg2: str
return '{}+{}+{}+{}'.format(param1, param2, arg1, arg2) ) -> str:
return "{}+{}+{}+{}".format(param1, param2, arg1, arg2)
async with makeE2EClient(API, Impl()) as client: async with makeE2EClient(API, Impl()) as client:
self.assertEqual( self.assertEqual(await client["1"]["2"].GET(arg1="A", arg2="B"), "1+2+A+B")
await client["1"]["2"].GET(arg1="A", arg2="B"), '1+2+A+B')
def test_only_one_path_parameter(self): def test_only_one_path_parameter(self):
class API: class API:
@path_parameter @path_parameter
class param1: class param1:
def GET(arg: str) -> str: ... def GET(arg: str) -> str:
...
@path_parameter @path_parameter
class param2: class param2:
def GET(arg: str) -> str: ... def GET(arg: str) -> str:
...
with self.assertRaises(MultiplePathParameters): with self.assertRaises(MultiplePathParameters):
api(API) api(API)
@ -124,35 +122,32 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
async def test_defaults(self): async def test_defaults(self):
@api @api
class API: class API:
def GET(arg1: str = "arg1", arg2: str = "arg2") -> str: ... def GET(arg1: str = "arg1", arg2: str = "arg2") -> str:
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def GET(self, arg1: str = "arg1", arg2: str = "arg2") -> str: async def GET(self, arg1: str = "arg1", arg2: str = "arg2") -> str:
return '{}+{}'.format(arg1, arg2) return "{}+{}".format(arg1, arg2)
async with makeE2EClient(API, Impl()) as client: async with makeE2EClient(API, Impl()) as client:
self.assertEqual( self.assertEqual(await client.GET(arg1="A", arg2="B"), "A+B")
await client.GET(arg1="A", arg2="B"), 'A+B') self.assertEqual(await client.GET(arg1="A"), "A+arg2")
self.assertEqual( self.assertEqual(await client.GET(arg2="B"), "arg1+B")
await client.GET(arg1="A"), 'A+arg2')
self.assertEqual(
await client.GET(arg2="B"), 'arg1+B')
async def test_post(self): async def test_post(self):
@api @api
class API: class API:
def POST(data: Payload[dict]) -> str: ... def POST(data: Payload[dict]) -> str:
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def POST(self, data: dict) -> str: async def POST(self, data: dict) -> str:
return data['key'] return data["key"]
async with makeE2EClient(API, Impl()) as client: async with makeE2EClient(API, Impl()) as client:
self.assertEqual( self.assertEqual(await client.POST({"key": "value"}), "value")
await client.POST({'key': 'value'}), 'value')
async def test_typed(self): async def test_typed(self):
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class In: class In:
val: int val: int
@ -164,11 +159,12 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
@api @api
class API: class API:
class doubler: class doubler:
def POST(data: In) -> Out: ... def POST(data: In) -> Out:
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def doubler_POST(self, data: In) -> Out: async def doubler_POST(self, data: In) -> Out:
return Out(doubled=data.val*2) return Out(doubled=data.val * 2)
async with makeE2EClient(API, Impl()) as client: async with makeE2EClient(API, Impl()) as client:
out = await client.doubler.POST(In(3)) out = await client.doubler.POST(In(3))
@ -177,17 +173,16 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
async def test_middleware(self): async def test_middleware(self):
@api @api
class API: class API:
def GET() -> int: ... def GET() -> int:
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def GET(self) -> int: async def GET(self) -> int:
return 1/0 return 1 / 0
@web.middleware @web.middleware
async def middleware(request, handler): async def middleware(request, handler):
return web.Response( return web.Response(status=200, headers={"x-status": "skip"})
status=200,
headers={'x-status': 'skip'})
class Skip(Exception): class Skip(Exception):
pass pass
@ -195,15 +190,15 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def custom_make_request(client, method, path, *, params, json): async def custom_make_request(client, method, path, *, params, json):
async with make_request( async with make_request(
client, method, path, params=params, json=json) as resp: client, method, path, params=params, json=json
if resp.headers.get('x-status') == 'skip': ) as resp:
if resp.headers.get("x-status") == "skip":
raise Skip raise Skip
yield resp yield resp
async with makeE2EClient( async with makeE2EClient(
API, Impl(), API, Impl(), middlewares=[middleware], make_request=custom_make_request
middlewares=[middleware], ) as client:
make_request=custom_make_request) as client:
with self.assertRaises(Skip): with self.assertRaises(Skip):
await client.GET() await client.GET()
@ -211,10 +206,12 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
@api @api
class API: class API:
class good: class good:
def GET(x: int) -> int: ... def GET(x: int) -> int:
...
class bad: class bad:
def GET(x: int) -> int: ... def GET(x: int) -> int:
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def good_GET(self, x: int) -> int: async def good_GET(self, x: int) -> int:
@ -233,23 +230,25 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
@api @api
class API: class API:
class good: class good:
def GET(x: int) -> int: ... def GET(x: int) -> int:
...
class bad: class bad:
def GET(x: int) -> int: ... def GET(x: int) -> int:
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def good_GET(self, x: int) -> int: async def good_GET(self, x: int) -> int:
return x + 1 return x + 1
async def bad_GET(self, x: int) -> int: async def bad_GET(self, x: int) -> int:
1/0 1 / 0
@web.middleware @web.middleware
async def middleware(request, handler): async def middleware(request, handler):
resp = await handler(request) resp = await handler(request)
if resp.get('exception'): if resp.get("exception"):
resp.headers['x-status'] = 'ERROR' resp.headers["x-status"] = "ERROR"
return resp return resp
class Abort(Exception): class Abort(Exception):
@ -258,15 +257,15 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def custom_make_request(client, method, path, *, params, json): async def custom_make_request(client, method, path, *, params, json):
async with make_request( async with make_request(
client, method, path, params=params, json=json) as resp: client, method, path, params=params, json=json
if resp.headers.get('x-status') == 'ERROR': ) as resp:
if resp.headers.get("x-status") == "ERROR":
raise Abort raise Abort
yield resp yield resp
async with makeE2EClient( async with makeE2EClient(
API, Impl(), API, Impl(), middlewares=[middleware], make_request=custom_make_request
middlewares=[middleware], ) as client:
make_request=custom_make_request) as client:
r = await client.good.GET(2) r = await client.good.GET(2)
self.assertEqual(r, 3) self.assertEqual(r, 3)
with self.assertRaises(Abort): with self.assertRaises(Abort):

View File

@ -17,38 +17,30 @@ import contextlib
import unittest import unittest
from unittest import mock from unittest import mock
from aiohttp.test_utils import TestClient, TestServer
from aiohttp import web from aiohttp import web
from aiohttp.test_utils import TestClient, TestServer
from subiquitycore.context import Context from subiquity.common.api.defs import Payload, allowed_before_start, api, path_parameter
from subiquity.common.api.defs import (
allowed_before_start,
api,
path_parameter,
Payload,
)
from subiquity.common.api.server import ( from subiquity.common.api.server import (
bind,
controller_for_request,
MissingImplementationError, MissingImplementationError,
SignatureMisatchError, SignatureMisatchError,
) bind,
controller_for_request,
)
from subiquitycore.context import Context
class TestApp: class TestApp:
def report_start_event(self, context, description): def report_start_event(self, context, description):
pass pass
def report_finish_event(self, context, description, result): def report_finish_event(self, context, description, result):
pass pass
project = 'test' project = "test"
class ControllerBase: class ControllerBase:
def __init__(self): def __init__(self):
self.context = Context.new(TestApp()) self.context = Context.new(TestApp())
self.app = mock.Mock() self.app = mock.Mock()
@ -64,109 +56,109 @@ async def makeTestClient(api, impl, middlewares=()):
class TestBind(unittest.IsolatedAsyncioTestCase): class TestBind(unittest.IsolatedAsyncioTestCase):
async def assertResponse(self, coro, value): async def assertResponse(self, coro, value):
resp = await coro resp = await coro
if resp.headers.get('x-error-msg'): if resp.headers.get("x-error-msg"):
raise Exception( raise Exception(
resp.headers['x-error-type'] + ': ' + resp.headers["x-error-type"] + ": " + resp.headers["x-error-msg"]
resp.headers['x-error-msg']) )
self.assertEqual(resp.status, 200) self.assertEqual(resp.status, 200)
self.assertEqual(await resp.json(), value) self.assertEqual(await resp.json(), value)
async def test_simple(self): async def test_simple(self):
@api @api
class API: class API:
def GET() -> str: ... def GET() -> str:
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def GET(self) -> str: async def GET(self) -> str:
return 'value' return "value"
async with makeTestClient(API, Impl()) as client: async with makeTestClient(API, Impl()) as client:
await self.assertResponse( await self.assertResponse(client.get("/"), "value")
client.get("/"), 'value')
async def test_nested(self): async def test_nested(self):
@api @api
class API: class API:
class endpoint: class endpoint:
class nested: class nested:
def get(): ... def get():
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def nested_get(self, request, context): async def nested_get(self, request, context):
return 'nested' return "nested"
async with makeTestClient(API.endpoint, Impl()) as client: async with makeTestClient(API.endpoint, Impl()) as client:
await self.assertResponse( await self.assertResponse(client.get("/endpoint/nested"), "nested")
client.get("/endpoint/nested"), 'nested')
async def test_args(self): async def test_args(self):
@api @api
class API: class API:
def GET(arg: str): ... def GET(arg: str):
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def GET(self, arg: str): async def GET(self, arg: str):
return arg return arg
async with makeTestClient(API, Impl()) as client: async with makeTestClient(API, Impl()) as client:
await self.assertResponse( await self.assertResponse(client.get('/?arg="whut"'), "whut")
client.get('/?arg="whut"'), 'whut')
async def test_missing_argument(self): async def test_missing_argument(self):
@api @api
class API: class API:
def GET(arg: str): ... def GET(arg: str):
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def GET(self, arg: str): async def GET(self, arg: str):
return arg return arg
async with makeTestClient(API, Impl()) as client: async with makeTestClient(API, Impl()) as client:
resp = await client.get('/') resp = await client.get("/")
self.assertEqual(resp.status, 500) self.assertEqual(resp.status, 500)
self.assertEqual(resp.headers['x-status'], 'error') self.assertEqual(resp.headers["x-status"], "error")
self.assertEqual(resp.headers['x-error-type'], 'TypeError') self.assertEqual(resp.headers["x-error-type"], "TypeError")
self.assertEqual( self.assertEqual(
resp.headers['x-error-msg'], resp.headers["x-error-msg"], 'missing required argument "arg"'
'missing required argument "arg"') )
async def test_error(self): async def test_error(self):
@api @api
class API: class API:
def GET(): ... def GET():
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def GET(self): async def GET(self):
return 1/0 return 1 / 0
async with makeTestClient(API, Impl()) as client: async with makeTestClient(API, Impl()) as client:
resp = await client.get('/') resp = await client.get("/")
self.assertEqual(resp.status, 500) self.assertEqual(resp.status, 500)
self.assertEqual(resp.headers['x-status'], 'error') self.assertEqual(resp.headers["x-status"], "error")
self.assertEqual( self.assertEqual(resp.headers["x-error-type"], "ZeroDivisionError")
resp.headers['x-error-type'], 'ZeroDivisionError')
async def test_post(self): async def test_post(self):
@api @api
class API: class API:
def POST(data: Payload[str]) -> str: ... def POST(data: Payload[str]) -> str:
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def POST(self, data: str) -> str: async def POST(self, data: str) -> str:
return data return data
async with makeTestClient(API, Impl()) as client: async with makeTestClient(API, Impl()) as client:
await self.assertResponse( await self.assertResponse(client.post("/", json="value"), "value")
client.post("/", json='value'),
'value')
def test_missing_method(self): def test_missing_method(self):
@api @api
class API: class API:
def GET(arg: str): ... def GET(arg: str):
...
class Impl(ControllerBase): class Impl(ControllerBase):
pass pass
@ -179,7 +171,8 @@ class TestBind(unittest.IsolatedAsyncioTestCase):
def test_signature_checking(self): def test_signature_checking(self):
@api @api
class API: class API:
def GET(arg: str): ... def GET(arg: str):
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def GET(self, arg: int): async def GET(self, arg: int):
@ -191,31 +184,28 @@ class TestBind(unittest.IsolatedAsyncioTestCase):
self.assertEqual(cm.exception.methname, "API.GET") self.assertEqual(cm.exception.methname, "API.GET")
async def test_middleware(self): async def test_middleware(self):
@web.middleware @web.middleware
async def middleware(request, handler): async def middleware(request, handler):
resp = await handler(request) resp = await handler(request)
exc = resp.get('exception') exc = resp.get("exception")
if resp.get('exception') is not None: if resp.get("exception") is not None:
resp.headers['x-error-type'] = type(exc).__name__ resp.headers["x-error-type"] = type(exc).__name__
return resp return resp
@api @api
class API: class API:
def GET() -> str: ... def GET() -> str:
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def GET(self) -> str: async def GET(self) -> str:
return 1/0 return 1 / 0
async with makeTestClient( async with makeTestClient(API, Impl(), middlewares=[middleware]) as client:
API, Impl(), middlewares=[middleware]) as client:
resp = await client.get("/") resp = await client.get("/")
self.assertEqual( self.assertEqual(resp.headers["x-error-type"], "ZeroDivisionError")
resp.headers['x-error-type'], 'ZeroDivisionError')
async def test_controller_for_request(self): async def test_controller_for_request(self):
seen_controller = None seen_controller = None
@web.middleware @web.middleware
@ -227,18 +217,18 @@ class TestBind(unittest.IsolatedAsyncioTestCase):
@api @api
class API: class API:
class meth: class meth:
def GET() -> str: ... def GET() -> str:
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def GET(self) -> str: async def GET(self) -> str:
return '' return ""
impl = Impl() impl = Impl()
async with makeTestClient( async with makeTestClient(API.meth, impl, middlewares=[middleware]) as client:
API.meth, impl, middlewares=[middleware]) as client:
resp = await client.get("/meth") resp = await client.get("/meth")
self.assertEqual(await resp.json(), '') self.assertEqual(await resp.json(), "")
self.assertIs(impl, seen_controller) self.assertIs(impl, seen_controller)
@ -246,14 +236,19 @@ class TestBind(unittest.IsolatedAsyncioTestCase):
@api @api
class API: class API:
serialize_query_args = False serialize_query_args = False
def GET(arg: str) -> str: ...
def GET(arg: str) -> str:
...
class meth: class meth:
def GET(arg: str) -> str: ... def GET(arg: str) -> str:
...
class more: class more:
serialize_query_args = True serialize_query_args = True
def GET(arg: str) -> str: ...
def GET(arg: str) -> str:
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def GET(self, arg: str) -> str: async def GET(self, arg: str) -> str:
@ -266,67 +261,64 @@ class TestBind(unittest.IsolatedAsyncioTestCase):
return arg return arg
async with makeTestClient(API, Impl()) as client: async with makeTestClient(API, Impl()) as client:
await self.assertResponse( await self.assertResponse(client.get("/?arg=whut"), "whut")
client.get('/?arg=whut'), 'whut') await self.assertResponse(client.get("/meth?arg=whut"), "whut")
await self.assertResponse( await self.assertResponse(client.get('/meth/more?arg="whut"'), "whut")
client.get('/meth?arg=whut'), 'whut')
await self.assertResponse(
client.get('/meth/more?arg="whut"'), 'whut')
async def test_path_parameters(self): async def test_path_parameters(self):
@api @api
class API: class API:
@path_parameter @path_parameter
class param: class param:
def GET(arg: int): ... def GET(arg: int):
...
class Impl(ControllerBase): class Impl(ControllerBase):
async def param_GET(self, param: str, arg: int): async def param_GET(self, param: str, arg: int):
return param + str(arg) return param + str(arg)
async with makeTestClient(API, Impl()) as client: async with makeTestClient(API, Impl()) as client:
await self.assertResponse( await self.assertResponse(client.get("/value?arg=2"), "value2")
client.get('/value?arg=2'), 'value2')
async def test_early_connect_ok(self): async def test_early_connect_ok(self):
@api @api
class API: class API:
class can_be_used_early: class can_be_used_early:
@allowed_before_start @allowed_before_start
def GET(): ... def GET():
...
class Impl(ControllerBase): class Impl(ControllerBase):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.app.controllers_have_started = mock.AsyncMock() self.app.controllers_have_started = mock.AsyncMock()
self.app.controllers_have_started.is_set = mock.Mock( self.app.controllers_have_started.is_set = mock.Mock(return_value=False)
return_value=False)
async def can_be_used_early_GET(self): async def can_be_used_early_GET(self):
pass pass
impl = Impl() impl = Impl()
async with makeTestClient(API, impl) as client: async with makeTestClient(API, impl) as client:
await client.get('/can_be_used_early') await client.get("/can_be_used_early")
impl.app.controllers_have_started.wait.assert_not_called() impl.app.controllers_have_started.wait.assert_not_called()
async def test_early_connect_wait(self): async def test_early_connect_wait(self):
@api @api
class API: class API:
class must_not_be_used_early: class must_not_be_used_early:
def GET(): ... def GET():
...
class Impl(ControllerBase): class Impl(ControllerBase):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.app.controllers_have_started = mock.AsyncMock() self.app.controllers_have_started = mock.AsyncMock()
self.app.controllers_have_started.is_set = mock.Mock( self.app.controllers_have_started.is_set = mock.Mock(return_value=False)
return_value=False)
async def must_not_be_used_early_GET(self): async def must_not_be_used_early_GET(self):
pass pass
impl = Impl() impl = Impl()
async with makeTestClient(API, impl) as client: async with makeTestClient(API, impl) as client:
await client.get('/must_not_be_used_early') await client.get("/must_not_be_used_early")
impl.app.controllers_have_started.wait.assert_called_once() impl.app.controllers_have_started.wait.assert_called_once()

View File

@ -16,79 +16,79 @@
import enum import enum
from typing import List, Optional from typing import List, Optional
from subiquitycore.models.network import (
BondConfig,
NetDevInfo,
StaticConfig,
WLANConfig,
)
from subiquity.common.api.defs import ( from subiquity.common.api.defs import (
Payload,
allowed_before_start, allowed_before_start,
api, api,
Payload,
simple_endpoint, simple_endpoint,
) )
from subiquity.common.types import ( from subiquity.common.types import (
AdConnectionInfo,
AdAdminNameValidation, AdAdminNameValidation,
AdConnectionInfo,
AdDomainNameValidation, AdDomainNameValidation,
AddPartitionV2,
AdJoinResult, AdJoinResult,
AdPasswordValidation, AdPasswordValidation,
AddPartitionV2,
AnyStep, AnyStep,
ApplicationState, ApplicationState,
ApplicationStatus, ApplicationStatus,
CasperMd5Results, CasperMd5Results,
Change, Change,
CodecsData,
Disk, Disk,
DriversPayload,
DriversResponse,
ErrorReportRef, ErrorReportRef,
GuidedChoiceV2, GuidedChoiceV2,
GuidedStorageResponseV2, GuidedStorageResponseV2,
IdentityData,
KeyboardSetting, KeyboardSetting,
KeyboardSetup, KeyboardSetup,
IdentityData, LiveSessionSSHInfo,
NetworkStatus, MirrorCheckResponse,
MirrorSelectionFallback,
MirrorGet, MirrorGet,
MirrorPost, MirrorPost,
MirrorPostResponse, MirrorPostResponse,
MirrorCheckResponse, MirrorSelectionFallback,
ModifyPartitionV2, ModifyPartitionV2,
NetworkStatus,
OEMResponse,
PackageInstallState,
ReformatDisk, ReformatDisk,
RefreshStatus, RefreshStatus,
ShutdownMode, ShutdownMode,
CodecsData,
DriversResponse,
DriversPayload,
OEMResponse,
SnapInfo, SnapInfo,
SnapListResponse, SnapListResponse,
SnapSelection, SnapSelection,
SourceSelectionAndSetting, SourceSelectionAndSetting,
SSHData, SSHData,
SSHFetchIdResponse, SSHFetchIdResponse,
LiveSessionSSHInfo,
StorageResponse, StorageResponse,
StorageResponseV2, StorageResponseV2,
TimeZoneInfo, TimeZoneInfo,
UbuntuProCheckTokenAnswer,
UbuntuProInfo, UbuntuProInfo,
UbuntuProResponse, UbuntuProResponse,
UbuntuProCheckTokenAnswer,
UPCSInitiateResponse, UPCSInitiateResponse,
UPCSWaitResponse, UPCSWaitResponse,
UsernameValidation, UsernameValidation,
PackageInstallState,
ZdevInfo,
WSLConfigurationBase,
WSLConfigurationAdvanced, WSLConfigurationAdvanced,
WSLConfigurationBase,
WSLSetupOptions, WSLSetupOptions,
) ZdevInfo,
)
from subiquitycore.models.network import (
BondConfig,
NetDevInfo,
StaticConfig,
WLANConfig,
)
@api @api
class API: class API:
"""The API offered by the subiquity installer process.""" """The API offered by the subiquity installer process."""
locale = simple_endpoint(str) locale = simple_endpoint(str)
proxy = simple_endpoint(str) proxy = simple_endpoint(str)
updates = simple_endpoint(str) updates = simple_endpoint(str)
@ -99,8 +99,7 @@ class API:
class meta: class meta:
class status: class status:
@allowed_before_start @allowed_before_start
def GET(cur: Optional[ApplicationState] = None) \ def GET(cur: Optional[ApplicationState] = None) -> ApplicationStatus:
-> ApplicationStatus:
"""Get the installer state.""" """Get the installer state."""
class mark_configured: class mark_configured:
@ -111,6 +110,7 @@ class API:
def POST(variant: str) -> None: def POST(variant: str) -> None:
"""Choose the install variant - """Choose the install variant -
desktop/server/wsl_setup/wsl_reconfigure""" desktop/server/wsl_setup/wsl_reconfigure"""
def GET() -> str: def GET() -> str:
"""Get the install variant - """Get the install variant -
desktop/server/wsl_setup/wsl_reconfigure""" desktop/server/wsl_setup/wsl_reconfigure"""
@ -124,10 +124,12 @@ class API:
"""Restart the server process.""" """Restart the server process."""
class ssh_info: class ssh_info:
def GET() -> Optional[LiveSessionSSHInfo]: ... def GET() -> Optional[LiveSessionSSHInfo]:
...
class free_only: class free_only:
def GET() -> bool: ... def GET() -> bool:
...
def POST(enable: bool) -> None: def POST(enable: bool) -> None:
"""Enable or disable free-only mode. Currently only controlls """Enable or disable free-only mode. Currently only controlls
@ -135,7 +137,8 @@ class API:
confirmation of filesystem changes""" confirmation of filesystem changes"""
class interactive_sections: class interactive_sections:
def GET() -> Optional[List[str]]: ... def GET() -> Optional[List[str]]:
...
class errors: class errors:
class wait: class wait:
@ -159,38 +162,55 @@ class API:
"""Start the update and return the change id.""" """Start the update and return the change id."""
class progress: class progress:
def GET(change_id: str) -> Change: ... def GET(change_id: str) -> Change:
...
class keyboard: class keyboard:
def GET() -> KeyboardSetup: ... def GET() -> KeyboardSetup:
def POST(data: Payload[KeyboardSetting]): ... ...
def POST(data: Payload[KeyboardSetting]):
...
class needs_toggle: class needs_toggle:
def GET(layout_code: str, variant_code: str) -> bool: ... def GET(layout_code: str, variant_code: str) -> bool:
...
class steps: class steps:
def GET(index: Optional[str]) -> AnyStep: ... def GET(index: Optional[str]) -> AnyStep:
...
class input_source: class input_source:
def POST(data: Payload[KeyboardSetting], def POST(
user: Optional[str] = None) -> None: ... data: Payload[KeyboardSetting], user: Optional[str] = None
) -> None:
...
class source: class source:
def GET() -> SourceSelectionAndSetting: ... def GET() -> SourceSelectionAndSetting:
def POST(source_id: str, search_drivers: bool = False) -> None: ... ...
def POST(source_id: str, search_drivers: bool = False) -> None:
...
class zdev: class zdev:
def GET() -> List[ZdevInfo]: ... def GET() -> List[ZdevInfo]:
...
class chzdev: class chzdev:
def POST(action: str, zdev: ZdevInfo) -> List[ZdevInfo]: ... def POST(action: str, zdev: ZdevInfo) -> List[ZdevInfo]:
...
class network: class network:
def GET() -> NetworkStatus: ... def GET() -> NetworkStatus:
def POST() -> None: ... ...
def POST() -> None:
...
class has_network: class has_network:
def GET() -> bool: ... def GET() -> bool:
...
class global_addresses: class global_addresses:
def GET() -> List[str]: def GET() -> List[str]:
@ -201,8 +221,12 @@ class API:
The socket must serve the NetEventAPI below. The socket must serve the NetEventAPI below.
""" """
def PUT(socket_path: str) -> None: ...
def DELETE(socket_path: str) -> None: ... def PUT(socket_path: str) -> None:
...
def DELETE(socket_path: str) -> None:
...
# These methods could definitely be more RESTish, like maybe a # These methods could definitely be more RESTish, like maybe a
# GET request to /network/interfaces/$name should return netplan # GET request to /network/interfaces/$name should return netplan
@ -237,77 +261,100 @@ class API:
# So methods on nics get an extra dev_name: str parameter) # So methods on nics get an extra dev_name: str parameter)
class set_static_config: class set_static_config:
def POST(dev_name: str, ip_version: int, def POST(
static_config: Payload[StaticConfig]) -> None: ... dev_name: str, ip_version: int, static_config: Payload[StaticConfig]
) -> None:
...
class enable_dhcp: class enable_dhcp:
def POST(dev_name: str, ip_version: int) -> None: ... def POST(dev_name: str, ip_version: int) -> None:
...
class disable: class disable:
def POST(dev_name: str, ip_version: int) -> None: ... def POST(dev_name: str, ip_version: int) -> None:
...
class vlan: class vlan:
def PUT(dev_name: str, vlan_id: int) -> None: ... def PUT(dev_name: str, vlan_id: int) -> None:
...
class add_or_edit_bond: class add_or_edit_bond:
def POST(existing_name: Optional[str], new_name: str, def POST(
bond_config: Payload[BondConfig]) -> None: ... existing_name: Optional[str],
new_name: str,
bond_config: Payload[BondConfig],
) -> None:
...
class start_scan: class start_scan:
def POST(dev_name: str) -> None: ... def POST(dev_name: str) -> None:
...
class set_wlan: class set_wlan:
def POST(dev_name: str, wlan: WLANConfig) -> None: ... def POST(dev_name: str, wlan: WLANConfig) -> None:
...
class delete: class delete:
def POST(dev_name: str) -> None: ... def POST(dev_name: str) -> None:
...
class info: class info:
def GET(dev_name: str) -> str: ... def GET(dev_name: str) -> str:
...
class storage: class storage:
class guided: class guided:
def POST(data: Payload[GuidedChoiceV2]) \ def POST(data: Payload[GuidedChoiceV2]) -> StorageResponse:
-> StorageResponse:
pass pass
def GET(wait: bool = False, use_cached_result: bool = False) \ def GET(wait: bool = False, use_cached_result: bool = False) -> StorageResponse:
-> StorageResponse: ... ...
def POST(config: Payload[list]): ... def POST(config: Payload[list]):
...
class dry_run_wait_probe: class dry_run_wait_probe:
"""This endpoint only works in dry-run mode.""" """This endpoint only works in dry-run mode."""
def POST() -> None: ...
def POST() -> None:
...
class reset: class reset:
def POST() -> StorageResponse: ... def POST() -> StorageResponse:
...
class has_rst: class has_rst:
def GET() -> bool: def GET() -> bool:
pass pass
class has_bitlocker: class has_bitlocker:
def GET() -> List[Disk]: ... def GET() -> List[Disk]:
...
class v2: class v2:
def GET( def GET(
wait: bool = False, wait: bool = False,
include_raid: bool = False, include_raid: bool = False,
) -> StorageResponseV2: ... ) -> StorageResponseV2:
...
def POST() -> StorageResponseV2: ... def POST() -> StorageResponseV2:
...
class orig_config: class orig_config:
def GET() -> StorageResponseV2: ... def GET() -> StorageResponseV2:
...
class guided: class guided:
def GET(wait: bool = False) -> GuidedStorageResponseV2: ... def GET(wait: bool = False) -> GuidedStorageResponseV2:
def POST(data: Payload[GuidedChoiceV2]) \ ...
-> GuidedStorageResponseV2: ...
def POST(data: Payload[GuidedChoiceV2]) -> GuidedStorageResponseV2:
...
class reset: class reset:
def POST() -> StorageResponseV2: ... def POST() -> StorageResponseV2:
...
class ensure_transaction: class ensure_transaction:
"""This call will ensure that a transaction is initiated. """This call will ensure that a transaction is initiated.
@ -316,18 +363,22 @@ class API:
A transaction will also be initiated by any v2_storage POST A transaction will also be initiated by any v2_storage POST
request that modifies the partitioning configuration (e.g., request that modifies the partitioning configuration (e.g.,
add_partition, edit_partition, ...) but ensure_transaction can add_partition, edit_partition, ...) but ensure_transaction can
be called early if desired. """ be called early if desired."""
def POST() -> None: ...
def POST() -> None:
...
class reformat_disk: class reformat_disk:
def POST(data: Payload[ReformatDisk]) \ def POST(data: Payload[ReformatDisk]) -> StorageResponseV2:
-> StorageResponseV2: ... ...
class add_boot_partition: class add_boot_partition:
"""Mark a given disk as bootable, which may cause a partition """Mark a given disk as bootable, which may cause a partition
to be added to the disk. It is an error to call this for a to be added to the disk. It is an error to call this for a
disk for which can_be_boot_device is False.""" disk for which can_be_boot_device is False."""
def POST(disk_id: str) -> StorageResponseV2: ...
def POST(disk_id: str) -> StorageResponseV2:
...
class add_partition: class add_partition:
"""required field format and mount, optional field size """required field format and mount, optional field size
@ -339,15 +390,17 @@ class API:
add_boot_partition for more control over this. add_boot_partition for more control over this.
format=None means an unformatted partition format=None means an unformatted partition
""" """
def POST(data: Payload[AddPartitionV2]) \
-> StorageResponseV2: ... def POST(data: Payload[AddPartitionV2]) -> StorageResponseV2:
...
class delete_partition: class delete_partition:
"""required field number """required field number
It is an error to modify other Partition fields. It is an error to modify other Partition fields.
""" """
def POST(data: Payload[ModifyPartitionV2]) \
-> StorageResponseV2: ... def POST(data: Payload[ModifyPartitionV2]) -> StorageResponseV2:
...
class edit_partition: class edit_partition:
"""required field number """required field number
@ -355,124 +408,173 @@ class API:
It is an error to do wipe=null and change the format. It is an error to do wipe=null and change the format.
It is an error to modify other Partition fields. It is an error to modify other Partition fields.
""" """
def POST(data: Payload[ModifyPartitionV2]) \
-> StorageResponseV2: ... def POST(data: Payload[ModifyPartitionV2]) -> StorageResponseV2:
...
class codecs: class codecs:
def GET() -> CodecsData: ... def GET() -> CodecsData:
def POST(data: Payload[CodecsData]) -> None: ... ...
def POST(data: Payload[CodecsData]) -> None:
...
class drivers: class drivers:
def GET(wait: bool = False) -> DriversResponse: ... def GET(wait: bool = False) -> DriversResponse:
def POST(data: Payload[DriversPayload]) -> None: ... ...
def POST(data: Payload[DriversPayload]) -> None:
...
class oem: class oem:
def GET(wait: bool = False) -> OEMResponse: ... def GET(wait: bool = False) -> OEMResponse:
...
class snaplist: class snaplist:
def GET(wait: bool = False) -> SnapListResponse: ... def GET(wait: bool = False) -> SnapListResponse:
def POST(data: Payload[List[SnapSelection]]): ... ...
def POST(data: Payload[List[SnapSelection]]):
...
class snap_info: class snap_info:
def GET(snap_name: str) -> SnapInfo: ... def GET(snap_name: str) -> SnapInfo:
...
class timezone: class timezone:
def GET() -> TimeZoneInfo: ... def GET() -> TimeZoneInfo:
def POST(tz: str): ... ...
def POST(tz: str):
...
class shutdown: class shutdown:
def POST(mode: ShutdownMode, immediate: bool = False): ... def POST(mode: ShutdownMode, immediate: bool = False):
...
class mirror: class mirror:
def GET() -> MirrorGet: ... def GET() -> MirrorGet:
...
def POST(data: Payload[Optional[MirrorPost]]) \ def POST(data: Payload[Optional[MirrorPost]]) -> MirrorPostResponse:
-> MirrorPostResponse: ... ...
class disable_components: class disable_components:
def GET() -> List[str]: ... def GET() -> List[str]:
def POST(data: Payload[List[str]]): ... ...
def POST(data: Payload[List[str]]):
...
class check_mirror: class check_mirror:
class start: class start:
def POST(cancel_ongoing: bool = False) -> None: ... def POST(cancel_ongoing: bool = False) -> None:
...
class progress: class progress:
def GET() -> Optional[MirrorCheckResponse]: ... def GET() -> Optional[MirrorCheckResponse]:
...
class abort: class abort:
def POST() -> None: ... def POST() -> None:
...
fallback = simple_endpoint(MirrorSelectionFallback) fallback = simple_endpoint(MirrorSelectionFallback)
class ubuntu_pro: class ubuntu_pro:
def GET() -> UbuntuProResponse: ... def GET() -> UbuntuProResponse:
def POST(data: Payload[UbuntuProInfo]) -> None: ... ...
def POST(data: Payload[UbuntuProInfo]) -> None:
...
class skip: class skip:
def POST() -> None: ... def POST() -> None:
...
class check_token: class check_token:
def GET(token: Payload[str]) \ def GET(token: Payload[str]) -> UbuntuProCheckTokenAnswer:
-> UbuntuProCheckTokenAnswer: ... ...
class contract_selection: class contract_selection:
class initiate: class initiate:
def POST() -> UPCSInitiateResponse: ... def POST() -> UPCSInitiateResponse:
...
class wait: class wait:
def GET() -> UPCSWaitResponse: ... def GET() -> UPCSWaitResponse:
...
class cancel: class cancel:
def POST() -> None: ... def POST() -> None:
...
class identity: class identity:
def GET() -> IdentityData: ... def GET() -> IdentityData:
def POST(data: Payload[IdentityData]): ... ...
def POST(data: Payload[IdentityData]):
...
class validate_username: class validate_username:
def GET(username: str) -> UsernameValidation: ... def GET(username: str) -> UsernameValidation:
...
class ssh: class ssh:
def GET() -> SSHData: ... def GET() -> SSHData:
def POST(data: Payload[SSHData]) -> None: ... ...
def POST(data: Payload[SSHData]) -> None:
...
class fetch_id: class fetch_id:
def GET(user_id: str) -> SSHFetchIdResponse: ... def GET(user_id: str) -> SSHFetchIdResponse:
...
class integrity: class integrity:
def GET() -> CasperMd5Results: ... def GET() -> CasperMd5Results:
...
class active_directory: class active_directory:
def GET() -> Optional[AdConnectionInfo]: ... def GET() -> Optional[AdConnectionInfo]:
...
# POST expects input validated by the check methods below: # POST expects input validated by the check methods below:
def POST(data: Payload[AdConnectionInfo]) -> None: ... def POST(data: Payload[AdConnectionInfo]) -> None:
...
class has_support: class has_support:
""" Whether the live system supports Active Directory or not. """Whether the live system supports Active Directory or not.
Network status is not considered. Network status is not considered.
Clients should call this before showing the AD page. """ Clients should call this before showing the AD page."""
def GET() -> bool: ...
def GET() -> bool:
...
class check_domain_name: class check_domain_name:
""" Applies basic validation to the candidate domain name, """Applies basic validation to the candidate domain name,
without calling the domain network. """ without calling the domain network."""
def POST(domain_name: Payload[str]) \
-> List[AdDomainNameValidation]: ... def POST(domain_name: Payload[str]) -> List[AdDomainNameValidation]:
...
class ping_domain_controller: class ping_domain_controller:
""" Attempts to contact the controller of the provided domain. """ """Attempts to contact the controller of the provided domain."""
def POST(domain_name: Payload[str]) \
-> AdDomainNameValidation: ... def POST(domain_name: Payload[str]) -> AdDomainNameValidation:
...
class check_admin_name: class check_admin_name:
def POST(admin_name: Payload[str]) -> AdAdminNameValidation: ... def POST(admin_name: Payload[str]) -> AdAdminNameValidation:
...
class check_password: class check_password:
def POST(password: Payload[str]) -> AdPasswordValidation: ... def POST(password: Payload[str]) -> AdPasswordValidation:
...
class join_result: class join_result:
def GET(wait: bool = True) -> AdJoinResult: ... def GET(wait: bool = True) -> AdJoinResult:
...
class LinkAction(enum.Enum): class LinkAction(enum.Enum):
@ -484,19 +586,25 @@ class LinkAction(enum.Enum):
@api @api
class NetEventAPI: class NetEventAPI:
class wlan_support_install_finished: class wlan_support_install_finished:
def POST(state: PackageInstallState) -> None: ... def POST(state: PackageInstallState) -> None:
...
class update_link: class update_link:
def POST(act: LinkAction, info: Payload[NetDevInfo]) -> None: ... def POST(act: LinkAction, info: Payload[NetDevInfo]) -> None:
...
class route_watch: class route_watch:
def POST(has_default_route: bool) -> None: ... def POST(has_default_route: bool) -> None:
...
class apply_starting: class apply_starting:
def POST() -> None: ... def POST() -> None:
...
class apply_stopping: class apply_stopping:
def POST() -> None: ... def POST() -> None:
...
class apply_error: class apply_error:
def POST(stage: str) -> None: ... def POST(stage: str) -> None:
...

View File

@ -27,33 +27,20 @@ from typing import Iterable, Set
import apport import apport
import apport.crashdb import apport.crashdb
import apport.hookutils import apport.hookutils
import attr import attr
import bson import bson
import requests import requests
import urwid import urwid
from subiquitycore.async_helpers import ( from subiquity.common.types import ErrorReportKind, ErrorReportRef, ErrorReportState
run_in_thread, from subiquitycore.async_helpers import run_in_thread, schedule_task
schedule_task,
)
from subiquity.common.types import ( log = logging.getLogger("subiquity.common.errorreport")
ErrorReportKind,
ErrorReportRef,
ErrorReportState,
)
log = logging.getLogger('subiquity.common.errorreport')
@attr.s(eq=False) @attr.s(eq=False)
class Upload(metaclass=urwid.MetaSignals): class Upload(metaclass=urwid.MetaSignals):
signals = ['progress'] signals = ["progress"]
bytes_to_send = attr.ib() bytes_to_send = attr.ib()
bytes_sent = attr.ib(default=0) bytes_sent = attr.ib(default=0)
@ -68,13 +55,13 @@ class Upload(metaclass=urwid.MetaSignals):
def _progress(self): def _progress(self):
os.read(self.pipe_r, 4096) os.read(self.pipe_r, 4096)
urwid.emit_signal(self, 'progress') urwid.emit_signal(self, "progress")
def _bg_update(self, sent, to_send=None): def _bg_update(self, sent, to_send=None):
self.bytes_sent = sent self.bytes_sent = sent
if to_send is not None: if to_send is not None:
self.bytes_to_send = to_send self.bytes_to_send = to_send
os.write(self.pipe_w, b'x') os.write(self.pipe_w, b"x")
def stop(self): def stop(self):
asyncio.get_running_loop().remove_reader(self.pipe_r) asyncio.get_running_loop().remove_reader(self.pipe_r)
@ -84,7 +71,6 @@ class Upload(metaclass=urwid.MetaSignals):
@attr.s(eq=False) @attr.s(eq=False)
class ErrorReport(metaclass=urwid.MetaSignals): class ErrorReport(metaclass=urwid.MetaSignals):
signals = ["changed"] signals = ["changed"]
reporter = attr.ib() reporter = attr.ib()
@ -101,17 +87,19 @@ class ErrorReport(metaclass=urwid.MetaSignals):
@classmethod @classmethod
def new(cls, reporter, kind): def new(cls, reporter, kind):
base = "{:.9f}.{}".format(time.time(), kind.name.lower()) base = "{:.9f}.{}".format(time.time(), kind.name.lower())
crash_file = open( crash_file = open(os.path.join(reporter.crash_directory, base + ".crash"), "wb")
os.path.join(reporter.crash_directory, base + ".crash"),
'wb')
pr = apport.Report('Bug') pr = apport.Report("Bug")
pr['CrashDB'] = repr(reporter.crashdb_spec) pr["CrashDB"] = repr(reporter.crashdb_spec)
r = cls( r = cls(
reporter=reporter, base=base, pr=pr, file=crash_file, reporter=reporter,
base=base,
pr=pr,
file=crash_file,
state=ErrorReportState.INCOMPLETE, state=ErrorReportState.INCOMPLETE,
context=reporter.context.child(base)) context=reporter.context.child(base),
)
r.set_meta("kind", kind.name) r.set_meta("kind", kind.name)
return r return r
@ -119,11 +107,15 @@ class ErrorReport(metaclass=urwid.MetaSignals):
def from_file(cls, reporter, fpath): def from_file(cls, reporter, fpath):
base = os.path.splitext(os.path.basename(fpath))[0] base = os.path.splitext(os.path.basename(fpath))[0]
report = cls( report = cls(
reporter, base, pr=apport.Report(date='???'), reporter,
state=ErrorReportState.LOADING, file=open(fpath, 'rb'), base,
context=reporter.context.child(base)) pr=apport.Report(date="???"),
state=ErrorReportState.LOADING,
file=open(fpath, "rb"),
context=reporter.context.child(base),
)
try: try:
fp = open(report.meta_path, 'r') fp = open(report.meta_path, "r")
except FileNotFoundError: except FileNotFoundError:
pass pass
else: else:
@ -140,26 +132,27 @@ class ErrorReport(metaclass=urwid.MetaSignals):
if not self.reporter.dry_run: if not self.reporter.dry_run:
self.pr.add_hooks_info(None) self.pr.add_hooks_info(None)
apport.hookutils.attach_hardware(self.pr) apport.hookutils.attach_hardware(self.pr)
self.pr['Syslog'] = apport.hookutils.recent_syslog(re.compile('.')) self.pr["Syslog"] = apport.hookutils.recent_syslog(re.compile("."))
snap_name = os.environ.get('SNAP_NAME', '') snap_name = os.environ.get("SNAP_NAME", "")
if snap_name != '': if snap_name != "":
self.add_tags([snap_name]) self.add_tags([snap_name])
# Because apport-cli will in general be run on a different # Because apport-cli will in general be run on a different
# machine, we make some slightly obscure alterations to the report # machine, we make some slightly obscure alterations to the report
# to make this go better. # to make this go better.
# apport-cli gets upset if neither of these are present. # apport-cli gets upset if neither of these are present.
self.pr['Package'] = 'subiquity ' + os.environ.get( self.pr["Package"] = "subiquity " + os.environ.get(
"SNAP_REVISION", "SNAP_REVISION") "SNAP_REVISION", "SNAP_REVISION"
self.pr['SourcePackage'] = 'subiquity' )
self.pr["SourcePackage"] = "subiquity"
# If ExecutableTimestamp is present, apport-cli will try to check # If ExecutableTimestamp is present, apport-cli will try to check
# that ExecutablePath hasn't changed. But it won't be there. # that ExecutablePath hasn't changed. But it won't be there.
del self.pr['ExecutableTimestamp'] del self.pr["ExecutableTimestamp"]
# apport-cli gets upset at the probert C extensions it sees in # apport-cli gets upset at the probert C extensions it sees in
# here. /proc/maps is very unlikely to be interesting for us # here. /proc/maps is very unlikely to be interesting for us
# anyway. # anyway.
del self.pr['ProcMaps'] del self.pr["ProcMaps"]
self.pr.write(self._file) self.pr.write(self._file)
async def add_info(): async def add_info():
@ -175,6 +168,7 @@ class ErrorReport(metaclass=urwid.MetaSignals):
self._file.close() self._file.close()
self._file = None self._file = None
urwid.emit_signal(self, "changed") urwid.emit_signal(self, "changed")
if wait: if wait:
with self._context.child("add_info") as context: with self._context.child("add_info") as context:
_bg_add_info() _bg_add_info()
@ -210,13 +204,11 @@ class ErrorReport(metaclass=urwid.MetaSignals):
if uploader.cancelled: if uploader.cancelled:
log.debug("upload for %s cancelled", self.base) log.debug("upload for %s cancelled", self.base)
return return
yield data[i:i+chunk_size] yield data[i : i + chunk_size]
uploader._bg_update(uploader.bytes_sent + chunk_size) uploader._bg_update(uploader.bytes_sent + chunk_size)
def _bg_upload(): def _bg_upload():
for_upload = { for_upload = {"Kind": self.kind.value}
"Kind": self.kind.value
}
for k, v in self.pr.items(): for k, v in self.pr.items():
if len(v) < 1024 or k in { if len(v) < 1024 or k in {
"InstallerLogInfo", "InstallerLogInfo",
@ -236,8 +228,9 @@ class ErrorReport(metaclass=urwid.MetaSignals):
data = bson.BSON().encode(for_upload) data = bson.BSON().encode(for_upload)
self.uploader._bg_update(0, len(data)) self.uploader._bg_update(0, len(data))
headers = { headers = {
'user-agent': 'subiquity/{}'.format( "user-agent": "subiquity/{}".format(
os.environ.get("SNAP_VERSION", "SNAP_VERSION")), os.environ.get("SNAP_VERSION", "SNAP_VERSION")
),
} }
response = requests.post(url, data=chunk(data), headers=headers) response = requests.post(url, data=chunk(data), headers=headers)
response.raise_for_status() response.raise_for_status()
@ -254,28 +247,27 @@ class ErrorReport(metaclass=urwid.MetaSignals):
context.description = oops_id context.description = oops_id
uploader.stop() uploader.stop()
self.uploader = None self.uploader = None
urwid.emit_signal(self, 'changed') urwid.emit_signal(self, "changed")
urwid.emit_signal(self, 'changed') urwid.emit_signal(self, "changed")
uploader.start() uploader.start()
schedule_task(upload()) schedule_task(upload())
def _path_with_ext(self, ext): def _path_with_ext(self, ext):
return os.path.join( return os.path.join(self.reporter.crash_directory, self.base + "." + ext)
self.reporter.crash_directory, self.base + '.' + ext)
@property @property
def meta_path(self): def meta_path(self):
return self._path_with_ext('meta') return self._path_with_ext("meta")
@property @property
def path(self): def path(self):
return self._path_with_ext('crash') return self._path_with_ext("crash")
def set_meta(self, key, value): def set_meta(self, key, value):
self.meta[key] = value self.meta[key] = value
with open(self.meta_path, 'w') as fp: with open(self.meta_path, "w") as fp:
json.dump(self.meta, fp, indent=4) json.dump(self.meta, fp, indent=4)
def mark_seen(self): def mark_seen(self):
@ -300,9 +292,8 @@ class ErrorReport(metaclass=urwid.MetaSignals):
"""Return fs-label, path-on-fs to report.""" """Return fs-label, path-on-fs to report."""
# Not sure if this is more or less sane than shelling out to # Not sure if this is more or less sane than shelling out to
# findmnt(1). # findmnt(1).
looking_for = os.path.abspath( looking_for = os.path.abspath(os.path.normpath(self.reporter.crash_directory))
os.path.normpath(self.reporter.crash_directory)) for line in open("/proc/self/mountinfo").readlines():
for line in open('/proc/self/mountinfo').readlines():
parts = line.strip().split() parts = line.strip().split()
if os.path.normpath(parts[4]) == looking_for: if os.path.normpath(parts[4]) == looking_for:
devname = parts[9] devname = parts[9]
@ -310,19 +301,19 @@ class ErrorReport(metaclass=urwid.MetaSignals):
break break
else: else:
if self.reporter.dry_run: if self.reporter.dry_run:
path = ('install-logs/2019-11-06.0/crash/' + path = "install-logs/2019-11-06.0/crash/" + self.base + ".crash"
self.base +
'.crash')
return "casper-rw", path return "casper-rw", path
return None, None return None, None
import pyudev import pyudev
c = pyudev.Context() c = pyudev.Context()
devs = list(c.list_devices( devs = list(
subsystem='block', DEVNAME=os.path.realpath(devname))) c.list_devices(subsystem="block", DEVNAME=os.path.realpath(devname))
)
if not devs: if not devs:
return None, None return None, None
label = devs[0].get('ID_FS_LABEL_ENC', '') label = devs[0].get("ID_FS_LABEL_ENC", "")
return label, root[1:] + '/' + self.base + '.crash' return label, root[1:] + "/" + self.base + ".crash"
def ref(self): def ref(self):
return ErrorReportRef( return ErrorReportRef(
@ -349,22 +340,21 @@ class ErrorReport(metaclass=urwid.MetaSignals):
class ErrorReporter(object): class ErrorReporter(object):
def __init__(self, context, dry_run, root, client=None): def __init__(self, context, dry_run, root, client=None):
self.context = context self.context = context
self.dry_run = dry_run self.dry_run = dry_run
self.crash_directory = os.path.join(root, 'var/crash') self.crash_directory = os.path.join(root, "var/crash")
self.client = client self.client = client
self.reports = [] self.reports = []
self._reports_by_base = {} self._reports_by_base = {}
self._reports_by_exception = {} self._reports_by_exception = {}
self.crashdb_spec = { self.crashdb_spec = {
'impl': 'launchpad', "impl": "launchpad",
'project': 'subiquity', "project": "subiquity",
} }
if dry_run: if dry_run:
self.crashdb_spec['launchpad_instance'] = 'staging' self.crashdb_spec["launchpad_instance"] = "staging"
self._apport_data = [] self._apport_data = []
self._apport_files = [] self._apport_files = []
@ -398,7 +388,7 @@ class ErrorReporter(object):
return self._reports_by_exception.get(exc) return self._reports_by_exception.get(exc)
def make_apport_report(self, kind, thing, *, wait=False, exc=None, **kw): def make_apport_report(self, kind, thing, *, wait=False, exc=None, **kw):
if not self.dry_run and not os.path.exists('/cdrom/.disk/info'): if not self.dry_run and not os.path.exists("/cdrom/.disk/info"):
return None return None
log.debug("generating crash report") log.debug("generating crash report")
@ -415,16 +405,14 @@ class ErrorReporter(object):
if exc is None: if exc is None:
exc = sys.exc_info()[1] exc = sys.exc_info()[1]
if exc is not None: if exc is not None:
report.pr["Title"] = "{} crashed with {}".format( report.pr["Title"] = "{} crashed with {}".format(thing, type(exc).__name__)
thing, type(exc).__name__)
tb = traceback.TracebackException.from_exception(exc) tb = traceback.TracebackException.from_exception(exc)
report.pr['Traceback'] = "".join(tb.format()) report.pr["Traceback"] = "".join(tb.format())
self._reports_by_exception[exc] = report self._reports_by_exception[exc] = report
else: else:
report.pr["Title"] = thing report.pr["Title"] = thing
log.info( log.info("saving crash report %r to %s", report.pr["Title"], report.path)
"saving crash report %r to %s", report.pr["Title"], report.path)
apport_files = self._apport_files[:] apport_files = self._apport_files[:]
apport_data = self._apport_data.copy() apport_data = self._apport_data.copy()
@ -456,8 +444,7 @@ class ErrorReporter(object):
await self.client.errors.wait.GET(error_ref) await self.client.errors.wait.GET(error_ref)
path = os.path.join( path = os.path.join(self.crash_directory, error_ref.base + ".crash")
self.crash_directory, error_ref.base + '.crash')
report = ErrorReport.from_file(self, path) report = ErrorReport.from_file(self, path)
self.reports.insert(0, report) self.reports.insert(0, report)
self._reports_by_base[error_ref.base] = report self._reports_by_base[error_ref.base] = report

View File

@ -27,8 +27,7 @@ from subiquity.models.filesystem import (
Partition, Partition,
Raid, Raid,
raidlevels_by_value, raidlevels_by_value,
) )
_checkers = {} _checkers = {}
@ -37,8 +36,9 @@ def make_checker(action):
@functools.singledispatch @functools.singledispatch
def impl(device): def impl(device):
raise NotImplementedError( raise NotImplementedError(
"checker for %s on %s not implemented" % ( "checker for %s on %s not implemented" % (action, device)
action, device)) )
_checkers[action] = impl _checkers[action] = impl
return impl return impl
@ -75,8 +75,7 @@ class DeviceAction(enum.Enum):
@functools.singledispatch @functools.singledispatch
def _supported_actions(device): def _supported_actions(device):
raise NotImplementedError( raise NotImplementedError("_supported_actions({}) not defined".format(device))
"_supported_actions({}) not defined".format(device))
@_supported_actions.register(Disk) @_supported_actions.register(Disk)
@ -117,7 +116,7 @@ def _raid_actions(raid):
DeviceAction.REFORMAT, DeviceAction.REFORMAT,
] ]
if raid._m.bootloader == Bootloader.UEFI: if raid._m.bootloader == Bootloader.UEFI:
if raid.container and raid.container.metadata == 'imsm': if raid.container and raid.container.metadata == "imsm":
actions.append(DeviceAction.TOGGLE_BOOT) actions.append(DeviceAction.TOGGLE_BOOT)
return actions return actions
@ -161,11 +160,10 @@ def _can_edit_generic(device):
if cd is None: if cd is None:
return True return True
return _( return _(
"Cannot edit {selflabel} as it is part of the {cdtype} " "Cannot edit {selflabel} as it is part of the {cdtype} " "{cdname}."
"{cdname}.").format( ).format(
selflabel=labels.label(device), selflabel=labels.label(device), cdtype=labels.desc(cd), cdname=labels.label(cd)
cdtype=labels.desc(cd), )
cdname=labels.label(cd))
@_can_edit.register(Partition) @_can_edit.register(Partition)
@ -183,9 +181,9 @@ def _can_edit_raid(raid):
if raid.preserve: if raid.preserve:
return _("Cannot edit pre-existing RAIDs.") return _("Cannot edit pre-existing RAIDs.")
elif len(raid._partitions) > 0: elif len(raid._partitions) > 0:
return _( return _("Cannot edit {raidlabel} because it has partitions.").format(
"Cannot edit {raidlabel} because it has partitions.").format( raidlabel=labels.label(raid)
raidlabel=labels.label(raid)) )
else: else:
return _can_edit_generic(raid) return _can_edit_generic(raid)
@ -195,10 +193,9 @@ def _can_edit_vg(vg):
if vg.preserve: if vg.preserve:
return _("Cannot edit pre-existing volume groups.") return _("Cannot edit pre-existing volume groups.")
elif len(vg._partitions) > 0: elif len(vg._partitions) > 0:
return _( return _("Cannot edit {vglabel} because it has logical " "volumes.").format(
"Cannot edit {vglabel} because it has logical " vglabel=labels.label(vg)
"volumes.").format( )
vglabel=labels.label(vg))
else: else:
return _can_edit_generic(vg) return _can_edit_generic(vg)
@ -251,11 +248,13 @@ def _can_remove_device(device):
if cd is None: if cd is None:
return False return False
if cd.preserve: if cd.preserve:
return _("Cannot remove {selflabel} from pre-existing {cdtype} " return _(
"{cdlabel}.").format( "Cannot remove {selflabel} from pre-existing {cdtype} " "{cdlabel}."
).format(
selflabel=labels.label(device), selflabel=labels.label(device),
cdtype=labels.desc(cd), cdtype=labels.desc(cd),
cdlabel=labels.label(cd)) cdlabel=labels.label(cd),
)
if isinstance(cd, Raid): if isinstance(cd, Raid):
if device in cd.spare_devices: if device in cd.spare_devices:
return True return True
@ -263,19 +262,23 @@ def _can_remove_device(device):
if len(cd.devices) == min_devices: if len(cd.devices) == min_devices:
return _( return _(
"Removing {selflabel} would leave the {cdtype} {cdlabel} with " "Removing {selflabel} would leave the {cdtype} {cdlabel} with "
"less than {min_devices} devices.").format( "less than {min_devices} devices."
).format(
selflabel=labels.label(device), selflabel=labels.label(device),
cdtype=labels.desc(cd), cdtype=labels.desc(cd),
cdlabel=labels.label(cd), cdlabel=labels.label(cd),
min_devices=min_devices) min_devices=min_devices,
)
elif isinstance(cd, LVM_VolGroup): elif isinstance(cd, LVM_VolGroup):
if len(cd.devices) == 1: if len(cd.devices) == 1:
return _( return _(
"Removing {selflabel} would leave the {cdtype} {cdlabel} with " "Removing {selflabel} would leave the {cdtype} {cdlabel} with "
"no devices.").format( "no devices."
).format(
selflabel=labels.label(device), selflabel=labels.label(device),
cdtype=labels.desc(cd), cdtype=labels.desc(cd),
cdlabel=labels.label(cd)) cdlabel=labels.label(cd),
)
return True return True
@ -287,11 +290,10 @@ def _can_delete_generic(device):
if cd is None: if cd is None:
return True return True
return _( return _(
"Cannot delete {selflabel} as it is part of the {cdtype} " "Cannot delete {selflabel} as it is part of the {cdtype} " "{cdname}."
"{cdname}.").format( ).format(
selflabel=labels.label(device), selflabel=labels.label(device), cdtype=labels.desc(cd), cdname=labels.label(cd)
cdtype=labels.desc(cd), )
cdname=labels.label(cd))
@_can_delete.register(Partition) @_can_delete.register(Partition)
@ -299,8 +301,10 @@ def _can_delete_partition(partition):
if partition._is_in_use: if partition._is_in_use:
return False return False
if partition.device._has_preexisting_partition(): if partition.device._has_preexisting_partition():
return _("Cannot delete a single partition from a device that " return _(
"already has partitions.") "Cannot delete a single partition from a device that "
"already has partitions."
)
if boot.is_bootloader_partition(partition): if boot.is_bootloader_partition(partition):
return _("Cannot delete required bootloader partition") return _("Cannot delete required bootloader partition")
return _can_delete_generic(partition) return _can_delete_generic(partition)
@ -317,7 +321,8 @@ def _can_delete_raid_vg(device):
cd = p.constructed_device() cd = p.constructed_device()
return _( return _(
"Cannot delete {devicelabel} as partition {partnum} is part " "Cannot delete {devicelabel} as partition {partnum} is part "
"of the {cdtype} {cdname}.").format( "of the {cdtype} {cdname}."
).format(
devicelabel=labels.label(device), devicelabel=labels.label(device),
partnum=p.number, partnum=p.number,
cdtype=labels.desc(cd), cdtype=labels.desc(cd),
@ -325,10 +330,8 @@ def _can_delete_raid_vg(device):
) )
if mounted_partitions > 1: if mounted_partitions > 1:
return _( return _(
"Cannot delete {devicelabel} because it has {count} mounted " "Cannot delete {devicelabel} because it has {count} mounted " "partitions."
"partitions.").format( ).format(devicelabel=labels.label(device), count=mounted_partitions)
devicelabel=labels.label(device),
count=mounted_partitions)
elif mounted_partitions == 1: elif mounted_partitions == 1:
return _( return _(
"Cannot delete {devicelabel} because it has 1 mounted partition." "Cannot delete {devicelabel} because it has 1 mounted partition."
@ -340,8 +343,10 @@ def _can_delete_raid_vg(device):
@_can_delete.register(LVM_LogicalVolume) @_can_delete.register(LVM_LogicalVolume)
def _can_delete_lv(lv): def _can_delete_lv(lv):
if lv.volgroup._has_preexisting_partition(): if lv.volgroup._has_preexisting_partition():
return _("Cannot delete a single logical volume from a volume " return _(
"group that already has logical volumes.") "Cannot delete a single logical volume from a volume "
"group that already has logical volumes."
)
return True return True

View File

@ -21,16 +21,9 @@ from typing import Any, Optional
import attr import attr
from subiquity.common.filesystem import gaps, sizes from subiquity.common.filesystem import gaps, sizes
from subiquity.models.filesystem import ( from subiquity.models.filesystem import Bootloader, Disk, Partition, Raid, align_up
align_up,
Disk,
Raid,
Bootloader,
Partition,
)
log = logging.getLogger("subiquity.common.filesystem.boot")
log = logging.getLogger('subiquity.common.filesystem.boot')
@functools.singledispatch @functools.singledispatch
@ -55,7 +48,7 @@ def _is_boot_device_raid(raid):
bl = raid._m.bootloader bl = raid._m.bootloader
if bl != Bootloader.UEFI: if bl != Bootloader.UEFI:
return False return False
if not raid.container or raid.container.metadata != 'imsm': if not raid.container or raid.container.metadata != "imsm":
return False return False
return any(p.grub_device for p in raid._partitions) return any(p.grub_device for p in raid._partitions)
@ -85,8 +78,7 @@ class CreatePartPlan(MakeBootDevicePlan):
args: dict = attr.ib(factory=dict) args: dict = attr.ib(factory=dict)
def apply(self, manipulator): def apply(self, manipulator):
manipulator.create_partition( manipulator.create_partition(self.gap.device, self.gap, self.spec, **self.args)
self.gap.device, self.gap, self.spec, **self.args)
def _can_resize_part(inst, field, part): def _can_resize_part(inst, field, part):
@ -166,34 +158,39 @@ class MultiStepPlan(MakeBootDevicePlan):
def get_boot_device_plan_bios(device) -> Optional[MakeBootDevicePlan]: def get_boot_device_plan_bios(device) -> Optional[MakeBootDevicePlan]:
attr_plan = SetAttrPlan(device, 'grub_device', True) attr_plan = SetAttrPlan(device, "grub_device", True)
if device.ptable == 'msdos': if device.ptable == "msdos":
return attr_plan return attr_plan
pgs = gaps.parts_and_gaps(device) pgs = gaps.parts_and_gaps(device)
if len(pgs) > 0: if len(pgs) > 0:
if isinstance(pgs[0], Partition) and pgs[0].flag == "bios_grub": if isinstance(pgs[0], Partition) and pgs[0].flag == "bios_grub":
return attr_plan return attr_plan
gap = gaps.Gap(device=device, gap = gaps.Gap(
device=device,
offset=device.alignment_data().min_start_offset, offset=device.alignment_data().min_start_offset,
size=sizes.BIOS_GRUB_SIZE_BYTES) size=sizes.BIOS_GRUB_SIZE_BYTES,
)
create_part_plan = CreatePartPlan( create_part_plan = CreatePartPlan(
gap=gap, gap=gap,
spec=dict(size=sizes.BIOS_GRUB_SIZE_BYTES, fstype=None, mount=None), spec=dict(size=sizes.BIOS_GRUB_SIZE_BYTES, fstype=None, mount=None),
args=dict(flag='bios_grub')) args=dict(flag="bios_grub"),
)
movable = [] movable = []
for pg in pgs: for pg in pgs:
if isinstance(pg, gaps.Gap): if isinstance(pg, gaps.Gap):
if pg.size >= sizes.BIOS_GRUB_SIZE_BYTES: if pg.size >= sizes.BIOS_GRUB_SIZE_BYTES:
return MultiStepPlan(plans=[ return MultiStepPlan(
plans=[
SlidePlan( SlidePlan(
parts=movable, parts=movable, offset_delta=sizes.BIOS_GRUB_SIZE_BYTES
offset_delta=sizes.BIOS_GRUB_SIZE_BYTES), ),
create_part_plan, create_part_plan,
attr_plan, attr_plan,
]) ]
)
else: else:
return None return None
elif pg.preserve: elif pg.preserve:
@ -204,23 +201,21 @@ def get_boot_device_plan_bios(device) -> Optional[MakeBootDevicePlan]:
if not movable: if not movable:
return None return None
largest_i, largest_part = max( largest_i, largest_part = max(enumerate(movable), key=lambda i_p: i_p[1].size)
enumerate(movable), return MultiStepPlan(
key=lambda i_p: i_p[1].size) plans=[
return MultiStepPlan(plans=[ ResizePlan(part=largest_part, size_delta=-sizes.BIOS_GRUB_SIZE_BYTES),
ResizePlan(
part=largest_part,
size_delta=-sizes.BIOS_GRUB_SIZE_BYTES),
SlidePlan( SlidePlan(
parts=movable[:largest_i+1], parts=movable[: largest_i + 1], offset_delta=sizes.BIOS_GRUB_SIZE_BYTES
offset_delta=sizes.BIOS_GRUB_SIZE_BYTES), ),
create_part_plan, create_part_plan,
attr_plan, attr_plan,
]) ]
)
def get_add_part_plan(device, *, spec, args, resize_partition=None): def get_add_part_plan(device, *, spec, args, resize_partition=None):
size = spec['size'] size = spec["size"]
partitions = device.partitions() partitions = device.partitions()
create_part_plan = CreatePartPlan(gap=None, spec=spec, args=args) create_part_plan = CreatePartPlan(gap=None, spec=spec, args=args)
@ -238,70 +233,79 @@ def get_add_part_plan(device, *, spec, args, resize_partition=None):
return None return None
offset = resize_partition.offset + resize_partition.size - size offset = resize_partition.offset + resize_partition.size - size
create_part_plan.gap = gaps.Gap( create_part_plan.gap = gaps.Gap(device=device, offset=offset, size=size)
device=device, offset=offset, size=size) return MultiStepPlan(
return MultiStepPlan(plans=[ plans=[
ResizePlan( ResizePlan(
part=resize_partition, part=resize_partition,
size_delta=-size, size_delta=-size,
allow_resize_preserved=True, allow_resize_preserved=True,
), ),
create_part_plan, create_part_plan,
]) ]
)
else: else:
new_primaries = [p for p in partitions new_primaries = [
p
for p in partitions
if not p.preserve if not p.preserve
if p.flag not in ('extended', 'logical')] if p.flag not in ("extended", "logical")
]
if not new_primaries: if not new_primaries:
return None return None
largest_part = max(new_primaries, key=lambda p: p.size) largest_part = max(new_primaries, key=lambda p: p.size)
if size > largest_part.size // 2: if size > largest_part.size // 2:
return None return None
create_part_plan.gap = gaps.Gap( create_part_plan.gap = gaps.Gap(
device=device, offset=largest_part.offset, size=size) device=device, offset=largest_part.offset, size=size
return MultiStepPlan(plans=[ )
ResizePlan( return MultiStepPlan(
part=largest_part, plans=[
size_delta=-size), ResizePlan(part=largest_part, size_delta=-size),
SlidePlan( SlidePlan(parts=[largest_part], offset_delta=size),
parts=[largest_part],
offset_delta=size),
create_part_plan, create_part_plan,
]) ]
)
def get_boot_device_plan_uefi(device, resize_partition): def get_boot_device_plan_uefi(device, resize_partition):
for part in device.partitions(): for part in device.partitions():
if is_esp(part): if is_esp(part):
plans = [SetAttrPlan(part, 'grub_device', True)] plans = [SetAttrPlan(part, "grub_device", True)]
if device._m._mount_for_path('/boot/efi') is None: if device._m._mount_for_path("/boot/efi") is None:
plans.append(MountBootEfiPlan(part)) plans.append(MountBootEfiPlan(part))
return MultiStepPlan(plans=plans) return MultiStepPlan(plans=plans)
part_align = device.alignment_data().part_align part_align = device.alignment_data().part_align
size = align_up(sizes.get_efi_size(device.size), part_align) size = align_up(sizes.get_efi_size(device.size), part_align)
spec = dict(size=size, fstype='fat32', mount=None) spec = dict(size=size, fstype="fat32", mount=None)
if device._m._mount_for_path("/boot/efi") is None: if device._m._mount_for_path("/boot/efi") is None:
spec['mount'] = '/boot/efi' spec["mount"] = "/boot/efi"
return get_add_part_plan( return get_add_part_plan(
device, spec=spec, args=dict(flag='boot', grub_device=True), device,
resize_partition=resize_partition) spec=spec,
args=dict(flag="boot", grub_device=True),
resize_partition=resize_partition,
)
def get_boot_device_plan_prep(device, resize_partition): def get_boot_device_plan_prep(device, resize_partition):
for part in device.partitions(): for part in device.partitions():
if part.flag == "prep": if part.flag == "prep":
return MultiStepPlan(plans=[ return MultiStepPlan(
SetAttrPlan(part, 'grub_device', True), plans=[
SetAttrPlan(part, 'wipe', 'zero') SetAttrPlan(part, "grub_device", True),
]) SetAttrPlan(part, "wipe", "zero"),
]
)
return get_add_part_plan( return get_add_part_plan(
device, device,
spec=dict(size=sizes.PREP_GRUB_SIZE_BYTES, fstype=None, mount=None), spec=dict(size=sizes.PREP_GRUB_SIZE_BYTES, fstype=None, mount=None),
args=dict(flag='prep', grub_device=True, wipe='zero'), args=dict(flag="prep", grub_device=True, wipe="zero"),
resize_partition=resize_partition) resize_partition=resize_partition,
)
def get_boot_device_plan(device, resize_partition=None): def get_boot_device_plan(device, resize_partition=None):
@ -317,12 +321,11 @@ def get_boot_device_plan(device, resize_partition=None):
return get_boot_device_plan_prep(device, resize_partition) return get_boot_device_plan_prep(device, resize_partition)
if bl == Bootloader.NONE: if bl == Bootloader.NONE:
return NoOpBootPlan() return NoOpBootPlan()
raise Exception(f'unexpected bootloader {bl} here') raise Exception(f"unexpected bootloader {bl} here")
@functools.singledispatch @functools.singledispatch
def can_be_boot_device(device, *, def can_be_boot_device(device, *, resize_partition=None, with_reformatting=False):
resize_partition=None, with_reformatting=False):
"""Can `device` be made into a boot device? """Can `device` be made into a boot device?
If with_reformatting=True, return true if the device can be made If with_reformatting=True, return true if the device can be made
@ -332,8 +335,7 @@ def can_be_boot_device(device, *,
@can_be_boot_device.register(Disk) @can_be_boot_device.register(Disk)
def _can_be_boot_device_disk(disk, *, def _can_be_boot_device_disk(disk, *, resize_partition=None, with_reformatting=False):
resize_partition=None, with_reformatting=False):
if with_reformatting: if with_reformatting:
disk = disk._reformatted() disk = disk._reformatted()
plan = get_boot_device_plan(disk, resize_partition=resize_partition) plan = get_boot_device_plan(disk, resize_partition=resize_partition)
@ -341,12 +343,11 @@ def _can_be_boot_device_disk(disk, *,
@can_be_boot_device.register(Raid) @can_be_boot_device.register(Raid)
def _can_be_boot_device_raid(raid, *, def _can_be_boot_device_raid(raid, *, resize_partition=None, with_reformatting=False):
resize_partition=None, with_reformatting=False):
bl = raid._m.bootloader bl = raid._m.bootloader
if bl != Bootloader.UEFI: if bl != Bootloader.UEFI:
return False return False
if not raid.container or raid.container.metadata != 'imsm': if not raid.container or raid.container.metadata != "imsm":
return False return False
if with_reformatting: if with_reformatting:
return True return True
@ -376,7 +377,7 @@ def _is_esp_partition(partition):
if typecode is None: if typecode is None:
return False return False
try: try:
return int(typecode, 0) == 0xef return int(typecode, 0) == 0xEF
except ValueError: except ValueError:
# In case there was garbage in the udev entry... # In case there was garbage in the udev entry...
return False return False

View File

@ -14,21 +14,21 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import functools import functools
from typing import Tuple, List from typing import List, Tuple
import attr import attr
from subiquity.common.types import GapUsable from subiquity.common.types import GapUsable
from subiquity.models.filesystem import ( from subiquity.models.filesystem import (
align_up,
align_down,
Disk,
LVM_CHUNK_SIZE, LVM_CHUNK_SIZE,
Disk,
LVM_LogicalVolume, LVM_LogicalVolume,
LVM_VolGroup, LVM_VolGroup,
Partition, Partition,
Raid, Raid,
) align_down,
align_up,
)
# should also set on_setattr=None with attrs 20.1.0 # should also set on_setattr=None with attrs 20.1.0
@ -40,11 +40,11 @@ class Gap:
in_extended: bool = False in_extended: bool = False
usable: str = GapUsable.YES usable: str = GapUsable.YES
type: str = 'gap' type: str = "gap"
@property @property
def id(self): def id(self):
return 'gap-' + self.device.id return "gap-" + self.device.id
@property @property
def is_usable(self): def is_usable(self):
@ -55,21 +55,25 @@ class Gap:
the supplied size. If size is equal to the gap size, the second gap is the supplied size. If size is equal to the gap size, the second gap is
None. The original gap is unmodified.""" None. The original gap is unmodified."""
if size > self.size: if size > self.size:
raise ValueError('requested size larger than gap') raise ValueError("requested size larger than gap")
if size == self.size: if size == self.size:
return (self, None) return (self, None)
first_gap = Gap(device=self.device, first_gap = Gap(
device=self.device,
offset=self.offset, offset=self.offset,
size=size, size=size,
in_extended=self.in_extended, in_extended=self.in_extended,
usable=self.usable) usable=self.usable,
)
if self.in_extended: if self.in_extended:
size += self.device.alignment_data().ebr_space size += self.device.alignment_data().ebr_space
rest_gap = Gap(device=self.device, rest_gap = Gap(
device=self.device,
offset=self.offset + size, offset=self.offset + size,
size=self.size - size, size=self.size - size,
in_extended=self.in_extended, in_extended=self.in_extended,
usable=self.usable) usable=self.usable,
)
return (first_gap, rest_gap) return (first_gap, rest_gap)
def within(self): def within(self):
@ -134,11 +138,15 @@ def find_disk_gaps_v2(device, info=None):
else: else:
usable = GapUsable.TOO_MANY_PRIMARY_PARTS usable = GapUsable.TOO_MANY_PRIMARY_PARTS
if end - start >= info.min_gap_size: if end - start >= info.min_gap_size:
result.append(Gap(device=device, result.append(
Gap(
device=device,
offset=start, offset=start,
size=end - start, size=end - start,
in_extended=in_extended, in_extended=in_extended,
usable=usable)) usable=usable,
)
)
prev_end = info.min_start_offset prev_end = info.min_start_offset
@ -158,8 +166,7 @@ def find_disk_gaps_v2(device, info=None):
gap_start = au(prev_end) gap_start = au(prev_end)
if extended_end is not None: if extended_end is not None:
gap_start = min( gap_start = min(extended_end, au(gap_start + info.ebr_space))
extended_end, au(gap_start + info.ebr_space))
if extended_end is not None and gap_end >= extended_end: if extended_end is not None and gap_end >= extended_end:
maybe_add_gap(gap_start, ad(extended_end), True) maybe_add_gap(gap_start, ad(extended_end), True)
@ -247,7 +254,7 @@ def largest_gap_size(device, in_extended=None):
@functools.singledispatch @functools.singledispatch
def movable_trailing_partitions_and_gap_size(partition): def movable_trailing_partitions_and_gap_size(partition):
""" For a given partition (or LVM logical volume), return the total, """For a given partition (or LVM logical volume), return the total,
potentially available, free space immediately following the partition. potentially available, free space immediately following the partition.
By potentially available, we mean that to claim that much free space, some By potentially available, we mean that to claim that much free space, some
other partitions might need to be moved. other partitions might need to be moved.
@ -259,13 +266,14 @@ def movable_trailing_partitions_and_gap_size(partition):
@movable_trailing_partitions_and_gap_size.register @movable_trailing_partitions_and_gap_size.register
def _movable_trailing_partitions_and_gap_size_partition(partition: Partition) \ def _movable_trailing_partitions_and_gap_size_partition(
-> Tuple[List[Partition], int]: partition: Partition,
) -> Tuple[List[Partition], int]:
pgs = parts_and_gaps(partition.device) pgs = parts_and_gaps(partition.device)
part_idx = pgs.index(partition) part_idx = pgs.index(partition)
trailing_partitions = [] trailing_partitions = []
in_extended = partition.is_logical in_extended = partition.is_logical
for pg in pgs[part_idx + 1:]: for pg in pgs[part_idx + 1 :]:
if isinstance(pg, Partition): if isinstance(pg, Partition):
if pg.preserve: if pg.preserve:
break break
@ -281,8 +289,9 @@ def _movable_trailing_partitions_and_gap_size_partition(partition: Partition) \
@movable_trailing_partitions_and_gap_size.register @movable_trailing_partitions_and_gap_size.register
def _movable_trailing_partitions_and_gap_size_lvm(volume: LVM_LogicalVolume) \ def _movable_trailing_partitions_and_gap_size_lvm(
-> Tuple[List[LVM_LogicalVolume], int]: volume: LVM_LogicalVolume,
) -> Tuple[List[LVM_LogicalVolume], int]:
# In a Volume Group, there is no need to move partitions around, one can # In a Volume Group, there is no need to move partitions around, one can
# always use the remaining space. # always use the remaining space.

View File

@ -18,18 +18,18 @@ import functools
from subiquity.common import types from subiquity.common import types
from subiquity.common.filesystem import boot, gaps from subiquity.common.filesystem import boot, gaps
from subiquity.models.filesystem import ( from subiquity.models.filesystem import (
ZFS,
Disk, Disk,
LVM_LogicalVolume, LVM_LogicalVolume,
LVM_VolGroup, LVM_VolGroup,
Partition, Partition,
Raid, Raid,
ZFS,
ZPool, ZPool,
) )
def _annotations_generic(device): def _annotations_generic(device):
preserve = getattr(device, 'preserve', None) preserve = getattr(device, "preserve", None)
if preserve is None: if preserve is None:
return [] return []
elif preserve: elif preserve:
@ -128,16 +128,15 @@ def _desc_partition(partition):
@desc.register(Raid) @desc.register(Raid)
def _desc_raid(raid): def _desc_raid(raid):
level = raid.raidlevel level = raid.raidlevel
if level.lower().startswith('raid'): if level.lower().startswith("raid"):
level = level[4:] level = level[4:]
if raid.container: if raid.container:
raid_type = raid.container.metadata raid_type = raid.container.metadata
elif raid.metadata == "imsm": elif raid.metadata == "imsm":
raid_type = 'imsm' # maybe VROC? raid_type = "imsm" # maybe VROC?
else: else:
raid_type = _("software") raid_type = _("software")
return _("{type} RAID {level}").format( return _("{type} RAID {level}").format(type=raid_type, level=level)
type=raid_type, level=level)
@desc.register(LVM_VolGroup) @desc.register(LVM_VolGroup)
@ -199,7 +198,8 @@ def _label_partition(partition, *, short=False):
return p + _("partition {number}").format(number=partition.number) return p + _("partition {number}").format(number=partition.number)
else: else:
return p + _("partition {number} of {device}").format( return p + _("partition {number} of {device}").format(
number=partition.number, device=label(partition.device)) number=partition.number, device=label(partition.device)
)
@label.register(gaps.Gap) @label.register(gaps.Gap)
@ -215,9 +215,8 @@ def _usage_labels_generic(device, *, exclude_final_unused=False):
if cd is not None: if cd is not None:
return [ return [
_("{component_name} of {desc} {name}").format( _("{component_name} of {desc} {name}").format(
component_name=cd.component_name, component_name=cd.component_name, desc=desc(cd), name=cd.name
desc=desc(cd), ),
name=cd.name),
] ]
fs = device.fs() fs = device.fs()
if fs is not None: if fs is not None:
@ -279,10 +278,11 @@ def _usage_labels_disk(disk):
@usage_labels.register(Raid) @usage_labels.register(Raid)
def _usage_labels_raid(raid): def _usage_labels_raid(raid):
if raid.metadata == 'imsm' and raid._subvolumes: if raid.metadata == "imsm" and raid._subvolumes:
return [ return [
_('container for {devices}').format( _("container for {devices}").format(
devices=', '.join([label(v) for v in raid._subvolumes])) devices=", ".join([label(v) for v in raid._subvolumes])
)
] ]
return _usage_labels_generic(raid) return _usage_labels_generic(raid)
@ -309,7 +309,7 @@ def _for_client_disk(disk, *, min_size=0):
return types.Disk( return types.Disk(
id=disk.id, id=disk.id,
label=label(disk), label=label(disk),
path=getattr(disk, 'path', None), path=getattr(disk, "path", None),
type=desc(disk), type=desc(disk),
size=disk.size, size=disk.size,
ptable=disk.ptable, ptable=disk.ptable,
@ -319,9 +319,10 @@ def _for_client_disk(disk, *, min_size=0):
boot_device=boot.is_boot_device(disk), boot_device=boot.is_boot_device(disk),
can_be_boot_device=boot.can_be_boot_device(disk), can_be_boot_device=boot.can_be_boot_device(disk),
ok_for_guided=disk.size >= min_size, ok_for_guided=disk.size >= min_size,
model=getattr(disk, 'model', None), model=getattr(disk, "model", None),
vendor=getattr(disk, 'vendor', None), vendor=getattr(disk, "vendor", None),
has_in_use_partition=disk._has_in_use_partition) has_in_use_partition=disk._has_in_use_partition,
)
@for_client.register(Partition) @for_client.register(Partition)
@ -341,7 +342,8 @@ def _for_client_partition(partition, *, min_size=0):
estimated_min_size=partition.estimated_min_size, estimated_min_size=partition.estimated_min_size,
mount=partition.mount, mount=partition.mount,
format=partition.format, format=partition.format,
is_in_use=partition._is_in_use) is_in_use=partition._is_in_use,
)
@for_client.register(gaps.Gap) @for_client.register(gaps.Gap)

View File

@ -19,20 +19,16 @@ from curtin.block import get_resize_fstypes
from subiquity.common.filesystem import boot, gaps from subiquity.common.filesystem import boot, gaps
from subiquity.common.types import Bootloader from subiquity.common.types import Bootloader
from subiquity.models.filesystem import ( from subiquity.models.filesystem import Partition, align_up
align_up,
Partition,
)
log = logging.getLogger('subiquity.common.filesystem.manipulator') log = logging.getLogger("subiquity.common.filesystem.manipulator")
class FilesystemManipulator: class FilesystemManipulator:
def create_mount(self, fs, spec): def create_mount(self, fs, spec):
if spec.get('mount') is None: if spec.get("mount") is None:
return return
mount = self.model.add_mount(fs, spec['mount']) mount = self.model.add_mount(fs, spec["mount"])
if self.model.needs_bootloader_partition(): if self.model.needs_bootloader_partition():
vol = fs.volume vol = fs.volume
if vol.type == "partition" and boot.can_be_boot_device(vol.device): if vol.type == "partition" and boot.can_be_boot_device(vol.device):
@ -45,20 +41,20 @@ class FilesystemManipulator:
self.model.remove_mount(mount) self.model.remove_mount(mount)
def create_filesystem(self, volume, spec): def create_filesystem(self, volume, spec):
if spec.get('fstype') is None: if spec.get("fstype") is None:
# prep partitions are always wiped (and never have a filesystem) # prep partitions are always wiped (and never have a filesystem)
if getattr(volume, 'flag', None) != 'prep': if getattr(volume, "flag", None) != "prep":
volume.wipe = spec.get('wipe', volume.wipe) volume.wipe = spec.get("wipe", volume.wipe)
fstype = volume.original_fstype() fstype = volume.original_fstype()
if fstype is None: if fstype is None:
return None return None
else: else:
fstype = spec['fstype'] fstype = spec["fstype"]
# partition editing routines are expected to provide the # partition editing routines are expected to provide the
# appropriate wipe value. partition additions may not, give them a # appropriate wipe value. partition additions may not, give them a
# basic wipe value consistent with older behavior until we prove # basic wipe value consistent with older behavior until we prove
# that everyone does so correctly. # that everyone does so correctly.
volume.wipe = spec.get('wipe', 'superblock') volume.wipe = spec.get("wipe", "superblock")
preserve = volume.wipe is None preserve = volume.wipe is None
fs = self.model.add_filesystem(volume, fstype, preserve) fs = self.model.add_filesystem(volume, fstype, preserve)
if isinstance(volume, Partition): if isinstance(volume, Partition):
@ -66,9 +62,9 @@ class FilesystemManipulator:
volume.flag = "swap" volume.flag = "swap"
elif volume.flag == "swap": elif volume.flag == "swap":
volume.flag = "" volume.flag = ""
if spec.get('fstype') == "swap": if spec.get("fstype") == "swap":
self.model.add_mount(fs, "") self.model.add_mount(fs, "")
elif spec.get('fstype') is None and spec.get('use_swap'): elif spec.get("fstype") is None and spec.get("use_swap"):
self.model.add_mount(fs, "") self.model.add_mount(fs, "")
else: else:
self.create_mount(fs, spec) self.create_mount(fs, spec)
@ -79,35 +75,39 @@ class FilesystemManipulator:
return return
self.delete_mount(fs.mount()) self.delete_mount(fs.mount())
self.model.remove_filesystem(fs) self.model.remove_filesystem(fs)
delete_format = delete_filesystem delete_format = delete_filesystem
def create_partition(self, device, gap, spec, **kw): def create_partition(self, device, gap, spec, **kw):
flag = kw.pop('flag', None) flag = kw.pop("flag", None)
if gap.in_extended: if gap.in_extended:
if flag not in (None, 'logical'): if flag not in (None, "logical"):
log.debug(f'overriding flag {flag} ' log.debug(
'due to being in an extended partition') f"overriding flag {flag} " "due to being in an extended partition"
flag = 'logical' )
flag = "logical"
part = self.model.add_partition( part = self.model.add_partition(
device, size=gap.size, offset=gap.offset, flag=flag, **kw) device, size=gap.size, offset=gap.offset, flag=flag, **kw
)
self.create_filesystem(part, spec) self.create_filesystem(part, spec)
return part return part
def delete_partition(self, part, override_preserve=False): def delete_partition(self, part, override_preserve=False):
if not override_preserve and part.device.preserve and \ if (
self.model.storage_version < 2: not override_preserve
and part.device.preserve
and self.model.storage_version < 2
):
raise Exception("cannot delete partitions from preserved disks") raise Exception("cannot delete partitions from preserved disks")
self.clear(part) self.clear(part)
self.model.remove_partition(part) self.model.remove_partition(part)
def create_raid(self, spec): def create_raid(self, spec):
for d in spec['devices'] | spec['spare_devices']: for d in spec["devices"] | spec["spare_devices"]:
self.clear(d) self.clear(d)
raid = self.model.add_raid( raid = self.model.add_raid(
spec['name'], spec["name"], spec["level"].value, spec["devices"], spec["spare_devices"]
spec['level'].value, )
spec['devices'],
spec['spare_devices'])
return raid return raid
def delete_raid(self, raid): def delete_raid(self, raid):
@ -119,69 +119,74 @@ class FilesystemManipulator:
for p in list(raid.partitions()): for p in list(raid.partitions()):
self.delete_partition(p, True) self.delete_partition(p, True)
for d in set(raid.devices) | set(raid.spare_devices): for d in set(raid.devices) | set(raid.spare_devices):
d.wipe = 'superblock' d.wipe = "superblock"
self.model.remove_raid(raid) self.model.remove_raid(raid)
def create_volgroup(self, spec): def create_volgroup(self, spec):
devices = set() devices = set()
key = spec.get('passphrase') key = spec.get("passphrase")
for device in spec['devices']: for device in spec["devices"]:
self.clear(device) self.clear(device)
if key: if key:
device = self.model.add_dm_crypt(device, key) device = self.model.add_dm_crypt(device, key)
devices.add(device) devices.add(device)
return self.model.add_volgroup(name=spec['name'], devices=devices) return self.model.add_volgroup(name=spec["name"], devices=devices)
create_lvm_volgroup = create_volgroup create_lvm_volgroup = create_volgroup
def delete_volgroup(self, vg): def delete_volgroup(self, vg):
for lv in list(vg.partitions()): for lv in list(vg.partitions()):
self.delete_logical_volume(lv) self.delete_logical_volume(lv)
for d in vg.devices: for d in vg.devices:
d.wipe = 'superblock' d.wipe = "superblock"
if d.type == "dm_crypt": if d.type == "dm_crypt":
self.model.remove_dm_crypt(d) self.model.remove_dm_crypt(d)
self.model.remove_volgroup(vg) self.model.remove_volgroup(vg)
delete_lvm_volgroup = delete_volgroup delete_lvm_volgroup = delete_volgroup
def create_logical_volume(self, vg, spec): def create_logical_volume(self, vg, spec):
lv = self.model.add_logical_volume( lv = self.model.add_logical_volume(vg=vg, name=spec["name"], size=spec["size"])
vg=vg,
name=spec['name'],
size=spec['size'])
self.create_filesystem(lv, spec) self.create_filesystem(lv, spec)
return lv return lv
create_lvm_partition = create_logical_volume create_lvm_partition = create_logical_volume
def delete_logical_volume(self, lv): def delete_logical_volume(self, lv):
self.clear(lv) self.clear(lv)
self.model.remove_logical_volume(lv) self.model.remove_logical_volume(lv)
delete_lvm_partition = delete_logical_volume delete_lvm_partition = delete_logical_volume
def create_zpool(self, device, pool, mountpoint): def create_zpool(self, device, pool, mountpoint):
fs_properties = dict( fs_properties = dict(
acltype='posixacl', acltype="posixacl",
relatime='on', relatime="on",
canmount='on', canmount="on",
compression='gzip', compression="gzip",
devices='off', devices="off",
xattr='sa', xattr="sa",
) )
pool_properties = dict(ashift=12) pool_properties = dict(ashift=12)
self.model.add_zpool( self.model.add_zpool(
device, pool, mountpoint, device,
fs_properties=fs_properties, pool_properties=pool_properties) pool,
mountpoint,
fs_properties=fs_properties,
pool_properties=pool_properties,
)
def delete(self, obj): def delete(self, obj):
if obj is None: if obj is None:
return return
getattr(self, 'delete_' + obj.type)(obj) getattr(self, "delete_" + obj.type)(obj)
def clear(self, obj, wipe=None): def clear(self, obj, wipe=None):
if obj.type == "disk": if obj.type == "disk":
obj.preserve = False obj.preserve = False
if wipe is None: if wipe is None:
wipe = 'superblock' wipe = "superblock"
obj.wipe = wipe obj.wipe = wipe
for subobj in obj.fs(), obj.constructed_device(): for subobj in obj.fs(), obj.constructed_device():
self.delete(subobj) self.delete(subobj)
@ -202,14 +207,14 @@ class FilesystemManipulator:
return True return True
def partition_disk_handler(self, disk, spec, *, partition=None, gap=None): def partition_disk_handler(self, disk, spec, *, partition=None, gap=None):
log.debug('partition_disk_handler: %s %s %s %s', log.debug("partition_disk_handler: %s %s %s %s", disk, spec, partition, gap)
disk, spec, partition, gap)
if partition is not None: if partition is not None:
if 'size' in spec and spec['size'] != partition.size: if "size" in spec and spec["size"] != partition.size:
trailing, gap_size = \ trailing, gap_size = gaps.movable_trailing_partitions_and_gap_size(
gaps.movable_trailing_partitions_and_gap_size(partition) partition
new_size = align_up(spec['size']) )
new_size = align_up(spec["size"])
size_change = new_size - partition.size size_change = new_size - partition.size
if size_change > gap_size: if size_change > gap_size:
raise Exception("partition size too large") raise Exception("partition size too large")
@ -226,12 +231,12 @@ class FilesystemManipulator:
if len(disk.partitions()) == 0: if len(disk.partitions()) == 0:
if disk.type == "disk": if disk.type == "disk":
disk.preserve = False disk.preserve = False
disk.wipe = 'superblock-recursive' disk.wipe = "superblock-recursive"
elif disk.type == "raid": elif disk.type == "raid":
disk.wipe = 'superblock-recursive' disk.wipe = "superblock-recursive"
needs_boot = self.model.needs_bootloader_partition() needs_boot = self.model.needs_bootloader_partition()
log.debug('model needs a bootloader partition? {}'.format(needs_boot)) log.debug("model needs a bootloader partition? {}".format(needs_boot))
can_be_boot = boot.can_be_boot_device(disk) can_be_boot = boot.can_be_boot_device(disk)
if needs_boot and len(disk.partitions()) == 0 and can_be_boot: if needs_boot and len(disk.partitions()) == 0 and can_be_boot:
self.add_boot_disk(disk) self.add_boot_disk(disk)
@ -241,13 +246,13 @@ class FilesystemManipulator:
# 1) with len(partitions()) == 0 there could only have been 1 gap # 1) with len(partitions()) == 0 there could only have been 1 gap
# 2) having just done add_boot_disk(), the gap is no longer valid. # 2) having just done add_boot_disk(), the gap is no longer valid.
gap = gaps.largest_gap(disk) gap = gaps.largest_gap(disk)
if spec['size'] > gap.size: if spec["size"] > gap.size:
log.debug( log.debug(
"Adjusting request down from %s to %s", "Adjusting request down from %s to %s", spec["size"], gap.size
spec['size'], gap.size) )
spec['size'] = gap.size spec["size"] = gap.size
gap = gap.split(spec['size'])[0] gap = gap.split(spec["size"])[0]
self.create_partition(disk, gap, spec) self.create_partition(disk, gap, spec)
log.debug("Successfully added partition") log.debug("Successfully added partition")
@ -256,13 +261,13 @@ class FilesystemManipulator:
# keep the partition name for compat with PartitionStretchy.handler # keep the partition name for compat with PartitionStretchy.handler
lv = partition lv = partition
log.debug('logical_volume_handler: %s %s %s', vg, lv, spec) log.debug("logical_volume_handler: %s %s %s", vg, lv, spec)
if lv is not None: if lv is not None:
if 'name' in spec: if "name" in spec:
lv.name = spec['name'] lv.name = spec["name"]
if 'size' in spec: if "size" in spec:
lv.size = align_up(spec['size']) lv.size = align_up(spec["size"])
if gaps.largest_gap_size(vg) < 0: if gaps.largest_gap_size(vg) < 0:
raise Exception("lv size too large") raise Exception("lv size too large")
self.delete_filesystem(lv.fs()) self.delete_filesystem(lv.fs())
@ -272,7 +277,7 @@ class FilesystemManipulator:
self.create_logical_volume(vg, spec) self.create_logical_volume(vg, spec)
def add_format_handler(self, volume, spec): def add_format_handler(self, volume, spec):
log.debug('add_format_handler %s %s', volume, spec) log.debug("add_format_handler %s %s", volume, spec)
self.clear(volume) self.clear(volume)
self.create_filesystem(volume, spec) self.create_filesystem(volume, spec)
@ -281,47 +286,46 @@ class FilesystemManipulator:
if existing is not None: if existing is not None:
for d in existing.devices | existing.spare_devices: for d in existing.devices | existing.spare_devices:
d._constructed_device = None d._constructed_device = None
for d in spec['devices'] | spec['spare_devices']: for d in spec["devices"] | spec["spare_devices"]:
self.clear(d) self.clear(d)
d._constructed_device = existing d._constructed_device = existing
existing.name = spec['name'] existing.name = spec["name"]
existing.raidlevel = spec['level'].value existing.raidlevel = spec["level"].value
existing.devices = spec['devices'] existing.devices = spec["devices"]
existing.spare_devices = spec['spare_devices'] existing.spare_devices = spec["spare_devices"]
else: else:
self.create_raid(spec) self.create_raid(spec)
def volgroup_handler(self, existing, spec): def volgroup_handler(self, existing, spec):
if existing is not None: if existing is not None:
key = spec.get('passphrase') key = spec.get("passphrase")
for d in existing.devices: for d in existing.devices:
if d.type == "dm_crypt": if d.type == "dm_crypt":
self.model.remove_dm_crypt(d) self.model.remove_dm_crypt(d)
d = d.volume d = d.volume
d._constructed_device = None d._constructed_device = None
devices = set() devices = set()
for d in spec['devices']: for d in spec["devices"]:
self.clear(d) self.clear(d)
if key: if key:
d = self.model.add_dm_crypt(d, key) d = self.model.add_dm_crypt(d, key)
d._constructed_device = existing d._constructed_device = existing
devices.add(d) devices.add(d)
existing.name = spec['name'] existing.name = spec["name"]
existing.devices = devices existing.devices = devices
else: else:
self.create_volgroup(spec) self.create_volgroup(spec)
def _mount_esp(self, part): def _mount_esp(self, part):
if part.fs() is None: if part.fs() is None:
self.model.add_filesystem(part, 'fat32') self.model.add_filesystem(part, "fat32")
self.model.add_mount(part.fs(), '/boot/efi') self.model.add_mount(part.fs(), "/boot/efi")
def remove_boot_disk(self, boot_disk): def remove_boot_disk(self, boot_disk):
if self.model.bootloader == Bootloader.BIOS: if self.model.bootloader == Bootloader.BIOS:
boot_disk.grub_device = False boot_disk.grub_device = False
partitions = [ partitions = [
p for p in boot_disk.partitions() p for p in boot_disk.partitions() if boot.is_bootloader_partition(p)
if boot.is_bootloader_partition(p)
] ]
remount = False remount = False
if boot_disk.preserve: if boot_disk.preserve:
@ -339,7 +343,8 @@ class FilesystemManipulator:
if not p.fs().preserve and p.original_fstype(): if not p.fs().preserve and p.original_fstype():
self.delete_filesystem(p.fs()) self.delete_filesystem(p.fs())
self.model.add_filesystem( self.model.add_filesystem(
p, p.original_fstype(), preserve=True) p, p.original_fstype(), preserve=True
)
else: else:
full = gaps.largest_gap_size(boot_disk) == 0 full = gaps.largest_gap_size(boot_disk) == 0
tot_size = 0 tot_size = 0
@ -349,11 +354,10 @@ class FilesystemManipulator:
remount = True remount = True
self.delete_partition(p) self.delete_partition(p)
if full: if full:
largest_part = max( largest_part = max(boot_disk.partitions(), key=lambda p: p.size)
boot_disk.partitions(), key=lambda p: p.size)
largest_part.size += tot_size largest_part.size += tot_size
if self.model.bootloader == Bootloader.UEFI and remount: if self.model.bootloader == Bootloader.UEFI and remount:
part = self.model._one(type='partition', grub_device=True) part = self.model._one(type="partition", grub_device=True)
if part: if part:
self._mount_esp(part) self._mount_esp(part)
@ -363,7 +367,7 @@ class FilesystemManipulator:
self.remove_boot_disk(disk) self.remove_boot_disk(disk)
plan = boot.get_boot_device_plan(new_boot_disk) plan = boot.get_boot_device_plan(new_boot_disk)
if plan is None: if plan is None:
raise ValueError(f'No known plan to make {new_boot_disk} bootable') raise ValueError(f"No known plan to make {new_boot_disk} bootable")
plan.apply(self) plan.apply(self)
if not new_boot_disk._has_preexisting_partition(): if not new_boot_disk._has_preexisting_partition():
if new_boot_disk.type == "disk": if new_boot_disk.type == "disk":

View File

@ -16,17 +16,10 @@
import math import math
import attr import attr
from curtin import swap from curtin import swap
from subiquity.models.filesystem import (
align_up,
align_down,
MiB,
GiB,
)
from subiquity.common.types import GuidedResizeValues from subiquity.common.types import GuidedResizeValues
from subiquity.models.filesystem import GiB, MiB, align_down, align_up
BIOS_GRUB_SIZE_BYTES = 1 * MiB BIOS_GRUB_SIZE_BYTES = 1 * MiB
PREP_GRUB_SIZE_BYTES = 8 * MiB PREP_GRUB_SIZE_BYTES = 8 * MiB
@ -39,18 +32,11 @@ class PartitionScaleFactors:
maximum: int maximum: int
uefi_scale = PartitionScaleFactors( uefi_scale = PartitionScaleFactors(minimum=538 * MiB, priority=538, maximum=1075 * MiB)
minimum=538 * MiB,
priority=538,
maximum=1075 * MiB)
bootfs_scale = PartitionScaleFactors( bootfs_scale = PartitionScaleFactors(
minimum=1792 * MiB, minimum=1792 * MiB, priority=1024, maximum=2048 * MiB
priority=1024, )
maximum=2048 * MiB) rootfs_scale = PartitionScaleFactors(minimum=900 * MiB, priority=10000, maximum=-1)
rootfs_scale = PartitionScaleFactors(
minimum=900 * MiB,
priority=10000,
maximum=-1)
def scale_partitions(all_factors, available_space): def scale_partitions(all_factors, available_space):
@ -96,14 +82,15 @@ def get_bootfs_size(available_space):
# decide to keep with the existing partition, or allocate to the new install # decide to keep with the existing partition, or allocate to the new install
# 4) Assume that the installs will grow proportionally to their minimum sizes, # 4) Assume that the installs will grow proportionally to their minimum sizes,
# and split according to the ratio of the minimum sizes # and split according to the ratio of the minimum sizes
def calculate_guided_resize(part_min: int, part_size: int, install_min: int, def calculate_guided_resize(
part_align: int = MiB) -> GuidedResizeValues: part_min: int, part_size: int, install_min: int, part_align: int = MiB
) -> GuidedResizeValues:
if part_min < 0: if part_min < 0:
return None return None
part_size = align_up(part_size, part_align) part_size = align_up(part_size, part_align)
other_room_to_grow = max(2 * GiB, math.ceil(.25 * part_min)) other_room_to_grow = max(2 * GiB, math.ceil(0.25 * part_min))
padded_other_min = part_min + other_room_to_grow padded_other_min = part_min + other_room_to_grow
other_min = min(align_up(padded_other_min, part_align), part_size) other_min = min(align_up(padded_other_min, part_align), part_size)
@ -118,7 +105,10 @@ def calculate_guided_resize(part_min: int, part_size: int, install_min: int,
recommended = align_up(raw_recommended, part_align) recommended = align_up(raw_recommended, part_align)
return GuidedResizeValues( return GuidedResizeValues(
install_max=plausible_free_space, install_max=plausible_free_space,
minimum=other_min, recommended=recommended, maximum=other_max) minimum=other_min,
recommended=recommended,
maximum=other_max,
)
# Factors for suggested minimum install size: # Factors for suggested minimum install size:
@ -142,10 +132,9 @@ def calculate_guided_resize(part_min: int, part_size: int, install_min: int,
# 4) Swap - Ubuntu policy recommends swap, either in the form of a partition or # 4) Swap - Ubuntu policy recommends swap, either in the form of a partition or
# a swapfile. Curtin has an existing swap size recommendation at # a swapfile. Curtin has an existing swap size recommendation at
# curtin.swap.suggested_swapsize(). # curtin.swap.suggested_swapsize().
def calculate_suggested_install_min(source_min: int, def calculate_suggested_install_min(source_min: int, part_align: int = MiB) -> int:
part_align: int = MiB) -> int:
room_for_boot = bootfs_scale.maximum room_for_boot = bootfs_scale.maximum
room_to_grow = max(2 * GiB, math.ceil(.8 * source_min)) room_to_grow = max(2 * GiB, math.ceil(0.8 * source_min))
room_for_swap = swap.suggested_swapsize() room_for_swap = swap.suggested_swapsize()
total = source_min + room_for_boot + room_to_grow + room_for_swap total = source_min + room_for_boot + room_to_grow + room_for_swap
return align_up(total, part_align) return align_up(total, part_align)

View File

@ -15,10 +15,8 @@
import unittest import unittest
from subiquity.common.filesystem.actions import (
DeviceAction,
)
from subiquity.common.filesystem import gaps from subiquity.common.filesystem import gaps
from subiquity.common.filesystem.actions import DeviceAction
from subiquity.models.filesystem import Bootloader from subiquity.models.filesystem import Bootloader
from subiquity.models.tests.test_filesystem import ( from subiquity.models.tests.test_filesystem import (
make_disk, make_disk,
@ -32,11 +30,10 @@ from subiquity.models.tests.test_filesystem import (
make_partition, make_partition,
make_raid, make_raid,
make_vg, make_vg,
) )
class TestActions(unittest.TestCase): class TestActions(unittest.TestCase):
def assertActionNotSupported(self, obj, action): def assertActionNotSupported(self, obj, action):
self.assertNotIn(action, DeviceAction.supported(obj)) self.assertNotIn(action, DeviceAction.supported(obj))
@ -51,7 +48,7 @@ class TestActions(unittest.TestCase):
def _test_remove_action(self, model, objects): def _test_remove_action(self, model, objects):
self.assertActionNotPossible(objects[0], DeviceAction.REMOVE) self.assertActionNotPossible(objects[0], DeviceAction.REMOVE)
vg = model.add_volgroup('vg1', {objects[0], objects[1]}) vg = model.add_volgroup("vg1", {objects[0], objects[1]})
self.assertActionPossible(objects[0], DeviceAction.REMOVE) self.assertActionPossible(objects[0], DeviceAction.REMOVE)
# Cannot remove a device from a preexisting VG # Cannot remove a device from a preexisting VG
@ -63,7 +60,7 @@ class TestActions(unittest.TestCase):
vg.devices.remove(objects[1]) vg.devices.remove(objects[1])
objects[1]._constructed_device = None objects[1]._constructed_device = None
self.assertActionNotPossible(objects[0], DeviceAction.REMOVE) self.assertActionNotPossible(objects[0], DeviceAction.REMOVE)
raid = model.add_raid('md0', 'raid1', set(objects[2:]), set()) raid = model.add_raid("md0", "raid1", set(objects[2:]), set())
self.assertActionPossible(objects[2], DeviceAction.REMOVE) self.assertActionPossible(objects[2], DeviceAction.REMOVE)
# Cannot remove a device from a preexisting RAID # Cannot remove a device from a preexisting RAID
@ -91,22 +88,22 @@ class TestActions(unittest.TestCase):
self.assertActionNotPossible(disk1, DeviceAction.REFORMAT) self.assertActionNotPossible(disk1, DeviceAction.REFORMAT)
disk1p1 = make_partition(model, disk1, preserve=False) disk1p1 = make_partition(model, disk1, preserve=False)
self.assertActionPossible(disk1, DeviceAction.REFORMAT) self.assertActionPossible(disk1, DeviceAction.REFORMAT)
model.add_volgroup('vg0', {disk1p1}) model.add_volgroup("vg0", {disk1p1})
self.assertActionNotPossible(disk1, DeviceAction.REFORMAT) self.assertActionNotPossible(disk1, DeviceAction.REFORMAT)
disk2 = make_disk(model, preserve=True) disk2 = make_disk(model, preserve=True)
self.assertActionNotPossible(disk2, DeviceAction.REFORMAT) self.assertActionNotPossible(disk2, DeviceAction.REFORMAT)
disk2p1 = make_partition(model, disk2, preserve=True) disk2p1 = make_partition(model, disk2, preserve=True)
self.assertActionPossible(disk2, DeviceAction.REFORMAT) self.assertActionPossible(disk2, DeviceAction.REFORMAT)
model.add_volgroup('vg1', {disk2p1}) model.add_volgroup("vg1", {disk2p1})
self.assertActionNotPossible(disk2, DeviceAction.REFORMAT) self.assertActionNotPossible(disk2, DeviceAction.REFORMAT)
disk3 = make_disk(model, preserve=False) disk3 = make_disk(model, preserve=False)
model.add_volgroup('vg2', {disk3}) model.add_volgroup("vg2", {disk3})
self.assertActionNotPossible(disk3, DeviceAction.REFORMAT) self.assertActionNotPossible(disk3, DeviceAction.REFORMAT)
disk4 = make_disk(model, preserve=True) disk4 = make_disk(model, preserve=True)
model.add_volgroup('vg2', {disk4}) model.add_volgroup("vg2", {disk4})
self.assertActionNotPossible(disk4, DeviceAction.REFORMAT) self.assertActionNotPossible(disk4, DeviceAction.REFORMAT)
def test_disk_action_PARTITION(self): def test_disk_action_PARTITION(self):
@ -119,7 +116,7 @@ class TestActions(unittest.TestCase):
make_partition(model, disk) make_partition(model, disk)
self.assertActionNotPossible(disk, DeviceAction.FORMAT) self.assertActionNotPossible(disk, DeviceAction.FORMAT)
disk2 = make_disk(model) disk2 = make_disk(model)
model.add_volgroup('vg1', {disk2}) model.add_volgroup("vg1", {disk2})
self.assertActionNotPossible(disk2, DeviceAction.FORMAT) self.assertActionNotPossible(disk2, DeviceAction.FORMAT)
def test_disk_action_REMOVE(self): def test_disk_action_REMOVE(self):
@ -138,18 +135,18 @@ class TestActions(unittest.TestCase):
def test_disk_action_TOGGLE_BOOT_BIOS(self): def test_disk_action_TOGGLE_BOOT_BIOS(self):
model = make_model(Bootloader.BIOS) model = make_model(Bootloader.BIOS)
# Disks with msdos partition tables can always be the BIOS boot disk. # Disks with msdos partition tables can always be the BIOS boot disk.
dos_disk = make_disk(model, ptable='msdos', preserve=True) dos_disk = make_disk(model, ptable="msdos", preserve=True)
self.assertActionPossible(dos_disk, DeviceAction.TOGGLE_BOOT) self.assertActionPossible(dos_disk, DeviceAction.TOGGLE_BOOT)
# Even if they have existing partitions # Even if they have existing partitions
make_partition( make_partition(
model, dos_disk, size=gaps.largest_gap_size(dos_disk), model, dos_disk, size=gaps.largest_gap_size(dos_disk), preserve=True
preserve=True) )
self.assertActionPossible(dos_disk, DeviceAction.TOGGLE_BOOT) self.assertActionPossible(dos_disk, DeviceAction.TOGGLE_BOOT)
# (we never create dos partition tables so no need to test # (we never create dos partition tables so no need to test
# preserve=False case). # preserve=False case).
# GPT disks with new partition tables can always be the BIOS boot disk # GPT disks with new partition tables can always be the BIOS boot disk
gpt_disk = make_disk(model, ptable='gpt', preserve=False) gpt_disk = make_disk(model, ptable="gpt", preserve=False)
self.assertActionPossible(gpt_disk, DeviceAction.TOGGLE_BOOT) self.assertActionPossible(gpt_disk, DeviceAction.TOGGLE_BOOT)
# Even if they are filled with partitions (we resize partitions to fit) # Even if they are filled with partitions (we resize partitions to fit)
make_partition(model, gpt_disk, size=gaps.largest_gap_size(dos_disk)) make_partition(model, gpt_disk, size=gaps.largest_gap_size(dos_disk))
@ -157,25 +154,37 @@ class TestActions(unittest.TestCase):
# GPT disks with existing partition tables but no partitions can be the # GPT disks with existing partition tables but no partitions can be the
# BIOS boot disk (in general we ignore existing empty partition tables) # BIOS boot disk (in general we ignore existing empty partition tables)
gpt_disk2 = make_disk(model, ptable='gpt', preserve=True) gpt_disk2 = make_disk(model, ptable="gpt", preserve=True)
self.assertActionPossible(gpt_disk2, DeviceAction.TOGGLE_BOOT) self.assertActionPossible(gpt_disk2, DeviceAction.TOGGLE_BOOT)
# If there is an existing *partition* though, it cannot be the boot # If there is an existing *partition* though, it cannot be the boot
# disk # disk
make_partition(model, gpt_disk2, preserve=True) make_partition(model, gpt_disk2, preserve=True)
self.assertActionNotPossible(gpt_disk2, DeviceAction.TOGGLE_BOOT) self.assertActionNotPossible(gpt_disk2, DeviceAction.TOGGLE_BOOT)
# Unless there is already a bios_grub partition we can reuse # Unless there is already a bios_grub partition we can reuse
gpt_disk3 = make_disk(model, ptable='gpt', preserve=True) gpt_disk3 = make_disk(model, ptable="gpt", preserve=True)
make_partition(model, gpt_disk3, flag="bios_grub", preserve=True, make_partition(
offset=1 << 20, size=512 << 20) model,
make_partition(model, gpt_disk3, preserve=True, gpt_disk3,
offset=513 << 20, size=8192 << 20) flag="bios_grub",
preserve=True,
offset=1 << 20,
size=512 << 20,
)
make_partition(
model, gpt_disk3, preserve=True, offset=513 << 20, size=8192 << 20
)
self.assertActionPossible(gpt_disk3, DeviceAction.TOGGLE_BOOT) self.assertActionPossible(gpt_disk3, DeviceAction.TOGGLE_BOOT)
# Edge case city: the bios_grub partition has to be first # Edge case city: the bios_grub partition has to be first
gpt_disk4 = make_disk(model, ptable='gpt', preserve=True) gpt_disk4 = make_disk(model, ptable="gpt", preserve=True)
make_partition(model, gpt_disk4, preserve=True, make_partition(model, gpt_disk4, preserve=True, offset=1 << 20, size=8192 << 20)
offset=1 << 20, size=8192 << 20) make_partition(
make_partition(model, gpt_disk4, flag="bios_grub", preserve=True, model,
offset=8193 << 20, size=512 << 20) gpt_disk4,
flag="bios_grub",
preserve=True,
offset=8193 << 20,
size=512 << 20,
)
self.assertActionNotPossible(gpt_disk4, DeviceAction.TOGGLE_BOOT) self.assertActionNotPossible(gpt_disk4, DeviceAction.TOGGLE_BOOT)
def _test_TOGGLE_BOOT_boot_partition(self, bl, flag): def _test_TOGGLE_BOOT_boot_partition(self, bl, flag):
@ -188,20 +197,20 @@ class TestActions(unittest.TestCase):
new_disk = make_disk(model, preserve=False) new_disk = make_disk(model, preserve=False)
self.assertActionPossible(new_disk, DeviceAction.TOGGLE_BOOT) self.assertActionPossible(new_disk, DeviceAction.TOGGLE_BOOT)
# Even if they are filled with partitions (we resize partitions to fit) # Even if they are filled with partitions (we resize partitions to fit)
make_partition( make_partition(model, new_disk, size=gaps.largest_gap_size(new_disk))
model, new_disk, size=gaps.largest_gap_size(new_disk))
self.assertActionPossible(new_disk, DeviceAction.TOGGLE_BOOT) self.assertActionPossible(new_disk, DeviceAction.TOGGLE_BOOT)
# A disk with an existing but empty partitions can also be the # A disk with an existing but empty partitions can also be the
# UEFI/PREP boot disk. # UEFI/PREP boot disk.
old_disk = make_disk(model, preserve=True, ptable='gpt') old_disk = make_disk(model, preserve=True, ptable="gpt")
self.assertActionPossible(old_disk, DeviceAction.TOGGLE_BOOT) self.assertActionPossible(old_disk, DeviceAction.TOGGLE_BOOT)
# If there is an existing partition though, it cannot. # If there is an existing partition though, it cannot.
make_partition(model, old_disk, preserve=True) make_partition(model, old_disk, preserve=True)
self.assertActionNotPossible(old_disk, DeviceAction.TOGGLE_BOOT) self.assertActionNotPossible(old_disk, DeviceAction.TOGGLE_BOOT)
# If there is an existing ESP/PReP partition though, fine! # If there is an existing ESP/PReP partition though, fine!
make_partition(model, old_disk, flag=flag, preserve=True, make_partition(
offset=1 << 20, size=512 << 20) model, old_disk, flag=flag, preserve=True, offset=1 << 20, size=512 << 20
)
self.assertActionPossible(old_disk, DeviceAction.TOGGLE_BOOT) self.assertActionPossible(old_disk, DeviceAction.TOGGLE_BOOT)
def test_disk_action_TOGGLE_BOOT_UEFI(self): def test_disk_action_TOGGLE_BOOT_UEFI(self):
@ -217,7 +226,7 @@ class TestActions(unittest.TestCase):
def test_partition_action_EDIT(self): def test_partition_action_EDIT(self):
model, part = make_model_and_partition() model, part = make_model_and_partition()
self.assertActionPossible(part, DeviceAction.EDIT) self.assertActionPossible(part, DeviceAction.EDIT)
model.add_volgroup('vg1', {part}) model.add_volgroup("vg1", {part})
self.assertActionNotPossible(part, DeviceAction.EDIT) self.assertActionNotPossible(part, DeviceAction.EDIT)
def test_partition_action_REFORMAT(self): def test_partition_action_REFORMAT(self):
@ -243,20 +252,20 @@ class TestActions(unittest.TestCase):
model = make_model() model = make_model()
part1 = make_partition(model) part1 = make_partition(model)
self.assertActionPossible(part1, DeviceAction.DELETE) self.assertActionPossible(part1, DeviceAction.DELETE)
fs = model.add_filesystem(part1, 'ext4') fs = model.add_filesystem(part1, "ext4")
self.assertActionPossible(part1, DeviceAction.DELETE) self.assertActionPossible(part1, DeviceAction.DELETE)
model.add_mount(fs, '/') model.add_mount(fs, "/")
self.assertActionPossible(part1, DeviceAction.DELETE) self.assertActionPossible(part1, DeviceAction.DELETE)
part2 = make_partition(model) part2 = make_partition(model)
model.add_volgroup('vg1', {part2}) model.add_volgroup("vg1", {part2})
self.assertActionNotPossible(part2, DeviceAction.DELETE) self.assertActionNotPossible(part2, DeviceAction.DELETE)
part = make_partition(make_model(Bootloader.BIOS), flag='bios_grub') part = make_partition(make_model(Bootloader.BIOS), flag="bios_grub")
self.assertActionNotPossible(part, DeviceAction.DELETE) self.assertActionNotPossible(part, DeviceAction.DELETE)
part = make_partition(make_model(Bootloader.UEFI), flag='boot') part = make_partition(make_model(Bootloader.UEFI), flag="boot")
self.assertActionNotPossible(part, DeviceAction.DELETE) self.assertActionNotPossible(part, DeviceAction.DELETE)
part = make_partition(make_model(Bootloader.PREP), flag='prep') part = make_partition(make_model(Bootloader.PREP), flag="prep")
self.assertActionNotPossible(part, DeviceAction.DELETE) self.assertActionNotPossible(part, DeviceAction.DELETE)
# You cannot delete a partition from a disk that has # You cannot delete a partition from a disk that has
@ -277,7 +286,7 @@ class TestActions(unittest.TestCase):
model = make_model() model = make_model()
raid1 = make_raid(model) raid1 = make_raid(model)
self.assertActionPossible(raid1, DeviceAction.EDIT) self.assertActionPossible(raid1, DeviceAction.EDIT)
model.add_volgroup('vg1', {raid1}) model.add_volgroup("vg1", {raid1})
self.assertActionNotPossible(raid1, DeviceAction.EDIT) self.assertActionNotPossible(raid1, DeviceAction.EDIT)
raid2 = make_raid(model) raid2 = make_raid(model)
make_partition(model, raid2) make_partition(model, raid2)
@ -294,23 +303,23 @@ class TestActions(unittest.TestCase):
self.assertActionNotPossible(raid1, DeviceAction.REFORMAT) self.assertActionNotPossible(raid1, DeviceAction.REFORMAT)
raid1p1 = make_partition(model, raid1) raid1p1 = make_partition(model, raid1)
self.assertActionPossible(raid1, DeviceAction.REFORMAT) self.assertActionPossible(raid1, DeviceAction.REFORMAT)
model.add_volgroup('vg0', {raid1p1}) model.add_volgroup("vg0", {raid1p1})
self.assertActionNotPossible(raid1, DeviceAction.REFORMAT) self.assertActionNotPossible(raid1, DeviceAction.REFORMAT)
raid2 = make_raid(model, preserve=True) raid2 = make_raid(model, preserve=True)
self.assertActionNotPossible(raid2, DeviceAction.REFORMAT) self.assertActionNotPossible(raid2, DeviceAction.REFORMAT)
raid2p1 = make_partition(model, raid2, preserve=True) raid2p1 = make_partition(model, raid2, preserve=True)
self.assertActionPossible(raid2, DeviceAction.REFORMAT) self.assertActionPossible(raid2, DeviceAction.REFORMAT)
model.add_volgroup('vg1', {raid2p1}) model.add_volgroup("vg1", {raid2p1})
self.assertActionNotPossible(raid2, DeviceAction.REFORMAT) self.assertActionNotPossible(raid2, DeviceAction.REFORMAT)
raid3 = make_raid(model) raid3 = make_raid(model)
model.add_volgroup('vg2', {raid3}) model.add_volgroup("vg2", {raid3})
self.assertActionNotPossible(raid3, DeviceAction.REFORMAT) self.assertActionNotPossible(raid3, DeviceAction.REFORMAT)
raid4 = make_raid(model) raid4 = make_raid(model)
raid4.preserve = True raid4.preserve = True
model.add_volgroup('vg2', {raid4}) model.add_volgroup("vg2", {raid4})
self.assertActionNotPossible(raid4, DeviceAction.REFORMAT) self.assertActionNotPossible(raid4, DeviceAction.REFORMAT)
def test_raid_action_PARTITION(self): def test_raid_action_PARTITION(self):
@ -323,7 +332,7 @@ class TestActions(unittest.TestCase):
make_partition(model, raid) make_partition(model, raid)
self.assertActionNotPossible(raid, DeviceAction.FORMAT) self.assertActionNotPossible(raid, DeviceAction.FORMAT)
raid2 = make_raid(model) raid2 = make_raid(model)
model.add_volgroup('vg1', {raid2}) model.add_volgroup("vg1", {raid2})
self.assertActionNotPossible(raid2, DeviceAction.FORMAT) self.assertActionNotPossible(raid2, DeviceAction.FORMAT)
def test_raid_action_REMOVE(self): def test_raid_action_REMOVE(self):
@ -338,20 +347,20 @@ class TestActions(unittest.TestCase):
self.assertActionPossible(raid1, DeviceAction.DELETE) self.assertActionPossible(raid1, DeviceAction.DELETE)
part = make_partition(model, raid1) part = make_partition(model, raid1)
self.assertActionPossible(raid1, DeviceAction.DELETE) self.assertActionPossible(raid1, DeviceAction.DELETE)
fs = model.add_filesystem(part, 'ext4') fs = model.add_filesystem(part, "ext4")
self.assertActionPossible(raid1, DeviceAction.DELETE) self.assertActionPossible(raid1, DeviceAction.DELETE)
model.add_mount(fs, '/') model.add_mount(fs, "/")
self.assertActionNotPossible(raid1, DeviceAction.DELETE) self.assertActionNotPossible(raid1, DeviceAction.DELETE)
raid2 = make_raid(model) raid2 = make_raid(model)
self.assertActionPossible(raid2, DeviceAction.DELETE) self.assertActionPossible(raid2, DeviceAction.DELETE)
fs = model.add_filesystem(raid2, 'ext4') fs = model.add_filesystem(raid2, "ext4")
self.assertActionPossible(raid2, DeviceAction.DELETE) self.assertActionPossible(raid2, DeviceAction.DELETE)
model.add_mount(fs, '/') model.add_mount(fs, "/")
self.assertActionPossible(raid2, DeviceAction.DELETE) self.assertActionPossible(raid2, DeviceAction.DELETE)
raid2 = make_raid(model) raid2 = make_raid(model)
model.add_volgroup('vg0', {raid2}) model.add_volgroup("vg0", {raid2})
self.assertActionNotPossible(raid2, DeviceAction.DELETE) self.assertActionNotPossible(raid2, DeviceAction.DELETE)
def test_raid_action_TOGGLE_BOOT(self): def test_raid_action_TOGGLE_BOOT(self):
@ -365,7 +374,7 @@ class TestActions(unittest.TestCase):
def test_vg_action_EDIT(self): def test_vg_action_EDIT(self):
model, vg = make_model_and_vg() model, vg = make_model_and_vg()
self.assertActionPossible(vg, DeviceAction.EDIT) self.assertActionPossible(vg, DeviceAction.EDIT)
model.add_logical_volume(vg, 'lv1', size=gaps.largest_gap_size(vg)) model.add_logical_volume(vg, "lv1", size=gaps.largest_gap_size(vg))
self.assertActionNotPossible(vg, DeviceAction.EDIT) self.assertActionNotPossible(vg, DeviceAction.EDIT)
vg2 = make_vg(model) vg2 = make_vg(model)
@ -392,12 +401,11 @@ class TestActions(unittest.TestCase):
model, vg = make_model_and_vg() model, vg = make_model_and_vg()
self.assertActionPossible(vg, DeviceAction.DELETE) self.assertActionPossible(vg, DeviceAction.DELETE)
self.assertActionPossible(vg, DeviceAction.DELETE) self.assertActionPossible(vg, DeviceAction.DELETE)
lv = model.add_logical_volume( lv = model.add_logical_volume(vg, "lv0", size=gaps.largest_gap_size(vg) // 2)
vg, 'lv0', size=gaps.largest_gap_size(vg)//2)
self.assertActionPossible(vg, DeviceAction.DELETE) self.assertActionPossible(vg, DeviceAction.DELETE)
fs = model.add_filesystem(lv, 'ext4') fs = model.add_filesystem(lv, "ext4")
self.assertActionPossible(vg, DeviceAction.DELETE) self.assertActionPossible(vg, DeviceAction.DELETE)
model.add_mount(fs, '/') model.add_mount(fs, "/")
self.assertActionNotPossible(vg, DeviceAction.DELETE) self.assertActionNotPossible(vg, DeviceAction.DELETE)
def test_vg_action_TOGGLE_BOOT(self): def test_vg_action_TOGGLE_BOOT(self):
@ -431,9 +439,9 @@ class TestActions(unittest.TestCase):
def test_lv_action_DELETE(self): def test_lv_action_DELETE(self):
model, lv = make_model_and_lv() model, lv = make_model_and_lv()
self.assertActionPossible(lv, DeviceAction.DELETE) self.assertActionPossible(lv, DeviceAction.DELETE)
fs = model.add_filesystem(lv, 'ext4') fs = model.add_filesystem(lv, "ext4")
self.assertActionPossible(lv, DeviceAction.DELETE) self.assertActionPossible(lv, DeviceAction.DELETE)
model.add_mount(fs, '/') model.add_mount(fs, "/")
self.assertActionPossible(lv, DeviceAction.DELETE) self.assertActionPossible(lv, DeviceAction.DELETE)
lv2 = make_lv(model) lv2 = make_lv(model)

View File

@ -13,20 +13,20 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from functools import partial
import unittest import unittest
from functools import partial
from unittest import mock from unittest import mock
from subiquitycore.tests.parameterized import parameterized from subiquity.common.filesystem import gaps
from subiquity.common.types import GapUsable
from subiquity.models.filesystem import ( from subiquity.models.filesystem import (
Disk,
LVM_CHUNK_SIZE, LVM_CHUNK_SIZE,
LVM_OVERHEAD, LVM_OVERHEAD,
Disk,
MiB, MiB,
PartitionAlignmentData, PartitionAlignmentData,
Raid, Raid,
) )
from subiquity.models.tests.test_filesystem import ( from subiquity.models.tests.test_filesystem import (
make_disk, make_disk,
make_lv, make_lv,
@ -35,22 +35,19 @@ from subiquity.models.tests.test_filesystem import (
make_model_and_lv, make_model_and_lv,
make_partition, make_partition,
make_vg, make_vg,
) )
from subiquitycore.tests.parameterized import parameterized
from subiquity.common.filesystem import gaps
from subiquity.common.types import GapUsable
class GapTestCase(unittest.TestCase): class GapTestCase(unittest.TestCase):
def use_alignment_data(self, alignment_data): def use_alignment_data(self, alignment_data):
m = mock.patch('subiquity.common.filesystem.gaps.parts_and_gaps') m = mock.patch("subiquity.common.filesystem.gaps.parts_and_gaps")
p = m.start() p = m.start()
self.addCleanup(m.stop) self.addCleanup(m.stop)
p.side_effect = partial( p.side_effect = partial(gaps.find_disk_gaps_v2, info=alignment_data)
gaps.find_disk_gaps_v2, info=alignment_data)
for cls in Disk, Raid: for cls in Disk, Raid:
md = mock.patch.object(cls, 'alignment_data') md = mock.patch.object(cls, "alignment_data")
p = md.start() p = md.start()
self.addCleanup(md.stop) self.addCleanup(md.stop)
p.return_value = alignment_data p.return_value = alignment_data
@ -86,14 +83,21 @@ class TestSplitGap(GapTestCase):
@parameterized.expand([[1], [10]]) @parameterized.expand([[1], [10]])
def test_split_in_extended(self, ebr_space): def test_split_in_extended(self, ebr_space):
self.use_alignment_data(PartitionAlignmentData( self.use_alignment_data(
part_align=1, min_gap_size=1, min_start_offset=0, PartitionAlignmentData(
min_end_offset=0, primary_part_limit=4, ebr_space=ebr_space)) part_align=1,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=4,
ebr_space=ebr_space,
)
)
m = make_model(storage_version=2) m = make_model(storage_version=2)
d = make_disk(m, size=100, ptable='dos') d = make_disk(m, size=100, ptable="dos")
[gap] = gaps.parts_and_gaps(d) [gap] = gaps.parts_and_gaps(d)
make_partition(m, d, offset=gap.offset, size=gap.size, flag='extended') make_partition(m, d, offset=gap.offset, size=gap.size, flag="extended")
[p, g] = gaps.parts_and_gaps(d) [p, g] = gaps.parts_and_gaps(d)
self.assertTrue(g.in_extended) self.assertTrue(g.in_extended)
g1, g2 = g.split(10) g1, g2 = g.split(10)
@ -205,125 +209,158 @@ class TestAfter(unittest.TestCase):
class TestDiskGaps(unittest.TestCase): class TestDiskGaps(unittest.TestCase):
def test_no_partition_gpt(self): def test_no_partition_gpt(self):
size = 1 << 30 size = 1 << 30
d = make_disk(size=size, ptable='gpt') d = make_disk(size=size, ptable="gpt")
self.assertEqual( self.assertEqual(
gaps.find_disk_gaps_v2(d), gaps.find_disk_gaps_v2(d), [gaps.Gap(d, MiB, size - 2 * MiB, False)]
[gaps.Gap(d, MiB, size - 2*MiB, False)]) )
def test_no_partition_dos(self): def test_no_partition_dos(self):
size = 1 << 30 size = 1 << 30
d = make_disk(size=size, ptable='dos') d = make_disk(size=size, ptable="dos")
self.assertEqual( self.assertEqual(
gaps.find_disk_gaps_v2(d), gaps.find_disk_gaps_v2(d), [gaps.Gap(d, MiB, size - MiB, False)]
[gaps.Gap(d, MiB, size - MiB, False)]) )
def test_all_partition(self): def test_all_partition(self):
info = PartitionAlignmentData( info = PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=0, part_align=10,
min_end_offset=0, primary_part_limit=10) min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100) m, d = make_model_and_disk(size=100)
p = make_partition(m, d, offset=0, size=100) p = make_partition(m, d, offset=0, size=100)
self.assertEqual( self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p])
gaps.find_disk_gaps_v2(d, info),
[p])
def test_all_partition_with_min_offsets(self): def test_all_partition_with_min_offsets(self):
info = PartitionAlignmentData( info = PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=10, part_align=10,
min_end_offset=10, primary_part_limit=10) min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100) m, d = make_model_and_disk(size=100)
p = make_partition(m, d, offset=10, size=80) p = make_partition(m, d, offset=10, size=80)
self.assertEqual( self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p])
gaps.find_disk_gaps_v2(d, info),
[p])
def test_half_partition(self): def test_half_partition(self):
info = PartitionAlignmentData( info = PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=0, part_align=10,
min_end_offset=0, primary_part_limit=10) min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100) m, d = make_model_and_disk(size=100)
p = make_partition(m, d, offset=0, size=50) p = make_partition(m, d, offset=0, size=50)
self.assertEqual( self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p, gaps.Gap(d, 50, 50)])
gaps.find_disk_gaps_v2(d, info),
[p, gaps.Gap(d, 50, 50)])
def test_gap_in_middle(self): def test_gap_in_middle(self):
info = PartitionAlignmentData( info = PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=0, part_align=10,
min_end_offset=0, primary_part_limit=10) min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100) m, d = make_model_and_disk(size=100)
p1 = make_partition(m, d, offset=0, size=20) p1 = make_partition(m, d, offset=0, size=20)
p2 = make_partition(m, d, offset=80, size=20) p2 = make_partition(m, d, offset=80, size=20)
self.assertEqual( self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p1, gaps.Gap(d, 20, 60), p2])
gaps.find_disk_gaps_v2(d, info),
[p1, gaps.Gap(d, 20, 60), p2])
def test_small_gap(self): def test_small_gap(self):
info = PartitionAlignmentData( info = PartitionAlignmentData(
part_align=10, min_gap_size=20, min_start_offset=0, part_align=10,
min_end_offset=0, primary_part_limit=10) min_gap_size=20,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100) m, d = make_model_and_disk(size=100)
p1 = make_partition(m, d, offset=0, size=40) p1 = make_partition(m, d, offset=0, size=40)
p2 = make_partition(m, d, offset=50, size=50) p2 = make_partition(m, d, offset=50, size=50)
self.assertEqual( self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p1, p2])
gaps.find_disk_gaps_v2(d, info),
[p1, p2])
def test_align_gap(self): def test_align_gap(self):
info = PartitionAlignmentData( info = PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=0, part_align=10,
min_end_offset=0, primary_part_limit=10) min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100) m, d = make_model_and_disk(size=100)
p1 = make_partition(m, d, offset=0, size=17) p1 = make_partition(m, d, offset=0, size=17)
p2 = make_partition(m, d, offset=53, size=47) p2 = make_partition(m, d, offset=53, size=47)
self.assertEqual( self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p1, gaps.Gap(d, 20, 30), p2])
gaps.find_disk_gaps_v2(d, info),
[p1, gaps.Gap(d, 20, 30), p2])
def test_all_extended(self): def test_all_extended(self):
info = PartitionAlignmentData( info = PartitionAlignmentData(
part_align=5, min_gap_size=1, min_start_offset=0, min_end_offset=0, part_align=5,
ebr_space=2, primary_part_limit=10) min_gap_size=1,
m, d = make_model_and_disk(size=100, ptable='dos') min_start_offset=0,
p = make_partition(m, d, offset=0, size=100, flag='extended') min_end_offset=0,
ebr_space=2,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100, ptable="dos")
p = make_partition(m, d, offset=0, size=100, flag="extended")
self.assertEqual( self.assertEqual(
gaps.find_disk_gaps_v2(d, info), gaps.find_disk_gaps_v2(d, info),
[ [
p, p,
gaps.Gap(d, 5, 95, True), gaps.Gap(d, 5, 95, True),
]) ],
)
def test_half_extended(self): def test_half_extended(self):
info = PartitionAlignmentData( info = PartitionAlignmentData(
part_align=5, min_gap_size=1, min_start_offset=0, min_end_offset=0, part_align=5,
ebr_space=2, primary_part_limit=10) min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
ebr_space=2,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100) m, d = make_model_and_disk(size=100)
p = make_partition(m, d, offset=0, size=50, flag='extended') p = make_partition(m, d, offset=0, size=50, flag="extended")
self.assertEqual( self.assertEqual(
gaps.find_disk_gaps_v2(d, info), gaps.find_disk_gaps_v2(d, info),
[p, gaps.Gap(d, 5, 45, True), gaps.Gap(d, 50, 50, False)]) [p, gaps.Gap(d, 5, 45, True), gaps.Gap(d, 50, 50, False)],
)
def test_half_extended_one_logical(self): def test_half_extended_one_logical(self):
info = PartitionAlignmentData( info = PartitionAlignmentData(
part_align=5, min_gap_size=1, min_start_offset=0, min_end_offset=0, part_align=5,
ebr_space=2, primary_part_limit=10) min_gap_size=1,
m, d = make_model_and_disk(size=100, ptable='dos') min_start_offset=0,
p1 = make_partition(m, d, offset=0, size=50, flag='extended') min_end_offset=0,
p2 = make_partition(m, d, offset=5, size=45, flag='logical') ebr_space=2,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100, ptable="dos")
p1 = make_partition(m, d, offset=0, size=50, flag="extended")
p2 = make_partition(m, d, offset=5, size=45, flag="logical")
self.assertEqual( self.assertEqual(
gaps.find_disk_gaps_v2(d, info), gaps.find_disk_gaps_v2(d, info), [p1, p2, gaps.Gap(d, 50, 50, False)]
[p1, p2, gaps.Gap(d, 50, 50, False)]) )
def test_half_extended_half_logical(self): def test_half_extended_half_logical(self):
info = PartitionAlignmentData( info = PartitionAlignmentData(
part_align=5, min_gap_size=1, min_start_offset=0, min_end_offset=0, part_align=5,
ebr_space=2, primary_part_limit=10) min_gap_size=1,
m, d = make_model_and_disk(size=100, ptable='dos') min_start_offset=0,
p1 = make_partition(m, d, offset=0, size=50, flag='extended') min_end_offset=0,
p2 = make_partition(m, d, offset=5, size=25, flag='logical') ebr_space=2,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100, ptable="dos")
p1 = make_partition(m, d, offset=0, size=50, flag="extended")
p2 = make_partition(m, d, offset=5, size=25, flag="logical")
self.assertEqual( self.assertEqual(
gaps.find_disk_gaps_v2(d, info), gaps.find_disk_gaps_v2(d, info),
[ [
@ -331,159 +368,204 @@ class TestDiskGaps(unittest.TestCase):
p2, p2,
gaps.Gap(d, 35, 15, True), gaps.Gap(d, 35, 15, True),
gaps.Gap(d, 50, 50, False), gaps.Gap(d, 50, 50, False),
]) ],
)
def test_gap_before_primary(self): def test_gap_before_primary(self):
# 0----10---20---30---40---50---60---70---80---90---100 # 0----10---20---30---40---50---60---70---80---90---100
# [ g1 ][ p1 (primary) ] # [ g1 ][ p1 (primary) ]
info = PartitionAlignmentData( info = PartitionAlignmentData(
part_align=5, min_gap_size=1, min_start_offset=0, min_end_offset=0, part_align=5,
ebr_space=1, primary_part_limit=10) min_gap_size=1,
m, d = make_model_and_disk(size=100, ptable='dos') min_start_offset=0,
min_end_offset=0,
ebr_space=1,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100, ptable="dos")
p1 = make_partition(m, d, offset=50, size=50) p1 = make_partition(m, d, offset=50, size=50)
self.assertEqual( self.assertEqual(
gaps.find_disk_gaps_v2(d, info), gaps.find_disk_gaps_v2(d, info), [gaps.Gap(d, 0, 50, False), p1]
[gaps.Gap(d, 0, 50, False), p1]) )
def test_gap_in_extended_before_logical(self): def test_gap_in_extended_before_logical(self):
# 0----10---20---30---40---50---60---70---80---90---100 # 0----10---20---30---40---50---60---70---80---90---100
# [ p1 (extended) ] # [ p1 (extended) ]
# [ g1 ] [ p5 (logical) ] # [ g1 ] [ p5 (logical) ]
info = PartitionAlignmentData( info = PartitionAlignmentData(
part_align=5, min_gap_size=1, min_start_offset=0, min_end_offset=0, part_align=5,
ebr_space=1, primary_part_limit=10) min_gap_size=1,
m, d = make_model_and_disk(size=100, ptable='dos') min_start_offset=0,
p1 = make_partition(m, d, offset=0, size=100, flag='extended') min_end_offset=0,
p5 = make_partition(m, d, offset=50, size=50, flag='logical') ebr_space=1,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100, ptable="dos")
p1 = make_partition(m, d, offset=0, size=100, flag="extended")
p5 = make_partition(m, d, offset=50, size=50, flag="logical")
self.assertEqual( self.assertEqual(
gaps.find_disk_gaps_v2(d, info), gaps.find_disk_gaps_v2(d, info), [p1, gaps.Gap(d, 5, 40, True), p5]
[p1, gaps.Gap(d, 5, 40, True), p5]) )
def test_unusable_gap_primaries(self): def test_unusable_gap_primaries(self):
info = PartitionAlignmentData( info = PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=0, part_align=10,
min_end_offset=0, primary_part_limit=1) min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=1,
)
m, d = make_model_and_disk(size=100) m, d = make_model_and_disk(size=100)
p = make_partition(m, d, offset=0, size=90) p = make_partition(m, d, offset=0, size=90)
g = gaps.Gap(d, offset=90, size=10, g = gaps.Gap(d, offset=90, size=10, usable=GapUsable.TOO_MANY_PRIMARY_PARTS)
usable=GapUsable.TOO_MANY_PRIMARY_PARTS) self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p, g])
self.assertEqual(
gaps.find_disk_gaps_v2(d, info),
[p, g])
def test_usable_gap_primaries(self): def test_usable_gap_primaries(self):
info = PartitionAlignmentData( info = PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=0, part_align=10,
min_end_offset=0, primary_part_limit=2) min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=2,
)
m, d = make_model_and_disk(size=100) m, d = make_model_and_disk(size=100)
p = make_partition(m, d, offset=0, size=90) p = make_partition(m, d, offset=0, size=90)
g = gaps.Gap(d, offset=90, size=10, usable=GapUsable.YES) g = gaps.Gap(d, offset=90, size=10, usable=GapUsable.YES)
self.assertEqual( self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p, g])
gaps.find_disk_gaps_v2(d, info),
[p, g])
def test_usable_gap_extended(self): def test_usable_gap_extended(self):
info = PartitionAlignmentData( info = PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=0, part_align=10,
min_end_offset=0, primary_part_limit=1) min_gap_size=1,
m, d = make_model_and_disk(size=100, ptable='dos') min_start_offset=0,
p = make_partition(m, d, offset=0, size=100, flag='extended') min_end_offset=0,
g = gaps.Gap(d, offset=0, size=100, primary_part_limit=1,
in_extended=True, usable=GapUsable.YES) )
self.assertEqual( m, d = make_model_and_disk(size=100, ptable="dos")
gaps.find_disk_gaps_v2(d, info), p = make_partition(m, d, offset=0, size=100, flag="extended")
[p, g]) g = gaps.Gap(d, offset=0, size=100, in_extended=True, usable=GapUsable.YES)
self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p, g])
class TestMovableTrailingPartitionsAndGapSize(GapTestCase): class TestMovableTrailingPartitionsAndGapSize(GapTestCase):
def test_largest_non_extended(self): def test_largest_non_extended(self):
self.use_alignment_data(PartitionAlignmentData( self.use_alignment_data(
part_align=10, min_gap_size=1, min_start_offset=0, PartitionAlignmentData(
min_end_offset=0, primary_part_limit=4)) part_align=10,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=4,
)
)
# 0----10---20---30---40---50---60---70---80---90---100 # 0----10---20---30---40---50---60---70---80---90---100
# [ p1 ] [ p2 (extended) ] # [ p1 ] [ p2 (extended) ]
# [ g1 ][ g2 ] # [ g1 ][ g2 ]
m, d = make_model_and_disk(size=100, ptable='dos') m, d = make_model_and_disk(size=100, ptable="dos")
make_partition(m, d, offset=0, size=10) make_partition(m, d, offset=0, size=10)
make_partition(m, d, offset=30, size=70, flag='extended') make_partition(m, d, offset=30, size=70, flag="extended")
g = gaps.Gap(d, offset=10, size=20, g = gaps.Gap(d, offset=10, size=20, in_extended=False, usable=GapUsable.YES)
in_extended=False, usable=GapUsable.YES)
self.assertEqual(g, gaps.largest_gap(d, in_extended=False)) self.assertEqual(g, gaps.largest_gap(d, in_extended=False))
def test_largest_yes_extended(self): def test_largest_yes_extended(self):
self.use_alignment_data(PartitionAlignmentData( self.use_alignment_data(
part_align=10, min_gap_size=1, min_start_offset=0, PartitionAlignmentData(
min_end_offset=0, primary_part_limit=4)) part_align=10,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=4,
)
)
# 0----10---20---30---40---50---60---70---80---90---100 # 0----10---20---30---40---50---60---70---80---90---100
# [ p1 ] [ p2 (extended) ] # [ p1 ] [ p2 (extended) ]
# [ g1 ][ g2 ] # [ g1 ][ g2 ]
m, d = make_model_and_disk(size=100, ptable='dos') m, d = make_model_and_disk(size=100, ptable="dos")
make_partition(m, d, offset=0, size=10) make_partition(m, d, offset=0, size=10)
make_partition(m, d, offset=30, size=70, flag='extended') make_partition(m, d, offset=30, size=70, flag="extended")
g = gaps.Gap(d, offset=30, size=70, g = gaps.Gap(d, offset=30, size=70, in_extended=True, usable=GapUsable.YES)
in_extended=True, usable=GapUsable.YES)
self.assertEqual(g, gaps.largest_gap(d, in_extended=True)) self.assertEqual(g, gaps.largest_gap(d, in_extended=True))
def test_largest_any_extended(self): def test_largest_any_extended(self):
self.use_alignment_data(PartitionAlignmentData( self.use_alignment_data(
part_align=10, min_gap_size=1, min_start_offset=0, PartitionAlignmentData(
min_end_offset=0, primary_part_limit=4)) part_align=10,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=4,
)
)
# 0----10---20---30---40---50---60---70---80---90---100 # 0----10---20---30---40---50---60---70---80---90---100
# [ p1 ] [ p2 (extended) ] # [ p1 ] [ p2 (extended) ]
# [ g1 ][ g2 ] # [ g1 ][ g2 ]
m, d = make_model_and_disk(size=100, ptable='dos') m, d = make_model_and_disk(size=100, ptable="dos")
make_partition(m, d, offset=0, size=10) make_partition(m, d, offset=0, size=10)
make_partition(m, d, offset=30, size=70, flag='extended') make_partition(m, d, offset=30, size=70, flag="extended")
g = gaps.Gap(d, offset=30, size=70, g = gaps.Gap(d, offset=30, size=70, in_extended=True, usable=GapUsable.YES)
in_extended=True, usable=GapUsable.YES)
self.assertEqual(g, gaps.largest_gap(d)) self.assertEqual(g, gaps.largest_gap(d))
def test_no_next_gap(self): def test_no_next_gap(self):
self.use_alignment_data(PartitionAlignmentData( self.use_alignment_data(
part_align=10, min_gap_size=1, min_start_offset=10, PartitionAlignmentData(
min_end_offset=10, primary_part_limit=10)) part_align=10,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
)
)
# 0----10---20---30---40---50---60---70---80---90---100 # 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 ]##### # #####[ p1 ]#####
m, d = make_model_and_disk(size=100) m, d = make_model_and_disk(size=100)
p = make_partition(m, d, offset=10, size=80) p = make_partition(m, d, offset=10, size=80)
self.assertEqual( self.assertEqual(([], 0), gaps.movable_trailing_partitions_and_gap_size(p))
([], 0),
gaps.movable_trailing_partitions_and_gap_size(p))
def test_immediately_trailing_gap(self): def test_immediately_trailing_gap(self):
self.use_alignment_data(PartitionAlignmentData( self.use_alignment_data(
part_align=10, min_gap_size=1, min_start_offset=10, PartitionAlignmentData(
min_end_offset=10, primary_part_limit=10)) part_align=10,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
)
)
# 0----10---20---30---40---50---60---70---80---90---100 # 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 ] [ p2 ] ##### # #####[ p1 ] [ p2 ] #####
m, d = make_model_and_disk(size=100) m, d = make_model_and_disk(size=100)
p1 = make_partition(m, d, offset=10, size=20) p1 = make_partition(m, d, offset=10, size=20)
p2 = make_partition(m, d, offset=50, size=20) p2 = make_partition(m, d, offset=50, size=20)
self.assertEqual( self.assertEqual(([], 20), gaps.movable_trailing_partitions_and_gap_size(p1))
([], 20), self.assertEqual(([], 20), gaps.movable_trailing_partitions_and_gap_size(p2))
gaps.movable_trailing_partitions_and_gap_size(p1))
self.assertEqual(
([], 20),
gaps.movable_trailing_partitions_and_gap_size(p2))
def test_one_trailing_movable_partition_and_gap(self): def test_one_trailing_movable_partition_and_gap(self):
self.use_alignment_data(PartitionAlignmentData( self.use_alignment_data(
part_align=10, min_gap_size=1, min_start_offset=10, PartitionAlignmentData(
min_end_offset=10, primary_part_limit=10)) part_align=10,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
)
)
# 0----10---20---30---40---50---60---70---80---90---100 # 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 ][ p2 ] ##### # #####[ p1 ][ p2 ] #####
m, d = make_model_and_disk(size=100) m, d = make_model_and_disk(size=100)
p1 = make_partition(m, d, offset=10, size=40) p1 = make_partition(m, d, offset=10, size=40)
p2 = make_partition(m, d, offset=50, size=10) p2 = make_partition(m, d, offset=50, size=10)
self.assertEqual( self.assertEqual(([p2], 30), gaps.movable_trailing_partitions_and_gap_size(p1))
([p2], 30),
gaps.movable_trailing_partitions_and_gap_size(p1))
def test_two_trailing_movable_partitions_and_gap(self): def test_two_trailing_movable_partitions_and_gap(self):
self.use_alignment_data(PartitionAlignmentData( self.use_alignment_data(
part_align=10, min_gap_size=1, min_start_offset=10, PartitionAlignmentData(
min_end_offset=10, primary_part_limit=10)) part_align=10,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
)
)
# 0----10---20---30---40---50---60---70---80---90---100 # 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 ][ p2 ][ p3 ] ##### # #####[ p1 ][ p2 ][ p3 ] #####
m, d = make_model_and_disk(size=100) m, d = make_model_and_disk(size=100)
@ -491,122 +573,152 @@ class TestMovableTrailingPartitionsAndGapSize(GapTestCase):
p2 = make_partition(m, d, offset=50, size=10) p2 = make_partition(m, d, offset=50, size=10)
p3 = make_partition(m, d, offset=60, size=10) p3 = make_partition(m, d, offset=60, size=10)
self.assertEqual( self.assertEqual(
([p2, p3], 20), ([p2, p3], 20), gaps.movable_trailing_partitions_and_gap_size(p1)
gaps.movable_trailing_partitions_and_gap_size(p1)) )
def test_one_trailing_movable_partition_and_no_gap(self): def test_one_trailing_movable_partition_and_no_gap(self):
self.use_alignment_data(PartitionAlignmentData( self.use_alignment_data(
part_align=10, min_gap_size=1, min_start_offset=10, PartitionAlignmentData(
min_end_offset=10, primary_part_limit=10)) part_align=10,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
)
)
# 0----10---20---30---40---50---60---70---80---90---100 # 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 ][ p2 ]##### # #####[ p1 ][ p2 ]#####
m, d = make_model_and_disk(size=100) m, d = make_model_and_disk(size=100)
p1 = make_partition(m, d, offset=10, size=40) p1 = make_partition(m, d, offset=10, size=40)
p2 = make_partition(m, d, offset=50, size=40) p2 = make_partition(m, d, offset=50, size=40)
self.assertEqual( self.assertEqual(([p2], 0), gaps.movable_trailing_partitions_and_gap_size(p1))
([p2], 0),
gaps.movable_trailing_partitions_and_gap_size(p1))
def test_full_extended_partition_then_gap(self): def test_full_extended_partition_then_gap(self):
self.use_alignment_data(PartitionAlignmentData( self.use_alignment_data(
part_align=1, min_gap_size=1, min_start_offset=10, PartitionAlignmentData(
min_end_offset=10, primary_part_limit=10, ebr_space=2)) part_align=1,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
ebr_space=2,
)
)
# 0----10---20---30---40---50---60---70---80---90---100 # 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 (extended) ] ##### # #####[ p1 (extended) ] #####
# ######[ p5 (logical) ] ##### # ######[ p5 (logical) ] #####
m, d = make_model_and_disk(size=100, ptable='dos') m, d = make_model_and_disk(size=100, ptable="dos")
make_partition(m, d, offset=10, size=40, flag='extended') make_partition(m, d, offset=10, size=40, flag="extended")
p5 = make_partition(m, d, offset=12, size=38, flag='logical') p5 = make_partition(m, d, offset=12, size=38, flag="logical")
self.assertEqual( self.assertEqual(([], 0), gaps.movable_trailing_partitions_and_gap_size(p5))
([], 0),
gaps.movable_trailing_partitions_and_gap_size(p5))
def test_full_extended_partition_then_part(self): def test_full_extended_partition_then_part(self):
self.use_alignment_data(PartitionAlignmentData( self.use_alignment_data(
part_align=1, min_gap_size=1, min_start_offset=10, PartitionAlignmentData(
min_end_offset=10, primary_part_limit=10, ebr_space=2)) part_align=1,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
ebr_space=2,
)
)
# 0----10---20---30---40---50---60---70---80---90---100 # 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 (extended) ][ p2 ]##### # #####[ p1 (extended) ][ p2 ]#####
# ######[ p5 (logical) ] ##### # ######[ p5 (logical) ] #####
m, d = make_model_and_disk(size=100, ptable='dos') m, d = make_model_and_disk(size=100, ptable="dos")
make_partition(m, d, offset=10, size=40, flag='extended') make_partition(m, d, offset=10, size=40, flag="extended")
make_partition(m, d, offset=50, size=40) make_partition(m, d, offset=50, size=40)
p5 = make_partition(m, d, offset=12, size=38, flag='logical') p5 = make_partition(m, d, offset=12, size=38, flag="logical")
self.assertEqual( self.assertEqual(([], 0), gaps.movable_trailing_partitions_and_gap_size(p5))
([], 0),
gaps.movable_trailing_partitions_and_gap_size(p5))
def test_gap_in_extended_partition(self): def test_gap_in_extended_partition(self):
self.use_alignment_data(PartitionAlignmentData( self.use_alignment_data(
part_align=1, min_gap_size=1, min_start_offset=10, PartitionAlignmentData(
min_end_offset=10, primary_part_limit=10, ebr_space=2)) part_align=1,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
ebr_space=2,
)
)
# 0----10---20---30---40---50---60---70---80---90---100 # 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 (extended) ] ##### # #####[ p1 (extended) ] #####
# ######[ p5 (logical)] ##### # ######[ p5 (logical)] #####
m, d = make_model_and_disk(size=100, ptable='dos') m, d = make_model_and_disk(size=100, ptable="dos")
make_partition(m, d, offset=10, size=40, flag='extended') make_partition(m, d, offset=10, size=40, flag="extended")
p5 = make_partition(m, d, offset=12, size=30, flag='logical') p5 = make_partition(m, d, offset=12, size=30, flag="logical")
self.assertEqual( self.assertEqual(([], 6), gaps.movable_trailing_partitions_and_gap_size(p5))
([], 6),
gaps.movable_trailing_partitions_and_gap_size(p5))
def test_trailing_logical_partition_then_gap(self): def test_trailing_logical_partition_then_gap(self):
self.use_alignment_data(PartitionAlignmentData( self.use_alignment_data(
part_align=1, min_gap_size=1, min_start_offset=10, PartitionAlignmentData(
min_end_offset=10, primary_part_limit=10, ebr_space=2)) part_align=1,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
ebr_space=2,
)
)
# 0----10---20---30---40---50---60---70---80---90---100 # 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 (extended) ]##### # #####[ p1 (extended) ]#####
# ######[ p5 (logical)] [ p6 (logical)] ##### # ######[ p5 (logical)] [ p6 (logical)] #####
m, d = make_model_and_disk(size=100, ptable='dos') m, d = make_model_and_disk(size=100, ptable="dos")
make_partition(m, d, offset=10, size=80, flag='extended') make_partition(m, d, offset=10, size=80, flag="extended")
p5 = make_partition(m, d, offset=12, size=30, flag='logical') p5 = make_partition(m, d, offset=12, size=30, flag="logical")
p6 = make_partition(m, d, offset=44, size=30, flag='logical') p6 = make_partition(m, d, offset=44, size=30, flag="logical")
self.assertEqual( self.assertEqual(([p6], 14), gaps.movable_trailing_partitions_and_gap_size(p5))
([p6], 14),
gaps.movable_trailing_partitions_and_gap_size(p5))
def test_trailing_logical_partition_then_no_gap(self): def test_trailing_logical_partition_then_no_gap(self):
self.use_alignment_data(PartitionAlignmentData( self.use_alignment_data(
part_align=1, min_gap_size=1, min_start_offset=10, PartitionAlignmentData(
min_end_offset=10, primary_part_limit=10, ebr_space=2)) part_align=1,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
ebr_space=2,
)
)
# 0----10---20---30---40---50---60---70---80---90---100 # 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 (extended) ]##### # #####[ p1 (extended) ]#####
# ######[ p5 (logical)] [ p6 (logical) ] ##### # ######[ p5 (logical)] [ p6 (logical) ] #####
m, d = make_model_and_disk(size=100, ptable='dos') m, d = make_model_and_disk(size=100, ptable="dos")
make_partition(m, d, offset=10, size=80, flag='extended') make_partition(m, d, offset=10, size=80, flag="extended")
p5 = make_partition(m, d, offset=12, size=30, flag='logical') p5 = make_partition(m, d, offset=12, size=30, flag="logical")
p6 = make_partition(m, d, offset=44, size=44, flag='logical') p6 = make_partition(m, d, offset=44, size=44, flag="logical")
self.assertEqual( self.assertEqual(([p6], 0), gaps.movable_trailing_partitions_and_gap_size(p5))
([p6], 0),
gaps.movable_trailing_partitions_and_gap_size(p5))
def test_trailing_preserved_partition(self): def test_trailing_preserved_partition(self):
self.use_alignment_data(PartitionAlignmentData( self.use_alignment_data(
part_align=10, min_gap_size=1, min_start_offset=10, PartitionAlignmentData(
min_end_offset=10, primary_part_limit=10)) part_align=10,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
)
)
# 0----10---20---30---40---50---60---70---80---90---100 # 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 ][ p2 p ] ##### # #####[ p1 ][ p2 p ] #####
m, d = make_model_and_disk(size=100) m, d = make_model_and_disk(size=100)
p1 = make_partition(m, d, offset=10, size=40) p1 = make_partition(m, d, offset=10, size=40)
make_partition(m, d, offset=50, size=20, preserve=True) make_partition(m, d, offset=50, size=20, preserve=True)
self.assertEqual( self.assertEqual(([], 0), gaps.movable_trailing_partitions_and_gap_size(p1))
([], 0),
gaps.movable_trailing_partitions_and_gap_size(p1))
def test_trailing_lvm_logical_volume_no_gap(self): def test_trailing_lvm_logical_volume_no_gap(self):
model, lv = make_model_and_lv() model, lv = make_model_and_lv()
self.assertEqual( self.assertEqual(([], 0), gaps.movable_trailing_partitions_and_gap_size(lv))
([], 0),
gaps.movable_trailing_partitions_and_gap_size(lv))
def test_trailing_lvm_logical_volume_with_gap(self): def test_trailing_lvm_logical_volume_with_gap(self):
model, disk = make_model_and_disk( model, disk = make_model_and_disk(size=100 * LVM_CHUNK_SIZE + LVM_OVERHEAD)
size=100 * LVM_CHUNK_SIZE + LVM_OVERHEAD)
vg = make_vg(model, pvs={disk}) vg = make_vg(model, pvs={disk})
lv = make_lv(model, vg=vg, size=40 * LVM_CHUNK_SIZE) lv = make_lv(model, vg=vg, size=40 * LVM_CHUNK_SIZE)
self.assertEqual( self.assertEqual(
([], LVM_CHUNK_SIZE * 60), ([], LVM_CHUNK_SIZE * 60), gaps.movable_trailing_partitions_and_gap_size(lv)
gaps.movable_trailing_partitions_and_gap_size(lv)) )
class TestLargestGaps(unittest.TestCase): class TestLargestGaps(unittest.TestCase):
@ -667,6 +779,7 @@ class TestLargestGaps(unittest.TestCase):
class TestUsable(unittest.TestCase): class TestUsable(unittest.TestCase):
def test_strings(self): def test_strings(self):
self.assertEqual('YES', GapUsable.YES.name) self.assertEqual("YES", GapUsable.YES.name)
self.assertEqual('TOO_MANY_PRIMARY_PARTS', self.assertEqual(
GapUsable.TOO_MANY_PRIMARY_PARTS.name) "TOO_MANY_PRIMARY_PARTS", GapUsable.TOO_MANY_PRIMARY_PARTS.name
)

View File

@ -16,22 +16,17 @@
import unittest import unittest
from subiquity.common.filesystem.labels import ( from subiquity.common.filesystem.labels import annotations, for_client, usage_labels
annotations,
for_client,
usage_labels,
)
from subiquity.models.tests.test_filesystem import ( from subiquity.models.tests.test_filesystem import (
make_model, make_model,
make_model_and_disk, make_model_and_disk,
make_model_and_partition, make_model_and_partition,
make_model_and_raid, make_model_and_raid,
make_partition, make_partition,
) )
class TestAnnotations(unittest.TestCase): class TestAnnotations(unittest.TestCase):
def test_disk_annotations(self): def test_disk_annotations(self):
# disks never have annotations # disks never have annotations
model, disk = make_model_and_disk() model, disk = make_model_and_disk()
@ -42,93 +37,88 @@ class TestAnnotations(unittest.TestCase):
def test_partition_annotations(self): def test_partition_annotations(self):
model = make_model() model = make_model()
part = make_partition(model) part = make_partition(model)
self.assertEqual(annotations(part), ['new']) self.assertEqual(annotations(part), ["new"])
part.preserve = True part.preserve = True
self.assertEqual(annotations(part), ['existing']) self.assertEqual(annotations(part), ["existing"])
model = make_model() model = make_model()
part = make_partition(model, flag="bios_grub") part = make_partition(model, flag="bios_grub")
self.assertEqual( self.assertEqual(annotations(part), ["new", "BIOS grub spacer"])
annotations(part), ['new', 'BIOS grub spacer'])
part.preserve = True part.preserve = True
self.assertEqual( self.assertEqual(
annotations(part), annotations(part), ["existing", "unconfigured", "BIOS grub spacer"]
['existing', 'unconfigured', 'BIOS grub spacer']) )
part.device.grub_device = True part.device.grub_device = True
self.assertEqual( self.assertEqual(
annotations(part), annotations(part), ["existing", "configured", "BIOS grub spacer"]
['existing', 'configured', 'BIOS grub spacer']) )
model = make_model() model = make_model()
part = make_partition(model, flag="boot", grub_device=True) part = make_partition(model, flag="boot", grub_device=True)
self.assertEqual(annotations(part), ['new', 'backup ESP']) self.assertEqual(annotations(part), ["new", "backup ESP"])
fs = model.add_filesystem(part, fstype="fat32") fs = model.add_filesystem(part, fstype="fat32")
model.add_mount(fs, "/boot/efi") model.add_mount(fs, "/boot/efi")
self.assertEqual(annotations(part), ['new', 'primary ESP']) self.assertEqual(annotations(part), ["new", "primary ESP"])
model = make_model() model = make_model()
part = make_partition(model, flag="boot", preserve=True) part = make_partition(model, flag="boot", preserve=True)
self.assertEqual(annotations(part), ['existing', 'unused ESP']) self.assertEqual(annotations(part), ["existing", "unused ESP"])
part.grub_device = True part.grub_device = True
self.assertEqual(annotations(part), ['existing', 'backup ESP']) self.assertEqual(annotations(part), ["existing", "backup ESP"])
fs = model.add_filesystem(part, fstype="fat32") fs = model.add_filesystem(part, fstype="fat32")
model.add_mount(fs, "/boot/efi") model.add_mount(fs, "/boot/efi")
self.assertEqual(annotations(part), ['existing', 'primary ESP']) self.assertEqual(annotations(part), ["existing", "primary ESP"])
model = make_model() model = make_model()
part = make_partition(model, flag="prep", grub_device=True) part = make_partition(model, flag="prep", grub_device=True)
self.assertEqual(annotations(part), ['new', 'PReP']) self.assertEqual(annotations(part), ["new", "PReP"])
model = make_model() model = make_model()
part = make_partition(model, flag="prep", preserve=True) part = make_partition(model, flag="prep", preserve=True)
self.assertEqual( self.assertEqual(annotations(part), ["existing", "PReP", "unconfigured"])
annotations(part), ['existing', 'PReP', 'unconfigured'])
part.grub_device = True part.grub_device = True
self.assertEqual( self.assertEqual(annotations(part), ["existing", "PReP", "configured"])
annotations(part), ['existing', 'PReP', 'configured'])
def test_vg_default_annotations(self): def test_vg_default_annotations(self):
model, disk = make_model_and_disk() model, disk = make_model_and_disk()
vg = model.add_volgroup('vg-0', {disk}) vg = model.add_volgroup("vg-0", {disk})
self.assertEqual(annotations(vg), ['new']) self.assertEqual(annotations(vg), ["new"])
vg.preserve = True vg.preserve = True
self.assertEqual(annotations(vg), ['existing']) self.assertEqual(annotations(vg), ["existing"])
def test_vg_encrypted_annotations(self): def test_vg_encrypted_annotations(self):
model, disk = make_model_and_disk() model, disk = make_model_and_disk()
dm_crypt = model.add_dm_crypt(disk, key='passw0rd') dm_crypt = model.add_dm_crypt(disk, key="passw0rd")
vg = model.add_volgroup('vg-0', {dm_crypt}) vg = model.add_volgroup("vg-0", {dm_crypt})
self.assertEqual(annotations(vg), ['new', 'encrypted']) self.assertEqual(annotations(vg), ["new", "encrypted"])
class TestUsageLabels(unittest.TestCase): class TestUsageLabels(unittest.TestCase):
def test_partition_usage_labels(self): def test_partition_usage_labels(self):
model, partition = make_model_and_partition() model, partition = make_model_and_partition()
self.assertEqual(usage_labels(partition), ["unused"]) self.assertEqual(usage_labels(partition), ["unused"])
fs = model.add_filesystem(partition, 'ext4') fs = model.add_filesystem(partition, "ext4")
self.assertEqual( self.assertEqual(
usage_labels(partition), usage_labels(partition), ["to be formatted as ext4", "not mounted"]
["to be formatted as ext4", "not mounted"]) )
model._orig_config = model._render_actions() model._orig_config = model._render_actions()
fs.preserve = True fs.preserve = True
partition.preserve = True partition.preserve = True
self.assertEqual( self.assertEqual(
usage_labels(partition), usage_labels(partition), ["already formatted as ext4", "not mounted"]
["already formatted as ext4", "not mounted"]) )
model.remove_filesystem(fs) model.remove_filesystem(fs)
fs2 = model.add_filesystem(partition, 'ext4') fs2 = model.add_filesystem(partition, "ext4")
self.assertEqual( self.assertEqual(
usage_labels(partition), usage_labels(partition), ["to be reformatted as ext4", "not mounted"]
["to be reformatted as ext4", "not mounted"]) )
model.add_mount(fs2, '/') model.add_mount(fs2, "/")
self.assertEqual( self.assertEqual(
usage_labels(partition), usage_labels(partition), ["to be reformatted as ext4", "mounted at /"]
["to be reformatted as ext4", "mounted at /"]) )
class TestForClient(unittest.TestCase): class TestForClient(unittest.TestCase):
def test_for_client_raid_parts(self): def test_for_client_raid_parts(self):
model, raid = make_model_and_raid() model, raid = make_model_and_raid()
make_partition(model, raid) make_partition(model, raid)

View File

@ -19,24 +19,17 @@ from unittest import mock
import attr import attr
from subiquitycore.tests.parameterized import parameterized
from subiquity.common.filesystem.actions import (
DeviceAction,
)
from subiquity.common.filesystem import boot, gaps, sizes from subiquity.common.filesystem import boot, gaps, sizes
from subiquity.common.filesystem.actions import DeviceAction
from subiquity.common.filesystem.manipulator import FilesystemManipulator from subiquity.common.filesystem.manipulator import FilesystemManipulator
from subiquity.models.filesystem import Bootloader, MiB, Partition
from subiquity.models.tests.test_filesystem import ( from subiquity.models.tests.test_filesystem import (
make_disk, make_disk,
make_filesystem,
make_model, make_model,
make_partition, make_partition,
make_filesystem, )
) from subiquitycore.tests.parameterized import parameterized
from subiquity.models.filesystem import (
Bootloader,
MiB,
Partition,
)
def make_manipulator(bootloader=None, storage_version=1): def make_manipulator(bootloader=None, storage_version=1):
@ -78,18 +71,16 @@ class Create:
class TestFilesystemManipulator(unittest.TestCase): class TestFilesystemManipulator(unittest.TestCase):
def test_delete_encrypted_vg(self): def test_delete_encrypted_vg(self):
manipulator, disk = make_manipulator_and_disk() manipulator, disk = make_manipulator_and_disk()
spec = { spec = {
'passphrase': 'passw0rd', "passphrase": "passw0rd",
'devices': {disk}, "devices": {disk},
'name': 'vg0', "name": "vg0",
} }
vg = manipulator.create_volgroup(spec) vg = manipulator.create_volgroup(spec)
manipulator.delete_volgroup(vg) manipulator.delete_volgroup(vg)
dm_crypts = [ dm_crypts = [a for a in manipulator.model._actions if a.type == "dm_crypt"]
a for a in manipulator.model._actions if a.type == 'dm_crypt']
self.assertEqual(dm_crypts, []) self.assertEqual(dm_crypts, [])
def test_wipe_existing_fs(self): def test_wipe_existing_fs(self):
@ -98,18 +89,19 @@ class TestFilesystemManipulator(unittest.TestCase):
# example: it's ext2, wipe and make it ext2 again # example: it's ext2, wipe and make it ext2 again
manipulator, d = make_manipulator_and_disk() manipulator, d = make_manipulator_and_disk()
p = make_partition(manipulator.model, d, preserve=True) p = make_partition(manipulator.model, d, preserve=True)
fs = make_filesystem(manipulator.model, partition=p, fstype='ext2', fs = make_filesystem(
preserve=True) manipulator.model, partition=p, fstype="ext2", preserve=True
)
manipulator.model._actions.append(fs) manipulator.model._actions.append(fs)
manipulator.delete_filesystem(fs) manipulator.delete_filesystem(fs)
spec = {'wipe': 'random'} spec = {"wipe": "random"}
with mock.patch.object(p, 'original_fstype') as original_fstype: with mock.patch.object(p, "original_fstype") as original_fstype:
original_fstype.return_value = 'ext2' original_fstype.return_value = "ext2"
fs = manipulator.create_filesystem(p, spec) fs = manipulator.create_filesystem(p, spec)
self.assertTrue(p.preserve) self.assertTrue(p.preserve)
self.assertEqual('random', p.wipe) self.assertEqual("random", p.wipe)
self.assertFalse(fs.preserve) self.assertFalse(fs.preserve)
self.assertEqual('ext2', fs.fstype) self.assertEqual("ext2", fs.fstype)
def test_can_only_add_boot_once(self): def test_can_only_add_boot_once(self):
# This is really testing model code but it's much easier to test with a # This is really testing model code but it's much easier to test with a
@ -122,7 +114,8 @@ class TestFilesystemManipulator(unittest.TestCase):
self.assertFalse( self.assertFalse(
DeviceAction.TOGGLE_BOOT.can(disk)[0], DeviceAction.TOGGLE_BOOT.can(disk)[0],
"add_boot_disk(disk) did not make _can_TOGGLE_BOOT false " "add_boot_disk(disk) did not make _can_TOGGLE_BOOT false "
"with bootloader {}".format(bl)) "with bootloader {}".format(bl),
)
def assertIsMountedAtBootEFI(self, device): def assertIsMountedAtBootEFI(self, device):
efi_mnts = device._m._all(type="mount", path="/boot/efi") efi_mnts = device._m._all(type="mount", path="/boot/efi")
@ -136,20 +129,23 @@ class TestFilesystemManipulator(unittest.TestCase):
def add_existing_boot_partition(self, manipulator, disk): def add_existing_boot_partition(self, manipulator, disk):
if manipulator.model.bootloader == Bootloader.BIOS: if manipulator.model.bootloader == Bootloader.BIOS:
part = manipulator.model.add_partition( part = manipulator.model.add_partition(
disk, size=1 << 20, offset=0, flag="bios_grub") disk, size=1 << 20, offset=0, flag="bios_grub"
)
elif manipulator.model.bootloader == Bootloader.UEFI: elif manipulator.model.bootloader == Bootloader.UEFI:
part = manipulator.model.add_partition( part = manipulator.model.add_partition(
disk, size=512 << 20, offset=0, flag="boot") disk, size=512 << 20, offset=0, flag="boot"
)
elif manipulator.model.bootloader == Bootloader.PREP: elif manipulator.model.bootloader == Bootloader.PREP:
part = manipulator.model.add_partition( part = manipulator.model.add_partition(
disk, size=8 << 20, offset=0, flag="prep") disk, size=8 << 20, offset=0, flag="prep"
)
part.preserve = True part.preserve = True
return part return part
def assertIsBootDisk(self, manipulator, disk): def assertIsBootDisk(self, manipulator, disk):
if manipulator.model.bootloader == Bootloader.BIOS: if manipulator.model.bootloader == Bootloader.BIOS:
self.assertTrue(disk.grub_device) self.assertTrue(disk.grub_device)
if disk.ptable == 'gpt': if disk.ptable == "gpt":
self.assertEqual(disk.partitions()[0].flag, "bios_grub") self.assertEqual(disk.partitions()[0].flag, "bios_grub")
elif manipulator.model.bootloader == Bootloader.UEFI: elif manipulator.model.bootloader == Bootloader.UEFI:
for part in disk.partitions(): for part in disk.partitions():
@ -159,7 +155,7 @@ class TestFilesystemManipulator(unittest.TestCase):
elif manipulator.model.bootloader == Bootloader.PREP: elif manipulator.model.bootloader == Bootloader.PREP:
for part in disk.partitions(): for part in disk.partitions():
if part.flag == "prep" and part.grub_device: if part.flag == "prep" and part.grub_device:
self.assertEqual(part.wipe, 'zero') self.assertEqual(part.wipe, "zero")
return return
self.fail("{} is not a boot disk".format(disk)) self.fail("{} is not a boot disk".format(disk))
@ -186,7 +182,8 @@ class TestFilesystemManipulator(unittest.TestCase):
disk2 = make_disk(manipulator.model, preserve=False) disk2 = make_disk(manipulator.model, preserve=False)
gap = gaps.largest_gap(disk2) gap = gaps.largest_gap(disk2)
disk2p1 = manipulator.model.add_partition( disk2p1 = manipulator.model.add_partition(
disk2, size=gap.size, offset=gap.offset) disk2, size=gap.size, offset=gap.offset
)
manipulator.add_boot_disk(disk1) manipulator.add_boot_disk(disk1)
self.assertIsBootDisk(manipulator, disk1) self.assertIsBootDisk(manipulator, disk1)
@ -202,8 +199,7 @@ class TestFilesystemManipulator(unittest.TestCase):
self.assertNotMounted(disk2.partitions()[0]) self.assertNotMounted(disk2.partitions()[0])
self.assertEqual(len(disk2.partitions()), 2) self.assertEqual(len(disk2.partitions()), 2)
self.assertEqual(disk2.partitions()[1], disk2p1) self.assertEqual(disk2.partitions()[1], disk2p1)
self.assertEqual( self.assertEqual(disk2.partitions()[0].size + disk2p1.size, size_before)
disk2.partitions()[0].size + disk2p1.size, size_before)
manipulator.remove_boot_disk(disk1) manipulator.remove_boot_disk(disk1)
self.assertIsNotBootDisk(manipulator, disk1) self.assertIsNotBootDisk(manipulator, disk1)
@ -228,7 +224,8 @@ class TestFilesystemManipulator(unittest.TestCase):
disk2 = make_disk(manipulator.model, preserve=False) disk2 = make_disk(manipulator.model, preserve=False)
gap = gaps.largest_gap(disk2) gap = gaps.largest_gap(disk2)
disk2p1 = manipulator.model.add_partition( disk2p1 = manipulator.model.add_partition(
disk2, size=gap.size, offset=gap.offset) disk2, size=gap.size, offset=gap.offset
)
manipulator.add_boot_disk(disk1) manipulator.add_boot_disk(disk1)
self.assertIsBootDisk(manipulator, disk1) self.assertIsBootDisk(manipulator, disk1)
@ -243,8 +240,7 @@ class TestFilesystemManipulator(unittest.TestCase):
self.assertIsMountedAtBootEFI(disk2.partitions()[0]) self.assertIsMountedAtBootEFI(disk2.partitions()[0])
self.assertEqual(len(disk2.partitions()), 2) self.assertEqual(len(disk2.partitions()), 2)
self.assertEqual(disk2.partitions()[1], disk2p1) self.assertEqual(disk2.partitions()[1], disk2p1)
self.assertEqual( self.assertEqual(disk2.partitions()[0].size + disk2p1.size, size_before)
disk2.partitions()[0].size + disk2p1.size, size_before)
def test_boot_disk_existing(self): def test_boot_disk_existing(self):
for bl in Bootloader: for bl in Bootloader:
@ -272,13 +268,16 @@ class TestFilesystemManipulator(unittest.TestCase):
manipulator = make_manipulator(Bootloader.UEFI) manipulator = make_manipulator(Bootloader.UEFI)
disk1 = make_disk(manipulator.model, preserve=True) disk1 = make_disk(manipulator.model, preserve=True)
disk1p1 = manipulator.model.add_partition( disk1p1 = manipulator.model.add_partition(
disk1, size=512 << 20, offset=0, flag="boot") disk1, size=512 << 20, offset=0, flag="boot"
)
disk1p1.preserve = True disk1p1.preserve = True
disk1p2 = manipulator.model.add_partition( disk1p2 = manipulator.model.add_partition(
disk1, size=8192 << 20, offset=513 << 20) disk1, size=8192 << 20, offset=513 << 20
)
disk1p2.preserve = True disk1p2.preserve = True
manipulator.partition_disk_handler( manipulator.partition_disk_handler(
disk1, {'fstype': 'ext4', 'mount': '/'}, partition=disk1p2) disk1, {"fstype": "ext4", "mount": "/"}, partition=disk1p2
)
efi_mnt = manipulator.model._mount_for_path("/boot/efi") efi_mnt = manipulator.model._mount_for_path("/boot/efi")
self.assertEqual(efi_mnt.device.volume, disk1p1) self.assertEqual(efi_mnt.device.volume, disk1p1)
@ -295,14 +294,15 @@ class TestFilesystemManipulator(unittest.TestCase):
@contextlib.contextmanager @contextlib.contextmanager
def assertPartitionOperations(self, disk, *ops): def assertPartitionOperations(self, disk, *ops):
existing_parts = set(disk.partitions()) existing_parts = set(disk.partitions())
part_details = {} part_details = {}
for op in ops: for op in ops:
if isinstance(op, MoveResize): if isinstance(op, MoveResize):
part_details[op.part] = ( part_details[op.part] = (
op.part.offset + op.offset, op.part.size + op.size) op.part.offset + op.offset,
op.part.size + op.size,
)
try: try:
yield yield
@ -314,8 +314,8 @@ class TestFilesystemManipulator(unittest.TestCase):
if isinstance(op, MoveResize): if isinstance(op, MoveResize):
new_parts.remove(op.part) new_parts.remove(op.part)
self.assertEqual( self.assertEqual(
part_details[op.part], part_details[op.part], (op.part.offset, op.part.size)
(op.part.offset, op.part.size)) )
else: else:
for part in created_parts: for part in created_parts:
lhs = (part.offset, part.size) lhs = (part.offset, part.size)
@ -339,7 +339,8 @@ class TestFilesystemManipulator(unittest.TestCase):
disk, disk,
Create( Create(
offset=disk.alignment_data().min_start_offset, offset=disk.alignment_data().min_start_offset,
size=sizes.BIOS_GRUB_SIZE_BYTES), size=sizes.BIOS_GRUB_SIZE_BYTES,
),
): ):
manipulator.add_boot_disk(disk) manipulator.add_boot_disk(disk)
@ -349,18 +350,19 @@ class TestFilesystemManipulator(unittest.TestCase):
def test_add_boot_BIOS_full(self, version): def test_add_boot_BIOS_full(self, version):
manipulator = make_manipulator(Bootloader.BIOS, version) manipulator = make_manipulator(Bootloader.BIOS, version)
disk = make_disk(manipulator.model, preserve=True) disk = make_disk(manipulator.model, preserve=True)
part = make_partition( part = make_partition(manipulator.model, disk, size=gaps.largest_gap_size(disk))
manipulator.model, disk, size=gaps.largest_gap_size(disk))
with self.assertPartitionOperations( with self.assertPartitionOperations(
disk, disk,
Create( Create(
offset=disk.alignment_data().min_start_offset, offset=disk.alignment_data().min_start_offset,
size=sizes.BIOS_GRUB_SIZE_BYTES), size=sizes.BIOS_GRUB_SIZE_BYTES,
),
MoveResize( MoveResize(
part=part, part=part,
offset=sizes.BIOS_GRUB_SIZE_BYTES, offset=sizes.BIOS_GRUB_SIZE_BYTES,
size=-sizes.BIOS_GRUB_SIZE_BYTES), size=-sizes.BIOS_GRUB_SIZE_BYTES,
),
): ):
manipulator.add_boot_disk(disk) manipulator.add_boot_disk(disk)
@ -371,14 +373,15 @@ class TestFilesystemManipulator(unittest.TestCase):
manipulator = make_manipulator(Bootloader.BIOS, version) manipulator = make_manipulator(Bootloader.BIOS, version)
disk = make_disk(manipulator.model, preserve=True) disk = make_disk(manipulator.model, preserve=True)
part = make_partition( part = make_partition(
manipulator.model, disk, size=gaps.largest_gap_size(disk)//2) manipulator.model, disk, size=gaps.largest_gap_size(disk) // 2
)
with self.assertPartitionOperations( with self.assertPartitionOperations(
disk, disk,
Create( Create(
offset=disk.alignment_data().min_start_offset, offset=disk.alignment_data().min_start_offset,
size=sizes.BIOS_GRUB_SIZE_BYTES), size=sizes.BIOS_GRUB_SIZE_BYTES,
Move( ),
part=part, offset=sizes.BIOS_GRUB_SIZE_BYTES), Move(part=part, offset=sizes.BIOS_GRUB_SIZE_BYTES),
): ):
manipulator.add_boot_disk(disk) manipulator.add_boot_disk(disk)
self.assertIsBootDisk(manipulator, disk) self.assertIsBootDisk(manipulator, disk)
@ -388,22 +391,25 @@ class TestFilesystemManipulator(unittest.TestCase):
manipulator = make_manipulator(Bootloader.BIOS) manipulator = make_manipulator(Bootloader.BIOS)
# 2002MiB so that the space available for partitioning (2000MiB) # 2002MiB so that the space available for partitioning (2000MiB)
# divided by 4 is an whole number of megabytes. # divided by 4 is an whole number of megabytes.
disk = make_disk(manipulator.model, preserve=True, size=2002*MiB) disk = make_disk(manipulator.model, preserve=True, size=2002 * MiB)
part_smaller = make_partition( part_smaller = make_partition(
manipulator.model, disk, size=gaps.largest_gap_size(disk)//4) manipulator.model, disk, size=gaps.largest_gap_size(disk) // 4
)
part_larger = make_partition( part_larger = make_partition(
manipulator.model, disk, size=gaps.largest_gap_size(disk)) manipulator.model, disk, size=gaps.largest_gap_size(disk)
)
with self.assertPartitionOperations( with self.assertPartitionOperations(
disk, disk,
Create( Create(
offset=disk.alignment_data().min_start_offset, offset=disk.alignment_data().min_start_offset,
size=sizes.BIOS_GRUB_SIZE_BYTES), size=sizes.BIOS_GRUB_SIZE_BYTES,
Move( ),
part=part_smaller, offset=sizes.BIOS_GRUB_SIZE_BYTES), Move(part=part_smaller, offset=sizes.BIOS_GRUB_SIZE_BYTES),
MoveResize( MoveResize(
part=part_larger, part=part_larger,
offset=sizes.BIOS_GRUB_SIZE_BYTES, offset=sizes.BIOS_GRUB_SIZE_BYTES,
size=-sizes.BIOS_GRUB_SIZE_BYTES), size=-sizes.BIOS_GRUB_SIZE_BYTES,
),
): ):
manipulator.add_boot_disk(disk) manipulator.add_boot_disk(disk)
self.assertIsBootDisk(manipulator, disk) self.assertIsBootDisk(manipulator, disk)
@ -413,31 +419,31 @@ class TestFilesystemManipulator(unittest.TestCase):
manipulator = make_manipulator(Bootloader.BIOS, version) manipulator = make_manipulator(Bootloader.BIOS, version)
disk = make_disk(manipulator.model, preserve=True) disk = make_disk(manipulator.model, preserve=True)
full_size = gaps.largest_gap_size(disk) full_size = gaps.largest_gap_size(disk)
make_partition( make_partition(manipulator.model, disk, size=full_size, preserve=True)
manipulator.model, disk, size=full_size, preserve=True)
self.assertFalse(boot.can_be_boot_device(disk)) self.assertFalse(boot.can_be_boot_device(disk))
def test_no_add_boot_BIOS_preserved_v1(self): def test_no_add_boot_BIOS_preserved_v1(self):
manipulator = make_manipulator(Bootloader.BIOS, 1) manipulator = make_manipulator(Bootloader.BIOS, 1)
disk = make_disk(manipulator.model, preserve=True) disk = make_disk(manipulator.model, preserve=True)
half_size = gaps.largest_gap_size(disk)//2 half_size = gaps.largest_gap_size(disk) // 2
make_partition( make_partition(
manipulator.model, disk, size=half_size, offset=half_size, manipulator.model, disk, size=half_size, offset=half_size, preserve=True
preserve=True) )
self.assertFalse(boot.can_be_boot_device(disk)) self.assertFalse(boot.can_be_boot_device(disk))
def test_add_boot_BIOS_preserved_v2(self): def test_add_boot_BIOS_preserved_v2(self):
manipulator = make_manipulator(Bootloader.BIOS, 2) manipulator = make_manipulator(Bootloader.BIOS, 2)
disk = make_disk(manipulator.model, preserve=True) disk = make_disk(manipulator.model, preserve=True)
half_size = gaps.largest_gap_size(disk)//2 half_size = gaps.largest_gap_size(disk) // 2
part = make_partition( part = make_partition(
manipulator.model, disk, size=half_size, offset=half_size, manipulator.model, disk, size=half_size, offset=half_size, preserve=True
preserve=True) )
with self.assertPartitionOperations( with self.assertPartitionOperations(
disk, disk,
Create( Create(
offset=disk.alignment_data().min_start_offset, offset=disk.alignment_data().min_start_offset,
size=sizes.BIOS_GRUB_SIZE_BYTES), size=sizes.BIOS_GRUB_SIZE_BYTES,
),
Unchanged(part=part), Unchanged(part=part),
): ):
manipulator.add_boot_disk(disk) manipulator.add_boot_disk(disk)
@ -447,21 +453,23 @@ class TestFilesystemManipulator(unittest.TestCase):
manipulator = make_manipulator(Bootloader.BIOS, 2) manipulator = make_manipulator(Bootloader.BIOS, 2)
# 2002MiB so that the space available for partitioning (2000MiB) # 2002MiB so that the space available for partitioning (2000MiB)
# divided by 4 is an whole number of megabytes. # divided by 4 is an whole number of megabytes.
disk = make_disk(manipulator.model, preserve=True, size=2002*MiB) disk = make_disk(manipulator.model, preserve=True, size=2002 * MiB)
avail = gaps.largest_gap_size(disk) avail = gaps.largest_gap_size(disk)
p1_new = make_partition( p1_new = make_partition(manipulator.model, disk, size=avail // 4)
manipulator.model, disk, size=avail//4)
p2_preserved = make_partition( p2_preserved = make_partition(
manipulator.model, disk, size=3*avail//4, preserve=True) manipulator.model, disk, size=3 * avail // 4, preserve=True
)
with self.assertPartitionOperations( with self.assertPartitionOperations(
disk, disk,
Create( Create(
offset=disk.alignment_data().min_start_offset, offset=disk.alignment_data().min_start_offset,
size=sizes.BIOS_GRUB_SIZE_BYTES), size=sizes.BIOS_GRUB_SIZE_BYTES,
),
MoveResize( MoveResize(
part=p1_new, part=p1_new,
offset=sizes.BIOS_GRUB_SIZE_BYTES, offset=sizes.BIOS_GRUB_SIZE_BYTES,
size=-sizes.BIOS_GRUB_SIZE_BYTES), size=-sizes.BIOS_GRUB_SIZE_BYTES,
),
Unchanged(part=p2_preserved), Unchanged(part=p2_preserved),
): ):
manipulator.add_boot_disk(disk) manipulator.add_boot_disk(disk)
@ -471,10 +479,8 @@ class TestFilesystemManipulator(unittest.TestCase):
def test_no_add_boot_BIOS_preserved_at_start(self, version): def test_no_add_boot_BIOS_preserved_at_start(self, version):
manipulator = make_manipulator(Bootloader.BIOS, version) manipulator = make_manipulator(Bootloader.BIOS, version)
disk = make_disk(manipulator.model, preserve=True) disk = make_disk(manipulator.model, preserve=True)
half_size = gaps.largest_gap_size(disk)//2 half_size = gaps.largest_gap_size(disk) // 2
make_partition( make_partition(manipulator.model, disk, size=half_size, preserve=True)
manipulator.model, disk, size=half_size,
preserve=True)
self.assertFalse(boot.can_be_boot_device(disk)) self.assertFalse(boot.can_be_boot_device(disk))
@parameterized.expand([(1,), (2,)]) @parameterized.expand([(1,), (2,)])
@ -482,16 +488,16 @@ class TestFilesystemManipulator(unittest.TestCase):
# Showed up in VM testing - the ISO created cloud-localds shows up as a # Showed up in VM testing - the ISO created cloud-localds shows up as a
# potential install target, and weird things happen. # potential install target, and weird things happen.
manipulator = make_manipulator(Bootloader.BIOS, version) manipulator = make_manipulator(Bootloader.BIOS, version)
disk = make_disk(manipulator.model, ptable='gpt', size=1 << 20) disk = make_disk(manipulator.model, ptable="gpt", size=1 << 20)
self.assertFalse(boot.can_be_boot_device(disk)) self.assertFalse(boot.can_be_boot_device(disk))
@parameterized.expand([(1,), (2,)]) @parameterized.expand([(1,), (2,)])
def test_add_boot_BIOS_msdos(self, version): def test_add_boot_BIOS_msdos(self, version):
manipulator = make_manipulator(Bootloader.BIOS, version) manipulator = make_manipulator(Bootloader.BIOS, version)
disk = make_disk(manipulator.model, preserve=True, ptable='msdos') disk = make_disk(manipulator.model, preserve=True, ptable="msdos")
part = make_partition( part = make_partition(
manipulator.model, disk, size=gaps.largest_gap_size(disk), manipulator.model, disk, size=gaps.largest_gap_size(disk), preserve=True
preserve=True) )
with self.assertPartitionOperations( with self.assertPartitionOperations(
disk, disk,
Unchanged(part=part), Unchanged(part=part),
@ -512,7 +518,8 @@ class TestFilesystemManipulator(unittest.TestCase):
(Bootloader.UEFI, 1), (Bootloader.UEFI, 1),
(Bootloader.PREP, 1), (Bootloader.PREP, 1),
(Bootloader.UEFI, 2), (Bootloader.UEFI, 2),
(Bootloader.PREP, 2)] (Bootloader.PREP, 2),
]
@parameterized.expand(ADD_BOOT_PARAMS) @parameterized.expand(ADD_BOOT_PARAMS)
def test_add_boot_empty(self, bl, version): def test_add_boot_empty(self, bl, version):
@ -522,7 +529,8 @@ class TestFilesystemManipulator(unittest.TestCase):
disk, disk,
Create( Create(
offset=disk.alignment_data().min_start_offset, offset=disk.alignment_data().min_start_offset,
size=self.boot_size_for_disk(disk)), size=self.boot_size_for_disk(disk),
),
): ):
manipulator.add_boot_disk(disk) manipulator.add_boot_disk(disk)
self.assertIsBootDisk(manipulator, disk) self.assertIsBootDisk(manipulator, disk)
@ -531,18 +539,12 @@ class TestFilesystemManipulator(unittest.TestCase):
def test_add_boot_full(self, bl, version): def test_add_boot_full(self, bl, version):
manipulator = make_manipulator(bl, version) manipulator = make_manipulator(bl, version)
disk = make_disk(manipulator.model, preserve=True) disk = make_disk(manipulator.model, preserve=True)
part = make_partition( part = make_partition(manipulator.model, disk, size=gaps.largest_gap_size(disk))
manipulator.model, disk, size=gaps.largest_gap_size(disk))
size = self.boot_size_for_disk(disk) size = self.boot_size_for_disk(disk)
with self.assertPartitionOperations( with self.assertPartitionOperations(
disk, disk,
Create( Create(offset=disk.alignment_data().min_start_offset, size=size),
offset=disk.alignment_data().min_start_offset, MoveResize(part=part, offset=size, size=-size),
size=size),
MoveResize(
part=part,
offset=size,
size=-size),
): ):
manipulator.add_boot_disk(disk) manipulator.add_boot_disk(disk)
self.assertIsBootDisk(manipulator, disk) self.assertIsBootDisk(manipulator, disk)
@ -552,14 +554,13 @@ class TestFilesystemManipulator(unittest.TestCase):
manipulator = make_manipulator(bl, version) manipulator = make_manipulator(bl, version)
disk = make_disk(manipulator.model, preserve=True) disk = make_disk(manipulator.model, preserve=True)
part = make_partition( part = make_partition(
manipulator.model, disk, size=gaps.largest_gap_size(disk)//2) manipulator.model, disk, size=gaps.largest_gap_size(disk) // 2
)
size = self.boot_size_for_disk(disk) size = self.boot_size_for_disk(disk)
with self.assertPartitionOperations( with self.assertPartitionOperations(
disk, disk,
Unchanged(part=part), Unchanged(part=part),
Create( Create(offset=part.offset + part.size, size=size),
offset=part.offset + part.size,
size=size),
): ):
manipulator.add_boot_disk(disk) manipulator.add_boot_disk(disk)
self.assertIsBootDisk(manipulator, disk) self.assertIsBootDisk(manipulator, disk)
@ -569,22 +570,19 @@ class TestFilesystemManipulator(unittest.TestCase):
manipulator = make_manipulator(bl, version) manipulator = make_manipulator(bl, version)
# 2002MiB so that the space available for partitioning (2000MiB) # 2002MiB so that the space available for partitioning (2000MiB)
# divided by 4 is an whole number of megabytes. # divided by 4 is an whole number of megabytes.
disk = make_disk(manipulator.model, preserve=True, size=2002*MiB) disk = make_disk(manipulator.model, preserve=True, size=2002 * MiB)
part_smaller = make_partition( part_smaller = make_partition(
manipulator.model, disk, size=gaps.largest_gap_size(disk)//4) manipulator.model, disk, size=gaps.largest_gap_size(disk) // 4
)
part_larger = make_partition( part_larger = make_partition(
manipulator.model, disk, size=gaps.largest_gap_size(disk)) manipulator.model, disk, size=gaps.largest_gap_size(disk)
)
size = self.boot_size_for_disk(disk) size = self.boot_size_for_disk(disk)
with self.assertPartitionOperations( with self.assertPartitionOperations(
disk, disk,
Unchanged(part_smaller), Unchanged(part_smaller),
Create( Create(offset=part_smaller.offset + part_smaller.size, size=size),
offset=part_smaller.offset + part_smaller.size, MoveResize(part=part_larger, offset=size, size=-size),
size=size),
MoveResize(
part=part_larger,
offset=size,
size=-size),
): ):
manipulator.add_boot_disk(disk) manipulator.add_boot_disk(disk)
self.assertIsBootDisk(manipulator, disk) self.assertIsBootDisk(manipulator, disk)
@ -594,18 +592,15 @@ class TestFilesystemManipulator(unittest.TestCase):
manipulator = make_manipulator(bl, version) manipulator = make_manipulator(bl, version)
disk = make_disk(manipulator.model, preserve=True) disk = make_disk(manipulator.model, preserve=True)
full_size = gaps.largest_gap_size(disk) full_size = gaps.largest_gap_size(disk)
make_partition( make_partition(manipulator.model, disk, size=full_size, preserve=True)
manipulator.model, disk, size=full_size, preserve=True)
self.assertFalse(boot.can_be_boot_device(disk)) self.assertFalse(boot.can_be_boot_device(disk))
@parameterized.expand(ADD_BOOT_PARAMS) @parameterized.expand(ADD_BOOT_PARAMS)
def test_add_boot_half_preserved_partition(self, bl, version): def test_add_boot_half_preserved_partition(self, bl, version):
manipulator = make_manipulator(bl, version) manipulator = make_manipulator(bl, version)
disk = make_disk(manipulator.model, preserve=True) disk = make_disk(manipulator.model, preserve=True)
half_size = gaps.largest_gap_size(disk)//2 half_size = gaps.largest_gap_size(disk) // 2
part = make_partition( part = make_partition(manipulator.model, disk, size=half_size, preserve=True)
manipulator.model, disk, size=half_size,
preserve=True)
size = self.boot_size_for_disk(disk) size = self.boot_size_for_disk(disk)
if version == 1: if version == 1:
self.assertFalse(boot.can_be_boot_device(disk)) self.assertFalse(boot.can_be_boot_device(disk))
@ -613,9 +608,7 @@ class TestFilesystemManipulator(unittest.TestCase):
with self.assertPartitionOperations( with self.assertPartitionOperations(
disk, disk,
Unchanged(part), Unchanged(part),
Create( Create(offset=part.offset + part.size, size=size),
offset=part.offset + part.size,
size=size),
): ):
manipulator.add_boot_disk(disk) manipulator.add_boot_disk(disk)
self.assertIsBootDisk(manipulator, disk) self.assertIsBootDisk(manipulator, disk)
@ -626,23 +619,16 @@ class TestFilesystemManipulator(unittest.TestCase):
# situation in a v2 world. # situation in a v2 world.
manipulator = make_manipulator(bl, 2) manipulator = make_manipulator(bl, 2)
disk = make_disk(manipulator.model, preserve=True) disk = make_disk(manipulator.model, preserve=True)
half_size = gaps.largest_gap_size(disk)//2 half_size = gaps.largest_gap_size(disk) // 2
old_part = make_partition( old_part = make_partition(manipulator.model, disk, size=half_size)
manipulator.model, disk, size=half_size) new_part = make_partition(manipulator.model, disk, size=half_size)
new_part = make_partition(
manipulator.model, disk, size=half_size)
old_part.preserve = True old_part.preserve = True
size = self.boot_size_for_disk(disk) size = self.boot_size_for_disk(disk)
with self.assertPartitionOperations( with self.assertPartitionOperations(
disk, disk,
Unchanged(old_part), Unchanged(old_part),
Create( Create(offset=new_part.offset, size=size),
offset=new_part.offset, MoveResize(part=new_part, offset=size, size=-size),
size=size),
MoveResize(
part=new_part,
offset=size,
size=-size),
): ):
manipulator.add_boot_disk(disk) manipulator.add_boot_disk(disk)
self.assertIsBootDisk(manipulator, disk) self.assertIsBootDisk(manipulator, disk)
@ -650,14 +636,14 @@ class TestFilesystemManipulator(unittest.TestCase):
@parameterized.expand(ADD_BOOT_PARAMS) @parameterized.expand(ADD_BOOT_PARAMS)
def test_add_boot_existing_boot_partition(self, bl, version): def test_add_boot_existing_boot_partition(self, bl, version):
manipulator = make_manipulator(bl, version) manipulator = make_manipulator(bl, version)
disk = make_disk(manipulator.model, preserve=True, size=2002*MiB) disk = make_disk(manipulator.model, preserve=True, size=2002 * MiB)
if bl == Bootloader.UEFI: if bl == Bootloader.UEFI:
flag = 'boot' flag = "boot"
else: else:
flag = 'prep' flag = "prep"
boot_part = make_partition( boot_part = make_partition(
manipulator.model, disk, size=self.boot_size_for_disk(disk), manipulator.model, disk, size=self.boot_size_for_disk(disk), flag=flag
flag=flag) )
rest_part = make_partition(manipulator.model, disk) rest_part = make_partition(manipulator.model, disk)
boot_part.preserve = rest_part.preserve = True boot_part.preserve = rest_part.preserve = True
with self.assertPartitionOperations( with self.assertPartitionOperations(
@ -671,39 +657,39 @@ class TestFilesystemManipulator(unittest.TestCase):
def test_logical_no_esp_in_gap(self): def test_logical_no_esp_in_gap(self):
manipulator = make_manipulator(Bootloader.UEFI, storage_version=2) manipulator = make_manipulator(Bootloader.UEFI, storage_version=2)
m = manipulator.model m = manipulator.model
d1 = make_disk(m, preserve=True, ptable='msdos') d1 = make_disk(m, preserve=True, ptable="msdos")
size = gaps.largest_gap_size(d1) size = gaps.largest_gap_size(d1)
make_partition(m, d1, flag='extended', preserve=True, size=size) make_partition(m, d1, flag="extended", preserve=True, size=size)
size = gaps.largest_gap_size(d1) // 2 size = gaps.largest_gap_size(d1) // 2
make_partition(m, d1, flag='logical', preserve=True, size=size) make_partition(m, d1, flag="logical", preserve=True, size=size)
self.assertTrue(gaps.largest_gap_size(d1) > sizes.uefi_scale.maximum) self.assertTrue(gaps.largest_gap_size(d1) > sizes.uefi_scale.maximum)
self.assertFalse(boot.can_be_boot_device(d1)) self.assertFalse(boot.can_be_boot_device(d1))
def test_logical_no_esp_by_resize(self): def test_logical_no_esp_by_resize(self):
manipulator = make_manipulator(Bootloader.UEFI, storage_version=2) manipulator = make_manipulator(Bootloader.UEFI, storage_version=2)
m = manipulator.model m = manipulator.model
d1 = make_disk(m, preserve=True, ptable='msdos') d1 = make_disk(m, preserve=True, ptable="msdos")
size = gaps.largest_gap_size(d1) size = gaps.largest_gap_size(d1)
make_partition(m, d1, flag='extended', preserve=True, size=size) make_partition(m, d1, flag="extended", preserve=True, size=size)
size = gaps.largest_gap_size(d1) size = gaps.largest_gap_size(d1)
p5 = make_partition(m, d1, flag='logical', preserve=True, size=size) p5 = make_partition(m, d1, flag="logical", preserve=True, size=size)
self.assertFalse(boot.can_be_boot_device(d1, resize_partition=p5)) self.assertFalse(boot.can_be_boot_device(d1, resize_partition=p5))
def test_logical_no_esp_in_new_part(self): def test_logical_no_esp_in_new_part(self):
manipulator = make_manipulator(Bootloader.UEFI, storage_version=2) manipulator = make_manipulator(Bootloader.UEFI, storage_version=2)
m = manipulator.model m = manipulator.model
d1 = make_disk(m, preserve=True, ptable='msdos') d1 = make_disk(m, preserve=True, ptable="msdos")
size = gaps.largest_gap_size(d1) size = gaps.largest_gap_size(d1)
make_partition(m, d1, flag='extended', size=size) make_partition(m, d1, flag="extended", size=size)
size = gaps.largest_gap_size(d1) size = gaps.largest_gap_size(d1)
make_partition(m, d1, flag='logical', size=size) make_partition(m, d1, flag="logical", size=size)
self.assertFalse(boot.can_be_boot_device(d1)) self.assertFalse(boot.can_be_boot_device(d1))
@parameterized.expand([[Bootloader.UEFI], [Bootloader.PREP]]) @parameterized.expand([[Bootloader.UEFI], [Bootloader.PREP]])
def test_too_many_primaries(self, bl): def test_too_many_primaries(self, bl):
manipulator = make_manipulator(bl, storage_version=2) manipulator = make_manipulator(bl, storage_version=2)
m = manipulator.model m = manipulator.model
d1 = make_disk(m, preserve=True, ptable='msdos', size=100 << 30) d1 = make_disk(m, preserve=True, ptable="msdos", size=100 << 30)
tenth = 10 << 30 tenth = 10 << 30
for i in range(4): for i in range(4):
make_partition(m, d1, preserve=True, size=tenth) make_partition(m, d1, preserve=True, size=tenth)
@ -715,12 +701,12 @@ class TestFilesystemManipulator(unittest.TestCase):
def test_logical_when_in_extended(self): def test_logical_when_in_extended(self):
manipulator = make_manipulator(storage_version=2) manipulator = make_manipulator(storage_version=2)
m = manipulator.model m = manipulator.model
d1 = make_disk(m, preserve=True, ptable='msdos') d1 = make_disk(m, preserve=True, ptable="msdos")
size = gaps.largest_gap_size(d1) size = gaps.largest_gap_size(d1)
make_partition(m, d1, flag='extended', size=size) make_partition(m, d1, flag="extended", size=size)
gap = gaps.largest_gap(d1) gap = gaps.largest_gap(d1)
p = manipulator.create_partition(d1, gap, spec={}, flag='primary') p = manipulator.create_partition(d1, gap, spec={}, flag="primary")
self.assertEqual(p.flag, 'logical') self.assertEqual(p.flag, "logical")
class TestReformat(unittest.TestCase): class TestReformat(unittest.TestCase):
@ -733,19 +719,19 @@ class TestReformat(unittest.TestCase):
self.assertEqual(None, disk.ptable) self.assertEqual(None, disk.ptable)
def test_reformat_keep_current(self): def test_reformat_keep_current(self):
disk = make_disk(self.manipulator.model, ptable='msdos') disk = make_disk(self.manipulator.model, ptable="msdos")
self.manipulator.reformat(disk) self.manipulator.reformat(disk)
self.assertEqual('msdos', disk.ptable) self.assertEqual("msdos", disk.ptable)
def test_reformat_to_gpt(self): def test_reformat_to_gpt(self):
disk = make_disk(self.manipulator.model, ptable=None) disk = make_disk(self.manipulator.model, ptable=None)
self.manipulator.reformat(disk, 'gpt') self.manipulator.reformat(disk, "gpt")
self.assertEqual('gpt', disk.ptable) self.assertEqual("gpt", disk.ptable)
def test_reformat_to_msdos(self): def test_reformat_to_msdos(self):
disk = make_disk(self.manipulator.model, ptable=None) disk = make_disk(self.manipulator.model, ptable=None)
self.manipulator.reformat(disk, 'msdos') self.manipulator.reformat(disk, "msdos")
self.assertEqual('msdos', disk.ptable) self.assertEqual("msdos", disk.ptable)
class TestCanResize(unittest.TestCase): class TestCanResize(unittest.TestCase):
@ -760,11 +746,11 @@ class TestCanResize(unittest.TestCase):
def test_resize_ext4(self): def test_resize_ext4(self):
disk = make_disk(self.manipulator.model, ptable=None) disk = make_disk(self.manipulator.model, ptable=None)
part = make_partition(self.manipulator.model, disk, preserve=True) part = make_partition(self.manipulator.model, disk, preserve=True)
make_filesystem(self.manipulator.model, partition=part, fstype='ext4') make_filesystem(self.manipulator.model, partition=part, fstype="ext4")
self.assertTrue(self.manipulator.can_resize_partition(part)) self.assertTrue(self.manipulator.can_resize_partition(part))
def test_resize_invalid(self): def test_resize_invalid(self):
disk = make_disk(self.manipulator.model, ptable=None) disk = make_disk(self.manipulator.model, ptable=None)
part = make_partition(self.manipulator.model, disk, preserve=True) part = make_partition(self.manipulator.model, disk, preserve=True)
make_filesystem(self.manipulator.model, partition=part, fstype='asdf') make_filesystem(self.manipulator.model, partition=part, fstype="asdf")
self.assertFalse(self.manipulator.can_resize_partition(part)) self.assertFalse(self.manipulator.can_resize_partition(part))

View File

@ -18,19 +18,17 @@ import unittest
from unittest import mock from unittest import mock
from subiquity.common.filesystem.sizes import ( from subiquity.common.filesystem.sizes import (
PartitionScaleFactors,
bootfs_scale, bootfs_scale,
calculate_guided_resize, calculate_guided_resize,
calculate_suggested_install_min, calculate_suggested_install_min,
get_efi_size,
get_bootfs_size, get_bootfs_size,
PartitionScaleFactors, get_efi_size,
scale_partitions, scale_partitions,
uefi_scale, uefi_scale,
) )
from subiquity.common.types import GuidedResizeValues from subiquity.common.types import GuidedResizeValues
from subiquity.models.filesystem import ( from subiquity.models.filesystem import MiB
MiB,
)
class TestPartitionSizeScaling(unittest.TestCase): class TestPartitionSizeScaling(unittest.TestCase):
@ -87,46 +85,54 @@ class TestPartitionSizeScaling(unittest.TestCase):
class TestCalculateGuidedResize(unittest.TestCase): class TestCalculateGuidedResize(unittest.TestCase):
def test_ignore_nonresizable(self): def test_ignore_nonresizable(self):
actual = calculate_guided_resize( actual = calculate_guided_resize(
part_min=-1, part_size=100 << 30, install_min=10 << 30) part_min=-1, part_size=100 << 30, install_min=10 << 30
)
self.assertIsNone(actual) self.assertIsNone(actual)
def test_too_small(self): def test_too_small(self):
actual = calculate_guided_resize( actual = calculate_guided_resize(
part_min=95 << 30, part_size=100 << 30, install_min=10 << 30) part_min=95 << 30, part_size=100 << 30, install_min=10 << 30
)
self.assertIsNone(actual) self.assertIsNone(actual)
def test_even_split(self): def test_even_split(self):
# 8 GiB * 1.25 == 10 GiB # 8 GiB * 1.25 == 10 GiB
size = 10 << 30 size = 10 << 30
actual = calculate_guided_resize( actual = calculate_guided_resize(
part_min=8 << 30, part_size=100 << 30, install_min=size) part_min=8 << 30, part_size=100 << 30, install_min=size
)
expected = GuidedResizeValues( expected = GuidedResizeValues(
install_max=(100 << 30) - size, install_max=(100 << 30) - size,
minimum=size, recommended=50 << 30, maximum=(100 << 30) - size) minimum=size,
recommended=50 << 30,
maximum=(100 << 30) - size,
)
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
def test_weighted_split(self): def test_weighted_split(self):
actual = calculate_guided_resize( actual = calculate_guided_resize(
part_min=40 << 30, part_size=240 << 30, install_min=10 << 30) part_min=40 << 30, part_size=240 << 30, install_min=10 << 30
)
expected = GuidedResizeValues( expected = GuidedResizeValues(
install_max=190 << 30, install_max=190 << 30,
minimum=50 << 30, recommended=200 << 30, maximum=230 << 30) minimum=50 << 30,
recommended=200 << 30,
maximum=230 << 30,
)
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
def test_unaligned_size(self): def test_unaligned_size(self):
def assertAligned(val): def assertAligned(val):
if val % MiB != 0: if val % MiB != 0:
self.fail( self.fail(
"val of %s not aligned (%s) with delta %s" % ( "val of %s not aligned (%s) with delta %s" % (val, val % MiB, delta)
val, val % MiB, delta)) )
for i in range(1000): for i in range(1000):
delta = random.randrange(MiB) delta = random.randrange(MiB)
actual = calculate_guided_resize( actual = calculate_guided_resize(
part_min=40 << 30, part_min=40 << 30, part_size=(240 << 30) + delta, install_min=10 << 30
part_size=(240 << 30) + delta, )
install_min=10 << 30)
assertAligned(actual.install_max) assertAligned(actual.install_max)
assertAligned(actual.minimum) assertAligned(actual.minimum)
assertAligned(actual.recommended) assertAligned(actual.recommended)
@ -134,8 +140,8 @@ class TestCalculateGuidedResize(unittest.TestCase):
class TestCalculateInstallMin(unittest.TestCase): class TestCalculateInstallMin(unittest.TestCase):
@mock.patch('subiquity.common.filesystem.sizes.swap.suggested_swapsize') @mock.patch("subiquity.common.filesystem.sizes.swap.suggested_swapsize")
@mock.patch('subiquity.common.filesystem.sizes.bootfs_scale') @mock.patch("subiquity.common.filesystem.sizes.bootfs_scale")
def test_small_setups(self, bootfs_scale, swapsize): def test_small_setups(self, bootfs_scale, swapsize):
swapsize.return_value = 1 << 30 swapsize.return_value = 1 << 30
bootfs_scale.maximum = 1 << 30 bootfs_scale.maximum = 1 << 30
@ -143,24 +149,24 @@ class TestCalculateInstallMin(unittest.TestCase):
# with a small source, we hit the default 2GiB padding # with a small source, we hit the default 2GiB padding
self.assertEqual(5 << 30, calculate_suggested_install_min(source_min)) self.assertEqual(5 << 30, calculate_suggested_install_min(source_min))
@mock.patch('subiquity.common.filesystem.sizes.swap.suggested_swapsize') @mock.patch("subiquity.common.filesystem.sizes.swap.suggested_swapsize")
@mock.patch('subiquity.common.filesystem.sizes.bootfs_scale') @mock.patch("subiquity.common.filesystem.sizes.bootfs_scale")
def test_small_setups_big_swap(self, bootfs_scale, swapsize): def test_small_setups_big_swap(self, bootfs_scale, swapsize):
swapsize.return_value = 10 << 30 swapsize.return_value = 10 << 30
bootfs_scale.maximum = 1 << 30 bootfs_scale.maximum = 1 << 30
source_min = 1 << 30 source_min = 1 << 30
self.assertEqual(14 << 30, calculate_suggested_install_min(source_min)) self.assertEqual(14 << 30, calculate_suggested_install_min(source_min))
@mock.patch('subiquity.common.filesystem.sizes.swap.suggested_swapsize') @mock.patch("subiquity.common.filesystem.sizes.swap.suggested_swapsize")
@mock.patch('subiquity.common.filesystem.sizes.bootfs_scale') @mock.patch("subiquity.common.filesystem.sizes.bootfs_scale")
def test_small_setups_big_boot(self, bootfs_scale, swapsize): def test_small_setups_big_boot(self, bootfs_scale, swapsize):
swapsize.return_value = 1 << 30 swapsize.return_value = 1 << 30
bootfs_scale.maximum = 10 << 30 bootfs_scale.maximum = 10 << 30
source_min = 1 << 30 source_min = 1 << 30
self.assertEqual(14 << 30, calculate_suggested_install_min(source_min)) self.assertEqual(14 << 30, calculate_suggested_install_min(source_min))
@mock.patch('subiquity.common.filesystem.sizes.swap.suggested_swapsize') @mock.patch("subiquity.common.filesystem.sizes.swap.suggested_swapsize")
@mock.patch('subiquity.common.filesystem.sizes.bootfs_scale') @mock.patch("subiquity.common.filesystem.sizes.bootfs_scale")
def test_big_source(self, bootfs_scale, swapsize): def test_big_source(self, bootfs_scale, swapsize):
swapsize.return_value = 1 << 30 swapsize.return_value = 1 << 30
bootfs_scale.maximum = 2 << 30 bootfs_scale.maximum = 2 << 30

View File

@ -20,7 +20,7 @@ import logging
import os import os
from typing import List from typing import List
log = logging.getLogger('subiquity.common.resources') log = logging.getLogger("subiquity.common.resources")
def resource_path(relative_path): def resource_path(relative_path):
@ -31,17 +31,17 @@ def get_users_and_groups(chroot_prefix=[]) -> List:
# prevent import when calling just resource_path # prevent import when calling just resource_path
from subiquitycore.utils import run_command from subiquitycore.utils import run_command
users_and_groups_path = resource_path('users-and-groups') users_and_groups_path = resource_path("users-and-groups")
groups = ['admin'] groups = ["admin"]
if os.path.exists(users_and_groups_path): if os.path.exists(users_and_groups_path):
with open(users_and_groups_path) as f: with open(users_and_groups_path) as f:
groups = f.read().split() groups = f.read().split()
groups.append('sudo') groups.append("sudo")
command = chroot_prefix + ['getent', 'group'] command = chroot_prefix + ["getent", "group"]
cp = run_command(command, check=True) cp = run_command(command, check=True)
target_groups = set() target_groups = set()
for line in cp.stdout.splitlines(): for line in cp.stdout.splitlines():
target_groups.add(line.split(':')[0]) target_groups.add(line.split(":")[0])
return list(target_groups.intersection(groups)) return list(target_groups.intersection(groups))

View File

@ -15,19 +15,19 @@
import datetime import datetime
import enum import enum
import json
import inspect import inspect
import json
import typing import typing
import attr import attr
def named_field(name, default=attr.NOTHING): def named_field(name, default=attr.NOTHING):
return attr.ib(metadata={'name': name}, default=default) return attr.ib(metadata={"name": name}, default=default)
def _field_name(field): def _field_name(field):
return field.metadata.get('name', field.name) return field.metadata.get("name", field.name)
class SerializationError(Exception): class SerializationError(Exception):
@ -39,7 +39,7 @@ class SerializationError(Exception):
def __str__(self): def __str__(self):
p = self.path p = self.path
if not p: if not p:
p = 'top-level' p = "top-level"
return f"processing {self.obj}: at {p}, {self.message}" return f"processing {self.obj}: at {p}, {self.message}"
@ -53,13 +53,12 @@ class SerializationContext:
@classmethod @classmethod
def new(cls, obj, *, serializing): def new(cls, obj, *, serializing):
return SerializationContext(obj, obj, '', {}, serializing) return SerializationContext(obj, obj, "", {}, serializing)
def child(self, path, cur, metadata=None): def child(self, path, cur, metadata=None):
if metadata is None: if metadata is None:
metadata = self.metadata metadata = self.metadata
return attr.evolve( return attr.evolve(self, path=self.path + path, cur=cur, metadata=metadata)
self, path=self.path + path, cur=cur, metadata=metadata)
def error(self, message): def error(self, message):
raise SerializationError(self.obj, self.path, message) raise SerializationError(self.obj, self.path, message)
@ -74,9 +73,9 @@ class SerializationContext:
class Serializer: class Serializer:
def __init__(
def __init__(self, *, compact=False, ignore_unknown_fields=False, self, *, compact=False, ignore_unknown_fields=False, serialize_enums_by="name"
serialize_enums_by="name"): ):
self.compact = compact self.compact = compact
self.ignore_unknown_fields = ignore_unknown_fields self.ignore_unknown_fields = ignore_unknown_fields
assert serialize_enums_by in ("value", "name") assert serialize_enums_by in ("value", "name")
@ -119,14 +118,14 @@ class Serializer:
if self.compact: if self.compact:
r.insert(0, a.__name__) r.insert(0, a.__name__)
else: else:
r['$type'] = a.__name__ r["$type"] = a.__name__
return r return r
context.error(f"type of {context.cur} not found in {args}") context.error(f"type of {context.cur} not found in {args}")
else: else:
if self.compact: if self.compact:
n = context.cur.pop(0) n = context.cur.pop(0)
else: else:
n = context.cur.pop('$type') n = context.cur.pop("$type")
for a in args: for a in args:
if a.__name__ == n: if a.__name__ == n:
return meth(a, context) return meth(a, context)
@ -135,8 +134,7 @@ class Serializer:
def _walk_List(self, meth, args, context): def _walk_List(self, meth, args, context):
return [ return [
meth(args[0], context.child(f'[{i}]', v)) meth(args[0], context.child(f"[{i}]", v)) for i, v in enumerate(context.cur)
for i, v in enumerate(context.cur)
] ]
def _walk_Dict(self, meth, args, context): def _walk_Dict(self, meth, args, context):
@ -145,10 +143,13 @@ class Serializer:
input_items = context.cur input_items = context.cur
else: else:
input_items = context.cur.items() input_items = context.cur.items()
output_items = [[ output_items = [
meth(k_ann, context.child(f'/{k}', k)), [
meth(v_ann, context.child(f'[{k}]', v)) meth(k_ann, context.child(f"/{k}", k)),
] for k, v in input_items] meth(v_ann, context.child(f"[{k}]", v)),
]
for k, v in input_items
]
if context.serializing and k_ann is not str: if context.serializing and k_ann is not str:
return output_items return output_items
return dict(output_items) return dict(output_items)
@ -156,12 +157,12 @@ class Serializer:
def _serialize_dict(self, annotation, context): def _serialize_dict(self, annotation, context):
context.assert_type(annotation) context.assert_type(annotation)
for k in context.cur: for k in context.cur:
context.child(f'/{k}', k).assert_type(str) context.child(f"/{k}", k).assert_type(str)
return context.cur return context.cur
def _serialize_datetime(self, annotation, context): def _serialize_datetime(self, annotation, context):
context.assert_type(annotation) context.assert_type(annotation)
fmt = context.metadata.get('time_fmt') fmt = context.metadata.get("time_fmt")
if fmt is not None: if fmt is not None:
return context.cur.strftime(fmt) return context.cur.strftime(fmt)
else: else:
@ -170,15 +171,19 @@ class Serializer:
def _serialize_attr(self, annotation, context): def _serialize_attr(self, annotation, context):
serialized = [] serialized = []
for field in attr.fields(annotation): for field in attr.fields(annotation):
serialized.append(( serialized.append(
(
_field_name(field), _field_name(field),
self._serialize( self._serialize(
field.type, field.type,
context.child( context.child(
f'.{field.name}', f".{field.name}",
getattr(context.cur, field.name), getattr(context.cur, field.name),
field.metadata)), field.metadata,
)) ),
),
)
)
if self.compact: if self.compact:
return [s[1] for s in serialized] return [s[1] for s in serialized]
else: else:
@ -196,7 +201,7 @@ class Serializer:
return context.cur return context.cur
if attr.has(annotation): if attr.has(annotation):
return self._serialize_attr(annotation, context) return self._serialize_attr(annotation, context)
origin = getattr(annotation, '__origin__', None) origin = getattr(annotation, "__origin__", None)
if origin is not None: if origin is not None:
args = annotation.__args__ args = annotation.__args__
return self.typing_walkers[origin](self._serialize, args, context) return self.typing_walkers[origin](self._serialize, args, context)
@ -214,7 +219,7 @@ class Serializer:
return self._serialize(annotation, context) return self._serialize(annotation, context)
def _deserialize_datetime(self, annotation, context): def _deserialize_datetime(self, annotation, context):
fmt = context.metadata.get('time_fmt') fmt = context.metadata.get("time_fmt")
if fmt is None: if fmt is None:
context.error("cannot serialize datetime without format") context.error("cannot serialize datetime without format")
return datetime.datetime.strptime(context.cur, fmt) return datetime.datetime.strptime(context.cur, fmt)
@ -224,19 +229,19 @@ class Serializer:
context.assert_type(list) context.assert_type(list)
args = [] args = []
for field, value in zip(attr.fields(annotation), context.cur): for field, value in zip(attr.fields(annotation), context.cur):
args.append(self._deserialize( args.append(
self._deserialize(
field.type, field.type,
context.child(f'[{field.name!r}]', value, field.metadata))) context.child(f"[{field.name!r}]", value, field.metadata),
)
)
return annotation(*args) return annotation(*args)
else: else:
context.assert_type(dict) context.assert_type(dict)
args = {} args = {}
fields = { fields = {_field_name(field): field for field in attr.fields(annotation)}
_field_name(field): field for field in attr.fields(annotation)
}
for key, value in context.cur.items(): for key, value in context.cur.items():
if key not in fields and ( if key not in fields and (key == "$type" or self.ignore_unknown_fields):
key == '$type' or self.ignore_unknown_fields):
# Union types can contain a '$type' field that is not # Union types can contain a '$type' field that is not
# actually one of the keys. This happens if a object is # actually one of the keys. This happens if a object is
# serialized as part of a Union, sent to an API caller, # serialized as part of a Union, sent to an API caller,
@ -245,8 +250,8 @@ class Serializer:
continue continue
field = fields[key] field = fields[key]
args[field.name] = self._deserialize( args[field.name] = self._deserialize(
field.type, field.type, context.child(f"[{key!r}]", value, field.metadata)
context.child(f'[{key!r}]', value, field.metadata)) )
return annotation(**args) return annotation(**args)
def _deserialize_enum(self, annotation, context): def _deserialize_enum(self, annotation, context):
@ -263,10 +268,11 @@ class Serializer:
return context.cur return context.cur
if attr.has(annotation): if attr.has(annotation):
return self._deserialize_attr(annotation, context) return self._deserialize_attr(annotation, context)
origin = getattr(annotation, '__origin__', None) origin = getattr(annotation, "__origin__", None)
if origin is not None: if origin is not None:
return self.typing_walkers[origin]( return self.typing_walkers[origin](
self._deserialize, annotation.__args__, context) self._deserialize, annotation.__args__, context
)
if isinstance(annotation, type) and issubclass(annotation, enum.Enum): if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
return self._deserialize_enum(annotation, context) return self._deserialize_enum(annotation, context)
return self.type_deserializers[annotation](annotation, context) return self.type_deserializers[annotation](annotation, context)

View File

@ -13,7 +13,6 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import attr
import datetime import datetime
import enum import enum
import inspect import inspect
@ -22,11 +21,9 @@ import string
import typing import typing
import unittest import unittest
from subiquity.common.serialize import ( import attr
named_field,
Serializer, from subiquity.common.serialize import SerializationError, Serializer, named_field
SerializationError,
)
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
@ -36,9 +33,7 @@ class Data:
@staticmethod @staticmethod
def make_random(): def make_random():
return Data( return Data(random.choice(string.ascii_letters), random.randint(0, 1000))
random.choice(string.ascii_letters),
random.randint(0, 1000))
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
@ -50,8 +45,8 @@ class Container:
def make_random(): def make_random():
return Container( return Container(
data=Data.make_random(), data=Data.make_random(),
data_list=[Data.make_random() data_list=[Data.make_random() for i in range(random.randint(10, 20))],
for i in range(random.randint(10, 20))]) )
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
@ -67,7 +62,6 @@ class MyEnum(enum.Enum):
class CommonSerializerTests: class CommonSerializerTests:
simple_examples = [ simple_examples = [
(int, 1), (int, 1),
(str, "v"), (str, "v"),
@ -77,17 +71,14 @@ class CommonSerializerTests:
] ]
def assertSerializesTo(self, annotation, value, expected): def assertSerializesTo(self, annotation, value, expected):
self.assertEqual( self.assertEqual(self.serializer.serialize(annotation, value), expected)
self.serializer.serialize(annotation, value), expected)
def assertDeserializesTo(self, annotation, value, expected): def assertDeserializesTo(self, annotation, value, expected):
self.assertEqual( self.assertEqual(self.serializer.deserialize(annotation, value), expected)
self.serializer.deserialize(annotation, value), expected)
def assertRoundtrips(self, annotation, value): def assertRoundtrips(self, annotation, value):
serialized = self.serializer.to_json(annotation, value) serialized = self.serializer.to_json(annotation, value)
self.assertEqual( self.assertEqual(self.serializer.from_json(annotation, serialized), value)
self.serializer.from_json(annotation, serialized), value)
def assertSerialization(self, annotation, value, expected): def assertSerialization(self, annotation, value, expected):
self.assertSerializesTo(annotation, value, expected) self.assertSerializesTo(annotation, value, expected)
@ -114,8 +105,7 @@ class CommonSerializerTests:
self.assertSerialization(typ, val, val) self.assertSerialization(typ, val, val)
def test_non_string_key_dict(self): def test_non_string_key_dict(self):
self.assertRaises( self.assertRaises(Exception, self.serializer.serialize, dict, {1: 2})
Exception, self.serializer.serialize, dict, {1: 2})
def test_roundtrip_dict(self): def test_roundtrip_dict(self):
ann = typing.Dict[int, str] ann = typing.Dict[int, str]
@ -141,7 +131,8 @@ class CommonSerializerTests:
def test_enums_by_value(self): def test_enums_by_value(self):
self.serializer = type(self.serializer)( self.serializer = type(self.serializer)(
compact=self.serializer.compact, serialize_enums_by="value") compact=self.serializer.compact, serialize_enums_by="value"
)
self.assertSerialization(MyEnum, MyEnum.name, "value") self.assertSerialization(MyEnum, MyEnum.name, "value")
def test_serialize_any(self): def test_serialize_any(self):
@ -151,19 +142,19 @@ class CommonSerializerTests:
class TestSerializer(CommonSerializerTests, unittest.TestCase): class TestSerializer(CommonSerializerTests, unittest.TestCase):
serializer = Serializer() serializer = Serializer()
def test_datetime(self): def test_datetime(self):
@attr.s @attr.s
class C: class C:
d: datetime.datetime = attr.ib(metadata={'time_fmt': '%Y-%m-%d'}) d: datetime.datetime = attr.ib(metadata={"time_fmt": "%Y-%m-%d"})
c = C(datetime.datetime(2022, 1, 1)) c = C(datetime.datetime(2022, 1, 1))
self.assertSerialization(C, c, {"d": "2022-01-01"}) self.assertSerialization(C, c, {"d": "2022-01-01"})
def test_serialize_attr(self): def test_serialize_attr(self):
data = Data.make_random() data = Data.make_random()
expected = {'field1': data.field1, 'field2': data.field2} expected = {"field1": data.field1, "field2": data.field2}
self.assertSerialization(Data, data, expected) self.assertSerialization(Data, data, expected)
def test_serialize_container(self): def test_serialize_container(self):
@ -171,9 +162,9 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase):
data2 = Data.make_random() data2 = Data.make_random()
container = Container(data1, [data2]) container = Container(data1, [data2])
expected = { expected = {
'data': {'field1': data1.field1, 'field2': data1.field2}, "data": {"field1": data1.field1, "field2": data1.field2},
'data_list': [ "data_list": [
{'field1': data2.field1, 'field2': data2.field2}, {"field1": data2.field1, "field2": data2.field2},
], ],
} }
self.assertSerialization(Container, container, expected) self.assertSerialization(Container, container, expected)
@ -181,9 +172,9 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase):
def test_serialize_union(self): def test_serialize_union(self):
data = Data.make_random() data = Data.make_random()
expected = { expected = {
'$type': 'Data', "$type": "Data",
'field1': data.field1, "field1": data.field1,
'field2': data.field2, "field2": data.field2,
} }
self.assertSerialization(typing.Union[Data, Container], data, expected) self.assertSerialization(typing.Union[Data, Container], data, expected)
@ -193,18 +184,18 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase):
# API entrypoint, one that isn't taking a Union, it must be cool with # API entrypoint, one that isn't taking a Union, it must be cool with
# the excess $type field. # the excess $type field.
data = { data = {
'$type': 'Data', "$type": "Data",
'field1': '1', "field1": "1",
'field2': 2, "field2": 2,
} }
expected = Data(field1='1', field2=2) expected = Data(field1="1", field2=2)
self.assertDeserializesTo(Data, data, expected) self.assertDeserializesTo(Data, data, expected)
def test_reject_unknown_fields_by_default(self): def test_reject_unknown_fields_by_default(self):
serializer = Serializer() serializer = Serializer()
data = Data.make_random() data = Data.make_random()
serialized = serializer.serialize(Data, data) serialized = serializer.serialize(Data, data)
serialized['foobar'] = 'baz' serialized["foobar"] = "baz"
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
serializer.deserialize(Data, serialized) serializer.deserialize(Data, serialized)
@ -212,11 +203,10 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase):
serializer = Serializer(ignore_unknown_fields=True) serializer = Serializer(ignore_unknown_fields=True)
data = Data.make_random() data = Data.make_random()
serialized = serializer.serialize(Data, data) serialized = serializer.serialize(Data, data)
serialized['foobar'] = 'baz' serialized["foobar"] = "baz"
self.assertEqual(serializer.deserialize(Data, serialized), data) self.assertEqual(serializer.deserialize(Data, serialized), data)
def test_override_field_name(self): def test_override_field_name(self):
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class Object: class Object:
x: int x: int
@ -224,10 +214,10 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase):
z: int = named_field("field-z", 0) z: int = named_field("field-z", 0)
self.assertSerialization( self.assertSerialization(
Object, Object(1, 2), {"x": 1, "field-y": 2, "field-z": 0}) Object, Object(1, 2), {"x": 1, "field-y": 2, "field-z": 0}
)
def test_embedding(self): def test_embedding(self):
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class Base1: class Base1:
x: str x: str
@ -249,28 +239,28 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase):
self.assertSerialization( self.assertSerialization(
Derived2, Derived2,
Derived2(b=Derived1(x="a", y=1), c=2), Derived2(b=Derived1(x="a", y=1), c=2),
{"b": {"x": "a", "y": 1}, "c": 2}) {"b": {"x": "a", "y": 1}, "c": 2},
)
def test_error_paths(self): def test_error_paths(self):
with self.assertRaises(SerializationError) as catcher: with self.assertRaises(SerializationError) as catcher:
self.serializer.serialize(str, 1) self.serializer.serialize(str, 1)
self.assertEqual(catcher.exception.path, '') self.assertEqual(catcher.exception.path, "")
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class Type: class Type:
field1: str = named_field('field-1') field1: str = named_field("field-1")
field2: int field2: int
with self.assertRaises(SerializationError) as catcher: with self.assertRaises(SerializationError) as catcher:
self.serializer.serialize(Type, Data(2, 3)) self.serializer.serialize(Type, Data(2, 3))
self.assertEqual(catcher.exception.path, '.field1') self.assertEqual(catcher.exception.path, ".field1")
with self.assertRaises(SerializationError) as catcher: with self.assertRaises(SerializationError) as catcher:
self.serializer.deserialize(Type, {'field-1': 1, 'field2': 2}) self.serializer.deserialize(Type, {"field-1": 1, "field2": 2})
self.assertEqual(catcher.exception.path, "['field-1']") self.assertEqual(catcher.exception.path, "['field-1']")
class TestCompactSerializer(CommonSerializerTests, unittest.TestCase): class TestCompactSerializer(CommonSerializerTests, unittest.TestCase):
serializer = Serializer(compact=True) serializer = Serializer(compact=True)
def test_serialize_attr(self): def test_serialize_attr(self):
@ -290,70 +280,57 @@ class TestCompactSerializer(CommonSerializerTests, unittest.TestCase):
def test_serialize_union(self): def test_serialize_union(self):
data = Data.make_random() data = Data.make_random()
expected = ['Data', data.field1, data.field2] expected = ["Data", data.field1, data.field2]
self.assertSerialization(typing.Union[Data, Container], data, expected) self.assertSerialization(typing.Union[Data, Container], data, expected)
class TestOptionalAndDefault(CommonSerializerTests, unittest.TestCase): class TestOptionalAndDefault(CommonSerializerTests, unittest.TestCase):
serializer = Serializer() serializer = Serializer()
def test_happy(self): def test_happy(self):
data = { data = {
'int': 11, "int": 11,
'optional_int': 12, "optional_int": 12,
'int_default': 3, "int_default": 3,
'optional_int_default': 4 "optional_int_default": 4,
} }
expected = OptionalAndDefault(11, 12) expected = OptionalAndDefault(11, 12)
self.assertDeserializesTo(OptionalAndDefault, data, expected) self.assertDeserializesTo(OptionalAndDefault, data, expected)
def test_skip_int(self): def test_skip_int(self):
data = { data = {"optional_int": 12, "int_default": 13, "optional_int_default": 14}
'optional_int': 12,
'int_default': 13,
'optional_int_default': 14
}
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
self.serializer.deserialize(OptionalAndDefault, data) self.serializer.deserialize(OptionalAndDefault, data)
def test_skip_optional_int(self): def test_skip_optional_int(self):
data = { data = {"int": 11, "int_default": 13, "optional_int_default": 14}
'int': 11,
'int_default': 13,
'optional_int_default': 14
}
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
self.serializer.deserialize(OptionalAndDefault, data) self.serializer.deserialize(OptionalAndDefault, data)
def test_optional_int_none(self): def test_optional_int_none(self):
data = { data = {
'int': 11, "int": 11,
'optional_int': None, "optional_int": None,
'int_default': 13, "int_default": 13,
'optional_int_default': 14 "optional_int_default": 14,
} }
expected = OptionalAndDefault(11, None, 13, 14) expected = OptionalAndDefault(11, None, 13, 14)
self.assertDeserializesTo(OptionalAndDefault, data, expected) self.assertDeserializesTo(OptionalAndDefault, data, expected)
def test_skip_int_default(self): def test_skip_int_default(self):
data = { data = {"int": 11, "optional_int": 12, "optional_int_default": 14}
'int': 11,
'optional_int': 12,
'optional_int_default': 14
}
expected = OptionalAndDefault(11, 12, 3, 14) expected = OptionalAndDefault(11, 12, 3, 14)
self.assertDeserializesTo(OptionalAndDefault, data, expected) self.assertDeserializesTo(OptionalAndDefault, data, expected)
def test_skip_optional_int_default(self): def test_skip_optional_int_default(self):
data = { data = {
'int': 11, "int": 11,
'optional_int': 12, "optional_int": 12,
'int_default': 13, "int_default": 13,
} }
expected = OptionalAndDefault(11, 12, 13, 4) expected = OptionalAndDefault(11, 12, 13, 4)
@ -361,10 +338,10 @@ class TestOptionalAndDefault(CommonSerializerTests, unittest.TestCase):
def test_optional_int_default_none(self): def test_optional_int_default_none(self):
data = { data = {
'int': 11, "int": 11,
'optional_int': 12, "optional_int": 12,
'int_default': 13, "int_default": 13,
'optional_int_default': None, "optional_int_default": None,
} }
expected = OptionalAndDefault(11, 12, 13, None) expected = OptionalAndDefault(11, 12, 13, None)
@ -372,11 +349,11 @@ class TestOptionalAndDefault(CommonSerializerTests, unittest.TestCase):
def test_extra(self): def test_extra(self):
data = { data = {
'int': 11, "int": 11,
'optional_int': 12, "optional_int": 12,
'int_default': 13, "int_default": 13,
'optional_int_default': 14, "optional_int_default": 14,
'extra': 15 "extra": 15,
} }
with self.assertRaises(KeyError): with self.assertRaises(KeyError):

View File

@ -15,19 +15,16 @@
import unittest import unittest
from subiquity.common.types import ( from subiquity.common.types import GuidedCapability, SizingPolicy
GuidedCapability,
SizingPolicy,
)
class TestSizingPolicy(unittest.TestCase): class TestSizingPolicy(unittest.TestCase):
def test_all(self): def test_all(self):
actual = SizingPolicy.from_string('all') actual = SizingPolicy.from_string("all")
self.assertEqual(SizingPolicy.ALL, actual) self.assertEqual(SizingPolicy.ALL, actual)
def test_scaled_size(self): def test_scaled_size(self):
actual = SizingPolicy.from_string('scaled') actual = SizingPolicy.from_string("scaled")
self.assertEqual(SizingPolicy.SCALED, actual) self.assertEqual(SizingPolicy.SCALED, actual)
def test_default(self): def test_default(self):

View File

@ -56,7 +56,7 @@ class ErrorReportRef:
class ApplicationState(enum.Enum): class ApplicationState(enum.Enum):
""" Represents the state of the application at a given time. """ """Represents the state of the application at a given time."""
# States reported during the initial stages of the installation. # States reported during the initial stages of the installation.
STARTING_UP = enum.auto() STARTING_UP = enum.auto()
@ -125,8 +125,8 @@ class RefreshCheckState(enum.Enum):
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class RefreshStatus: class RefreshStatus:
availability: RefreshCheckState availability: RefreshCheckState
current_snap_version: str = '' current_snap_version: str = ""
new_snap_version: str = '' new_snap_version: str = ""
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
@ -163,7 +163,7 @@ class KeyboardSetting:
# Ideally, we would internally represent a keyboard setting as a # Ideally, we would internally represent a keyboard setting as a
# toggle + a list of [layout, variant]. # toggle + a list of [layout, variant].
layout: str layout: str
variant: str = '' variant: str = ""
toggle: Optional[str] = None toggle: Optional[str] = None
@ -216,7 +216,7 @@ class ZdevInfo:
@classmethod @classmethod
def from_row(cls, row): def from_row(cls, row):
row = dict((k.split('=', 1) for k in shlex.split(row))) row = dict((k.split("=", 1) for k in shlex.split(row)))
for k, v in row.items(): for k, v in row.items():
if v == "yes": if v == "yes":
row[k] = True row[k] = True
@ -228,8 +228,8 @@ class ZdevInfo:
@property @property
def typeclass(self): def typeclass(self):
if self.type.startswith('zfcp'): if self.type.startswith("zfcp"):
return 'zfcp' return "zfcp"
return self.type return self.type
@ -352,22 +352,25 @@ class GuidedCapability(enum.Enum):
CORE_BOOT_PREFER_UNENCRYPTED = enum.auto() CORE_BOOT_PREFER_UNENCRYPTED = enum.auto()
def is_lvm(self) -> bool: def is_lvm(self) -> bool:
return self in [GuidedCapability.LVM, return self in [GuidedCapability.LVM, GuidedCapability.LVM_LUKS]
GuidedCapability.LVM_LUKS]
def is_core_boot(self) -> bool: def is_core_boot(self) -> bool:
return self in [GuidedCapability.CORE_BOOT_ENCRYPTED, return self in [
GuidedCapability.CORE_BOOT_ENCRYPTED,
GuidedCapability.CORE_BOOT_UNENCRYPTED, GuidedCapability.CORE_BOOT_UNENCRYPTED,
GuidedCapability.CORE_BOOT_PREFER_ENCRYPTED, GuidedCapability.CORE_BOOT_PREFER_ENCRYPTED,
GuidedCapability.CORE_BOOT_PREFER_UNENCRYPTED] GuidedCapability.CORE_BOOT_PREFER_UNENCRYPTED,
]
def supports_manual_customization(self) -> bool: def supports_manual_customization(self) -> bool:
# After posting this capability to guided_POST, is it possible # After posting this capability to guided_POST, is it possible
# for the user to customize the layout further? # for the user to customize the layout further?
return self in [GuidedCapability.MANUAL, return self in [
GuidedCapability.MANUAL,
GuidedCapability.DIRECT, GuidedCapability.DIRECT,
GuidedCapability.LVM, GuidedCapability.LVM,
GuidedCapability.LVM_LUKS] GuidedCapability.LVM_LUKS,
]
def is_zfs(self) -> bool: def is_zfs(self) -> bool:
return self in [GuidedCapability.ZFS] return self in [GuidedCapability.ZFS]
@ -414,11 +417,11 @@ class SizingPolicy(enum.Enum):
@classmethod @classmethod
def from_string(cls, value): def from_string(cls, value):
if value is None or value == 'scaled': if value is None or value == "scaled":
return cls.SCALED return cls.SCALED
if value == 'all': if value == "all":
return cls.ALL return cls.ALL
raise Exception(f'Unknown SizingPolicy value {value}') raise Exception(f"Unknown SizingPolicy value {value}")
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
@ -470,15 +473,16 @@ class GuidedStorageTargetUseGap:
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class GuidedStorageTargetManual: class GuidedStorageTargetManual:
allowed: List[GuidedCapability] = attr.Factory( allowed: List[GuidedCapability] = attr.Factory(lambda: [GuidedCapability.MANUAL])
lambda: [GuidedCapability.MANUAL])
disallowed: List[GuidedDisallowedCapability] = attr.Factory(list) disallowed: List[GuidedDisallowedCapability] = attr.Factory(list)
GuidedStorageTarget = Union[GuidedStorageTargetReformat, GuidedStorageTarget = Union[
GuidedStorageTargetReformat,
GuidedStorageTargetResize, GuidedStorageTargetResize,
GuidedStorageTargetUseGap, GuidedStorageTargetUseGap,
GuidedStorageTargetManual] GuidedStorageTargetManual,
]
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
@ -519,10 +523,10 @@ class ReformatDisk:
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class IdentityData: class IdentityData:
realname: str = '' realname: str = ""
username: str = '' username: str = ""
crypted_password: str = attr.ib(default='', repr=False) crypted_password: str = attr.ib(default="", repr=False)
hostname: str = '' hostname: str = ""
class UsernameValidation(enum.Enum): class UsernameValidation(enum.Enum):
@ -542,7 +546,8 @@ class SSHData:
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class SSHIdentity: class SSHIdentity:
""" Represents a SSH identity (public key + fingerprint). """ """Represents a SSH identity (public key + fingerprint)."""
key_type: str key_type: str
key: str key: str
key_comment: str key_comment: str
@ -579,25 +584,26 @@ class ChannelSnapInfo:
version: str version: str
size: int size: int
released_at: datetime.datetime = attr.ib( released_at: datetime.datetime = attr.ib(
metadata={'time_fmt': '%Y-%m-%dT%H:%M:%S.%fZ'}) metadata={"time_fmt": "%Y-%m-%dT%H:%M:%S.%fZ"}
)
@attr.s(auto_attribs=True, eq=False) @attr.s(auto_attribs=True, eq=False)
class SnapInfo: class SnapInfo:
name: str name: str
summary: str = '' summary: str = ""
publisher: str = '' publisher: str = ""
verified: bool = False verified: bool = False
starred: bool = False starred: bool = False
description: str = '' description: str = ""
confinement: str = '' confinement: str = ""
license: str = '' license: str = ""
channels: List[ChannelSnapInfo] = attr.Factory(list) channels: List[ChannelSnapInfo] = attr.Factory(list)
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class DriversResponse: class DriversResponse:
""" Response to GET request to drivers. """Response to GET request to drivers.
:install: tells whether third-party drivers will be installed (if any is :install: tells whether third-party drivers will be installed (if any is
available). available).
:drivers: tells what third-party drivers will be installed should we decide :drivers: tells what third-party drivers will be installed should we decide
@ -606,6 +612,7 @@ class DriversResponse:
:local_only: tells if we are looking for drivers only from the ISO. :local_only: tells if we are looking for drivers only from the ISO.
:search_drivers: enables or disables drivers listing. :search_drivers: enables or disables drivers listing.
""" """
install: bool install: bool
drivers: Optional[List[str]] drivers: Optional[List[str]]
local_only: bool local_only: bool
@ -654,7 +661,8 @@ class UbuntuProInfo:
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class UbuntuProResponse: class UbuntuProResponse:
""" Response to GET request to /ubuntu_pro """ """Response to GET request to /ubuntu_pro"""
token: str = attr.ib(repr=False) token: str = attr.ib(repr=False)
has_network: bool has_network: bool
@ -668,7 +676,8 @@ class UbuntuProCheckTokenStatus(enum.Enum):
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class UPCSInitiateResponse: class UPCSInitiateResponse:
""" Response to Ubuntu Pro contract selection initiate request. """ """Response to Ubuntu Pro contract selection initiate request."""
user_code: str user_code: str
validity_seconds: int validity_seconds: int
@ -680,7 +689,8 @@ class UPCSWaitStatus(enum.Enum):
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class UPCSWaitResponse: class UPCSWaitResponse:
""" Response to Ubuntu Pro contract selection wait request. """ """Response to Ubuntu Pro contract selection wait request."""
status: UPCSWaitStatus status: UPCSWaitStatus
contract_token: Optional[str] contract_token: Optional[str]
@ -715,8 +725,8 @@ class ShutdownMode(enum.Enum):
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class WSLConfigurationBase: class WSLConfigurationBase:
automount_root: str = attr.ib(default='/mnt/') automount_root: str = attr.ib(default="/mnt/")
automount_options: str = '' automount_options: str = ""
network_generatehosts: bool = attr.ib(default=True) network_generatehosts: bool = attr.ib(default=True)
network_generateresolvconf: bool = attr.ib(default=True) network_generateresolvconf: bool = attr.ib(default=True)
@ -750,7 +760,7 @@ class TaskStatus(enum.Enum):
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class TaskProgress: class TaskProgress:
label: str = '' label: str = ""
done: int = 0 done: int = 0
total: int = 0 total: int = 0
@ -777,10 +787,10 @@ class Change:
class CasperMd5Results(enum.Enum): class CasperMd5Results(enum.Enum):
UNKNOWN = 'unknown' UNKNOWN = "unknown"
FAIL = 'fail' FAIL = "fail"
PASS = 'pass' PASS = "pass"
SKIP = 'skip' SKIP = "skip"
class MirrorCheckStatus(enum.Enum): class MirrorCheckStatus(enum.Enum):
@ -817,9 +827,9 @@ class MirrorGet:
class MirrorSelectionFallback(enum.Enum): class MirrorSelectionFallback(enum.Enum):
ABORT = 'abort' ABORT = "abort"
CONTINUE_ANYWAY = 'continue-anyway' CONTINUE_ANYWAY = "continue-anyway"
OFFLINE_INSTALL = 'offline-install' OFFLINE_INSTALL = "offline-install"
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
@ -830,32 +840,32 @@ class AdConnectionInfo:
class AdAdminNameValidation(enum.Enum): class AdAdminNameValidation(enum.Enum):
OK = 'OK' OK = "OK"
EMPTY = 'Empty' EMPTY = "Empty"
INVALID_CHARS = 'Contains invalid characters' INVALID_CHARS = "Contains invalid characters"
class AdDomainNameValidation(enum.Enum): class AdDomainNameValidation(enum.Enum):
OK = 'OK' OK = "OK"
EMPTY = 'Empty' EMPTY = "Empty"
TOO_LONG = 'Too long' TOO_LONG = "Too long"
INVALID_CHARS = 'Contains invalid characters' INVALID_CHARS = "Contains invalid characters"
START_DOT = 'Starts with a dot' START_DOT = "Starts with a dot"
END_DOT = 'Ends with a dot' END_DOT = "Ends with a dot"
START_HYPHEN = 'Starts with a hyphen' START_HYPHEN = "Starts with a hyphen"
END_HYPHEN = 'Ends with a hyphen' END_HYPHEN = "Ends with a hyphen"
MULTIPLE_DOTS = 'Contains multiple dots' MULTIPLE_DOTS = "Contains multiple dots"
REALM_NOT_FOUND = 'Could not find the domain controller' REALM_NOT_FOUND = "Could not find the domain controller"
class AdPasswordValidation(enum.Enum): class AdPasswordValidation(enum.Enum):
OK = 'OK' OK = "OK"
EMPTY = 'Empty' EMPTY = "Empty"
class AdJoinResult(enum.Enum): class AdJoinResult(enum.Enum):
OK = 'OK' OK = "OK"
JOIN_ERROR = 'Failed to join' JOIN_ERROR = "Failed to join"
EMPTY_HOSTNAME = 'Target hostname cannot be empty' EMPTY_HOSTNAME = "Target hostname cannot be empty"
PAM_ERROR = 'Failed to update pam-auth' PAM_ERROR = "Failed to update pam-auth"
UNKNOWN = "Didn't attempt to join yet" UNKNOWN = "Didn't attempt to join yet"

View File

@ -23,7 +23,7 @@ def journald_listen(identifiers, callback, seek=False):
reader = journal.Reader() reader = journal.Reader()
args = [] args = []
for identifier in identifiers: for identifier in identifiers:
if '=' in identifier: if "=" in identifier:
args.append(identifier) args.append(identifier)
else: else:
args.append("SYSLOG_IDENTIFIER={}".format(identifier)) args.append("SYSLOG_IDENTIFIER={}".format(identifier))
@ -37,6 +37,7 @@ def journald_listen(identifiers, callback, seek=False):
return return
for event in reader: for event in reader:
callback(event) callback(event)
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
loop.add_reader(reader.fileno(), watch) loop.add_reader(reader.fileno(), watch)
return reader.fileno() return reader.fileno()

View File

@ -18,11 +18,10 @@ import logging
from subiquitycore.async_helpers import run_in_thread from subiquitycore.async_helpers import run_in_thread
log = logging.getLogger('subiquity.lockfile') log = logging.getLogger("subiquity.lockfile")
class _LockContext: class _LockContext:
def __init__(self, lockfile, flags): def __init__(self, lockfile, flags):
self.lockfile = lockfile self.lockfile = lockfile
self.flags = flags self.flags = flags
@ -49,10 +48,9 @@ class _LockContext:
class Lockfile: class Lockfile:
def __init__(self, path): def __init__(self, path):
self.path = path self.path = path
self.fp = open(path, 'a+') self.fp = open(path, "a+")
def read_content(self): def read_content(self):
self.fp.seek(0) self.fp.seek(0)

View File

@ -18,11 +18,12 @@ from typing import Optional
from subiquity.common.types import AdConnectionInfo from subiquity.common.types import AdConnectionInfo
log = logging.getLogger('subiquity.models.ad') log = logging.getLogger("subiquity.models.ad")
class AdModel: class AdModel:
""" Models the Active Directory feature """ """Models the Active Directory feature"""
def __init__(self) -> None: def __init__(self) -> None:
self.do_join = False self.do_join = False
self.conn_info: Optional[AdConnectionInfo] = None self.conn_info: Optional[AdConnectionInfo] = None

View File

@ -15,7 +15,7 @@
import logging import logging
log = logging.getLogger('subiquity.models.codecs') log = logging.getLogger("subiquity.models.codecs")
class CodecsModel: class CodecsModel:

View File

@ -15,7 +15,7 @@
import logging import logging
log = logging.getLogger('subiquity.models.drivers') log = logging.getLogger("subiquity.models.drivers")
class DriversModel: class DriversModel:

File diff suppressed because it is too large Load Diff

View File

@ -17,8 +17,7 @@ import logging
import attr import attr
log = logging.getLogger("subiquity.models.identity")
log = logging.getLogger('subiquity.models.identity')
@attr.s @attr.s
@ -29,8 +28,7 @@ class User(object):
class IdentityModel(object): class IdentityModel(object):
""" Model representing user identity """Model representing user identity"""
"""
def __init__(self): def __init__(self):
self._user = None self._user = None
@ -39,11 +37,11 @@ class IdentityModel(object):
def add_user(self, identity_data): def add_user(self, identity_data):
self._hostname = identity_data.hostname self._hostname = identity_data.hostname
d = {} d = {}
d['realname'] = identity_data.realname d["realname"] = identity_data.realname
d['username'] = identity_data.username d["username"] = identity_data.username
d['password'] = identity_data.crypted_password d["password"] = identity_data.crypted_password
if not d['realname']: if not d["realname"]:
d['realname'] = identity_data.username d["realname"] = identity_data.username
self._user = User(**d) self._user = User(**d)
@property @property

View File

@ -17,7 +17,6 @@ from typing import Optional
class KernelModel: class KernelModel:
# The name of the kernel metapackage that we intend to install. # The name of the kernel metapackage that we intend to install.
metapkg_name: Optional[str] = None metapkg_name: Optional[str] = None
# During the installation, if we detect that a different kernel version is # During the installation, if we detect that a different kernel version is
@ -42,10 +41,10 @@ class KernelModel:
def render(self): def render(self):
if self.curthooks_no_install: if self.curthooks_no_install:
return {'kernel': None} return {"kernel": None}
return { return {
'kernel': { "kernel": {
'package': self.needed_kernel, "package": self.needed_kernel,
}, },
} }

View File

@ -14,17 +14,14 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging import logging
import re
import os import os
import re
import yaml import yaml
from subiquity.common.resources import resource_path from subiquity.common.resources import resource_path
from subiquity.common.serialize import Serializer from subiquity.common.serialize import Serializer
from subiquity.common.types import ( from subiquity.common.types import KeyboardLayout, KeyboardSetting
KeyboardLayout,
KeyboardSetting,
)
log = logging.getLogger("subiquity.models.keyboard") log = logging.getLogger("subiquity.models.keyboard")
@ -44,12 +41,13 @@ BACKSPACE="guess"
class InconsistentMultiLayoutError(ValueError): class InconsistentMultiLayoutError(ValueError):
""" Exception to raise when a multi layout has a different number of """Exception to raise when a multi layout has a different number of
layouts and variants. """ layouts and variants."""
def __init__(self, layouts: str, variants: str) -> None: def __init__(self, layouts: str, variants: str) -> None:
super().__init__( super().__init__(
f'inconsistent multi-layout: layouts="{layouts}"' f'inconsistent multi-layout: layouts="{layouts}"' f' variants="{variants}"'
f' variants="{variants}"') )
def from_config_file(config_file): def from_config_file(config_file):
@ -57,10 +55,10 @@ def from_config_file(config_file):
content = fp.read() content = fp.read()
def optval(opt, default): def optval(opt, default):
match = re.search(r'(?m)^\s*%s=(.*)$' % (opt,), content) match = re.search(r"(?m)^\s*%s=(.*)$" % (opt,), content)
if match: if match:
r = match.group(1).strip('"') r = match.group(1).strip('"')
if r != '': if r != "":
return r return r
return default return default
@ -68,24 +66,22 @@ def from_config_file(config_file):
XKBVARIANT = optval("XKBVARIANT", "") XKBVARIANT = optval("XKBVARIANT", "")
XKBOPTIONS = optval("XKBOPTIONS", "") XKBOPTIONS = optval("XKBOPTIONS", "")
toggle = None toggle = None
for option in XKBOPTIONS.split(','): for option in XKBOPTIONS.split(","):
if option.startswith('grp:'): if option.startswith("grp:"):
toggle = option[4:] toggle = option[4:]
return KeyboardSetting(layout=XKBLAYOUT, variant=XKBVARIANT, toggle=toggle) return KeyboardSetting(layout=XKBLAYOUT, variant=XKBVARIANT, toggle=toggle)
class KeyboardModel: class KeyboardModel:
def __init__(self, root): def __init__(self, root):
self.config_path = os.path.join( self.config_path = os.path.join(root, "etc", "default", "keyboard")
root, 'etc', 'default', 'keyboard')
self.layout_for_lang = self.load_layout_suggestions() self.layout_for_lang = self.load_layout_suggestions()
if os.path.exists(self.config_path): if os.path.exists(self.config_path):
self.default_setting = from_config_file(self.config_path) self.default_setting = from_config_file(self.config_path)
else: else:
self.default_setting = self.layout_for_lang['en_US.UTF-8'] self.default_setting = self.layout_for_lang["en_US.UTF-8"]
self.keyboard_list = KeyboardList() self.keyboard_list = KeyboardList()
self.keyboard_list.load_language('C') self.keyboard_list.load_language("C")
self._setting = None self._setting = None
@property @property
@ -105,41 +101,47 @@ class KeyboardModel:
if len(layout_tokens) != len(variant_tokens): if len(layout_tokens) != len(variant_tokens):
raise InconsistentMultiLayoutError( raise InconsistentMultiLayoutError(
layouts=setting.layout, variants=setting.variant) layouts=setting.layout, variants=setting.variant
)
for layout, variant in zip(layout_tokens, variant_tokens): for layout, variant in zip(layout_tokens, variant_tokens):
kbd_layout = self.keyboard_list.layout_map.get(layout) kbd_layout = self.keyboard_list.layout_map.get(layout)
if kbd_layout is None: if kbd_layout is None:
raise ValueError(f'Unknown keyboard layout "{layout}"') raise ValueError(f'Unknown keyboard layout "{layout}"')
if not any(kbd_variant.code == variant if not any(
for kbd_variant in kbd_layout.variants): kbd_variant.code == variant for kbd_variant in kbd_layout.variants
raise ValueError(f'Unknown keyboard variant "{variant}" ' ):
f'for layout "{layout}"') raise ValueError(
f'Unknown keyboard variant "{variant}" ' f'for layout "{layout}"'
)
def render_config_file(self): def render_config_file(self):
options = "" options = ""
if self.setting.toggle: if self.setting.toggle:
options = "grp:" + self.setting.toggle options = "grp:" + self.setting.toggle
return etc_default_keyboard_template.format( return etc_default_keyboard_template.format(
layout=self.setting.layout, layout=self.setting.layout, variant=self.setting.variant, options=options
variant=self.setting.variant, )
options=options)
def render(self): def render(self):
return { return {
'write_files': { "write_files": {
'etc_default_keyboard': { "etc_default_keyboard": {
'path': 'etc/default/keyboard', "path": "etc/default/keyboard",
'content': self.render_config_file(), "content": self.render_config_file(),
'permissions': 0o644, "permissions": 0o644,
}, },
}, },
'curthooks_commands': { "curthooks_commands": {
# The below command must be run after updating # The below command must be run after updating
# etc/default/keyboard on the target so that the initramfs uses # etc/default/keyboard on the target so that the initramfs uses
# the keyboard mapping selected by the user. See LP #1894009 # the keyboard mapping selected by the user. See LP #1894009
'002-setupcon-save-only': [ "002-setupcon-save-only": [
'curtin', 'in-target', '--', 'setupcon', '--save-only', "curtin",
"in-target",
"--",
"setupcon",
"--save-only",
], ],
}, },
} }
@ -154,7 +156,7 @@ class KeyboardModel:
def load_layout_suggestions(self, path=None): def load_layout_suggestions(self, path=None):
if path is None: if path is None:
path = resource_path('kbds') + '/keyboard-configuration.yaml' path = resource_path("kbds") + "/keyboard-configuration.yaml"
with open(path) as fp: with open(path) as fp:
data = yaml.safe_load(fp) data = yaml.safe_load(fp)
@ -166,25 +168,24 @@ class KeyboardModel:
class KeyboardList: class KeyboardList:
def __init__(self): def __init__(self):
self._kbnames_dir = resource_path('kbds') self._kbnames_dir = resource_path("kbds")
self.serializer = Serializer(compact=True) self.serializer = Serializer(compact=True)
self._clear() self._clear()
def _file_for_lang(self, code): def _file_for_lang(self, code):
return os.path.join(self._kbnames_dir, code + '.jsonl') return os.path.join(self._kbnames_dir, code + ".jsonl")
def _has_language(self, code): def _has_language(self, code):
return os.path.exists(self._file_for_lang(code)) return os.path.exists(self._file_for_lang(code))
def load_language(self, code): def load_language(self, code):
if '.' in code: if "." in code:
code = code.split('.')[0] code = code.split(".")[0]
if not self._has_language(code): if not self._has_language(code):
code = code.split('_')[0] code = code.split("_")[0]
if not self._has_language(code): if not self._has_language(code):
code = 'C' code = "C"
if code == self.current_lang: if code == self.current_lang:
return return

View File

@ -19,11 +19,11 @@ import subprocess
from subiquitycore.utils import arun_command, split_cmd_output from subiquitycore.utils import arun_command, split_cmd_output
log = logging.getLogger('subiquity.models.locale') log = logging.getLogger("subiquity.models.locale")
class LocaleModel: class LocaleModel:
""" Model representing locale selection """Model representing locale selection
XXX Only represents *language* selection for now. XXX Only represents *language* selection for now.
""" """
@ -38,11 +38,7 @@ class LocaleModel:
self.selected_language = code self.selected_language = code
async def localectl_set_locale(self) -> None: async def localectl_set_locale(self) -> None:
cmd = [ cmd = ["localectl", "set-locale", locale.normalize(self.selected_language)]
'localectl',
'set-locale',
locale.normalize(self.selected_language)
]
await arun_command(cmd, check=True) await arun_command(cmd, check=True)
async def try_localectl_set_locale(self) -> None: async def try_localectl_set_locale(self) -> None:
@ -60,18 +56,19 @@ class LocaleModel:
if self.locale_support == "none": if self.locale_support == "none":
return {} return {}
locale = self.selected_language locale = self.selected_language
if '.' not in locale and '_' in locale: if "." not in locale and "_" in locale:
locale += '.UTF-8' locale += ".UTF-8"
return {'locale': locale} return {"locale": locale}
async def target_packages(self): async def target_packages(self):
if self.selected_language is None: if self.selected_language is None:
return [] return []
if self.locale_support != "langpack": if self.locale_support != "langpack":
return [] return []
lang = self.selected_language.split('.')[0] lang = self.selected_language.split(".")[0]
if lang == "C": if lang == "C":
return [] return []
return await split_cmd_output( return await split_cmd_output(
self.chroot_prefix + ['check-language-support', '-l', lang], None) self.chroot_prefix + ["check-language-support", "-l", lang], None
)

View File

@ -72,33 +72,29 @@ primary entry
""" """
import abc import abc
import copy
import contextlib import contextlib
import copy
import logging import logging
from typing import ( from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Union
Any, Callable, Dict, Iterator, List,
Optional, Sequence, Set, Union,
)
from urllib import parse from urllib import parse
import attr import attr
from subiquity.common.types import MirrorSelectionFallback
from curtin.commands.apt_config import ( from curtin.commands.apt_config import (
get_arch_mirrorconfig,
get_mirror,
PORTS_ARCHES, PORTS_ARCHES,
PRIMARY_ARCHES, PRIMARY_ARCHES,
) get_arch_mirrorconfig,
get_mirror,
)
from curtin.config import merge_config from curtin.config import merge_config
from subiquity.common.types import MirrorSelectionFallback
try: try:
from curtin.distro import get_architecture from curtin.distro import get_architecture
except ImportError: except ImportError:
from curtin.util import get_architecture from curtin.util import get_architecture
log = logging.getLogger('subiquity.models.mirror') log = logging.getLogger("subiquity.models.mirror")
DEFAULT_SUPPORTED_ARCHES_URI = "http://archive.ubuntu.com/ubuntu" DEFAULT_SUPPORTED_ARCHES_URI = "http://archive.ubuntu.com/ubuntu"
DEFAULT_PORTS_ARCHES_URI = "http://ports.ubuntu.com/ubuntu-ports" DEFAULT_PORTS_ARCHES_URI = "http://ports.ubuntu.com/ubuntu-ports"
@ -107,7 +103,8 @@ LEGACY_DEFAULT_PRIMARY_SECTION = [
{ {
"arches": PRIMARY_ARCHES, "arches": PRIMARY_ARCHES,
"uri": DEFAULT_SUPPORTED_ARCHES_URI, "uri": DEFAULT_SUPPORTED_ARCHES_URI,
}, { },
{
"arches": ["default"], "arches": ["default"],
"uri": DEFAULT_PORTS_ARCHES_URI, "uri": DEFAULT_PORTS_ARCHES_URI,
}, },
@ -120,9 +117,10 @@ DEFAULT = {
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class BasePrimaryEntry(abc.ABC): class BasePrimaryEntry(abc.ABC):
""" Base class to represent an entry from the 'primary' autoinstall """Base class to represent an entry from the 'primary' autoinstall
section. A BasePrimaryEntry is expected to have a URI and therefore can be section. A BasePrimaryEntry is expected to have a URI and therefore can be
used as a primary candidate. """ used as a primary candidate."""
parent: "MirrorModel" = attr.ib(kw_only=True) parent: "MirrorModel" = attr.ib(kw_only=True)
def stage(self) -> None: def stage(self) -> None:
@ -133,19 +131,20 @@ class BasePrimaryEntry(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def serialize_for_ai(self) -> Any: def serialize_for_ai(self) -> Any:
""" Serialize the entry for autoinstall. """ """Serialize the entry for autoinstall."""
@abc.abstractmethod @abc.abstractmethod
def supports_arch(self, arch: str) -> bool: def supports_arch(self, arch: str) -> bool:
""" Tells whether the mirror claims to support the architecture """Tells whether the mirror claims to support the architecture
specified. """ specified."""
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class PrimaryEntry(BasePrimaryEntry): class PrimaryEntry(BasePrimaryEntry):
""" Represents a single primary mirror candidate; which can be converted """Represents a single primary mirror candidate; which can be converted
to/from an entry of the 'apt->mirror-selection->primary' autoinstall to/from an entry of the 'apt->mirror-selection->primary' autoinstall
section. """ section."""
# Having uri set to None is only valid for a country mirror. # Having uri set to None is only valid for a country mirror.
uri: Optional[str] = None uri: Optional[str] = None
# When arches is None, it is assumed that the mirror is compatible with the # When arches is None, it is assumed that the mirror is compatible with the
@ -183,10 +182,11 @@ class PrimaryEntry(BasePrimaryEntry):
class LegacyPrimaryEntry(BasePrimaryEntry): class LegacyPrimaryEntry(BasePrimaryEntry):
""" Represents a single primary mirror candidate; which can be converted """Represents a single primary mirror candidate; which can be converted
to/from the whole 'apt->primary' autoinstall section (legacy format). to/from the whole 'apt->primary' autoinstall section (legacy format).
The format is defined by curtin, so we make use of curtin to access The format is defined by curtin, so we make use of curtin to access
the elements. """ the elements."""
def __init__(self, config: List[Any], *, parent: "MirrorModel") -> None: def __init__(self, config: List[Any], *, parent: "MirrorModel") -> None:
self.config = config self.config = config
super().__init__(parent=parent) super().__init__(parent=parent)
@ -200,8 +200,8 @@ class LegacyPrimaryEntry(BasePrimaryEntry):
@uri.setter @uri.setter
def uri(self, uri: str) -> None: def uri(self, uri: str) -> None:
config = get_arch_mirrorconfig( config = get_arch_mirrorconfig(
{"primary": self.config}, {"primary": self.config}, "primary", self.parent.architecture
"primary", self.parent.architecture) )
config["uri"] = uri config["uri"] = uri
def mirror_is_default(self) -> bool: def mirror_is_default(self) -> bool:
@ -209,8 +209,7 @@ class LegacyPrimaryEntry(BasePrimaryEntry):
@classmethod @classmethod
def new_from_default(cls, parent: "MirrorModel") -> "LegacyPrimaryEntry": def new_from_default(cls, parent: "MirrorModel") -> "LegacyPrimaryEntry":
return cls(copy.deepcopy(LEGACY_DEFAULT_PRIMARY_SECTION), return cls(copy.deepcopy(LEGACY_DEFAULT_PRIMARY_SECTION), parent=parent)
parent=parent)
def serialize_for_ai(self) -> List[Any]: def serialize_for_ai(self) -> List[Any]:
return self.config return self.config
@ -222,18 +221,18 @@ class LegacyPrimaryEntry(BasePrimaryEntry):
def countrify_uri(uri: str, cc: str) -> str: def countrify_uri(uri: str, cc: str) -> str:
""" Return a URL where the host is prefixed with a country code. """ """Return a URL where the host is prefixed with a country code."""
parsed = parse.urlparse(uri) parsed = parse.urlparse(uri)
new = parsed._replace(netloc=cc + '.' + parsed.netloc) new = parsed._replace(netloc=cc + "." + parsed.netloc)
return parse.urlunparse(new) return parse.urlunparse(new)
CandidateFilter = Callable[[BasePrimaryEntry], bool] CandidateFilter = Callable[[BasePrimaryEntry], bool]
def filter_candidates(candidates: List[BasePrimaryEntry], def filter_candidates(
*, filters: Sequence[CandidateFilter]) \ candidates: List[BasePrimaryEntry], *, filters: Sequence[CandidateFilter]
-> Iterator[BasePrimaryEntry]: ) -> Iterator[BasePrimaryEntry]:
candidates_iter = iter(candidates) candidates_iter = iter(candidates)
for filt in filters: for filt in filters:
candidates_iter = filter(filt, candidates_iter) candidates_iter = filter(filt, candidates_iter)
@ -241,21 +240,20 @@ def filter_candidates(candidates: List[BasePrimaryEntry],
class MirrorModel(object): class MirrorModel(object):
def __init__(self): def __init__(self):
self.config = copy.deepcopy(DEFAULT) self.config = copy.deepcopy(DEFAULT)
self.legacy_primary = False self.legacy_primary = False
self.disabled_components: Set[str] = set() self.disabled_components: Set[str] = set()
self.primary_elected: Optional[BasePrimaryEntry] = None self.primary_elected: Optional[BasePrimaryEntry] = None
self.primary_candidates: List[BasePrimaryEntry] = \ self.primary_candidates: List[
self._default_primary_entries() BasePrimaryEntry
] = self._default_primary_entries()
self.primary_staged: Optional[BasePrimaryEntry] = None self.primary_staged: Optional[BasePrimaryEntry] = None
self.architecture = get_architecture() self.architecture = get_architecture()
# Only useful for legacy primary sections # Only useful for legacy primary sections
self.default_mirror = \ self.default_mirror = LegacyPrimaryEntry.new_from_default(parent=self).uri
LegacyPrimaryEntry.new_from_default(parent=self).uri
# What to do if automatic mirror-selection fails. # What to do if automatic mirror-selection fails.
self.fallback = MirrorSelectionFallback.ABORT self.fallback = MirrorSelectionFallback.ABORT
@ -263,14 +261,17 @@ class MirrorModel(object):
def _default_primary_entries(self) -> List[PrimaryEntry]: def _default_primary_entries(self) -> List[PrimaryEntry]:
return [ return [
PrimaryEntry(parent=self, country_mirror=True), PrimaryEntry(parent=self, country_mirror=True),
PrimaryEntry(uri=DEFAULT_SUPPORTED_ARCHES_URI, PrimaryEntry(
arches=PRIMARY_ARCHES, parent=self), uri=DEFAULT_SUPPORTED_ARCHES_URI, arches=PRIMARY_ARCHES, parent=self
PrimaryEntry(uri=DEFAULT_PORTS_ARCHES_URI, ),
arches=PORTS_ARCHES, parent=self), PrimaryEntry(
uri=DEFAULT_PORTS_ARCHES_URI, arches=PORTS_ARCHES, parent=self
),
] ]
def get_default_primary_candidates( def get_default_primary_candidates(
self, legacy: Optional[bool] = None) -> Sequence[BasePrimaryEntry]: self, legacy: Optional[bool] = None
) -> Sequence[BasePrimaryEntry]:
want_legacy = legacy if legacy is not None else self.legacy_primary want_legacy = legacy if legacy is not None else self.legacy_primary
if want_legacy: if want_legacy:
return [LegacyPrimaryEntry.new_from_default(parent=self)] return [LegacyPrimaryEntry.new_from_default(parent=self)]
@ -282,15 +283,15 @@ class MirrorModel(object):
self.disabled_components = set(data.pop("disable_components")) self.disabled_components = set(data.pop("disable_components"))
if "primary" in data and "mirror-selection" in data: if "primary" in data and "mirror-selection" in data:
raise ValueError("apt->primary and apt->mirror-selection are" raise ValueError(
" mutually exclusive.") "apt->primary and apt->mirror-selection are" " mutually exclusive."
)
self.legacy_primary = "primary" in data self.legacy_primary = "primary" in data
primary_candidates = self.get_default_primary_candidates() primary_candidates = self.get_default_primary_candidates()
if "primary" in data: if "primary" in data:
# Legacy sections only support a single candidate # Legacy sections only support a single candidate
primary_candidates = \ primary_candidates = [LegacyPrimaryEntry(data.pop("primary"), parent=self)]
[LegacyPrimaryEntry(data.pop("primary"), parent=self)]
if "mirror-selection" in data: if "mirror-selection" in data:
mirror_selection = data.pop("mirror-selection") mirror_selection = data.pop("mirror-selection")
if "primary" in mirror_selection: if "primary" in mirror_selection:
@ -314,7 +315,8 @@ class MirrorModel(object):
return config return config
def _get_apt_config_using_candidate( def _get_apt_config_using_candidate(
self, candidate: BasePrimaryEntry) -> Dict[str, Any]: self, candidate: BasePrimaryEntry
) -> Dict[str, Any]:
config = self._get_apt_config_common() config = self._get_apt_config_common()
config["primary"] = candidate.config config["primary"] = candidate.config
return config return config
@ -327,8 +329,7 @@ class MirrorModel(object):
assert self.primary_elected is not None assert self.primary_elected is not None
return self._get_apt_config_using_candidate(self.primary_elected) return self._get_apt_config_using_candidate(self.primary_elected)
def get_apt_config( def get_apt_config(self, final: bool, has_network: bool) -> Dict[str, Any]:
self, final: bool, has_network: bool) -> Dict[str, Any]:
if not final: if not final:
return self.get_apt_config_staged() return self.get_apt_config_staged()
if has_network: if has_network:
@ -347,15 +348,16 @@ class MirrorModel(object):
lambda c: c.uri is not None, lambda c: c.uri is not None,
lambda c: c.supports_arch(self.architecture), lambda c: c.supports_arch(self.architecture),
] ]
candidate = next(filter_candidates(self.primary_candidates, candidate = next(
filters=filters)) filter_candidates(self.primary_candidates, filters=filters)
)
return self._get_apt_config_using_candidate(candidate) return self._get_apt_config_using_candidate(candidate)
# Our last resort is to include no primary section. Curtin will use # Our last resort is to include no primary section. Curtin will use
# its own internal values. # its own internal values.
return self._get_apt_config_common() return self._get_apt_config_common()
def set_country(self, cc): def set_country(self, cc):
""" Set the URI of country-mirror candidates. """ """Set the URI of country-mirror candidates."""
for candidate in self.country_mirror_candidates(): for candidate in self.country_mirror_candidates():
if self.legacy_primary: if self.legacy_primary:
candidate.uri = countrify_uri(candidate.uri, cc=cc) candidate.uri = countrify_uri(candidate.uri, cc=cc)
@ -367,8 +369,8 @@ class MirrorModel(object):
candidate.uri = countrify_uri(uri, cc=cc) candidate.uri = countrify_uri(uri, cc=cc)
def disable_components(self, comps, add: bool) -> None: def disable_components(self, comps, add: bool) -> None:
""" Add (or remove) a component (e.g., multiverse) from the list of """Add (or remove) a component (e.g., multiverse) from the list of
disabled components. """ disabled components."""
comps = set(comps) comps = set(comps)
if add: if add:
self.disabled_components |= comps self.disabled_components |= comps
@ -376,19 +378,17 @@ class MirrorModel(object):
self.disabled_components -= comps self.disabled_components -= comps
def create_primary_candidate( def create_primary_candidate(
self, uri: Optional[str], self, uri: Optional[str], country_mirror: bool = False
country_mirror: bool = False) -> BasePrimaryEntry: ) -> BasePrimaryEntry:
if self.legacy_primary: if self.legacy_primary:
entry = LegacyPrimaryEntry.new_from_default(parent=self) entry = LegacyPrimaryEntry.new_from_default(parent=self)
entry.uri = uri entry.uri = uri
return entry return entry
return PrimaryEntry(uri=uri, country_mirror=country_mirror, return PrimaryEntry(uri=uri, country_mirror=country_mirror, parent=self)
parent=self)
def wants_geoip(self) -> bool: def wants_geoip(self) -> bool:
""" Tell whether geoip results would be useful. """ """Tell whether geoip results would be useful."""
return next(self.country_mirror_candidates(), None) is not None return next(self.country_mirror_candidates(), None) is not None
def country_mirror_candidates(self) -> Iterator[BasePrimaryEntry]: def country_mirror_candidates(self) -> Iterator[BasePrimaryEntry]:

View File

@ -20,11 +20,10 @@ from subiquity import cloudinit
from subiquitycore.models.network import NetworkModel as CoreNetworkModel from subiquitycore.models.network import NetworkModel as CoreNetworkModel
from subiquitycore.utils import arun_command from subiquitycore.utils import arun_command
log = logging.getLogger('subiquity.models.network') log = logging.getLogger("subiquity.models.network")
class NetworkModel(CoreNetworkModel): class NetworkModel(CoreNetworkModel):
def __init__(self): def __init__(self):
super().__init__("subiquity") super().__init__("subiquity")
self.override_config = None self.override_config = None
@ -46,59 +45,57 @@ class NetworkModel(CoreNetworkModel):
# If host cloud-init version has no readable combined-cloud-config, # If host cloud-init version has no readable combined-cloud-config,
# default to False. # default to False.
cloud_cfg = cloudinit.get_host_combined_cloud_config() cloud_cfg = cloudinit.get_host_combined_cloud_config()
use_cloudinit_net = cloud_cfg.get('features', {}).get( use_cloudinit_net = cloud_cfg.get("features", {}).get(
'NETPLAN_CONFIG_ROOT_READ_ONLY', False "NETPLAN_CONFIG_ROOT_READ_ONLY", False
) )
if use_cloudinit_net: if use_cloudinit_net:
r = { r = {
'write_files': { "write_files": {
'etc_netplan_installer': { "etc_netplan_installer": {
'path': ( "path": ("etc/cloud/cloud.cfg.d/90-installer-network.cfg"),
'etc/cloud/cloud.cfg.d/90-installer-network.cfg' "content": self.stringify_config(netplan),
), "permissions": "0600",
'content': self.stringify_config(netplan),
'permissions': '0600',
}, },
} }
} }
else: else:
# Separate sensitive wifi config from world-readable config # Separate sensitive wifi config from world-readable config
wifis = netplan['network'].pop('wifis', None) wifis = netplan["network"].pop("wifis", None)
r = { r = {
'write_files': { "write_files": {
# Disable cloud-init networking # Disable cloud-init networking
'no_cloudinit_net': { "no_cloudinit_net": {
'path': ( "path": (
'etc/cloud/cloud.cfg.d/' "etc/cloud/cloud.cfg.d/"
'subiquity-disable-cloudinit-networking.cfg' "subiquity-disable-cloudinit-networking.cfg"
), ),
'content': 'network: {config: disabled}\n', "content": "network: {config: disabled}\n",
}, },
# World-readable netplan without sensitive wifi config # World-readable netplan without sensitive wifi config
'etc_netplan_installer': { "etc_netplan_installer": {
'path': 'etc/netplan/00-installer-config.yaml', "path": "etc/netplan/00-installer-config.yaml",
'content': self.stringify_config(netplan), "content": self.stringify_config(netplan),
}, },
}, },
} }
if wifis is not None: if wifis is not None:
netplan_wifi = { netplan_wifi = {
'network': { "network": {
'version': 2, "version": 2,
'wifis': wifis, "wifis": wifis,
}, },
} }
r['write_files']['etc_netplan_installer_wifi'] = { r["write_files"]["etc_netplan_installer_wifi"] = {
'path': 'etc/netplan/00-installer-config-wifi.yaml', "path": "etc/netplan/00-installer-config-wifi.yaml",
'content': self.stringify_config(netplan_wifi), "content": self.stringify_config(netplan_wifi),
'permissions': '0600', "permissions": "0600",
} }
return r return r
async def target_packages(self): async def target_packages(self):
if self.needs_wpasupplicant: if self.needs_wpasupplicant:
return ['wpasupplicant'] return ["wpasupplicant"]
else: else:
return [] return []

View File

@ -18,7 +18,7 @@ from typing import Any, Dict, List, Optional, Union
import attr import attr
log = logging.getLogger('subiquity.models.oem') log = logging.getLogger("subiquity.models.oem")
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)

View File

@ -15,17 +15,16 @@
import logging import logging
log = logging.getLogger('subiquity.models.proxy') log = logging.getLogger("subiquity.models.proxy")
dropin_template = '''\ dropin_template = """\
[Service] [Service]
Environment="HTTP_PROXY={proxy}" Environment="HTTP_PROXY={proxy}"
Environment="HTTPS_PROXY={proxy}" Environment="HTTPS_PROXY={proxy}"
''' """
class ProxyModel(object): class ProxyModel(object):
def __init__(self): def __init__(self):
self.proxy = "" self.proxy = ""
@ -35,8 +34,8 @@ class ProxyModel(object):
def get_apt_config(self, final: bool, has_network: bool): def get_apt_config(self, final: bool, has_network: bool):
if self.proxy: if self.proxy:
return { return {
'http_proxy': self.proxy, "http_proxy": self.proxy,
'https_proxy': self.proxy, "https_proxy": self.proxy,
} }
else: else:
return {} return {}
@ -44,16 +43,17 @@ class ProxyModel(object):
def render(self): def render(self):
if self.proxy: if self.proxy:
return { return {
'proxy': { "proxy": {
'http_proxy': self.proxy, "http_proxy": self.proxy,
'https_proxy': self.proxy, "https_proxy": self.proxy,
}, },
'write_files': { "write_files": {
'snapd_dropin': { "snapd_dropin": {
'path': ('etc/systemd/system/' "path": (
'snapd.service.d/snap_proxy.conf'), "etc/systemd/system/" "snapd.service.d/snap_proxy.conf"
'content': self.proxy_systemd_dropin(), ),
'permissions': 0o644, "content": self.proxy_systemd_dropin(),
"permissions": 0o644,
}, },
}, },
} }

View File

@ -16,11 +16,7 @@
import datetime import datetime
import logging import logging
from subiquity.common.types import ( from subiquity.common.types import ChannelSnapInfo, SnapInfo
ChannelSnapInfo,
SnapInfo,
)
log = logging.getLogger("subiquity.models.snaplist") log = logging.getLogger("subiquity.models.snaplist")
@ -44,47 +40,49 @@ class SnapListModel:
return s return s
def load_find_data(self, data): def load_find_data(self, data):
for info in data['result']: for info in data["result"]:
self.update(self._snap_for_name(info['name']), info) self.update(self._snap_for_name(info["name"]), info)
def add_partial_snap(self, name): def add_partial_snap(self, name):
self._snaps_for_name(name) self._snaps_for_name(name)
def update(self, snap, data): def update(self, snap, data):
snap.summary = data['summary'] snap.summary = data["summary"]
snap.publisher = data['developer'] snap.publisher = data["developer"]
snap.verified = data['publisher']['validation'] == "verified" snap.verified = data["publisher"]["validation"] == "verified"
snap.starred = data['publisher']['validation'] == "starred" snap.starred = data["publisher"]["validation"] == "starred"
snap.description = data['description'] snap.description = data["description"]
snap.confinement = data['confinement'] snap.confinement = data["confinement"]
snap.license = data['license'] snap.license = data["license"]
self.complete_snaps.add(snap) self.complete_snaps.add(snap)
def load_info_data(self, data): def load_info_data(self, data):
info = data['result'][0] info = data["result"][0]
snap = self._snaps_by_name.get(info['name']) snap = self._snaps_by_name.get(info["name"])
if snap is None: if snap is None:
return return
if snap not in self.complete_snaps: if snap not in self.complete_snaps:
self.update_snap(snap, info) self.update_snap(snap, info)
channel_map = info['channels'] channel_map = info["channels"]
for track in info['tracks']: for track in info["tracks"]:
for risk in risks: for risk in risks:
channel_name = '{}/{}'.format(track, risk) channel_name = "{}/{}".format(track, risk)
if channel_name in channel_map: if channel_name in channel_map:
channel_data = channel_map[channel_name] channel_data = channel_map[channel_name]
if track == "latest": if track == "latest":
channel_name = risk channel_name = risk
snap.channels.append(ChannelSnapInfo( snap.channels.append(
ChannelSnapInfo(
channel_name=channel_name, channel_name=channel_name,
revision=channel_data['revision'], revision=channel_data["revision"],
confinement=channel_data['confinement'], confinement=channel_data["confinement"],
version=channel_data['version'], version=channel_data["version"],
size=channel_data['size'], size=channel_data["size"],
released_at=datetime.datetime.strptime( released_at=datetime.datetime.strptime(
channel_data['released-at'], channel_data["released-at"], "%Y-%m-%dT%H:%M:%S.%fZ"
'%Y-%m-%dT%H:%M:%S.%fZ'), ),
)) )
)
return snap return snap
def get_snap_list(self): def get_snap_list(self):
@ -100,9 +98,9 @@ class SnapListModel:
return {} return {}
cmds = [] cmds = []
for selection in self.selections: for selection in self.selections:
cmd = ['snap', 'install', '--channel=' + selection.channel] cmd = ["snap", "install", "--channel=" + selection.channel]
if selection.classic: if selection.classic:
cmd.append('--classic') cmd.append("--classic")
cmd.append(selection.name) cmd.append(selection.name)
cmds.append(' '.join(cmd)) cmds.append(" ".join(cmd))
return {'snap': {'commands': cmds}} return {"snap": {"commands": cmds}}

View File

@ -16,13 +16,13 @@
import logging import logging
import os import os
import typing import typing
import yaml
import attr import attr
import yaml
from subiquity.common.serialize import Serializer from subiquity.common.serialize import Serializer
log = logging.getLogger('subiquity.models.source') log = logging.getLogger("subiquity.models.source")
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
@ -49,33 +49,32 @@ class CatalogEntry:
def __attrs_post_init__(self): def __attrs_post_init__(self):
if not self.variations: if not self.variations:
self.variations['default'] = CatalogEntryVariation( self.variations["default"] = CatalogEntryVariation(
path=self.path, path=self.path,
size=self.size, size=self.size,
snapd_system_label=self.snapd_system_label) snapd_system_label=self.snapd_system_label,
)
legacy_server_entry = CatalogEntry( legacy_server_entry = CatalogEntry(
variant='server', variant="server",
id='synthesized', id="synthesized",
name={'en': 'Ubuntu Server'}, name={"en": "Ubuntu Server"},
description={'en': 'the default'}, description={"en": "the default"},
path='/media/filesystem', path="/media/filesystem",
type='cp', type="cp",
default=True, default=True,
size=2 << 30, size=2 << 30,
locale_support="locale-only", locale_support="locale-only",
variations={ variations={
'default': CatalogEntryVariation( "default": CatalogEntryVariation(path="/media/filesystem", size=2 << 30),
path='/media/filesystem', },
size=2 << 30), )
})
class SourceModel: class SourceModel:
def __init__(self): def __init__(self):
self._dir = '/cdrom/casper' self._dir = "/cdrom/casper"
self.current = legacy_server_entry self.current = legacy_server_entry
self.sources = [self.current] self.sources = [self.current]
self.lang = None self.lang = None
@ -86,8 +85,8 @@ class SourceModel:
self.sources = [] self.sources = []
self.current = None self.current = None
self.sources = Serializer(ignore_unknown_fields=True).deserialize( self.sources = Serializer(ignore_unknown_fields=True).deserialize(
typing.List[CatalogEntry], typing.List[CatalogEntry], yaml.safe_load(fp)
yaml.safe_load(fp)) )
for entry in self.sources: for entry in self.sources:
if entry.default: if entry.default:
self.current = entry self.current = entry
@ -96,7 +95,7 @@ class SourceModel:
self.current = self.sources[0] self.current = self.sources[0]
def get_matching_source(self, id_: str) -> CatalogEntry: def get_matching_source(self, id_: str) -> CatalogEntry:
""" Return a source object that has the ID requested. """ """Return a source object that has the ID requested."""
for source in self.sources: for source in self.sources:
if source.id == id_: if source.id == id_:
return source return source
@ -113,10 +112,10 @@ class SourceModel:
if self.lang in self.current.preinstalled_langs: if self.lang in self.current.preinstalled_langs:
suffix = self.lang suffix = self.lang
else: else:
suffix = 'no-languages' suffix = "no-languages"
path = base + '.' + suffix + ext path = base + "." + suffix + ext
scheme = self.current.type scheme = self.current.type
return f'{scheme}://{path}' return f"{scheme}://{path}"
def render(self): def render(self):
return {} return {}

View File

@ -20,7 +20,6 @@ log = logging.getLogger("subiquity.models.ssh")
class SSHModel: class SSHModel:
def __init__(self): def __init__(self):
self.install_server = False self.install_server = False
self.authorized_keys: List[str] = [] self.authorized_keys: List[str] = []
@ -28,10 +27,10 @@ class SSHModel:
# Although the generated config just contains the key above, # Although the generated config just contains the key above,
# we store the imported id so that we can re-fill the form if # we store the imported id so that we can re-fill the form if
# we go back to it. # we go back to it.
self.ssh_import_id = '' self.ssh_import_id = ""
async def target_packages(self): async def target_packages(self):
if self.install_server: if self.install_server:
return ['openssh-server'] return ["openssh-server"]
else: else:
return [] return []

View File

@ -14,37 +14,35 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio import asyncio
from collections import OrderedDict
import functools import functools
import json import json
import logging import logging
import os import os
from typing import Any, Dict, Set
import uuid import uuid
import yaml from collections import OrderedDict
from typing import Any, Dict, Set
import yaml
from cloudinit.config.schema import ( from cloudinit.config.schema import (
SchemaValidationError, SchemaValidationError,
get_schema, get_schema,
validate_cloudconfig_schema validate_cloudconfig_schema,
) )
try: try:
from cloudinit.config.schema import SchemaProblem from cloudinit.config.schema import SchemaProblem
except ImportError: except ImportError:
def SchemaProblem(x, y): return (x, y) # TODO(drop on cloud-init 22.3 SRU)
def SchemaProblem(x, y):
return (x, y) # TODO(drop on cloud-init 22.3 SRU)
from curtin.config import merge_config from curtin.config import merge_config
from subiquitycore.file_util import (
generate_timestamped_header,
write_file,
)
from subiquitycore.lsb_release import lsb_release
from subiquity.common.resources import get_users_and_groups from subiquity.common.resources import get_users_and_groups
from subiquity.server.types import InstallerChannels from subiquity.server.types import InstallerChannels
from subiquitycore.file_util import generate_timestamped_header, write_file
from subiquitycore.lsb_release import lsb_release
from .ad import AdModel from .ad import AdModel
from .codecs import CodecsModel from .codecs import CodecsModel
@ -66,8 +64,7 @@ from .timezone import TimeZoneModel
from .ubuntu_pro import UbuntuProModel from .ubuntu_pro import UbuntuProModel
from .updates import UpdatesModel from .updates import UpdatesModel
log = logging.getLogger("subiquity.models.subiquity")
log = logging.getLogger('subiquity.models.subiquity')
def merge_cloud_init_config(target, source): def merge_cloud_init_config(target, source):
@ -136,7 +133,7 @@ CLOUDINIT_DISABLE_AFTER_INSTALL = {
"Disabled by Ubuntu live installer after first boot.\n" "Disabled by Ubuntu live installer after first boot.\n"
"To re-enable cloud-init on this image run:\n" "To re-enable cloud-init on this image run:\n"
" sudo cloud-init clean --machine-id\n" " sudo cloud-init clean --machine-id\n"
) ),
} }
@ -156,29 +153,26 @@ class ModelNames:
class DebconfSelectionsModel: class DebconfSelectionsModel:
def __init__(self): def __init__(self):
self.selections = '' self.selections = ""
def render(self): def render(self):
return {} return {}
def get_apt_config( def get_apt_config(self, final: bool, has_network: bool) -> Dict[str, Any]:
self, final: bool, has_network: bool) -> Dict[str, Any]: return {"debconf_selections": {"subiquity": self.selections}}
return {'debconf_selections': {'subiquity': self.selections}}
class SubiquityModel: class SubiquityModel:
"""The overall model for subiquity.""" """The overall model for subiquity."""
target = '/target' target = "/target"
chroot_prefix = ['chroot', target] chroot_prefix = ["chroot", target]
def __init__(self, root, hub, install_model_names, def __init__(self, root, hub, install_model_names, postinstall_model_names):
postinstall_model_names):
self.root = root self.root = root
self.hub = hub self.hub = hub
if root != '/': if root != "/":
self.target = root self.target = root
self.chroot_prefix = [] self.chroot_prefix = []
@ -212,8 +206,7 @@ class SubiquityModel:
self._install_model_names = install_model_names self._install_model_names = install_model_names
self._postinstall_model_names = postinstall_model_names self._postinstall_model_names = postinstall_model_names
self._cur_install_model_names = install_model_names.default_names self._cur_install_model_names = install_model_names.default_names
self._cur_postinstall_model_names = \ self._cur_postinstall_model_names = postinstall_model_names.default_names
postinstall_model_names.default_names
self._install_event = asyncio.Event() self._install_event = asyncio.Event()
self._postinstall_event = asyncio.Event() self._postinstall_event = asyncio.Event()
all_names = set() all_names = set()
@ -222,15 +215,17 @@ class SubiquityModel:
for name in all_names: for name in all_names:
hub.subscribe( hub.subscribe(
(InstallerChannels.CONFIGURED, name), (InstallerChannels.CONFIGURED, name),
functools.partial(self._configured, name)) functools.partial(self._configured, name),
)
def set_source_variant(self, variant): def set_source_variant(self, variant):
self._cur_install_model_names = \ self._cur_install_model_names = self._install_model_names.for_variant(variant)
self._install_model_names.for_variant(variant) self._cur_postinstall_model_names = self._postinstall_model_names.for_variant(
self._cur_postinstall_model_names = \ variant
self._postinstall_model_names.for_variant(variant) )
unconfigured_install_model_names = \ unconfigured_install_model_names = (
self._cur_install_model_names - self._configured_names self._cur_install_model_names - self._configured_names
)
if unconfigured_install_model_names: if unconfigured_install_model_names:
if self._install_event.is_set(): if self._install_event.is_set():
self._install_event = asyncio.Event() self._install_event = asyncio.Event()
@ -238,8 +233,9 @@ class SubiquityModel:
self._confirmation_task.cancel() self._confirmation_task.cancel()
else: else:
self._install_event.set() self._install_event.set()
unconfigured_postinstall_model_names = \ unconfigured_postinstall_model_names = (
self._cur_postinstall_model_names - self._configured_names self._cur_postinstall_model_names - self._configured_names
)
if unconfigured_postinstall_model_names: if unconfigured_postinstall_model_names:
if self._postinstall_event.is_set(): if self._postinstall_event.is_set():
self._postinstall_event = asyncio.Event() self._postinstall_event = asyncio.Event()
@ -247,27 +243,34 @@ class SubiquityModel:
self._postinstall_event.set() self._postinstall_event.set()
def _configured(self, model_name): def _configured(self, model_name):
""" Add the model to the set of models that have been configured. If """Add the model to the set of models that have been configured. If
there is no more model to configure in the relevant section(s) (i.e., there is no more model to configure in the relevant section(s) (i.e.,
INSTALL or POSTINSTALL), we trigger the associated event(s). """ INSTALL or POSTINSTALL), we trigger the associated event(s)."""
def log_and_trigger(stage: str, names: Set[str],
event: asyncio.Event) -> None: def log_and_trigger(stage: str, names: Set[str], event: asyncio.Event) -> None:
unconfigured = names - self._configured_names unconfigured = names - self._configured_names
log.debug( log.debug(
"model %s for %s stage is configured, to go %s", "model %s for %s stage is configured, to go %s",
model_name, stage, unconfigured) model_name,
stage,
unconfigured,
)
if not unconfigured: if not unconfigured:
event.set() event.set()
self._configured_names.add(model_name) self._configured_names.add(model_name)
if model_name in self._cur_install_model_names: if model_name in self._cur_install_model_names:
log_and_trigger(stage="install", log_and_trigger(
stage="install",
names=self._cur_install_model_names, names=self._cur_install_model_names,
event=self._install_event) event=self._install_event,
)
if model_name in self._cur_postinstall_model_names: if model_name in self._cur_postinstall_model_names:
log_and_trigger(stage="postinstall", log_and_trigger(
stage="postinstall",
names=self._cur_postinstall_model_names, names=self._cur_postinstall_model_names,
event=self._postinstall_event) event=self._postinstall_event,
)
async def wait_install(self): async def wait_install(self):
if len(self._cur_install_model_names) == 0: if len(self._cur_install_model_names) == 0:
@ -281,8 +284,7 @@ class SubiquityModel:
async def wait_confirmation(self): async def wait_confirmation(self):
if self._confirmation_task is None: if self._confirmation_task is None:
self._confirmation_task = asyncio.create_task( self._confirmation_task = asyncio.create_task(self._confirmation.wait())
self._confirmation.wait())
try: try:
await self._confirmation_task await self._confirmation_task
except asyncio.CancelledError: except asyncio.CancelledError:
@ -293,8 +295,10 @@ class SubiquityModel:
self._confirmation_task = None self._confirmation_task = None
def is_postinstall_only(self, model_name): def is_postinstall_only(self, model_name):
return model_name in self._cur_postinstall_model_names and \ return (
model_name not in self._cur_install_model_names model_name in self._cur_postinstall_model_names
and model_name not in self._cur_install_model_names
)
async def confirm(self): async def confirm(self):
self._confirmation.set() self._confirmation.set()
@ -326,9 +330,9 @@ class SubiquityModel:
if warnings: if warnings:
log.warning( log.warning(
"The cloud-init configuration for %s contains" "The cloud-init configuration for %s contains"
" deprecated values:\n%s", data_source, "\n".join( " deprecated values:\n%s",
warnings data_source,
) "\n".join(warnings),
) )
if e.schema_errors: if e.schema_errors:
if data_source == "autoinstall.user-data": if data_source == "autoinstall.user-data":
@ -342,50 +346,46 @@ class SubiquityModel:
def _cloud_init_config(self): def _cloud_init_config(self):
config = { config = {
'growpart': { "growpart": {
'mode': 'off', "mode": "off",
}, },
'resize_rootfs': False, "resize_rootfs": False,
} }
if self.identity.hostname is not None: if self.identity.hostname is not None:
config['preserve_hostname'] = True config["preserve_hostname"] = True
user = self.identity.user user = self.identity.user
if user: if user:
groups = get_users_and_groups(self.chroot_prefix) groups = get_users_and_groups(self.chroot_prefix)
user_info = { user_info = {
'name': user.username, "name": user.username,
'gecos': user.realname, "gecos": user.realname,
'passwd': user.password, "passwd": user.password,
'shell': '/bin/bash', "shell": "/bin/bash",
'groups': ','.join(sorted(groups)), "groups": ",".join(sorted(groups)),
'lock_passwd': False, "lock_passwd": False,
} }
if self.ssh.authorized_keys: if self.ssh.authorized_keys:
user_info['ssh_authorized_keys'] = self.ssh.authorized_keys user_info["ssh_authorized_keys"] = self.ssh.authorized_keys
config['users'] = [user_info] config["users"] = [user_info]
else: else:
if self.ssh.authorized_keys: if self.ssh.authorized_keys:
config['ssh_authorized_keys'] = self.ssh.authorized_keys config["ssh_authorized_keys"] = self.ssh.authorized_keys
if self.ssh.install_server: if self.ssh.install_server:
config['ssh_pwauth'] = self.ssh.pwauth config["ssh_pwauth"] = self.ssh.pwauth
for model_name in self._postinstall_model_names.all(): for model_name in self._postinstall_model_names.all():
model = getattr(self, model_name) model = getattr(self, model_name)
if getattr(model, 'make_cloudconfig', None): if getattr(model, "make_cloudconfig", None):
merge_config(config, model.make_cloudconfig()) merge_config(config, model.make_cloudconfig())
merge_cloud_init_config(config, self.userdata) merge_cloud_init_config(config, self.userdata)
if lsb_release()['release'] not in ("20.04", "22.04"): if lsb_release()["release"] not in ("20.04", "22.04"):
config.setdefault("write_files", []).append( config.setdefault("write_files", []).append(CLOUDINIT_DISABLE_AFTER_INSTALL)
CLOUDINIT_DISABLE_AFTER_INSTALL self.validate_cloudconfig_schema(data=config, data_source="system install")
)
self.validate_cloudconfig_schema(
data=config, data_source="system install"
)
return config return config
async def target_packages(self): async def target_packages(self):
packages = list(self.packages) packages = list(self.packages)
for model_name in self._postinstall_model_names.all(): for model_name in self._postinstall_model_names.all():
meth = getattr(getattr(self, model_name), 'target_packages', None) meth = getattr(getattr(self, model_name), "target_packages", None)
if meth is not None: if meth is not None:
packages.extend(await meth()) packages.extend(await meth())
return packages return packages
@ -395,49 +395,55 @@ class SubiquityModel:
# first boot of the target, it reconfigures datasource_list to none # first boot of the target, it reconfigures datasource_list to none
# for subsequent boots. # for subsequent boots.
# (mwhudson does not entirely know what the above means!) # (mwhudson does not entirely know what the above means!)
userdata = '#cloud-config\n' + yaml.dump(self._cloud_init_config()) userdata = "#cloud-config\n" + yaml.dump(self._cloud_init_config())
metadata = {'instance-id': str(uuid.uuid4())} metadata = {"instance-id": str(uuid.uuid4())}
config = yaml.dump({ config = yaml.dump(
'datasource_list': ["None"], {
'datasource': { "datasource_list": ["None"],
"datasource": {
"None": { "None": {
'userdata_raw': userdata, "userdata_raw": userdata,
'metadata': metadata, "metadata": metadata,
}, },
}, },
}) }
)
files = [ files = [
('etc/cloud/cloud.cfg.d/99-installer.cfg', config, 0o600), ("etc/cloud/cloud.cfg.d/99-installer.cfg", config, 0o600),
('etc/cloud/ds-identify.cfg', 'policy: enabled\n', 0o644), ("etc/cloud/ds-identify.cfg", "policy: enabled\n", 0o644),
] ]
# Add cloud-init clean hooks to support golden-image creation. # Add cloud-init clean hooks to support golden-image creation.
cfg_files = ["/" + path for (path, _content, _cmode) in files] cfg_files = ["/" + path for (path, _content, _cmode) in files]
cfg_files.extend(self.network.rendered_config_paths()) cfg_files.extend(self.network.rendered_config_paths())
if lsb_release()['release'] not in ("20.04", "22.04"): if lsb_release()["release"] not in ("20.04", "22.04"):
cfg_files.append("/etc/cloud/cloud-init.disabled") cfg_files.append("/etc/cloud/cloud-init.disabled")
if self.identity.hostname is not None: if self.identity.hostname is not None:
hostname = self.identity.hostname.strip() hostname = self.identity.hostname.strip()
files.extend([ files.extend(
('etc/hostname', hostname + "\n", 0o644), [
('etc/hosts', HOSTS_CONTENT.format(hostname=hostname), 0o644), ("etc/hostname", hostname + "\n", 0o644),
]) ("etc/hosts", HOSTS_CONTENT.format(hostname=hostname), 0o644),
]
)
files.append(( files.append(
'etc/cloud/clean.d/99-installer', (
"etc/cloud/clean.d/99-installer",
CLOUDINIT_CLEAN_FILE_TMPL.format( CLOUDINIT_CLEAN_FILE_TMPL.format(
header=generate_timestamped_header(), header=generate_timestamped_header(),
cfg_files=json.dumps(sorted(cfg_files)) cfg_files=json.dumps(sorted(cfg_files)),
), ),
0o755 0o755,
)) )
)
return files return files
def configure_cloud_init(self): def configure_cloud_init(self):
if self.target is None: if self.target is None:
# i.e. reset_partition_only # i.e. reset_partition_only
return return
if self.source.current.variant == 'core': if self.source.current.variant == "core":
# can probably be supported but requires changes # can probably be supported but requires changes
return return
for path, content, cmode in self._cloud_init_files(): for path, content, cmode in self._cloud_init_files():
@ -446,59 +452,57 @@ class SubiquityModel:
write_file(path, content, cmode=cmode) write_file(path, content, cmode=cmode)
def _media_info(self): def _media_info(self):
if os.path.exists('/cdrom/.disk/info'): if os.path.exists("/cdrom/.disk/info"):
with open('/cdrom/.disk/info') as fp: with open("/cdrom/.disk/info") as fp:
return fp.read() return fp.read()
else: else:
return "media-info" return "media-info"
def _machine_id(self): def _machine_id(self):
with open('/etc/machine-id') as fp: with open("/etc/machine-id") as fp:
return fp.read() return fp.read()
def render(self): def render(self):
config = { config = {
'grub': { "grub": {
'terminal': 'unmodified', "terminal": "unmodified",
'probe_additional_os': True, "probe_additional_os": True,
'reorder_uefi': False, "reorder_uefi": False,
}, },
"install": {
'install': { "unmount": "disabled",
'unmount': 'disabled', "save_install_config": False,
'save_install_config': False, "save_install_log": False,
'save_install_log': False,
}, },
"pollinate": {
'pollinate': { "user_agent": {
'user_agent': { "subiquity": "%s_%s"
'subiquity': "%s_%s" % (os.environ.get("SNAP_VERSION", % (
'dry-run'), os.environ.get("SNAP_VERSION", "dry-run"),
os.environ.get("SNAP_REVISION", os.environ.get("SNAP_REVISION", "dry-run"),
'dry-run')), ),
}, },
}, },
"write_files": {
'write_files': { "etc_machine_id": {
'etc_machine_id': { "path": "etc/machine-id",
'path': 'etc/machine-id', "content": self._machine_id(),
'content': self._machine_id(), "permissions": 0o444,
'permissions': 0o444,
}, },
'media_info': { "media_info": {
'path': 'var/log/installer/media-info', "path": "var/log/installer/media-info",
'content': self._media_info(), "content": self._media_info(),
'permissions': 0o644, "permissions": 0o644,
}, },
}, },
} }
if os.path.exists('/run/casper-md5check.json'): if os.path.exists("/run/casper-md5check.json"):
with open('/run/casper-md5check.json') as fp: with open("/run/casper-md5check.json") as fp:
config['write_files']['md5check'] = { config["write_files"]["md5check"] = {
'path': 'var/log/installer/casper-md5check.json', "path": "var/log/installer/casper-md5check.json",
'content': fp.read(), "content": fp.read(),
'permissions': 0o644, "permissions": 0o644,
} }
for model_name in self._install_model_names.all(): for model_name in self._install_model_names.all():

File diff suppressed because it is too large Load Diff

View File

@ -13,15 +13,10 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from subiquitycore.tests.parameterized import parameterized
from subiquitycore.tests import SubiTestCase
from subiquity.common.types import KeyboardSetting from subiquity.common.types import KeyboardSetting
from subiquity.models.keyboard import ( from subiquity.models.keyboard import InconsistentMultiLayoutError, KeyboardModel
InconsistentMultiLayoutError, from subiquitycore.tests import SubiTestCase
KeyboardModel, from subiquitycore.tests.parameterized import parameterized
)
class TestKeyboardModel(SubiTestCase): class TestKeyboardModel(SubiTestCase):
@ -30,9 +25,9 @@ class TestKeyboardModel(SubiTestCase):
def testDefaultUS(self): def testDefaultUS(self):
self.assertIsNone(self.model._setting) self.assertIsNone(self.model._setting)
self.assertEqual('us', self.model.setting.layout) self.assertEqual("us", self.model.setting.layout)
@parameterized.expand((['zz'], ['en'])) @parameterized.expand((["zz"], ["en"]))
def testSetToInvalidLayout(self, layout): def testSetToInvalidLayout(self, layout):
initial = self.model.setting initial = self.model.setting
val = KeyboardSetting(layout=layout) val = KeyboardSetting(layout=layout)
@ -40,39 +35,41 @@ class TestKeyboardModel(SubiTestCase):
self.model.setting = val self.model.setting = val
self.assertEqual(initial, self.model.setting) self.assertEqual(initial, self.model.setting)
@parameterized.expand((['zz'])) @parameterized.expand((["zz"]))
def testSetToInvalidVariant(self, variant): def testSetToInvalidVariant(self, variant):
initial = self.model.setting initial = self.model.setting
val = KeyboardSetting(layout='us', variant=variant) val = KeyboardSetting(layout="us", variant=variant)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.model.setting = val self.model.setting = val
self.assertEqual(initial, self.model.setting) self.assertEqual(initial, self.model.setting)
def testMultiLayout(self): def testMultiLayout(self):
val = KeyboardSetting(layout='us,ara', variant=',') val = KeyboardSetting(layout="us,ara", variant=",")
self.model.setting = val self.model.setting = val
self.assertEqual(self.model.setting, val) self.assertEqual(self.model.setting, val)
def testInconsistentMultiLayout(self): def testInconsistentMultiLayout(self):
initial = self.model.setting initial = self.model.setting
val = KeyboardSetting(layout='us,ara', variant='') val = KeyboardSetting(layout="us,ara", variant="")
with self.assertRaises(InconsistentMultiLayoutError): with self.assertRaises(InconsistentMultiLayoutError):
self.model.setting = val self.model.setting = val
self.assertEqual(self.model.setting, initial) self.assertEqual(self.model.setting, initial)
def testInvalidMultiLayout(self): def testInvalidMultiLayout(self):
initial = self.model.setting initial = self.model.setting
val = KeyboardSetting(layout='us,ara', variant='zz,') val = KeyboardSetting(layout="us,ara", variant="zz,")
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.model.setting = val self.model.setting = val
self.assertEqual(self.model.setting, initial) self.assertEqual(self.model.setting, initial)
@parameterized.expand([ @parameterized.expand(
['ast_ES.UTF-8', 'es', 'ast'], [
['de_DE.UTF-8', 'de', ''], ["ast_ES.UTF-8", "es", "ast"],
['fr_FR.UTF-8', 'fr', 'latin9'], ["de_DE.UTF-8", "de", ""],
['oc', 'us', ''], ["fr_FR.UTF-8", "fr", "latin9"],
]) ["oc", "us", ""],
]
)
def testSettingForLang(self, lang, layout, variant): def testSettingForLang(self, lang, layout, variant):
val = self.model.setting_for_lang(lang) val = self.model.setting_for_lang(lang)
self.assertEqual(layout, val.layout) self.assertEqual(layout, val.layout)
@ -81,20 +78,22 @@ class TestKeyboardModel(SubiTestCase):
def testAllLangsHaveKeyboardSuggestion(self): def testAllLangsHaveKeyboardSuggestion(self):
# every language in the list needs a suggestion, # every language in the list needs a suggestion,
# even if only the default # even if only the default
with open('languagelist') as fp: with open("languagelist") as fp:
for line in fp.readlines(): for line in fp.readlines():
tokens = line.split(':') tokens = line.split(":")
locale = tokens[1] locale = tokens[1]
self.assertIn(locale, self.model.layout_for_lang.keys()) self.assertIn(locale, self.model.layout_for_lang.keys())
def testLoadSuggestions(self): def testLoadSuggestions(self):
data = self.tmp_path('kbd.yaml') data = self.tmp_path("kbd.yaml")
with open(data, 'w') as fp: with open(data, "w") as fp:
fp.write(''' fp.write(
"""
aa_BB.UTF-8: aa_BB.UTF-8:
layout: aa layout: aa
variant: cc variant: cc
''') """
)
actual = self.model.load_layout_suggestions(data) actual = self.model.load_layout_suggestions(data)
expected = {'aa_BB.UTF-8': KeyboardSetting(layout='aa', variant='cc')} expected = {"aa_BB.UTF-8": KeyboardSetting(layout="aa", variant="cc")}
self.assertEqual(expected, actual) self.assertEqual(expected, actual)

View File

@ -41,16 +41,16 @@ class TestLocaleModel(unittest.IsolatedAsyncioTestCase):
self.model.selected_language = "fr_FR" self.model.selected_language = "fr_FR"
with mock.patch("subiquity.models.locale.arun_command") as arun_cmd: with mock.patch("subiquity.models.locale.arun_command") as arun_cmd:
# Currently, the default for fr_FR is fr_FR.ISO8859-1 # Currently, the default for fr_FR is fr_FR.ISO8859-1
with mock.patch("subiquity.models.locale.locale.normalize", with mock.patch(
return_value="fr_FR.UTF-8"): "subiquity.models.locale.locale.normalize", return_value="fr_FR.UTF-8"
):
await self.model.localectl_set_locale() await self.model.localectl_set_locale()
arun_cmd.assert_called_once_with(expected_cmd, check=True) arun_cmd.assert_called_once_with(expected_cmd, check=True)
async def test_try_localectl_set_locale(self): async def test_try_localectl_set_locale(self):
self.model.selected_language = "fr_FR.UTF-8" self.model.selected_language = "fr_FR.UTF-8"
exc = subprocess.CalledProcessError(returncode=1, cmd=["localedef"]) exc = subprocess.CalledProcessError(returncode=1, cmd=["localedef"])
with mock.patch("subiquity.models.locale.arun_command", with mock.patch("subiquity.models.locale.arun_command", side_effect=exc):
side_effect=exc):
await self.model.try_localectl_set_locale() await self.model.try_localectl_set_locale()
with mock.patch("subiquity.models.locale.arun_command"): with mock.patch("subiquity.models.locale.arun_command"):
await self.model.try_localectl_set_locale() await self.model.try_localectl_set_locale()

View File

@ -18,41 +18,37 @@ import unittest
from unittest import mock from unittest import mock
from subiquity.models.mirror import ( from subiquity.models.mirror import (
countrify_uri,
LEGACY_DEFAULT_PRIMARY_SECTION, LEGACY_DEFAULT_PRIMARY_SECTION,
LegacyPrimaryEntry,
MirrorModel, MirrorModel,
MirrorSelectionFallback, MirrorSelectionFallback,
LegacyPrimaryEntry,
PrimaryEntry, PrimaryEntry,
) countrify_uri,
)
class TestCountrifyUrl(unittest.TestCase): class TestCountrifyUrl(unittest.TestCase):
def test_official_archive(self): def test_official_archive(self):
self.assertEqual( self.assertEqual(
countrify_uri( countrify_uri("http://archive.ubuntu.com/ubuntu", cc="fr"),
"http://archive.ubuntu.com/ubuntu", "http://fr.archive.ubuntu.com/ubuntu",
cc="fr"), )
"http://fr.archive.ubuntu.com/ubuntu")
self.assertEqual( self.assertEqual(
countrify_uri( countrify_uri("http://archive.ubuntu.com/ubuntu", cc="us"),
"http://archive.ubuntu.com/ubuntu", "http://us.archive.ubuntu.com/ubuntu",
cc="us"), )
"http://us.archive.ubuntu.com/ubuntu")
def test_ports_archive(self): def test_ports_archive(self):
self.assertEqual( self.assertEqual(
countrify_uri( countrify_uri("http://ports.ubuntu.com/ubuntu-ports", cc="fr"),
"http://ports.ubuntu.com/ubuntu-ports", "http://fr.ports.ubuntu.com/ubuntu-ports",
cc="fr"), )
"http://fr.ports.ubuntu.com/ubuntu-ports")
self.assertEqual( self.assertEqual(
countrify_uri( countrify_uri("http://ports.ubuntu.com/ubuntu-ports", cc="us"),
"http://ports.ubuntu.com/ubuntu-ports", "http://us.ports.ubuntu.com/ubuntu-ports",
cc="us"), )
"http://us.ports.ubuntu.com/ubuntu-ports")
class TestPrimaryEntry(unittest.TestCase): class TestPrimaryEntry(unittest.TestCase):
@ -78,21 +74,20 @@ class TestPrimaryEntry(unittest.TestCase):
model = MirrorModel() model = MirrorModel()
entry = PrimaryEntry.from_config("country-mirror", parent=model) entry = PrimaryEntry.from_config("country-mirror", parent=model)
self.assertEqual(entry, PrimaryEntry(parent=model, self.assertEqual(entry, PrimaryEntry(parent=model, country_mirror=True))
country_mirror=True))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
entry = PrimaryEntry.from_config({}, parent=model) entry = PrimaryEntry.from_config({}, parent=model)
entry = PrimaryEntry.from_config( entry = PrimaryEntry.from_config({"uri": "http://mirror"}, parent=model)
{"uri": "http://mirror"}, parent=model) self.assertEqual(entry, PrimaryEntry(uri="http://mirror", parent=model))
self.assertEqual(entry, PrimaryEntry(
uri="http://mirror", parent=model))
entry = PrimaryEntry.from_config( entry = PrimaryEntry.from_config(
{"uri": "http://mirror", "arches": ["amd64"]}, parent=model) {"uri": "http://mirror", "arches": ["amd64"]}, parent=model
self.assertEqual(entry, PrimaryEntry( )
uri="http://mirror", arches=["amd64"], parent=model)) self.assertEqual(
entry, PrimaryEntry(uri="http://mirror", arches=["amd64"], parent=model)
)
class TestLegacyPrimaryEntry(unittest.TestCase): class TestLegacyPrimaryEntry(unittest.TestCase):
@ -111,8 +106,8 @@ class TestLegacyPrimaryEntry(unittest.TestCase):
def test_get_uri(self): def test_get_uri(self):
self.model.architecture = "amd64" self.model.architecture = "amd64"
primary = LegacyPrimaryEntry( primary = LegacyPrimaryEntry(
[{"uri": "http://myurl", "arches": "amd64"}], [{"uri": "http://myurl", "arches": "amd64"}], parent=self.model
parent=self.model) )
self.assertEqual(primary.uri, "http://myurl") self.assertEqual(primary.uri, "http://myurl")
def test_set_uri(self): def test_set_uri(self):
@ -130,8 +125,9 @@ class TestMirrorModel(unittest.TestCase):
self.model_legacy = MirrorModel() self.model_legacy = MirrorModel()
self.model_legacy.legacy_primary = True self.model_legacy.legacy_primary = True
self.model_legacy.primary_candidates = [ self.model_legacy.primary_candidates = [
LegacyPrimaryEntry(copy.deepcopy( LegacyPrimaryEntry(
LEGACY_DEFAULT_PRIMARY_SECTION), parent=self.model_legacy), copy.deepcopy(LEGACY_DEFAULT_PRIMARY_SECTION), parent=self.model_legacy
),
] ]
self.candidate_legacy = self.model_legacy.primary_candidates[0] self.candidate_legacy = self.model_legacy.primary_candidates[0]
self.model_legacy.primary_staged = self.candidate_legacy self.model_legacy.primary_staged = self.candidate_legacy
@ -152,7 +148,8 @@ class TestMirrorModel(unittest.TestCase):
[ [
"http://CC.archive.ubuntu.com/ubuntu", "http://CC.archive.ubuntu.com/ubuntu",
"http://CC.ports.ubuntu.com/ubuntu-ports", "http://CC.ports.ubuntu.com/ubuntu-ports",
]) ],
)
do_test(self.model) do_test(self.model)
do_test(self.model_legacy) do_test(self.model_legacy)
@ -167,7 +164,7 @@ class TestMirrorModel(unittest.TestCase):
def test_default_disable_components(self): def test_default_disable_components(self):
def do_test(model, candidate): def do_test(model, candidate):
config = model.get_apt_config_staged() config = model.get_apt_config_staged()
self.assertEqual([], config['disable_components']) self.assertEqual([], config["disable_components"])
# The candidate[0] is a country-mirror, skip it. # The candidate[0] is a country-mirror, skip it.
candidate = self.model.primary_candidates[1] candidate = self.model.primary_candidates[1]
@ -179,15 +176,14 @@ class TestMirrorModel(unittest.TestCase):
# autoinstall loads to the config directly # autoinstall loads to the config directly
model = MirrorModel() model = MirrorModel()
data = { data = {
'disable_components': ['non-free'], "disable_components": ["non-free"],
'fallback': 'offline-install', "fallback": "offline-install",
} }
model.load_autoinstall_data(data) model.load_autoinstall_data(data)
self.assertFalse(model.legacy_primary) self.assertFalse(model.legacy_primary)
model.primary_candidates[0].stage() model.primary_candidates[0].stage()
self.assertEqual(set(['non-free']), model.disabled_components) self.assertEqual(set(["non-free"]), model.disabled_components)
self.assertEqual(model.primary_candidates, self.assertEqual(model.primary_candidates, model._default_primary_entries())
model._default_primary_entries())
def test_from_autoinstall_modern(self): def test_from_autoinstall_modern(self):
data = { data = {
@ -202,16 +198,19 @@ class TestMirrorModel(unittest.TestCase):
} }
model = MirrorModel() model = MirrorModel()
model.load_autoinstall_data(data) model.load_autoinstall_data(data)
self.assertEqual(model.primary_candidates, [ self.assertEqual(
model.primary_candidates,
[
PrimaryEntry(parent=model, country_mirror=True), PrimaryEntry(parent=model, country_mirror=True),
PrimaryEntry(uri="http://mirror", parent=model), PrimaryEntry(uri="http://mirror", parent=model),
]) ],
)
def test_disable_add(self): def test_disable_add(self):
def do_test(model, candidate): def do_test(model, candidate):
expected = ['things', 'stuff'] expected = ["things", "stuff"]
model.disable_components(expected.copy(), add=True) model.disable_components(expected.copy(), add=True)
actual = model.get_apt_config_staged()['disable_components'] actual = model.get_apt_config_staged()["disable_components"]
actual.sort() actual.sort()
expected.sort() expected.sort()
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
@ -223,11 +222,11 @@ class TestMirrorModel(unittest.TestCase):
do_test(self.model_legacy, self.candidate_legacy) do_test(self.model_legacy, self.candidate_legacy)
def test_disable_remove(self): def test_disable_remove(self):
self.model.disabled_components = set(['a', 'b', 'things']) self.model.disabled_components = set(["a", "b", "things"])
to_remove = ['things', 'stuff'] to_remove = ["things", "stuff"]
expected = ['a', 'b'] expected = ["a", "b"]
self.model.disable_components(to_remove, add=False) self.model.disable_components(to_remove, add=False)
actual = self.model.get_apt_config_staged()['disable_components'] actual = self.model.get_apt_config_staged()["disable_components"]
actual.sort() actual.sort()
expected.sort() expected.sort()
self.assertEqual(expected, actual) self.assertEqual(expected, actual)
@ -242,12 +241,15 @@ class TestMirrorModel(unittest.TestCase):
self.model.legacy_primary = False self.model.legacy_primary = False
self.model.fallback = MirrorSelectionFallback.OFFLINE_INSTALL self.model.fallback = MirrorSelectionFallback.OFFLINE_INSTALL
self.model.primary_candidates = [ self.model.primary_candidates = [
PrimaryEntry(uri=None, arches=None, PrimaryEntry(uri=None, arches=None, country_mirror=True, parent=self.model),
country_mirror=True, parent=self.model), PrimaryEntry(
PrimaryEntry(uri="http://mirror.local/ubuntu", uri="http://mirror.local/ubuntu", arches=None, parent=self.model
arches=None, parent=self.model), ),
PrimaryEntry(uri="http://amd64.mirror.local/ubuntu", PrimaryEntry(
arches=["amd64"], parent=self.model), uri="http://amd64.mirror.local/ubuntu",
arches=["amd64"],
parent=self.model,
),
] ]
cfg = self.model.make_autoinstall() cfg = self.model.make_autoinstall()
self.assertEqual(cfg["disable_components"], ["non-free"]) self.assertEqual(cfg["disable_components"], ["non-free"])
@ -259,8 +261,7 @@ class TestMirrorModel(unittest.TestCase):
self.model.disabled_components = set(["non-free"]) self.model.disabled_components = set(["non-free"])
self.model.fallback = MirrorSelectionFallback.ABORT self.model.fallback = MirrorSelectionFallback.ABORT
self.model.legacy_primary = True self.model.legacy_primary = True
self.model.primary_candidates = \ self.model.primary_candidates = [LegacyPrimaryEntry(primary, parent=self.model)]
[LegacyPrimaryEntry(primary, parent=self.model)]
self.model.primary_candidates[0].elect() self.model.primary_candidates[0].elect()
cfg = self.model.make_autoinstall() cfg = self.model.make_autoinstall()
self.assertEqual(cfg["disable_components"], ["non-free"]) self.assertEqual(cfg["disable_components"], ["non-free"])
@ -269,22 +270,23 @@ class TestMirrorModel(unittest.TestCase):
def test_create_primary_candidate(self): def test_create_primary_candidate(self):
self.model.legacy_primary = False self.model.legacy_primary = False
candidate = self.model.create_primary_candidate( candidate = self.model.create_primary_candidate("http://mymirror.valid")
"http://mymirror.valid")
self.assertEqual(candidate.uri, "http://mymirror.valid") self.assertEqual(candidate.uri, "http://mymirror.valid")
self.model.legacy_primary = True self.model.legacy_primary = True
candidate = self.model.create_primary_candidate( candidate = self.model.create_primary_candidate("http://mymirror.valid")
"http://mymirror.valid")
self.assertEqual(candidate.uri, "http://mymirror.valid") self.assertEqual(candidate.uri, "http://mymirror.valid")
def test_wants_geoip(self): def test_wants_geoip(self):
country_mirror_candidates = mock.patch.object( country_mirror_candidates = mock.patch.object(
self.model, "country_mirror_candidates", return_value=iter([])) self.model, "country_mirror_candidates", return_value=iter([])
)
with country_mirror_candidates: with country_mirror_candidates:
self.assertFalse(self.model.wants_geoip()) self.assertFalse(self.model.wants_geoip())
country_mirror_candidates = mock.patch.object( country_mirror_candidates = mock.patch.object(
self.model, "country_mirror_candidates", self.model,
return_value=iter([PrimaryEntry(parent=self.model)])) "country_mirror_candidates",
return_value=iter([PrimaryEntry(parent=self.model)]),
)
with country_mirror_candidates: with country_mirror_candidates:
self.assertTrue(self.model.wants_geoip()) self.assertTrue(self.model.wants_geoip())

View File

@ -17,9 +17,7 @@ import subprocess
import unittest import unittest
from unittest import mock from unittest import mock
from subiquity.models.network import ( from subiquity.models.network import NetworkModel
NetworkModel,
)
class TestNetworkModel(unittest.IsolatedAsyncioTestCase): class TestNetworkModel(unittest.IsolatedAsyncioTestCase):
@ -42,6 +40,5 @@ class TestNetworkModel(unittest.IsolatedAsyncioTestCase):
self.assertFalse(await self.model.is_nm_enabled()) self.assertFalse(await self.model.is_nm_enabled())
with mock.patch("subiquity.models.network.arun_command") as arun: with mock.patch("subiquity.models.network.arun_command") as arun:
arun.side_effect = subprocess.CalledProcessError( arun.side_effect = subprocess.CalledProcessError(1, [], None, "error")
1, [], None, "error")
self.assertFalse(await self.model.is_nm_enabled()) self.assertFalse(await self.model.is_nm_enabled())

View File

@ -18,28 +18,26 @@ import random
import shutil import shutil
import tempfile import tempfile
import unittest import unittest
import yaml import yaml
from subiquity.models.source import SourceModel
from subiquitycore.tests.util import random_string from subiquitycore.tests.util import random_string
from subiquity.models.source import (
SourceModel,
)
def make_entry(**fields): def make_entry(**fields):
raw = { raw = {
'name': { "name": {
'en': random_string(), "en": random_string(),
}, },
'description': { "description": {
'en': random_string(), "en": random_string(),
}, },
'id': random_string(), "id": random_string(),
'type': random_string(), "type": random_string(),
'variant': random_string(), "variant": random_string(),
'path': random_string(), "path": random_string(),
'size': random.randint(1000000, 2000000), "size": random.randint(1000000, 2000000),
} }
for k, v in fields.items(): for k, v in fields.items():
raw[k] = v raw[k] = v
@ -47,7 +45,6 @@ def make_entry(**fields):
class TestSourceModel(unittest.TestCase): class TestSourceModel(unittest.TestCase):
def tdir(self): def tdir(self):
tdir = tempfile.mkdtemp() tdir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, tdir) self.addCleanup(shutil.rmtree, tdir)
@ -56,83 +53,78 @@ class TestSourceModel(unittest.TestCase):
def write_and_load_entries(self, model, entries, dir=None): def write_and_load_entries(self, model, entries, dir=None):
if dir is None: if dir is None:
dir = self.tdir() dir = self.tdir()
cat_path = os.path.join(dir, 'catalog.yaml') cat_path = os.path.join(dir, "catalog.yaml")
with open(cat_path, 'w') as fp: with open(cat_path, "w") as fp:
yaml.dump(entries, fp) yaml.dump(entries, fp)
with open(cat_path) as fp: with open(cat_path) as fp:
model.load_from_file(fp) model.load_from_file(fp)
def test_initially_server(self): def test_initially_server(self):
model = SourceModel() model = SourceModel()
self.assertEqual(model.current.variant, 'server') self.assertEqual(model.current.variant, "server")
def test_load_from_file_simple(self): def test_load_from_file_simple(self):
entry = make_entry(id='id1') entry = make_entry(id="id1")
model = SourceModel() model = SourceModel()
self.write_and_load_entries(model, [entry]) self.write_and_load_entries(model, [entry])
self.assertEqual(model.current.id, 'id1') self.assertEqual(model.current.id, "id1")
def test_load_from_file_ignores_extra_keys(self): def test_load_from_file_ignores_extra_keys(self):
entry = make_entry(id='id1', foobarbaz=random_string()) entry = make_entry(id="id1", foobarbaz=random_string())
model = SourceModel() model = SourceModel()
self.write_and_load_entries(model, [entry]) self.write_and_load_entries(model, [entry])
self.assertEqual(model.current.id, 'id1') self.assertEqual(model.current.id, "id1")
def test_default(self): def test_default(self):
entries = [ entries = [
make_entry(id='id1'), make_entry(id="id1"),
make_entry(id='id2', default=True), make_entry(id="id2", default=True),
make_entry(id='id3'), make_entry(id="id3"),
] ]
model = SourceModel() model = SourceModel()
self.write_and_load_entries(model, entries) self.write_and_load_entries(model, entries)
self.assertEqual(model.current.id, 'id2') self.assertEqual(model.current.id, "id2")
def test_get_source_absolute(self): def test_get_source_absolute(self):
entry = make_entry( entry = make_entry(
type='scheme', type="scheme",
path='/foo/bar/baz', path="/foo/bar/baz",
) )
model = SourceModel() model = SourceModel()
self.write_and_load_entries(model, [entry]) self.write_and_load_entries(model, [entry])
self.assertEqual( self.assertEqual(model.get_source(), "scheme:///foo/bar/baz")
model.get_source(), 'scheme:///foo/bar/baz')
def test_get_source_relative(self): def test_get_source_relative(self):
dir = self.tdir() dir = self.tdir()
entry = make_entry( entry = make_entry(
type='scheme', type="scheme",
path='foo/bar/baz', path="foo/bar/baz",
) )
model = SourceModel() model = SourceModel()
self.write_and_load_entries(model, [entry], dir) self.write_and_load_entries(model, [entry], dir)
self.assertEqual( self.assertEqual(model.get_source(), f"scheme://{dir}/foo/bar/baz")
model.get_source(),
f'scheme://{dir}/foo/bar/baz')
def test_get_source_initial(self): def test_get_source_initial(self):
model = SourceModel() model = SourceModel()
self.assertEqual( self.assertEqual(model.get_source(), "cp:///media/filesystem")
model.get_source(),
'cp:///media/filesystem')
def test_render(self): def test_render(self):
model = SourceModel() model = SourceModel()
self.assertEqual(model.render(), {}) self.assertEqual(model.render(), {})
def test_canary(self): def test_canary(self):
with open('examples/sources/install-canary.yaml') as fp: with open("examples/sources/install-canary.yaml") as fp:
model = SourceModel() model = SourceModel()
model.load_from_file(fp) model.load_from_file(fp)
self.assertEqual(2, len(model.sources)) self.assertEqual(2, len(model.sources))
minimal = model.get_matching_source('ubuntu-desktop-minimal') minimal = model.get_matching_source("ubuntu-desktop-minimal")
self.assertIsNotNone(minimal.variations) self.assertIsNotNone(minimal.variations)
self.assertEqual(minimal.size, minimal.variations['default'].size) self.assertEqual(minimal.size, minimal.variations["default"].size)
self.assertEqual(minimal.path, minimal.variations['default'].path) self.assertEqual(minimal.path, minimal.variations["default"].path)
self.assertIsNone(minimal.variations['default'].snapd_system_label) self.assertIsNone(minimal.variations["default"].snapd_system_label)
standard = model.get_matching_source('ubuntu-desktop') standard = model.get_matching_source("ubuntu-desktop")
self.assertIsNotNone(standard.variations) self.assertIsNotNone(standard.variations)
self.assertEqual(2, len(standard.variations)) self.assertEqual(2, len(standard.variations))

View File

@ -15,18 +15,20 @@
import fnmatch import fnmatch
import json import json
import re
import unittest import unittest
from unittest import mock from unittest import mock
import re
import yaml
import yaml
from cloudinit.config.schema import SchemaValidationError from cloudinit.config.schema import SchemaValidationError
try: try:
from cloudinit.config.schema import SchemaProblem from cloudinit.config.schema import SchemaProblem
except ImportError: except ImportError:
def SchemaProblem(x, y): return (x, y) # TODO(drop on cloud-init 22.3 SRU)
from subiquitycore.pubsub import MessageHub def SchemaProblem(x, y):
return (x, y) # TODO(drop on cloud-init 22.3 SRU)
from subiquity.common.types import IdentityData from subiquity.common.types import IdentityData
from subiquity.models.subiquity import ( from subiquity.models.subiquity import (
@ -35,13 +37,11 @@ from subiquity.models.subiquity import (
ModelNames, ModelNames,
SubiquityModel, SubiquityModel,
) )
from subiquity.server.server import ( from subiquity.server.server import INSTALL_MODEL_NAMES, POSTINSTALL_MODEL_NAMES
INSTALL_MODEL_NAMES,
POSTINSTALL_MODEL_NAMES,
)
from subiquity.server.types import InstallerChannels from subiquity.server.types import InstallerChannels
from subiquitycore.pubsub import MessageHub
getent_group_output = ''' getent_group_output = """
root:x:0: root:x:0:
daemon:x:1: daemon:x:1:
bin:x:2: bin:x:2:
@ -57,56 +57,55 @@ man:x:12:
sudo:x:27: sudo:x:27:
ssh:x:118: ssh:x:118:
users:x:100: users:x:100:
''' """
class TestModelNames(unittest.TestCase): class TestModelNames(unittest.TestCase):
def test_for_known_variant(self): def test_for_known_variant(self):
model_names = ModelNames({'a'}, var1={'b'}, var2={'c'}) model_names = ModelNames({"a"}, var1={"b"}, var2={"c"})
self.assertEqual(model_names.for_variant('var1'), {'a', 'b'}) self.assertEqual(model_names.for_variant("var1"), {"a", "b"})
def test_for_unknown_variant(self): def test_for_unknown_variant(self):
model_names = ModelNames({'a'}, var1={'b'}, var2={'c'}) model_names = ModelNames({"a"}, var1={"b"}, var2={"c"})
self.assertEqual(model_names.for_variant('var3'), {'a'}) self.assertEqual(model_names.for_variant("var3"), {"a"})
def test_all(self): def test_all(self):
model_names = ModelNames({'a'}, var1={'b'}, var2={'c'}) model_names = ModelNames({"a"}, var1={"b"}, var2={"c"})
self.assertEqual(model_names.all(), {'a', 'b', 'c'}) self.assertEqual(model_names.all(), {"a", "b", "c"})
class TestSubiquityModel(unittest.IsolatedAsyncioTestCase): class TestSubiquityModel(unittest.IsolatedAsyncioTestCase):
maxDiff = None maxDiff = None
def writtenFiles(self, config): def writtenFiles(self, config):
for k, v in config.get('write_files', {}).items(): for k, v in config.get("write_files", {}).items():
yield v yield v
def assertConfigWritesFile(self, config, path): def assertConfigWritesFile(self, config, path):
self.assertIn(path, [s['path'] for s in self.writtenFiles(config)]) self.assertIn(path, [s["path"] for s in self.writtenFiles(config)])
def writtenFilesMatching(self, config, pattern): def writtenFilesMatching(self, config, pattern):
files = list(self.writtenFiles(config)) files = list(self.writtenFiles(config))
matching = [] matching = []
for spec in files: for spec in files:
if fnmatch.fnmatch(spec['path'], pattern): if fnmatch.fnmatch(spec["path"], pattern):
matching.append(spec) matching.append(spec)
return matching return matching
def writtenFilesMatchingContaining(self, config, pattern, content): def writtenFilesMatchingContaining(self, config, pattern, content):
matching = [] matching = []
for spec in self.writtenFilesMatching(config, pattern): for spec in self.writtenFilesMatching(config, pattern):
if content in spec['content']: if content in spec["content"]:
matching.append(content) matching.append(content)
return matching return matching
def configVal(self, config, path): def configVal(self, config, path):
cur = config cur = config
for component in path.split('.'): for component in path.split("."):
if not isinstance(cur, dict): if not isinstance(cur, dict):
self.fail( self.fail(
"extracting {} reached non-dict {} too early".format( "extracting {} reached non-dict {} too early".format(path, cur)
path, cur)) )
if component not in cur: if component not in cur:
self.fail("no value found for {}".format(path)) self.fail("no value found for {}".format(path))
cur = cur[component] cur = cur[component]
@ -117,11 +116,11 @@ class TestSubiquityModel(unittest.IsolatedAsyncioTestCase):
def assertConfigDoesNotHaveVal(self, config, path): def assertConfigDoesNotHaveVal(self, config, path):
cur = config cur = config
for component in path.split('.'): for component in path.split("."):
if not isinstance(cur, dict): if not isinstance(cur, dict):
self.fail( self.fail(
"extracting {} reached non-dict {} too early".format( "extracting {} reached non-dict {} too early".format(path, cur)
path, cur)) )
if component not in cur: if component not in cur:
return return
cur = cur[component] cur = cur[component]
@ -129,157 +128,149 @@ class TestSubiquityModel(unittest.IsolatedAsyncioTestCase):
async def test_configure(self): async def test_configure(self):
hub = MessageHub() hub = MessageHub()
model = SubiquityModel( model = SubiquityModel("test", hub, ModelNames({"a", "b"}), ModelNames(set()))
'test', hub, ModelNames({'a', 'b'}), ModelNames(set())) model.set_source_variant("var")
model.set_source_variant('var') await hub.abroadcast((InstallerChannels.CONFIGURED, "a"))
await hub.abroadcast((InstallerChannels.CONFIGURED, 'a'))
self.assertFalse(model._install_event.is_set()) self.assertFalse(model._install_event.is_set())
await hub.abroadcast((InstallerChannels.CONFIGURED, 'b')) await hub.abroadcast((InstallerChannels.CONFIGURED, "b"))
self.assertTrue(model._install_event.is_set()) self.assertTrue(model._install_event.is_set())
def make_model(self): def make_model(self):
return SubiquityModel( return SubiquityModel(
'test', MessageHub(), INSTALL_MODEL_NAMES, POSTINSTALL_MODEL_NAMES) "test", MessageHub(), INSTALL_MODEL_NAMES, POSTINSTALL_MODEL_NAMES
)
def test_proxy_set(self): def test_proxy_set(self):
model = self.make_model() model = self.make_model()
proxy_val = 'http://my-proxy' proxy_val = "http://my-proxy"
model.proxy.proxy = proxy_val model.proxy.proxy = proxy_val
config = model.render() config = model.render()
self.assertConfigHasVal(config, 'proxy.http_proxy', proxy_val) self.assertConfigHasVal(config, "proxy.http_proxy", proxy_val)
self.assertConfigHasVal(config, 'proxy.https_proxy', proxy_val) self.assertConfigHasVal(config, "proxy.https_proxy", proxy_val)
confs = self.writtenFilesMatchingContaining( confs = self.writtenFilesMatchingContaining(
config, config,
'etc/systemd/system/snapd.service.d/*.conf', "etc/systemd/system/snapd.service.d/*.conf",
'HTTP_PROXY=' + proxy_val) "HTTP_PROXY=" + proxy_val,
)
self.assertTrue(len(confs) > 0) self.assertTrue(len(confs) > 0)
def test_proxy_notset(self): def test_proxy_notset(self):
model = self.make_model() model = self.make_model()
config = model.render() config = model.render()
self.assertConfigDoesNotHaveVal(config, 'proxy.http_proxy') self.assertConfigDoesNotHaveVal(config, "proxy.http_proxy")
self.assertConfigDoesNotHaveVal(config, 'proxy.https_proxy') self.assertConfigDoesNotHaveVal(config, "proxy.https_proxy")
self.assertConfigDoesNotHaveVal(config, 'apt.http_proxy') self.assertConfigDoesNotHaveVal(config, "apt.http_proxy")
self.assertConfigDoesNotHaveVal(config, 'apt.https_proxy') self.assertConfigDoesNotHaveVal(config, "apt.https_proxy")
confs = self.writtenFilesMatchingContaining( confs = self.writtenFilesMatchingContaining(
config, config, "etc/systemd/system/snapd.service.d/*.conf", "HTTP_PROXY="
'etc/systemd/system/snapd.service.d/*.conf', )
'HTTP_PROXY=')
self.assertTrue(len(confs) == 0) self.assertTrue(len(confs) == 0)
def test_keyboard(self): def test_keyboard(self):
model = self.make_model() model = self.make_model()
config = model.render() config = model.render()
self.assertConfigWritesFile(config, 'etc/default/keyboard') self.assertConfigWritesFile(config, "etc/default/keyboard")
def test_writes_machine_id_media_info(self): def test_writes_machine_id_media_info(self):
model_no_proxy = self.make_model() model_no_proxy = self.make_model()
model_proxy = self.make_model() model_proxy = self.make_model()
model_proxy.proxy.proxy = 'http://something' model_proxy.proxy.proxy = "http://something"
for model in model_no_proxy, model_proxy: for model in model_no_proxy, model_proxy:
config = model.render() config = model.render()
self.assertConfigWritesFile(config, 'etc/machine-id') self.assertConfigWritesFile(config, "etc/machine-id")
self.assertConfigWritesFile(config, 'var/log/installer/media-info') self.assertConfigWritesFile(config, "var/log/installer/media-info")
def test_storage_version(self): def test_storage_version(self):
model = self.make_model() model = self.make_model()
config = model.render() config = model.render()
self.assertConfigHasVal(config, 'storage.version', 1) self.assertConfigHasVal(config, "storage.version", 1)
def test_write_netplan(self): def test_write_netplan(self):
model = self.make_model() model = self.make_model()
config = model.render() config = model.render()
netplan_content = None netplan_content = None
for fspec in config['write_files'].values(): for fspec in config["write_files"].values():
if fspec['path'].startswith('etc/netplan'): if fspec["path"].startswith("etc/netplan"):
if netplan_content is not None: if netplan_content is not None:
self.fail("writing two files to netplan?") self.fail("writing two files to netplan?")
netplan_content = fspec['content'] netplan_content = fspec["content"]
self.assertIsNot(netplan_content, None) self.assertIsNot(netplan_content, None)
netplan = yaml.safe_load(netplan_content) netplan = yaml.safe_load(netplan_content)
self.assertConfigHasVal(netplan, 'network.version', 2) self.assertConfigHasVal(netplan, "network.version", 2)
def test_sources(self): def test_sources(self):
model = self.make_model() model = self.make_model()
config = model.render() config = model.render()
self.assertNotIn('sources', config) self.assertNotIn("sources", config)
def test_mirror(self): def test_mirror(self):
model = self.make_model() model = self.make_model()
mirror_val = 'http://my-mirror' mirror_val = "http://my-mirror"
model.mirror.create_primary_candidate(mirror_val).elect() model.mirror.create_primary_candidate(mirror_val).elect()
config = model.render() config = model.render()
self.assertNotIn('apt', config) self.assertNotIn("apt", config)
@mock.patch('subiquitycore.utils.run_command') @mock.patch("subiquitycore.utils.run_command")
def test_cloud_init_user_list_merge(self, run_cmd): def test_cloud_init_user_list_merge(self, run_cmd):
main_user = IdentityData( main_user = IdentityData(
username='mainuser', username="mainuser", crypted_password="sample_value", hostname="somehost"
crypted_password='sample_value', )
hostname='somehost') secondary_user = {"name": "user2"}
secondary_user = {'name': 'user2'}
with self.subTest('Main user + secondary user'): with self.subTest("Main user + secondary user"):
model = self.make_model() model = self.make_model()
model.identity.add_user(main_user) model.identity.add_user(main_user)
model.userdata = {'users': [secondary_user]} model.userdata = {"users": [secondary_user]}
run_cmd.return_value.stdout = getent_group_output run_cmd.return_value.stdout = getent_group_output
cloud_init_config = model._cloud_init_config() cloud_init_config = model._cloud_init_config()
self.assertEqual(len(cloud_init_config['users']), 2) self.assertEqual(len(cloud_init_config["users"]), 2)
self.assertEqual(cloud_init_config['users'][0]['name'], 'mainuser') self.assertEqual(cloud_init_config["users"][0]["name"], "mainuser")
self.assertEqual( self.assertEqual(cloud_init_config["users"][0]["groups"], "adm,sudo,users")
cloud_init_config['users'][0]['groups'], self.assertEqual(cloud_init_config["users"][1]["name"], "user2")
'adm,sudo,users' run_cmd.assert_called_with(["getent", "group"], check=True)
)
self.assertEqual(cloud_init_config['users'][1]['name'], 'user2')
run_cmd.assert_called_with(['getent', 'group'], check=True)
with self.subTest('Secondary user only'): with self.subTest("Secondary user only"):
model = self.make_model() model = self.make_model()
model.userdata = {'users': [secondary_user]} model.userdata = {"users": [secondary_user]}
cloud_init_config = model._cloud_init_config() cloud_init_config = model._cloud_init_config()
self.assertEqual(len(cloud_init_config['users']), 1) self.assertEqual(len(cloud_init_config["users"]), 1)
self.assertEqual(cloud_init_config['users'][0]['name'], 'user2') self.assertEqual(cloud_init_config["users"][0]["name"], "user2")
with self.subTest('Invalid user-data raises error'): with self.subTest("Invalid user-data raises error"):
model = self.make_model() model = self.make_model()
model.userdata = {'bootcmd': "nope"} model.userdata = {"bootcmd": "nope"}
with self.assertRaises(SchemaValidationError) as ctx: with self.assertRaises(SchemaValidationError) as ctx:
model._cloud_init_config() model._cloud_init_config()
expected_error = ( expected_error = (
"Cloud config schema errors: bootcmd: 'nope' is not of type" "Cloud config schema errors: bootcmd: 'nope' is not of type" " 'array'"
" 'array'"
) )
self.assertEqual(expected_error, str(ctx.exception)) self.assertEqual(expected_error, str(ctx.exception))
@mock.patch('subiquity.models.subiquity.lsb_release') @mock.patch("subiquity.models.subiquity.lsb_release")
@mock.patch('subiquitycore.file_util.datetime.datetime') @mock.patch("subiquitycore.file_util.datetime.datetime")
def test_cloud_init_files_emits_datasource_config_and_clean_script( def test_cloud_init_files_emits_datasource_config_and_clean_script(
self, datetime, lsb_release self, datetime, lsb_release
): ):
datetime.utcnow.return_value = "2004-03-05 ..." datetime.utcnow.return_value = "2004-03-05 ..."
main_user = IdentityData( main_user = IdentityData(
username='mainuser', username="mainuser", crypted_password="sample_pass", hostname="somehost"
crypted_password='sample_pass', )
hostname='somehost')
model = self.make_model() model = self.make_model()
model.identity.add_user(main_user) model.identity.add_user(main_user)
model.userdata = {} model.userdata = {}
expected_files = { expected_files = {
'etc/cloud/cloud.cfg.d/99-installer.cfg': re.compile( "etc/cloud/cloud.cfg.d/99-installer.cfg": re.compile(
'datasource:\n None:\n metadata:\n instance-id: .*\n userdata_raw: "#cloud-config\\\\ngrowpart:\\\\n mode: \\\'off\\\'\\\\npreserve_hostname: true\\\\n\\\\\n' # noqa "datasource:\n None:\n metadata:\n instance-id: .*\n userdata_raw: \"#cloud-config\\\\ngrowpart:\\\\n mode: \\'off\\'\\\\npreserve_hostname: true\\\\n\\\\\n" # noqa
), ),
'etc/hostname': 'somehost\n', "etc/hostname": "somehost\n",
'etc/cloud/ds-identify.cfg': 'policy: enabled\n', "etc/cloud/ds-identify.cfg": "policy: enabled\n",
'etc/hosts': HOSTS_CONTENT.format(hostname='somehost'), "etc/hosts": HOSTS_CONTENT.format(hostname="somehost"),
} }
# Avoid removing /etc/hosts and /etc/hostname in cloud-init clean # Avoid removing /etc/hosts and /etc/hostname in cloud-init clean
cfg_files = [ cfg_files = ["/" + key for key in expected_files.keys() if "host" not in key]
"/" + key for key in expected_files.keys() if "host" not in key
]
cfg_files.append( cfg_files.append(
# Obtained from NetworkModel.render when cloud-init features # Obtained from NetworkModel.render when cloud-init features
# NETPLAN_CONFIG_ROOT_READ_ONLY is True # NETPLAN_CONFIG_ROOT_READ_ONLY is True
@ -287,15 +278,15 @@ class TestSubiquityModel(unittest.IsolatedAsyncioTestCase):
) )
header = "# Autogenerated by Subiquity: 2004-03-05 ... UTC\n" header = "# Autogenerated by Subiquity: 2004-03-05 ... UTC\n"
with self.subTest( with self.subTest(
'Stable releases Jammy do not disable cloud-init.' "Stable releases Jammy do not disable cloud-init."
' NETPLAN_ROOT_READ_ONLY=True uses cloud-init networking' " NETPLAN_ROOT_READ_ONLY=True uses cloud-init networking"
): ):
lsb_release.return_value = {"release": "22.04"} lsb_release.return_value = {"release": "22.04"}
expected_files['etc/cloud/clean.d/99-installer'] = ( expected_files[
CLOUDINIT_CLEAN_FILE_TMPL.format( "etc/cloud/clean.d/99-installer"
] = CLOUDINIT_CLEAN_FILE_TMPL.format(
header=header, cfg_files=json.dumps(sorted(cfg_files)) header=header, cfg_files=json.dumps(sorted(cfg_files))
) )
)
with unittest.mock.patch( with unittest.mock.patch(
"subiquity.cloudinit.open", "subiquity.cloudinit.open",
mock.mock_open( mock.mock_open(
@ -304,57 +295,52 @@ class TestSubiquityModel(unittest.IsolatedAsyncioTestCase):
) )
), ),
): ):
for (cpath, content, perms) in model._cloud_init_files(): for cpath, content, perms in model._cloud_init_files():
if isinstance(expected_files[cpath], re.Pattern): if isinstance(expected_files[cpath], re.Pattern):
self.assertIsNotNone( self.assertIsNotNone(expected_files[cpath].match(content))
expected_files[cpath].match(content)
)
else: else:
self.assertEqual(expected_files[cpath], content) self.assertEqual(expected_files[cpath], content)
with self.subTest( with self.subTest(
'Kinetic++ disables cloud-init post install.' "Kinetic++ disables cloud-init post install."
' NETPLAN_ROOT_READ_ONLY=False avoids cloud-init networking' " NETPLAN_ROOT_READ_ONLY=False avoids cloud-init networking"
): ):
lsb_release.return_value = {"release": "22.10"} lsb_release.return_value = {"release": "22.10"}
cfg_files.append( cfg_files.append(
# Added by _cloud_init_files for 22.10 and later releases # Added by _cloud_init_files for 22.10 and later releases
'/etc/cloud/cloud-init.disabled', "/etc/cloud/cloud-init.disabled",
) )
# Obtained from NetworkModel.render # Obtained from NetworkModel.render
cfg_files.remove('/etc/cloud/cloud.cfg.d/90-installer-network.cfg') cfg_files.remove("/etc/cloud/cloud.cfg.d/90-installer-network.cfg")
cfg_files.append('/etc/netplan/00-installer-config.yaml') cfg_files.append("/etc/netplan/00-installer-config.yaml")
cfg_files.append( cfg_files.append(
'/etc/cloud/cloud.cfg.d/' "/etc/cloud/cloud.cfg.d/" "subiquity-disable-cloudinit-networking.cfg"
'subiquity-disable-cloudinit-networking.cfg'
) )
expected_files[ expected_files[
'etc/cloud/clean.d/99-installer' "etc/cloud/clean.d/99-installer"
] = CLOUDINIT_CLEAN_FILE_TMPL.format( ] = CLOUDINIT_CLEAN_FILE_TMPL.format(
header=header, cfg_files=json.dumps(sorted(cfg_files)) header=header, cfg_files=json.dumps(sorted(cfg_files))
) )
with unittest.mock.patch( with unittest.mock.patch(
'subiquity.cloudinit.open', "subiquity.cloudinit.open",
mock.mock_open( mock.mock_open(
read_data=json.dumps( read_data=json.dumps(
{'features': {'NETPLAN_CONFIG_ROOT_READ_ONLY': False}} {"features": {"NETPLAN_CONFIG_ROOT_READ_ONLY": False}}
) )
), ),
): ):
for (cpath, content, perms) in model._cloud_init_files(): for cpath, content, perms in model._cloud_init_files():
if isinstance(expected_files[cpath], re.Pattern): if isinstance(expected_files[cpath], re.Pattern):
self.assertIsNotNone( self.assertIsNotNone(expected_files[cpath].match(content))
expected_files[cpath].match(content)
)
else: else:
self.assertEqual(expected_files[cpath], content) self.assertEqual(expected_files[cpath], content)
def test_validatecloudconfig_schema(self): def test_validatecloudconfig_schema(self):
model = self.make_model() model = self.make_model()
with self.subTest('Valid cloud-config does not error'): with self.subTest("Valid cloud-config does not error"):
model.validate_cloudconfig_schema( model.validate_cloudconfig_schema(
data={"ssh_import_id": ["chad.smith"]}, data={"ssh_import_id": ["chad.smith"]},
data_source="autoinstall.user-data" data_source="autoinstall.user-data",
) )
# Create our own subclass for focal as schema_deprecations # Create our own subclass for focal as schema_deprecations
@ -367,22 +353,18 @@ class TestSubiquityModel(unittest.IsolatedAsyncioTestCase):
self.schema_deprecations = schema_deprecations self.schema_deprecations = schema_deprecations
problem = SchemaProblem( problem = SchemaProblem(
"bogus", "bogus", "'bogus' is deprecated, use 'notbogus' instead"
"'bogus' is deprecated, use 'notbogus' instead"
) )
with self.subTest('Deprecated cloud-config warns'): with self.subTest("Deprecated cloud-config warns"):
with unittest.mock.patch( with unittest.mock.patch(
"subiquity.models.subiquity.validate_cloudconfig_schema" "subiquity.models.subiquity.validate_cloudconfig_schema"
) as validate: ) as validate:
validate.side_effect = SchemaDeprecation( validate.side_effect = SchemaDeprecation(schema_deprecations=(problem,))
schema_deprecations=(problem,)
)
with self.assertLogs( with self.assertLogs(
"subiquity.models.subiquity", level="INFO" "subiquity.models.subiquity", level="INFO"
) as logs: ) as logs:
model.validate_cloudconfig_schema( model.validate_cloudconfig_schema(
data={"bogus": True}, data={"bogus": True}, data_source="autoinstall.user-data"
data_source="autoinstall.user-data"
) )
expected = ( expected = (
"WARNING:subiquity.models.subiquity:The cloud-init" "WARNING:subiquity.models.subiquity:The cloud-init"
@ -391,22 +373,20 @@ class TestSubiquityModel(unittest.IsolatedAsyncioTestCase):
) )
self.assertEqual(logs.output, [expected]) self.assertEqual(logs.output, [expected])
with self.subTest('Invalid cloud-config schema errors'): with self.subTest("Invalid cloud-config schema errors"):
with self.assertRaises(SchemaValidationError) as ctx: with self.assertRaises(SchemaValidationError) as ctx:
model.validate_cloudconfig_schema( model.validate_cloudconfig_schema(
data={"bootcmd": "nope"}, data_source="system info" data={"bootcmd": "nope"}, data_source="system info"
) )
expected_error = ( expected_error = (
"Cloud config schema errors: bootcmd: 'nope' is not of" "Cloud config schema errors: bootcmd: 'nope' is not of" " type 'array'"
" type 'array'"
) )
self.assertEqual(expected_error, str(ctx.exception)) self.assertEqual(expected_error, str(ctx.exception))
with self.subTest('Prefix autoinstall.user-data cloud-config errors'): with self.subTest("Prefix autoinstall.user-data cloud-config errors"):
with self.assertRaises(SchemaValidationError) as ctx: with self.assertRaises(SchemaValidationError) as ctx:
model.validate_cloudconfig_schema( model.validate_cloudconfig_schema(
data={"bootcmd": "nope"}, data={"bootcmd": "nope"}, data_source="autoinstall.user-data"
data_source="autoinstall.user-data"
) )
expected_error = ( expected_error = (
"Cloud config schema errors:" "Cloud config schema errors:"

View File

@ -15,13 +15,13 @@
import logging import logging
log = logging.getLogger('subiquity.models.timezone') log = logging.getLogger("subiquity.models.timezone")
class TimeZoneModel(object): class TimeZoneModel(object):
""" Model representing timezone""" """Model representing timezone"""
timezone = '' timezone = ""
def __init__(self): def __init__(self):
# This is the raw request from the API / autoinstall. # This is the raw request from the API / autoinstall.
@ -32,16 +32,16 @@ class TimeZoneModel(object):
self._request = None self._request = None
# The actually timezone to set, possibly post-geoip lookup or # The actually timezone to set, possibly post-geoip lookup or
# possibly manually specified. # possibly manually specified.
self.timezone = '' self.timezone = ""
self.got_from_geoip = False self.got_from_geoip = False
def set(self, value): def set(self, value):
self._request = value self._request = value
self.timezone = value if value != 'geoip' else '' self.timezone = value if value != "geoip" else ""
@property @property
def detect_with_geoip(self): def detect_with_geoip(self):
return self._request == 'geoip' return self._request == "geoip"
@property @property
def should_set_tz(self): def should_set_tz(self):
@ -54,10 +54,15 @@ class TimeZoneModel(object):
def make_cloudconfig(self): def make_cloudconfig(self):
if not self.should_set_tz: if not self.should_set_tz:
return {} return {}
return {'timezone': self.timezone} return {"timezone": self.timezone}
def __repr__(self): def __repr__(self):
return ("<TimeZone: request {} detect {} should_set {} " + return (
"timezone {} gfg {}>").format( "<TimeZone: request {} detect {} should_set {} " + "timezone {} gfg {}>"
self.request, self.detect_with_geoip, self.should_set_tz, ).format(
self.timezone, self.got_from_geoip) self.request,
self.detect_with_geoip,
self.should_set_tz,
self.timezone,
self.got_from_geoip,
)

View File

@ -26,8 +26,9 @@ class UbuntuProModel:
the provided token is correct ; nor to retrieve information about the the provided token is correct ; nor to retrieve information about the
subscription. subscription.
""" """
def __init__(self): def __init__(self):
""" Initialize the model. """ """Initialize the model."""
self.token: str = "" self.token: str = ""
def make_cloudconfig(self) -> dict: def make_cloudconfig(self) -> dict:

View File

@ -15,13 +15,13 @@
import logging import logging
log = logging.getLogger('subiquity.models.updates') log = logging.getLogger("subiquity.models.updates")
class UpdatesModel(object): class UpdatesModel(object):
""" Model representing updates selection""" """Model representing updates selection"""
updates = 'security' updates = "security"
def __repr__(self): def __repr__(self):
return "<Updates: {}>".format(self.updates) return "<Updates: {}>".format(self.updates)

View File

@ -14,30 +14,28 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio import asyncio
from contextlib import contextmanager
import logging import logging
import os import os
from contextlib import contextmanager
from socket import gethostname from socket import gethostname
from subprocess import CalledProcessError from subprocess import CalledProcessError
from subiquitycore.utils import arun_command, run_command
from subiquity.server.curtin import run_curtin_command
from subiquity.common.types import (
AdConnectionInfo,
AdJoinResult,
)
log = logging.getLogger('subiquity.server.ad_joiner') from subiquity.common.types import AdConnectionInfo, AdJoinResult
from subiquity.server.curtin import run_curtin_command
from subiquitycore.utils import arun_command, run_command
log = logging.getLogger("subiquity.server.ad_joiner")
@contextmanager @contextmanager
def joining_context(hostname: str, root_dir: str): def joining_context(hostname: str, root_dir: str):
""" Temporarily adjusts the host name to [hostname] and bind-mounts """Temporarily adjusts the host name to [hostname] and bind-mounts
interesting system directories in preparation for running realm interesting system directories in preparation for running realm
in target's [root_dir], undoing it all on exit. """ in target's [root_dir], undoing it all on exit."""
hostname_current = gethostname() hostname_current = gethostname()
binds = ("/proc", "/sys", "/dev", "/run") binds = ("/proc", "/sys", "/dev", "/run")
try: try:
hostname_process = run_command(['hostname', hostname]) hostname_process = run_command(["hostname", hostname])
for bind in binds: for bind in binds:
bound_dir = os.path.join(root_dir, bind[1:]) bound_dir = os.path.join(root_dir, bind[1:])
if bound_dir != bind: if bound_dir != bind:
@ -45,7 +43,7 @@ def joining_context(hostname: str, root_dir: str):
yield hostname_process yield hostname_process
finally: finally:
# Restoring the live session hostname. # Restoring the live session hostname.
hostname_process = run_command(['hostname', hostname_current]) hostname_process = run_command(["hostname", hostname_current])
if hostname_process.returncode: if hostname_process.returncode:
log.info("Failed to restore live session hostname") log.info("Failed to restore live session hostname")
for bind in reversed(binds): for bind in reversed(binds):
@ -54,17 +52,18 @@ def joining_context(hostname: str, root_dir: str):
run_command(["umount", "-f", bound_dir]) run_command(["umount", "-f", bound_dir])
class AdJoinStrategy(): class AdJoinStrategy:
realm = "/usr/sbin/realm" realm = "/usr/sbin/realm"
pam = "/usr/sbin/pam-auth-update" pam = "/usr/sbin/pam-auth-update"
def __init__(self, app): def __init__(self, app):
self.app = app self.app = app
async def do_join(self, info: AdConnectionInfo, hostname: str, context) \ async def do_join(
-> AdJoinResult: self, info: AdConnectionInfo, hostname: str, context
""" This method changes the hostname and perform a real AD join, thus ) -> AdJoinResult:
should only run in a live session. """ """This method changes the hostname and perform a real AD join, thus
should only run in a live session."""
root_dir = self.app.base_model.target root_dir = self.app.base_model.target
# Set hostname for AD to determine FQDN (no FQDN option in realm join, # Set hostname for AD to determine FQDN (no FQDN option in realm join,
# only adcli, which only understands the live system, but not chroot) # only adcli, which only understands the live system, but not chroot)
@ -73,11 +72,21 @@ class AdJoinStrategy():
log.info("Failed to set live session hostname for adcli") log.info("Failed to set live session hostname for adcli")
return AdJoinResult.JOIN_ERROR return AdJoinResult.JOIN_ERROR
cp = await arun_command([self.realm, "join", "--install", root_dir, cp = await arun_command(
"--user", info.admin_name, [
"--computer-name", hostname, self.realm,
"--unattended", info.domain_name], "join",
input=info.password) "--install",
root_dir,
"--user",
info.admin_name,
"--computer-name",
hostname,
"--unattended",
info.domain_name,
],
input=info.password,
)
if cp.returncode: if cp.returncode:
# Try again without the computer name. Lab tests shown more # Try again without the computer name. Lab tests shown more
@ -88,10 +97,19 @@ class AdJoinStrategy():
log.debug(cp.stderr) log.debug(cp.stderr)
log.debug(cp.stdout) log.debug(cp.stdout)
log.debug("Trying again without overriding the computer name:") log.debug("Trying again without overriding the computer name:")
cp = await arun_command([self.realm, "join", "--install", cp = await arun_command(
root_dir, "--user", info.admin_name, [
"--unattended", info.domain_name], self.realm,
input=info.password) "join",
"--install",
root_dir,
"--user",
info.admin_name,
"--unattended",
info.domain_name,
],
input=info.password,
)
if cp.returncode: if cp.returncode:
log.debug("Joining operation failed:") log.debug("Joining operation failed:")
@ -102,11 +120,19 @@ class AdJoinStrategy():
# Enable pam_mkhomedir # Enable pam_mkhomedir
try: try:
# The function raises if the process fail. # The function raises if the process fail.
await run_curtin_command(self.app, context, await run_curtin_command(
"in-target", "-t", root_dir, self.app,
"--", self.pam, "--package", context,
"--enable", "mkhomedir", "in-target",
private_mounts=False) "-t",
root_dir,
"--",
self.pam,
"--package",
"--enable",
"mkhomedir",
private_mounts=False,
)
return AdJoinResult.OK return AdJoinResult.OK
except CalledProcessError: except CalledProcessError:
@ -120,24 +146,25 @@ class AdJoinStrategy():
class StubStrategy(AdJoinStrategy): class StubStrategy(AdJoinStrategy):
async def do_join(self, info: AdConnectionInfo, hostname: str, context) \ async def do_join(
-> AdJoinResult: self, info: AdConnectionInfo, hostname: str, context
""" Enables testing without real join. The result depends on the ) -> AdJoinResult:
"""Enables testing without real join. The result depends on the
domain name initial character, such that if it is: domain name initial character, such that if it is:
- p or P: returns PAM_ERROR. - p or P: returns PAM_ERROR.
- j or J: returns JOIN_ERROR. - j or J: returns JOIN_ERROR.
- returns OK otherwise. """ - returns OK otherwise."""
initial = info.domain_name[0] initial = info.domain_name[0]
if initial in ('j', 'J'): if initial in ("j", "J"):
return AdJoinResult.JOIN_ERROR return AdJoinResult.JOIN_ERROR
if initial in ('p', 'P'): if initial in ("p", "P"):
return AdJoinResult.PAM_ERROR return AdJoinResult.PAM_ERROR
return AdJoinResult.OK return AdJoinResult.OK
class AdJoiner(): class AdJoiner:
def __init__(self, app): def __init__(self, app):
self._result = AdJoinResult.UNKNOWN self._result = AdJoinResult.UNKNOWN
self._completion_event = asyncio.Event() self._completion_event = asyncio.Event()
@ -146,8 +173,9 @@ class AdJoiner():
else: else:
self.strategy = AdJoinStrategy(app) self.strategy = AdJoinStrategy(app)
async def join_domain(self, info: AdConnectionInfo, hostname: str, async def join_domain(
context) -> AdJoinResult: self, info: AdConnectionInfo, hostname: str, context
) -> AdJoinResult:
if hostname: if hostname:
self._result = await self.strategy.do_join(info, hostname, context) self._result = await self.strategy.do_join(info, hostname, context)
else: else:

View File

@ -28,13 +28,8 @@ import tempfile
from typing import List, Optional from typing import List, Optional
import apt_pkg import apt_pkg
from curtin.config import merge_config
from curtin.commands.extract import AbstractSourceHandler from curtin.commands.extract import AbstractSourceHandler
from curtin.config import merge_config
from subiquitycore.file_util import write_file, generate_config_yaml
from subiquitycore.lsb_release import lsb_release
from subiquitycore.utils import astart_command, orig_environ
from subiquity.server.curtin import run_curtin_command from subiquity.server.curtin import run_curtin_command
from subiquity.server.mounter import ( from subiquity.server.mounter import (
@ -43,18 +38,21 @@ from subiquity.server.mounter import (
Mountpoint, Mountpoint,
OverlayCleanupError, OverlayCleanupError,
OverlayMountpoint, OverlayMountpoint,
) )
from subiquitycore.file_util import generate_config_yaml, write_file
from subiquitycore.lsb_release import lsb_release
from subiquitycore.utils import astart_command, orig_environ
log = logging.getLogger('subiquity.server.apt') log = logging.getLogger("subiquity.server.apt")
class AptConfigCheckError(Exception): class AptConfigCheckError(Exception):
""" Error to raise when apt-get update fails with the currently applied """Error to raise when apt-get update fails with the currently applied
configuration. """ configuration."""
def get_index_targets() -> List[str]: def get_index_targets() -> List[str]:
""" Return the identifier of the data files that would be downloaded during """Return the identifier of the data files that would be downloaded during
apt-get update. apt-get update.
NOTE: this uses the default configuration files from the host so this might NOTE: this uses the default configuration files from the host so this might
slightly differ from what we would have in the overlay. slightly differ from what we would have in the overlay.
@ -108,8 +106,7 @@ class AptConfigurer:
# system, or if it is not, just copy /var/lib/apt/lists from the # system, or if it is not, just copy /var/lib/apt/lists from the
# 'configured_tree' overlay. # 'configured_tree' overlay.
def __init__(self, app, mounter: Mounter, def __init__(self, app, mounter: Mounter, source_handler: AbstractSourceHandler):
source_handler: AbstractSourceHandler):
self.app = app self.app = app
self.mounter = mounter self.mounter = mounter
self.source_handler: AbstractSourceHandler = source_handler self.source_handler: AbstractSourceHandler = source_handler
@ -133,30 +130,37 @@ class AptConfigurer:
self.app.base_model.debconf_selections, self.app.base_model.debconf_selections,
] ]
for model in models: for model in models:
merge_config(cfg, model.get_apt_config( merge_config(
final=final, has_network=has_network)) cfg, model.get_apt_config(final=final, has_network=has_network)
return {'apt': cfg} )
return {"apt": cfg}
async def apply_apt_config(self, context, final: bool): async def apply_apt_config(self, context, final: bool):
self.configured_tree = await self.mounter.setup_overlay( self.configured_tree = await self.mounter.setup_overlay([self.source_path])
[self.source_path])
config_location = os.path.join( config_location = os.path.join(
self.app.root, 'var/log/installer/subiquity-curtin-apt.conf') self.app.root, "var/log/installer/subiquity-curtin-apt.conf"
)
generate_config_yaml(config_location, self.apt_config(final)) generate_config_yaml(config_location, self.apt_config(final))
self.app.note_data_for_apport("CurtinAptConfig", config_location) self.app.note_data_for_apport("CurtinAptConfig", config_location)
await run_curtin_command( await run_curtin_command(
self.app, context, 'apt-config', '-t', self.configured_tree.p(), self.app,
config=config_location, private_mounts=True) context,
"apt-config",
"-t",
self.configured_tree.p(),
config=config_location,
private_mounts=True,
)
async def run_apt_config_check(self, output: io.StringIO) -> None: async def run_apt_config_check(self, output: io.StringIO) -> None:
""" Run apt-get update (with various options limiting the amount of """Run apt-get update (with various options limiting the amount of
data donwloaded) in the overlay where the apt configuration was data donwloaded) in the overlay where the apt configuration was
previously deployed. The output of apt-get (stdout + stderr) will be previously deployed. The output of apt-get (stdout + stderr) will be
written to the output parameter. written to the output parameter.
Raises a AptConfigCheckError exception if the apt-get command exited Raises a AptConfigCheckError exception if the apt-get command exited
with non-zero. """ with non-zero."""
assert self.configured_tree is not None assert self.configured_tree is not None
pfx = pathlib.Path(self.configured_tree.p()) pfx = pathlib.Path(self.configured_tree.p())
@ -173,15 +177,15 @@ class AptConfigurer:
# Need to ensure the "partial" directory exists. # Need to ensure the "partial" directory exists.
partial_dir = apt_dirs["State::Lists"] / "partial" partial_dir = apt_dirs["State::Lists"] / "partial"
partial_dir.mkdir( partial_dir.mkdir(parents=True, exist_ok=True)
parents=True, exist_ok=True)
try: try:
shutil.chown(partial_dir, user="_apt") shutil.chown(partial_dir, user="_apt")
except (PermissionError, LookupError) as exc: except (PermissionError, LookupError) as exc:
log.warning("could to set owner of file %s: %r", partial_dir, exc) log.warning("could to set owner of file %s: %r", partial_dir, exc)
apt_cmd = [ apt_cmd = [
"apt-get", "update", "apt-get",
"update",
"-oAPT::Update::Error-Mode=any", "-oAPT::Update::Error-Mode=any",
# Workaround because the default sandbox user (i.e., _apt) does not # Workaround because the default sandbox user (i.e., _apt) does not
# have access to the overlay. # have access to the overlay.
@ -201,8 +205,9 @@ class AptConfigurer:
env["APT_CONFIG"] = config_file.name env["APT_CONFIG"] = config_file.name
config_file.write(apt_config.dump()) config_file.write(apt_config.dump())
config_file.flush() config_file.flush()
proc = await astart_command(apt_cmd, stderr=subprocess.STDOUT, proc = await astart_command(
clean_locale=False, env=env) apt_cmd, stderr=subprocess.STDOUT, clean_locale=False, env=env
)
async def _reader(): async def _reader():
while not proc.stdout.at_eof(): while not proc.stdout.at_eof():
@ -223,42 +228,51 @@ class AptConfigurer:
async def configure_for_install(self, context): async def configure_for_install(self, context):
assert self.configured_tree is not None assert self.configured_tree is not None
self.install_tree = await self.mounter.setup_overlay( self.install_tree = await self.mounter.setup_overlay([self.configured_tree])
[self.configured_tree])
os.mkdir(self.install_tree.p('cdrom')) os.mkdir(self.install_tree.p("cdrom"))
await self.mounter.mount( await self.mounter.mount("/cdrom", self.install_tree.p("cdrom"), options="bind")
'/cdrom', self.install_tree.p('cdrom'), options='bind')
if self.app.base_model.network.has_network: if self.app.base_model.network.has_network:
os.rename( os.rename(
self.install_tree.p('etc/apt/sources.list'), self.install_tree.p("etc/apt/sources.list"),
self.install_tree.p('etc/apt/sources.list.d/original.list')) self.install_tree.p("etc/apt/sources.list.d/original.list"),
)
else: else:
proxy_path = self.install_tree.p( proxy_path = self.install_tree.p("etc/apt/apt.conf.d/90curtin-aptproxy")
'etc/apt/apt.conf.d/90curtin-aptproxy')
if os.path.exists(proxy_path): if os.path.exists(proxy_path):
os.unlink(proxy_path) os.unlink(proxy_path)
codename = lsb_release(dry_run=self.app.opts.dry_run)['codename'] codename = lsb_release(dry_run=self.app.opts.dry_run)["codename"]
write_file( write_file(
self.install_tree.p('etc/apt/sources.list'), self.install_tree.p("etc/apt/sources.list"),
f'deb [check-date=no] file:///cdrom {codename} main restricted\n') f"deb [check-date=no] file:///cdrom {codename} main restricted\n",
)
await run_curtin_command( await run_curtin_command(
self.app, context, "in-target", "-t", self.install_tree.p(), self.app,
"--", "apt-get", "update", private_mounts=True) context,
"in-target",
"-t",
self.install_tree.p(),
"--",
"apt-get",
"update",
private_mounts=True,
)
return self.install_tree.p() return self.install_tree.p()
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def overlay(self): async def overlay(self):
overlay = await self.mounter.setup_overlay([ overlay = await self.mounter.setup_overlay(
[
self.install_tree.upperdir, self.install_tree.upperdir,
self.configured_tree.upperdir, self.configured_tree.upperdir,
self.source_path, self.source_path,
]) ]
)
try: try:
yield overlay yield overlay
finally: finally:
@ -269,8 +283,7 @@ class AptConfigurer:
# compares equal to the one we discarded earlier. # compares equal to the one we discarded earlier.
# But really, there should be better ways to handle this. # But really, there should be better ways to handle this.
try: try:
await self.mounter.unmount( await self.mounter.unmount(Mountpoint(mountpoint=overlay.mountpoint))
Mountpoint(mountpoint=overlay.mountpoint))
except subprocess.CalledProcessError as exc: except subprocess.CalledProcessError as exc:
raise OverlayCleanupError from exc raise OverlayCleanupError from exc
@ -286,41 +299,53 @@ class AptConfigurer:
async def _restore_dir(dir): async def _restore_dir(dir):
shutil.rmtree(target_mnt.p(dir)) shutil.rmtree(target_mnt.p(dir))
await self.app.command_runner.run([ await self.app.command_runner.run(
'cp', '-aT', self.configured_tree.p(dir), target_mnt.p(dir), [
]) "cp",
"-aT",
self.configured_tree.p(dir),
target_mnt.p(dir),
]
)
def _restore_file(path: str) -> None: def _restore_file(path: str) -> None:
shutil.copyfile(self.configured_tree.p(path), target_mnt.p(path)) shutil.copyfile(self.configured_tree.p(path), target_mnt.p(path))
# The file only exists if we are online # The file only exists if we are online
with contextlib.suppress(FileNotFoundError): with contextlib.suppress(FileNotFoundError):
os.unlink(target_mnt.p('etc/apt/sources.list.d/original.list')) os.unlink(target_mnt.p("etc/apt/sources.list.d/original.list"))
_restore_file('etc/apt/sources.list') _restore_file("etc/apt/sources.list")
with contextlib.suppress(FileNotFoundError): with contextlib.suppress(FileNotFoundError):
_restore_file('etc/apt/apt.conf.d/90curtin-aptproxy') _restore_file("etc/apt/apt.conf.d/90curtin-aptproxy")
if self.app.base_model.network.has_network: if self.app.base_model.network.has_network:
await run_curtin_command( await run_curtin_command(
self.app, context, "in-target", "-t", target_mnt.p(), self.app,
"--", "apt-get", "update", private_mounts=True) context,
"in-target",
"-t",
target_mnt.p(),
"--",
"apt-get",
"update",
private_mounts=True,
)
else: else:
await _restore_dir('var/lib/apt/lists') await _restore_dir("var/lib/apt/lists")
await self.cleanup() await self.cleanup()
try: try:
d = target_mnt.p('cdrom') d = target_mnt.p("cdrom")
os.rmdir(d) os.rmdir(d)
except OSError as ose: except OSError as ose:
log.warning(f'failed to rmdir {d}: {ose}') log.warning(f"failed to rmdir {d}: {ose}")
async def setup_target(self, context, target: str): async def setup_target(self, context, target: str):
# Call this after the rootfs has been extracted to the real target # Call this after the rootfs has been extracted to the real target
# system but before any configuration is applied to it. # system but before any configuration is applied to it.
target_mnt = Mountpoint(mountpoint=target) target_mnt = Mountpoint(mountpoint=target)
await self.mounter.mount( await self.mounter.mount("/cdrom", target_mnt.p("cdrom"), options="bind")
'/cdrom', target_mnt.p('cdrom'), options='bind')
class DryRunAptConfigurer(AptConfigurer): class DryRunAptConfigurer(AptConfigurer):
@ -339,8 +364,8 @@ class DryRunAptConfigurer(AptConfigurer):
await self.cleanup() await self.cleanup()
def get_mirror_check_strategy(self, url: str) -> "MirrorCheckStrategy": def get_mirror_check_strategy(self, url: str) -> "MirrorCheckStrategy":
""" For a given mirror URL, return the strategy that we should use to """For a given mirror URL, return the strategy that we should use to
perform mirror checking. """ perform mirror checking."""
for known in self.app.dr_cfg.apt_mirrors_known: for known in self.app.dr_cfg.apt_mirrors_known:
if "url" in known: if "url" in known:
if known["url"] != url: if known["url"] != url:
@ -354,16 +379,18 @@ class DryRunAptConfigurer(AptConfigurer):
return self.MirrorCheckStrategy(known["strategy"]) return self.MirrorCheckStrategy(known["strategy"])
return self.MirrorCheckStrategy( return self.MirrorCheckStrategy(
self.app.dr_cfg.apt_mirror_check_default_strategy) self.app.dr_cfg.apt_mirror_check_default_strategy
)
async def apt_config_check_failure(self, output: io.StringIO) -> None: async def apt_config_check_failure(self, output: io.StringIO) -> None:
""" Pretend that the execution of the apt-get update command results in """Pretend that the execution of the apt-get update command results in
a failure. """ a failure."""
url = self.app.base_model.mirror.primary_staged.uri url = self.app.base_model.mirror.primary_staged.uri
release = lsb_release(dry_run=True)["codename"] release = lsb_release(dry_run=True)["codename"]
host = url.split("/")[2] host = url.split("/")[2]
output.write(f"""\ output.write(
f"""\
Ign:1 {url} {release} InRelease Ign:1 {url} {release} InRelease
Ign:2 {url} {release}-updates InRelease Ign:2 {url} {release}-updates InRelease
Ign:3 {url} {release}-backports InRelease Ign:3 {url} {release}-backports InRelease
@ -389,28 +416,31 @@ E: Failed to fetch {url}/dists/{release}-security/InRelease\
Temporary failure resolving '{host}' Temporary failure resolving '{host}'
E: Some index files failed to download. They have been ignored, E: Some index files failed to download. They have been ignored,
or old ones used instead. or old ones used instead.
""") """
)
raise AptConfigCheckError raise AptConfigCheckError
async def apt_config_check_success(self, output: io.StringIO) -> None: async def apt_config_check_success(self, output: io.StringIO) -> None:
""" Pretend that the execution of the apt-get update command results in """Pretend that the execution of the apt-get update command results in
a success. """ a success."""
url = self.app.base_model.mirror.primary_staged.uri url = self.app.base_model.mirror.primary_staged.uri
release = lsb_release(dry_run=True)["codename"] release = lsb_release(dry_run=True)["codename"]
output.write(f"""\ output.write(
f"""\
Get:1 {url} {release} InRelease [267 kB] Get:1 {url} {release} InRelease [267 kB]
Get:2 {url} {release}-updates InRelease [109 kB] Get:2 {url} {release}-updates InRelease [109 kB]
Get:3 {url} {release}-backports InRelease [99.9 kB] Get:3 {url} {release}-backports InRelease [99.9 kB]
Get:4 {url} {release}-security InRelease [109 kB] Get:4 {url} {release}-security InRelease [109 kB]
Fetched 585 kB in 1s (1057 kB/s) Fetched 585 kB in 1s (1057 kB/s)
Reading package lists... Reading package lists...
""") """
)
async def run_apt_config_check(self, output: io.StringIO) -> None: async def run_apt_config_check(self, output: io.StringIO) -> None:
""" Dry-run implementation of the Apt config check. """Dry-run implementation of the Apt config check.
The strategy used is based on the URL of the primary mirror. The The strategy used is based on the URL of the primary mirror. The
apt-get command can either run on the host or be faked entirely. """ apt-get command can either run on the host or be faked entirely."""
assert self.configured_tree is not None assert self.configured_tree is not None
failure = self.apt_config_check_failure failure = self.apt_config_check_failure

View File

@ -16,40 +16,38 @@
import asyncio import asyncio
import logging import logging
from typing import Any, List
import time import time
from typing import Any, List
from subiquity.server.ubuntu_advantage import ( from subiquity.server.ubuntu_advantage import APIUsageError, UAInterface
UAInterface,
APIUsageError,
)
from subiquitycore.async_helpers import schedule_task from subiquitycore.async_helpers import schedule_task
log = logging.getLogger("subiquity.server.contract_selection") log = logging.getLogger("subiquity.server.contract_selection")
class UPCSExpiredError(Exception): class UPCSExpiredError(Exception):
""" Exception to be raised when a contract selection expired. """ """Exception to be raised when a contract selection expired."""
class UnknownError(Exception): class UnknownError(Exception):
""" Exception to be raised in case of unexpected error. """ """Exception to be raised in case of unexpected error."""
def __init__(self, message: str = "", errors: List[Any] = []) -> None: def __init__(self, message: str = "", errors: List[Any] = []) -> None:
self.errors = errors self.errors = errors
super().__init__(message) super().__init__(message)
class ContractSelection: class ContractSelection:
""" Represents an already initiated contract selection. """ """Represents an already initiated contract selection."""
def __init__( def __init__(
self, self,
client: UAInterface, client: UAInterface,
magic_token: str, magic_token: str,
user_code: str, user_code: str,
validity_seconds: int) -> None: validity_seconds: int,
""" Initialize the contract selection. """ ) -> None:
"""Initialize the contract selection."""
self.client = client self.client = client
self.magic_token = magic_token self.magic_token = magic_token
self.user_code = user_code self.user_code = user_code
@ -57,10 +55,9 @@ class ContractSelection:
self.task = asyncio.create_task(self._run_polling()) self.task = asyncio.create_task(self._run_polling())
@classmethod @classmethod
async def initiate(cls, client: UAInterface) \ async def initiate(cls, client: UAInterface) -> "ContractSelection":
-> "ContractSelection": """Initiate a contract selection and return a ContractSelection
""" Initiate a contract selection and return a ContractSelection request object."""
request object. """
answer = await client.strategy.magic_initiate_v1() answer = await client.strategy.magic_initiate_v1()
if answer["result"] != "success": if answer["result"] != "success":
@ -70,17 +67,17 @@ class ContractSelection:
client=client, client=client,
magic_token=answer["data"]["attributes"]["token"], magic_token=answer["data"]["attributes"]["token"],
validity_seconds=answer["data"]["attributes"]["expires_in"], validity_seconds=answer["data"]["attributes"]["expires_in"],
user_code=answer["data"]["attributes"]["user_code"]) user_code=answer["data"]["attributes"]["user_code"],
)
async def _run_polling(self) -> str: async def _run_polling(self) -> str:
""" Runs the polling and eventually return a contract token. """ """Runs the polling and eventually return a contract token."""
# Wait an initial 30 seconds before sending the first request, then # Wait an initial 30 seconds before sending the first request, then
# send requests at regular interval every 10 seconds. # send requests at regular interval every 10 seconds.
await asyncio.sleep(30 / self.client.strategy.scale_factor) await asyncio.sleep(30 / self.client.strategy.scale_factor)
start_time = time.monotonic() start_time = time.monotonic()
answer = await self.client.strategy.magic_wait_v1( answer = await self.client.strategy.magic_wait_v1(magic_token=self.magic_token)
magic_token=self.magic_token)
call_duration = time.monotonic() - start_time call_duration = time.monotonic() - start_time
@ -89,26 +86,28 @@ class ContractSelection:
exception = UnknownError() exception = UnknownError()
for error in answer["errors"]: for error in answer["errors"]:
if error["code"] == "magic-attach-token-error" and \ if (
call_duration >= 60 / self.client.strategy.scale_factor: error["code"] == "magic-attach-token-error"
and call_duration >= 60 / self.client.strategy.scale_factor
):
# Assume it's a timeout if it lasted more than 1 minute. # Assume it's a timeout if it lasted more than 1 minute.
exception = UPCSExpiredError(error["title"]) exception = UPCSExpiredError(error["title"])
else: else:
log.warning("magic_wait_v1: %s: %s", log.warning("magic_wait_v1: %s: %s", error["code"], error["title"])
error["code"], error["title"])
raise exception raise exception
def cancel(self): def cancel(self):
""" Cancel the polling task and asynchronously delete the associated """Cancel the polling task and asynchronously delete the associated
resource. """ resource."""
self.task.cancel() self.task.cancel()
async def delete_resource() -> None: async def delete_resource() -> None:
""" Release the resource on the server. """ """Release the resource on the server."""
try: try:
answer = await self.client.strategy.magic_revoke_v1( answer = await self.client.strategy.magic_revoke_v1(
magic_token=self.magic_token) magic_token=self.magic_token
)
except APIUsageError as e: except APIUsageError as e:
log.warning("failed to revoke magic-token: %r", e) log.warning("failed to revoke magic-token: %r", e)
@ -117,7 +116,8 @@ class ContractSelection:
log.debug("successfully revoked magic-token") log.debug("successfully revoked magic-token")
return return
for error in answer["errors"]: for error in answer["errors"]:
log.warning("magic_revoke_v1: %s: %s", log.warning(
error["code"], error["title"]) "magic_revoke_v1: %s: %s", error["code"], error["title"]
)
schedule_task(delete_resource()) schedule_task(delete_resource())

View File

@ -20,19 +20,15 @@ from typing import Any, Optional
import jsonschema import jsonschema
from subiquitycore.context import with_context
from subiquitycore.controller import (
BaseController,
)
from subiquity.common.api.server import bind from subiquity.common.api.server import bind
from subiquity.server.types import InstallerChannels from subiquity.server.types import InstallerChannels
from subiquitycore.context import with_context
from subiquitycore.controller import BaseController
log = logging.getLogger("subiquity.server.controller") log = logging.getLogger("subiquity.server.controller")
class SubiquityController(BaseController): class SubiquityController(BaseController):
autoinstall_key: Optional[str] = None autoinstall_key: Optional[str] = None
autoinstall_schema: Any = None autoinstall_schema: Any = None
autoinstall_default: Any = None autoinstall_default: Any = None
@ -48,10 +44,9 @@ class SubiquityController(BaseController):
def __init__(self, app): def __init__(self, app):
super().__init__(app) super().__init__(app)
self.context.set('controller', self) self.context.set("controller", self)
if self.interactive_for_variants is not None: if self.interactive_for_variants is not None:
self.app.hub.subscribe( self.app.hub.subscribe(InstallerChannels.INSTALL_CONFIRMED, self._confirmed)
InstallerChannels.INSTALL_CONFIRMED, self._confirmed)
async def _confirmed(self): async def _confirmed(self):
variant = self.app.base_model.source.current.variant variant = self.app.base_model.source.current.variant
@ -102,8 +97,7 @@ class SubiquityController(BaseController):
def interactive(self): def interactive(self):
if not self.app.autoinstall_config: if not self.app.autoinstall_config:
return self._active return self._active
i_sections = self.app.autoinstall_config.get( i_sections = self.app.autoinstall_config.get("interactive-sections", [])
'interactive-sections', [])
if "*" in i_sections: if "*" in i_sections:
return self._active return self._active
@ -111,21 +105,23 @@ class SubiquityController(BaseController):
if self.autoinstall_key in i_sections: if self.autoinstall_key in i_sections:
return self._active return self._active
if self.autoinstall_key_alias is not None \ if (
and self.autoinstall_key_alias in i_sections: self.autoinstall_key_alias is not None
and self.autoinstall_key_alias in i_sections
):
return self._active return self._active
async def configured(self): async def configured(self):
"""Let the world know that this controller's model is now configured. """Let the world know that this controller's model is now configured."""
""" with open(self.app.state_path("states", self.name), "w") as fp:
with open(self.app.state_path('states', self.name), 'w') as fp:
json.dump(self.serialize(), fp) json.dump(self.serialize(), fp)
if self.model_name is not None: if self.model_name is not None:
await self.app.hub.abroadcast( await self.app.hub.abroadcast(
(InstallerChannels.CONFIGURED, self.model_name)) (InstallerChannels.CONFIGURED, self.model_name)
)
def load_state(self): def load_state(self):
state_path = self.app.state_path('states', self.name) state_path = self.app.state_path("states", self.name)
if not os.path.exists(state_path): if not os.path.exists(state_path):
return return
with open(state_path) as fp: with open(state_path) as fp:
@ -143,6 +139,5 @@ class SubiquityController(BaseController):
class NonInteractiveController(SubiquityController): class NonInteractiveController(SubiquityController):
def interactive(self): def interactive(self):
return False return False

Some files were not shown because too many files have changed in this diff Show More