format with black + isort

This commit is contained in:
Dan Bungert 2023-07-25 15:26:25 -06:00
parent 93a34a64fa
commit 34d40643ad
298 changed files with 14900 additions and 14044 deletions

View File

@ -16,6 +16,7 @@
""" Console-Conf """
from subiquitycore import i18n
__all__ = [
'i18n',
"i18n",
]

View File

@ -16,42 +16,63 @@
import argparse
import asyncio
import sys
import os
import logging
from subiquitycore.log import setup_logger
from subiquitycore import __version__ as VERSION
import os
import sys
from console_conf.core import ConsoleConf, RecoveryChooser
from subiquitycore import __version__ as VERSION
from subiquitycore.log import setup_logger
def parse_options(argv):
parser = argparse.ArgumentParser(
description=(
'console-conf - Pre-Ownership Configuration for Ubuntu Core'),
prog='console-conf')
parser.add_argument('--dry-run', action='store_true',
dest='dry_run',
help='menu-only, do not call installer function')
parser.add_argument('--serial', action='store_true',
dest='run_on_serial',
help='Run the installer over serial console.')
parser.add_argument('--ascii', action='store_true',
dest='ascii',
help='Run the installer in ascii mode.')
parser.add_argument('--machine-config', metavar='CONFIG',
dest='machine_config',
description=("console-conf - Pre-Ownership Configuration for Ubuntu Core"),
prog="console-conf",
)
parser.add_argument(
"--dry-run",
action="store_true",
dest="dry_run",
help="menu-only, do not call installer function",
)
parser.add_argument(
"--serial",
action="store_true",
dest="run_on_serial",
help="Run the installer over serial console.",
)
parser.add_argument(
"--ascii",
action="store_true",
dest="ascii",
help="Run the installer in ascii mode.",
)
parser.add_argument(
"--machine-config",
metavar="CONFIG",
dest="machine_config",
type=argparse.FileType(),
help="Don't Probe. Use probe data file")
parser.add_argument('--screens', action='append', dest='screens',
default=[])
parser.add_argument('--answers')
parser.add_argument('--recovery-chooser-mode', action='store_true',
dest='chooser_systems',
help=('Run as a recovery chooser interacting with the '
'calling process over stdin/stdout streams'))
parser.add_argument('--output-base', action='store', dest='output_base',
default='.subiquity',
help='in dryrun, control basedir of files')
help="Don't Probe. Use probe data file",
)
parser.add_argument("--screens", action="append", dest="screens", default=[])
parser.add_argument("--answers")
parser.add_argument(
"--recovery-chooser-mode",
action="store_true",
dest="chooser_systems",
help=(
"Run as a recovery chooser interacting with the "
"calling process over stdin/stdout streams"
),
)
parser.add_argument(
"--output-base",
action="store",
dest="output_base",
default=".subiquity",
help="in dryrun, control basedir of files",
)
return parser.parse_args(argv)
@ -64,7 +85,7 @@ def main():
if opts.dry_run:
LOGDIR = opts.output_base
setup_logger(dir=LOGDIR)
logger = logging.getLogger('console_conf')
logger = logging.getLogger("console_conf")
logger.info("Starting console-conf v{}".format(VERSION))
logger.info("Arguments passed: {}".format(sys.argv))
@ -73,8 +94,7 @@ def main():
# when running as a chooser, the stdin/stdout streams are set up by
# the process that runs us, attempt to restore the tty in/out by
# looking at stderr
chooser_input, chooser_output = restore_std_streams_from(
sys.stderr)
chooser_input, chooser_output = restore_std_streams_from(sys.stderr)
interface = RecoveryChooser(opts, chooser_input, chooser_output)
else:
interface = ConsoleConf(opts)
@ -91,10 +111,10 @@ def restore_std_streams_from(from_file):
tty = os.ttyname(from_file.fileno())
# we have tty now
chooser_input, chooser_output = sys.stdin, sys.stdout
sys.stdin = open(tty, 'r')
sys.stdout = open(tty, 'w')
sys.stdin = open(tty, "r")
sys.stdout = open(tty, "w")
return chooser_input, chooser_output
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(main())

View File

@ -17,20 +17,18 @@
import logging
import sys
from subiquitycore.log import setup_logger
from subiquitycore import __version__ as VERSION
from console_conf.controllers.identity import write_login_details_standalone
from subiquitycore import __version__ as VERSION
from subiquitycore.log import setup_logger
def main():
logger = setup_logger(dir='/var/log/console-conf')
logger = logging.getLogger('console_conf')
logger.info(
"Starting console-conf-write-login-details v{}".format(VERSION))
logger = setup_logger(dir="/var/log/console-conf")
logger = logging.getLogger("console_conf")
logger.info("Starting console-conf-write-login-details v{}".format(VERSION))
logger.info("Arguments passed: {}".format(sys.argv))
return write_login_details_standalone()
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(main())

View File

@ -15,19 +15,17 @@
""" console-conf controllers """
from .identity import IdentityController
from subiquitycore.controllers.network import NetworkController
from .welcome import WelcomeController, RecoveryChooserWelcomeController
from .chooser import (
RecoveryChooserController,
RecoveryChooserConfirmController
)
from .chooser import RecoveryChooserConfirmController, RecoveryChooserController
from .identity import IdentityController
from .welcome import RecoveryChooserWelcomeController, WelcomeController
__all__ = [
'IdentityController',
'NetworkController',
'WelcomeController',
'RecoveryChooserWelcomeController',
'RecoveryChooserController',
'RecoveryChooserConfirmController',
"IdentityController",
"NetworkController",
"WelcomeController",
"RecoveryChooserWelcomeController",
"RecoveryChooserController",
"RecoveryChooserConfirmController",
]

View File

@ -15,18 +15,16 @@
import logging
from console_conf.ui.views import (
ChooserView,
ChooserCurrentSystemView,
ChooserConfirmView,
ChooserCurrentSystemView,
ChooserView,
)
from subiquitycore.tuicontroller import TuiController
log = logging.getLogger("console_conf.controllers.chooser")
class RecoveryChooserBaseController(TuiController):
def __init__(self, app):
super().__init__(app)
self.model = app.base_model
@ -40,7 +38,6 @@ class RecoveryChooserBaseController(TuiController):
class RecoveryChooserController(RecoveryChooserBaseController):
def __init__(self, app):
super().__init__(app)
self._model_view, self._all_view = self._make_views()
@ -64,9 +61,9 @@ class RecoveryChooserController(RecoveryChooserBaseController):
if self.model.current and self.model.current.actions:
# only when we have a current system and it has actions available
more = len(self.model.systems) > 1
current_view = ChooserCurrentSystemView(self,
self.model.current,
has_more=more)
current_view = ChooserCurrentSystemView(
self, self.model.current, has_more=more
)
all_view = ChooserView(self, self.model.systems)
return current_view, all_view
@ -80,8 +77,7 @@ class RecoveryChooserController(RecoveryChooserBaseController):
self.ui.set_body(self._all_view)
def back(self):
if self._current_view == self._all_view and \
self._model_view is not None:
if self._current_view == self._all_view and self._model_view is not None:
# back in the all-systems screen goes back to the current model
# screen, provided we have one
self._current_view = self._model_view

View File

@ -20,14 +20,13 @@ import pwd
import shlex
import sys
from subiquitycore.ssh import host_key_info, get_ips_standalone
from console_conf.ui.views import IdentityView, LoginView
from subiquitycore.snapd import SnapdConnection
from subiquitycore.ssh import get_ips_standalone, host_key_info
from subiquitycore.tuicontroller import TuiController
from subiquitycore.utils import disable_console_conf, run_command
from console_conf.ui.views import IdentityView, LoginView
log = logging.getLogger('console_conf.controllers.identity')
log = logging.getLogger("console_conf.controllers.identity")
def get_core_version():
@ -54,31 +53,31 @@ def get_core_version():
def get_managed():
"""Check if device is managed"""
con = SnapdConnection('', '/run/snapd.socket')
return con.get('v2/system-info').json()['result']['managed']
con = SnapdConnection("", "/run/snapd.socket")
return con.get("v2/system-info").json()["result"]["managed"]
def get_realname(username):
try:
info = pwd.getpwnam(username)
except KeyError:
return ''
return info.pw_gecos.split(',', 1)[0]
return ""
return info.pw_gecos.split(",", 1)[0]
def get_device_owner():
"""Get device owner, if any"""
con = SnapdConnection('', '/run/snapd.socket')
for user in con.get('v2/users').json()['result']:
if 'username' not in user:
con = SnapdConnection("", "/run/snapd.socket")
for user in con.get("v2/users").json()["result"]:
if "username" not in user:
continue
username = user['username']
homedir = '/home/' + username
username = user["username"]
homedir = "/home/" + username
if os.path.isdir(homedir):
return {
'username': username,
'realname': get_realname(username),
'homedir': homedir,
"username": username,
"realname": get_realname(username),
"homedir": homedir,
}
return None
@ -110,15 +109,22 @@ def write_login_details(fp, username, ips):
tty_name = os.ttyname(0)[5:] # strip off the /dev/
version = get_core_version() or "16"
if len(ips) == 0:
fp.write(login_details_tmpl_no_ip.format(
sshcommands=sshcommands, tty_name=tty_name, version=version))
fp.write(
login_details_tmpl_no_ip.format(
sshcommands=sshcommands, tty_name=tty_name, version=version
)
)
else:
first_ip = ips[0]
fp.write(login_details_tmpl.format(sshcommands=sshcommands,
fp.write(
login_details_tmpl.format(
sshcommands=sshcommands,
host_key_info=host_key_info(),
tty_name=tty_name,
first_ip=first_ip,
version=version))
version=version,
)
)
def write_login_details_standalone():
@ -131,18 +137,16 @@ def write_login_details_standalone():
else:
tty_name = os.ttyname(0)[5:]
version = get_core_version() or "16"
print(login_details_tmpl_no_ip.format(
tty_name=tty_name, version=version))
print(login_details_tmpl_no_ip.format(tty_name=tty_name, version=version))
return 2
if owner is None:
print("device managed without user @ {}".format(', '.join(ips)))
print("device managed without user @ {}".format(", ".join(ips)))
else:
write_login_details(sys.stdout, owner['username'], ips)
write_login_details(sys.stdout, owner["username"], ips)
return 0
class IdentityController(TuiController):
def __init__(self, app):
super().__init__(app)
self.model = app.base_model.identity
@ -152,11 +156,9 @@ class IdentityController(TuiController):
device_owner = get_device_owner()
if device_owner:
self.model.add_user(device_owner)
key_file = os.path.join(device_owner['homedir'],
".ssh/authorized_keys")
cp = run_command(['ssh-keygen', '-lf', key_file])
self.model.user.fingerprints = (
cp.stdout.replace('\r', '').splitlines())
key_file = os.path.join(device_owner["homedir"], ".ssh/authorized_keys")
cp = run_command(["ssh-keygen", "-lf", key_file])
self.model.user.fingerprints = cp.stdout.replace("\r", "").splitlines()
return self.make_login_view()
else:
return IdentityView(self.model, self)
@ -164,35 +166,35 @@ class IdentityController(TuiController):
def identity_done(self, email):
if self.opts.dry_run:
result = {
'realname': email,
'username': email,
"realname": email,
"username": email,
}
self.model.add_user(result)
login_details_path = self.opts.output_base + '/login-details.txt'
login_details_path = self.opts.output_base + "/login-details.txt"
else:
self.app.urwid_loop.draw_screen()
cp = run_command(
["snap", "create-user", "--sudoer", "--json", email])
cp = run_command(["snap", "create-user", "--sudoer", "--json", email])
if cp.returncode != 0:
if isinstance(self.ui.body, IdentityView):
self.ui.body.snap_create_user_failed(
"Creating user failed:", cp.stderr)
"Creating user failed:", cp.stderr
)
return
else:
data = json.loads(cp.stdout)
result = {
'realname': email,
'username': data['username'],
"realname": email,
"username": data["username"],
}
os.makedirs('/run/console-conf', exist_ok=True)
login_details_path = '/run/console-conf/login-details.txt'
os.makedirs("/run/console-conf", exist_ok=True)
login_details_path = "/run/console-conf/login-details.txt"
self.model.add_user(result)
ips = []
net_model = self.app.base_model.network
for dev in net_model.get_all_netdevs():
ips.extend(dev.actual_global_ip_addresses)
with open(login_details_path, 'w') as fp:
write_login_details(fp, result['username'], ips)
with open(login_details_path, "w") as fp:
write_login_details(fp, result["username"], ips)
self.login()
def cancel(self):

View File

@ -15,16 +15,12 @@
import unittest
from unittest import mock
from subiquitycore.tests.mocks import make_app
from console_conf.controllers.chooser import (
RecoveryChooserController,
RecoveryChooserConfirmController,
RecoveryChooserController,
)
from console_conf.models.systems import (
RecoverySystemsModel,
SelectedSystemAction,
)
from console_conf.models.systems import RecoverySystemsModel, SelectedSystemAction
from subiquitycore.tests.mocks import make_app
model1_non_current = {
"current": False,
@ -43,7 +39,7 @@ model1_non_current = {
"actions": [
{"title": "action 1", "mode": "action1"},
{"title": "action 2", "mode": "action2"},
]
],
}
model2_current = {
@ -62,21 +58,19 @@ model2_current = {
},
"actions": [
{"title": "action 1", "mode": "action1"},
]
],
}
def make_model():
return RecoverySystemsModel.from_systems([model1_non_current,
model2_current])
return RecoverySystemsModel.from_systems([model1_non_current, model2_current])
class TestChooserConfirmController(unittest.TestCase):
def test_back(self):
app = make_app()
c = RecoveryChooserConfirmController(app)
c.model = mock.Mock(selection='selection')
c.model = mock.Mock(selection="selection")
c.back()
app.respond.assert_not_called()
app.exit.assert_not_called()
@ -86,12 +80,12 @@ class TestChooserConfirmController(unittest.TestCase):
def test_confirm(self):
app = make_app()
c = RecoveryChooserConfirmController(app)
c.model = mock.Mock(selection='selection')
c.model = mock.Mock(selection="selection")
c.confirm()
app.respond.assert_called_with('selection')
app.respond.assert_called_with("selection")
app.exit.assert_called()
@mock.patch('console_conf.controllers.chooser.ChooserConfirmView')
@mock.patch("console_conf.controllers.chooser.ChooserConfirmView")
def test_confirm_view(self, ccv):
app = make_app()
c = RecoveryChooserConfirmController(app)
@ -102,24 +96,25 @@ class TestChooserConfirmController(unittest.TestCase):
class TestChooserController(unittest.TestCase):
def test_select(self):
app = make_app(model=make_model())
c = RecoveryChooserController(app)
c.select(c.model.systems[0], c.model.systems[0].actions[0])
exp = SelectedSystemAction(system=c.model.systems[0],
action=c.model.systems[0].actions[0])
exp = SelectedSystemAction(
system=c.model.systems[0], action=c.model.systems[0].actions[0]
)
self.assertEqual(c.model.selection, exp)
app.next_screen.assert_called()
app.respond.assert_not_called()
app.exit.assert_not_called()
@mock.patch('console_conf.controllers.chooser.ChooserCurrentSystemView',
return_value='current')
@mock.patch('console_conf.controllers.chooser.ChooserView',
return_value='all')
@mock.patch(
"console_conf.controllers.chooser.ChooserCurrentSystemView",
return_value="current",
)
@mock.patch("console_conf.controllers.chooser.ChooserView", return_value="all")
def test_current_ui_first(self, cv, ccsv):
app = make_app(model=make_model())
c = RecoveryChooserController(app)
@ -128,15 +123,16 @@ class TestChooserController(unittest.TestCase):
# as well as all systems view
cv.assert_called_with(c, c.model.systems)
v = c.make_ui()
self.assertEqual(v, 'current')
self.assertEqual(v, "current")
# user selects more options and the view is replaced
c.more_options()
c.ui.set_body.assert_called_with('all')
c.ui.set_body.assert_called_with("all")
@mock.patch('console_conf.controllers.chooser.ChooserCurrentSystemView',
return_value='current')
@mock.patch('console_conf.controllers.chooser.ChooserView',
return_value='all')
@mock.patch(
"console_conf.controllers.chooser.ChooserCurrentSystemView",
return_value="current",
)
@mock.patch("console_conf.controllers.chooser.ChooserView", return_value="all")
def test_current_current_all_there_and_back(self, cv, ccsv):
app = make_app(model=make_model())
c = RecoveryChooserController(app)
@ -145,21 +141,22 @@ class TestChooserController(unittest.TestCase):
cv.assert_called_with(c, c.model.systems)
v = c.make_ui()
self.assertEqual(v, 'current')
self.assertEqual(v, "current")
# user selects more options and the view is replaced
c.more_options()
c.ui.set_body.assert_called_with('all')
c.ui.set_body.assert_called_with("all")
# go back now
c.back()
c.ui.set_body.assert_called_with('current')
c.ui.set_body.assert_called_with("current")
# nothing
c.back()
c.ui.set_body.not_called()
@mock.patch('console_conf.controllers.chooser.ChooserCurrentSystemView',
return_value='current')
@mock.patch('console_conf.controllers.chooser.ChooserView',
return_value='all')
@mock.patch(
"console_conf.controllers.chooser.ChooserCurrentSystemView",
return_value="current",
)
@mock.patch("console_conf.controllers.chooser.ChooserView", return_value="all")
def test_only_one_and_current(self, cv, ccsv):
model = RecoverySystemsModel.from_systems([model2_current])
app = make_app(model=model)
@ -169,15 +166,16 @@ class TestChooserController(unittest.TestCase):
cv.assert_called_with(c, c.model.systems)
v = c.make_ui()
self.assertEqual(v, 'current')
self.assertEqual(v, "current")
# going back does nothing
c.back()
c.ui.set_body.not_called()
@mock.patch('console_conf.controllers.chooser.ChooserCurrentSystemView',
return_value='current')
@mock.patch('console_conf.controllers.chooser.ChooserView',
return_value='all')
@mock.patch(
"console_conf.controllers.chooser.ChooserCurrentSystemView",
return_value="current",
)
@mock.patch("console_conf.controllers.chooser.ChooserView", return_value="all")
def test_all_systems_first_no_current(self, cv, ccsv):
model = RecoverySystemsModel.from_systems([model1_non_current])
app = make_app(model=model)
@ -192,4 +190,4 @@ class TestChooserController(unittest.TestCase):
ccsv.assert_not_called()
v = c.make_ui()
self.assertEqual(v, 'all')
self.assertEqual(v, "all")

View File

@ -13,13 +13,11 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from console_conf.ui.views import WelcomeView, ChooserWelcomeView
from console_conf.ui.views import ChooserWelcomeView, WelcomeView
from subiquitycore.tuicontroller import TuiController
class WelcomeController(TuiController):
welcome_view = WelcomeView
def make_ui(self):
@ -34,7 +32,6 @@ class WelcomeController(TuiController):
class RecoveryChooserWelcomeController(WelcomeController):
welcome_view = ChooserWelcomeView
def __init__(self, app):

View File

@ -15,18 +15,17 @@
import logging
from subiquitycore.prober import Prober
from subiquitycore.tui import TuiApplication
from console_conf.models.console_conf import ConsoleConfModel
from console_conf.models.systems import RecoverySystemsModel
from subiquitycore.prober import Prober
from subiquitycore.tui import TuiApplication
log = logging.getLogger("console_conf.core")
class ConsoleConf(TuiApplication):
from console_conf import controllers as controllers_mod
project = "console_conf"
make_model = ConsoleConfModel
@ -43,8 +42,8 @@ class ConsoleConf(TuiApplication):
class RecoveryChooser(TuiApplication):
from console_conf import controllers as controllers_mod
project = "console_conf"
controllers = [

View File

@ -1,5 +1,5 @@
from subiquitycore.models.network import NetworkModel
from .identity import IdentityModel

View File

@ -17,8 +17,7 @@ import logging
import attr
log = logging.getLogger('console_conf.models.identity')
log = logging.getLogger("console_conf.models.identity")
@attr.s
@ -30,8 +29,7 @@ class User(object):
class IdentityModel(object):
""" Model representing user identity
"""
"""Model representing user identity"""
def __init__(self):
self._user = None

View File

@ -13,8 +13,8 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging
import json
import logging
import attr
import jsonschema
@ -113,9 +113,7 @@ class RecoverySystemsModel:
m = syst["model"]
b = syst["brand"]
model = SystemModel(
model=m["model"],
brand_id=m["brand-id"],
display_name=m["display-name"]
model=m["model"], brand_id=m["brand-id"], display_name=m["display-name"]
)
brand = Brand(
ID=b["id"],
@ -180,11 +178,13 @@ class RecoverySystem:
actions = attr.ib()
def __eq__(self, other):
return self.current == other.current and \
self.label == other.label and \
self.model == other.model and \
self.brand == other.brand and \
self.actions == other.actions
return (
self.current == other.current
and self.label == other.label
and self.model == other.model
and self.brand == other.brand
and self.actions == other.actions
)
@attr.s
@ -195,10 +195,12 @@ class Brand:
validation = attr.ib()
def __eq__(self, other):
return self.ID == other.ID and \
self.username == other.username and \
self.display_name == other.display_name and \
self.validation == other.validation
return (
self.ID == other.ID
and self.username == other.username
and self.display_name == other.display_name
and self.validation == other.validation
)
@attr.s
@ -208,9 +210,11 @@ class SystemModel:
display_name = attr.ib()
def __eq__(self, other):
return self.model == other.model and \
self.brand_id == other.brand_id and \
self.display_name == other.display_name
return (
self.model == other.model
and self.brand_id == other.brand_id
and self.display_name == other.display_name
)
@attr.s
@ -219,8 +223,7 @@ class SystemAction:
mode = attr.ib()
def __eq__(self, other):
return self.title == other.title and \
self.mode == other.mode
return self.title == other.title and self.mode == other.mode
@attr.s
@ -229,5 +232,4 @@ class SelectedSystemAction:
action = attr.ib()
def __eq__(self, other):
return self.system == other.system and \
self.action == other.action
return self.system == other.system and self.action == other.action

View File

@ -13,23 +13,23 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
import json
import jsonschema
import unittest
from io import StringIO
import jsonschema
from console_conf.models.systems import (
RecoverySystemsModel,
RecoverySystem,
Brand,
SystemModel,
SystemAction,
RecoverySystem,
RecoverySystemsModel,
SelectedSystemAction,
SystemAction,
SystemModel,
)
class RecoverySystemsModelTests(unittest.TestCase):
reference = {
"systems": [
{
@ -49,7 +49,7 @@ class RecoverySystemsModelTests(unittest.TestCase):
"actions": [
{"title": "reinstall", "mode": "install"},
{"title": "recover", "mode": "recover"},
]
],
},
{
"label": "other",
@ -66,30 +66,34 @@ class RecoverySystemsModelTests(unittest.TestCase):
},
"actions": [
{"title": "reinstall", "mode": "install"},
]
}
],
},
]
}
def test_from_systems_stream_happy(self):
raw = json.dumps(self.reference)
systems = RecoverySystemsModel.from_systems_stream(StringIO(raw))
exp = RecoverySystemsModel([
exp = RecoverySystemsModel(
[
RecoverySystem(
current=True,
label="1234",
model=SystemModel(
model="core20-amd64",
brand_id="brand-id",
display_name="Core 20 AMD64 system"),
display_name="Core 20 AMD64 system",
),
brand=Brand(
ID="brand-id",
username="brand-username",
display_name="this is my brand",
validation="verified"),
actions=[SystemAction(title="reinstall", mode="install"),
SystemAction(title="recover", mode="recover")]
validation="verified",
),
actions=[
SystemAction(title="reinstall", mode="install"),
SystemAction(title="recover", mode="recover"),
],
),
RecoverySystem(
current=False,
@ -97,15 +101,18 @@ class RecoverySystemsModelTests(unittest.TestCase):
model=SystemModel(
model="my-brand-box",
brand_id="other-brand-id",
display_name="Funky box"),
display_name="Funky box",
),
brand=Brand(
ID="other-brand-id",
username="other-brand",
display_name="my brand",
validation="unproven"),
actions=[SystemAction(title="reinstall", mode="install")]
validation="unproven",
),
])
actions=[SystemAction(title="reinstall", mode="install")],
),
]
)
self.assertEqual(systems.systems, exp.systems)
self.assertEqual(systems.current, exp.systems[0])
@ -114,7 +121,8 @@ class RecoverySystemsModelTests(unittest.TestCase):
RecoverySystemsModel.from_systems_stream(StringIO("{}"))
def test_from_systems_stream_invalid_missing_system_label(self):
raw = json.dumps({
raw = json.dumps(
{
"systems": [
{
"brand": {
@ -131,15 +139,17 @@ class RecoverySystemsModelTests(unittest.TestCase):
"actions": [
{"title": "reinstall", "mode": "install"},
{"title": "recover", "mode": "recover"},
]
],
},
]
})
}
)
with self.assertRaises(jsonschema.ValidationError):
RecoverySystemsModel.from_systems_stream(StringIO(raw))
def test_from_systems_stream_invalid_missing_brand(self):
raw = json.dumps({
raw = json.dumps(
{
"systems": [
{
"label": "1234",
@ -151,15 +161,17 @@ class RecoverySystemsModelTests(unittest.TestCase):
"actions": [
{"title": "reinstall", "mode": "install"},
{"title": "recover", "mode": "recover"},
]
],
},
]
})
}
)
with self.assertRaises(jsonschema.ValidationError):
RecoverySystemsModel.from_systems_stream(StringIO(raw))
def test_from_systems_stream_invalid_missing_model(self):
raw = json.dumps({
raw = json.dumps(
{
"systems": [
{
"label": "1234",
@ -172,15 +184,17 @@ class RecoverySystemsModelTests(unittest.TestCase):
"actions": [
{"title": "reinstall", "mode": "install"},
{"title": "recover", "mode": "recover"},
]
],
},
]
})
}
)
with self.assertRaises(jsonschema.ValidationError):
RecoverySystemsModel.from_systems_stream(StringIO(raw))
def test_from_systems_stream_valid_no_actions(self):
raw = json.dumps({
raw = json.dumps(
{
"systems": [
{
"label": "1234",
@ -198,17 +212,20 @@ class RecoverySystemsModelTests(unittest.TestCase):
"actions": [],
},
]
})
}
)
RecoverySystemsModel.from_systems_stream(StringIO(raw))
def test_selection(self):
raw = json.dumps(self.reference)
model = RecoverySystemsModel.from_systems_stream(StringIO(raw))
model.select(model.systems[1], model.systems[1].actions[0])
self.assertEqual(model.selection,
self.assertEqual(
model.selection,
SelectedSystemAction(
system=model.systems[1],
action=model.systems[1].actions[0]))
system=model.systems[1], action=model.systems[1].actions[0]
),
)
model.unselect()
self.assertIsNone(model.selection)
@ -221,13 +238,16 @@ class RecoverySystemsModelTests(unittest.TestCase):
stream = StringIO()
RecoverySystemsModel.to_response_stream(model.selection, stream)
fromjson = json.loads(stream.getvalue())
self.assertEqual(fromjson, {
self.assertEqual(
fromjson,
{
"label": "other",
"action": {
"mode": "install",
"title": "reinstall",
},
})
},
)
def test_no_current(self):
reference = {
@ -249,7 +269,7 @@ class RecoverySystemsModelTests(unittest.TestCase):
"actions": [
{"title": "reinstall", "mode": "install"},
{"title": "recover", "mode": "recover"},
]
],
},
],
}

View File

@ -15,14 +15,10 @@
""" ConsoleConf UI Views """
from .chooser import ChooserConfirmView, ChooserCurrentSystemView, ChooserView
from .identity import IdentityView
from .login import LoginView
from .welcome import WelcomeView, ChooserWelcomeView
from .chooser import (
ChooserView,
ChooserCurrentSystemView,
ChooserConfirmView
)
from .welcome import ChooserWelcomeView, WelcomeView
__all__ = [
"IdentityView",

View File

@ -20,29 +20,14 @@ Chooser provides a view with recovery chooser actions.
"""
import logging
from urwid import (
connect_signal,
Text,
)
from subiquitycore.ui.buttons import (
danger_btn,
forward_btn,
back_btn,
)
from subiquitycore.ui.actionmenu import (
Action,
ActionMenu,
)
from subiquitycore.ui.container import Pile, ListBox
from subiquitycore.ui.utils import (
button_pile,
screen,
make_action_menu_row,
Color,
)
from subiquitycore.ui.table import TableRow, TablePile
from subiquitycore.view import BaseView
from urwid import Text, connect_signal
from subiquitycore.ui.actionmenu import Action, ActionMenu
from subiquitycore.ui.buttons import back_btn, danger_btn, forward_btn
from subiquitycore.ui.container import ListBox, Pile
from subiquitycore.ui.table import TablePile, TableRow
from subiquitycore.ui.utils import Color, button_pile, make_action_menu_row, screen
from subiquitycore.view import BaseView
log = logging.getLogger("console_conf.ui.views.chooser")
@ -69,28 +54,31 @@ class ChooserCurrentSystemView(ChooserBaseView):
def __init__(self, controller, current, has_more=False):
self.controller = controller
log.debug('more systems available: %s', has_more)
log.debug('current system: %s', current)
log.debug("more systems available: %s", has_more)
log.debug("current system: %s", current)
actions = []
for action in sorted(current.actions, key=by_preferred_action_type):
actions.append(forward_btn(label=action.title.capitalize(),
actions.append(
forward_btn(
label=action.title.capitalize(),
on_press=self._current_system_action,
user_arg=(current, action)))
user_arg=(current, action),
)
)
if has_more:
# add a button to show the other systems
actions.append(Text(""))
actions.append(forward_btn(label="Show all available systems",
on_press=self._more_options))
actions.append(
forward_btn(
label="Show all available systems", on_press=self._more_options
)
)
lb = ListBox(actions)
super().__init__(current,
screen(
lb,
narrow_rows=True,
excerpt=self.excerpt))
super().__init__(current, screen(lb, narrow_rows=True, excerpt=self.excerpt))
def _current_system_action(self, sender, arg):
current, action = arg
@ -104,43 +92,54 @@ class ChooserCurrentSystemView(ChooserBaseView):
class ChooserView(ChooserBaseView):
excerpt = ("Select one of available recovery systems and a desired "
"action to execute.")
excerpt = (
"Select one of available recovery systems and a desired " "action to execute."
)
def __init__(self, controller, systems):
self.controller = controller
heading_table = TablePile([
TableRow([
Color.info_minor(Text(header)) for header in [
"LABEL", "MODEL", "PUBLISHER", ""
heading_table = TablePile(
[
TableRow(
[
Color.info_minor(Text(header))
for header in ["LABEL", "MODEL", "PUBLISHER", ""]
]
])
)
],
spacing=2)
spacing=2,
)
trows = []
systems = sorted(systems,
key=lambda s: (s.brand.display_name,
systems = sorted(
systems,
key=lambda s: (
s.brand.display_name,
s.model.display_name,
s.current,
s.label))
s.label,
),
)
for s in systems:
actions = []
log.debug('actions: %s', s.actions)
log.debug("actions: %s", s.actions)
for act in sorted(s.actions, key=by_preferred_action_type):
actions.append(Action(label=act.title.capitalize(),
value=act,
enabled=True))
actions.append(
Action(label=act.title.capitalize(), value=act, enabled=True)
)
menu = ActionMenu(actions)
connect_signal(menu, 'action', self._system_action, s)
srow = make_action_menu_row([
connect_signal(menu, "action", self._system_action, s)
srow = make_action_menu_row(
[
Text(s.label),
Text(s.model.display_name),
Text(s.brand.display_name),
Text("(installed)" if s.current else ""),
menu,
], menu)
],
menu,
)
trows.append(srow)
systems_table = TablePile(trows, spacing=2)
@ -154,12 +153,15 @@ class ChooserView(ChooserBaseView):
# back to options of current system
buttons.append(back_btn("BACK", on_press=self.back))
super().__init__(controller.model.current,
super().__init__(
controller.model.current,
screen(
rows=rows,
buttons=button_pile(buttons),
focus_buttons=False,
excerpt=self.excerpt))
excerpt=self.excerpt,
),
)
def _system_action(self, sender, action, system):
self.controller.select(system, action)
@ -169,22 +171,24 @@ class ChooserView(ChooserBaseView):
class ChooserConfirmView(ChooserBaseView):
canned_summary = {
"run": "Continue running the system without any changes.",
"recover": ("You have requested to reboot the system into recovery "
"mode."),
"install": ("You are about to {action_lower} the system version "
"recover": ("You have requested to reboot the system into recovery " "mode."),
"install": (
"You are about to {action_lower} the system version "
"{version} for {model} from {publisher}.\n\n"
"This will remove all existing user data on the "
"device.\n\n"
"The system will reboot in the process."),
"The system will reboot in the process."
),
}
default_summary = ("You are about to execute action \"{action}\" using "
default_summary = (
'You are about to execute action "{action}" using '
"system version {version} for device {model} from "
"{publisher}.\n\n"
"Make sure you understand the consequences of "
"performing this action.")
"performing this action."
)
def __init__(self, controller, selection):
self.controller = controller
@ -193,21 +197,21 @@ class ChooserConfirmView(ChooserBaseView):
danger_btn("CONFIRM", on_press=self.confirm),
back_btn("BACK", on_press=self.back),
]
fmt = self.canned_summary.get(selection.action.mode,
self.default_summary)
summary = fmt.format(action=selection.action.title,
fmt = self.canned_summary.get(selection.action.mode, self.default_summary)
summary = fmt.format(
action=selection.action.title,
action_lower=selection.action.title.lower(),
model=selection.system.model.display_name,
publisher=selection.system.brand.display_name,
version=selection.system.label)
version=selection.system.label,
)
rows = [
Text(summary),
]
super().__init__(controller.model.current,
screen(
rows=rows,
buttons=button_pile(buttons),
focus_buttons=False))
super().__init__(
controller.model.current,
screen(rows=rows, buttons=button_pile(buttons), focus_buttons=False),
)
def confirm(self, result):
self.controller.confirm()

View File

@ -14,20 +14,18 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging
from urwid import connect_signal
from subiquitycore.view import BaseView
from subiquitycore.ui.form import EmailField, Form
from subiquitycore.ui.utils import SomethingFailed
from subiquitycore.ui.form import (
Form,
EmailField,
)
from subiquitycore.view import BaseView
log = logging.getLogger("console_conf.ui.views.identity")
sso_help = ("If you do not have an account, visit "
"https://login.ubuntu.com to create one.")
sso_help = (
"If you do not have an account, visit " "https://login.ubuntu.com to create one."
)
class IdentityForm(Form):
@ -45,11 +43,12 @@ class IdentityView(BaseView):
self.controller = controller
self.form = IdentityForm()
connect_signal(self.form, 'submit', self.done)
connect_signal(self.form, 'cancel', self.cancel)
connect_signal(self.form, "submit", self.done)
connect_signal(self.form, "cancel", self.cancel)
super().__init__(self.form.as_screen(focus_buttons=False,
excerpt=_(self.excerpt)))
super().__init__(
self.form.as_screen(focus_buttons=False, excerpt=_(self.excerpt))
)
def cancel(self, button=None):
self.controller.cancel()

View File

@ -24,7 +24,7 @@ from urwid import Text
from subiquitycore.ui.buttons import done_btn
from subiquitycore.ui.container import ListBox, Pile
from subiquitycore.ui.utils import button_pile, Padding
from subiquitycore.ui.utils import Padding, button_pile
from subiquitycore.view import BaseView
log = logging.getLogger("console_conf.ui.views.login")
@ -41,15 +41,23 @@ class LoginView(BaseView):
self.items = []
super().__init__(
Pile([
('pack', Text("")),
Pile(
[
("pack", Text("")),
Padding.center_79(ListBox(self._build_model_inputs())),
('pack', Pile([
('pack', Text("")),
(
"pack",
Pile(
[
("pack", Text("")),
button_pile(self._build_buttons()),
('pack', Text("")),
])),
]))
("pack", Text("")),
]
),
),
]
)
)
def _build_buttons(self):
return [
@ -68,19 +76,19 @@ class LoginView(BaseView):
sl.append(Text("no owner"))
return sl
local_tpl = (
"This device is registered to {realname}.")
local_tpl = "This device is registered to {realname}."
remote_tpl = (
"\n\nRemote access was enabled via authentication with SSO user"
" <{username}>.\nPublic SSH keys were added to the device "
"for remote access.\n\n{realname} can connect remotely to this "
"device via SSH:")
"device via SSH:"
)
sl = []
login_info = {
'realname': user.realname,
'username': user.username,
"realname": user.realname,
"username": user.username,
}
login_text = local_tpl.format(**login_info)
login_text += remote_tpl.format(**login_info)

View File

@ -14,16 +14,11 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from console_conf.models.systems import (
SystemAction,
)
from console_conf.ui.views.chooser import (
by_preferred_action_type,
)
from console_conf.models.systems import SystemAction
from console_conf.ui.views.chooser import by_preferred_action_type
class TestSorting(unittest.TestCase):
def test_simple(self):
data = [
SystemAction(title="b-other", mode="unknown"),
@ -39,5 +34,4 @@ class TestSorting(unittest.TestCase):
SystemAction(title="a-other", mode="some-other"),
SystemAction(title="b-other", mode="unknown"),
]
self.assertSequenceEqual(sorted(data, key=by_preferred_action_type),
exp)
self.assertSequenceEqual(sorted(data, key=by_preferred_action_type), exp)

View File

@ -16,8 +16,9 @@
""" Subiquity """
from subiquitycore import i18n
__all__ = [
'i18n',
"i18n",
]
__version__ = "0.0.5"

View File

@ -1,5 +1,6 @@
import sys
if __name__ == '__main__':
if __name__ == "__main__":
from subiquity.cmd.tui import main
sys.exit(main())

View File

@ -25,46 +25,28 @@ from typing import Callable, Dict, List, Optional, Union
import aiohttp
from subiquitycore.async_helpers import (
run_bg_task,
run_in_thread,
)
from subiquitycore.screen import is_linux_tty
from subiquitycore.tuicontroller import Skip
from subiquitycore.tui import TuiApplication
from subiquitycore.utils import orig_environ
from subiquitycore.view import BaseView
from subiquity.client.controller import Confirm
from subiquity.client.keycodes import (
NoOpKeycodesFilter,
KeyCodesFilter,
)
from subiquity.client.keycodes import KeyCodesFilter, NoOpKeycodesFilter
from subiquity.common.api.client import make_client_for_conn
from subiquity.common.apidef import API
from subiquity.common.errorreport import (
ErrorReporter,
)
from subiquity.common.errorreport import ErrorReporter
from subiquity.common.serialize import from_json
from subiquity.common.types import (
ApplicationState,
ErrorReportKind,
ErrorReportRef,
)
from subiquity.common.types import ApplicationState, ErrorReportKind, ErrorReportRef
from subiquity.journald import journald_listen
from subiquity.server.server import POSTINSTALL_MODEL_NAMES
from subiquity.ui.frame import SubiquityUI
from subiquity.ui.views.error import ErrorReportStretchy
from subiquity.ui.views.help import HelpMenu, ssh_help_texts
from subiquity.ui.views.installprogress import (
InstallConfirmation,
)
from subiquity.ui.views.welcome import (
CloudInitFail,
)
from subiquity.ui.views.installprogress import InstallConfirmation
from subiquity.ui.views.welcome import CloudInitFail
from subiquitycore.async_helpers import run_bg_task, run_in_thread
from subiquitycore.screen import is_linux_tty
from subiquitycore.tui import TuiApplication
from subiquitycore.tuicontroller import Skip
from subiquitycore.utils import orig_environ
from subiquitycore.view import BaseView
log = logging.getLogger('subiquity.client.client')
log = logging.getLogger("subiquity.client.client")
class Abort(Exception):
@ -72,7 +54,8 @@ class Abort(Exception):
self.error_report_ref = error_report_ref
DEBUG_SHELL_INTRO = _("""\
DEBUG_SHELL_INTRO = _(
"""\
Installer shell session activated.
This shell session is running inside the installer environment. You
@ -81,18 +64,19 @@ example by typing Control-D or 'exit'.
Be aware that this is an ephemeral environment. Changes to this
environment will not survive a reboot. If the install has started, the
installed system will be mounted at /target.""")
installed system will be mounted at /target."""
)
class SubiquityClient(TuiApplication):
snapd_socket_path: Optional[str] = '/run/snapd.socket'
snapd_socket_path: Optional[str] = "/run/snapd.socket"
variant: Optional[str] = None
cmdline = ['snap', 'run', 'subiquity']
dryrun_cmdline_module = 'subiquity.cmd.tui'
cmdline = ["snap", "run", "subiquity"]
dryrun_cmdline_module = "subiquity.cmd.tui"
from subiquity.client import controllers as controllers_mod
project = "subiquity"
def make_model(self):
@ -142,11 +126,11 @@ class SubiquityClient(TuiApplication):
except OSError:
self.our_tty = "not a tty"
self.in_make_view_cvar = contextvars.ContextVar(
'in_make_view', default=False)
self.in_make_view_cvar = contextvars.ContextVar("in_make_view", default=False)
self.error_reporter = ErrorReporter(
self.context.child("ErrorReporter"), self.opts.dry_run, self.root)
self.context.child("ErrorReporter"), self.opts.dry_run, self.root
)
self.note_data_for_apport("SnapUpdated", str(self.updated))
self.note_data_for_apport("UsingAnswers", str(bool(self.answers)))
@ -169,12 +153,9 @@ class SubiquityClient(TuiApplication):
def restart(self, remove_last_screen=True, restart_server=False):
log.debug(f"restart {remove_last_screen} {restart_server}")
if self.fg_proc is not None:
log.debug(
"killing foreground process %s before restarting",
self.fg_proc)
log.debug("killing foreground process %s before restarting", self.fg_proc)
self.restarting = True
run_bg_task(
self._kill_fg_proc(remove_last_screen, restart_server))
run_bg_task(self._kill_fg_proc(remove_last_screen, restart_server))
return
if remove_last_screen:
self._remove_last_screen()
@ -187,50 +168,56 @@ class SubiquityClient(TuiApplication):
self.urwid_loop.stop()
cmdline = self.cmdline
if self.opts.dry_run:
cmdline = [
sys.executable, '-m', self.dryrun_cmdline_module,
] + sys.argv[1:] + ['--socket', self.opts.socket]
cmdline = (
[
sys.executable,
"-m",
self.dryrun_cmdline_module,
]
+ sys.argv[1:]
+ ["--socket", self.opts.socket]
)
if self.opts.server_pid is not None:
cmdline.extend(['--server-pid', self.opts.server_pid])
cmdline.extend(["--server-pid", self.opts.server_pid])
log.debug("restarting %r", cmdline)
os.execvpe(cmdline[0], cmdline, orig_environ(os.environ))
def resp_hook(self, response):
headers = response.headers
if 'x-updated' in headers:
if "x-updated" in headers:
if self.server_updated is None:
self.server_updated = headers['x-updated']
elif self.server_updated != headers['x-updated']:
self.server_updated = headers["x-updated"]
elif self.server_updated != headers["x-updated"]:
self.restart(remove_last_screen=False)
raise Abort
status = headers.get('x-status')
if status == 'skip':
status = headers.get("x-status")
if status == "skip":
raise Skip
elif status == 'confirm':
elif status == "confirm":
raise Confirm
if headers.get('x-error-report') is not None:
ref = from_json(ErrorReportRef, headers['x-error-report'])
if headers.get("x-error-report") is not None:
ref = from_json(ErrorReportRef, headers["x-error-report"])
raise Abort(ref)
try:
response.raise_for_status()
except aiohttp.ClientError:
report = self.error_reporter.make_apport_report(
ErrorReportKind.SERVER_REQUEST_FAIL,
"request to {}".format(response.url.path))
"request to {}".format(response.url.path),
)
raise Abort(report.ref())
return response
async def noninteractive_confirmation(self):
await asyncio.sleep(1)
yes = _('yes')
no = _('no')
yes = _("yes")
no = _("no")
answer = no
print(_("Confirmation is required to continue."))
print(_("Add 'autoinstall' to your kernel command line to avoid this"))
print()
prompt = "\n\n{} ({}|{})".format(
_("Continue with autoinstall?"), yes, no)
prompt = "\n\n{} ({}|{})".format(_("Continue with autoinstall?"), yes, no)
while answer != yes:
print(prompt)
answer = await run_in_thread(input)
@ -260,7 +247,8 @@ class SubiquityClient(TuiApplication):
if app_state == ApplicationState.NEEDS_CONFIRMATION:
if confirm_task is None:
confirm_task = asyncio.create_task(
self.noninteractive_confirmation())
self.noninteractive_confirmation()
)
elif confirm_task is not None:
confirm_task.cancel()
confirm_task = None
@ -272,21 +260,20 @@ class SubiquityClient(TuiApplication):
app_status = await self._status_get(app_state)
def subiquity_event_noninteractive(self, event):
if event['SUBIQUITY_EVENT_TYPE'] == 'start':
print('start: ' + event["MESSAGE"])
elif event['SUBIQUITY_EVENT_TYPE'] == 'finish':
print('finish: ' + event["MESSAGE"])
if event["SUBIQUITY_EVENT_TYPE"] == "start":
print("start: " + event["MESSAGE"])
elif event["SUBIQUITY_EVENT_TYPE"] == "finish":
print("finish: " + event["MESSAGE"])
async def connect(self):
def p(s):
print(s, end='', flush=True)
print(s, end="", flush=True)
async def spin(message):
p(message + '... ')
p(message + "... ")
while True:
for t in ['-', '\\', '|', '/']:
p('\x08' + t)
for t in ["-", "\\", "|", "/"]:
p("\x08" + t)
await asyncio.sleep(0.5)
async def spinning_wait(message, task):
@ -295,16 +282,18 @@ class SubiquityClient(TuiApplication):
return await task
finally:
spinner.cancel()
p('\x08 \n')
p("\x08 \n")
status = await spinning_wait("connecting", self._status_get())
journald_listen([status.echo_syslog_id], lambda e: print(e['MESSAGE']))
journald_listen([status.echo_syslog_id], lambda e: print(e["MESSAGE"]))
if status.state == ApplicationState.STARTING_UP:
status = await spinning_wait(
"starting up", self._status_get(cur=status.state))
"starting up", self._status_get(cur=status.state)
)
if status.state == ApplicationState.CLOUD_INIT_WAIT:
status = await spinning_wait(
"waiting for cloud-init", self._status_get(cur=status.state))
"waiting for cloud-init", self._status_get(cur=status.state)
)
if status.state == ApplicationState.EARLY_COMMANDS:
print("running early commands")
status = await self._status_get(cur=status.state)
@ -316,12 +305,13 @@ class SubiquityClient(TuiApplication):
def header_func():
if self.in_make_view_cvar.get():
return {'x-make-view-request': 'yes'}
return {"x-make-view-request": "yes"}
else:
return None
self.client = make_client_for_conn(API, conn, self.resp_hook,
header_func=header_func)
self.client = make_client_for_conn(
API, conn, self.resp_hook, header_func=header_func
)
self.error_reporter.client = self.client
status = await self.connect()
@ -332,11 +322,14 @@ class SubiquityClient(TuiApplication):
texts = ssh_help_texts(ssh_info)
for line in texts:
import urwid
if isinstance(line, urwid.Widget):
line = '\n'.join([
line.decode('utf-8').rstrip()
line = "\n".join(
[
line.decode("utf-8").rstrip()
for line in line.render((1000,)).text
])
]
)
print(line)
return
@ -355,11 +348,11 @@ class SubiquityClient(TuiApplication):
# the progress page
if hasattr(self.controllers, "Progress"):
journald_listen(
[status.event_syslog_id],
self.controllers.Progress.event)
[status.event_syslog_id], self.controllers.Progress.event
)
journald_listen(
[status.log_syslog_id],
self.controllers.Progress.log_line)
[status.log_syslog_id], self.controllers.Progress.log_line
)
if not status.cloud_init_ok:
self.add_global_overlay(CloudInitFail(self))
self.error_reporter.load_reports()
@ -376,18 +369,16 @@ class SubiquityClient(TuiApplication):
# does not matter as the settings will soon be clobbered but
# for a non-interactive one we need to clear things up or the
# prompting for confirmation will be confusing.
os.system('stty sane')
os.system("stty sane")
journald_listen(
[status.event_syslog_id],
self.subiquity_event_noninteractive,
seek=True)
run_bg_task(
self.noninteractive_watch_app_state(status))
[status.event_syslog_id], self.subiquity_event_noninteractive, seek=True
)
run_bg_task(self.noninteractive_watch_app_state(status))
def _exception_handler(self, loop, context):
exc = context.get('exception')
exc = context.get("exception")
if self.restarting:
log.debug('ignoring %s %s during restart', exc, type(exc))
log.debug("ignoring %s %s during restart", exc, type(exc))
return
if isinstance(exc, Abort):
if self.interactive:
@ -405,8 +396,8 @@ class SubiquityClient(TuiApplication):
print("generating crash report")
try:
report = self.make_apport_report(
ErrorReportKind.UI, "Installer UI", interrupt=False,
wait=True)
ErrorReportKind.UI, "Installer UI", interrupt=False, wait=True
)
if report is not None:
print("report saved to {path}".format(path=report.path))
except Exception:
@ -426,7 +417,7 @@ class SubiquityClient(TuiApplication):
# the server up to a second to exit, and then we signal it.
pid = int(self.opts.server_pid)
print(f'giving the server [{pid}] up to a second to exit')
print(f"giving the server [{pid}] up to a second to exit")
for unused in range(10):
try:
if os.waitpid(pid, os.WNOHANG) != (0, 0):
@ -435,9 +426,9 @@ class SubiquityClient(TuiApplication):
# If we attached to an existing server process,
# waitpid will fail.
pass
await asyncio.sleep(.1)
await asyncio.sleep(0.1)
else:
print('killing server {}'.format(pid))
print("killing server {}".format(pid))
os.kill(pid, 2)
os.waitpid(pid, 0)
@ -448,21 +439,19 @@ class SubiquityClient(TuiApplication):
if source.id == source_selection.current_id:
current = source
break
if current is not None and current.variant != 'server':
if current is not None and current.variant != "server":
# If using server to install desktop, mark the controllers
# the TUI client does not currently have interfaces for as
# configured.
needed = POSTINSTALL_MODEL_NAMES.for_variant(current.variant)
for c in self.controllers.instances:
if getattr(c, 'endpoint_name', None) is not None:
if getattr(c, "endpoint_name", None) is not None:
needed.discard(c.endpoint_name)
if needed:
log.info(
"marking additional endpoints as configured: %s",
needed)
log.info("marking additional endpoints as configured: %s", needed)
await self.client.meta.mark_configured.POST(list(needed))
# TODO: remove this when TUI gets an Active Directory screen:
await self.client.meta.mark_configured.POST(['active_directory'])
await self.client.meta.mark_configured.POST(["active_directory"])
await self.client.meta.confirm.POST(self.our_tty)
def add_global_overlay(self, overlay):
@ -477,7 +466,7 @@ class SubiquityClient(TuiApplication):
self.ui.body.remove_overlay(overlay)
def _remove_last_screen(self):
last_screen = self.state_path('last-screen')
last_screen = self.state_path("last-screen")
if os.path.exists(last_screen):
os.unlink(last_screen)
@ -488,7 +477,7 @@ class SubiquityClient(TuiApplication):
def select_initial_screen(self):
last_screen = None
if self.updated:
state_path = self.state_path('last-screen')
state_path = self.state_path("last-screen")
if os.path.exists(state_path):
with open(state_path) as fp:
last_screen = fp.read().strip()
@ -502,7 +491,7 @@ class SubiquityClient(TuiApplication):
async def _select_initial_screen(self, index):
endpoint_names = []
for c in self.controllers.instances[:index]:
if getattr(c, 'endpoint_name', None) is not None:
if getattr(c, "endpoint_name", None) is not None:
endpoint_names.append(c.endpoint_name)
if endpoint_names:
await self.client.meta.mark_configured.POST(endpoint_names)
@ -521,11 +510,12 @@ class SubiquityClient(TuiApplication):
log.debug("showing InstallConfirmation over %s", self.ui.body)
overlay = InstallConfirmation(self)
self.add_global_overlay(overlay)
if self.answers.get('filesystem-confirmed', False):
if self.answers.get("filesystem-confirmed", False):
overlay.ok(None)
async def _start_answers_for_view(
self, controller, view: Union[BaseView, Callable[[], BaseView]]):
self, controller, view: Union[BaseView, Callable[[], BaseView]]
):
def noop():
return view
@ -551,19 +541,19 @@ class SubiquityClient(TuiApplication):
self.in_make_view_cvar.reset(tok)
if new.answers:
run_bg_task(self._start_answers_for_view(new, view))
with open(self.state_path('last-screen'), 'w') as fp:
with open(self.state_path("last-screen"), "w") as fp:
fp.write(new.name)
return view
def show_progress(self):
if hasattr(self.controllers, 'Progress'):
if hasattr(self.controllers, "Progress"):
self.ui.set_body(self.controllers.Progress.progress_view)
def unhandled_input(self, key):
if key == 'f1':
if key == "f1":
if not self.ui.right_icon.current_help:
self.ui.right_icon.open_pop_up()
elif key in ['ctrl z', 'f2']:
elif key in ["ctrl z", "f2"]:
self.debug_shell()
elif self.opts.dry_run:
self.unhandled_input_dry_run(key)
@ -571,22 +561,22 @@ class SubiquityClient(TuiApplication):
super().unhandled_input(key)
def unhandled_input_dry_run(self, key):
if key in ['ctrl e', 'ctrl r']:
interrupt = key == 'ctrl e'
if key in ["ctrl e", "ctrl r"]:
interrupt = key == "ctrl e"
try:
1 / 0
except ZeroDivisionError:
self.make_apport_report(
ErrorReportKind.UNKNOWN, "example", interrupt=interrupt)
elif key == 'ctrl u':
ErrorReportKind.UNKNOWN, "example", interrupt=interrupt
)
elif key == "ctrl u":
1 / 0
elif key == 'ctrl b':
elif key == "ctrl b":
run_bg_task(self.client.dry_run.crash.GET())
else:
super().unhandled_input(key)
def debug_shell(self, after_hook=None):
def _before():
os.system("clear")
print(DEBUG_SHELL_INTRO)
@ -594,7 +584,8 @@ class SubiquityClient(TuiApplication):
env = orig_environ(os.environ)
cmd = ["bash"]
self.run_command_in_foreground(
cmd, env=env, before_hook=_before, after_hook=after_hook, cwd='/')
cmd, env=env, before_hook=_before, after_hook=after_hook, cwd="/"
)
def note_file_for_apport(self, key, path):
self.error_reporter.note_file_for_apport(key, path)
@ -603,8 +594,7 @@ class SubiquityClient(TuiApplication):
self.error_reporter.note_data_for_apport(key, value)
def make_apport_report(self, kind, thing, *, interrupt, wait=False, **kw):
report = self.error_reporter.make_apport_report(
kind, thing, wait=wait, **kw)
report = self.error_reporter.make_apport_report(kind, thing, wait=wait, **kw)
if report is not None and interrupt:
self.show_error_report(report.ref())
@ -614,7 +604,7 @@ class SubiquityClient(TuiApplication):
def show_error_report(self, error_ref):
log.debug("show_error_report %r", error_ref.base)
if isinstance(self.ui.body, BaseView):
w = getattr(self.ui.body._w, 'stretchy', None)
w = getattr(self.ui.body._w, "stretchy", None)
if isinstance(w, ErrorReportStretchy):
# Don't show an error if already looking at one.
return

View File

@ -16,9 +16,7 @@
import logging
from typing import Optional
from subiquitycore.tuicontroller import (
TuiController,
)
from subiquitycore.tuicontroller import TuiController
log = logging.getLogger("subiquity.client.controller")
@ -28,7 +26,6 @@ class Confirm(Exception):
class SubiquityTuiController(TuiController):
endpoint_name: Optional[str] = None
def __init__(self, app):

View File

@ -14,6 +14,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from subiquitycore.tuicontroller import RepeatedController
from .drivers import DriversController
from .filesystem import FilesystemController
from .identity import IdentityController
@ -33,21 +34,21 @@ from .zdev import ZdevController
# see SubiquityClient.controllers for another list
__all__ = [
'DriversController',
'FilesystemController',
'IdentityController',
'KeyboardController',
'MirrorController',
'NetworkController',
'ProgressController',
'ProxyController',
'RefreshController',
'RepeatedController',
'SerialController',
'SnapListController',
'SourceController',
'SSHController',
'UbuntuProController',
'WelcomeController',
'ZdevController',
"DriversController",
"FilesystemController",
"IdentityController",
"KeyboardController",
"MirrorController",
"NetworkController",
"ProgressController",
"ProxyController",
"RefreshController",
"RepeatedController",
"SerialController",
"SnapListController",
"SourceController",
"SSHController",
"UbuntuProController",
"WelcomeController",
"ZdevController",
]

View File

@ -17,18 +17,16 @@ import asyncio
import logging
from typing import List
from subiquity.client.controller import SubiquityTuiController
from subiquity.common.types import DriversPayload, DriversResponse
from subiquity.ui.views.drivers import DriversView, DriversViewStatus
from subiquitycore.tuicontroller import Skip
from subiquity.common.types import DriversPayload, DriversResponse
from subiquity.client.controller import SubiquityTuiController
from subiquity.ui.views.drivers import DriversView, DriversViewStatus
log = logging.getLogger('subiquity.client.controllers.drivers')
log = logging.getLogger("subiquity.client.controllers.drivers")
class DriversController(SubiquityTuiController):
endpoint_name = 'drivers'
endpoint_name = "drivers"
async def make_ui(self) -> DriversView:
response: DriversResponse = await self.endpoint.GET()
@ -37,8 +35,9 @@ class DriversController(SubiquityTuiController):
await self.endpoint.POST(DriversPayload(install=False))
raise Skip
return DriversView(self, response.drivers,
response.install, response.local_only)
return DriversView(
self, response.drivers, response.install, response.local_only
)
async def _wait_drivers(self) -> List[str]:
response: DriversResponse = await self.endpoint.GET(wait=True)
@ -46,12 +45,10 @@ class DriversController(SubiquityTuiController):
return response.drivers
async def run_answers(self):
if 'install' not in self.answers:
if "install" not in self.answers:
return
from subiquitycore.testing.view_helpers import (
click,
)
from subiquitycore.testing.view_helpers import click
view = self.app.ui.body
while view.status == DriversViewStatus.WAITING:
@ -60,7 +57,7 @@ class DriversController(SubiquityTuiController):
click(view.cont_btn.base_widget)
return
view.form.install.value = self.answers['install']
view.form.install.value = self.answers["install"]
click(view.form.done_btn.base_widget)
@ -69,5 +66,4 @@ class DriversController(SubiquityTuiController):
def done(self, install: bool) -> None:
log.debug("DriversController.done next_screen install=%s", install)
self.app.next_screen(
self.endpoint.POST(DriversPayload(install=install)))
self.app.next_screen(self.endpoint.POST(DriversPayload(install=install)))

View File

@ -17,50 +17,37 @@ import asyncio
import logging
from typing import Callable, Optional
from subiquitycore.async_helpers import run_bg_task
from subiquitycore.lsb_release import lsb_release
from subiquitycore.view import BaseView
from subiquity.client.controller import SubiquityTuiController
from subiquity.common.filesystem import gaps
from subiquity.common.filesystem.manipulator import FilesystemManipulator
from subiquity.common.types import (
ProbeStatus,
GuidedCapability,
GuidedChoiceV2,
GuidedStorageResponseV2,
GuidedStorageTargetManual,
ProbeStatus,
StorageResponseV2,
)
from subiquity.models.filesystem import (
Bootloader,
FilesystemModel,
raidlevels_by_value,
)
from subiquity.ui.views import (
FilesystemView,
GuidedDiskSelectionView,
)
from subiquity.ui.views.filesystem.probing import (
SlowProbing,
ProbingFailed,
)
from subiquity.models.filesystem import Bootloader, FilesystemModel, raidlevels_by_value
from subiquity.ui.views import FilesystemView, GuidedDiskSelectionView
from subiquity.ui.views.filesystem.probing import ProbingFailed, SlowProbing
from subiquitycore.async_helpers import run_bg_task
from subiquitycore.lsb_release import lsb_release
from subiquitycore.view import BaseView
log = logging.getLogger("subiquity.client.controllers.filesystem")
class FilesystemController(SubiquityTuiController, FilesystemManipulator):
endpoint_name = 'storage'
endpoint_name = "storage"
def __init__(self, app):
super().__init__(app)
self.model = None
self.answers.setdefault('guided', False)
self.answers.setdefault('tpm-default', False)
self.answers.setdefault('guided-index', 0)
self.answers.setdefault('manual', [])
self.answers.setdefault("guided", False)
self.answers.setdefault("tpm-default", False)
self.answers.setdefault("guided-index", 0)
self.answers.setdefault("manual", [])
self.current_view: Optional[BaseView] = None
async def make_ui(self) -> Callable[[], BaseView]:
@ -90,8 +77,7 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator):
if isinstance(self.ui.body, SlowProbing):
self.ui.set_body(self.current_view)
else:
log.debug("not refreshing the display. Current display is %r",
self.ui.body)
log.debug("not refreshing the display. Current display is %r", self.ui.body)
async def make_guided_ui(
self,
@ -101,12 +87,9 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator):
self.app.show_error_report(status.error_report)
return ProbingFailed(self, status.error_report)
response: StorageResponseV2 = await self.endpoint.v2.GET(
include_raid=True)
response: StorageResponseV2 = await self.endpoint.v2.GET(include_raid=True)
disk_by_id = {
disk.id: disk for disk in response.disks
}
disk_by_id = {disk.id: disk for disk in response.disks}
if status.error_report:
self.app.show_error_report(status.error_report)
@ -118,48 +101,46 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator):
while not isinstance(self.ui.body, GuidedDiskSelectionView):
await asyncio.sleep(0.1)
if self.answers['tpm-default']:
self.app.answers['filesystem-confirmed'] = True
if self.answers["tpm-default"]:
self.app.answers["filesystem-confirmed"] = True
self.ui.body.done(self.ui.body.form)
if self.answers['guided']:
if self.answers["guided"]:
targets = self.ui.body.form.targets
if 'guided-index' in self.answers:
target = targets[self.answers['guided-index']]
elif 'guided-label' in self.answers:
label = self.answers['guided-label']
if "guided-index" in self.answers:
target = targets[self.answers["guided-index"]]
elif "guided-label" in self.answers:
label = self.answers["guided-label"]
disk_by_id = self.ui.body.form.disk_by_id
[target] = [
t
for t in targets
if disk_by_id[t.disk_id].label == label
]
method = self.answers.get('guided-method')
[target] = [t for t in targets if disk_by_id[t.disk_id].label == label]
method = self.answers.get("guided-method")
value = {
'disk': target,
'use_lvm': method == "lvm",
"disk": target,
"use_lvm": method == "lvm",
}
passphrase = self.answers.get('guided-passphrase')
passphrase = self.answers.get("guided-passphrase")
if passphrase is not None:
value['lvm_options'] = {
'encrypt': True,
'luks_options': {
'passphrase': passphrase,
'confirm_passphrase': passphrase,
}
value["lvm_options"] = {
"encrypt": True,
"luks_options": {
"passphrase": passphrase,
"confirm_passphrase": passphrase,
},
}
self.ui.body.form.guided_choice.value = value
self.ui.body.done(None)
self.app.answers['filesystem-confirmed'] = True
self.app.answers["filesystem-confirmed"] = True
while not isinstance(self.ui.body, FilesystemView):
await asyncio.sleep(0.1)
self.finish()
elif self.answers['manual']:
elif self.answers["manual"]:
await self._guided_choice(
GuidedChoiceV2(
target=GuidedStorageTargetManual(),
capability=GuidedCapability.MANUAL))
await self._run_actions(self.answers['manual'])
self.answers['manual'] = []
capability=GuidedCapability.MANUAL,
)
)
await self._run_actions(self.answers["manual"])
self.answers["manual"] = []
def _action_get(self, id):
dev_spec = id[0].split()
@ -168,7 +149,7 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator):
if dev_spec[1] == "index":
dev = self.model.all_disks()[int(dev_spec[2])]
elif dev_spec[1] == "serial":
dev = self.model._one(type='disk', serial=dev_spec[2])
dev = self.model._one(type="disk", serial=dev_spec[2])
elif dev_spec[0] == "raid":
if dev_spec[1] == "name":
for r in self.model.all_raids():
@ -192,16 +173,13 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator):
raise Exception("could not resolve {}".format(id))
def _action_clean_devices_raid(self, devices):
r = {
self._action_get(d): v
for d, v in zip(devices[::2], devices[1::2])
}
r = {self._action_get(d): v for d, v in zip(devices[::2], devices[1::2])}
for d in r:
assert d.ok_for_raid
return r
def _action_clean_devices_vg(self, devices):
r = {self._action_get(d): 'active' for d in devices}
r = {self._action_get(d): "active" for d in devices}
for d in r:
assert d.ok_for_lvm_vg
return r
@ -210,12 +188,13 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator):
return raidlevels_by_value[level]
async def _answers_action(self, action):
from subiquitycore.ui.stretchy import StretchyOverlay
from subiquity.ui.views.filesystem.delete import ConfirmDeleteStretchy
from subiquitycore.ui.stretchy import StretchyOverlay
log.debug("_answers_action %r", action)
if 'obj' in action:
obj = self._action_get(action['obj'])
action_name = action['action']
if "obj" in action:
obj = self._action_get(action["obj"])
action_name = action["action"]
if action_name == "MAKE_BOOT":
action_name = "TOGGLE_BOOT"
if action_name == "CREATE_LV":
@ -223,8 +202,8 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator):
if action_name == "PARTITION":
obj = gaps.largest_gap(obj)
meth = getattr(
self.ui.body.avail_list,
"_{}_{}".format(obj.type, action_name))
self.ui.body.avail_list, "_{}_{}".format(obj.type, action_name)
)
meth(obj)
yield
body = self.ui.body._w
@ -235,34 +214,35 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator):
body.stretchy.done()
else:
async for _ in self._enter_form_data(
body.stretchy.form,
action['data'],
action.get("submit", True)):
body.stretchy.form, action["data"], action.get("submit", True)
):
pass
elif action['action'] == 'create-raid':
elif action["action"] == "create-raid":
self.ui.body.create_raid()
yield
body = self.ui.body._w
async for _ in self._enter_form_data(
body.stretchy.form,
action['data'],
action["data"],
action.get("submit", True),
clean_suffix='raid'):
clean_suffix="raid",
):
pass
elif action['action'] == 'create-vg':
elif action["action"] == "create-vg":
self.ui.body.create_vg()
yield
body = self.ui.body._w
async for _ in self._enter_form_data(
body.stretchy.form,
action['data'],
action["data"],
action.get("submit", True),
clean_suffix='vg'):
clean_suffix="vg",
):
pass
elif action['action'] == 'done':
elif action["action"] == "done":
if not self.ui.body.done_btn.enabled:
raise Exception("answers did not provide complete fs config")
self.app.answers['filesystem-confirmed'] = True
self.app.answers["filesystem-confirmed"] = True
self.finish()
else:
raise Exception("could not process action {}".format(action))
@ -278,8 +258,8 @@ class FilesystemController(SubiquityTuiController, FilesystemManipulator):
if self.model.bootloader == Bootloader.PREP:
self.supports_resilient_boot = False
else:
release = lsb_release(dry_run=self.app.opts.dry_run)['release']
self.supports_resilient_boot = release >= '20.04'
release = lsb_release(dry_run=self.app.opts.dry_run)["release"]
self.supports_resilient_boot = release >= "20.04"
self.ui.set_body(FilesystemView(self.model, self))
def guided_choice(self, choice):

View File

@ -19,34 +19,34 @@ from subiquity.client.controller import SubiquityTuiController
from subiquity.common.types import IdentityData
from subiquity.ui.views import IdentityView
log = logging.getLogger('subiquity.client.controllers.identity')
log = logging.getLogger("subiquity.client.controllers.identity")
class IdentityController(SubiquityTuiController):
endpoint_name = 'identity'
endpoint_name = "identity"
async def make_ui(self):
data = await self.endpoint.GET()
return IdentityView(self, data)
def run_answers(self):
if all(elem in self.answers for elem in
['realname', 'username', 'password', 'hostname']):
if all(
elem in self.answers
for elem in ["realname", "username", "password", "hostname"]
):
identity = IdentityData(
realname=self.answers['realname'],
username=self.answers['username'],
hostname=self.answers['hostname'],
crypted_password=self.answers['password'])
realname=self.answers["realname"],
username=self.answers["username"],
hostname=self.answers["hostname"],
crypted_password=self.answers["password"],
)
self.done(identity)
def cancel(self):
self.app.prev_screen()
def done(self, identity_data):
log.debug(
"IdentityController.done next_screen user_spec=%s",
identity_data)
log.debug("IdentityController.done next_screen user_spec=%s", identity_data)
self.app.next_screen(self.endpoint.POST(identity_data))
async def validate_username(self, username):

View File

@ -20,21 +20,20 @@ from subiquity.client.controller import SubiquityTuiController
from subiquity.common.types import KeyboardSetting
from subiquity.ui.views import KeyboardView
log = logging.getLogger('subiquity.client.controllers.keyboard')
log = logging.getLogger("subiquity.client.controllers.keyboard")
class KeyboardController(SubiquityTuiController):
endpoint_name = 'keyboard'
endpoint_name = "keyboard"
async def make_ui(self):
setup = await self.endpoint.GET()
return KeyboardView(self, setup)
async def run_answers(self):
if 'layout' in self.answers:
layout = self.answers['layout']
variant = self.answers.get('variant', '')
if "layout" in self.answers:
layout = self.answers["layout"]
variant = self.answers.get("variant", "")
await self.apply(KeyboardSetting(layout=layout, variant=variant))
self.done()
@ -43,7 +42,8 @@ class KeyboardController(SubiquityTuiController):
async def needs_toggle(self, setting):
return await self.endpoint.needs_toggle.GET(
layout_code=setting.layout, variant_code=setting.variant)
layout_code=setting.layout, variant_code=setting.variant
)
async def apply(self, setting):
await self.endpoint.POST(setting)

View File

@ -16,22 +16,16 @@
import asyncio
import logging
from subiquity.client.controller import SubiquityTuiController
from subiquity.common.types import MirrorCheckStatus, MirrorGet, MirrorPost
from subiquity.ui.views.mirror import MirrorView
from subiquitycore.tuicontroller import Skip
from subiquity.common.types import (
MirrorCheckStatus,
MirrorGet,
MirrorPost,
)
from subiquity.client.controller import SubiquityTuiController
from subiquity.ui.views.mirror import MirrorView
log = logging.getLogger('subiquity.client.controllers.mirror')
log = logging.getLogger("subiquity.client.controllers.mirror")
class MirrorController(SubiquityTuiController):
endpoint_name = 'mirror'
endpoint_name = "mirror"
async def make_ui(self):
mirror_response: MirrorGet = await self.endpoint.GET()
@ -62,14 +56,13 @@ class MirrorController(SubiquityTuiController):
last_status = self.app.ui.body.last_status
if last_status not in [None, MirrorCheckStatus.RUNNING]:
return
await asyncio.sleep(.1)
await asyncio.sleep(0.1)
if 'mirror' in self.answers:
self.app.ui.body.form.url.value = self.answers['mirror']
if "mirror" in self.answers:
self.app.ui.body.form.url.value = self.answers["mirror"]
await wait_mirror_check()
self.app.ui.body.form._click_done(None)
elif 'country-code' in self.answers \
or 'accept-default' in self.answers:
elif "country-code" in self.answers or "accept-default" in self.answers:
await wait_mirror_check()
self.app.ui.body.form._click_done(None)
@ -78,5 +71,4 @@ class MirrorController(SubiquityTuiController):
def done(self, mirror):
log.debug("MirrorController.done next_screen mirror=%s", mirror)
self.app.next_screen(self.endpoint.POST(
MirrorPost(elected=mirror)))
self.app.next_screen(self.endpoint.POST(MirrorPost(elected=mirror)))

View File

@ -21,6 +21,10 @@ from typing import Optional
from aiohttp import web
from subiquity.client.controller import SubiquityTuiController
from subiquity.common.api.server import make_server_at_path
from subiquity.common.apidef import LinkAction, NetEventAPI
from subiquity.common.types import ErrorReportKind, PackageInstallState
from subiquitycore.async_helpers import run_bg_task
from subiquitycore.controllers.network import NetworkAnswersMixin
from subiquitycore.models.network import (
@ -31,20 +35,11 @@ from subiquitycore.models.network import (
)
from subiquitycore.ui.views.network import NetworkView
from subiquity.client.controller import SubiquityTuiController
from subiquity.common.api.server import make_server_at_path
from subiquity.common.apidef import LinkAction, NetEventAPI
from subiquity.common.types import (
ErrorReportKind,
PackageInstallState,
)
log = logging.getLogger('subiquity.client.controllers.network')
log = logging.getLogger("subiquity.client.controllers.network")
class NetworkController(SubiquityTuiController, NetworkAnswersMixin):
endpoint_name = 'network'
endpoint_name = "network"
def __init__(self, app):
super().__init__(app)
@ -53,18 +48,18 @@ class NetworkController(SubiquityTuiController, NetworkAnswersMixin):
@web.middleware
async def middleware(self, request, handler):
resp = await handler(request)
if resp.get('exception'):
exc = resp['exception']
log.debug(
'request to {} crashed'.format(request.raw_path), exc_info=exc)
if resp.get("exception"):
exc = resp["exception"]
log.debug("request to {} crashed".format(request.raw_path), exc_info=exc)
self.app.make_apport_report(
ErrorReportKind.NETWORK_CLIENT_FAIL,
"request to {}".format(request.raw_path),
exc=exc, interrupt=True)
exc=exc,
interrupt=True,
)
return resp
async def update_link_POST(self, act: LinkAction,
info: NetDevInfo) -> None:
async def update_link_POST(self, act: LinkAction, info: NetDevInfo) -> None:
if self.view is None:
return
if act == LinkAction.NEW:
@ -91,15 +86,17 @@ class NetworkController(SubiquityTuiController, NetworkAnswersMixin):
self.view.show_network_error(stage)
async def wlan_support_install_finished_POST(
self, state: PackageInstallState) -> None:
self, state: PackageInstallState
) -> None:
if self.view:
self.view.update_for_wlan_support_install_state(state.name)
async def subscribe(self):
self.tdir = tempfile.mkdtemp()
self.sock_path = os.path.join(self.tdir, 'socket')
self.sock_path = os.path.join(self.tdir, "socket")
self.site = await make_server_at_path(
self.sock_path, NetEventAPI, self, middlewares=[self.middleware])
self.sock_path, NetEventAPI, self, middlewares=[self.middleware]
)
await self.endpoint.subscription.PUT(self.sock_path)
async def unsubscribe(self):
@ -110,8 +107,8 @@ class NetworkController(SubiquityTuiController, NetworkAnswersMixin):
async def make_ui(self):
network_status = await self.endpoint.GET()
self.view = NetworkView(
self, network_status.devices,
network_status.wlan_support_install_state.name)
self, network_status.devices, network_status.wlan_support_install_state.name
)
await self.subscribe()
return self.view
@ -126,40 +123,37 @@ class NetworkController(SubiquityTuiController, NetworkAnswersMixin):
def done(self):
self.app.next_screen(self.endpoint.POST())
def set_static_config(self, dev_name: str, ip_version: int,
static_config: StaticConfig) -> None:
def set_static_config(
self, dev_name: str, ip_version: int, static_config: StaticConfig
) -> None:
run_bg_task(
self.endpoint.set_static_config.POST(
dev_name, ip_version, static_config))
self.endpoint.set_static_config.POST(dev_name, ip_version, static_config)
)
def enable_dhcp(self, dev_name, ip_version: int) -> None:
run_bg_task(
self.endpoint.enable_dhcp.POST(dev_name, ip_version))
run_bg_task(self.endpoint.enable_dhcp.POST(dev_name, ip_version))
def disable_network(self, dev_name: str, ip_version: int) -> None:
run_bg_task(
self.endpoint.disable.POST(dev_name, ip_version))
run_bg_task(self.endpoint.disable.POST(dev_name, ip_version))
def add_vlan(self, dev_name: str, vlan_id: int):
run_bg_task(
self.endpoint.vlan.PUT(dev_name, vlan_id))
run_bg_task(self.endpoint.vlan.PUT(dev_name, vlan_id))
def set_wlan(self, dev_name: str, wlan: WLANConfig) -> None:
run_bg_task(
self.endpoint.set_wlan.POST(dev_name, wlan))
run_bg_task(self.endpoint.set_wlan.POST(dev_name, wlan))
def start_scan(self, dev_name: str) -> None:
run_bg_task(
self.endpoint.start_scan.POST(dev_name))
run_bg_task(self.endpoint.start_scan.POST(dev_name))
def delete_link(self, dev_name: str):
run_bg_task(self.endpoint.delete.POST(dev_name))
def add_or_update_bond(self, existing_name: Optional[str],
new_name: str, new_info: BondConfig) -> None:
def add_or_update_bond(
self, existing_name: Optional[str], new_name: str, new_info: BondConfig
) -> None:
run_bg_task(
self.endpoint.add_or_edit_bond.POST(
existing_name, new_name, new_info))
self.endpoint.add_or_edit_bond.POST(existing_name, new_name, new_info)
)
async def get_info_for_netdev(self, dev_name: str) -> str:
return await self.endpoint.info.GET(dev_name)

View File

@ -18,25 +18,16 @@ import logging
import aiohttp
from subiquity.client.controller import SubiquityTuiController
from subiquity.common.types import ApplicationState, ShutdownMode
from subiquity.ui.views.installprogress import InstallRunning, ProgressView
from subiquitycore.async_helpers import run_bg_task
from subiquitycore.context import with_context
from subiquity.client.controller import SubiquityTuiController
from subiquity.common.types import (
ApplicationState,
ShutdownMode,
)
from subiquity.ui.views.installprogress import (
InstallRunning,
ProgressView,
)
log = logging.getLogger("subiquity.client.controllers.progress")
class ProgressController(SubiquityTuiController):
def __init__(self, app):
super().__init__(app)
self.progress_view = ProgressView(self)
@ -49,13 +40,13 @@ class ProgressController(SubiquityTuiController):
self.progress_view.event_start(
event["SUBIQUITY_CONTEXT_ID"],
event.get("SUBIQUITY_CONTEXT_PARENT_ID"),
event["MESSAGE"])
event["MESSAGE"],
)
elif event["SUBIQUITY_EVENT_TYPE"] == "finish":
self.progress_view.event_finish(
event["SUBIQUITY_CONTEXT_ID"])
self.progress_view.event_finish(event["SUBIQUITY_CONTEXT_ID"])
def log_line(self, event):
log_line = event['MESSAGE']
log_line = event["MESSAGE"]
self.progress_view.add_log_line(log_line)
def cancel(self):
@ -79,8 +70,7 @@ class ProgressController(SubiquityTuiController):
install_running = None
while True:
try:
app_status = await self.app.client.meta.status.GET(
cur=self.app_state)
app_status = await self.app.client.meta.status.GET(cur=self.app_state)
except aiohttp.ClientError:
await asyncio.sleep(1)
continue
@ -103,7 +93,8 @@ class ProgressController(SubiquityTuiController):
if self.app_state == ApplicationState.RUNNING:
if app_status.confirming_tty != self.app.our_tty:
install_running = InstallRunning(
self.app, app_status.confirming_tty)
self.app, app_status.confirming_tty
)
self.app.add_global_overlay(install_running)
else:
if install_running is not None:
@ -111,7 +102,7 @@ class ProgressController(SubiquityTuiController):
install_running = None
if self.app_state == ApplicationState.DONE:
if self.answers.get('reboot', False):
if self.answers.get("reboot", False):
self.click_reboot()
def make_ui(self):

View File

@ -18,20 +18,19 @@ import logging
from subiquity.client.controller import SubiquityTuiController
from subiquity.ui.views.proxy import ProxyView
log = logging.getLogger('subiquity.client.controllers.proxy')
log = logging.getLogger("subiquity.client.controllers.proxy")
class ProxyController(SubiquityTuiController):
endpoint_name = 'proxy'
endpoint_name = "proxy"
async def make_ui(self):
proxy = await self.endpoint.GET()
return ProxyView(self, proxy)
def run_answers(self):
if 'proxy' in self.answers:
self.done(self.answers['proxy'])
if "proxy" in self.answers:
self.done(self.answers["proxy"])
def cancel(self):
self.app.prev_screen()

View File

@ -18,25 +18,16 @@ import logging
import aiohttp
from subiquitycore.tuicontroller import (
Skip,
)
from subiquity.common.types import (
RefreshCheckState,
)
from subiquity.client.controller import (
SubiquityTuiController,
)
from subiquity.client.controller import SubiquityTuiController
from subiquity.common.types import RefreshCheckState
from subiquity.ui.views.refresh import RefreshView
from subiquitycore.tuicontroller import Skip
log = logging.getLogger('subiquity.client.controllers.refresh')
log = logging.getLogger("subiquity.client.controllers.refresh")
class RefreshController(SubiquityTuiController):
endpoint_name = 'refresh'
endpoint_name = "refresh"
def __init__(self, app):
super().__init__(app)
@ -61,8 +52,10 @@ class RefreshController(SubiquityTuiController):
self.offered_first_time = True
elif index == 2:
if not self.offered_first_time:
if self.status.availability in [RefreshCheckState.UNKNOWN,
RefreshCheckState.AVAILABLE]:
if self.status.availability in [
RefreshCheckState.UNKNOWN,
RefreshCheckState.AVAILABLE,
]:
show = True
else:
raise AssertionError("unexpected index {}".format(index))

View File

@ -19,7 +19,7 @@ from subiquity.client.controller import SubiquityTuiController
from subiquity.ui.views.serial import SerialView
from subiquitycore.tuicontroller import Skip
log = logging.getLogger('subiquity.client.controllers.serial')
log = logging.getLogger("subiquity.client.controllers.serial")
class SerialController(SubiquityTuiController):
@ -30,11 +30,11 @@ class SerialController(SubiquityTuiController):
return SerialView(self, ssh_info)
def run_answers(self):
if 'rich' in self.answers:
if self.answers['rich']:
self.ui.body.rich_btn.base_widget._emit('click')
if "rich" in self.answers:
if self.answers["rich"]:
self.ui.body.rich_btn.base_widget._emit("click")
else:
self.ui.body.basic_btn.base_widget._emit('click')
self.ui.body.basic_btn.base_widget._emit("click")
def done(self, rich):
log.debug("SerialController.done rich %s next_screen", rich)

View File

@ -16,48 +16,37 @@
import logging
from typing import List
from subiquitycore.tuicontroller import (
Skip,
)
from subiquity.client.controller import (
SubiquityTuiController,
)
from subiquity.common.types import (
SnapCheckState,
SnapSelection,
)
from subiquity.client.controller import SubiquityTuiController
from subiquity.common.types import SnapCheckState, SnapSelection
from subiquity.ui.views.snaplist import SnapListView
from subiquitycore.tuicontroller import Skip
log = logging.getLogger('subiquity.client.controllers.snaplist')
log = logging.getLogger("subiquity.client.controllers.snaplist")
class SnapListController(SubiquityTuiController):
endpoint_name = 'snaplist'
endpoint_name = "snaplist"
async def make_ui(self):
data = await self.endpoint.GET()
if data.status == SnapCheckState.FAILED:
# If loading snaps failed or network is disabled, skip the screen.
log.debug('snaplist GET failed, mark done')
log.debug("snaplist GET failed, mark done")
await self.endpoint.POST([])
raise Skip
return SnapListView(self, data)
def run_answers(self):
if 'snaps' in self.answers:
if "snaps" in self.answers:
selections = []
for snap_name, selection in self.answers['snaps'].items():
if 'is_classic' in selection:
selection['classic'] = selection.pop('is_classic')
for snap_name, selection in self.answers["snaps"].items():
if "is_classic" in selection:
selection["classic"] = selection.pop("is_classic")
selections.append(SnapSelection(name=snap_name, **selection))
self.done(selections)
def done(self, selections: List[SnapSelection]):
log.debug(
"SnapListController.done next_screen snaps_to_install=%s",
selections)
log.debug("SnapListController.done next_screen snaps_to_install=%s", selections)
self.app.next_screen(self.endpoint.POST(selections))
def cancel(self, sender=None):

View File

@ -18,26 +18,24 @@ import logging
from subiquity.client.controller import SubiquityTuiController
from subiquity.ui.views.source import SourceView
log = logging.getLogger('subiquity.client.controllers.source')
log = logging.getLogger("subiquity.client.controllers.source")
class SourceController(SubiquityTuiController):
endpoint_name = 'source'
endpoint_name = "source"
async def make_ui(self):
sources = await self.endpoint.GET()
return SourceView(self,
sources.sources,
sources.current_id,
sources.search_drivers)
return SourceView(
self, sources.sources, sources.current_id, sources.search_drivers
)
def run_answers(self):
form = self.app.ui.body.form
if "search_drivers" in self.answers:
form.search_drivers.value = self.answers["search_drivers"]
if 'source' in self.answers:
wanted_id = self.answers['source']
if "source" in self.answers:
wanted_id = self.answers["source"]
for bf in form._fields:
if bf is form.search_drivers:
continue
@ -48,6 +46,9 @@ class SourceController(SubiquityTuiController):
self.app.prev_screen()
def done(self, source_id, search_drivers: bool):
log.debug("SourceController.done source_id=%s, search_drivers=%s",
source_id, search_drivers)
log.debug(
"SourceController.done source_id=%s, search_drivers=%s",
source_id,
search_drivers,
)
self.app.next_screen(self.endpoint.POST(source_id, search_drivers))

View File

@ -15,18 +15,13 @@
import logging
from subiquity.client.controller import SubiquityTuiController
from subiquity.common.types import SSHData, SSHFetchIdResponse, SSHFetchIdStatus
from subiquity.ui.views.ssh import SSHView
from subiquitycore.async_helpers import schedule_task
from subiquitycore.context import with_context
from subiquity.client.controller import SubiquityTuiController
from subiquity.common.types import (
SSHData,
SSHFetchIdResponse,
SSHFetchIdStatus,
)
from subiquity.ui.views.ssh import SSHView
log = logging.getLogger('subiquity.client.controllers.ssh')
log = logging.getLogger("subiquity.client.controllers.ssh")
class FetchSSHKeysFailure(Exception):
@ -36,35 +31,31 @@ class FetchSSHKeysFailure(Exception):
class SSHController(SubiquityTuiController):
endpoint_name = 'ssh'
endpoint_name = "ssh"
def __init__(self, app):
super().__init__(app)
self._fetch_task = None
if not self.answers:
identity_answers = self.app.answers.get('Identity', {})
if 'ssh-import-id' in identity_answers:
self.answers['ssh-import-id'] = identity_answers[
'ssh-import-id']
identity_answers = self.app.answers.get("Identity", {})
if "ssh-import-id" in identity_answers:
self.answers["ssh-import-id"] = identity_answers["ssh-import-id"]
async def make_ui(self):
ssh_data = await self.endpoint.GET()
return SSHView(self, ssh_data)
def run_answers(self):
if 'ssh-import-id' in self.answers:
import_id = self.answers['ssh-import-id']
ssh = SSHData(
install_server=True,
authorized_keys=[],
allow_pw=True)
if "ssh-import-id" in self.answers:
import_id = self.answers["ssh-import-id"]
ssh = SSHData(install_server=True, authorized_keys=[], allow_pw=True)
self.fetch_ssh_keys(ssh_import_id=import_id, ssh_data=ssh)
else:
ssh = SSHData(
install_server=self.answers.get("install_server", False),
authorized_keys=self.answers.get("authorized_keys", []),
allow_pw=self.answers.get("pwauth", True))
allow_pw=self.answers.get("pwauth", True),
)
self.done(ssh)
def cancel(self):
@ -75,44 +66,50 @@ class SSHController(SubiquityTuiController):
return
self._fetch_task.cancel()
@with_context(
name="ssh_import_id", description="{ssh_import_id}")
@with_context(name="ssh_import_id", description="{ssh_import_id}")
async def _fetch_ssh_keys(self, *, context, ssh_import_id, ssh_data):
with self.context.child("ssh_import_id", ssh_import_id):
response: SSHFetchIdResponse = await \
self.endpoint.fetch_id.GET(ssh_import_id)
response: SSHFetchIdResponse = await self.endpoint.fetch_id.GET(
ssh_import_id
)
if response.status == SSHFetchIdStatus.IMPORT_ERROR:
if isinstance(self.ui.body, SSHView):
self.ui.body.fetching_ssh_keys_failed(
_("Importing keys failed:"), response.error)
_("Importing keys failed:"), response.error
)
return
elif response.status == SSHFetchIdStatus.FINGERPRINT_ERROR:
if isinstance(self.ui.body, SSHView):
self.ui.body.fetching_ssh_keys_failed(
_("ssh-keygen failed to show fingerprint of"
" downloaded keys:"),
response.error)
_(
"ssh-keygen failed to show fingerprint of"
" downloaded keys:"
),
response.error,
)
return
identities = response.identities
if 'ssh-import-id' in self.app.answers.get("Identity", {}):
ssh_data.authorized_keys = \
[id_.to_authorized_key() for id_ in identities]
if "ssh-import-id" in self.app.answers.get("Identity", {}):
ssh_data.authorized_keys = [
id_.to_authorized_key() for id_ in identities
]
self.done(ssh_data)
else:
if isinstance(self.ui.body, SSHView):
self.ui.body.confirm_ssh_keys(
ssh_data, ssh_import_id, identities)
self.ui.body.confirm_ssh_keys(ssh_data, ssh_import_id, identities)
else:
log.debug("ui.body of unexpected instance: %s",
type(self.ui.body).__name__)
log.debug(
"ui.body of unexpected instance: %s",
type(self.ui.body).__name__,
)
def fetch_ssh_keys(self, ssh_import_id, ssh_data):
self._fetch_task = schedule_task(
self._fetch_ssh_keys(
ssh_import_id=ssh_import_id, ssh_data=ssh_data))
self._fetch_ssh_keys(ssh_import_id=ssh_import_id, ssh_data=ssh_data)
)
def done(self, result):
log.debug("SSHController.done next_screen result=%s", result)

View File

@ -22,23 +22,21 @@ from typing import Callable, Optional
import aiohttp
from urwid import Widget
from subiquitycore.async_helpers import schedule_task
from subiquity.client.controller import SubiquityTuiController
from subiquity.common.types import UbuntuProCheckTokenStatus as TokenStatus
from subiquity.common.types import (
UbuntuProInfo,
UbuntuProResponse,
UbuntuProCheckTokenStatus as TokenStatus,
UbuntuProSubscription,
UPCSWaitStatus,
)
from subiquity.ui.views.ubuntu_pro import (
UbuntuProView,
UpgradeYesNoForm,
UpgradeModeForm,
TokenAddedWidget,
UbuntuProView,
UpgradeModeForm,
UpgradeYesNoForm,
)
from subiquitycore.async_helpers import schedule_task
from subiquitycore.lsb_release import lsb_release
from subiquitycore.tuicontroller import Skip
@ -70,9 +68,9 @@ class UbuntuProController(SubiquityTuiController):
raise Skip("Not running LTS version")
ubuntu_pro_info: UbuntuProResponse = await self.endpoint.GET()
return UbuntuProView(self,
token=ubuntu_pro_info.token,
has_network=ubuntu_pro_info.has_network)
return UbuntuProView(
self, token=ubuntu_pro_info.token, has_network=ubuntu_pro_info.has_network
)
async def run_answers(self) -> None:
"""Interact with the UI to go through the pre-attach process if
@ -100,8 +98,7 @@ class UbuntuProController(SubiquityTuiController):
click(find_button_matching(view, UpgradeYesNoForm.ok_label))
def run_token_screen(token: str) -> None:
keypress(view.upgrade_mode_form.with_contract_token.widget,
key="enter")
keypress(view.upgrade_mode_form.with_contract_token.widget, key="enter")
data = {"with_contract_token_subform": {"token": token}}
# TODO: add this point, it would be good to trigger the validation
# code for the token field.
@ -117,13 +114,12 @@ class UbuntuProController(SubiquityTuiController):
# Wait until the "Token added successfully" overlay is shown.
while not find_with_pred(view, is_token_added_overlay):
await asyncio.sleep(.2)
await asyncio.sleep(0.2)
click(find_button_matching(view, TokenAddedWidget.done_label))
def run_subscription_screen() -> None:
click(find_button_matching(view._w,
UbuntuProView.subscription_done_label))
click(find_button_matching(view._w, UbuntuProView.subscription_done_label))
if not self.answers["token"]:
run_yes_no_screen(skip=True)
@ -134,11 +130,14 @@ class UbuntuProController(SubiquityTuiController):
await run_token_added_overlay()
run_subscription_screen()
def check_token(self, token: str,
def check_token(
self,
token: str,
on_success: Callable[[UbuntuProSubscription], None],
on_failure: Callable[[TokenStatus], None],
) -> None:
"""Asynchronously check the token passed as an argument."""
async def inner() -> None:
answer = await self.endpoint.check_token.GET(token)
if answer.status == TokenStatus.VALID_TOKEN:
@ -154,25 +153,28 @@ class UbuntuProController(SubiquityTuiController):
if self._check_task is not None:
self._check_task.cancel()
def contract_selection_initiate(
self, on_initiated: Callable[[str], None]) -> None:
def contract_selection_initiate(self, on_initiated: Callable[[str], None]) -> None:
"""Initiate the contract selection asynchronously. Calls on_initiated
when the contract-selection has initiated."""
async def inner() -> None:
answer = await self.endpoint.contract_selection.initiate.POST()
self.cs_initiated = True
on_initiated(answer.user_code)
self._magic_attach_task = schedule_task(inner())
def contract_selection_wait(
self,
on_contract_selected: Callable[[str], None],
on_timeout: Callable[[], None]) -> None:
on_timeout: Callable[[], None],
) -> None:
"""Asynchronously wait for the contract selection to finish.
Calls on_contract_selected with the contract-token once the contract
gets selected
Otherwise, calls on_timeout if no contract got selected within the
allowed time frame."""
async def inner() -> None:
try:
answer = await self.endpoint.contract_selection.wait.GET()
@ -187,9 +189,9 @@ class UbuntuProController(SubiquityTuiController):
self._monitor_task = schedule_task(inner())
def contract_selection_cancel(
self, on_cancelled: Callable[[], None]) -> None:
def contract_selection_cancel(self, on_cancelled: Callable[[], None]) -> None:
"""Cancel the asynchronous contract selection (if started)."""
async def inner() -> None:
await self.endpoint.contract_selection.cancel.POST()
self.cs_initiated = False
@ -204,9 +206,7 @@ class UbuntuProController(SubiquityTuiController):
def done(self, token: str) -> None:
"""Submit the token and move on to the next screen."""
self.app.next_screen(
self.endpoint.POST(UbuntuProInfo(token=token))
)
self.app.next_screen(self.endpoint.POST(UbuntuProInfo(token=token)))
def next_screen(self) -> None:
"""Move on to the next screen. Assume the token should not be

View File

@ -15,17 +15,16 @@
import logging
from subiquitycore import i18n
from subiquitycore.tuicontroller import Skip
from subiquity.client.controller import SubiquityTuiController
from subiquity.ui.views.welcome import WelcomeView
from subiquitycore import i18n
from subiquitycore.tuicontroller import Skip
log = logging.getLogger('subiquity.client.controllers.welcome')
log = logging.getLogger("subiquity.client.controllers.welcome")
class WelcomeController(SubiquityTuiController):
endpoint_name = 'locale'
endpoint_name = "locale"
async def make_ui(self):
if not self.app.rich_mode:
@ -36,8 +35,8 @@ class WelcomeController(SubiquityTuiController):
return WelcomeView(self, language, self.serial)
def run_answers(self):
if 'lang' in self.answers:
self.done((self.answers['lang'], ""))
if "lang" in self.answers:
self.done((self.answers["lang"], ""))
def done(self, lang):
"""Completes this controller. lang must be a tuple of strings

View File

@ -18,20 +18,18 @@ import logging
from subiquity.client.controller import SubiquityTuiController
from subiquity.ui.views import ZdevView
log = logging.getLogger("subiquity.client.controllers.zdev")
class ZdevController(SubiquityTuiController):
endpoint_name = 'zdev'
endpoint_name = "zdev"
async def make_ui(self):
infos = await self.endpoint.GET()
return ZdevView(self, infos)
def run_answers(self):
if 'accept-default' in self.answers:
if "accept-default" in self.answers:
self.done()
def cancel(self):

View File

@ -19,7 +19,7 @@ import os
import struct
import sys
log = logging.getLogger('subiquity.client.keycodes')
log = logging.getLogger("subiquity.client.keycodes")
# /usr/include/linux/kd.h
K_RAW = 0x00
@ -56,7 +56,7 @@ class KeyCodesFilter:
# well).
o = bytearray(4)
fcntl.ioctl(self._fd, KDGKBMODE, o)
self._old_mode = struct.unpack('i', o)[0]
self._old_mode = struct.unpack("i", o)[0]
# Set the keyboard mode to K_MEDIUMRAW, which causes the keyboard
# driver in the kernel to pass us keycodes.
fcntl.ioctl(self._fd, KDSKBMODE, K_MEDIUMRAW)
@ -76,17 +76,16 @@ class KeyCodesFilter:
while i < len(codes):
# This is straight from showkeys.c.
if codes[i] & 0x80:
p = 'release '
p = "release "
else:
p = 'press '
if i + 2 < n and (codes[i] & 0x7f) == 0:
p = "press "
if i + 2 < n and (codes[i] & 0x7F) == 0:
if (codes[i + 1] & 0x80) != 0:
if (codes[i + 2] & 0x80) != 0:
kc = (((codes[i + 1] & 0x7f) << 7) |
(codes[i + 2] & 0x7f))
kc = ((codes[i + 1] & 0x7F) << 7) | (codes[i + 2] & 0x7F)
i += 3
else:
kc = codes[i] & 0x7f
kc = codes[i] & 0x7F
i += 1
r.append(p + str(kc))
return r

View File

@ -28,11 +28,13 @@ def setup_environment():
locale.setlocale(locale.LC_CTYPE, "C.UTF-8")
# Prefer utils from $SUBIQUITY_ROOT, over system-wide
root = os.environ.get('SUBIQUITY_ROOT')
root = os.environ.get("SUBIQUITY_ROOT")
if root:
os.environ['PATH'] = os.pathsep.join([
os.path.join(root, 'bin'),
os.path.join(root, 'usr', 'bin'),
os.environ['PATH'],
])
os.environ["APPORT_DATA_DIR"] = os.path.join(root, 'share/apport')
os.environ["PATH"] = os.pathsep.join(
[
os.path.join(root, "bin"),
os.path.join(root, "usr", "bin"),
os.environ["PATH"],
]
)
os.environ["APPORT_DATA_DIR"] = os.path.join(root, "share/apport")

View File

@ -29,16 +29,16 @@ from subiquity.server.server import SubiquityServer
def make_schema(app):
schema = copy.deepcopy(app.base_schema)
for controller in app.controllers.instances:
ckey = getattr(controller, 'autoinstall_key', None)
ckey = getattr(controller, "autoinstall_key", None)
if ckey is None:
continue
cschema = getattr(controller, "autoinstall_schema", None)
if cschema is None:
continue
schema['properties'][ckey] = cschema
schema["properties"][ckey] = cschema
ckey_alias = getattr(controller, 'autoinstall_key_alias', None)
ckey_alias = getattr(controller, "autoinstall_key_alias", None)
if ckey_alias is None:
continue
@ -46,15 +46,15 @@ def make_schema(app):
cschema["deprecated"] = True
cschema["description"] = f"Compatibility only - use {ckey} instead"
schema['properties'][ckey_alias] = cschema
schema["properties"][ckey_alias] = cschema
return schema
def make_app():
parser = make_server_args_parser()
opts, unknown = parser.parse_known_args(['--dry-run'])
app = SubiquityServer(opts, '')
opts, unknown = parser.parse_known_args(["--dry-run"])
app = SubiquityServer(opts, "")
# This is needed because the ubuntu-pro server controller accesses dr_cfg
# in the initializer.
app.dr_cfg = DRConfig()
@ -72,5 +72,5 @@ def main():
asyncio.run(run_with_loop())
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(main())

View File

@ -25,10 +25,7 @@ import attr
from subiquitycore.log import setup_logger
from .common import (
LOGDIR,
setup_environment,
)
from .common import LOGDIR, setup_environment
@attr.s(auto_attribs=True)
@ -41,8 +38,8 @@ class CommandLineParams:
def from_cmdline(cls, cmdline):
r = cls(cmdline)
for tok in shlex.split(cmdline):
if '=' in tok:
k, v = tok.split('=', 1)
if "=" in tok:
k, v = tok.split("=", 1)
r._values[k] = v
else:
r._tokens.add(tok)
@ -57,74 +54,103 @@ class CommandLineParams:
def make_server_args_parser():
parser = argparse.ArgumentParser(
description='SUbiquity - Ubiquity for Servers',
prog='subiquity')
parser.add_argument('--dry-run', action='store_true',
dest='dry_run',
help='menu-only, do not call installer function')
parser.add_argument('--dry-run-config', type=argparse.FileType())
parser.add_argument('--socket')
parser.add_argument('--machine-config', metavar='CONFIG',
dest='machine_config',
type=argparse.FileType(),
help="Don't Probe. Use probe data file")
parser.add_argument('--bootloader',
choices=['none', 'bios', 'prep', 'uefi'],
help='Override style of bootloader to use')
description="SUbiquity - Ubiquity for Servers", prog="subiquity"
)
parser.add_argument(
'--autoinstall', action='store',
help=('Path to autoinstall file. Empty value disables autoinstall. '
'By default tries to load /autoinstall.yaml, '
'or autoinstall data from cloud-init.'))
with open('/proc/cmdline') as fp:
"--dry-run",
action="store_true",
dest="dry_run",
help="menu-only, do not call installer function",
)
parser.add_argument("--dry-run-config", type=argparse.FileType())
parser.add_argument("--socket")
parser.add_argument(
"--machine-config",
metavar="CONFIG",
dest="machine_config",
type=argparse.FileType(),
help="Don't Probe. Use probe data file",
)
parser.add_argument(
"--bootloader",
choices=["none", "bios", "prep", "uefi"],
help="Override style of bootloader to use",
)
parser.add_argument(
"--autoinstall",
action="store",
help=(
"Path to autoinstall file. Empty value disables autoinstall. "
"By default tries to load /autoinstall.yaml, "
"or autoinstall data from cloud-init."
),
)
with open("/proc/cmdline") as fp:
cmdline = fp.read()
parser.add_argument(
'--kernel-cmdline', action='store', default=cmdline,
type=CommandLineParams.from_cmdline)
"--kernel-cmdline",
action="store",
default=cmdline,
type=CommandLineParams.from_cmdline,
)
parser.add_argument(
'--snaps-from-examples', action='store_const', const=True,
"--snaps-from-examples",
action="store_const",
const=True,
dest="snaps_from_examples",
help=("Load snap details from examples/snaps instead of store. "
help=(
"Load snap details from examples/snaps instead of store. "
"Default in dry-run mode. "
"See examples/snaps/README.md for more."))
"See examples/snaps/README.md for more."
),
)
parser.add_argument(
'--no-snaps-from-examples', action='store_const', const=False,
"--no-snaps-from-examples",
action="store_const",
const=False,
dest="snaps_from_examples",
help=("Load snap details from store instead of examples. "
help=(
"Load snap details from store instead of examples. "
"Default in when not in dry-run mode. "
"See examples/snaps/README.md for more."))
"See examples/snaps/README.md for more."
),
)
parser.add_argument(
'--snap-section', action='store', default='server',
help=("Show snaps from this section of the store in the snap "
"list screen."))
"--snap-section",
action="store",
default="server",
help=("Show snaps from this section of the store in the snap " "list screen."),
)
parser.add_argument("--source-catalog", dest="source_catalog", action="store")
parser.add_argument(
'--source-catalog', dest='source_catalog', action='store')
"--output-base",
action="store",
dest="output_base",
default=".subiquity",
help="in dryrun, control basedir of files",
)
parser.add_argument("--storage-version", action="store", type=int)
parser.add_argument("--use-os-prober", action="store_true", default=False)
parser.add_argument(
'--output-base', action='store', dest='output_base',
default='.subiquity',
help='in dryrun, control basedir of files')
parser.add_argument(
'--storage-version', action='store', type=int)
parser.add_argument(
'--use-os-prober', action='store_true', default=False)
parser.add_argument(
'--postinst-hooks-dir', default='/etc/subiquity/postinst.d',
type=pathlib.Path)
"--postinst-hooks-dir", default="/etc/subiquity/postinst.d", type=pathlib.Path
)
return parser
def main():
print('starting server')
print("starting server")
setup_environment()
# setup_environment sets $APPORT_DATA_DIR which must be set before
# apport is imported, which is done by this import:
from subiquity.server.server import SubiquityServer
from subiquity.server.dryrun import DRConfig
from subiquity.server.server import SubiquityServer
parser = make_server_args_parser()
opts = parser.parse_args(sys.argv[1:])
if opts.storage_version is None:
opts.storage_version = int(opts.kernel_cmdline.get(
'subiquity-storage-version', 1))
opts.storage_version = int(
opts.kernel_cmdline.get("subiquity-storage-version", 1)
)
logdir = LOGDIR
if opts.dry_run:
if opts.dry_run_config:
@ -139,25 +165,26 @@ def main():
dr_cfg = None
if opts.socket is None:
if opts.dry_run:
opts.socket = opts.output_base + '/socket'
opts.socket = opts.output_base + "/socket"
else:
opts.socket = '/run/subiquity/socket'
opts.socket = "/run/subiquity/socket"
os.makedirs(os.path.dirname(opts.socket), exist_ok=True)
block_log_dir = os.path.join(logdir, "block")
os.makedirs(block_log_dir, exist_ok=True)
handler = logging.FileHandler(os.path.join(block_log_dir, 'discover.log'))
handler.setLevel('DEBUG')
handler = logging.FileHandler(os.path.join(block_log_dir, "discover.log"))
handler.setLevel("DEBUG")
handler.setFormatter(
logging.Formatter("%(asctime)s %(name)s:%(lineno)d %(message)s"))
logging.getLogger('probert').addHandler(handler)
handler.addFilter(lambda rec: rec.name != 'probert.network')
logging.getLogger('curtin').addHandler(handler)
logging.getLogger('block-discover').addHandler(handler)
logging.Formatter("%(asctime)s %(name)s:%(lineno)d %(message)s")
)
logging.getLogger("probert").addHandler(handler)
handler.addFilter(lambda rec: rec.name != "probert.network")
logging.getLogger("curtin").addHandler(handler)
logging.getLogger("block-discover").addHandler(handler)
logfiles = setup_logger(dir=logdir, base='subiquity-server')
logfiles = setup_logger(dir=logdir, base="subiquity-server")
logger = logging.getLogger('subiquity')
logger = logging.getLogger("subiquity")
version = os.environ.get("SNAP_REVISION", "unknown")
snap = os.environ.get("SNAP", "unknown")
logger.info(f"Starting Subiquity server revision {version} of snap {snap}")
@ -168,18 +195,16 @@ def main():
async def run_with_loop():
server = SubiquityServer(opts, block_log_dir)
server.dr_cfg = dr_cfg
server.note_file_for_apport(
"InstallerServerLog", logfiles['debug'])
server.note_file_for_apport(
"InstallerServerLogInfo", logfiles['info'])
server.note_file_for_apport("InstallerServerLog", logfiles["debug"])
server.note_file_for_apport("InstallerServerLogInfo", logfiles["info"])
server.note_file_for_apport(
"UdiLog",
os.path.realpath(
"/var/log/installer/ubuntu_desktop_installer.log"))
os.path.realpath("/var/log/installer/ubuntu_desktop_installer.log"),
)
await server.run()
asyncio.run(run_with_loop())
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(main())

View File

@ -16,53 +16,65 @@
import argparse
import asyncio
import fcntl
import logging
import os
import fcntl
import subprocess
import sys
from subiquitycore.log import setup_logger
from .common import (
LOGDIR,
setup_environment,
)
from .common import LOGDIR, setup_environment
from .server import make_server_args_parser
def make_client_args_parser():
parser = argparse.ArgumentParser(
description='Subiquity - Ubiquity for Servers',
prog='subiquity')
description="Subiquity - Ubiquity for Servers", prog="subiquity"
)
try:
ascii_default = os.ttyname(0) == "/dev/ttysclp0"
except OSError:
ascii_default = False
parser.set_defaults(ascii=ascii_default)
parser.add_argument('--dry-run', action='store_true',
dest='dry_run',
help='menu-only, do not call installer function')
parser.add_argument('--socket')
parser.add_argument('--serial', action='store_true',
dest='run_on_serial',
help='Run the installer over serial console.')
parser.add_argument('--ssh', action='store_true',
dest='ssh',
help='Print ssh login details')
parser.add_argument('--ascii', action='store_true',
dest='ascii',
help='Run the installer in ascii mode.')
parser.add_argument('--unicode', action='store_false',
dest='ascii',
help='Run the installer in unicode mode.')
parser.add_argument('--screens', action='append', dest='screens',
default=[])
parser.add_argument('--answers')
parser.add_argument('--server-pid')
parser.add_argument('--output-base', action='store', dest='output_base',
default='.subiquity',
help='in dryrun, control basedir of files')
parser.add_argument(
"--dry-run",
action="store_true",
dest="dry_run",
help="menu-only, do not call installer function",
)
parser.add_argument("--socket")
parser.add_argument(
"--serial",
action="store_true",
dest="run_on_serial",
help="Run the installer over serial console.",
)
parser.add_argument(
"--ssh", action="store_true", dest="ssh", help="Print ssh login details"
)
parser.add_argument(
"--ascii",
action="store_true",
dest="ascii",
help="Run the installer in ascii mode.",
)
parser.add_argument(
"--unicode",
action="store_false",
dest="ascii",
help="Run the installer in unicode mode.",
)
parser.add_argument("--screens", action="append", dest="screens", default=[])
parser.add_argument("--answers")
parser.add_argument("--server-pid")
parser.add_argument(
"--output-base",
action="store",
dest="output_base",
default=".subiquity",
help="in dryrun, control basedir of files",
)
return parser
@ -74,27 +86,28 @@ def main():
# setup_environment sets $APPORT_DATA_DIR which must be set before
# apport is imported, which is done by this import:
from subiquity.client.client import SubiquityClient
parser = make_client_args_parser()
args = sys.argv[1:]
if '--dry-run' in args:
if "--dry-run" in args:
opts, unknown = parser.parse_known_args(args)
if opts.socket is None:
base = opts.output_base
os.makedirs(base, exist_ok=True)
if os.path.exists(f'{base}/run/subiquity/server-state'):
os.unlink(f'{base}/run/subiquity/server-state')
sock_path = base + '/socket'
if os.path.exists(f"{base}/run/subiquity/server-state"):
os.unlink(f"{base}/run/subiquity/server-state")
sock_path = base + "/socket"
opts.socket = sock_path
server_args = ['--dry-run', '--socket=' + sock_path] + unknown
server_args.extend(('--output-base', base))
server_args = ["--dry-run", "--socket=" + sock_path] + unknown
server_args.extend(("--output-base", base))
server_parser = make_server_args_parser()
server_parser.parse_args(server_args) # just to check
server_stdout = open(os.path.join(base, 'server-stdout'), 'w')
server_stderr = open(os.path.join(base, 'server-stderr'), 'w')
server_cmd = [sys.executable, '-m', 'subiquity.cmd.server'] + \
server_args
server_stdout = open(os.path.join(base, "server-stdout"), "w")
server_stderr = open(os.path.join(base, "server-stderr"), "w")
server_cmd = [sys.executable, "-m", "subiquity.cmd.server"] + server_args
server_proc = subprocess.Popen(
server_cmd, stdout=server_stdout, stderr=server_stderr)
server_cmd, stdout=server_stdout, stderr=server_stderr
)
opts.server_pid = str(server_proc.pid)
print("running server pid {}".format(server_proc.pid))
elif opts.server_pid is not None:
@ -104,14 +117,14 @@ def main():
else:
opts = parser.parse_args(args)
if opts.socket is None:
opts.socket = '/run/subiquity/socket'
opts.socket = "/run/subiquity/socket"
os.makedirs(os.path.basename(opts.socket), exist_ok=True)
logdir = LOGDIR
if opts.dry_run:
logdir = opts.output_base
logfiles = setup_logger(dir=logdir, base='subiquity-client')
logfiles = setup_logger(dir=logdir, base="subiquity-client")
logger = logging.getLogger('subiquity')
logger = logging.getLogger("subiquity")
version = os.environ.get("SNAP_REVISION", "unknown")
snap = os.environ.get("SNAP", "unknown")
logger.info(f"Starting Subiquity TUI revision {version} of snap {snap}")
@ -127,21 +140,18 @@ def main():
try:
fcntl.flock(opts.answers, fcntl.LOCK_EX | fcntl.LOCK_NB)
except OSError:
logger.exception(
'Failed to lock auto answers file, proceding without it.')
logger.exception("Failed to lock auto answers file, proceding without it.")
opts.answers.close()
opts.answers = None
async def run_with_loop():
subiquity_interface = SubiquityClient(opts)
subiquity_interface.note_file_for_apport(
"InstallerLog", logfiles['debug'])
subiquity_interface.note_file_for_apport(
"InstallerLogInfo", logfiles['info'])
subiquity_interface.note_file_for_apport("InstallerLog", logfiles["debug"])
subiquity_interface.note_file_for_apport("InstallerLogInfo", logfiles["info"])
await subiquity_interface.run()
asyncio.run(run_with_loop())
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(main())

View File

@ -19,6 +19,7 @@ import inspect
import aiohttp
from subiquity.common.serialize import Serializer
from .defs import Payload
@ -27,7 +28,7 @@ def _wrap(make_request, path, meth, serializer, serialize_query_args):
meth_params = sig.parameters
payload_arg = None
for name, param in meth_params.items():
if getattr(param.annotation, '__origin__', None) is Payload:
if getattr(param.annotation, "__origin__", None) is Payload:
payload_arg = name
payload_ann = param.annotation.__args__[0]
r_ann = sig.return_annotation
@ -41,14 +42,14 @@ def _wrap(make_request, path, meth, serializer, serialize_query_args):
data = serializer.serialize(payload_ann, value)
else:
if serialize_query_args:
value = serializer.to_json(
meth_params[arg_name].annotation, value)
value = serializer.to_json(meth_params[arg_name].annotation, value)
query_args[arg_name] = value
async with make_request(
meth.__name__, path.format(**self.path_args),
json=data, params=query_args) as resp:
meth.__name__, path.format(**self.path_args), json=data, params=query_args
) as resp:
resp.raise_for_status()
return serializer.deserialize(r_ann, await resp.json())
return impl
@ -73,19 +74,24 @@ def make_client_cls(endpoint_cls, make_request, serializer=None):
if serializer is None:
serializer = Serializer()
ns = {'__init__': client_init}
ns = {"__init__": client_init}
for k, v in endpoint_cls.__dict__.items():
if isinstance(v, type):
if getattr(v, '__parameter__', False):
ns['__getitem__'] = make_getitem(v, make_request, serializer)
if getattr(v, "__parameter__", False):
ns["__getitem__"] = make_getitem(v, make_request, serializer)
else:
ns[k] = make_client(v, make_request, serializer)
elif callable(v):
ns[k] = _wrap(make_request, endpoint_cls.fullpath, v, serializer,
endpoint_cls.serialize_query_args)
ns[k] = _wrap(
make_request,
endpoint_cls.fullpath,
v,
serializer,
endpoint_cls.serialize_query_args,
)
return type('ClientFor({})'.format(endpoint_cls.__name__), (object,), ns)
return type("ClientFor({})".format(endpoint_cls.__name__), (object,), ns)
def make_client(endpoint_cls, make_request, serializer=None, path_args=None):
@ -93,10 +99,9 @@ def make_client(endpoint_cls, make_request, serializer=None, path_args=None):
def make_client_for_conn(
endpoint_cls, conn, resp_hook=lambda r: r, serializer=None,
header_func=None):
session = aiohttp.ClientSession(
connector=conn, connector_owner=False)
endpoint_cls, conn, resp_hook=lambda r: r, serializer=None, header_func=None
):
session = aiohttp.ClientSession(connector=conn, connector_owner=False)
@contextlib.asynccontextmanager
async def make_request(method, path, *, params, json):
@ -105,14 +110,14 @@ def make_client_for_conn(
# hardcode something here (I guess the "a" gets sent along to the
# server in the Host: header and the server could in principle do
# something like virtual host based selection but well....)
url = 'http://a' + path
url = "http://a" + path
if header_func is not None:
headers = header_func()
else:
headers = None
async with session.request(
method, url, json=json, params=params,
headers=headers, timeout=0) as response:
method, url, json=json, params=params, headers=headers, timeout=0
) as response:
yield resp_hook(response)
return make_client(endpoint_cls, make_request, serializer)

View File

@ -27,8 +27,10 @@ class InvalidQueryArgs(InvalidAPIDefinition):
self.param = param
def __str__(self):
return (f"{self.callable.__qualname__} does not serialize query "
f"arguments but has non-str parameter '{self.param}'")
return (
f"{self.callable.__qualname__} does not serialize query "
f"arguments but has non-str parameter '{self.param}'"
)
class MultiplePathParameters(InvalidAPIDefinition):
@ -38,45 +40,50 @@ class MultiplePathParameters(InvalidAPIDefinition):
self.param2 = param2
def __str__(self):
return (f"{self.cls.__name__} has multiple path parameters "
f"{self.param1!r} and {self.param2!r}")
return (
f"{self.cls.__name__} has multiple path parameters "
f"{self.param1!r} and {self.param2!r}"
)
def api(cls, prefix_names=(), prefix_path=(), path_params=(),
serialize_query_args=True):
if hasattr(cls, 'serialize_query_args'):
def api(
cls, prefix_names=(), prefix_path=(), path_params=(), serialize_query_args=True
):
if hasattr(cls, "serialize_query_args"):
serialize_query_args = cls.serialize_query_args
else:
cls.serialize_query_args = serialize_query_args
cls.fullpath = '/' + '/'.join(prefix_path)
cls.fullpath = "/" + "/".join(prefix_path)
cls.fullname = prefix_names
seen_path_param = None
for k, v in cls.__dict__.items():
if isinstance(v, type):
v.__shortname__ = k
v.__name__ = cls.__name__ + '.' + k
v.__name__ = cls.__name__ + "." + k
path_part = k
path_param = ()
if getattr(v, '__parameter__', False):
if getattr(v, "__parameter__", False):
if seen_path_param:
raise MultiplePathParameters(cls, seen_path_param, k)
seen_path_param = k
path_part = '{' + path_part + '}'
path_part = "{" + path_part + "}"
path_param = (k,)
api(
v,
prefix_names + (k,),
prefix_path + (path_part,),
path_params + path_param,
serialize_query_args)
serialize_query_args,
)
if callable(v):
v.__qualname__ = cls.__name__ + '.' + k
v.__qualname__ = cls.__name__ + "." + k
if not cls.serialize_query_args:
params = inspect.signature(v).parameters
for param_name, param in params.items():
if param.annotation is not str and \
getattr(param.annotation, '__origin__', None) is not \
Payload:
if (
param.annotation is not str
and getattr(param.annotation, "__origin__", None) is not Payload
):
raise InvalidQueryArgs(v, param)
v.__path_params__ = path_params
return cls
@ -96,8 +103,12 @@ def path_parameter(cls):
def simple_endpoint(typ):
class endpoint:
def GET() -> typ: ...
def POST(data: Payload[typ]): ...
def GET() -> typ:
...
def POST(data: Payload[typ]):
...
return endpoint

View File

@ -24,7 +24,7 @@ from subiquity.common.serialize import Serializer
from .defs import Payload
log = logging.getLogger('subiquity.common.api.server')
log = logging.getLogger("subiquity.common.api.server")
class BindError(Exception):
@ -47,24 +47,26 @@ class SignatureMisatchError(BindError):
self.actual = actual
def __str__(self):
return (f"implementation of {self.methname} has wrong signature, "
f"should be {self.expected} but is {self.actual}")
return (
f"implementation of {self.methname} has wrong signature, "
f"should be {self.expected} but is {self.actual}"
)
def trim(text):
if text is None:
return ''
return ""
elif len(text) > 80:
return text[:77] + '...'
return text[:77] + "..."
else:
return text
async def check_controllers_started(definition, controller, request):
if not hasattr(controller, 'app'):
if not hasattr(controller, "app"):
return
if getattr(definition, 'allowed_before_start', False):
if getattr(definition, "allowed_before_start", False):
return
# Most endpoints should not be responding to requests
@ -76,13 +78,14 @@ async def check_controllers_started(definition, controller, request):
if controller.app.controllers_have_started.is_set():
return
log.debug(f'{request.path} waiting on start')
log.debug(f"{request.path} waiting on start")
await controller.app.controllers_have_started.wait()
log.debug(f'{request.path} resuming')
log.debug(f"{request.path} resuming")
def _make_handler(controller, definition, implementation, serializer,
serialize_query_args):
def _make_handler(
controller, definition, implementation, serializer, serialize_query_args
):
def_sig = inspect.signature(definition)
def_ret_ann = def_sig.return_annotation
def_params = def_sig.parameters
@ -99,27 +102,26 @@ def _make_handler(controller, definition, implementation, serializer,
for param_name in definition.__path_params__:
check_def_params.append(
inspect.Parameter(
param_name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=str))
param_name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=str
)
)
for param_name, param in def_params.items():
if param_name in ('request', 'context'):
if param_name in ("request", "context"):
raise Exception(
"api method {} cannot have parameter called request or "
"context".format(definition))
if getattr(param.annotation, '__origin__', None) is Payload:
"context".format(definition)
)
if getattr(param.annotation, "__origin__", None) is Payload:
data_arg = param_name
data_annotation = param.annotation.__args__[0]
check_def_params.append(param.replace(annotation=data_annotation))
else:
query_args_anns.append(
(param_name, param.annotation, param.default))
query_args_anns.append((param_name, param.annotation, param.default))
check_def_params.append(param)
check_impl_params = [
p for p in impl_params.values()
if p.name not in ('context', 'request')
p for p in impl_params.values() if p.name not in ("context", "request")
]
check_impl_sig = impl_sig.replace(parameters=check_impl_params)
@ -127,17 +129,19 @@ def _make_handler(controller, definition, implementation, serializer,
if check_impl_sig != check_def_sig:
raise SignatureMisatchError(
definition.__qualname__, check_def_sig, check_impl_sig)
definition.__qualname__, check_def_sig, check_impl_sig
)
async def handler(request):
context = controller.context.child(implementation.__name__)
with context:
context.set('request', request)
context.set("request", request)
args = {}
try:
if data_annotation is not None:
args[data_arg] = serializer.from_json(
data_annotation, await request.text())
data_annotation, await request.text()
)
for arg, ann, default in query_args_anns:
if arg in request.query:
v = request.query[arg]
@ -146,34 +150,35 @@ def _make_handler(controller, definition, implementation, serializer,
elif default != inspect._empty:
v = default
else:
raise TypeError(
'missing required argument "{}"'.format(arg))
raise TypeError('missing required argument "{}"'.format(arg))
args[arg] = v
for param_name in definition.__path_params__:
args[param_name] = request.match_info[param_name]
if 'context' in impl_params:
args['context'] = context
if 'request' in impl_params:
args['request'] = request
await check_controllers_started(
definition, controller, request)
if "context" in impl_params:
args["context"] = context
if "request" in impl_params:
args["request"] = request
await check_controllers_started(definition, controller, request)
result = await implementation(**args)
resp = web.json_response(
serializer.serialize(def_ret_ann, result),
headers={'x-status': 'ok'})
headers={"x-status": "ok"},
)
except Exception as exc:
tb = traceback.TracebackException.from_exception(exc)
resp = web.Response(
text="".join(tb.format()),
status=500,
headers={
'x-status': 'error',
'x-error-type': type(exc).__name__,
'x-error-msg': str(exc),
})
resp['exception'] = exc
context.description = '{} {}'.format(resp.status, trim(resp.text))
"x-status": "error",
"x-error-type": type(exc).__name__,
"x-error-msg": str(exc),
},
)
resp["exception"] = exc
context.description = "{} {}".format(resp.status, trim(resp.text))
return resp
handler.controller = controller
return handler
@ -181,7 +186,7 @@ def _make_handler(controller, definition, implementation, serializer,
async def controller_for_request(request):
match_info = await request.app.router.resolve(request)
return getattr(match_info.handler, 'controller', None)
return getattr(match_info.handler, "controller", None)
def bind(router, endpoint, controller, serializer=None, _depth=None):
@ -203,8 +208,9 @@ def bind(router, endpoint, controller, serializer=None, _depth=None):
method=method,
path=endpoint.fullpath,
handler=_make_handler(
controller, v, impl, serializer,
endpoint.serialize_query_args))
controller, v, impl, serializer, endpoint.serialize_query_args
),
)
async def make_server_at_path(socket_path, endpoint, controller, **kw):

View File

@ -17,12 +17,7 @@ import contextlib
import unittest
from subiquity.common.api.client import make_client
from subiquity.common.api.defs import (
api,
InvalidQueryArgs,
path_parameter,
Payload,
)
from subiquity.common.api.defs import InvalidQueryArgs, Payload, api, path_parameter
def extract(c):
@ -46,20 +41,21 @@ class FakeResponse:
class TestClient(unittest.TestCase):
def test_simple(self):
@api
class API:
class endpoint:
def GET() -> str: ...
def POST(data: Payload[str]) -> None: ...
def GET() -> str:
...
def POST(data: Payload[str]) -> None:
...
@contextlib.asynccontextmanager
async def make_request(method, path, *, params, json):
requests.append((method, path, params, json))
if method == "GET":
v = 'value'
v = "value"
else:
v = None
yield FakeResponse(v)
@ -68,85 +64,95 @@ class TestClient(unittest.TestCase):
requests = []
r = extract(client.endpoint.GET())
self.assertEqual(r, 'value')
self.assertEqual(requests, [("GET", '/endpoint', {}, None)])
self.assertEqual(r, "value")
self.assertEqual(requests, [("GET", "/endpoint", {}, None)])
requests = []
r = extract(client.endpoint.POST('value'))
r = extract(client.endpoint.POST("value"))
self.assertEqual(r, None)
self.assertEqual(
requests, [("POST", '/endpoint', {}, 'value')])
self.assertEqual(requests, [("POST", "/endpoint", {}, "value")])
def test_args(self):
@api
class API:
def GET(arg: str) -> str: ...
def GET(arg: str) -> str:
...
@contextlib.asynccontextmanager
async def make_request(method, path, *, params, json):
requests.append((method, path, params, json))
yield FakeResponse(params['arg'])
yield FakeResponse(params["arg"])
client = make_client(API, make_request)
requests = []
r = extract(client.GET(arg='v'))
r = extract(client.GET(arg="v"))
self.assertEqual(r, '"v"')
self.assertEqual(requests, [("GET", '/', {'arg': '"v"'}, None)])
self.assertEqual(requests, [("GET", "/", {"arg": '"v"'}, None)])
def test_path_params(self):
@api
class API:
@path_parameter
class param:
def GET(arg: str) -> str: ...
def GET(arg: str) -> str:
...
@contextlib.asynccontextmanager
async def make_request(method, path, *, params, json):
requests.append((method, path, params, json))
yield FakeResponse(params['arg'])
yield FakeResponse(params["arg"])
client = make_client(API, make_request)
requests = []
r = extract(client['foo'].GET(arg='v'))
r = extract(client["foo"].GET(arg="v"))
self.assertEqual(r, '"v"')
self.assertEqual(requests, [("GET", '/foo', {'arg': '"v"'}, None)])
self.assertEqual(requests, [("GET", "/foo", {"arg": '"v"'}, None)])
def test_serialize_query_args(self):
@api
class API:
serialize_query_args = False
def GET(arg: str) -> str: ...
def GET(arg: str) -> str:
...
class meth:
def GET(arg: str, payload: Payload[int]) -> str: ...
def GET(arg: str, payload: Payload[int]) -> str:
...
class more:
serialize_query_args = True
def GET(arg: str) -> str: ...
def GET(arg: str) -> str:
...
@contextlib.asynccontextmanager
async def make_request(method, path, *, params, json):
requests.append((method, path, params, json))
yield FakeResponse(params['arg'])
yield FakeResponse(params["arg"])
client = make_client(API, make_request)
requests = []
extract(client.GET(arg='v'))
extract(client.meth.GET(arg='v', payload=1))
extract(client.meth.more.GET(arg='v'))
self.assertEqual(requests, [
("GET", '/', {'arg': 'v'}, None),
("GET", '/meth', {'arg': 'v'}, 1),
("GET", '/meth/more', {'arg': '"v"'}, None)])
extract(client.GET(arg="v"))
extract(client.meth.GET(arg="v", payload=1))
extract(client.meth.more.GET(arg="v"))
self.assertEqual(
requests,
[
("GET", "/", {"arg": "v"}, None),
("GET", "/meth", {"arg": "v"}, 1),
("GET", "/meth/more", {"arg": '"v"'}, None),
],
)
class API2:
serialize_query_args = False
def GET(arg: int) -> str: ...
def GET(arg: int) -> str:
...
with self.assertRaises(InvalidQueryArgs):
api(API2)

View File

@ -13,82 +13,77 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import attr
import contextlib
import functools
import unittest
import aiohttp
import attr
from aiohttp import web
from subiquity.common.api.client import make_client
from subiquity.common.api.defs import (
api,
MultiplePathParameters,
path_parameter,
Payload,
api,
path_parameter,
)
from .test_server import (
makeTestClient,
ControllerBase,
)
from .test_server import ControllerBase, makeTestClient
def make_request(client, method, path, *, params, json):
return client.request(
method, path, params=params, json=json)
return client.request(method, path, params=params, json=json)
@contextlib.asynccontextmanager
async def makeE2EClient(api, impl,
*, middlewares=(), make_request=make_request):
async with makeTestClient(
api, impl, middlewares=middlewares) as client:
async def makeE2EClient(api, impl, *, middlewares=(), make_request=make_request):
async with makeTestClient(api, impl, middlewares=middlewares) as client:
mr = functools.partial(make_request, client)
yield make_client(api, mr)
class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
async def test_simple(self):
@api
class API:
def GET() -> str: ...
def GET() -> str:
...
class Impl(ControllerBase):
async def GET(self) -> str:
return 'value'
return "value"
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(await client.GET(), 'value')
self.assertEqual(await client.GET(), "value")
async def test_nested(self):
@api
class API:
class endpoint:
class nested:
def GET() -> str: ...
def GET() -> str:
...
class Impl(ControllerBase):
async def endpoint_nested_GET(self) -> str:
return 'value'
return "value"
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(await client.endpoint.nested.GET(), 'value')
self.assertEqual(await client.endpoint.nested.GET(), "value")
async def test_args(self):
@api
class API:
def GET(arg1: str, arg2: str) -> str: ...
def GET(arg1: str, arg2: str) -> str:
...
class Impl(ControllerBase):
async def GET(self, arg1: str, arg2: str) -> str:
return '{}+{}'.format(arg1, arg2)
return "{}+{}".format(arg1, arg2)
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(
await client.GET(arg1="A", arg2="B"), 'A+B')
self.assertEqual(await client.GET(arg1="A", arg2="B"), "A+B")
async def test_path_args(self):
@api
@ -97,26 +92,29 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
class param1:
@path_parameter
class param2:
def GET(arg1: str, arg2: str) -> str: ...
def GET(arg1: str, arg2: str) -> str:
...
class Impl(ControllerBase):
async def param1_param2_GET(self, param1: str, param2: str,
arg1: str, arg2: str) -> str:
return '{}+{}+{}+{}'.format(param1, param2, arg1, arg2)
async def param1_param2_GET(
self, param1: str, param2: str, arg1: str, arg2: str
) -> str:
return "{}+{}+{}+{}".format(param1, param2, arg1, arg2)
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(
await client["1"]["2"].GET(arg1="A", arg2="B"), '1+2+A+B')
self.assertEqual(await client["1"]["2"].GET(arg1="A", arg2="B"), "1+2+A+B")
def test_only_one_path_parameter(self):
class API:
@path_parameter
class param1:
def GET(arg: str) -> str: ...
def GET(arg: str) -> str:
...
@path_parameter
class param2:
def GET(arg: str) -> str: ...
def GET(arg: str) -> str:
...
with self.assertRaises(MultiplePathParameters):
api(API)
@ -124,35 +122,32 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
async def test_defaults(self):
@api
class API:
def GET(arg1: str = "arg1", arg2: str = "arg2") -> str: ...
def GET(arg1: str = "arg1", arg2: str = "arg2") -> str:
...
class Impl(ControllerBase):
async def GET(self, arg1: str = "arg1", arg2: str = "arg2") -> str:
return '{}+{}'.format(arg1, arg2)
return "{}+{}".format(arg1, arg2)
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(
await client.GET(arg1="A", arg2="B"), 'A+B')
self.assertEqual(
await client.GET(arg1="A"), 'A+arg2')
self.assertEqual(
await client.GET(arg2="B"), 'arg1+B')
self.assertEqual(await client.GET(arg1="A", arg2="B"), "A+B")
self.assertEqual(await client.GET(arg1="A"), "A+arg2")
self.assertEqual(await client.GET(arg2="B"), "arg1+B")
async def test_post(self):
@api
class API:
def POST(data: Payload[dict]) -> str: ...
def POST(data: Payload[dict]) -> str:
...
class Impl(ControllerBase):
async def POST(self, data: dict) -> str:
return data['key']
return data["key"]
async with makeE2EClient(API, Impl()) as client:
self.assertEqual(
await client.POST({'key': 'value'}), 'value')
self.assertEqual(await client.POST({"key": "value"}), "value")
async def test_typed(self):
@attr.s(auto_attribs=True)
class In:
val: int
@ -164,7 +159,8 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
@api
class API:
class doubler:
def POST(data: In) -> Out: ...
def POST(data: In) -> Out:
...
class Impl(ControllerBase):
async def doubler_POST(self, data: In) -> Out:
@ -177,7 +173,8 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
async def test_middleware(self):
@api
class API:
def GET() -> int: ...
def GET() -> int:
...
class Impl(ControllerBase):
async def GET(self) -> int:
@ -185,9 +182,7 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
@web.middleware
async def middleware(request, handler):
return web.Response(
status=200,
headers={'x-status': 'skip'})
return web.Response(status=200, headers={"x-status": "skip"})
class Skip(Exception):
pass
@ -195,15 +190,15 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
@contextlib.asynccontextmanager
async def custom_make_request(client, method, path, *, params, json):
async with make_request(
client, method, path, params=params, json=json) as resp:
if resp.headers.get('x-status') == 'skip':
client, method, path, params=params, json=json
) as resp:
if resp.headers.get("x-status") == "skip":
raise Skip
yield resp
async with makeE2EClient(
API, Impl(),
middlewares=[middleware],
make_request=custom_make_request) as client:
API, Impl(), middlewares=[middleware], make_request=custom_make_request
) as client:
with self.assertRaises(Skip):
await client.GET()
@ -211,10 +206,12 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
@api
class API:
class good:
def GET(x: int) -> int: ...
def GET(x: int) -> int:
...
class bad:
def GET(x: int) -> int: ...
def GET(x: int) -> int:
...
class Impl(ControllerBase):
async def good_GET(self, x: int) -> int:
@ -233,10 +230,12 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
@api
class API:
class good:
def GET(x: int) -> int: ...
def GET(x: int) -> int:
...
class bad:
def GET(x: int) -> int: ...
def GET(x: int) -> int:
...
class Impl(ControllerBase):
async def good_GET(self, x: int) -> int:
@ -248,8 +247,8 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
@web.middleware
async def middleware(request, handler):
resp = await handler(request)
if resp.get('exception'):
resp.headers['x-status'] = 'ERROR'
if resp.get("exception"):
resp.headers["x-status"] = "ERROR"
return resp
class Abort(Exception):
@ -258,15 +257,15 @@ class TestEndToEnd(unittest.IsolatedAsyncioTestCase):
@contextlib.asynccontextmanager
async def custom_make_request(client, method, path, *, params, json):
async with make_request(
client, method, path, params=params, json=json) as resp:
if resp.headers.get('x-status') == 'ERROR':
client, method, path, params=params, json=json
) as resp:
if resp.headers.get("x-status") == "ERROR":
raise Abort
yield resp
async with makeE2EClient(
API, Impl(),
middlewares=[middleware],
make_request=custom_make_request) as client:
API, Impl(), middlewares=[middleware], make_request=custom_make_request
) as client:
r = await client.good.GET(2)
self.assertEqual(r, 3)
with self.assertRaises(Abort):

View File

@ -17,38 +17,30 @@ import contextlib
import unittest
from unittest import mock
from aiohttp.test_utils import TestClient, TestServer
from aiohttp import web
from aiohttp.test_utils import TestClient, TestServer
from subiquitycore.context import Context
from subiquity.common.api.defs import (
allowed_before_start,
api,
path_parameter,
Payload,
)
from subiquity.common.api.defs import Payload, allowed_before_start, api, path_parameter
from subiquity.common.api.server import (
bind,
controller_for_request,
MissingImplementationError,
SignatureMisatchError,
bind,
controller_for_request,
)
from subiquitycore.context import Context
class TestApp:
def report_start_event(self, context, description):
pass
def report_finish_event(self, context, description, result):
pass
project = 'test'
project = "test"
class ControllerBase:
def __init__(self):
self.context = Context.new(TestApp())
self.app = mock.Mock()
@ -64,109 +56,109 @@ async def makeTestClient(api, impl, middlewares=()):
class TestBind(unittest.IsolatedAsyncioTestCase):
async def assertResponse(self, coro, value):
resp = await coro
if resp.headers.get('x-error-msg'):
if resp.headers.get("x-error-msg"):
raise Exception(
resp.headers['x-error-type'] + ': ' +
resp.headers['x-error-msg'])
resp.headers["x-error-type"] + ": " + resp.headers["x-error-msg"]
)
self.assertEqual(resp.status, 200)
self.assertEqual(await resp.json(), value)
async def test_simple(self):
@api
class API:
def GET() -> str: ...
def GET() -> str:
...
class Impl(ControllerBase):
async def GET(self) -> str:
return 'value'
return "value"
async with makeTestClient(API, Impl()) as client:
await self.assertResponse(
client.get("/"), 'value')
await self.assertResponse(client.get("/"), "value")
async def test_nested(self):
@api
class API:
class endpoint:
class nested:
def get(): ...
def get():
...
class Impl(ControllerBase):
async def nested_get(self, request, context):
return 'nested'
return "nested"
async with makeTestClient(API.endpoint, Impl()) as client:
await self.assertResponse(
client.get("/endpoint/nested"), 'nested')
await self.assertResponse(client.get("/endpoint/nested"), "nested")
async def test_args(self):
@api
class API:
def GET(arg: str): ...
def GET(arg: str):
...
class Impl(ControllerBase):
async def GET(self, arg: str):
return arg
async with makeTestClient(API, Impl()) as client:
await self.assertResponse(
client.get('/?arg="whut"'), 'whut')
await self.assertResponse(client.get('/?arg="whut"'), "whut")
async def test_missing_argument(self):
@api
class API:
def GET(arg: str): ...
def GET(arg: str):
...
class Impl(ControllerBase):
async def GET(self, arg: str):
return arg
async with makeTestClient(API, Impl()) as client:
resp = await client.get('/')
resp = await client.get("/")
self.assertEqual(resp.status, 500)
self.assertEqual(resp.headers['x-status'], 'error')
self.assertEqual(resp.headers['x-error-type'], 'TypeError')
self.assertEqual(resp.headers["x-status"], "error")
self.assertEqual(resp.headers["x-error-type"], "TypeError")
self.assertEqual(
resp.headers['x-error-msg'],
'missing required argument "arg"')
resp.headers["x-error-msg"], 'missing required argument "arg"'
)
async def test_error(self):
@api
class API:
def GET(): ...
def GET():
...
class Impl(ControllerBase):
async def GET(self):
return 1 / 0
async with makeTestClient(API, Impl()) as client:
resp = await client.get('/')
resp = await client.get("/")
self.assertEqual(resp.status, 500)
self.assertEqual(resp.headers['x-status'], 'error')
self.assertEqual(
resp.headers['x-error-type'], 'ZeroDivisionError')
self.assertEqual(resp.headers["x-status"], "error")
self.assertEqual(resp.headers["x-error-type"], "ZeroDivisionError")
async def test_post(self):
@api
class API:
def POST(data: Payload[str]) -> str: ...
def POST(data: Payload[str]) -> str:
...
class Impl(ControllerBase):
async def POST(self, data: str) -> str:
return data
async with makeTestClient(API, Impl()) as client:
await self.assertResponse(
client.post("/", json='value'),
'value')
await self.assertResponse(client.post("/", json="value"), "value")
def test_missing_method(self):
@api
class API:
def GET(arg: str): ...
def GET(arg: str):
...
class Impl(ControllerBase):
pass
@ -179,7 +171,8 @@ class TestBind(unittest.IsolatedAsyncioTestCase):
def test_signature_checking(self):
@api
class API:
def GET(arg: str): ...
def GET(arg: str):
...
class Impl(ControllerBase):
async def GET(self, arg: int):
@ -191,31 +184,28 @@ class TestBind(unittest.IsolatedAsyncioTestCase):
self.assertEqual(cm.exception.methname, "API.GET")
async def test_middleware(self):
@web.middleware
async def middleware(request, handler):
resp = await handler(request)
exc = resp.get('exception')
if resp.get('exception') is not None:
resp.headers['x-error-type'] = type(exc).__name__
exc = resp.get("exception")
if resp.get("exception") is not None:
resp.headers["x-error-type"] = type(exc).__name__
return resp
@api
class API:
def GET() -> str: ...
def GET() -> str:
...
class Impl(ControllerBase):
async def GET(self) -> str:
return 1 / 0
async with makeTestClient(
API, Impl(), middlewares=[middleware]) as client:
async with makeTestClient(API, Impl(), middlewares=[middleware]) as client:
resp = await client.get("/")
self.assertEqual(
resp.headers['x-error-type'], 'ZeroDivisionError')
self.assertEqual(resp.headers["x-error-type"], "ZeroDivisionError")
async def test_controller_for_request(self):
seen_controller = None
@web.middleware
@ -227,18 +217,18 @@ class TestBind(unittest.IsolatedAsyncioTestCase):
@api
class API:
class meth:
def GET() -> str: ...
def GET() -> str:
...
class Impl(ControllerBase):
async def GET(self) -> str:
return ''
return ""
impl = Impl()
async with makeTestClient(
API.meth, impl, middlewares=[middleware]) as client:
async with makeTestClient(API.meth, impl, middlewares=[middleware]) as client:
resp = await client.get("/meth")
self.assertEqual(await resp.json(), '')
self.assertEqual(await resp.json(), "")
self.assertIs(impl, seen_controller)
@ -246,14 +236,19 @@ class TestBind(unittest.IsolatedAsyncioTestCase):
@api
class API:
serialize_query_args = False
def GET(arg: str) -> str: ...
def GET(arg: str) -> str:
...
class meth:
def GET(arg: str) -> str: ...
def GET(arg: str) -> str:
...
class more:
serialize_query_args = True
def GET(arg: str) -> str: ...
def GET(arg: str) -> str:
...
class Impl(ControllerBase):
async def GET(self, arg: str) -> str:
@ -266,67 +261,64 @@ class TestBind(unittest.IsolatedAsyncioTestCase):
return arg
async with makeTestClient(API, Impl()) as client:
await self.assertResponse(
client.get('/?arg=whut'), 'whut')
await self.assertResponse(
client.get('/meth?arg=whut'), 'whut')
await self.assertResponse(
client.get('/meth/more?arg="whut"'), 'whut')
await self.assertResponse(client.get("/?arg=whut"), "whut")
await self.assertResponse(client.get("/meth?arg=whut"), "whut")
await self.assertResponse(client.get('/meth/more?arg="whut"'), "whut")
async def test_path_parameters(self):
@api
class API:
@path_parameter
class param:
def GET(arg: int): ...
def GET(arg: int):
...
class Impl(ControllerBase):
async def param_GET(self, param: str, arg: int):
return param + str(arg)
async with makeTestClient(API, Impl()) as client:
await self.assertResponse(
client.get('/value?arg=2'), 'value2')
await self.assertResponse(client.get("/value?arg=2"), "value2")
async def test_early_connect_ok(self):
@api
class API:
class can_be_used_early:
@allowed_before_start
def GET(): ...
def GET():
...
class Impl(ControllerBase):
def __init__(self):
super().__init__()
self.app.controllers_have_started = mock.AsyncMock()
self.app.controllers_have_started.is_set = mock.Mock(
return_value=False)
self.app.controllers_have_started.is_set = mock.Mock(return_value=False)
async def can_be_used_early_GET(self):
pass
impl = Impl()
async with makeTestClient(API, impl) as client:
await client.get('/can_be_used_early')
await client.get("/can_be_used_early")
impl.app.controllers_have_started.wait.assert_not_called()
async def test_early_connect_wait(self):
@api
class API:
class must_not_be_used_early:
def GET(): ...
def GET():
...
class Impl(ControllerBase):
def __init__(self):
super().__init__()
self.app.controllers_have_started = mock.AsyncMock()
self.app.controllers_have_started.is_set = mock.Mock(
return_value=False)
self.app.controllers_have_started.is_set = mock.Mock(return_value=False)
async def must_not_be_used_early_GET(self):
pass
impl = Impl()
async with makeTestClient(API, impl) as client:
await client.get('/must_not_be_used_early')
await client.get("/must_not_be_used_early")
impl.app.controllers_have_started.wait.assert_called_once()

View File

@ -16,6 +16,67 @@
import enum
from typing import List, Optional
from subiquity.common.api.defs import (
Payload,
allowed_before_start,
api,
simple_endpoint,
)
from subiquity.common.types import (
AdAdminNameValidation,
AdConnectionInfo,
AdDomainNameValidation,
AddPartitionV2,
AdJoinResult,
AdPasswordValidation,
AnyStep,
ApplicationState,
ApplicationStatus,
CasperMd5Results,
Change,
CodecsData,
Disk,
DriversPayload,
DriversResponse,
ErrorReportRef,
GuidedChoiceV2,
GuidedStorageResponseV2,
IdentityData,
KeyboardSetting,
KeyboardSetup,
LiveSessionSSHInfo,
MirrorCheckResponse,
MirrorGet,
MirrorPost,
MirrorPostResponse,
MirrorSelectionFallback,
ModifyPartitionV2,
NetworkStatus,
OEMResponse,
PackageInstallState,
ReformatDisk,
RefreshStatus,
ShutdownMode,
SnapInfo,
SnapListResponse,
SnapSelection,
SourceSelectionAndSetting,
SSHData,
SSHFetchIdResponse,
StorageResponse,
StorageResponseV2,
TimeZoneInfo,
UbuntuProCheckTokenAnswer,
UbuntuProInfo,
UbuntuProResponse,
UPCSInitiateResponse,
UPCSWaitResponse,
UsernameValidation,
WSLConfigurationAdvanced,
WSLConfigurationBase,
WSLSetupOptions,
ZdevInfo,
)
from subiquitycore.models.network import (
BondConfig,
NetDevInfo,
@ -23,72 +84,11 @@ from subiquitycore.models.network import (
WLANConfig,
)
from subiquity.common.api.defs import (
allowed_before_start,
api,
Payload,
simple_endpoint,
)
from subiquity.common.types import (
AdConnectionInfo,
AdAdminNameValidation,
AdDomainNameValidation,
AdJoinResult,
AdPasswordValidation,
AddPartitionV2,
AnyStep,
ApplicationState,
ApplicationStatus,
CasperMd5Results,
Change,
Disk,
ErrorReportRef,
GuidedChoiceV2,
GuidedStorageResponseV2,
KeyboardSetting,
KeyboardSetup,
IdentityData,
NetworkStatus,
MirrorSelectionFallback,
MirrorGet,
MirrorPost,
MirrorPostResponse,
MirrorCheckResponse,
ModifyPartitionV2,
ReformatDisk,
RefreshStatus,
ShutdownMode,
CodecsData,
DriversResponse,
DriversPayload,
OEMResponse,
SnapInfo,
SnapListResponse,
SnapSelection,
SourceSelectionAndSetting,
SSHData,
SSHFetchIdResponse,
LiveSessionSSHInfo,
StorageResponse,
StorageResponseV2,
TimeZoneInfo,
UbuntuProInfo,
UbuntuProResponse,
UbuntuProCheckTokenAnswer,
UPCSInitiateResponse,
UPCSWaitResponse,
UsernameValidation,
PackageInstallState,
ZdevInfo,
WSLConfigurationBase,
WSLConfigurationAdvanced,
WSLSetupOptions,
)
@api
class API:
"""The API offered by the subiquity installer process."""
locale = simple_endpoint(str)
proxy = simple_endpoint(str)
updates = simple_endpoint(str)
@ -99,8 +99,7 @@ class API:
class meta:
class status:
@allowed_before_start
def GET(cur: Optional[ApplicationState] = None) \
-> ApplicationStatus:
def GET(cur: Optional[ApplicationState] = None) -> ApplicationStatus:
"""Get the installer state."""
class mark_configured:
@ -111,6 +110,7 @@ class API:
def POST(variant: str) -> None:
"""Choose the install variant -
desktop/server/wsl_setup/wsl_reconfigure"""
def GET() -> str:
"""Get the install variant -
desktop/server/wsl_setup/wsl_reconfigure"""
@ -124,10 +124,12 @@ class API:
"""Restart the server process."""
class ssh_info:
def GET() -> Optional[LiveSessionSSHInfo]: ...
def GET() -> Optional[LiveSessionSSHInfo]:
...
class free_only:
def GET() -> bool: ...
def GET() -> bool:
...
def POST(enable: bool) -> None:
"""Enable or disable free-only mode. Currently only controlls
@ -135,7 +137,8 @@ class API:
confirmation of filesystem changes"""
class interactive_sections:
def GET() -> Optional[List[str]]: ...
def GET() -> Optional[List[str]]:
...
class errors:
class wait:
@ -159,38 +162,55 @@ class API:
"""Start the update and return the change id."""
class progress:
def GET(change_id: str) -> Change: ...
def GET(change_id: str) -> Change:
...
class keyboard:
def GET() -> KeyboardSetup: ...
def POST(data: Payload[KeyboardSetting]): ...
def GET() -> KeyboardSetup:
...
def POST(data: Payload[KeyboardSetting]):
...
class needs_toggle:
def GET(layout_code: str, variant_code: str) -> bool: ...
def GET(layout_code: str, variant_code: str) -> bool:
...
class steps:
def GET(index: Optional[str]) -> AnyStep: ...
def GET(index: Optional[str]) -> AnyStep:
...
class input_source:
def POST(data: Payload[KeyboardSetting],
user: Optional[str] = None) -> None: ...
def POST(
data: Payload[KeyboardSetting], user: Optional[str] = None
) -> None:
...
class source:
def GET() -> SourceSelectionAndSetting: ...
def POST(source_id: str, search_drivers: bool = False) -> None: ...
def GET() -> SourceSelectionAndSetting:
...
def POST(source_id: str, search_drivers: bool = False) -> None:
...
class zdev:
def GET() -> List[ZdevInfo]: ...
def GET() -> List[ZdevInfo]:
...
class chzdev:
def POST(action: str, zdev: ZdevInfo) -> List[ZdevInfo]: ...
def POST(action: str, zdev: ZdevInfo) -> List[ZdevInfo]:
...
class network:
def GET() -> NetworkStatus: ...
def POST() -> None: ...
def GET() -> NetworkStatus:
...
def POST() -> None:
...
class has_network:
def GET() -> bool: ...
def GET() -> bool:
...
class global_addresses:
def GET() -> List[str]:
@ -201,8 +221,12 @@ class API:
The socket must serve the NetEventAPI below.
"""
def PUT(socket_path: str) -> None: ...
def DELETE(socket_path: str) -> None: ...
def PUT(socket_path: str) -> None:
...
def DELETE(socket_path: str) -> None:
...
# These methods could definitely be more RESTish, like maybe a
# GET request to /network/interfaces/$name should return netplan
@ -237,77 +261,100 @@ class API:
# So methods on nics get an extra dev_name: str parameter)
class set_static_config:
def POST(dev_name: str, ip_version: int,
static_config: Payload[StaticConfig]) -> None: ...
def POST(
dev_name: str, ip_version: int, static_config: Payload[StaticConfig]
) -> None:
...
class enable_dhcp:
def POST(dev_name: str, ip_version: int) -> None: ...
def POST(dev_name: str, ip_version: int) -> None:
...
class disable:
def POST(dev_name: str, ip_version: int) -> None: ...
def POST(dev_name: str, ip_version: int) -> None:
...
class vlan:
def PUT(dev_name: str, vlan_id: int) -> None: ...
def PUT(dev_name: str, vlan_id: int) -> None:
...
class add_or_edit_bond:
def POST(existing_name: Optional[str], new_name: str,
bond_config: Payload[BondConfig]) -> None: ...
def POST(
existing_name: Optional[str],
new_name: str,
bond_config: Payload[BondConfig],
) -> None:
...
class start_scan:
def POST(dev_name: str) -> None: ...
def POST(dev_name: str) -> None:
...
class set_wlan:
def POST(dev_name: str, wlan: WLANConfig) -> None: ...
def POST(dev_name: str, wlan: WLANConfig) -> None:
...
class delete:
def POST(dev_name: str) -> None: ...
def POST(dev_name: str) -> None:
...
class info:
def GET(dev_name: str) -> str: ...
def GET(dev_name: str) -> str:
...
class storage:
class guided:
def POST(data: Payload[GuidedChoiceV2]) \
-> StorageResponse:
def POST(data: Payload[GuidedChoiceV2]) -> StorageResponse:
pass
def GET(wait: bool = False, use_cached_result: bool = False) \
-> StorageResponse: ...
def GET(wait: bool = False, use_cached_result: bool = False) -> StorageResponse:
...
def POST(config: Payload[list]): ...
def POST(config: Payload[list]):
...
class dry_run_wait_probe:
"""This endpoint only works in dry-run mode."""
def POST() -> None: ...
def POST() -> None:
...
class reset:
def POST() -> StorageResponse: ...
def POST() -> StorageResponse:
...
class has_rst:
def GET() -> bool:
pass
class has_bitlocker:
def GET() -> List[Disk]: ...
def GET() -> List[Disk]:
...
class v2:
def GET(
wait: bool = False,
include_raid: bool = False,
) -> StorageResponseV2: ...
) -> StorageResponseV2:
...
def POST() -> StorageResponseV2: ...
def POST() -> StorageResponseV2:
...
class orig_config:
def GET() -> StorageResponseV2: ...
def GET() -> StorageResponseV2:
...
class guided:
def GET(wait: bool = False) -> GuidedStorageResponseV2: ...
def POST(data: Payload[GuidedChoiceV2]) \
-> GuidedStorageResponseV2: ...
def GET(wait: bool = False) -> GuidedStorageResponseV2:
...
def POST(data: Payload[GuidedChoiceV2]) -> GuidedStorageResponseV2:
...
class reset:
def POST() -> StorageResponseV2: ...
def POST() -> StorageResponseV2:
...
class ensure_transaction:
"""This call will ensure that a transaction is initiated.
@ -317,17 +364,21 @@ class API:
request that modifies the partitioning configuration (e.g.,
add_partition, edit_partition, ...) but ensure_transaction can
be called early if desired."""
def POST() -> None: ...
def POST() -> None:
...
class reformat_disk:
def POST(data: Payload[ReformatDisk]) \
-> StorageResponseV2: ...
def POST(data: Payload[ReformatDisk]) -> StorageResponseV2:
...
class add_boot_partition:
"""Mark a given disk as bootable, which may cause a partition
to be added to the disk. It is an error to call this for a
disk for which can_be_boot_device is False."""
def POST(disk_id: str) -> StorageResponseV2: ...
def POST(disk_id: str) -> StorageResponseV2:
...
class add_partition:
"""required field format and mount, optional field size
@ -339,15 +390,17 @@ class API:
add_boot_partition for more control over this.
format=None means an unformatted partition
"""
def POST(data: Payload[AddPartitionV2]) \
-> StorageResponseV2: ...
def POST(data: Payload[AddPartitionV2]) -> StorageResponseV2:
...
class delete_partition:
"""required field number
It is an error to modify other Partition fields.
"""
def POST(data: Payload[ModifyPartitionV2]) \
-> StorageResponseV2: ...
def POST(data: Payload[ModifyPartitionV2]) -> StorageResponseV2:
...
class edit_partition:
"""required field number
@ -355,124 +408,173 @@ class API:
It is an error to do wipe=null and change the format.
It is an error to modify other Partition fields.
"""
def POST(data: Payload[ModifyPartitionV2]) \
-> StorageResponseV2: ...
def POST(data: Payload[ModifyPartitionV2]) -> StorageResponseV2:
...
class codecs:
def GET() -> CodecsData: ...
def POST(data: Payload[CodecsData]) -> None: ...
def GET() -> CodecsData:
...
def POST(data: Payload[CodecsData]) -> None:
...
class drivers:
def GET(wait: bool = False) -> DriversResponse: ...
def POST(data: Payload[DriversPayload]) -> None: ...
def GET(wait: bool = False) -> DriversResponse:
...
def POST(data: Payload[DriversPayload]) -> None:
...
class oem:
def GET(wait: bool = False) -> OEMResponse: ...
def GET(wait: bool = False) -> OEMResponse:
...
class snaplist:
def GET(wait: bool = False) -> SnapListResponse: ...
def POST(data: Payload[List[SnapSelection]]): ...
def GET(wait: bool = False) -> SnapListResponse:
...
def POST(data: Payload[List[SnapSelection]]):
...
class snap_info:
def GET(snap_name: str) -> SnapInfo: ...
def GET(snap_name: str) -> SnapInfo:
...
class timezone:
def GET() -> TimeZoneInfo: ...
def POST(tz: str): ...
def GET() -> TimeZoneInfo:
...
def POST(tz: str):
...
class shutdown:
def POST(mode: ShutdownMode, immediate: bool = False): ...
def POST(mode: ShutdownMode, immediate: bool = False):
...
class mirror:
def GET() -> MirrorGet: ...
def GET() -> MirrorGet:
...
def POST(data: Payload[Optional[MirrorPost]]) \
-> MirrorPostResponse: ...
def POST(data: Payload[Optional[MirrorPost]]) -> MirrorPostResponse:
...
class disable_components:
def GET() -> List[str]: ...
def POST(data: Payload[List[str]]): ...
def GET() -> List[str]:
...
def POST(data: Payload[List[str]]):
...
class check_mirror:
class start:
def POST(cancel_ongoing: bool = False) -> None: ...
def POST(cancel_ongoing: bool = False) -> None:
...
class progress:
def GET() -> Optional[MirrorCheckResponse]: ...
def GET() -> Optional[MirrorCheckResponse]:
...
class abort:
def POST() -> None: ...
def POST() -> None:
...
fallback = simple_endpoint(MirrorSelectionFallback)
class ubuntu_pro:
def GET() -> UbuntuProResponse: ...
def POST(data: Payload[UbuntuProInfo]) -> None: ...
def GET() -> UbuntuProResponse:
...
def POST(data: Payload[UbuntuProInfo]) -> None:
...
class skip:
def POST() -> None: ...
def POST() -> None:
...
class check_token:
def GET(token: Payload[str]) \
-> UbuntuProCheckTokenAnswer: ...
def GET(token: Payload[str]) -> UbuntuProCheckTokenAnswer:
...
class contract_selection:
class initiate:
def POST() -> UPCSInitiateResponse: ...
def POST() -> UPCSInitiateResponse:
...
class wait:
def GET() -> UPCSWaitResponse: ...
def GET() -> UPCSWaitResponse:
...
class cancel:
def POST() -> None: ...
def POST() -> None:
...
class identity:
def GET() -> IdentityData: ...
def POST(data: Payload[IdentityData]): ...
def GET() -> IdentityData:
...
def POST(data: Payload[IdentityData]):
...
class validate_username:
def GET(username: str) -> UsernameValidation: ...
def GET(username: str) -> UsernameValidation:
...
class ssh:
def GET() -> SSHData: ...
def POST(data: Payload[SSHData]) -> None: ...
def GET() -> SSHData:
...
def POST(data: Payload[SSHData]) -> None:
...
class fetch_id:
def GET(user_id: str) -> SSHFetchIdResponse: ...
def GET(user_id: str) -> SSHFetchIdResponse:
...
class integrity:
def GET() -> CasperMd5Results: ...
def GET() -> CasperMd5Results:
...
class active_directory:
def GET() -> Optional[AdConnectionInfo]: ...
def GET() -> Optional[AdConnectionInfo]:
...
# POST expects input validated by the check methods below:
def POST(data: Payload[AdConnectionInfo]) -> None: ...
def POST(data: Payload[AdConnectionInfo]) -> None:
...
class has_support:
"""Whether the live system supports Active Directory or not.
Network status is not considered.
Clients should call this before showing the AD page."""
def GET() -> bool: ...
def GET() -> bool:
...
class check_domain_name:
"""Applies basic validation to the candidate domain name,
without calling the domain network."""
def POST(domain_name: Payload[str]) \
-> List[AdDomainNameValidation]: ...
def POST(domain_name: Payload[str]) -> List[AdDomainNameValidation]:
...
class ping_domain_controller:
"""Attempts to contact the controller of the provided domain."""
def POST(domain_name: Payload[str]) \
-> AdDomainNameValidation: ...
def POST(domain_name: Payload[str]) -> AdDomainNameValidation:
...
class check_admin_name:
def POST(admin_name: Payload[str]) -> AdAdminNameValidation: ...
def POST(admin_name: Payload[str]) -> AdAdminNameValidation:
...
class check_password:
def POST(password: Payload[str]) -> AdPasswordValidation: ...
def POST(password: Payload[str]) -> AdPasswordValidation:
...
class join_result:
def GET(wait: bool = True) -> AdJoinResult: ...
def GET(wait: bool = True) -> AdJoinResult:
...
class LinkAction(enum.Enum):
@ -484,19 +586,25 @@ class LinkAction(enum.Enum):
@api
class NetEventAPI:
class wlan_support_install_finished:
def POST(state: PackageInstallState) -> None: ...
def POST(state: PackageInstallState) -> None:
...
class update_link:
def POST(act: LinkAction, info: Payload[NetDevInfo]) -> None: ...
def POST(act: LinkAction, info: Payload[NetDevInfo]) -> None:
...
class route_watch:
def POST(has_default_route: bool) -> None: ...
def POST(has_default_route: bool) -> None:
...
class apply_starting:
def POST() -> None: ...
def POST() -> None:
...
class apply_stopping:
def POST() -> None: ...
def POST() -> None:
...
class apply_error:
def POST(stage: str) -> None: ...
def POST(stage: str) -> None:
...

View File

@ -27,33 +27,20 @@ from typing import Iterable, Set
import apport
import apport.crashdb
import apport.hookutils
import attr
import bson
import requests
import urwid
from subiquitycore.async_helpers import (
run_in_thread,
schedule_task,
)
from subiquity.common.types import ErrorReportKind, ErrorReportRef, ErrorReportState
from subiquitycore.async_helpers import run_in_thread, schedule_task
from subiquity.common.types import (
ErrorReportKind,
ErrorReportRef,
ErrorReportState,
)
log = logging.getLogger('subiquity.common.errorreport')
log = logging.getLogger("subiquity.common.errorreport")
@attr.s(eq=False)
class Upload(metaclass=urwid.MetaSignals):
signals = ['progress']
signals = ["progress"]
bytes_to_send = attr.ib()
bytes_sent = attr.ib(default=0)
@ -68,13 +55,13 @@ class Upload(metaclass=urwid.MetaSignals):
def _progress(self):
os.read(self.pipe_r, 4096)
urwid.emit_signal(self, 'progress')
urwid.emit_signal(self, "progress")
def _bg_update(self, sent, to_send=None):
self.bytes_sent = sent
if to_send is not None:
self.bytes_to_send = to_send
os.write(self.pipe_w, b'x')
os.write(self.pipe_w, b"x")
def stop(self):
asyncio.get_running_loop().remove_reader(self.pipe_r)
@ -84,7 +71,6 @@ class Upload(metaclass=urwid.MetaSignals):
@attr.s(eq=False)
class ErrorReport(metaclass=urwid.MetaSignals):
signals = ["changed"]
reporter = attr.ib()
@ -101,17 +87,19 @@ class ErrorReport(metaclass=urwid.MetaSignals):
@classmethod
def new(cls, reporter, kind):
base = "{:.9f}.{}".format(time.time(), kind.name.lower())
crash_file = open(
os.path.join(reporter.crash_directory, base + ".crash"),
'wb')
crash_file = open(os.path.join(reporter.crash_directory, base + ".crash"), "wb")
pr = apport.Report('Bug')
pr['CrashDB'] = repr(reporter.crashdb_spec)
pr = apport.Report("Bug")
pr["CrashDB"] = repr(reporter.crashdb_spec)
r = cls(
reporter=reporter, base=base, pr=pr, file=crash_file,
reporter=reporter,
base=base,
pr=pr,
file=crash_file,
state=ErrorReportState.INCOMPLETE,
context=reporter.context.child(base))
context=reporter.context.child(base),
)
r.set_meta("kind", kind.name)
return r
@ -119,11 +107,15 @@ class ErrorReport(metaclass=urwid.MetaSignals):
def from_file(cls, reporter, fpath):
base = os.path.splitext(os.path.basename(fpath))[0]
report = cls(
reporter, base, pr=apport.Report(date='???'),
state=ErrorReportState.LOADING, file=open(fpath, 'rb'),
context=reporter.context.child(base))
reporter,
base,
pr=apport.Report(date="???"),
state=ErrorReportState.LOADING,
file=open(fpath, "rb"),
context=reporter.context.child(base),
)
try:
fp = open(report.meta_path, 'r')
fp = open(report.meta_path, "r")
except FileNotFoundError:
pass
else:
@ -140,26 +132,27 @@ class ErrorReport(metaclass=urwid.MetaSignals):
if not self.reporter.dry_run:
self.pr.add_hooks_info(None)
apport.hookutils.attach_hardware(self.pr)
self.pr['Syslog'] = apport.hookutils.recent_syslog(re.compile('.'))
snap_name = os.environ.get('SNAP_NAME', '')
if snap_name != '':
self.pr["Syslog"] = apport.hookutils.recent_syslog(re.compile("."))
snap_name = os.environ.get("SNAP_NAME", "")
if snap_name != "":
self.add_tags([snap_name])
# Because apport-cli will in general be run on a different
# machine, we make some slightly obscure alterations to the report
# to make this go better.
# apport-cli gets upset if neither of these are present.
self.pr['Package'] = 'subiquity ' + os.environ.get(
"SNAP_REVISION", "SNAP_REVISION")
self.pr['SourcePackage'] = 'subiquity'
self.pr["Package"] = "subiquity " + os.environ.get(
"SNAP_REVISION", "SNAP_REVISION"
)
self.pr["SourcePackage"] = "subiquity"
# If ExecutableTimestamp is present, apport-cli will try to check
# that ExecutablePath hasn't changed. But it won't be there.
del self.pr['ExecutableTimestamp']
del self.pr["ExecutableTimestamp"]
# apport-cli gets upset at the probert C extensions it sees in
# here. /proc/maps is very unlikely to be interesting for us
# anyway.
del self.pr['ProcMaps']
del self.pr["ProcMaps"]
self.pr.write(self._file)
async def add_info():
@ -175,6 +168,7 @@ class ErrorReport(metaclass=urwid.MetaSignals):
self._file.close()
self._file = None
urwid.emit_signal(self, "changed")
if wait:
with self._context.child("add_info") as context:
_bg_add_info()
@ -214,9 +208,7 @@ class ErrorReport(metaclass=urwid.MetaSignals):
uploader._bg_update(uploader.bytes_sent + chunk_size)
def _bg_upload():
for_upload = {
"Kind": self.kind.value
}
for_upload = {"Kind": self.kind.value}
for k, v in self.pr.items():
if len(v) < 1024 or k in {
"InstallerLogInfo",
@ -236,8 +228,9 @@ class ErrorReport(metaclass=urwid.MetaSignals):
data = bson.BSON().encode(for_upload)
self.uploader._bg_update(0, len(data))
headers = {
'user-agent': 'subiquity/{}'.format(
os.environ.get("SNAP_VERSION", "SNAP_VERSION")),
"user-agent": "subiquity/{}".format(
os.environ.get("SNAP_VERSION", "SNAP_VERSION")
),
}
response = requests.post(url, data=chunk(data), headers=headers)
response.raise_for_status()
@ -254,28 +247,27 @@ class ErrorReport(metaclass=urwid.MetaSignals):
context.description = oops_id
uploader.stop()
self.uploader = None
urwid.emit_signal(self, 'changed')
urwid.emit_signal(self, "changed")
urwid.emit_signal(self, 'changed')
urwid.emit_signal(self, "changed")
uploader.start()
schedule_task(upload())
def _path_with_ext(self, ext):
return os.path.join(
self.reporter.crash_directory, self.base + '.' + ext)
return os.path.join(self.reporter.crash_directory, self.base + "." + ext)
@property
def meta_path(self):
return self._path_with_ext('meta')
return self._path_with_ext("meta")
@property
def path(self):
return self._path_with_ext('crash')
return self._path_with_ext("crash")
def set_meta(self, key, value):
self.meta[key] = value
with open(self.meta_path, 'w') as fp:
with open(self.meta_path, "w") as fp:
json.dump(self.meta, fp, indent=4)
def mark_seen(self):
@ -300,9 +292,8 @@ class ErrorReport(metaclass=urwid.MetaSignals):
"""Return fs-label, path-on-fs to report."""
# Not sure if this is more or less sane than shelling out to
# findmnt(1).
looking_for = os.path.abspath(
os.path.normpath(self.reporter.crash_directory))
for line in open('/proc/self/mountinfo').readlines():
looking_for = os.path.abspath(os.path.normpath(self.reporter.crash_directory))
for line in open("/proc/self/mountinfo").readlines():
parts = line.strip().split()
if os.path.normpath(parts[4]) == looking_for:
devname = parts[9]
@ -310,19 +301,19 @@ class ErrorReport(metaclass=urwid.MetaSignals):
break
else:
if self.reporter.dry_run:
path = ('install-logs/2019-11-06.0/crash/' +
self.base +
'.crash')
path = "install-logs/2019-11-06.0/crash/" + self.base + ".crash"
return "casper-rw", path
return None, None
import pyudev
c = pyudev.Context()
devs = list(c.list_devices(
subsystem='block', DEVNAME=os.path.realpath(devname)))
devs = list(
c.list_devices(subsystem="block", DEVNAME=os.path.realpath(devname))
)
if not devs:
return None, None
label = devs[0].get('ID_FS_LABEL_ENC', '')
return label, root[1:] + '/' + self.base + '.crash'
label = devs[0].get("ID_FS_LABEL_ENC", "")
return label, root[1:] + "/" + self.base + ".crash"
def ref(self):
return ErrorReportRef(
@ -349,22 +340,21 @@ class ErrorReport(metaclass=urwid.MetaSignals):
class ErrorReporter(object):
def __init__(self, context, dry_run, root, client=None):
self.context = context
self.dry_run = dry_run
self.crash_directory = os.path.join(root, 'var/crash')
self.crash_directory = os.path.join(root, "var/crash")
self.client = client
self.reports = []
self._reports_by_base = {}
self._reports_by_exception = {}
self.crashdb_spec = {
'impl': 'launchpad',
'project': 'subiquity',
"impl": "launchpad",
"project": "subiquity",
}
if dry_run:
self.crashdb_spec['launchpad_instance'] = 'staging'
self.crashdb_spec["launchpad_instance"] = "staging"
self._apport_data = []
self._apport_files = []
@ -398,7 +388,7 @@ class ErrorReporter(object):
return self._reports_by_exception.get(exc)
def make_apport_report(self, kind, thing, *, wait=False, exc=None, **kw):
if not self.dry_run and not os.path.exists('/cdrom/.disk/info'):
if not self.dry_run and not os.path.exists("/cdrom/.disk/info"):
return None
log.debug("generating crash report")
@ -415,16 +405,14 @@ class ErrorReporter(object):
if exc is None:
exc = sys.exc_info()[1]
if exc is not None:
report.pr["Title"] = "{} crashed with {}".format(
thing, type(exc).__name__)
report.pr["Title"] = "{} crashed with {}".format(thing, type(exc).__name__)
tb = traceback.TracebackException.from_exception(exc)
report.pr['Traceback'] = "".join(tb.format())
report.pr["Traceback"] = "".join(tb.format())
self._reports_by_exception[exc] = report
else:
report.pr["Title"] = thing
log.info(
"saving crash report %r to %s", report.pr["Title"], report.path)
log.info("saving crash report %r to %s", report.pr["Title"], report.path)
apport_files = self._apport_files[:]
apport_data = self._apport_data.copy()
@ -456,8 +444,7 @@ class ErrorReporter(object):
await self.client.errors.wait.GET(error_ref)
path = os.path.join(
self.crash_directory, error_ref.base + '.crash')
path = os.path.join(self.crash_directory, error_ref.base + ".crash")
report = ErrorReport.from_file(self, path)
self.reports.insert(0, report)
self._reports_by_base[error_ref.base] = report

View File

@ -29,7 +29,6 @@ from subiquity.models.filesystem import (
raidlevels_by_value,
)
_checkers = {}
@ -37,8 +36,9 @@ def make_checker(action):
@functools.singledispatch
def impl(device):
raise NotImplementedError(
"checker for %s on %s not implemented" % (
action, device))
"checker for %s on %s not implemented" % (action, device)
)
_checkers[action] = impl
return impl
@ -75,8 +75,7 @@ class DeviceAction(enum.Enum):
@functools.singledispatch
def _supported_actions(device):
raise NotImplementedError(
"_supported_actions({}) not defined".format(device))
raise NotImplementedError("_supported_actions({}) not defined".format(device))
@_supported_actions.register(Disk)
@ -117,7 +116,7 @@ def _raid_actions(raid):
DeviceAction.REFORMAT,
]
if raid._m.bootloader == Bootloader.UEFI:
if raid.container and raid.container.metadata == 'imsm':
if raid.container and raid.container.metadata == "imsm":
actions.append(DeviceAction.TOGGLE_BOOT)
return actions
@ -161,11 +160,10 @@ def _can_edit_generic(device):
if cd is None:
return True
return _(
"Cannot edit {selflabel} as it is part of the {cdtype} "
"{cdname}.").format(
selflabel=labels.label(device),
cdtype=labels.desc(cd),
cdname=labels.label(cd))
"Cannot edit {selflabel} as it is part of the {cdtype} " "{cdname}."
).format(
selflabel=labels.label(device), cdtype=labels.desc(cd), cdname=labels.label(cd)
)
@_can_edit.register(Partition)
@ -183,9 +181,9 @@ def _can_edit_raid(raid):
if raid.preserve:
return _("Cannot edit pre-existing RAIDs.")
elif len(raid._partitions) > 0:
return _(
"Cannot edit {raidlabel} because it has partitions.").format(
raidlabel=labels.label(raid))
return _("Cannot edit {raidlabel} because it has partitions.").format(
raidlabel=labels.label(raid)
)
else:
return _can_edit_generic(raid)
@ -195,10 +193,9 @@ def _can_edit_vg(vg):
if vg.preserve:
return _("Cannot edit pre-existing volume groups.")
elif len(vg._partitions) > 0:
return _(
"Cannot edit {vglabel} because it has logical "
"volumes.").format(
vglabel=labels.label(vg))
return _("Cannot edit {vglabel} because it has logical " "volumes.").format(
vglabel=labels.label(vg)
)
else:
return _can_edit_generic(vg)
@ -251,11 +248,13 @@ def _can_remove_device(device):
if cd is None:
return False
if cd.preserve:
return _("Cannot remove {selflabel} from pre-existing {cdtype} "
"{cdlabel}.").format(
return _(
"Cannot remove {selflabel} from pre-existing {cdtype} " "{cdlabel}."
).format(
selflabel=labels.label(device),
cdtype=labels.desc(cd),
cdlabel=labels.label(cd))
cdlabel=labels.label(cd),
)
if isinstance(cd, Raid):
if device in cd.spare_devices:
return True
@ -263,19 +262,23 @@ def _can_remove_device(device):
if len(cd.devices) == min_devices:
return _(
"Removing {selflabel} would leave the {cdtype} {cdlabel} with "
"less than {min_devices} devices.").format(
"less than {min_devices} devices."
).format(
selflabel=labels.label(device),
cdtype=labels.desc(cd),
cdlabel=labels.label(cd),
min_devices=min_devices)
min_devices=min_devices,
)
elif isinstance(cd, LVM_VolGroup):
if len(cd.devices) == 1:
return _(
"Removing {selflabel} would leave the {cdtype} {cdlabel} with "
"no devices.").format(
"no devices."
).format(
selflabel=labels.label(device),
cdtype=labels.desc(cd),
cdlabel=labels.label(cd))
cdlabel=labels.label(cd),
)
return True
@ -287,11 +290,10 @@ def _can_delete_generic(device):
if cd is None:
return True
return _(
"Cannot delete {selflabel} as it is part of the {cdtype} "
"{cdname}.").format(
selflabel=labels.label(device),
cdtype=labels.desc(cd),
cdname=labels.label(cd))
"Cannot delete {selflabel} as it is part of the {cdtype} " "{cdname}."
).format(
selflabel=labels.label(device), cdtype=labels.desc(cd), cdname=labels.label(cd)
)
@_can_delete.register(Partition)
@ -299,8 +301,10 @@ def _can_delete_partition(partition):
if partition._is_in_use:
return False
if partition.device._has_preexisting_partition():
return _("Cannot delete a single partition from a device that "
"already has partitions.")
return _(
"Cannot delete a single partition from a device that "
"already has partitions."
)
if boot.is_bootloader_partition(partition):
return _("Cannot delete required bootloader partition")
return _can_delete_generic(partition)
@ -317,7 +321,8 @@ def _can_delete_raid_vg(device):
cd = p.constructed_device()
return _(
"Cannot delete {devicelabel} as partition {partnum} is part "
"of the {cdtype} {cdname}.").format(
"of the {cdtype} {cdname}."
).format(
devicelabel=labels.label(device),
partnum=p.number,
cdtype=labels.desc(cd),
@ -325,10 +330,8 @@ def _can_delete_raid_vg(device):
)
if mounted_partitions > 1:
return _(
"Cannot delete {devicelabel} because it has {count} mounted "
"partitions.").format(
devicelabel=labels.label(device),
count=mounted_partitions)
"Cannot delete {devicelabel} because it has {count} mounted " "partitions."
).format(devicelabel=labels.label(device), count=mounted_partitions)
elif mounted_partitions == 1:
return _(
"Cannot delete {devicelabel} because it has 1 mounted partition."
@ -340,8 +343,10 @@ def _can_delete_raid_vg(device):
@_can_delete.register(LVM_LogicalVolume)
def _can_delete_lv(lv):
if lv.volgroup._has_preexisting_partition():
return _("Cannot delete a single logical volume from a volume "
"group that already has logical volumes.")
return _(
"Cannot delete a single logical volume from a volume "
"group that already has logical volumes."
)
return True

View File

@ -21,16 +21,9 @@ from typing import Any, Optional
import attr
from subiquity.common.filesystem import gaps, sizes
from subiquity.models.filesystem import (
align_up,
Disk,
Raid,
Bootloader,
Partition,
)
from subiquity.models.filesystem import Bootloader, Disk, Partition, Raid, align_up
log = logging.getLogger('subiquity.common.filesystem.boot')
log = logging.getLogger("subiquity.common.filesystem.boot")
@functools.singledispatch
@ -55,7 +48,7 @@ def _is_boot_device_raid(raid):
bl = raid._m.bootloader
if bl != Bootloader.UEFI:
return False
if not raid.container or raid.container.metadata != 'imsm':
if not raid.container or raid.container.metadata != "imsm":
return False
return any(p.grub_device for p in raid._partitions)
@ -85,8 +78,7 @@ class CreatePartPlan(MakeBootDevicePlan):
args: dict = attr.ib(factory=dict)
def apply(self, manipulator):
manipulator.create_partition(
self.gap.device, self.gap, self.spec, **self.args)
manipulator.create_partition(self.gap.device, self.gap, self.spec, **self.args)
def _can_resize_part(inst, field, part):
@ -166,34 +158,39 @@ class MultiStepPlan(MakeBootDevicePlan):
def get_boot_device_plan_bios(device) -> Optional[MakeBootDevicePlan]:
attr_plan = SetAttrPlan(device, 'grub_device', True)
if device.ptable == 'msdos':
attr_plan = SetAttrPlan(device, "grub_device", True)
if device.ptable == "msdos":
return attr_plan
pgs = gaps.parts_and_gaps(device)
if len(pgs) > 0:
if isinstance(pgs[0], Partition) and pgs[0].flag == "bios_grub":
return attr_plan
gap = gaps.Gap(device=device,
gap = gaps.Gap(
device=device,
offset=device.alignment_data().min_start_offset,
size=sizes.BIOS_GRUB_SIZE_BYTES)
size=sizes.BIOS_GRUB_SIZE_BYTES,
)
create_part_plan = CreatePartPlan(
gap=gap,
spec=dict(size=sizes.BIOS_GRUB_SIZE_BYTES, fstype=None, mount=None),
args=dict(flag='bios_grub'))
args=dict(flag="bios_grub"),
)
movable = []
for pg in pgs:
if isinstance(pg, gaps.Gap):
if pg.size >= sizes.BIOS_GRUB_SIZE_BYTES:
return MultiStepPlan(plans=[
return MultiStepPlan(
plans=[
SlidePlan(
parts=movable,
offset_delta=sizes.BIOS_GRUB_SIZE_BYTES),
parts=movable, offset_delta=sizes.BIOS_GRUB_SIZE_BYTES
),
create_part_plan,
attr_plan,
])
]
)
else:
return None
elif pg.preserve:
@ -204,23 +201,21 @@ def get_boot_device_plan_bios(device) -> Optional[MakeBootDevicePlan]:
if not movable:
return None
largest_i, largest_part = max(
enumerate(movable),
key=lambda i_p: i_p[1].size)
return MultiStepPlan(plans=[
ResizePlan(
part=largest_part,
size_delta=-sizes.BIOS_GRUB_SIZE_BYTES),
largest_i, largest_part = max(enumerate(movable), key=lambda i_p: i_p[1].size)
return MultiStepPlan(
plans=[
ResizePlan(part=largest_part, size_delta=-sizes.BIOS_GRUB_SIZE_BYTES),
SlidePlan(
parts=movable[:largest_i+1],
offset_delta=sizes.BIOS_GRUB_SIZE_BYTES),
parts=movable[: largest_i + 1], offset_delta=sizes.BIOS_GRUB_SIZE_BYTES
),
create_part_plan,
attr_plan,
])
]
)
def get_add_part_plan(device, *, spec, args, resize_partition=None):
size = spec['size']
size = spec["size"]
partitions = device.partitions()
create_part_plan = CreatePartPlan(gap=None, spec=spec, args=args)
@ -238,70 +233,79 @@ def get_add_part_plan(device, *, spec, args, resize_partition=None):
return None
offset = resize_partition.offset + resize_partition.size - size
create_part_plan.gap = gaps.Gap(
device=device, offset=offset, size=size)
return MultiStepPlan(plans=[
create_part_plan.gap = gaps.Gap(device=device, offset=offset, size=size)
return MultiStepPlan(
plans=[
ResizePlan(
part=resize_partition,
size_delta=-size,
allow_resize_preserved=True,
),
create_part_plan,
])
]
)
else:
new_primaries = [p for p in partitions
new_primaries = [
p
for p in partitions
if not p.preserve
if p.flag not in ('extended', 'logical')]
if p.flag not in ("extended", "logical")
]
if not new_primaries:
return None
largest_part = max(new_primaries, key=lambda p: p.size)
if size > largest_part.size // 2:
return None
create_part_plan.gap = gaps.Gap(
device=device, offset=largest_part.offset, size=size)
return MultiStepPlan(plans=[
ResizePlan(
part=largest_part,
size_delta=-size),
SlidePlan(
parts=[largest_part],
offset_delta=size),
device=device, offset=largest_part.offset, size=size
)
return MultiStepPlan(
plans=[
ResizePlan(part=largest_part, size_delta=-size),
SlidePlan(parts=[largest_part], offset_delta=size),
create_part_plan,
])
]
)
def get_boot_device_plan_uefi(device, resize_partition):
for part in device.partitions():
if is_esp(part):
plans = [SetAttrPlan(part, 'grub_device', True)]
if device._m._mount_for_path('/boot/efi') is None:
plans = [SetAttrPlan(part, "grub_device", True)]
if device._m._mount_for_path("/boot/efi") is None:
plans.append(MountBootEfiPlan(part))
return MultiStepPlan(plans=plans)
part_align = device.alignment_data().part_align
size = align_up(sizes.get_efi_size(device.size), part_align)
spec = dict(size=size, fstype='fat32', mount=None)
spec = dict(size=size, fstype="fat32", mount=None)
if device._m._mount_for_path("/boot/efi") is None:
spec['mount'] = '/boot/efi'
spec["mount"] = "/boot/efi"
return get_add_part_plan(
device, spec=spec, args=dict(flag='boot', grub_device=True),
resize_partition=resize_partition)
device,
spec=spec,
args=dict(flag="boot", grub_device=True),
resize_partition=resize_partition,
)
def get_boot_device_plan_prep(device, resize_partition):
for part in device.partitions():
if part.flag == "prep":
return MultiStepPlan(plans=[
SetAttrPlan(part, 'grub_device', True),
SetAttrPlan(part, 'wipe', 'zero')
])
return MultiStepPlan(
plans=[
SetAttrPlan(part, "grub_device", True),
SetAttrPlan(part, "wipe", "zero"),
]
)
return get_add_part_plan(
device,
spec=dict(size=sizes.PREP_GRUB_SIZE_BYTES, fstype=None, mount=None),
args=dict(flag='prep', grub_device=True, wipe='zero'),
resize_partition=resize_partition)
args=dict(flag="prep", grub_device=True, wipe="zero"),
resize_partition=resize_partition,
)
def get_boot_device_plan(device, resize_partition=None):
@ -317,12 +321,11 @@ def get_boot_device_plan(device, resize_partition=None):
return get_boot_device_plan_prep(device, resize_partition)
if bl == Bootloader.NONE:
return NoOpBootPlan()
raise Exception(f'unexpected bootloader {bl} here')
raise Exception(f"unexpected bootloader {bl} here")
@functools.singledispatch
def can_be_boot_device(device, *,
resize_partition=None, with_reformatting=False):
def can_be_boot_device(device, *, resize_partition=None, with_reformatting=False):
"""Can `device` be made into a boot device?
If with_reformatting=True, return true if the device can be made
@ -332,8 +335,7 @@ def can_be_boot_device(device, *,
@can_be_boot_device.register(Disk)
def _can_be_boot_device_disk(disk, *,
resize_partition=None, with_reformatting=False):
def _can_be_boot_device_disk(disk, *, resize_partition=None, with_reformatting=False):
if with_reformatting:
disk = disk._reformatted()
plan = get_boot_device_plan(disk, resize_partition=resize_partition)
@ -341,12 +343,11 @@ def _can_be_boot_device_disk(disk, *,
@can_be_boot_device.register(Raid)
def _can_be_boot_device_raid(raid, *,
resize_partition=None, with_reformatting=False):
def _can_be_boot_device_raid(raid, *, resize_partition=None, with_reformatting=False):
bl = raid._m.bootloader
if bl != Bootloader.UEFI:
return False
if not raid.container or raid.container.metadata != 'imsm':
if not raid.container or raid.container.metadata != "imsm":
return False
if with_reformatting:
return True
@ -376,7 +377,7 @@ def _is_esp_partition(partition):
if typecode is None:
return False
try:
return int(typecode, 0) == 0xef
return int(typecode, 0) == 0xEF
except ValueError:
# In case there was garbage in the udev entry...
return False

View File

@ -14,20 +14,20 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import functools
from typing import Tuple, List
from typing import List, Tuple
import attr
from subiquity.common.types import GapUsable
from subiquity.models.filesystem import (
align_up,
align_down,
Disk,
LVM_CHUNK_SIZE,
Disk,
LVM_LogicalVolume,
LVM_VolGroup,
Partition,
Raid,
align_down,
align_up,
)
@ -40,11 +40,11 @@ class Gap:
in_extended: bool = False
usable: str = GapUsable.YES
type: str = 'gap'
type: str = "gap"
@property
def id(self):
return 'gap-' + self.device.id
return "gap-" + self.device.id
@property
def is_usable(self):
@ -55,21 +55,25 @@ class Gap:
the supplied size. If size is equal to the gap size, the second gap is
None. The original gap is unmodified."""
if size > self.size:
raise ValueError('requested size larger than gap')
raise ValueError("requested size larger than gap")
if size == self.size:
return (self, None)
first_gap = Gap(device=self.device,
first_gap = Gap(
device=self.device,
offset=self.offset,
size=size,
in_extended=self.in_extended,
usable=self.usable)
usable=self.usable,
)
if self.in_extended:
size += self.device.alignment_data().ebr_space
rest_gap = Gap(device=self.device,
rest_gap = Gap(
device=self.device,
offset=self.offset + size,
size=self.size - size,
in_extended=self.in_extended,
usable=self.usable)
usable=self.usable,
)
return (first_gap, rest_gap)
def within(self):
@ -134,11 +138,15 @@ def find_disk_gaps_v2(device, info=None):
else:
usable = GapUsable.TOO_MANY_PRIMARY_PARTS
if end - start >= info.min_gap_size:
result.append(Gap(device=device,
result.append(
Gap(
device=device,
offset=start,
size=end - start,
in_extended=in_extended,
usable=usable))
usable=usable,
)
)
prev_end = info.min_start_offset
@ -158,8 +166,7 @@ def find_disk_gaps_v2(device, info=None):
gap_start = au(prev_end)
if extended_end is not None:
gap_start = min(
extended_end, au(gap_start + info.ebr_space))
gap_start = min(extended_end, au(gap_start + info.ebr_space))
if extended_end is not None and gap_end >= extended_end:
maybe_add_gap(gap_start, ad(extended_end), True)
@ -259,8 +266,9 @@ def movable_trailing_partitions_and_gap_size(partition):
@movable_trailing_partitions_and_gap_size.register
def _movable_trailing_partitions_and_gap_size_partition(partition: Partition) \
-> Tuple[List[Partition], int]:
def _movable_trailing_partitions_and_gap_size_partition(
partition: Partition,
) -> Tuple[List[Partition], int]:
pgs = parts_and_gaps(partition.device)
part_idx = pgs.index(partition)
trailing_partitions = []
@ -281,8 +289,9 @@ def _movable_trailing_partitions_and_gap_size_partition(partition: Partition) \
@movable_trailing_partitions_and_gap_size.register
def _movable_trailing_partitions_and_gap_size_lvm(volume: LVM_LogicalVolume) \
-> Tuple[List[LVM_LogicalVolume], int]:
def _movable_trailing_partitions_and_gap_size_lvm(
volume: LVM_LogicalVolume,
) -> Tuple[List[LVM_LogicalVolume], int]:
# In a Volume Group, there is no need to move partitions around, one can
# always use the remaining space.

View File

@ -18,18 +18,18 @@ import functools
from subiquity.common import types
from subiquity.common.filesystem import boot, gaps
from subiquity.models.filesystem import (
ZFS,
Disk,
LVM_LogicalVolume,
LVM_VolGroup,
Partition,
Raid,
ZFS,
ZPool,
)
def _annotations_generic(device):
preserve = getattr(device, 'preserve', None)
preserve = getattr(device, "preserve", None)
if preserve is None:
return []
elif preserve:
@ -128,16 +128,15 @@ def _desc_partition(partition):
@desc.register(Raid)
def _desc_raid(raid):
level = raid.raidlevel
if level.lower().startswith('raid'):
if level.lower().startswith("raid"):
level = level[4:]
if raid.container:
raid_type = raid.container.metadata
elif raid.metadata == "imsm":
raid_type = 'imsm' # maybe VROC?
raid_type = "imsm" # maybe VROC?
else:
raid_type = _("software")
return _("{type} RAID {level}").format(
type=raid_type, level=level)
return _("{type} RAID {level}").format(type=raid_type, level=level)
@desc.register(LVM_VolGroup)
@ -199,7 +198,8 @@ def _label_partition(partition, *, short=False):
return p + _("partition {number}").format(number=partition.number)
else:
return p + _("partition {number} of {device}").format(
number=partition.number, device=label(partition.device))
number=partition.number, device=label(partition.device)
)
@label.register(gaps.Gap)
@ -215,9 +215,8 @@ def _usage_labels_generic(device, *, exclude_final_unused=False):
if cd is not None:
return [
_("{component_name} of {desc} {name}").format(
component_name=cd.component_name,
desc=desc(cd),
name=cd.name),
component_name=cd.component_name, desc=desc(cd), name=cd.name
),
]
fs = device.fs()
if fs is not None:
@ -279,10 +278,11 @@ def _usage_labels_disk(disk):
@usage_labels.register(Raid)
def _usage_labels_raid(raid):
if raid.metadata == 'imsm' and raid._subvolumes:
if raid.metadata == "imsm" and raid._subvolumes:
return [
_('container for {devices}').format(
devices=', '.join([label(v) for v in raid._subvolumes]))
_("container for {devices}").format(
devices=", ".join([label(v) for v in raid._subvolumes])
)
]
return _usage_labels_generic(raid)
@ -309,7 +309,7 @@ def _for_client_disk(disk, *, min_size=0):
return types.Disk(
id=disk.id,
label=label(disk),
path=getattr(disk, 'path', None),
path=getattr(disk, "path", None),
type=desc(disk),
size=disk.size,
ptable=disk.ptable,
@ -319,9 +319,10 @@ def _for_client_disk(disk, *, min_size=0):
boot_device=boot.is_boot_device(disk),
can_be_boot_device=boot.can_be_boot_device(disk),
ok_for_guided=disk.size >= min_size,
model=getattr(disk, 'model', None),
vendor=getattr(disk, 'vendor', None),
has_in_use_partition=disk._has_in_use_partition)
model=getattr(disk, "model", None),
vendor=getattr(disk, "vendor", None),
has_in_use_partition=disk._has_in_use_partition,
)
@for_client.register(Partition)
@ -341,7 +342,8 @@ def _for_client_partition(partition, *, min_size=0):
estimated_min_size=partition.estimated_min_size,
mount=partition.mount,
format=partition.format,
is_in_use=partition._is_in_use)
is_in_use=partition._is_in_use,
)
@for_client.register(gaps.Gap)

View File

@ -19,20 +19,16 @@ from curtin.block import get_resize_fstypes
from subiquity.common.filesystem import boot, gaps
from subiquity.common.types import Bootloader
from subiquity.models.filesystem import (
align_up,
Partition,
)
from subiquity.models.filesystem import Partition, align_up
log = logging.getLogger('subiquity.common.filesystem.manipulator')
log = logging.getLogger("subiquity.common.filesystem.manipulator")
class FilesystemManipulator:
def create_mount(self, fs, spec):
if spec.get('mount') is None:
if spec.get("mount") is None:
return
mount = self.model.add_mount(fs, spec['mount'])
mount = self.model.add_mount(fs, spec["mount"])
if self.model.needs_bootloader_partition():
vol = fs.volume
if vol.type == "partition" and boot.can_be_boot_device(vol.device):
@ -45,20 +41,20 @@ class FilesystemManipulator:
self.model.remove_mount(mount)
def create_filesystem(self, volume, spec):
if spec.get('fstype') is None:
if spec.get("fstype") is None:
# prep partitions are always wiped (and never have a filesystem)
if getattr(volume, 'flag', None) != 'prep':
volume.wipe = spec.get('wipe', volume.wipe)
if getattr(volume, "flag", None) != "prep":
volume.wipe = spec.get("wipe", volume.wipe)
fstype = volume.original_fstype()
if fstype is None:
return None
else:
fstype = spec['fstype']
fstype = spec["fstype"]
# partition editing routines are expected to provide the
# appropriate wipe value. partition additions may not, give them a
# basic wipe value consistent with older behavior until we prove
# that everyone does so correctly.
volume.wipe = spec.get('wipe', 'superblock')
volume.wipe = spec.get("wipe", "superblock")
preserve = volume.wipe is None
fs = self.model.add_filesystem(volume, fstype, preserve)
if isinstance(volume, Partition):
@ -66,9 +62,9 @@ class FilesystemManipulator:
volume.flag = "swap"
elif volume.flag == "swap":
volume.flag = ""
if spec.get('fstype') == "swap":
if spec.get("fstype") == "swap":
self.model.add_mount(fs, "")
elif spec.get('fstype') is None and spec.get('use_swap'):
elif spec.get("fstype") is None and spec.get("use_swap"):
self.model.add_mount(fs, "")
else:
self.create_mount(fs, spec)
@ -79,35 +75,39 @@ class FilesystemManipulator:
return
self.delete_mount(fs.mount())
self.model.remove_filesystem(fs)
delete_format = delete_filesystem
def create_partition(self, device, gap, spec, **kw):
flag = kw.pop('flag', None)
flag = kw.pop("flag", None)
if gap.in_extended:
if flag not in (None, 'logical'):
log.debug(f'overriding flag {flag} '
'due to being in an extended partition')
flag = 'logical'
if flag not in (None, "logical"):
log.debug(
f"overriding flag {flag} " "due to being in an extended partition"
)
flag = "logical"
part = self.model.add_partition(
device, size=gap.size, offset=gap.offset, flag=flag, **kw)
device, size=gap.size, offset=gap.offset, flag=flag, **kw
)
self.create_filesystem(part, spec)
return part
def delete_partition(self, part, override_preserve=False):
if not override_preserve and part.device.preserve and \
self.model.storage_version < 2:
if (
not override_preserve
and part.device.preserve
and self.model.storage_version < 2
):
raise Exception("cannot delete partitions from preserved disks")
self.clear(part)
self.model.remove_partition(part)
def create_raid(self, spec):
for d in spec['devices'] | spec['spare_devices']:
for d in spec["devices"] | spec["spare_devices"]:
self.clear(d)
raid = self.model.add_raid(
spec['name'],
spec['level'].value,
spec['devices'],
spec['spare_devices'])
spec["name"], spec["level"].value, spec["devices"], spec["spare_devices"]
)
return raid
def delete_raid(self, raid):
@ -119,69 +119,74 @@ class FilesystemManipulator:
for p in list(raid.partitions()):
self.delete_partition(p, True)
for d in set(raid.devices) | set(raid.spare_devices):
d.wipe = 'superblock'
d.wipe = "superblock"
self.model.remove_raid(raid)
def create_volgroup(self, spec):
devices = set()
key = spec.get('passphrase')
for device in spec['devices']:
key = spec.get("passphrase")
for device in spec["devices"]:
self.clear(device)
if key:
device = self.model.add_dm_crypt(device, key)
devices.add(device)
return self.model.add_volgroup(name=spec['name'], devices=devices)
return self.model.add_volgroup(name=spec["name"], devices=devices)
create_lvm_volgroup = create_volgroup
def delete_volgroup(self, vg):
for lv in list(vg.partitions()):
self.delete_logical_volume(lv)
for d in vg.devices:
d.wipe = 'superblock'
d.wipe = "superblock"
if d.type == "dm_crypt":
self.model.remove_dm_crypt(d)
self.model.remove_volgroup(vg)
delete_lvm_volgroup = delete_volgroup
def create_logical_volume(self, vg, spec):
lv = self.model.add_logical_volume(
vg=vg,
name=spec['name'],
size=spec['size'])
lv = self.model.add_logical_volume(vg=vg, name=spec["name"], size=spec["size"])
self.create_filesystem(lv, spec)
return lv
create_lvm_partition = create_logical_volume
def delete_logical_volume(self, lv):
self.clear(lv)
self.model.remove_logical_volume(lv)
delete_lvm_partition = delete_logical_volume
def create_zpool(self, device, pool, mountpoint):
fs_properties = dict(
acltype='posixacl',
relatime='on',
canmount='on',
compression='gzip',
devices='off',
xattr='sa',
acltype="posixacl",
relatime="on",
canmount="on",
compression="gzip",
devices="off",
xattr="sa",
)
pool_properties = dict(ashift=12)
self.model.add_zpool(
device, pool, mountpoint,
fs_properties=fs_properties, pool_properties=pool_properties)
device,
pool,
mountpoint,
fs_properties=fs_properties,
pool_properties=pool_properties,
)
def delete(self, obj):
if obj is None:
return
getattr(self, 'delete_' + obj.type)(obj)
getattr(self, "delete_" + obj.type)(obj)
def clear(self, obj, wipe=None):
if obj.type == "disk":
obj.preserve = False
if wipe is None:
wipe = 'superblock'
wipe = "superblock"
obj.wipe = wipe
for subobj in obj.fs(), obj.constructed_device():
self.delete(subobj)
@ -202,14 +207,14 @@ class FilesystemManipulator:
return True
def partition_disk_handler(self, disk, spec, *, partition=None, gap=None):
log.debug('partition_disk_handler: %s %s %s %s',
disk, spec, partition, gap)
log.debug("partition_disk_handler: %s %s %s %s", disk, spec, partition, gap)
if partition is not None:
if 'size' in spec and spec['size'] != partition.size:
trailing, gap_size = \
gaps.movable_trailing_partitions_and_gap_size(partition)
new_size = align_up(spec['size'])
if "size" in spec and spec["size"] != partition.size:
trailing, gap_size = gaps.movable_trailing_partitions_and_gap_size(
partition
)
new_size = align_up(spec["size"])
size_change = new_size - partition.size
if size_change > gap_size:
raise Exception("partition size too large")
@ -226,12 +231,12 @@ class FilesystemManipulator:
if len(disk.partitions()) == 0:
if disk.type == "disk":
disk.preserve = False
disk.wipe = 'superblock-recursive'
disk.wipe = "superblock-recursive"
elif disk.type == "raid":
disk.wipe = 'superblock-recursive'
disk.wipe = "superblock-recursive"
needs_boot = self.model.needs_bootloader_partition()
log.debug('model needs a bootloader partition? {}'.format(needs_boot))
log.debug("model needs a bootloader partition? {}".format(needs_boot))
can_be_boot = boot.can_be_boot_device(disk)
if needs_boot and len(disk.partitions()) == 0 and can_be_boot:
self.add_boot_disk(disk)
@ -241,13 +246,13 @@ class FilesystemManipulator:
# 1) with len(partitions()) == 0 there could only have been 1 gap
# 2) having just done add_boot_disk(), the gap is no longer valid.
gap = gaps.largest_gap(disk)
if spec['size'] > gap.size:
if spec["size"] > gap.size:
log.debug(
"Adjusting request down from %s to %s",
spec['size'], gap.size)
spec['size'] = gap.size
"Adjusting request down from %s to %s", spec["size"], gap.size
)
spec["size"] = gap.size
gap = gap.split(spec['size'])[0]
gap = gap.split(spec["size"])[0]
self.create_partition(disk, gap, spec)
log.debug("Successfully added partition")
@ -256,13 +261,13 @@ class FilesystemManipulator:
# keep the partition name for compat with PartitionStretchy.handler
lv = partition
log.debug('logical_volume_handler: %s %s %s', vg, lv, spec)
log.debug("logical_volume_handler: %s %s %s", vg, lv, spec)
if lv is not None:
if 'name' in spec:
lv.name = spec['name']
if 'size' in spec:
lv.size = align_up(spec['size'])
if "name" in spec:
lv.name = spec["name"]
if "size" in spec:
lv.size = align_up(spec["size"])
if gaps.largest_gap_size(vg) < 0:
raise Exception("lv size too large")
self.delete_filesystem(lv.fs())
@ -272,7 +277,7 @@ class FilesystemManipulator:
self.create_logical_volume(vg, spec)
def add_format_handler(self, volume, spec):
log.debug('add_format_handler %s %s', volume, spec)
log.debug("add_format_handler %s %s", volume, spec)
self.clear(volume)
self.create_filesystem(volume, spec)
@ -281,47 +286,46 @@ class FilesystemManipulator:
if existing is not None:
for d in existing.devices | existing.spare_devices:
d._constructed_device = None
for d in spec['devices'] | spec['spare_devices']:
for d in spec["devices"] | spec["spare_devices"]:
self.clear(d)
d._constructed_device = existing
existing.name = spec['name']
existing.raidlevel = spec['level'].value
existing.devices = spec['devices']
existing.spare_devices = spec['spare_devices']
existing.name = spec["name"]
existing.raidlevel = spec["level"].value
existing.devices = spec["devices"]
existing.spare_devices = spec["spare_devices"]
else:
self.create_raid(spec)
def volgroup_handler(self, existing, spec):
if existing is not None:
key = spec.get('passphrase')
key = spec.get("passphrase")
for d in existing.devices:
if d.type == "dm_crypt":
self.model.remove_dm_crypt(d)
d = d.volume
d._constructed_device = None
devices = set()
for d in spec['devices']:
for d in spec["devices"]:
self.clear(d)
if key:
d = self.model.add_dm_crypt(d, key)
d._constructed_device = existing
devices.add(d)
existing.name = spec['name']
existing.name = spec["name"]
existing.devices = devices
else:
self.create_volgroup(spec)
def _mount_esp(self, part):
if part.fs() is None:
self.model.add_filesystem(part, 'fat32')
self.model.add_mount(part.fs(), '/boot/efi')
self.model.add_filesystem(part, "fat32")
self.model.add_mount(part.fs(), "/boot/efi")
def remove_boot_disk(self, boot_disk):
if self.model.bootloader == Bootloader.BIOS:
boot_disk.grub_device = False
partitions = [
p for p in boot_disk.partitions()
if boot.is_bootloader_partition(p)
p for p in boot_disk.partitions() if boot.is_bootloader_partition(p)
]
remount = False
if boot_disk.preserve:
@ -339,7 +343,8 @@ class FilesystemManipulator:
if not p.fs().preserve and p.original_fstype():
self.delete_filesystem(p.fs())
self.model.add_filesystem(
p, p.original_fstype(), preserve=True)
p, p.original_fstype(), preserve=True
)
else:
full = gaps.largest_gap_size(boot_disk) == 0
tot_size = 0
@ -349,11 +354,10 @@ class FilesystemManipulator:
remount = True
self.delete_partition(p)
if full:
largest_part = max(
boot_disk.partitions(), key=lambda p: p.size)
largest_part = max(boot_disk.partitions(), key=lambda p: p.size)
largest_part.size += tot_size
if self.model.bootloader == Bootloader.UEFI and remount:
part = self.model._one(type='partition', grub_device=True)
part = self.model._one(type="partition", grub_device=True)
if part:
self._mount_esp(part)
@ -363,7 +367,7 @@ class FilesystemManipulator:
self.remove_boot_disk(disk)
plan = boot.get_boot_device_plan(new_boot_disk)
if plan is None:
raise ValueError(f'No known plan to make {new_boot_disk} bootable')
raise ValueError(f"No known plan to make {new_boot_disk} bootable")
plan.apply(self)
if not new_boot_disk._has_preexisting_partition():
if new_boot_disk.type == "disk":

View File

@ -16,17 +16,10 @@
import math
import attr
from curtin import swap
from subiquity.models.filesystem import (
align_up,
align_down,
MiB,
GiB,
)
from subiquity.common.types import GuidedResizeValues
from subiquity.models.filesystem import GiB, MiB, align_down, align_up
BIOS_GRUB_SIZE_BYTES = 1 * MiB
PREP_GRUB_SIZE_BYTES = 8 * MiB
@ -39,18 +32,11 @@ class PartitionScaleFactors:
maximum: int
uefi_scale = PartitionScaleFactors(
minimum=538 * MiB,
priority=538,
maximum=1075 * MiB)
uefi_scale = PartitionScaleFactors(minimum=538 * MiB, priority=538, maximum=1075 * MiB)
bootfs_scale = PartitionScaleFactors(
minimum=1792 * MiB,
priority=1024,
maximum=2048 * MiB)
rootfs_scale = PartitionScaleFactors(
minimum=900 * MiB,
priority=10000,
maximum=-1)
minimum=1792 * MiB, priority=1024, maximum=2048 * MiB
)
rootfs_scale = PartitionScaleFactors(minimum=900 * MiB, priority=10000, maximum=-1)
def scale_partitions(all_factors, available_space):
@ -96,14 +82,15 @@ def get_bootfs_size(available_space):
# decide to keep with the existing partition, or allocate to the new install
# 4) Assume that the installs will grow proportionally to their minimum sizes,
# and split according to the ratio of the minimum sizes
def calculate_guided_resize(part_min: int, part_size: int, install_min: int,
part_align: int = MiB) -> GuidedResizeValues:
def calculate_guided_resize(
part_min: int, part_size: int, install_min: int, part_align: int = MiB
) -> GuidedResizeValues:
if part_min < 0:
return None
part_size = align_up(part_size, part_align)
other_room_to_grow = max(2 * GiB, math.ceil(.25 * part_min))
other_room_to_grow = max(2 * GiB, math.ceil(0.25 * part_min))
padded_other_min = part_min + other_room_to_grow
other_min = min(align_up(padded_other_min, part_align), part_size)
@ -118,7 +105,10 @@ def calculate_guided_resize(part_min: int, part_size: int, install_min: int,
recommended = align_up(raw_recommended, part_align)
return GuidedResizeValues(
install_max=plausible_free_space,
minimum=other_min, recommended=recommended, maximum=other_max)
minimum=other_min,
recommended=recommended,
maximum=other_max,
)
# Factors for suggested minimum install size:
@ -142,10 +132,9 @@ def calculate_guided_resize(part_min: int, part_size: int, install_min: int,
# 4) Swap - Ubuntu policy recommends swap, either in the form of a partition or
# a swapfile. Curtin has an existing swap size recommendation at
# curtin.swap.suggested_swapsize().
def calculate_suggested_install_min(source_min: int,
part_align: int = MiB) -> int:
def calculate_suggested_install_min(source_min: int, part_align: int = MiB) -> int:
room_for_boot = bootfs_scale.maximum
room_to_grow = max(2 * GiB, math.ceil(.8 * source_min))
room_to_grow = max(2 * GiB, math.ceil(0.8 * source_min))
room_for_swap = swap.suggested_swapsize()
total = source_min + room_for_boot + room_to_grow + room_for_swap
return align_up(total, part_align)

View File

@ -15,10 +15,8 @@
import unittest
from subiquity.common.filesystem.actions import (
DeviceAction,
)
from subiquity.common.filesystem import gaps
from subiquity.common.filesystem.actions import DeviceAction
from subiquity.models.filesystem import Bootloader
from subiquity.models.tests.test_filesystem import (
make_disk,
@ -36,7 +34,6 @@ from subiquity.models.tests.test_filesystem import (
class TestActions(unittest.TestCase):
def assertActionNotSupported(self, obj, action):
self.assertNotIn(action, DeviceAction.supported(obj))
@ -51,7 +48,7 @@ class TestActions(unittest.TestCase):
def _test_remove_action(self, model, objects):
self.assertActionNotPossible(objects[0], DeviceAction.REMOVE)
vg = model.add_volgroup('vg1', {objects[0], objects[1]})
vg = model.add_volgroup("vg1", {objects[0], objects[1]})
self.assertActionPossible(objects[0], DeviceAction.REMOVE)
# Cannot remove a device from a preexisting VG
@ -63,7 +60,7 @@ class TestActions(unittest.TestCase):
vg.devices.remove(objects[1])
objects[1]._constructed_device = None
self.assertActionNotPossible(objects[0], DeviceAction.REMOVE)
raid = model.add_raid('md0', 'raid1', set(objects[2:]), set())
raid = model.add_raid("md0", "raid1", set(objects[2:]), set())
self.assertActionPossible(objects[2], DeviceAction.REMOVE)
# Cannot remove a device from a preexisting RAID
@ -91,22 +88,22 @@ class TestActions(unittest.TestCase):
self.assertActionNotPossible(disk1, DeviceAction.REFORMAT)
disk1p1 = make_partition(model, disk1, preserve=False)
self.assertActionPossible(disk1, DeviceAction.REFORMAT)
model.add_volgroup('vg0', {disk1p1})
model.add_volgroup("vg0", {disk1p1})
self.assertActionNotPossible(disk1, DeviceAction.REFORMAT)
disk2 = make_disk(model, preserve=True)
self.assertActionNotPossible(disk2, DeviceAction.REFORMAT)
disk2p1 = make_partition(model, disk2, preserve=True)
self.assertActionPossible(disk2, DeviceAction.REFORMAT)
model.add_volgroup('vg1', {disk2p1})
model.add_volgroup("vg1", {disk2p1})
self.assertActionNotPossible(disk2, DeviceAction.REFORMAT)
disk3 = make_disk(model, preserve=False)
model.add_volgroup('vg2', {disk3})
model.add_volgroup("vg2", {disk3})
self.assertActionNotPossible(disk3, DeviceAction.REFORMAT)
disk4 = make_disk(model, preserve=True)
model.add_volgroup('vg2', {disk4})
model.add_volgroup("vg2", {disk4})
self.assertActionNotPossible(disk4, DeviceAction.REFORMAT)
def test_disk_action_PARTITION(self):
@ -119,7 +116,7 @@ class TestActions(unittest.TestCase):
make_partition(model, disk)
self.assertActionNotPossible(disk, DeviceAction.FORMAT)
disk2 = make_disk(model)
model.add_volgroup('vg1', {disk2})
model.add_volgroup("vg1", {disk2})
self.assertActionNotPossible(disk2, DeviceAction.FORMAT)
def test_disk_action_REMOVE(self):
@ -138,18 +135,18 @@ class TestActions(unittest.TestCase):
def test_disk_action_TOGGLE_BOOT_BIOS(self):
model = make_model(Bootloader.BIOS)
# Disks with msdos partition tables can always be the BIOS boot disk.
dos_disk = make_disk(model, ptable='msdos', preserve=True)
dos_disk = make_disk(model, ptable="msdos", preserve=True)
self.assertActionPossible(dos_disk, DeviceAction.TOGGLE_BOOT)
# Even if they have existing partitions
make_partition(
model, dos_disk, size=gaps.largest_gap_size(dos_disk),
preserve=True)
model, dos_disk, size=gaps.largest_gap_size(dos_disk), preserve=True
)
self.assertActionPossible(dos_disk, DeviceAction.TOGGLE_BOOT)
# (we never create dos partition tables so no need to test
# preserve=False case).
# GPT disks with new partition tables can always be the BIOS boot disk
gpt_disk = make_disk(model, ptable='gpt', preserve=False)
gpt_disk = make_disk(model, ptable="gpt", preserve=False)
self.assertActionPossible(gpt_disk, DeviceAction.TOGGLE_BOOT)
# Even if they are filled with partitions (we resize partitions to fit)
make_partition(model, gpt_disk, size=gaps.largest_gap_size(dos_disk))
@ -157,25 +154,37 @@ class TestActions(unittest.TestCase):
# GPT disks with existing partition tables but no partitions can be the
# BIOS boot disk (in general we ignore existing empty partition tables)
gpt_disk2 = make_disk(model, ptable='gpt', preserve=True)
gpt_disk2 = make_disk(model, ptable="gpt", preserve=True)
self.assertActionPossible(gpt_disk2, DeviceAction.TOGGLE_BOOT)
# If there is an existing *partition* though, it cannot be the boot
# disk
make_partition(model, gpt_disk2, preserve=True)
self.assertActionNotPossible(gpt_disk2, DeviceAction.TOGGLE_BOOT)
# Unless there is already a bios_grub partition we can reuse
gpt_disk3 = make_disk(model, ptable='gpt', preserve=True)
make_partition(model, gpt_disk3, flag="bios_grub", preserve=True,
offset=1 << 20, size=512 << 20)
make_partition(model, gpt_disk3, preserve=True,
offset=513 << 20, size=8192 << 20)
gpt_disk3 = make_disk(model, ptable="gpt", preserve=True)
make_partition(
model,
gpt_disk3,
flag="bios_grub",
preserve=True,
offset=1 << 20,
size=512 << 20,
)
make_partition(
model, gpt_disk3, preserve=True, offset=513 << 20, size=8192 << 20
)
self.assertActionPossible(gpt_disk3, DeviceAction.TOGGLE_BOOT)
# Edge case city: the bios_grub partition has to be first
gpt_disk4 = make_disk(model, ptable='gpt', preserve=True)
make_partition(model, gpt_disk4, preserve=True,
offset=1 << 20, size=8192 << 20)
make_partition(model, gpt_disk4, flag="bios_grub", preserve=True,
offset=8193 << 20, size=512 << 20)
gpt_disk4 = make_disk(model, ptable="gpt", preserve=True)
make_partition(model, gpt_disk4, preserve=True, offset=1 << 20, size=8192 << 20)
make_partition(
model,
gpt_disk4,
flag="bios_grub",
preserve=True,
offset=8193 << 20,
size=512 << 20,
)
self.assertActionNotPossible(gpt_disk4, DeviceAction.TOGGLE_BOOT)
def _test_TOGGLE_BOOT_boot_partition(self, bl, flag):
@ -188,20 +197,20 @@ class TestActions(unittest.TestCase):
new_disk = make_disk(model, preserve=False)
self.assertActionPossible(new_disk, DeviceAction.TOGGLE_BOOT)
# Even if they are filled with partitions (we resize partitions to fit)
make_partition(
model, new_disk, size=gaps.largest_gap_size(new_disk))
make_partition(model, new_disk, size=gaps.largest_gap_size(new_disk))
self.assertActionPossible(new_disk, DeviceAction.TOGGLE_BOOT)
# A disk with an existing but empty partitions can also be the
# UEFI/PREP boot disk.
old_disk = make_disk(model, preserve=True, ptable='gpt')
old_disk = make_disk(model, preserve=True, ptable="gpt")
self.assertActionPossible(old_disk, DeviceAction.TOGGLE_BOOT)
# If there is an existing partition though, it cannot.
make_partition(model, old_disk, preserve=True)
self.assertActionNotPossible(old_disk, DeviceAction.TOGGLE_BOOT)
# If there is an existing ESP/PReP partition though, fine!
make_partition(model, old_disk, flag=flag, preserve=True,
offset=1 << 20, size=512 << 20)
make_partition(
model, old_disk, flag=flag, preserve=True, offset=1 << 20, size=512 << 20
)
self.assertActionPossible(old_disk, DeviceAction.TOGGLE_BOOT)
def test_disk_action_TOGGLE_BOOT_UEFI(self):
@ -217,7 +226,7 @@ class TestActions(unittest.TestCase):
def test_partition_action_EDIT(self):
model, part = make_model_and_partition()
self.assertActionPossible(part, DeviceAction.EDIT)
model.add_volgroup('vg1', {part})
model.add_volgroup("vg1", {part})
self.assertActionNotPossible(part, DeviceAction.EDIT)
def test_partition_action_REFORMAT(self):
@ -243,20 +252,20 @@ class TestActions(unittest.TestCase):
model = make_model()
part1 = make_partition(model)
self.assertActionPossible(part1, DeviceAction.DELETE)
fs = model.add_filesystem(part1, 'ext4')
fs = model.add_filesystem(part1, "ext4")
self.assertActionPossible(part1, DeviceAction.DELETE)
model.add_mount(fs, '/')
model.add_mount(fs, "/")
self.assertActionPossible(part1, DeviceAction.DELETE)
part2 = make_partition(model)
model.add_volgroup('vg1', {part2})
model.add_volgroup("vg1", {part2})
self.assertActionNotPossible(part2, DeviceAction.DELETE)
part = make_partition(make_model(Bootloader.BIOS), flag='bios_grub')
part = make_partition(make_model(Bootloader.BIOS), flag="bios_grub")
self.assertActionNotPossible(part, DeviceAction.DELETE)
part = make_partition(make_model(Bootloader.UEFI), flag='boot')
part = make_partition(make_model(Bootloader.UEFI), flag="boot")
self.assertActionNotPossible(part, DeviceAction.DELETE)
part = make_partition(make_model(Bootloader.PREP), flag='prep')
part = make_partition(make_model(Bootloader.PREP), flag="prep")
self.assertActionNotPossible(part, DeviceAction.DELETE)
# You cannot delete a partition from a disk that has
@ -277,7 +286,7 @@ class TestActions(unittest.TestCase):
model = make_model()
raid1 = make_raid(model)
self.assertActionPossible(raid1, DeviceAction.EDIT)
model.add_volgroup('vg1', {raid1})
model.add_volgroup("vg1", {raid1})
self.assertActionNotPossible(raid1, DeviceAction.EDIT)
raid2 = make_raid(model)
make_partition(model, raid2)
@ -294,23 +303,23 @@ class TestActions(unittest.TestCase):
self.assertActionNotPossible(raid1, DeviceAction.REFORMAT)
raid1p1 = make_partition(model, raid1)
self.assertActionPossible(raid1, DeviceAction.REFORMAT)
model.add_volgroup('vg0', {raid1p1})
model.add_volgroup("vg0", {raid1p1})
self.assertActionNotPossible(raid1, DeviceAction.REFORMAT)
raid2 = make_raid(model, preserve=True)
self.assertActionNotPossible(raid2, DeviceAction.REFORMAT)
raid2p1 = make_partition(model, raid2, preserve=True)
self.assertActionPossible(raid2, DeviceAction.REFORMAT)
model.add_volgroup('vg1', {raid2p1})
model.add_volgroup("vg1", {raid2p1})
self.assertActionNotPossible(raid2, DeviceAction.REFORMAT)
raid3 = make_raid(model)
model.add_volgroup('vg2', {raid3})
model.add_volgroup("vg2", {raid3})
self.assertActionNotPossible(raid3, DeviceAction.REFORMAT)
raid4 = make_raid(model)
raid4.preserve = True
model.add_volgroup('vg2', {raid4})
model.add_volgroup("vg2", {raid4})
self.assertActionNotPossible(raid4, DeviceAction.REFORMAT)
def test_raid_action_PARTITION(self):
@ -323,7 +332,7 @@ class TestActions(unittest.TestCase):
make_partition(model, raid)
self.assertActionNotPossible(raid, DeviceAction.FORMAT)
raid2 = make_raid(model)
model.add_volgroup('vg1', {raid2})
model.add_volgroup("vg1", {raid2})
self.assertActionNotPossible(raid2, DeviceAction.FORMAT)
def test_raid_action_REMOVE(self):
@ -338,20 +347,20 @@ class TestActions(unittest.TestCase):
self.assertActionPossible(raid1, DeviceAction.DELETE)
part = make_partition(model, raid1)
self.assertActionPossible(raid1, DeviceAction.DELETE)
fs = model.add_filesystem(part, 'ext4')
fs = model.add_filesystem(part, "ext4")
self.assertActionPossible(raid1, DeviceAction.DELETE)
model.add_mount(fs, '/')
model.add_mount(fs, "/")
self.assertActionNotPossible(raid1, DeviceAction.DELETE)
raid2 = make_raid(model)
self.assertActionPossible(raid2, DeviceAction.DELETE)
fs = model.add_filesystem(raid2, 'ext4')
fs = model.add_filesystem(raid2, "ext4")
self.assertActionPossible(raid2, DeviceAction.DELETE)
model.add_mount(fs, '/')
model.add_mount(fs, "/")
self.assertActionPossible(raid2, DeviceAction.DELETE)
raid2 = make_raid(model)
model.add_volgroup('vg0', {raid2})
model.add_volgroup("vg0", {raid2})
self.assertActionNotPossible(raid2, DeviceAction.DELETE)
def test_raid_action_TOGGLE_BOOT(self):
@ -365,7 +374,7 @@ class TestActions(unittest.TestCase):
def test_vg_action_EDIT(self):
model, vg = make_model_and_vg()
self.assertActionPossible(vg, DeviceAction.EDIT)
model.add_logical_volume(vg, 'lv1', size=gaps.largest_gap_size(vg))
model.add_logical_volume(vg, "lv1", size=gaps.largest_gap_size(vg))
self.assertActionNotPossible(vg, DeviceAction.EDIT)
vg2 = make_vg(model)
@ -392,12 +401,11 @@ class TestActions(unittest.TestCase):
model, vg = make_model_and_vg()
self.assertActionPossible(vg, DeviceAction.DELETE)
self.assertActionPossible(vg, DeviceAction.DELETE)
lv = model.add_logical_volume(
vg, 'lv0', size=gaps.largest_gap_size(vg)//2)
lv = model.add_logical_volume(vg, "lv0", size=gaps.largest_gap_size(vg) // 2)
self.assertActionPossible(vg, DeviceAction.DELETE)
fs = model.add_filesystem(lv, 'ext4')
fs = model.add_filesystem(lv, "ext4")
self.assertActionPossible(vg, DeviceAction.DELETE)
model.add_mount(fs, '/')
model.add_mount(fs, "/")
self.assertActionNotPossible(vg, DeviceAction.DELETE)
def test_vg_action_TOGGLE_BOOT(self):
@ -431,9 +439,9 @@ class TestActions(unittest.TestCase):
def test_lv_action_DELETE(self):
model, lv = make_model_and_lv()
self.assertActionPossible(lv, DeviceAction.DELETE)
fs = model.add_filesystem(lv, 'ext4')
fs = model.add_filesystem(lv, "ext4")
self.assertActionPossible(lv, DeviceAction.DELETE)
model.add_mount(fs, '/')
model.add_mount(fs, "/")
self.assertActionPossible(lv, DeviceAction.DELETE)
lv2 = make_lv(model)

View File

@ -13,16 +13,16 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from functools import partial
import unittest
from functools import partial
from unittest import mock
from subiquitycore.tests.parameterized import parameterized
from subiquity.common.filesystem import gaps
from subiquity.common.types import GapUsable
from subiquity.models.filesystem import (
Disk,
LVM_CHUNK_SIZE,
LVM_OVERHEAD,
Disk,
MiB,
PartitionAlignmentData,
Raid,
@ -36,21 +36,18 @@ from subiquity.models.tests.test_filesystem import (
make_partition,
make_vg,
)
from subiquity.common.filesystem import gaps
from subiquity.common.types import GapUsable
from subiquitycore.tests.parameterized import parameterized
class GapTestCase(unittest.TestCase):
def use_alignment_data(self, alignment_data):
m = mock.patch('subiquity.common.filesystem.gaps.parts_and_gaps')
m = mock.patch("subiquity.common.filesystem.gaps.parts_and_gaps")
p = m.start()
self.addCleanup(m.stop)
p.side_effect = partial(
gaps.find_disk_gaps_v2, info=alignment_data)
p.side_effect = partial(gaps.find_disk_gaps_v2, info=alignment_data)
for cls in Disk, Raid:
md = mock.patch.object(cls, 'alignment_data')
md = mock.patch.object(cls, "alignment_data")
p = md.start()
self.addCleanup(md.stop)
p.return_value = alignment_data
@ -86,14 +83,21 @@ class TestSplitGap(GapTestCase):
@parameterized.expand([[1], [10]])
def test_split_in_extended(self, ebr_space):
self.use_alignment_data(PartitionAlignmentData(
part_align=1, min_gap_size=1, min_start_offset=0,
min_end_offset=0, primary_part_limit=4, ebr_space=ebr_space))
self.use_alignment_data(
PartitionAlignmentData(
part_align=1,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=4,
ebr_space=ebr_space,
)
)
m = make_model(storage_version=2)
d = make_disk(m, size=100, ptable='dos')
d = make_disk(m, size=100, ptable="dos")
[gap] = gaps.parts_and_gaps(d)
make_partition(m, d, offset=gap.offset, size=gap.size, flag='extended')
make_partition(m, d, offset=gap.offset, size=gap.size, flag="extended")
[p, g] = gaps.parts_and_gaps(d)
self.assertTrue(g.in_extended)
g1, g2 = g.split(10)
@ -205,125 +209,158 @@ class TestAfter(unittest.TestCase):
class TestDiskGaps(unittest.TestCase):
def test_no_partition_gpt(self):
size = 1 << 30
d = make_disk(size=size, ptable='gpt')
d = make_disk(size=size, ptable="gpt")
self.assertEqual(
gaps.find_disk_gaps_v2(d),
[gaps.Gap(d, MiB, size - 2*MiB, False)])
gaps.find_disk_gaps_v2(d), [gaps.Gap(d, MiB, size - 2 * MiB, False)]
)
def test_no_partition_dos(self):
size = 1 << 30
d = make_disk(size=size, ptable='dos')
d = make_disk(size=size, ptable="dos")
self.assertEqual(
gaps.find_disk_gaps_v2(d),
[gaps.Gap(d, MiB, size - MiB, False)])
gaps.find_disk_gaps_v2(d), [gaps.Gap(d, MiB, size - MiB, False)]
)
def test_all_partition(self):
info = PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=0,
min_end_offset=0, primary_part_limit=10)
part_align=10,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100)
p = make_partition(m, d, offset=0, size=100)
self.assertEqual(
gaps.find_disk_gaps_v2(d, info),
[p])
self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p])
def test_all_partition_with_min_offsets(self):
info = PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=10,
min_end_offset=10, primary_part_limit=10)
part_align=10,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100)
p = make_partition(m, d, offset=10, size=80)
self.assertEqual(
gaps.find_disk_gaps_v2(d, info),
[p])
self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p])
def test_half_partition(self):
info = PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=0,
min_end_offset=0, primary_part_limit=10)
part_align=10,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100)
p = make_partition(m, d, offset=0, size=50)
self.assertEqual(
gaps.find_disk_gaps_v2(d, info),
[p, gaps.Gap(d, 50, 50)])
self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p, gaps.Gap(d, 50, 50)])
def test_gap_in_middle(self):
info = PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=0,
min_end_offset=0, primary_part_limit=10)
part_align=10,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100)
p1 = make_partition(m, d, offset=0, size=20)
p2 = make_partition(m, d, offset=80, size=20)
self.assertEqual(
gaps.find_disk_gaps_v2(d, info),
[p1, gaps.Gap(d, 20, 60), p2])
self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p1, gaps.Gap(d, 20, 60), p2])
def test_small_gap(self):
info = PartitionAlignmentData(
part_align=10, min_gap_size=20, min_start_offset=0,
min_end_offset=0, primary_part_limit=10)
part_align=10,
min_gap_size=20,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100)
p1 = make_partition(m, d, offset=0, size=40)
p2 = make_partition(m, d, offset=50, size=50)
self.assertEqual(
gaps.find_disk_gaps_v2(d, info),
[p1, p2])
self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p1, p2])
def test_align_gap(self):
info = PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=0,
min_end_offset=0, primary_part_limit=10)
part_align=10,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100)
p1 = make_partition(m, d, offset=0, size=17)
p2 = make_partition(m, d, offset=53, size=47)
self.assertEqual(
gaps.find_disk_gaps_v2(d, info),
[p1, gaps.Gap(d, 20, 30), p2])
self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p1, gaps.Gap(d, 20, 30), p2])
def test_all_extended(self):
info = PartitionAlignmentData(
part_align=5, min_gap_size=1, min_start_offset=0, min_end_offset=0,
ebr_space=2, primary_part_limit=10)
m, d = make_model_and_disk(size=100, ptable='dos')
p = make_partition(m, d, offset=0, size=100, flag='extended')
part_align=5,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
ebr_space=2,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100, ptable="dos")
p = make_partition(m, d, offset=0, size=100, flag="extended")
self.assertEqual(
gaps.find_disk_gaps_v2(d, info),
[
p,
gaps.Gap(d, 5, 95, True),
])
],
)
def test_half_extended(self):
info = PartitionAlignmentData(
part_align=5, min_gap_size=1, min_start_offset=0, min_end_offset=0,
ebr_space=2, primary_part_limit=10)
part_align=5,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
ebr_space=2,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100)
p = make_partition(m, d, offset=0, size=50, flag='extended')
p = make_partition(m, d, offset=0, size=50, flag="extended")
self.assertEqual(
gaps.find_disk_gaps_v2(d, info),
[p, gaps.Gap(d, 5, 45, True), gaps.Gap(d, 50, 50, False)])
[p, gaps.Gap(d, 5, 45, True), gaps.Gap(d, 50, 50, False)],
)
def test_half_extended_one_logical(self):
info = PartitionAlignmentData(
part_align=5, min_gap_size=1, min_start_offset=0, min_end_offset=0,
ebr_space=2, primary_part_limit=10)
m, d = make_model_and_disk(size=100, ptable='dos')
p1 = make_partition(m, d, offset=0, size=50, flag='extended')
p2 = make_partition(m, d, offset=5, size=45, flag='logical')
part_align=5,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
ebr_space=2,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100, ptable="dos")
p1 = make_partition(m, d, offset=0, size=50, flag="extended")
p2 = make_partition(m, d, offset=5, size=45, flag="logical")
self.assertEqual(
gaps.find_disk_gaps_v2(d, info),
[p1, p2, gaps.Gap(d, 50, 50, False)])
gaps.find_disk_gaps_v2(d, info), [p1, p2, gaps.Gap(d, 50, 50, False)]
)
def test_half_extended_half_logical(self):
info = PartitionAlignmentData(
part_align=5, min_gap_size=1, min_start_offset=0, min_end_offset=0,
ebr_space=2, primary_part_limit=10)
m, d = make_model_and_disk(size=100, ptable='dos')
p1 = make_partition(m, d, offset=0, size=50, flag='extended')
p2 = make_partition(m, d, offset=5, size=25, flag='logical')
part_align=5,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
ebr_space=2,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100, ptable="dos")
p1 = make_partition(m, d, offset=0, size=50, flag="extended")
p2 = make_partition(m, d, offset=5, size=25, flag="logical")
self.assertEqual(
gaps.find_disk_gaps_v2(d, info),
[
@ -331,159 +368,204 @@ class TestDiskGaps(unittest.TestCase):
p2,
gaps.Gap(d, 35, 15, True),
gaps.Gap(d, 50, 50, False),
])
],
)
def test_gap_before_primary(self):
# 0----10---20---30---40---50---60---70---80---90---100
# [ g1 ][ p1 (primary) ]
info = PartitionAlignmentData(
part_align=5, min_gap_size=1, min_start_offset=0, min_end_offset=0,
ebr_space=1, primary_part_limit=10)
m, d = make_model_and_disk(size=100, ptable='dos')
part_align=5,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
ebr_space=1,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100, ptable="dos")
p1 = make_partition(m, d, offset=50, size=50)
self.assertEqual(
gaps.find_disk_gaps_v2(d, info),
[gaps.Gap(d, 0, 50, False), p1])
gaps.find_disk_gaps_v2(d, info), [gaps.Gap(d, 0, 50, False), p1]
)
def test_gap_in_extended_before_logical(self):
# 0----10---20---30---40---50---60---70---80---90---100
# [ p1 (extended) ]
# [ g1 ] [ p5 (logical) ]
info = PartitionAlignmentData(
part_align=5, min_gap_size=1, min_start_offset=0, min_end_offset=0,
ebr_space=1, primary_part_limit=10)
m, d = make_model_and_disk(size=100, ptable='dos')
p1 = make_partition(m, d, offset=0, size=100, flag='extended')
p5 = make_partition(m, d, offset=50, size=50, flag='logical')
part_align=5,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
ebr_space=1,
primary_part_limit=10,
)
m, d = make_model_and_disk(size=100, ptable="dos")
p1 = make_partition(m, d, offset=0, size=100, flag="extended")
p5 = make_partition(m, d, offset=50, size=50, flag="logical")
self.assertEqual(
gaps.find_disk_gaps_v2(d, info),
[p1, gaps.Gap(d, 5, 40, True), p5])
gaps.find_disk_gaps_v2(d, info), [p1, gaps.Gap(d, 5, 40, True), p5]
)
def test_unusable_gap_primaries(self):
info = PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=0,
min_end_offset=0, primary_part_limit=1)
part_align=10,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=1,
)
m, d = make_model_and_disk(size=100)
p = make_partition(m, d, offset=0, size=90)
g = gaps.Gap(d, offset=90, size=10,
usable=GapUsable.TOO_MANY_PRIMARY_PARTS)
self.assertEqual(
gaps.find_disk_gaps_v2(d, info),
[p, g])
g = gaps.Gap(d, offset=90, size=10, usable=GapUsable.TOO_MANY_PRIMARY_PARTS)
self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p, g])
def test_usable_gap_primaries(self):
info = PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=0,
min_end_offset=0, primary_part_limit=2)
part_align=10,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=2,
)
m, d = make_model_and_disk(size=100)
p = make_partition(m, d, offset=0, size=90)
g = gaps.Gap(d, offset=90, size=10, usable=GapUsable.YES)
self.assertEqual(
gaps.find_disk_gaps_v2(d, info),
[p, g])
self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p, g])
def test_usable_gap_extended(self):
info = PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=0,
min_end_offset=0, primary_part_limit=1)
m, d = make_model_and_disk(size=100, ptable='dos')
p = make_partition(m, d, offset=0, size=100, flag='extended')
g = gaps.Gap(d, offset=0, size=100,
in_extended=True, usable=GapUsable.YES)
self.assertEqual(
gaps.find_disk_gaps_v2(d, info),
[p, g])
part_align=10,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=1,
)
m, d = make_model_and_disk(size=100, ptable="dos")
p = make_partition(m, d, offset=0, size=100, flag="extended")
g = gaps.Gap(d, offset=0, size=100, in_extended=True, usable=GapUsable.YES)
self.assertEqual(gaps.find_disk_gaps_v2(d, info), [p, g])
class TestMovableTrailingPartitionsAndGapSize(GapTestCase):
def test_largest_non_extended(self):
self.use_alignment_data(PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=0,
min_end_offset=0, primary_part_limit=4))
self.use_alignment_data(
PartitionAlignmentData(
part_align=10,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=4,
)
)
# 0----10---20---30---40---50---60---70---80---90---100
# [ p1 ] [ p2 (extended) ]
# [ g1 ][ g2 ]
m, d = make_model_and_disk(size=100, ptable='dos')
m, d = make_model_and_disk(size=100, ptable="dos")
make_partition(m, d, offset=0, size=10)
make_partition(m, d, offset=30, size=70, flag='extended')
g = gaps.Gap(d, offset=10, size=20,
in_extended=False, usable=GapUsable.YES)
make_partition(m, d, offset=30, size=70, flag="extended")
g = gaps.Gap(d, offset=10, size=20, in_extended=False, usable=GapUsable.YES)
self.assertEqual(g, gaps.largest_gap(d, in_extended=False))
def test_largest_yes_extended(self):
self.use_alignment_data(PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=0,
min_end_offset=0, primary_part_limit=4))
self.use_alignment_data(
PartitionAlignmentData(
part_align=10,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=4,
)
)
# 0----10---20---30---40---50---60---70---80---90---100
# [ p1 ] [ p2 (extended) ]
# [ g1 ][ g2 ]
m, d = make_model_and_disk(size=100, ptable='dos')
m, d = make_model_and_disk(size=100, ptable="dos")
make_partition(m, d, offset=0, size=10)
make_partition(m, d, offset=30, size=70, flag='extended')
g = gaps.Gap(d, offset=30, size=70,
in_extended=True, usable=GapUsable.YES)
make_partition(m, d, offset=30, size=70, flag="extended")
g = gaps.Gap(d, offset=30, size=70, in_extended=True, usable=GapUsable.YES)
self.assertEqual(g, gaps.largest_gap(d, in_extended=True))
def test_largest_any_extended(self):
self.use_alignment_data(PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=0,
min_end_offset=0, primary_part_limit=4))
self.use_alignment_data(
PartitionAlignmentData(
part_align=10,
min_gap_size=1,
min_start_offset=0,
min_end_offset=0,
primary_part_limit=4,
)
)
# 0----10---20---30---40---50---60---70---80---90---100
# [ p1 ] [ p2 (extended) ]
# [ g1 ][ g2 ]
m, d = make_model_and_disk(size=100, ptable='dos')
m, d = make_model_and_disk(size=100, ptable="dos")
make_partition(m, d, offset=0, size=10)
make_partition(m, d, offset=30, size=70, flag='extended')
g = gaps.Gap(d, offset=30, size=70,
in_extended=True, usable=GapUsable.YES)
make_partition(m, d, offset=30, size=70, flag="extended")
g = gaps.Gap(d, offset=30, size=70, in_extended=True, usable=GapUsable.YES)
self.assertEqual(g, gaps.largest_gap(d))
def test_no_next_gap(self):
self.use_alignment_data(PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=10,
min_end_offset=10, primary_part_limit=10))
self.use_alignment_data(
PartitionAlignmentData(
part_align=10,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
)
)
# 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 ]#####
m, d = make_model_and_disk(size=100)
p = make_partition(m, d, offset=10, size=80)
self.assertEqual(
([], 0),
gaps.movable_trailing_partitions_and_gap_size(p))
self.assertEqual(([], 0), gaps.movable_trailing_partitions_and_gap_size(p))
def test_immediately_trailing_gap(self):
self.use_alignment_data(PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=10,
min_end_offset=10, primary_part_limit=10))
self.use_alignment_data(
PartitionAlignmentData(
part_align=10,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
)
)
# 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 ] [ p2 ] #####
m, d = make_model_and_disk(size=100)
p1 = make_partition(m, d, offset=10, size=20)
p2 = make_partition(m, d, offset=50, size=20)
self.assertEqual(
([], 20),
gaps.movable_trailing_partitions_and_gap_size(p1))
self.assertEqual(
([], 20),
gaps.movable_trailing_partitions_and_gap_size(p2))
self.assertEqual(([], 20), gaps.movable_trailing_partitions_and_gap_size(p1))
self.assertEqual(([], 20), gaps.movable_trailing_partitions_and_gap_size(p2))
def test_one_trailing_movable_partition_and_gap(self):
self.use_alignment_data(PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=10,
min_end_offset=10, primary_part_limit=10))
self.use_alignment_data(
PartitionAlignmentData(
part_align=10,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
)
)
# 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 ][ p2 ] #####
m, d = make_model_and_disk(size=100)
p1 = make_partition(m, d, offset=10, size=40)
p2 = make_partition(m, d, offset=50, size=10)
self.assertEqual(
([p2], 30),
gaps.movable_trailing_partitions_and_gap_size(p1))
self.assertEqual(([p2], 30), gaps.movable_trailing_partitions_and_gap_size(p1))
def test_two_trailing_movable_partitions_and_gap(self):
self.use_alignment_data(PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=10,
min_end_offset=10, primary_part_limit=10))
self.use_alignment_data(
PartitionAlignmentData(
part_align=10,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
)
)
# 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 ][ p2 ][ p3 ] #####
m, d = make_model_and_disk(size=100)
@ -491,122 +573,152 @@ class TestMovableTrailingPartitionsAndGapSize(GapTestCase):
p2 = make_partition(m, d, offset=50, size=10)
p3 = make_partition(m, d, offset=60, size=10)
self.assertEqual(
([p2, p3], 20),
gaps.movable_trailing_partitions_and_gap_size(p1))
([p2, p3], 20), gaps.movable_trailing_partitions_and_gap_size(p1)
)
def test_one_trailing_movable_partition_and_no_gap(self):
self.use_alignment_data(PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=10,
min_end_offset=10, primary_part_limit=10))
self.use_alignment_data(
PartitionAlignmentData(
part_align=10,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
)
)
# 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 ][ p2 ]#####
m, d = make_model_and_disk(size=100)
p1 = make_partition(m, d, offset=10, size=40)
p2 = make_partition(m, d, offset=50, size=40)
self.assertEqual(
([p2], 0),
gaps.movable_trailing_partitions_and_gap_size(p1))
self.assertEqual(([p2], 0), gaps.movable_trailing_partitions_and_gap_size(p1))
def test_full_extended_partition_then_gap(self):
self.use_alignment_data(PartitionAlignmentData(
part_align=1, min_gap_size=1, min_start_offset=10,
min_end_offset=10, primary_part_limit=10, ebr_space=2))
self.use_alignment_data(
PartitionAlignmentData(
part_align=1,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
ebr_space=2,
)
)
# 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 (extended) ] #####
# ######[ p5 (logical) ] #####
m, d = make_model_and_disk(size=100, ptable='dos')
make_partition(m, d, offset=10, size=40, flag='extended')
p5 = make_partition(m, d, offset=12, size=38, flag='logical')
self.assertEqual(
([], 0),
gaps.movable_trailing_partitions_and_gap_size(p5))
m, d = make_model_and_disk(size=100, ptable="dos")
make_partition(m, d, offset=10, size=40, flag="extended")
p5 = make_partition(m, d, offset=12, size=38, flag="logical")
self.assertEqual(([], 0), gaps.movable_trailing_partitions_and_gap_size(p5))
def test_full_extended_partition_then_part(self):
self.use_alignment_data(PartitionAlignmentData(
part_align=1, min_gap_size=1, min_start_offset=10,
min_end_offset=10, primary_part_limit=10, ebr_space=2))
self.use_alignment_data(
PartitionAlignmentData(
part_align=1,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
ebr_space=2,
)
)
# 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 (extended) ][ p2 ]#####
# ######[ p5 (logical) ] #####
m, d = make_model_and_disk(size=100, ptable='dos')
make_partition(m, d, offset=10, size=40, flag='extended')
m, d = make_model_and_disk(size=100, ptable="dos")
make_partition(m, d, offset=10, size=40, flag="extended")
make_partition(m, d, offset=50, size=40)
p5 = make_partition(m, d, offset=12, size=38, flag='logical')
self.assertEqual(
([], 0),
gaps.movable_trailing_partitions_and_gap_size(p5))
p5 = make_partition(m, d, offset=12, size=38, flag="logical")
self.assertEqual(([], 0), gaps.movable_trailing_partitions_and_gap_size(p5))
def test_gap_in_extended_partition(self):
self.use_alignment_data(PartitionAlignmentData(
part_align=1, min_gap_size=1, min_start_offset=10,
min_end_offset=10, primary_part_limit=10, ebr_space=2))
self.use_alignment_data(
PartitionAlignmentData(
part_align=1,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
ebr_space=2,
)
)
# 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 (extended) ] #####
# ######[ p5 (logical)] #####
m, d = make_model_and_disk(size=100, ptable='dos')
make_partition(m, d, offset=10, size=40, flag='extended')
p5 = make_partition(m, d, offset=12, size=30, flag='logical')
self.assertEqual(
([], 6),
gaps.movable_trailing_partitions_and_gap_size(p5))
m, d = make_model_and_disk(size=100, ptable="dos")
make_partition(m, d, offset=10, size=40, flag="extended")
p5 = make_partition(m, d, offset=12, size=30, flag="logical")
self.assertEqual(([], 6), gaps.movable_trailing_partitions_and_gap_size(p5))
def test_trailing_logical_partition_then_gap(self):
self.use_alignment_data(PartitionAlignmentData(
part_align=1, min_gap_size=1, min_start_offset=10,
min_end_offset=10, primary_part_limit=10, ebr_space=2))
self.use_alignment_data(
PartitionAlignmentData(
part_align=1,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
ebr_space=2,
)
)
# 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 (extended) ]#####
# ######[ p5 (logical)] [ p6 (logical)] #####
m, d = make_model_and_disk(size=100, ptable='dos')
make_partition(m, d, offset=10, size=80, flag='extended')
p5 = make_partition(m, d, offset=12, size=30, flag='logical')
p6 = make_partition(m, d, offset=44, size=30, flag='logical')
self.assertEqual(
([p6], 14),
gaps.movable_trailing_partitions_and_gap_size(p5))
m, d = make_model_and_disk(size=100, ptable="dos")
make_partition(m, d, offset=10, size=80, flag="extended")
p5 = make_partition(m, d, offset=12, size=30, flag="logical")
p6 = make_partition(m, d, offset=44, size=30, flag="logical")
self.assertEqual(([p6], 14), gaps.movable_trailing_partitions_and_gap_size(p5))
def test_trailing_logical_partition_then_no_gap(self):
self.use_alignment_data(PartitionAlignmentData(
part_align=1, min_gap_size=1, min_start_offset=10,
min_end_offset=10, primary_part_limit=10, ebr_space=2))
self.use_alignment_data(
PartitionAlignmentData(
part_align=1,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
ebr_space=2,
)
)
# 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 (extended) ]#####
# ######[ p5 (logical)] [ p6 (logical) ] #####
m, d = make_model_and_disk(size=100, ptable='dos')
make_partition(m, d, offset=10, size=80, flag='extended')
p5 = make_partition(m, d, offset=12, size=30, flag='logical')
p6 = make_partition(m, d, offset=44, size=44, flag='logical')
self.assertEqual(
([p6], 0),
gaps.movable_trailing_partitions_and_gap_size(p5))
m, d = make_model_and_disk(size=100, ptable="dos")
make_partition(m, d, offset=10, size=80, flag="extended")
p5 = make_partition(m, d, offset=12, size=30, flag="logical")
p6 = make_partition(m, d, offset=44, size=44, flag="logical")
self.assertEqual(([p6], 0), gaps.movable_trailing_partitions_and_gap_size(p5))
def test_trailing_preserved_partition(self):
self.use_alignment_data(PartitionAlignmentData(
part_align=10, min_gap_size=1, min_start_offset=10,
min_end_offset=10, primary_part_limit=10))
self.use_alignment_data(
PartitionAlignmentData(
part_align=10,
min_gap_size=1,
min_start_offset=10,
min_end_offset=10,
primary_part_limit=10,
)
)
# 0----10---20---30---40---50---60---70---80---90---100
# #####[ p1 ][ p2 p ] #####
m, d = make_model_and_disk(size=100)
p1 = make_partition(m, d, offset=10, size=40)
make_partition(m, d, offset=50, size=20, preserve=True)
self.assertEqual(
([], 0),
gaps.movable_trailing_partitions_and_gap_size(p1))
self.assertEqual(([], 0), gaps.movable_trailing_partitions_and_gap_size(p1))
def test_trailing_lvm_logical_volume_no_gap(self):
model, lv = make_model_and_lv()
self.assertEqual(
([], 0),
gaps.movable_trailing_partitions_and_gap_size(lv))
self.assertEqual(([], 0), gaps.movable_trailing_partitions_and_gap_size(lv))
def test_trailing_lvm_logical_volume_with_gap(self):
model, disk = make_model_and_disk(
size=100 * LVM_CHUNK_SIZE + LVM_OVERHEAD)
model, disk = make_model_and_disk(size=100 * LVM_CHUNK_SIZE + LVM_OVERHEAD)
vg = make_vg(model, pvs={disk})
lv = make_lv(model, vg=vg, size=40 * LVM_CHUNK_SIZE)
self.assertEqual(
([], LVM_CHUNK_SIZE * 60),
gaps.movable_trailing_partitions_and_gap_size(lv))
([], LVM_CHUNK_SIZE * 60), gaps.movable_trailing_partitions_and_gap_size(lv)
)
class TestLargestGaps(unittest.TestCase):
@ -667,6 +779,7 @@ class TestLargestGaps(unittest.TestCase):
class TestUsable(unittest.TestCase):
def test_strings(self):
self.assertEqual('YES', GapUsable.YES.name)
self.assertEqual('TOO_MANY_PRIMARY_PARTS',
GapUsable.TOO_MANY_PRIMARY_PARTS.name)
self.assertEqual("YES", GapUsable.YES.name)
self.assertEqual(
"TOO_MANY_PRIMARY_PARTS", GapUsable.TOO_MANY_PRIMARY_PARTS.name
)

View File

@ -16,11 +16,7 @@
import unittest
from subiquity.common.filesystem.labels import (
annotations,
for_client,
usage_labels,
)
from subiquity.common.filesystem.labels import annotations, for_client, usage_labels
from subiquity.models.tests.test_filesystem import (
make_model,
make_model_and_disk,
@ -31,7 +27,6 @@ from subiquity.models.tests.test_filesystem import (
class TestAnnotations(unittest.TestCase):
def test_disk_annotations(self):
# disks never have annotations
model, disk = make_model_and_disk()
@ -42,93 +37,88 @@ class TestAnnotations(unittest.TestCase):
def test_partition_annotations(self):
model = make_model()
part = make_partition(model)
self.assertEqual(annotations(part), ['new'])
self.assertEqual(annotations(part), ["new"])
part.preserve = True
self.assertEqual(annotations(part), ['existing'])
self.assertEqual(annotations(part), ["existing"])
model = make_model()
part = make_partition(model, flag="bios_grub")
self.assertEqual(
annotations(part), ['new', 'BIOS grub spacer'])
self.assertEqual(annotations(part), ["new", "BIOS grub spacer"])
part.preserve = True
self.assertEqual(
annotations(part),
['existing', 'unconfigured', 'BIOS grub spacer'])
annotations(part), ["existing", "unconfigured", "BIOS grub spacer"]
)
part.device.grub_device = True
self.assertEqual(
annotations(part),
['existing', 'configured', 'BIOS grub spacer'])
annotations(part), ["existing", "configured", "BIOS grub spacer"]
)
model = make_model()
part = make_partition(model, flag="boot", grub_device=True)
self.assertEqual(annotations(part), ['new', 'backup ESP'])
self.assertEqual(annotations(part), ["new", "backup ESP"])
fs = model.add_filesystem(part, fstype="fat32")
model.add_mount(fs, "/boot/efi")
self.assertEqual(annotations(part), ['new', 'primary ESP'])
self.assertEqual(annotations(part), ["new", "primary ESP"])
model = make_model()
part = make_partition(model, flag="boot", preserve=True)
self.assertEqual(annotations(part), ['existing', 'unused ESP'])
self.assertEqual(annotations(part), ["existing", "unused ESP"])
part.grub_device = True
self.assertEqual(annotations(part), ['existing', 'backup ESP'])
self.assertEqual(annotations(part), ["existing", "backup ESP"])
fs = model.add_filesystem(part, fstype="fat32")
model.add_mount(fs, "/boot/efi")
self.assertEqual(annotations(part), ['existing', 'primary ESP'])
self.assertEqual(annotations(part), ["existing", "primary ESP"])
model = make_model()
part = make_partition(model, flag="prep", grub_device=True)
self.assertEqual(annotations(part), ['new', 'PReP'])
self.assertEqual(annotations(part), ["new", "PReP"])
model = make_model()
part = make_partition(model, flag="prep", preserve=True)
self.assertEqual(
annotations(part), ['existing', 'PReP', 'unconfigured'])
self.assertEqual(annotations(part), ["existing", "PReP", "unconfigured"])
part.grub_device = True
self.assertEqual(
annotations(part), ['existing', 'PReP', 'configured'])
self.assertEqual(annotations(part), ["existing", "PReP", "configured"])
def test_vg_default_annotations(self):
model, disk = make_model_and_disk()
vg = model.add_volgroup('vg-0', {disk})
self.assertEqual(annotations(vg), ['new'])
vg = model.add_volgroup("vg-0", {disk})
self.assertEqual(annotations(vg), ["new"])
vg.preserve = True
self.assertEqual(annotations(vg), ['existing'])
self.assertEqual(annotations(vg), ["existing"])
def test_vg_encrypted_annotations(self):
model, disk = make_model_and_disk()
dm_crypt = model.add_dm_crypt(disk, key='passw0rd')
vg = model.add_volgroup('vg-0', {dm_crypt})
self.assertEqual(annotations(vg), ['new', 'encrypted'])
dm_crypt = model.add_dm_crypt(disk, key="passw0rd")
vg = model.add_volgroup("vg-0", {dm_crypt})
self.assertEqual(annotations(vg), ["new", "encrypted"])
class TestUsageLabels(unittest.TestCase):
def test_partition_usage_labels(self):
model, partition = make_model_and_partition()
self.assertEqual(usage_labels(partition), ["unused"])
fs = model.add_filesystem(partition, 'ext4')
fs = model.add_filesystem(partition, "ext4")
self.assertEqual(
usage_labels(partition),
["to be formatted as ext4", "not mounted"])
usage_labels(partition), ["to be formatted as ext4", "not mounted"]
)
model._orig_config = model._render_actions()
fs.preserve = True
partition.preserve = True
self.assertEqual(
usage_labels(partition),
["already formatted as ext4", "not mounted"])
usage_labels(partition), ["already formatted as ext4", "not mounted"]
)
model.remove_filesystem(fs)
fs2 = model.add_filesystem(partition, 'ext4')
fs2 = model.add_filesystem(partition, "ext4")
self.assertEqual(
usage_labels(partition),
["to be reformatted as ext4", "not mounted"])
model.add_mount(fs2, '/')
usage_labels(partition), ["to be reformatted as ext4", "not mounted"]
)
model.add_mount(fs2, "/")
self.assertEqual(
usage_labels(partition),
["to be reformatted as ext4", "mounted at /"])
usage_labels(partition), ["to be reformatted as ext4", "mounted at /"]
)
class TestForClient(unittest.TestCase):
def test_for_client_raid_parts(self):
model, raid = make_model_and_raid()
make_partition(model, raid)

View File

@ -19,24 +19,17 @@ from unittest import mock
import attr
from subiquitycore.tests.parameterized import parameterized
from subiquity.common.filesystem.actions import (
DeviceAction,
)
from subiquity.common.filesystem import boot, gaps, sizes
from subiquity.common.filesystem.actions import DeviceAction
from subiquity.common.filesystem.manipulator import FilesystemManipulator
from subiquity.models.filesystem import Bootloader, MiB, Partition
from subiquity.models.tests.test_filesystem import (
make_disk,
make_filesystem,
make_model,
make_partition,
make_filesystem,
)
from subiquity.models.filesystem import (
Bootloader,
MiB,
Partition,
)
from subiquitycore.tests.parameterized import parameterized
def make_manipulator(bootloader=None, storage_version=1):
@ -78,18 +71,16 @@ class Create:
class TestFilesystemManipulator(unittest.TestCase):
def test_delete_encrypted_vg(self):
manipulator, disk = make_manipulator_and_disk()
spec = {
'passphrase': 'passw0rd',
'devices': {disk},
'name': 'vg0',
"passphrase": "passw0rd",
"devices": {disk},
"name": "vg0",
}
vg = manipulator.create_volgroup(spec)
manipulator.delete_volgroup(vg)
dm_crypts = [
a for a in manipulator.model._actions if a.type == 'dm_crypt']
dm_crypts = [a for a in manipulator.model._actions if a.type == "dm_crypt"]
self.assertEqual(dm_crypts, [])
def test_wipe_existing_fs(self):
@ -98,18 +89,19 @@ class TestFilesystemManipulator(unittest.TestCase):
# example: it's ext2, wipe and make it ext2 again
manipulator, d = make_manipulator_and_disk()
p = make_partition(manipulator.model, d, preserve=True)
fs = make_filesystem(manipulator.model, partition=p, fstype='ext2',
preserve=True)
fs = make_filesystem(
manipulator.model, partition=p, fstype="ext2", preserve=True
)
manipulator.model._actions.append(fs)
manipulator.delete_filesystem(fs)
spec = {'wipe': 'random'}
with mock.patch.object(p, 'original_fstype') as original_fstype:
original_fstype.return_value = 'ext2'
spec = {"wipe": "random"}
with mock.patch.object(p, "original_fstype") as original_fstype:
original_fstype.return_value = "ext2"
fs = manipulator.create_filesystem(p, spec)
self.assertTrue(p.preserve)
self.assertEqual('random', p.wipe)
self.assertEqual("random", p.wipe)
self.assertFalse(fs.preserve)
self.assertEqual('ext2', fs.fstype)
self.assertEqual("ext2", fs.fstype)
def test_can_only_add_boot_once(self):
# This is really testing model code but it's much easier to test with a
@ -122,7 +114,8 @@ class TestFilesystemManipulator(unittest.TestCase):
self.assertFalse(
DeviceAction.TOGGLE_BOOT.can(disk)[0],
"add_boot_disk(disk) did not make _can_TOGGLE_BOOT false "
"with bootloader {}".format(bl))
"with bootloader {}".format(bl),
)
def assertIsMountedAtBootEFI(self, device):
efi_mnts = device._m._all(type="mount", path="/boot/efi")
@ -136,20 +129,23 @@ class TestFilesystemManipulator(unittest.TestCase):
def add_existing_boot_partition(self, manipulator, disk):
if manipulator.model.bootloader == Bootloader.BIOS:
part = manipulator.model.add_partition(
disk, size=1 << 20, offset=0, flag="bios_grub")
disk, size=1 << 20, offset=0, flag="bios_grub"
)
elif manipulator.model.bootloader == Bootloader.UEFI:
part = manipulator.model.add_partition(
disk, size=512 << 20, offset=0, flag="boot")
disk, size=512 << 20, offset=0, flag="boot"
)
elif manipulator.model.bootloader == Bootloader.PREP:
part = manipulator.model.add_partition(
disk, size=8 << 20, offset=0, flag="prep")
disk, size=8 << 20, offset=0, flag="prep"
)
part.preserve = True
return part
def assertIsBootDisk(self, manipulator, disk):
if manipulator.model.bootloader == Bootloader.BIOS:
self.assertTrue(disk.grub_device)
if disk.ptable == 'gpt':
if disk.ptable == "gpt":
self.assertEqual(disk.partitions()[0].flag, "bios_grub")
elif manipulator.model.bootloader == Bootloader.UEFI:
for part in disk.partitions():
@ -159,7 +155,7 @@ class TestFilesystemManipulator(unittest.TestCase):
elif manipulator.model.bootloader == Bootloader.PREP:
for part in disk.partitions():
if part.flag == "prep" and part.grub_device:
self.assertEqual(part.wipe, 'zero')
self.assertEqual(part.wipe, "zero")
return
self.fail("{} is not a boot disk".format(disk))
@ -186,7 +182,8 @@ class TestFilesystemManipulator(unittest.TestCase):
disk2 = make_disk(manipulator.model, preserve=False)
gap = gaps.largest_gap(disk2)
disk2p1 = manipulator.model.add_partition(
disk2, size=gap.size, offset=gap.offset)
disk2, size=gap.size, offset=gap.offset
)
manipulator.add_boot_disk(disk1)
self.assertIsBootDisk(manipulator, disk1)
@ -202,8 +199,7 @@ class TestFilesystemManipulator(unittest.TestCase):
self.assertNotMounted(disk2.partitions()[0])
self.assertEqual(len(disk2.partitions()), 2)
self.assertEqual(disk2.partitions()[1], disk2p1)
self.assertEqual(
disk2.partitions()[0].size + disk2p1.size, size_before)
self.assertEqual(disk2.partitions()[0].size + disk2p1.size, size_before)
manipulator.remove_boot_disk(disk1)
self.assertIsNotBootDisk(manipulator, disk1)
@ -228,7 +224,8 @@ class TestFilesystemManipulator(unittest.TestCase):
disk2 = make_disk(manipulator.model, preserve=False)
gap = gaps.largest_gap(disk2)
disk2p1 = manipulator.model.add_partition(
disk2, size=gap.size, offset=gap.offset)
disk2, size=gap.size, offset=gap.offset
)
manipulator.add_boot_disk(disk1)
self.assertIsBootDisk(manipulator, disk1)
@ -243,8 +240,7 @@ class TestFilesystemManipulator(unittest.TestCase):
self.assertIsMountedAtBootEFI(disk2.partitions()[0])
self.assertEqual(len(disk2.partitions()), 2)
self.assertEqual(disk2.partitions()[1], disk2p1)
self.assertEqual(
disk2.partitions()[0].size + disk2p1.size, size_before)
self.assertEqual(disk2.partitions()[0].size + disk2p1.size, size_before)
def test_boot_disk_existing(self):
for bl in Bootloader:
@ -272,13 +268,16 @@ class TestFilesystemManipulator(unittest.TestCase):
manipulator = make_manipulator(Bootloader.UEFI)
disk1 = make_disk(manipulator.model, preserve=True)
disk1p1 = manipulator.model.add_partition(
disk1, size=512 << 20, offset=0, flag="boot")
disk1, size=512 << 20, offset=0, flag="boot"
)
disk1p1.preserve = True
disk1p2 = manipulator.model.add_partition(
disk1, size=8192 << 20, offset=513 << 20)
disk1, size=8192 << 20, offset=513 << 20
)
disk1p2.preserve = True
manipulator.partition_disk_handler(
disk1, {'fstype': 'ext4', 'mount': '/'}, partition=disk1p2)
disk1, {"fstype": "ext4", "mount": "/"}, partition=disk1p2
)
efi_mnt = manipulator.model._mount_for_path("/boot/efi")
self.assertEqual(efi_mnt.device.volume, disk1p1)
@ -295,14 +294,15 @@ class TestFilesystemManipulator(unittest.TestCase):
@contextlib.contextmanager
def assertPartitionOperations(self, disk, *ops):
existing_parts = set(disk.partitions())
part_details = {}
for op in ops:
if isinstance(op, MoveResize):
part_details[op.part] = (
op.part.offset + op.offset, op.part.size + op.size)
op.part.offset + op.offset,
op.part.size + op.size,
)
try:
yield
@ -314,8 +314,8 @@ class TestFilesystemManipulator(unittest.TestCase):
if isinstance(op, MoveResize):
new_parts.remove(op.part)
self.assertEqual(
part_details[op.part],
(op.part.offset, op.part.size))
part_details[op.part], (op.part.offset, op.part.size)
)
else:
for part in created_parts:
lhs = (part.offset, part.size)
@ -339,7 +339,8 @@ class TestFilesystemManipulator(unittest.TestCase):
disk,
Create(
offset=disk.alignment_data().min_start_offset,
size=sizes.BIOS_GRUB_SIZE_BYTES),
size=sizes.BIOS_GRUB_SIZE_BYTES,
),
):
manipulator.add_boot_disk(disk)
@ -349,18 +350,19 @@ class TestFilesystemManipulator(unittest.TestCase):
def test_add_boot_BIOS_full(self, version):
manipulator = make_manipulator(Bootloader.BIOS, version)
disk = make_disk(manipulator.model, preserve=True)
part = make_partition(
manipulator.model, disk, size=gaps.largest_gap_size(disk))
part = make_partition(manipulator.model, disk, size=gaps.largest_gap_size(disk))
with self.assertPartitionOperations(
disk,
Create(
offset=disk.alignment_data().min_start_offset,
size=sizes.BIOS_GRUB_SIZE_BYTES),
size=sizes.BIOS_GRUB_SIZE_BYTES,
),
MoveResize(
part=part,
offset=sizes.BIOS_GRUB_SIZE_BYTES,
size=-sizes.BIOS_GRUB_SIZE_BYTES),
size=-sizes.BIOS_GRUB_SIZE_BYTES,
),
):
manipulator.add_boot_disk(disk)
@ -371,14 +373,15 @@ class TestFilesystemManipulator(unittest.TestCase):
manipulator = make_manipulator(Bootloader.BIOS, version)
disk = make_disk(manipulator.model, preserve=True)
part = make_partition(
manipulator.model, disk, size=gaps.largest_gap_size(disk)//2)
manipulator.model, disk, size=gaps.largest_gap_size(disk) // 2
)
with self.assertPartitionOperations(
disk,
Create(
offset=disk.alignment_data().min_start_offset,
size=sizes.BIOS_GRUB_SIZE_BYTES),
Move(
part=part, offset=sizes.BIOS_GRUB_SIZE_BYTES),
size=sizes.BIOS_GRUB_SIZE_BYTES,
),
Move(part=part, offset=sizes.BIOS_GRUB_SIZE_BYTES),
):
manipulator.add_boot_disk(disk)
self.assertIsBootDisk(manipulator, disk)
@ -390,20 +393,23 @@ class TestFilesystemManipulator(unittest.TestCase):
# divided by 4 is an whole number of megabytes.
disk = make_disk(manipulator.model, preserve=True, size=2002 * MiB)
part_smaller = make_partition(
manipulator.model, disk, size=gaps.largest_gap_size(disk)//4)
manipulator.model, disk, size=gaps.largest_gap_size(disk) // 4
)
part_larger = make_partition(
manipulator.model, disk, size=gaps.largest_gap_size(disk))
manipulator.model, disk, size=gaps.largest_gap_size(disk)
)
with self.assertPartitionOperations(
disk,
Create(
offset=disk.alignment_data().min_start_offset,
size=sizes.BIOS_GRUB_SIZE_BYTES),
Move(
part=part_smaller, offset=sizes.BIOS_GRUB_SIZE_BYTES),
size=sizes.BIOS_GRUB_SIZE_BYTES,
),
Move(part=part_smaller, offset=sizes.BIOS_GRUB_SIZE_BYTES),
MoveResize(
part=part_larger,
offset=sizes.BIOS_GRUB_SIZE_BYTES,
size=-sizes.BIOS_GRUB_SIZE_BYTES),
size=-sizes.BIOS_GRUB_SIZE_BYTES,
),
):
manipulator.add_boot_disk(disk)
self.assertIsBootDisk(manipulator, disk)
@ -413,8 +419,7 @@ class TestFilesystemManipulator(unittest.TestCase):
manipulator = make_manipulator(Bootloader.BIOS, version)
disk = make_disk(manipulator.model, preserve=True)
full_size = gaps.largest_gap_size(disk)
make_partition(
manipulator.model, disk, size=full_size, preserve=True)
make_partition(manipulator.model, disk, size=full_size, preserve=True)
self.assertFalse(boot.can_be_boot_device(disk))
def test_no_add_boot_BIOS_preserved_v1(self):
@ -422,8 +427,8 @@ class TestFilesystemManipulator(unittest.TestCase):
disk = make_disk(manipulator.model, preserve=True)
half_size = gaps.largest_gap_size(disk) // 2
make_partition(
manipulator.model, disk, size=half_size, offset=half_size,
preserve=True)
manipulator.model, disk, size=half_size, offset=half_size, preserve=True
)
self.assertFalse(boot.can_be_boot_device(disk))
def test_add_boot_BIOS_preserved_v2(self):
@ -431,13 +436,14 @@ class TestFilesystemManipulator(unittest.TestCase):
disk = make_disk(manipulator.model, preserve=True)
half_size = gaps.largest_gap_size(disk) // 2
part = make_partition(
manipulator.model, disk, size=half_size, offset=half_size,
preserve=True)
manipulator.model, disk, size=half_size, offset=half_size, preserve=True
)
with self.assertPartitionOperations(
disk,
Create(
offset=disk.alignment_data().min_start_offset,
size=sizes.BIOS_GRUB_SIZE_BYTES),
size=sizes.BIOS_GRUB_SIZE_BYTES,
),
Unchanged(part=part),
):
manipulator.add_boot_disk(disk)
@ -449,19 +455,21 @@ class TestFilesystemManipulator(unittest.TestCase):
# divided by 4 is an whole number of megabytes.
disk = make_disk(manipulator.model, preserve=True, size=2002 * MiB)
avail = gaps.largest_gap_size(disk)
p1_new = make_partition(
manipulator.model, disk, size=avail//4)
p1_new = make_partition(manipulator.model, disk, size=avail // 4)
p2_preserved = make_partition(
manipulator.model, disk, size=3*avail//4, preserve=True)
manipulator.model, disk, size=3 * avail // 4, preserve=True
)
with self.assertPartitionOperations(
disk,
Create(
offset=disk.alignment_data().min_start_offset,
size=sizes.BIOS_GRUB_SIZE_BYTES),
size=sizes.BIOS_GRUB_SIZE_BYTES,
),
MoveResize(
part=p1_new,
offset=sizes.BIOS_GRUB_SIZE_BYTES,
size=-sizes.BIOS_GRUB_SIZE_BYTES),
size=-sizes.BIOS_GRUB_SIZE_BYTES,
),
Unchanged(part=p2_preserved),
):
manipulator.add_boot_disk(disk)
@ -472,9 +480,7 @@ class TestFilesystemManipulator(unittest.TestCase):
manipulator = make_manipulator(Bootloader.BIOS, version)
disk = make_disk(manipulator.model, preserve=True)
half_size = gaps.largest_gap_size(disk) // 2
make_partition(
manipulator.model, disk, size=half_size,
preserve=True)
make_partition(manipulator.model, disk, size=half_size, preserve=True)
self.assertFalse(boot.can_be_boot_device(disk))
@parameterized.expand([(1,), (2,)])
@ -482,16 +488,16 @@ class TestFilesystemManipulator(unittest.TestCase):
# Showed up in VM testing - the ISO created cloud-localds shows up as a
# potential install target, and weird things happen.
manipulator = make_manipulator(Bootloader.BIOS, version)
disk = make_disk(manipulator.model, ptable='gpt', size=1 << 20)
disk = make_disk(manipulator.model, ptable="gpt", size=1 << 20)
self.assertFalse(boot.can_be_boot_device(disk))
@parameterized.expand([(1,), (2,)])
def test_add_boot_BIOS_msdos(self, version):
manipulator = make_manipulator(Bootloader.BIOS, version)
disk = make_disk(manipulator.model, preserve=True, ptable='msdos')
disk = make_disk(manipulator.model, preserve=True, ptable="msdos")
part = make_partition(
manipulator.model, disk, size=gaps.largest_gap_size(disk),
preserve=True)
manipulator.model, disk, size=gaps.largest_gap_size(disk), preserve=True
)
with self.assertPartitionOperations(
disk,
Unchanged(part=part),
@ -512,7 +518,8 @@ class TestFilesystemManipulator(unittest.TestCase):
(Bootloader.UEFI, 1),
(Bootloader.PREP, 1),
(Bootloader.UEFI, 2),
(Bootloader.PREP, 2)]
(Bootloader.PREP, 2),
]
@parameterized.expand(ADD_BOOT_PARAMS)
def test_add_boot_empty(self, bl, version):
@ -522,7 +529,8 @@ class TestFilesystemManipulator(unittest.TestCase):
disk,
Create(
offset=disk.alignment_data().min_start_offset,
size=self.boot_size_for_disk(disk)),
size=self.boot_size_for_disk(disk),
),
):
manipulator.add_boot_disk(disk)
self.assertIsBootDisk(manipulator, disk)
@ -531,18 +539,12 @@ class TestFilesystemManipulator(unittest.TestCase):
def test_add_boot_full(self, bl, version):
manipulator = make_manipulator(bl, version)
disk = make_disk(manipulator.model, preserve=True)
part = make_partition(
manipulator.model, disk, size=gaps.largest_gap_size(disk))
part = make_partition(manipulator.model, disk, size=gaps.largest_gap_size(disk))
size = self.boot_size_for_disk(disk)
with self.assertPartitionOperations(
disk,
Create(
offset=disk.alignment_data().min_start_offset,
size=size),
MoveResize(
part=part,
offset=size,
size=-size),
Create(offset=disk.alignment_data().min_start_offset, size=size),
MoveResize(part=part, offset=size, size=-size),
):
manipulator.add_boot_disk(disk)
self.assertIsBootDisk(manipulator, disk)
@ -552,14 +554,13 @@ class TestFilesystemManipulator(unittest.TestCase):
manipulator = make_manipulator(bl, version)
disk = make_disk(manipulator.model, preserve=True)
part = make_partition(
manipulator.model, disk, size=gaps.largest_gap_size(disk)//2)
manipulator.model, disk, size=gaps.largest_gap_size(disk) // 2
)
size = self.boot_size_for_disk(disk)
with self.assertPartitionOperations(
disk,
Unchanged(part=part),
Create(
offset=part.offset + part.size,
size=size),
Create(offset=part.offset + part.size, size=size),
):
manipulator.add_boot_disk(disk)
self.assertIsBootDisk(manipulator, disk)
@ -571,20 +572,17 @@ class TestFilesystemManipulator(unittest.TestCase):
# divided by 4 is an whole number of megabytes.
disk = make_disk(manipulator.model, preserve=True, size=2002 * MiB)
part_smaller = make_partition(
manipulator.model, disk, size=gaps.largest_gap_size(disk)//4)
manipulator.model, disk, size=gaps.largest_gap_size(disk) // 4
)
part_larger = make_partition(
manipulator.model, disk, size=gaps.largest_gap_size(disk))
manipulator.model, disk, size=gaps.largest_gap_size(disk)
)
size = self.boot_size_for_disk(disk)
with self.assertPartitionOperations(
disk,
Unchanged(part_smaller),
Create(
offset=part_smaller.offset + part_smaller.size,
size=size),
MoveResize(
part=part_larger,
offset=size,
size=-size),
Create(offset=part_smaller.offset + part_smaller.size, size=size),
MoveResize(part=part_larger, offset=size, size=-size),
):
manipulator.add_boot_disk(disk)
self.assertIsBootDisk(manipulator, disk)
@ -594,8 +592,7 @@ class TestFilesystemManipulator(unittest.TestCase):
manipulator = make_manipulator(bl, version)
disk = make_disk(manipulator.model, preserve=True)
full_size = gaps.largest_gap_size(disk)
make_partition(
manipulator.model, disk, size=full_size, preserve=True)
make_partition(manipulator.model, disk, size=full_size, preserve=True)
self.assertFalse(boot.can_be_boot_device(disk))
@parameterized.expand(ADD_BOOT_PARAMS)
@ -603,9 +600,7 @@ class TestFilesystemManipulator(unittest.TestCase):
manipulator = make_manipulator(bl, version)
disk = make_disk(manipulator.model, preserve=True)
half_size = gaps.largest_gap_size(disk) // 2
part = make_partition(
manipulator.model, disk, size=half_size,
preserve=True)
part = make_partition(manipulator.model, disk, size=half_size, preserve=True)
size = self.boot_size_for_disk(disk)
if version == 1:
self.assertFalse(boot.can_be_boot_device(disk))
@ -613,9 +608,7 @@ class TestFilesystemManipulator(unittest.TestCase):
with self.assertPartitionOperations(
disk,
Unchanged(part),
Create(
offset=part.offset + part.size,
size=size),
Create(offset=part.offset + part.size, size=size),
):
manipulator.add_boot_disk(disk)
self.assertIsBootDisk(manipulator, disk)
@ -627,22 +620,15 @@ class TestFilesystemManipulator(unittest.TestCase):
manipulator = make_manipulator(bl, 2)
disk = make_disk(manipulator.model, preserve=True)
half_size = gaps.largest_gap_size(disk) // 2
old_part = make_partition(
manipulator.model, disk, size=half_size)
new_part = make_partition(
manipulator.model, disk, size=half_size)
old_part = make_partition(manipulator.model, disk, size=half_size)
new_part = make_partition(manipulator.model, disk, size=half_size)
old_part.preserve = True
size = self.boot_size_for_disk(disk)
with self.assertPartitionOperations(
disk,
Unchanged(old_part),
Create(
offset=new_part.offset,
size=size),
MoveResize(
part=new_part,
offset=size,
size=-size),
Create(offset=new_part.offset, size=size),
MoveResize(part=new_part, offset=size, size=-size),
):
manipulator.add_boot_disk(disk)
self.assertIsBootDisk(manipulator, disk)
@ -652,12 +638,12 @@ class TestFilesystemManipulator(unittest.TestCase):
manipulator = make_manipulator(bl, version)
disk = make_disk(manipulator.model, preserve=True, size=2002 * MiB)
if bl == Bootloader.UEFI:
flag = 'boot'
flag = "boot"
else:
flag = 'prep'
flag = "prep"
boot_part = make_partition(
manipulator.model, disk, size=self.boot_size_for_disk(disk),
flag=flag)
manipulator.model, disk, size=self.boot_size_for_disk(disk), flag=flag
)
rest_part = make_partition(manipulator.model, disk)
boot_part.preserve = rest_part.preserve = True
with self.assertPartitionOperations(
@ -671,39 +657,39 @@ class TestFilesystemManipulator(unittest.TestCase):
def test_logical_no_esp_in_gap(self):
manipulator = make_manipulator(Bootloader.UEFI, storage_version=2)
m = manipulator.model
d1 = make_disk(m, preserve=True, ptable='msdos')
d1 = make_disk(m, preserve=True, ptable="msdos")
size = gaps.largest_gap_size(d1)
make_partition(m, d1, flag='extended', preserve=True, size=size)
make_partition(m, d1, flag="extended", preserve=True, size=size)
size = gaps.largest_gap_size(d1) // 2
make_partition(m, d1, flag='logical', preserve=True, size=size)
make_partition(m, d1, flag="logical", preserve=True, size=size)
self.assertTrue(gaps.largest_gap_size(d1) > sizes.uefi_scale.maximum)
self.assertFalse(boot.can_be_boot_device(d1))
def test_logical_no_esp_by_resize(self):
manipulator = make_manipulator(Bootloader.UEFI, storage_version=2)
m = manipulator.model
d1 = make_disk(m, preserve=True, ptable='msdos')
d1 = make_disk(m, preserve=True, ptable="msdos")
size = gaps.largest_gap_size(d1)
make_partition(m, d1, flag='extended', preserve=True, size=size)
make_partition(m, d1, flag="extended", preserve=True, size=size)
size = gaps.largest_gap_size(d1)
p5 = make_partition(m, d1, flag='logical', preserve=True, size=size)
p5 = make_partition(m, d1, flag="logical", preserve=True, size=size)
self.assertFalse(boot.can_be_boot_device(d1, resize_partition=p5))
def test_logical_no_esp_in_new_part(self):
manipulator = make_manipulator(Bootloader.UEFI, storage_version=2)
m = manipulator.model
d1 = make_disk(m, preserve=True, ptable='msdos')
d1 = make_disk(m, preserve=True, ptable="msdos")
size = gaps.largest_gap_size(d1)
make_partition(m, d1, flag='extended', size=size)
make_partition(m, d1, flag="extended", size=size)
size = gaps.largest_gap_size(d1)
make_partition(m, d1, flag='logical', size=size)
make_partition(m, d1, flag="logical", size=size)
self.assertFalse(boot.can_be_boot_device(d1))
@parameterized.expand([[Bootloader.UEFI], [Bootloader.PREP]])
def test_too_many_primaries(self, bl):
manipulator = make_manipulator(bl, storage_version=2)
m = manipulator.model
d1 = make_disk(m, preserve=True, ptable='msdos', size=100 << 30)
d1 = make_disk(m, preserve=True, ptable="msdos", size=100 << 30)
tenth = 10 << 30
for i in range(4):
make_partition(m, d1, preserve=True, size=tenth)
@ -715,12 +701,12 @@ class TestFilesystemManipulator(unittest.TestCase):
def test_logical_when_in_extended(self):
manipulator = make_manipulator(storage_version=2)
m = manipulator.model
d1 = make_disk(m, preserve=True, ptable='msdos')
d1 = make_disk(m, preserve=True, ptable="msdos")
size = gaps.largest_gap_size(d1)
make_partition(m, d1, flag='extended', size=size)
make_partition(m, d1, flag="extended", size=size)
gap = gaps.largest_gap(d1)
p = manipulator.create_partition(d1, gap, spec={}, flag='primary')
self.assertEqual(p.flag, 'logical')
p = manipulator.create_partition(d1, gap, spec={}, flag="primary")
self.assertEqual(p.flag, "logical")
class TestReformat(unittest.TestCase):
@ -733,19 +719,19 @@ class TestReformat(unittest.TestCase):
self.assertEqual(None, disk.ptable)
def test_reformat_keep_current(self):
disk = make_disk(self.manipulator.model, ptable='msdos')
disk = make_disk(self.manipulator.model, ptable="msdos")
self.manipulator.reformat(disk)
self.assertEqual('msdos', disk.ptable)
self.assertEqual("msdos", disk.ptable)
def test_reformat_to_gpt(self):
disk = make_disk(self.manipulator.model, ptable=None)
self.manipulator.reformat(disk, 'gpt')
self.assertEqual('gpt', disk.ptable)
self.manipulator.reformat(disk, "gpt")
self.assertEqual("gpt", disk.ptable)
def test_reformat_to_msdos(self):
disk = make_disk(self.manipulator.model, ptable=None)
self.manipulator.reformat(disk, 'msdos')
self.assertEqual('msdos', disk.ptable)
self.manipulator.reformat(disk, "msdos")
self.assertEqual("msdos", disk.ptable)
class TestCanResize(unittest.TestCase):
@ -760,11 +746,11 @@ class TestCanResize(unittest.TestCase):
def test_resize_ext4(self):
disk = make_disk(self.manipulator.model, ptable=None)
part = make_partition(self.manipulator.model, disk, preserve=True)
make_filesystem(self.manipulator.model, partition=part, fstype='ext4')
make_filesystem(self.manipulator.model, partition=part, fstype="ext4")
self.assertTrue(self.manipulator.can_resize_partition(part))
def test_resize_invalid(self):
disk = make_disk(self.manipulator.model, ptable=None)
part = make_partition(self.manipulator.model, disk, preserve=True)
make_filesystem(self.manipulator.model, partition=part, fstype='asdf')
make_filesystem(self.manipulator.model, partition=part, fstype="asdf")
self.assertFalse(self.manipulator.can_resize_partition(part))

View File

@ -18,19 +18,17 @@ import unittest
from unittest import mock
from subiquity.common.filesystem.sizes import (
PartitionScaleFactors,
bootfs_scale,
calculate_guided_resize,
calculate_suggested_install_min,
get_efi_size,
get_bootfs_size,
PartitionScaleFactors,
get_efi_size,
scale_partitions,
uefi_scale,
)
from subiquity.common.types import GuidedResizeValues
from subiquity.models.filesystem import (
MiB,
)
from subiquity.models.filesystem import MiB
class TestPartitionSizeScaling(unittest.TestCase):
@ -87,46 +85,54 @@ class TestPartitionSizeScaling(unittest.TestCase):
class TestCalculateGuidedResize(unittest.TestCase):
def test_ignore_nonresizable(self):
actual = calculate_guided_resize(
part_min=-1, part_size=100 << 30, install_min=10 << 30)
part_min=-1, part_size=100 << 30, install_min=10 << 30
)
self.assertIsNone(actual)
def test_too_small(self):
actual = calculate_guided_resize(
part_min=95 << 30, part_size=100 << 30, install_min=10 << 30)
part_min=95 << 30, part_size=100 << 30, install_min=10 << 30
)
self.assertIsNone(actual)
def test_even_split(self):
# 8 GiB * 1.25 == 10 GiB
size = 10 << 30
actual = calculate_guided_resize(
part_min=8 << 30, part_size=100 << 30, install_min=size)
part_min=8 << 30, part_size=100 << 30, install_min=size
)
expected = GuidedResizeValues(
install_max=(100 << 30) - size,
minimum=size, recommended=50 << 30, maximum=(100 << 30) - size)
minimum=size,
recommended=50 << 30,
maximum=(100 << 30) - size,
)
self.assertEqual(expected, actual)
def test_weighted_split(self):
actual = calculate_guided_resize(
part_min=40 << 30, part_size=240 << 30, install_min=10 << 30)
part_min=40 << 30, part_size=240 << 30, install_min=10 << 30
)
expected = GuidedResizeValues(
install_max=190 << 30,
minimum=50 << 30, recommended=200 << 30, maximum=230 << 30)
minimum=50 << 30,
recommended=200 << 30,
maximum=230 << 30,
)
self.assertEqual(expected, actual)
def test_unaligned_size(self):
def assertAligned(val):
if val % MiB != 0:
self.fail(
"val of %s not aligned (%s) with delta %s" % (
val, val % MiB, delta))
"val of %s not aligned (%s) with delta %s" % (val, val % MiB, delta)
)
for i in range(1000):
delta = random.randrange(MiB)
actual = calculate_guided_resize(
part_min=40 << 30,
part_size=(240 << 30) + delta,
install_min=10 << 30)
part_min=40 << 30, part_size=(240 << 30) + delta, install_min=10 << 30
)
assertAligned(actual.install_max)
assertAligned(actual.minimum)
assertAligned(actual.recommended)
@ -134,8 +140,8 @@ class TestCalculateGuidedResize(unittest.TestCase):
class TestCalculateInstallMin(unittest.TestCase):
@mock.patch('subiquity.common.filesystem.sizes.swap.suggested_swapsize')
@mock.patch('subiquity.common.filesystem.sizes.bootfs_scale')
@mock.patch("subiquity.common.filesystem.sizes.swap.suggested_swapsize")
@mock.patch("subiquity.common.filesystem.sizes.bootfs_scale")
def test_small_setups(self, bootfs_scale, swapsize):
swapsize.return_value = 1 << 30
bootfs_scale.maximum = 1 << 30
@ -143,24 +149,24 @@ class TestCalculateInstallMin(unittest.TestCase):
# with a small source, we hit the default 2GiB padding
self.assertEqual(5 << 30, calculate_suggested_install_min(source_min))
@mock.patch('subiquity.common.filesystem.sizes.swap.suggested_swapsize')
@mock.patch('subiquity.common.filesystem.sizes.bootfs_scale')
@mock.patch("subiquity.common.filesystem.sizes.swap.suggested_swapsize")
@mock.patch("subiquity.common.filesystem.sizes.bootfs_scale")
def test_small_setups_big_swap(self, bootfs_scale, swapsize):
swapsize.return_value = 10 << 30
bootfs_scale.maximum = 1 << 30
source_min = 1 << 30
self.assertEqual(14 << 30, calculate_suggested_install_min(source_min))
@mock.patch('subiquity.common.filesystem.sizes.swap.suggested_swapsize')
@mock.patch('subiquity.common.filesystem.sizes.bootfs_scale')
@mock.patch("subiquity.common.filesystem.sizes.swap.suggested_swapsize")
@mock.patch("subiquity.common.filesystem.sizes.bootfs_scale")
def test_small_setups_big_boot(self, bootfs_scale, swapsize):
swapsize.return_value = 1 << 30
bootfs_scale.maximum = 10 << 30
source_min = 1 << 30
self.assertEqual(14 << 30, calculate_suggested_install_min(source_min))
@mock.patch('subiquity.common.filesystem.sizes.swap.suggested_swapsize')
@mock.patch('subiquity.common.filesystem.sizes.bootfs_scale')
@mock.patch("subiquity.common.filesystem.sizes.swap.suggested_swapsize")
@mock.patch("subiquity.common.filesystem.sizes.bootfs_scale")
def test_big_source(self, bootfs_scale, swapsize):
swapsize.return_value = 1 << 30
bootfs_scale.maximum = 2 << 30

View File

@ -20,7 +20,7 @@ import logging
import os
from typing import List
log = logging.getLogger('subiquity.common.resources')
log = logging.getLogger("subiquity.common.resources")
def resource_path(relative_path):
@ -31,17 +31,17 @@ def get_users_and_groups(chroot_prefix=[]) -> List:
# prevent import when calling just resource_path
from subiquitycore.utils import run_command
users_and_groups_path = resource_path('users-and-groups')
groups = ['admin']
users_and_groups_path = resource_path("users-and-groups")
groups = ["admin"]
if os.path.exists(users_and_groups_path):
with open(users_and_groups_path) as f:
groups = f.read().split()
groups.append('sudo')
groups.append("sudo")
command = chroot_prefix + ['getent', 'group']
command = chroot_prefix + ["getent", "group"]
cp = run_command(command, check=True)
target_groups = set()
for line in cp.stdout.splitlines():
target_groups.add(line.split(':')[0])
target_groups.add(line.split(":")[0])
return list(target_groups.intersection(groups))

View File

@ -15,19 +15,19 @@
import datetime
import enum
import json
import inspect
import json
import typing
import attr
def named_field(name, default=attr.NOTHING):
return attr.ib(metadata={'name': name}, default=default)
return attr.ib(metadata={"name": name}, default=default)
def _field_name(field):
return field.metadata.get('name', field.name)
return field.metadata.get("name", field.name)
class SerializationError(Exception):
@ -39,7 +39,7 @@ class SerializationError(Exception):
def __str__(self):
p = self.path
if not p:
p = 'top-level'
p = "top-level"
return f"processing {self.obj}: at {p}, {self.message}"
@ -53,13 +53,12 @@ class SerializationContext:
@classmethod
def new(cls, obj, *, serializing):
return SerializationContext(obj, obj, '', {}, serializing)
return SerializationContext(obj, obj, "", {}, serializing)
def child(self, path, cur, metadata=None):
if metadata is None:
metadata = self.metadata
return attr.evolve(
self, path=self.path + path, cur=cur, metadata=metadata)
return attr.evolve(self, path=self.path + path, cur=cur, metadata=metadata)
def error(self, message):
raise SerializationError(self.obj, self.path, message)
@ -74,9 +73,9 @@ class SerializationContext:
class Serializer:
def __init__(self, *, compact=False, ignore_unknown_fields=False,
serialize_enums_by="name"):
def __init__(
self, *, compact=False, ignore_unknown_fields=False, serialize_enums_by="name"
):
self.compact = compact
self.ignore_unknown_fields = ignore_unknown_fields
assert serialize_enums_by in ("value", "name")
@ -119,14 +118,14 @@ class Serializer:
if self.compact:
r.insert(0, a.__name__)
else:
r['$type'] = a.__name__
r["$type"] = a.__name__
return r
context.error(f"type of {context.cur} not found in {args}")
else:
if self.compact:
n = context.cur.pop(0)
else:
n = context.cur.pop('$type')
n = context.cur.pop("$type")
for a in args:
if a.__name__ == n:
return meth(a, context)
@ -135,8 +134,7 @@ class Serializer:
def _walk_List(self, meth, args, context):
return [
meth(args[0], context.child(f'[{i}]', v))
for i, v in enumerate(context.cur)
meth(args[0], context.child(f"[{i}]", v)) for i, v in enumerate(context.cur)
]
def _walk_Dict(self, meth, args, context):
@ -145,10 +143,13 @@ class Serializer:
input_items = context.cur
else:
input_items = context.cur.items()
output_items = [[
meth(k_ann, context.child(f'/{k}', k)),
meth(v_ann, context.child(f'[{k}]', v))
] for k, v in input_items]
output_items = [
[
meth(k_ann, context.child(f"/{k}", k)),
meth(v_ann, context.child(f"[{k}]", v)),
]
for k, v in input_items
]
if context.serializing and k_ann is not str:
return output_items
return dict(output_items)
@ -156,12 +157,12 @@ class Serializer:
def _serialize_dict(self, annotation, context):
context.assert_type(annotation)
for k in context.cur:
context.child(f'/{k}', k).assert_type(str)
context.child(f"/{k}", k).assert_type(str)
return context.cur
def _serialize_datetime(self, annotation, context):
context.assert_type(annotation)
fmt = context.metadata.get('time_fmt')
fmt = context.metadata.get("time_fmt")
if fmt is not None:
return context.cur.strftime(fmt)
else:
@ -170,15 +171,19 @@ class Serializer:
def _serialize_attr(self, annotation, context):
serialized = []
for field in attr.fields(annotation):
serialized.append((
serialized.append(
(
_field_name(field),
self._serialize(
field.type,
context.child(
f'.{field.name}',
f".{field.name}",
getattr(context.cur, field.name),
field.metadata)),
))
field.metadata,
),
),
)
)
if self.compact:
return [s[1] for s in serialized]
else:
@ -196,7 +201,7 @@ class Serializer:
return context.cur
if attr.has(annotation):
return self._serialize_attr(annotation, context)
origin = getattr(annotation, '__origin__', None)
origin = getattr(annotation, "__origin__", None)
if origin is not None:
args = annotation.__args__
return self.typing_walkers[origin](self._serialize, args, context)
@ -214,7 +219,7 @@ class Serializer:
return self._serialize(annotation, context)
def _deserialize_datetime(self, annotation, context):
fmt = context.metadata.get('time_fmt')
fmt = context.metadata.get("time_fmt")
if fmt is None:
context.error("cannot serialize datetime without format")
return datetime.datetime.strptime(context.cur, fmt)
@ -224,19 +229,19 @@ class Serializer:
context.assert_type(list)
args = []
for field, value in zip(attr.fields(annotation), context.cur):
args.append(self._deserialize(
args.append(
self._deserialize(
field.type,
context.child(f'[{field.name!r}]', value, field.metadata)))
context.child(f"[{field.name!r}]", value, field.metadata),
)
)
return annotation(*args)
else:
context.assert_type(dict)
args = {}
fields = {
_field_name(field): field for field in attr.fields(annotation)
}
fields = {_field_name(field): field for field in attr.fields(annotation)}
for key, value in context.cur.items():
if key not in fields and (
key == '$type' or self.ignore_unknown_fields):
if key not in fields and (key == "$type" or self.ignore_unknown_fields):
# Union types can contain a '$type' field that is not
# actually one of the keys. This happens if a object is
# serialized as part of a Union, sent to an API caller,
@ -245,8 +250,8 @@ class Serializer:
continue
field = fields[key]
args[field.name] = self._deserialize(
field.type,
context.child(f'[{key!r}]', value, field.metadata))
field.type, context.child(f"[{key!r}]", value, field.metadata)
)
return annotation(**args)
def _deserialize_enum(self, annotation, context):
@ -263,10 +268,11 @@ class Serializer:
return context.cur
if attr.has(annotation):
return self._deserialize_attr(annotation, context)
origin = getattr(annotation, '__origin__', None)
origin = getattr(annotation, "__origin__", None)
if origin is not None:
return self.typing_walkers[origin](
self._deserialize, annotation.__args__, context)
self._deserialize, annotation.__args__, context
)
if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
return self._deserialize_enum(annotation, context)
return self.type_deserializers[annotation](annotation, context)

View File

@ -13,7 +13,6 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import attr
import datetime
import enum
import inspect
@ -22,11 +21,9 @@ import string
import typing
import unittest
from subiquity.common.serialize import (
named_field,
Serializer,
SerializationError,
)
import attr
from subiquity.common.serialize import SerializationError, Serializer, named_field
@attr.s(auto_attribs=True)
@ -36,9 +33,7 @@ class Data:
@staticmethod
def make_random():
return Data(
random.choice(string.ascii_letters),
random.randint(0, 1000))
return Data(random.choice(string.ascii_letters), random.randint(0, 1000))
@attr.s(auto_attribs=True)
@ -50,8 +45,8 @@ class Container:
def make_random():
return Container(
data=Data.make_random(),
data_list=[Data.make_random()
for i in range(random.randint(10, 20))])
data_list=[Data.make_random() for i in range(random.randint(10, 20))],
)
@attr.s(auto_attribs=True)
@ -67,7 +62,6 @@ class MyEnum(enum.Enum):
class CommonSerializerTests:
simple_examples = [
(int, 1),
(str, "v"),
@ -77,17 +71,14 @@ class CommonSerializerTests:
]
def assertSerializesTo(self, annotation, value, expected):
self.assertEqual(
self.serializer.serialize(annotation, value), expected)
self.assertEqual(self.serializer.serialize(annotation, value), expected)
def assertDeserializesTo(self, annotation, value, expected):
self.assertEqual(
self.serializer.deserialize(annotation, value), expected)
self.assertEqual(self.serializer.deserialize(annotation, value), expected)
def assertRoundtrips(self, annotation, value):
serialized = self.serializer.to_json(annotation, value)
self.assertEqual(
self.serializer.from_json(annotation, serialized), value)
self.assertEqual(self.serializer.from_json(annotation, serialized), value)
def assertSerialization(self, annotation, value, expected):
self.assertSerializesTo(annotation, value, expected)
@ -114,8 +105,7 @@ class CommonSerializerTests:
self.assertSerialization(typ, val, val)
def test_non_string_key_dict(self):
self.assertRaises(
Exception, self.serializer.serialize, dict, {1: 2})
self.assertRaises(Exception, self.serializer.serialize, dict, {1: 2})
def test_roundtrip_dict(self):
ann = typing.Dict[int, str]
@ -141,7 +131,8 @@ class CommonSerializerTests:
def test_enums_by_value(self):
self.serializer = type(self.serializer)(
compact=self.serializer.compact, serialize_enums_by="value")
compact=self.serializer.compact, serialize_enums_by="value"
)
self.assertSerialization(MyEnum, MyEnum.name, "value")
def test_serialize_any(self):
@ -151,19 +142,19 @@ class CommonSerializerTests:
class TestSerializer(CommonSerializerTests, unittest.TestCase):
serializer = Serializer()
def test_datetime(self):
@attr.s
class C:
d: datetime.datetime = attr.ib(metadata={'time_fmt': '%Y-%m-%d'})
d: datetime.datetime = attr.ib(metadata={"time_fmt": "%Y-%m-%d"})
c = C(datetime.datetime(2022, 1, 1))
self.assertSerialization(C, c, {"d": "2022-01-01"})
def test_serialize_attr(self):
data = Data.make_random()
expected = {'field1': data.field1, 'field2': data.field2}
expected = {"field1": data.field1, "field2": data.field2}
self.assertSerialization(Data, data, expected)
def test_serialize_container(self):
@ -171,9 +162,9 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase):
data2 = Data.make_random()
container = Container(data1, [data2])
expected = {
'data': {'field1': data1.field1, 'field2': data1.field2},
'data_list': [
{'field1': data2.field1, 'field2': data2.field2},
"data": {"field1": data1.field1, "field2": data1.field2},
"data_list": [
{"field1": data2.field1, "field2": data2.field2},
],
}
self.assertSerialization(Container, container, expected)
@ -181,9 +172,9 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase):
def test_serialize_union(self):
data = Data.make_random()
expected = {
'$type': 'Data',
'field1': data.field1,
'field2': data.field2,
"$type": "Data",
"field1": data.field1,
"field2": data.field2,
}
self.assertSerialization(typing.Union[Data, Container], data, expected)
@ -193,18 +184,18 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase):
# API entrypoint, one that isn't taking a Union, it must be cool with
# the excess $type field.
data = {
'$type': 'Data',
'field1': '1',
'field2': 2,
"$type": "Data",
"field1": "1",
"field2": 2,
}
expected = Data(field1='1', field2=2)
expected = Data(field1="1", field2=2)
self.assertDeserializesTo(Data, data, expected)
def test_reject_unknown_fields_by_default(self):
serializer = Serializer()
data = Data.make_random()
serialized = serializer.serialize(Data, data)
serialized['foobar'] = 'baz'
serialized["foobar"] = "baz"
with self.assertRaises(KeyError):
serializer.deserialize(Data, serialized)
@ -212,11 +203,10 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase):
serializer = Serializer(ignore_unknown_fields=True)
data = Data.make_random()
serialized = serializer.serialize(Data, data)
serialized['foobar'] = 'baz'
serialized["foobar"] = "baz"
self.assertEqual(serializer.deserialize(Data, serialized), data)
def test_override_field_name(self):
@attr.s(auto_attribs=True)
class Object:
x: int
@ -224,10 +214,10 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase):
z: int = named_field("field-z", 0)
self.assertSerialization(
Object, Object(1, 2), {"x": 1, "field-y": 2, "field-z": 0})
Object, Object(1, 2), {"x": 1, "field-y": 2, "field-z": 0}
)
def test_embedding(self):
@attr.s(auto_attribs=True)
class Base1:
x: str
@ -249,28 +239,28 @@ class TestSerializer(CommonSerializerTests, unittest.TestCase):
self.assertSerialization(
Derived2,
Derived2(b=Derived1(x="a", y=1), c=2),
{"b": {"x": "a", "y": 1}, "c": 2})
{"b": {"x": "a", "y": 1}, "c": 2},
)
def test_error_paths(self):
with self.assertRaises(SerializationError) as catcher:
self.serializer.serialize(str, 1)
self.assertEqual(catcher.exception.path, '')
self.assertEqual(catcher.exception.path, "")
@attr.s(auto_attribs=True)
class Type:
field1: str = named_field('field-1')
field1: str = named_field("field-1")
field2: int
with self.assertRaises(SerializationError) as catcher:
self.serializer.serialize(Type, Data(2, 3))
self.assertEqual(catcher.exception.path, '.field1')
self.assertEqual(catcher.exception.path, ".field1")
with self.assertRaises(SerializationError) as catcher:
self.serializer.deserialize(Type, {'field-1': 1, 'field2': 2})
self.serializer.deserialize(Type, {"field-1": 1, "field2": 2})
self.assertEqual(catcher.exception.path, "['field-1']")
class TestCompactSerializer(CommonSerializerTests, unittest.TestCase):
serializer = Serializer(compact=True)
def test_serialize_attr(self):
@ -290,70 +280,57 @@ class TestCompactSerializer(CommonSerializerTests, unittest.TestCase):
def test_serialize_union(self):
data = Data.make_random()
expected = ['Data', data.field1, data.field2]
expected = ["Data", data.field1, data.field2]
self.assertSerialization(typing.Union[Data, Container], data, expected)
class TestOptionalAndDefault(CommonSerializerTests, unittest.TestCase):
serializer = Serializer()
def test_happy(self):
data = {
'int': 11,
'optional_int': 12,
'int_default': 3,
'optional_int_default': 4
"int": 11,
"optional_int": 12,
"int_default": 3,
"optional_int_default": 4,
}
expected = OptionalAndDefault(11, 12)
self.assertDeserializesTo(OptionalAndDefault, data, expected)
def test_skip_int(self):
data = {
'optional_int': 12,
'int_default': 13,
'optional_int_default': 14
}
data = {"optional_int": 12, "int_default": 13, "optional_int_default": 14}
with self.assertRaises(TypeError):
self.serializer.deserialize(OptionalAndDefault, data)
def test_skip_optional_int(self):
data = {
'int': 11,
'int_default': 13,
'optional_int_default': 14
}
data = {"int": 11, "int_default": 13, "optional_int_default": 14}
with self.assertRaises(TypeError):
self.serializer.deserialize(OptionalAndDefault, data)
def test_optional_int_none(self):
data = {
'int': 11,
'optional_int': None,
'int_default': 13,
'optional_int_default': 14
"int": 11,
"optional_int": None,
"int_default": 13,
"optional_int_default": 14,
}
expected = OptionalAndDefault(11, None, 13, 14)
self.assertDeserializesTo(OptionalAndDefault, data, expected)
def test_skip_int_default(self):
data = {
'int': 11,
'optional_int': 12,
'optional_int_default': 14
}
data = {"int": 11, "optional_int": 12, "optional_int_default": 14}
expected = OptionalAndDefault(11, 12, 3, 14)
self.assertDeserializesTo(OptionalAndDefault, data, expected)
def test_skip_optional_int_default(self):
data = {
'int': 11,
'optional_int': 12,
'int_default': 13,
"int": 11,
"optional_int": 12,
"int_default": 13,
}
expected = OptionalAndDefault(11, 12, 13, 4)
@ -361,10 +338,10 @@ class TestOptionalAndDefault(CommonSerializerTests, unittest.TestCase):
def test_optional_int_default_none(self):
data = {
'int': 11,
'optional_int': 12,
'int_default': 13,
'optional_int_default': None,
"int": 11,
"optional_int": 12,
"int_default": 13,
"optional_int_default": None,
}
expected = OptionalAndDefault(11, 12, 13, None)
@ -372,11 +349,11 @@ class TestOptionalAndDefault(CommonSerializerTests, unittest.TestCase):
def test_extra(self):
data = {
'int': 11,
'optional_int': 12,
'int_default': 13,
'optional_int_default': 14,
'extra': 15
"int": 11,
"optional_int": 12,
"int_default": 13,
"optional_int_default": 14,
"extra": 15,
}
with self.assertRaises(KeyError):

View File

@ -15,19 +15,16 @@
import unittest
from subiquity.common.types import (
GuidedCapability,
SizingPolicy,
)
from subiquity.common.types import GuidedCapability, SizingPolicy
class TestSizingPolicy(unittest.TestCase):
def test_all(self):
actual = SizingPolicy.from_string('all')
actual = SizingPolicy.from_string("all")
self.assertEqual(SizingPolicy.ALL, actual)
def test_scaled_size(self):
actual = SizingPolicy.from_string('scaled')
actual = SizingPolicy.from_string("scaled")
self.assertEqual(SizingPolicy.SCALED, actual)
def test_default(self):

View File

@ -125,8 +125,8 @@ class RefreshCheckState(enum.Enum):
@attr.s(auto_attribs=True)
class RefreshStatus:
availability: RefreshCheckState
current_snap_version: str = ''
new_snap_version: str = ''
current_snap_version: str = ""
new_snap_version: str = ""
@attr.s(auto_attribs=True)
@ -163,7 +163,7 @@ class KeyboardSetting:
# Ideally, we would internally represent a keyboard setting as a
# toggle + a list of [layout, variant].
layout: str
variant: str = ''
variant: str = ""
toggle: Optional[str] = None
@ -216,7 +216,7 @@ class ZdevInfo:
@classmethod
def from_row(cls, row):
row = dict((k.split('=', 1) for k in shlex.split(row)))
row = dict((k.split("=", 1) for k in shlex.split(row)))
for k, v in row.items():
if v == "yes":
row[k] = True
@ -228,8 +228,8 @@ class ZdevInfo:
@property
def typeclass(self):
if self.type.startswith('zfcp'):
return 'zfcp'
if self.type.startswith("zfcp"):
return "zfcp"
return self.type
@ -352,22 +352,25 @@ class GuidedCapability(enum.Enum):
CORE_BOOT_PREFER_UNENCRYPTED = enum.auto()
def is_lvm(self) -> bool:
return self in [GuidedCapability.LVM,
GuidedCapability.LVM_LUKS]
return self in [GuidedCapability.LVM, GuidedCapability.LVM_LUKS]
def is_core_boot(self) -> bool:
return self in [GuidedCapability.CORE_BOOT_ENCRYPTED,
return self in [
GuidedCapability.CORE_BOOT_ENCRYPTED,
GuidedCapability.CORE_BOOT_UNENCRYPTED,
GuidedCapability.CORE_BOOT_PREFER_ENCRYPTED,
GuidedCapability.CORE_BOOT_PREFER_UNENCRYPTED]
GuidedCapability.CORE_BOOT_PREFER_UNENCRYPTED,
]
def supports_manual_customization(self) -> bool:
# After posting this capability to guided_POST, is it possible
# for the user to customize the layout further?
return self in [GuidedCapability.MANUAL,
return self in [
GuidedCapability.MANUAL,
GuidedCapability.DIRECT,
GuidedCapability.LVM,
GuidedCapability.LVM_LUKS]
GuidedCapability.LVM_LUKS,
]
def is_zfs(self) -> bool:
return self in [GuidedCapability.ZFS]
@ -414,11 +417,11 @@ class SizingPolicy(enum.Enum):
@classmethod
def from_string(cls, value):
if value is None or value == 'scaled':
if value is None or value == "scaled":
return cls.SCALED
if value == 'all':
if value == "all":
return cls.ALL
raise Exception(f'Unknown SizingPolicy value {value}')
raise Exception(f"Unknown SizingPolicy value {value}")
@attr.s(auto_attribs=True)
@ -470,15 +473,16 @@ class GuidedStorageTargetUseGap:
@attr.s(auto_attribs=True)
class GuidedStorageTargetManual:
allowed: List[GuidedCapability] = attr.Factory(
lambda: [GuidedCapability.MANUAL])
allowed: List[GuidedCapability] = attr.Factory(lambda: [GuidedCapability.MANUAL])
disallowed: List[GuidedDisallowedCapability] = attr.Factory(list)
GuidedStorageTarget = Union[GuidedStorageTargetReformat,
GuidedStorageTarget = Union[
GuidedStorageTargetReformat,
GuidedStorageTargetResize,
GuidedStorageTargetUseGap,
GuidedStorageTargetManual]
GuidedStorageTargetManual,
]
@attr.s(auto_attribs=True)
@ -519,10 +523,10 @@ class ReformatDisk:
@attr.s(auto_attribs=True)
class IdentityData:
realname: str = ''
username: str = ''
crypted_password: str = attr.ib(default='', repr=False)
hostname: str = ''
realname: str = ""
username: str = ""
crypted_password: str = attr.ib(default="", repr=False)
hostname: str = ""
class UsernameValidation(enum.Enum):
@ -543,6 +547,7 @@ class SSHData:
@attr.s(auto_attribs=True)
class SSHIdentity:
"""Represents a SSH identity (public key + fingerprint)."""
key_type: str
key: str
key_comment: str
@ -579,19 +584,20 @@ class ChannelSnapInfo:
version: str
size: int
released_at: datetime.datetime = attr.ib(
metadata={'time_fmt': '%Y-%m-%dT%H:%M:%S.%fZ'})
metadata={"time_fmt": "%Y-%m-%dT%H:%M:%S.%fZ"}
)
@attr.s(auto_attribs=True, eq=False)
class SnapInfo:
name: str
summary: str = ''
publisher: str = ''
summary: str = ""
publisher: str = ""
verified: bool = False
starred: bool = False
description: str = ''
confinement: str = ''
license: str = ''
description: str = ""
confinement: str = ""
license: str = ""
channels: List[ChannelSnapInfo] = attr.Factory(list)
@ -606,6 +612,7 @@ class DriversResponse:
:local_only: tells if we are looking for drivers only from the ISO.
:search_drivers: enables or disables drivers listing.
"""
install: bool
drivers: Optional[List[str]]
local_only: bool
@ -655,6 +662,7 @@ class UbuntuProInfo:
@attr.s(auto_attribs=True)
class UbuntuProResponse:
"""Response to GET request to /ubuntu_pro"""
token: str = attr.ib(repr=False)
has_network: bool
@ -669,6 +677,7 @@ class UbuntuProCheckTokenStatus(enum.Enum):
@attr.s(auto_attribs=True)
class UPCSInitiateResponse:
"""Response to Ubuntu Pro contract selection initiate request."""
user_code: str
validity_seconds: int
@ -681,6 +690,7 @@ class UPCSWaitStatus(enum.Enum):
@attr.s(auto_attribs=True)
class UPCSWaitResponse:
"""Response to Ubuntu Pro contract selection wait request."""
status: UPCSWaitStatus
contract_token: Optional[str]
@ -715,8 +725,8 @@ class ShutdownMode(enum.Enum):
@attr.s(auto_attribs=True)
class WSLConfigurationBase:
automount_root: str = attr.ib(default='/mnt/')
automount_options: str = ''
automount_root: str = attr.ib(default="/mnt/")
automount_options: str = ""
network_generatehosts: bool = attr.ib(default=True)
network_generateresolvconf: bool = attr.ib(default=True)
@ -750,7 +760,7 @@ class TaskStatus(enum.Enum):
@attr.s(auto_attribs=True)
class TaskProgress:
label: str = ''
label: str = ""
done: int = 0
total: int = 0
@ -777,10 +787,10 @@ class Change:
class CasperMd5Results(enum.Enum):
UNKNOWN = 'unknown'
FAIL = 'fail'
PASS = 'pass'
SKIP = 'skip'
UNKNOWN = "unknown"
FAIL = "fail"
PASS = "pass"
SKIP = "skip"
class MirrorCheckStatus(enum.Enum):
@ -817,9 +827,9 @@ class MirrorGet:
class MirrorSelectionFallback(enum.Enum):
ABORT = 'abort'
CONTINUE_ANYWAY = 'continue-anyway'
OFFLINE_INSTALL = 'offline-install'
ABORT = "abort"
CONTINUE_ANYWAY = "continue-anyway"
OFFLINE_INSTALL = "offline-install"
@attr.s(auto_attribs=True)
@ -830,32 +840,32 @@ class AdConnectionInfo:
class AdAdminNameValidation(enum.Enum):
OK = 'OK'
EMPTY = 'Empty'
INVALID_CHARS = 'Contains invalid characters'
OK = "OK"
EMPTY = "Empty"
INVALID_CHARS = "Contains invalid characters"
class AdDomainNameValidation(enum.Enum):
OK = 'OK'
EMPTY = 'Empty'
TOO_LONG = 'Too long'
INVALID_CHARS = 'Contains invalid characters'
START_DOT = 'Starts with a dot'
END_DOT = 'Ends with a dot'
START_HYPHEN = 'Starts with a hyphen'
END_HYPHEN = 'Ends with a hyphen'
MULTIPLE_DOTS = 'Contains multiple dots'
REALM_NOT_FOUND = 'Could not find the domain controller'
OK = "OK"
EMPTY = "Empty"
TOO_LONG = "Too long"
INVALID_CHARS = "Contains invalid characters"
START_DOT = "Starts with a dot"
END_DOT = "Ends with a dot"
START_HYPHEN = "Starts with a hyphen"
END_HYPHEN = "Ends with a hyphen"
MULTIPLE_DOTS = "Contains multiple dots"
REALM_NOT_FOUND = "Could not find the domain controller"
class AdPasswordValidation(enum.Enum):
OK = 'OK'
EMPTY = 'Empty'
OK = "OK"
EMPTY = "Empty"
class AdJoinResult(enum.Enum):
OK = 'OK'
JOIN_ERROR = 'Failed to join'
EMPTY_HOSTNAME = 'Target hostname cannot be empty'
PAM_ERROR = 'Failed to update pam-auth'
OK = "OK"
JOIN_ERROR = "Failed to join"
EMPTY_HOSTNAME = "Target hostname cannot be empty"
PAM_ERROR = "Failed to update pam-auth"
UNKNOWN = "Didn't attempt to join yet"

View File

@ -23,7 +23,7 @@ def journald_listen(identifiers, callback, seek=False):
reader = journal.Reader()
args = []
for identifier in identifiers:
if '=' in identifier:
if "=" in identifier:
args.append(identifier)
else:
args.append("SYSLOG_IDENTIFIER={}".format(identifier))
@ -37,6 +37,7 @@ def journald_listen(identifiers, callback, seek=False):
return
for event in reader:
callback(event)
loop = asyncio.get_running_loop()
loop.add_reader(reader.fileno(), watch)
return reader.fileno()

View File

@ -18,11 +18,10 @@ import logging
from subiquitycore.async_helpers import run_in_thread
log = logging.getLogger('subiquity.lockfile')
log = logging.getLogger("subiquity.lockfile")
class _LockContext:
def __init__(self, lockfile, flags):
self.lockfile = lockfile
self.flags = flags
@ -49,10 +48,9 @@ class _LockContext:
class Lockfile:
def __init__(self, path):
self.path = path
self.fp = open(path, 'a+')
self.fp = open(path, "a+")
def read_content(self):
self.fp.seek(0)

View File

@ -18,11 +18,12 @@ from typing import Optional
from subiquity.common.types import AdConnectionInfo
log = logging.getLogger('subiquity.models.ad')
log = logging.getLogger("subiquity.models.ad")
class AdModel:
"""Models the Active Directory feature"""
def __init__(self) -> None:
self.do_join = False
self.conn_info: Optional[AdConnectionInfo] = None

View File

@ -15,7 +15,7 @@
import logging
log = logging.getLogger('subiquity.models.codecs')
log = logging.getLogger("subiquity.models.codecs")
class CodecsModel:

View File

@ -15,7 +15,7 @@
import logging
log = logging.getLogger('subiquity.models.drivers')
log = logging.getLogger("subiquity.models.drivers")
class DriversModel:

File diff suppressed because it is too large Load Diff

View File

@ -17,8 +17,7 @@ import logging
import attr
log = logging.getLogger('subiquity.models.identity')
log = logging.getLogger("subiquity.models.identity")
@attr.s
@ -29,8 +28,7 @@ class User(object):
class IdentityModel(object):
""" Model representing user identity
"""
"""Model representing user identity"""
def __init__(self):
self._user = None
@ -39,11 +37,11 @@ class IdentityModel(object):
def add_user(self, identity_data):
self._hostname = identity_data.hostname
d = {}
d['realname'] = identity_data.realname
d['username'] = identity_data.username
d['password'] = identity_data.crypted_password
if not d['realname']:
d['realname'] = identity_data.username
d["realname"] = identity_data.realname
d["username"] = identity_data.username
d["password"] = identity_data.crypted_password
if not d["realname"]:
d["realname"] = identity_data.username
self._user = User(**d)
@property

View File

@ -17,7 +17,6 @@ from typing import Optional
class KernelModel:
# The name of the kernel metapackage that we intend to install.
metapkg_name: Optional[str] = None
# During the installation, if we detect that a different kernel version is
@ -42,10 +41,10 @@ class KernelModel:
def render(self):
if self.curthooks_no_install:
return {'kernel': None}
return {"kernel": None}
return {
'kernel': {
'package': self.needed_kernel,
"kernel": {
"package": self.needed_kernel,
},
}

View File

@ -14,17 +14,14 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging
import re
import os
import re
import yaml
from subiquity.common.resources import resource_path
from subiquity.common.serialize import Serializer
from subiquity.common.types import (
KeyboardLayout,
KeyboardSetting,
)
from subiquity.common.types import KeyboardLayout, KeyboardSetting
log = logging.getLogger("subiquity.models.keyboard")
@ -46,10 +43,11 @@ BACKSPACE="guess"
class InconsistentMultiLayoutError(ValueError):
"""Exception to raise when a multi layout has a different number of
layouts and variants."""
def __init__(self, layouts: str, variants: str) -> None:
super().__init__(
f'inconsistent multi-layout: layouts="{layouts}"'
f' variants="{variants}"')
f'inconsistent multi-layout: layouts="{layouts}"' f' variants="{variants}"'
)
def from_config_file(config_file):
@ -57,10 +55,10 @@ def from_config_file(config_file):
content = fp.read()
def optval(opt, default):
match = re.search(r'(?m)^\s*%s=(.*)$' % (opt,), content)
match = re.search(r"(?m)^\s*%s=(.*)$" % (opt,), content)
if match:
r = match.group(1).strip('"')
if r != '':
if r != "":
return r
return default
@ -68,24 +66,22 @@ def from_config_file(config_file):
XKBVARIANT = optval("XKBVARIANT", "")
XKBOPTIONS = optval("XKBOPTIONS", "")
toggle = None
for option in XKBOPTIONS.split(','):
if option.startswith('grp:'):
for option in XKBOPTIONS.split(","):
if option.startswith("grp:"):
toggle = option[4:]
return KeyboardSetting(layout=XKBLAYOUT, variant=XKBVARIANT, toggle=toggle)
class KeyboardModel:
def __init__(self, root):
self.config_path = os.path.join(
root, 'etc', 'default', 'keyboard')
self.config_path = os.path.join(root, "etc", "default", "keyboard")
self.layout_for_lang = self.load_layout_suggestions()
if os.path.exists(self.config_path):
self.default_setting = from_config_file(self.config_path)
else:
self.default_setting = self.layout_for_lang['en_US.UTF-8']
self.default_setting = self.layout_for_lang["en_US.UTF-8"]
self.keyboard_list = KeyboardList()
self.keyboard_list.load_language('C')
self.keyboard_list.load_language("C")
self._setting = None
@property
@ -105,41 +101,47 @@ class KeyboardModel:
if len(layout_tokens) != len(variant_tokens):
raise InconsistentMultiLayoutError(
layouts=setting.layout, variants=setting.variant)
layouts=setting.layout, variants=setting.variant
)
for layout, variant in zip(layout_tokens, variant_tokens):
kbd_layout = self.keyboard_list.layout_map.get(layout)
if kbd_layout is None:
raise ValueError(f'Unknown keyboard layout "{layout}"')
if not any(kbd_variant.code == variant
for kbd_variant in kbd_layout.variants):
raise ValueError(f'Unknown keyboard variant "{variant}" '
f'for layout "{layout}"')
if not any(
kbd_variant.code == variant for kbd_variant in kbd_layout.variants
):
raise ValueError(
f'Unknown keyboard variant "{variant}" ' f'for layout "{layout}"'
)
def render_config_file(self):
options = ""
if self.setting.toggle:
options = "grp:" + self.setting.toggle
return etc_default_keyboard_template.format(
layout=self.setting.layout,
variant=self.setting.variant,
options=options)
layout=self.setting.layout, variant=self.setting.variant, options=options
)
def render(self):
return {
'write_files': {
'etc_default_keyboard': {
'path': 'etc/default/keyboard',
'content': self.render_config_file(),
'permissions': 0o644,
"write_files": {
"etc_default_keyboard": {
"path": "etc/default/keyboard",
"content": self.render_config_file(),
"permissions": 0o644,
},
},
'curthooks_commands': {
"curthooks_commands": {
# The below command must be run after updating
# etc/default/keyboard on the target so that the initramfs uses
# the keyboard mapping selected by the user. See LP #1894009
'002-setupcon-save-only': [
'curtin', 'in-target', '--', 'setupcon', '--save-only',
"002-setupcon-save-only": [
"curtin",
"in-target",
"--",
"setupcon",
"--save-only",
],
},
}
@ -154,7 +156,7 @@ class KeyboardModel:
def load_layout_suggestions(self, path=None):
if path is None:
path = resource_path('kbds') + '/keyboard-configuration.yaml'
path = resource_path("kbds") + "/keyboard-configuration.yaml"
with open(path) as fp:
data = yaml.safe_load(fp)
@ -166,25 +168,24 @@ class KeyboardModel:
class KeyboardList:
def __init__(self):
self._kbnames_dir = resource_path('kbds')
self._kbnames_dir = resource_path("kbds")
self.serializer = Serializer(compact=True)
self._clear()
def _file_for_lang(self, code):
return os.path.join(self._kbnames_dir, code + '.jsonl')
return os.path.join(self._kbnames_dir, code + ".jsonl")
def _has_language(self, code):
return os.path.exists(self._file_for_lang(code))
def load_language(self, code):
if '.' in code:
code = code.split('.')[0]
if "." in code:
code = code.split(".")[0]
if not self._has_language(code):
code = code.split('_')[0]
code = code.split("_")[0]
if not self._has_language(code):
code = 'C'
code = "C"
if code == self.current_lang:
return

View File

@ -19,7 +19,7 @@ import subprocess
from subiquitycore.utils import arun_command, split_cmd_output
log = logging.getLogger('subiquity.models.locale')
log = logging.getLogger("subiquity.models.locale")
class LocaleModel:
@ -38,11 +38,7 @@ class LocaleModel:
self.selected_language = code
async def localectl_set_locale(self) -> None:
cmd = [
'localectl',
'set-locale',
locale.normalize(self.selected_language)
]
cmd = ["localectl", "set-locale", locale.normalize(self.selected_language)]
await arun_command(cmd, check=True)
async def try_localectl_set_locale(self) -> None:
@ -60,18 +56,19 @@ class LocaleModel:
if self.locale_support == "none":
return {}
locale = self.selected_language
if '.' not in locale and '_' in locale:
locale += '.UTF-8'
return {'locale': locale}
if "." not in locale and "_" in locale:
locale += ".UTF-8"
return {"locale": locale}
async def target_packages(self):
if self.selected_language is None:
return []
if self.locale_support != "langpack":
return []
lang = self.selected_language.split('.')[0]
lang = self.selected_language.split(".")[0]
if lang == "C":
return []
return await split_cmd_output(
self.chroot_prefix + ['check-language-support', '-l', lang], None)
self.chroot_prefix + ["check-language-support", "-l", lang], None
)

View File

@ -72,33 +72,29 @@ primary entry
"""
import abc
import copy
import contextlib
import copy
import logging
from typing import (
Any, Callable, Dict, Iterator, List,
Optional, Sequence, Set, Union,
)
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Union
from urllib import parse
import attr
from subiquity.common.types import MirrorSelectionFallback
from curtin.commands.apt_config import (
get_arch_mirrorconfig,
get_mirror,
PORTS_ARCHES,
PRIMARY_ARCHES,
get_arch_mirrorconfig,
get_mirror,
)
from curtin.config import merge_config
from subiquity.common.types import MirrorSelectionFallback
try:
from curtin.distro import get_architecture
except ImportError:
from curtin.util import get_architecture
log = logging.getLogger('subiquity.models.mirror')
log = logging.getLogger("subiquity.models.mirror")
DEFAULT_SUPPORTED_ARCHES_URI = "http://archive.ubuntu.com/ubuntu"
DEFAULT_PORTS_ARCHES_URI = "http://ports.ubuntu.com/ubuntu-ports"
@ -107,7 +103,8 @@ LEGACY_DEFAULT_PRIMARY_SECTION = [
{
"arches": PRIMARY_ARCHES,
"uri": DEFAULT_SUPPORTED_ARCHES_URI,
}, {
},
{
"arches": ["default"],
"uri": DEFAULT_PORTS_ARCHES_URI,
},
@ -123,6 +120,7 @@ class BasePrimaryEntry(abc.ABC):
"""Base class to represent an entry from the 'primary' autoinstall
section. A BasePrimaryEntry is expected to have a URI and therefore can be
used as a primary candidate."""
parent: "MirrorModel" = attr.ib(kw_only=True)
def stage(self) -> None:
@ -146,6 +144,7 @@ class PrimaryEntry(BasePrimaryEntry):
"""Represents a single primary mirror candidate; which can be converted
to/from an entry of the 'apt->mirror-selection->primary' autoinstall
section."""
# Having uri set to None is only valid for a country mirror.
uri: Optional[str] = None
# When arches is None, it is assumed that the mirror is compatible with the
@ -187,6 +186,7 @@ class LegacyPrimaryEntry(BasePrimaryEntry):
to/from the whole 'apt->primary' autoinstall section (legacy format).
The format is defined by curtin, so we make use of curtin to access
the elements."""
def __init__(self, config: List[Any], *, parent: "MirrorModel") -> None:
self.config = config
super().__init__(parent=parent)
@ -200,8 +200,8 @@ class LegacyPrimaryEntry(BasePrimaryEntry):
@uri.setter
def uri(self, uri: str) -> None:
config = get_arch_mirrorconfig(
{"primary": self.config},
"primary", self.parent.architecture)
{"primary": self.config}, "primary", self.parent.architecture
)
config["uri"] = uri
def mirror_is_default(self) -> bool:
@ -209,8 +209,7 @@ class LegacyPrimaryEntry(BasePrimaryEntry):
@classmethod
def new_from_default(cls, parent: "MirrorModel") -> "LegacyPrimaryEntry":
return cls(copy.deepcopy(LEGACY_DEFAULT_PRIMARY_SECTION),
parent=parent)
return cls(copy.deepcopy(LEGACY_DEFAULT_PRIMARY_SECTION), parent=parent)
def serialize_for_ai(self) -> List[Any]:
return self.config
@ -224,16 +223,16 @@ class LegacyPrimaryEntry(BasePrimaryEntry):
def countrify_uri(uri: str, cc: str) -> str:
"""Return a URL where the host is prefixed with a country code."""
parsed = parse.urlparse(uri)
new = parsed._replace(netloc=cc + '.' + parsed.netloc)
new = parsed._replace(netloc=cc + "." + parsed.netloc)
return parse.urlunparse(new)
CandidateFilter = Callable[[BasePrimaryEntry], bool]
def filter_candidates(candidates: List[BasePrimaryEntry],
*, filters: Sequence[CandidateFilter]) \
-> Iterator[BasePrimaryEntry]:
def filter_candidates(
candidates: List[BasePrimaryEntry], *, filters: Sequence[CandidateFilter]
) -> Iterator[BasePrimaryEntry]:
candidates_iter = iter(candidates)
for filt in filters:
candidates_iter = filter(filt, candidates_iter)
@ -241,21 +240,20 @@ def filter_candidates(candidates: List[BasePrimaryEntry],
class MirrorModel(object):
def __init__(self):
self.config = copy.deepcopy(DEFAULT)
self.legacy_primary = False
self.disabled_components: Set[str] = set()
self.primary_elected: Optional[BasePrimaryEntry] = None
self.primary_candidates: List[BasePrimaryEntry] = \
self._default_primary_entries()
self.primary_candidates: List[
BasePrimaryEntry
] = self._default_primary_entries()
self.primary_staged: Optional[BasePrimaryEntry] = None
self.architecture = get_architecture()
# Only useful for legacy primary sections
self.default_mirror = \
LegacyPrimaryEntry.new_from_default(parent=self).uri
self.default_mirror = LegacyPrimaryEntry.new_from_default(parent=self).uri
# What to do if automatic mirror-selection fails.
self.fallback = MirrorSelectionFallback.ABORT
@ -263,14 +261,17 @@ class MirrorModel(object):
def _default_primary_entries(self) -> List[PrimaryEntry]:
return [
PrimaryEntry(parent=self, country_mirror=True),
PrimaryEntry(uri=DEFAULT_SUPPORTED_ARCHES_URI,
arches=PRIMARY_ARCHES, parent=self),
PrimaryEntry(uri=DEFAULT_PORTS_ARCHES_URI,
arches=PORTS_ARCHES, parent=self),
PrimaryEntry(
uri=DEFAULT_SUPPORTED_ARCHES_URI, arches=PRIMARY_ARCHES, parent=self
),
PrimaryEntry(
uri=DEFAULT_PORTS_ARCHES_URI, arches=PORTS_ARCHES, parent=self
),
]
def get_default_primary_candidates(
self, legacy: Optional[bool] = None) -> Sequence[BasePrimaryEntry]:
self, legacy: Optional[bool] = None
) -> Sequence[BasePrimaryEntry]:
want_legacy = legacy if legacy is not None else self.legacy_primary
if want_legacy:
return [LegacyPrimaryEntry.new_from_default(parent=self)]
@ -282,15 +283,15 @@ class MirrorModel(object):
self.disabled_components = set(data.pop("disable_components"))
if "primary" in data and "mirror-selection" in data:
raise ValueError("apt->primary and apt->mirror-selection are"
" mutually exclusive.")
raise ValueError(
"apt->primary and apt->mirror-selection are" " mutually exclusive."
)
self.legacy_primary = "primary" in data
primary_candidates = self.get_default_primary_candidates()
if "primary" in data:
# Legacy sections only support a single candidate
primary_candidates = \
[LegacyPrimaryEntry(data.pop("primary"), parent=self)]
primary_candidates = [LegacyPrimaryEntry(data.pop("primary"), parent=self)]
if "mirror-selection" in data:
mirror_selection = data.pop("mirror-selection")
if "primary" in mirror_selection:
@ -314,7 +315,8 @@ class MirrorModel(object):
return config
def _get_apt_config_using_candidate(
self, candidate: BasePrimaryEntry) -> Dict[str, Any]:
self, candidate: BasePrimaryEntry
) -> Dict[str, Any]:
config = self._get_apt_config_common()
config["primary"] = candidate.config
return config
@ -327,8 +329,7 @@ class MirrorModel(object):
assert self.primary_elected is not None
return self._get_apt_config_using_candidate(self.primary_elected)
def get_apt_config(
self, final: bool, has_network: bool) -> Dict[str, Any]:
def get_apt_config(self, final: bool, has_network: bool) -> Dict[str, Any]:
if not final:
return self.get_apt_config_staged()
if has_network:
@ -347,8 +348,9 @@ class MirrorModel(object):
lambda c: c.uri is not None,
lambda c: c.supports_arch(self.architecture),
]
candidate = next(filter_candidates(self.primary_candidates,
filters=filters))
candidate = next(
filter_candidates(self.primary_candidates, filters=filters)
)
return self._get_apt_config_using_candidate(candidate)
# Our last resort is to include no primary section. Curtin will use
# its own internal values.
@ -376,16 +378,14 @@ class MirrorModel(object):
self.disabled_components -= comps
def create_primary_candidate(
self, uri: Optional[str],
country_mirror: bool = False) -> BasePrimaryEntry:
self, uri: Optional[str], country_mirror: bool = False
) -> BasePrimaryEntry:
if self.legacy_primary:
entry = LegacyPrimaryEntry.new_from_default(parent=self)
entry.uri = uri
return entry
return PrimaryEntry(uri=uri, country_mirror=country_mirror,
parent=self)
return PrimaryEntry(uri=uri, country_mirror=country_mirror, parent=self)
def wants_geoip(self) -> bool:
"""Tell whether geoip results would be useful."""

View File

@ -20,11 +20,10 @@ from subiquity import cloudinit
from subiquitycore.models.network import NetworkModel as CoreNetworkModel
from subiquitycore.utils import arun_command
log = logging.getLogger('subiquity.models.network')
log = logging.getLogger("subiquity.models.network")
class NetworkModel(CoreNetworkModel):
def __init__(self):
super().__init__("subiquity")
self.override_config = None
@ -46,59 +45,57 @@ class NetworkModel(CoreNetworkModel):
# If host cloud-init version has no readable combined-cloud-config,
# default to False.
cloud_cfg = cloudinit.get_host_combined_cloud_config()
use_cloudinit_net = cloud_cfg.get('features', {}).get(
'NETPLAN_CONFIG_ROOT_READ_ONLY', False
use_cloudinit_net = cloud_cfg.get("features", {}).get(
"NETPLAN_CONFIG_ROOT_READ_ONLY", False
)
if use_cloudinit_net:
r = {
'write_files': {
'etc_netplan_installer': {
'path': (
'etc/cloud/cloud.cfg.d/90-installer-network.cfg'
),
'content': self.stringify_config(netplan),
'permissions': '0600',
"write_files": {
"etc_netplan_installer": {
"path": ("etc/cloud/cloud.cfg.d/90-installer-network.cfg"),
"content": self.stringify_config(netplan),
"permissions": "0600",
},
}
}
else:
# Separate sensitive wifi config from world-readable config
wifis = netplan['network'].pop('wifis', None)
wifis = netplan["network"].pop("wifis", None)
r = {
'write_files': {
"write_files": {
# Disable cloud-init networking
'no_cloudinit_net': {
'path': (
'etc/cloud/cloud.cfg.d/'
'subiquity-disable-cloudinit-networking.cfg'
"no_cloudinit_net": {
"path": (
"etc/cloud/cloud.cfg.d/"
"subiquity-disable-cloudinit-networking.cfg"
),
'content': 'network: {config: disabled}\n',
"content": "network: {config: disabled}\n",
},
# World-readable netplan without sensitive wifi config
'etc_netplan_installer': {
'path': 'etc/netplan/00-installer-config.yaml',
'content': self.stringify_config(netplan),
"etc_netplan_installer": {
"path": "etc/netplan/00-installer-config.yaml",
"content": self.stringify_config(netplan),
},
},
}
if wifis is not None:
netplan_wifi = {
'network': {
'version': 2,
'wifis': wifis,
"network": {
"version": 2,
"wifis": wifis,
},
}
r['write_files']['etc_netplan_installer_wifi'] = {
'path': 'etc/netplan/00-installer-config-wifi.yaml',
'content': self.stringify_config(netplan_wifi),
'permissions': '0600',
r["write_files"]["etc_netplan_installer_wifi"] = {
"path": "etc/netplan/00-installer-config-wifi.yaml",
"content": self.stringify_config(netplan_wifi),
"permissions": "0600",
}
return r
async def target_packages(self):
if self.needs_wpasupplicant:
return ['wpasupplicant']
return ["wpasupplicant"]
else:
return []

View File

@ -18,7 +18,7 @@ from typing import Any, Dict, List, Optional, Union
import attr
log = logging.getLogger('subiquity.models.oem')
log = logging.getLogger("subiquity.models.oem")
@attr.s(auto_attribs=True)

View File

@ -15,17 +15,16 @@
import logging
log = logging.getLogger('subiquity.models.proxy')
log = logging.getLogger("subiquity.models.proxy")
dropin_template = '''\
dropin_template = """\
[Service]
Environment="HTTP_PROXY={proxy}"
Environment="HTTPS_PROXY={proxy}"
'''
"""
class ProxyModel(object):
def __init__(self):
self.proxy = ""
@ -35,8 +34,8 @@ class ProxyModel(object):
def get_apt_config(self, final: bool, has_network: bool):
if self.proxy:
return {
'http_proxy': self.proxy,
'https_proxy': self.proxy,
"http_proxy": self.proxy,
"https_proxy": self.proxy,
}
else:
return {}
@ -44,16 +43,17 @@ class ProxyModel(object):
def render(self):
if self.proxy:
return {
'proxy': {
'http_proxy': self.proxy,
'https_proxy': self.proxy,
"proxy": {
"http_proxy": self.proxy,
"https_proxy": self.proxy,
},
'write_files': {
'snapd_dropin': {
'path': ('etc/systemd/system/'
'snapd.service.d/snap_proxy.conf'),
'content': self.proxy_systemd_dropin(),
'permissions': 0o644,
"write_files": {
"snapd_dropin": {
"path": (
"etc/systemd/system/" "snapd.service.d/snap_proxy.conf"
),
"content": self.proxy_systemd_dropin(),
"permissions": 0o644,
},
},
}

View File

@ -16,11 +16,7 @@
import datetime
import logging
from subiquity.common.types import (
ChannelSnapInfo,
SnapInfo,
)
from subiquity.common.types import ChannelSnapInfo, SnapInfo
log = logging.getLogger("subiquity.models.snaplist")
@ -44,47 +40,49 @@ class SnapListModel:
return s
def load_find_data(self, data):
for info in data['result']:
self.update(self._snap_for_name(info['name']), info)
for info in data["result"]:
self.update(self._snap_for_name(info["name"]), info)
def add_partial_snap(self, name):
self._snaps_for_name(name)
def update(self, snap, data):
snap.summary = data['summary']
snap.publisher = data['developer']
snap.verified = data['publisher']['validation'] == "verified"
snap.starred = data['publisher']['validation'] == "starred"
snap.description = data['description']
snap.confinement = data['confinement']
snap.license = data['license']
snap.summary = data["summary"]
snap.publisher = data["developer"]
snap.verified = data["publisher"]["validation"] == "verified"
snap.starred = data["publisher"]["validation"] == "starred"
snap.description = data["description"]
snap.confinement = data["confinement"]
snap.license = data["license"]
self.complete_snaps.add(snap)
def load_info_data(self, data):
info = data['result'][0]
snap = self._snaps_by_name.get(info['name'])
info = data["result"][0]
snap = self._snaps_by_name.get(info["name"])
if snap is None:
return
if snap not in self.complete_snaps:
self.update_snap(snap, info)
channel_map = info['channels']
for track in info['tracks']:
channel_map = info["channels"]
for track in info["tracks"]:
for risk in risks:
channel_name = '{}/{}'.format(track, risk)
channel_name = "{}/{}".format(track, risk)
if channel_name in channel_map:
channel_data = channel_map[channel_name]
if track == "latest":
channel_name = risk
snap.channels.append(ChannelSnapInfo(
snap.channels.append(
ChannelSnapInfo(
channel_name=channel_name,
revision=channel_data['revision'],
confinement=channel_data['confinement'],
version=channel_data['version'],
size=channel_data['size'],
revision=channel_data["revision"],
confinement=channel_data["confinement"],
version=channel_data["version"],
size=channel_data["size"],
released_at=datetime.datetime.strptime(
channel_data['released-at'],
'%Y-%m-%dT%H:%M:%S.%fZ'),
))
channel_data["released-at"], "%Y-%m-%dT%H:%M:%S.%fZ"
),
)
)
return snap
def get_snap_list(self):
@ -100,9 +98,9 @@ class SnapListModel:
return {}
cmds = []
for selection in self.selections:
cmd = ['snap', 'install', '--channel=' + selection.channel]
cmd = ["snap", "install", "--channel=" + selection.channel]
if selection.classic:
cmd.append('--classic')
cmd.append("--classic")
cmd.append(selection.name)
cmds.append(' '.join(cmd))
return {'snap': {'commands': cmds}}
cmds.append(" ".join(cmd))
return {"snap": {"commands": cmds}}

View File

@ -16,13 +16,13 @@
import logging
import os
import typing
import yaml
import attr
import yaml
from subiquity.common.serialize import Serializer
log = logging.getLogger('subiquity.models.source')
log = logging.getLogger("subiquity.models.source")
@attr.s(auto_attribs=True)
@ -49,33 +49,32 @@ class CatalogEntry:
def __attrs_post_init__(self):
if not self.variations:
self.variations['default'] = CatalogEntryVariation(
self.variations["default"] = CatalogEntryVariation(
path=self.path,
size=self.size,
snapd_system_label=self.snapd_system_label)
snapd_system_label=self.snapd_system_label,
)
legacy_server_entry = CatalogEntry(
variant='server',
id='synthesized',
name={'en': 'Ubuntu Server'},
description={'en': 'the default'},
path='/media/filesystem',
type='cp',
variant="server",
id="synthesized",
name={"en": "Ubuntu Server"},
description={"en": "the default"},
path="/media/filesystem",
type="cp",
default=True,
size=2 << 30,
locale_support="locale-only",
variations={
'default': CatalogEntryVariation(
path='/media/filesystem',
size=2 << 30),
})
"default": CatalogEntryVariation(path="/media/filesystem", size=2 << 30),
},
)
class SourceModel:
def __init__(self):
self._dir = '/cdrom/casper'
self._dir = "/cdrom/casper"
self.current = legacy_server_entry
self.sources = [self.current]
self.lang = None
@ -86,8 +85,8 @@ class SourceModel:
self.sources = []
self.current = None
self.sources = Serializer(ignore_unknown_fields=True).deserialize(
typing.List[CatalogEntry],
yaml.safe_load(fp))
typing.List[CatalogEntry], yaml.safe_load(fp)
)
for entry in self.sources:
if entry.default:
self.current = entry
@ -113,10 +112,10 @@ class SourceModel:
if self.lang in self.current.preinstalled_langs:
suffix = self.lang
else:
suffix = 'no-languages'
path = base + '.' + suffix + ext
suffix = "no-languages"
path = base + "." + suffix + ext
scheme = self.current.type
return f'{scheme}://{path}'
return f"{scheme}://{path}"
def render(self):
return {}

View File

@ -20,7 +20,6 @@ log = logging.getLogger("subiquity.models.ssh")
class SSHModel:
def __init__(self):
self.install_server = False
self.authorized_keys: List[str] = []
@ -28,10 +27,10 @@ class SSHModel:
# Although the generated config just contains the key above,
# we store the imported id so that we can re-fill the form if
# we go back to it.
self.ssh_import_id = ''
self.ssh_import_id = ""
async def target_packages(self):
if self.install_server:
return ['openssh-server']
return ["openssh-server"]
else:
return []

View File

@ -14,37 +14,35 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
from collections import OrderedDict
import functools
import json
import logging
import os
from typing import Any, Dict, Set
import uuid
import yaml
from collections import OrderedDict
from typing import Any, Dict, Set
import yaml
from cloudinit.config.schema import (
SchemaValidationError,
get_schema,
validate_cloudconfig_schema
validate_cloudconfig_schema,
)
try:
from cloudinit.config.schema import SchemaProblem
except ImportError:
def SchemaProblem(x, y): return (x, y) # TODO(drop on cloud-init 22.3 SRU)
def SchemaProblem(x, y):
return (x, y) # TODO(drop on cloud-init 22.3 SRU)
from curtin.config import merge_config
from subiquitycore.file_util import (
generate_timestamped_header,
write_file,
)
from subiquitycore.lsb_release import lsb_release
from subiquity.common.resources import get_users_and_groups
from subiquity.server.types import InstallerChannels
from subiquitycore.file_util import generate_timestamped_header, write_file
from subiquitycore.lsb_release import lsb_release
from .ad import AdModel
from .codecs import CodecsModel
@ -66,8 +64,7 @@ from .timezone import TimeZoneModel
from .ubuntu_pro import UbuntuProModel
from .updates import UpdatesModel
log = logging.getLogger('subiquity.models.subiquity')
log = logging.getLogger("subiquity.models.subiquity")
def merge_cloud_init_config(target, source):
@ -136,7 +133,7 @@ CLOUDINIT_DISABLE_AFTER_INSTALL = {
"Disabled by Ubuntu live installer after first boot.\n"
"To re-enable cloud-init on this image run:\n"
" sudo cloud-init clean --machine-id\n"
)
),
}
@ -156,29 +153,26 @@ class ModelNames:
class DebconfSelectionsModel:
def __init__(self):
self.selections = ''
self.selections = ""
def render(self):
return {}
def get_apt_config(
self, final: bool, has_network: bool) -> Dict[str, Any]:
return {'debconf_selections': {'subiquity': self.selections}}
def get_apt_config(self, final: bool, has_network: bool) -> Dict[str, Any]:
return {"debconf_selections": {"subiquity": self.selections}}
class SubiquityModel:
"""The overall model for subiquity."""
target = '/target'
chroot_prefix = ['chroot', target]
target = "/target"
chroot_prefix = ["chroot", target]
def __init__(self, root, hub, install_model_names,
postinstall_model_names):
def __init__(self, root, hub, install_model_names, postinstall_model_names):
self.root = root
self.hub = hub
if root != '/':
if root != "/":
self.target = root
self.chroot_prefix = []
@ -212,8 +206,7 @@ class SubiquityModel:
self._install_model_names = install_model_names
self._postinstall_model_names = postinstall_model_names
self._cur_install_model_names = install_model_names.default_names
self._cur_postinstall_model_names = \
postinstall_model_names.default_names
self._cur_postinstall_model_names = postinstall_model_names.default_names
self._install_event = asyncio.Event()
self._postinstall_event = asyncio.Event()
all_names = set()
@ -222,15 +215,17 @@ class SubiquityModel:
for name in all_names:
hub.subscribe(
(InstallerChannels.CONFIGURED, name),
functools.partial(self._configured, name))
functools.partial(self._configured, name),
)
def set_source_variant(self, variant):
self._cur_install_model_names = \
self._install_model_names.for_variant(variant)
self._cur_postinstall_model_names = \
self._postinstall_model_names.for_variant(variant)
unconfigured_install_model_names = \
self._cur_install_model_names = self._install_model_names.for_variant(variant)
self._cur_postinstall_model_names = self._postinstall_model_names.for_variant(
variant
)
unconfigured_install_model_names = (
self._cur_install_model_names - self._configured_names
)
if unconfigured_install_model_names:
if self._install_event.is_set():
self._install_event = asyncio.Event()
@ -238,8 +233,9 @@ class SubiquityModel:
self._confirmation_task.cancel()
else:
self._install_event.set()
unconfigured_postinstall_model_names = \
unconfigured_postinstall_model_names = (
self._cur_postinstall_model_names - self._configured_names
)
if unconfigured_postinstall_model_names:
if self._postinstall_event.is_set():
self._postinstall_event = asyncio.Event()
@ -250,24 +246,31 @@ class SubiquityModel:
"""Add the model to the set of models that have been configured. If
there is no more model to configure in the relevant section(s) (i.e.,
INSTALL or POSTINSTALL), we trigger the associated event(s)."""
def log_and_trigger(stage: str, names: Set[str],
event: asyncio.Event) -> None:
def log_and_trigger(stage: str, names: Set[str], event: asyncio.Event) -> None:
unconfigured = names - self._configured_names
log.debug(
"model %s for %s stage is configured, to go %s",
model_name, stage, unconfigured)
model_name,
stage,
unconfigured,
)
if not unconfigured:
event.set()
self._configured_names.add(model_name)
if model_name in self._cur_install_model_names:
log_and_trigger(stage="install",
log_and_trigger(
stage="install",
names=self._cur_install_model_names,
event=self._install_event)
event=self._install_event,
)
if model_name in self._cur_postinstall_model_names:
log_and_trigger(stage="postinstall",
log_and_trigger(
stage="postinstall",
names=self._cur_postinstall_model_names,
event=self._postinstall_event)
event=self._postinstall_event,
)
async def wait_install(self):
if len(self._cur_install_model_names) == 0:
@ -281,8 +284,7 @@ class SubiquityModel:
async def wait_confirmation(self):
if self._confirmation_task is None:
self._confirmation_task = asyncio.create_task(
self._confirmation.wait())
self._confirmation_task = asyncio.create_task(self._confirmation.wait())
try:
await self._confirmation_task
except asyncio.CancelledError:
@ -293,8 +295,10 @@ class SubiquityModel:
self._confirmation_task = None
def is_postinstall_only(self, model_name):
return model_name in self._cur_postinstall_model_names and \
model_name not in self._cur_install_model_names
return (
model_name in self._cur_postinstall_model_names
and model_name not in self._cur_install_model_names
)
async def confirm(self):
self._confirmation.set()
@ -326,9 +330,9 @@ class SubiquityModel:
if warnings:
log.warning(
"The cloud-init configuration for %s contains"
" deprecated values:\n%s", data_source, "\n".join(
warnings
)
" deprecated values:\n%s",
data_source,
"\n".join(warnings),
)
if e.schema_errors:
if data_source == "autoinstall.user-data":
@ -342,50 +346,46 @@ class SubiquityModel:
def _cloud_init_config(self):
config = {
'growpart': {
'mode': 'off',
"growpart": {
"mode": "off",
},
'resize_rootfs': False,
"resize_rootfs": False,
}
if self.identity.hostname is not None:
config['preserve_hostname'] = True
config["preserve_hostname"] = True
user = self.identity.user
if user:
groups = get_users_and_groups(self.chroot_prefix)
user_info = {
'name': user.username,
'gecos': user.realname,
'passwd': user.password,
'shell': '/bin/bash',
'groups': ','.join(sorted(groups)),
'lock_passwd': False,
"name": user.username,
"gecos": user.realname,
"passwd": user.password,
"shell": "/bin/bash",
"groups": ",".join(sorted(groups)),
"lock_passwd": False,
}
if self.ssh.authorized_keys:
user_info['ssh_authorized_keys'] = self.ssh.authorized_keys
config['users'] = [user_info]
user_info["ssh_authorized_keys"] = self.ssh.authorized_keys
config["users"] = [user_info]
else:
if self.ssh.authorized_keys:
config['ssh_authorized_keys'] = self.ssh.authorized_keys
config["ssh_authorized_keys"] = self.ssh.authorized_keys
if self.ssh.install_server:
config['ssh_pwauth'] = self.ssh.pwauth
config["ssh_pwauth"] = self.ssh.pwauth
for model_name in self._postinstall_model_names.all():
model = getattr(self, model_name)
if getattr(model, 'make_cloudconfig', None):
if getattr(model, "make_cloudconfig", None):
merge_config(config, model.make_cloudconfig())
merge_cloud_init_config(config, self.userdata)
if lsb_release()['release'] not in ("20.04", "22.04"):
config.setdefault("write_files", []).append(
CLOUDINIT_DISABLE_AFTER_INSTALL
)
self.validate_cloudconfig_schema(
data=config, data_source="system install"
)
if lsb_release()["release"] not in ("20.04", "22.04"):
config.setdefault("write_files", []).append(CLOUDINIT_DISABLE_AFTER_INSTALL)
self.validate_cloudconfig_schema(data=config, data_source="system install")
return config
async def target_packages(self):
packages = list(self.packages)
for model_name in self._postinstall_model_names.all():
meth = getattr(getattr(self, model_name), 'target_packages', None)
meth = getattr(getattr(self, model_name), "target_packages", None)
if meth is not None:
packages.extend(await meth())
return packages
@ -395,49 +395,55 @@ class SubiquityModel:
# first boot of the target, it reconfigures datasource_list to none
# for subsequent boots.
# (mwhudson does not entirely know what the above means!)
userdata = '#cloud-config\n' + yaml.dump(self._cloud_init_config())
metadata = {'instance-id': str(uuid.uuid4())}
config = yaml.dump({
'datasource_list': ["None"],
'datasource': {
userdata = "#cloud-config\n" + yaml.dump(self._cloud_init_config())
metadata = {"instance-id": str(uuid.uuid4())}
config = yaml.dump(
{
"datasource_list": ["None"],
"datasource": {
"None": {
'userdata_raw': userdata,
'metadata': metadata,
"userdata_raw": userdata,
"metadata": metadata,
},
},
})
}
)
files = [
('etc/cloud/cloud.cfg.d/99-installer.cfg', config, 0o600),
('etc/cloud/ds-identify.cfg', 'policy: enabled\n', 0o644),
("etc/cloud/cloud.cfg.d/99-installer.cfg", config, 0o600),
("etc/cloud/ds-identify.cfg", "policy: enabled\n", 0o644),
]
# Add cloud-init clean hooks to support golden-image creation.
cfg_files = ["/" + path for (path, _content, _cmode) in files]
cfg_files.extend(self.network.rendered_config_paths())
if lsb_release()['release'] not in ("20.04", "22.04"):
if lsb_release()["release"] not in ("20.04", "22.04"):
cfg_files.append("/etc/cloud/cloud-init.disabled")
if self.identity.hostname is not None:
hostname = self.identity.hostname.strip()
files.extend([
('etc/hostname', hostname + "\n", 0o644),
('etc/hosts', HOSTS_CONTENT.format(hostname=hostname), 0o644),
])
files.extend(
[
("etc/hostname", hostname + "\n", 0o644),
("etc/hosts", HOSTS_CONTENT.format(hostname=hostname), 0o644),
]
)
files.append((
'etc/cloud/clean.d/99-installer',
files.append(
(
"etc/cloud/clean.d/99-installer",
CLOUDINIT_CLEAN_FILE_TMPL.format(
header=generate_timestamped_header(),
cfg_files=json.dumps(sorted(cfg_files))
cfg_files=json.dumps(sorted(cfg_files)),
),
0o755
))
0o755,
)
)
return files
def configure_cloud_init(self):
if self.target is None:
# i.e. reset_partition_only
return
if self.source.current.variant == 'core':
if self.source.current.variant == "core":
# can probably be supported but requires changes
return
for path, content, cmode in self._cloud_init_files():
@ -446,59 +452,57 @@ class SubiquityModel:
write_file(path, content, cmode=cmode)
def _media_info(self):
if os.path.exists('/cdrom/.disk/info'):
with open('/cdrom/.disk/info') as fp:
if os.path.exists("/cdrom/.disk/info"):
with open("/cdrom/.disk/info") as fp:
return fp.read()
else:
return "media-info"
def _machine_id(self):
with open('/etc/machine-id') as fp:
with open("/etc/machine-id") as fp:
return fp.read()
def render(self):
config = {
'grub': {
'terminal': 'unmodified',
'probe_additional_os': True,
'reorder_uefi': False,
"grub": {
"terminal": "unmodified",
"probe_additional_os": True,
"reorder_uefi": False,
},
'install': {
'unmount': 'disabled',
'save_install_config': False,
'save_install_log': False,
"install": {
"unmount": "disabled",
"save_install_config": False,
"save_install_log": False,
},
'pollinate': {
'user_agent': {
'subiquity': "%s_%s" % (os.environ.get("SNAP_VERSION",
'dry-run'),
os.environ.get("SNAP_REVISION",
'dry-run')),
"pollinate": {
"user_agent": {
"subiquity": "%s_%s"
% (
os.environ.get("SNAP_VERSION", "dry-run"),
os.environ.get("SNAP_REVISION", "dry-run"),
),
},
},
'write_files': {
'etc_machine_id': {
'path': 'etc/machine-id',
'content': self._machine_id(),
'permissions': 0o444,
"write_files": {
"etc_machine_id": {
"path": "etc/machine-id",
"content": self._machine_id(),
"permissions": 0o444,
},
'media_info': {
'path': 'var/log/installer/media-info',
'content': self._media_info(),
'permissions': 0o644,
"media_info": {
"path": "var/log/installer/media-info",
"content": self._media_info(),
"permissions": 0o644,
},
},
}
if os.path.exists('/run/casper-md5check.json'):
with open('/run/casper-md5check.json') as fp:
config['write_files']['md5check'] = {
'path': 'var/log/installer/casper-md5check.json',
'content': fp.read(),
'permissions': 0o644,
if os.path.exists("/run/casper-md5check.json"):
with open("/run/casper-md5check.json") as fp:
config["write_files"]["md5check"] = {
"path": "var/log/installer/casper-md5check.json",
"content": fp.read(),
"permissions": 0o644,
}
for model_name in self._install_model_names.all():

File diff suppressed because it is too large Load Diff

View File

@ -13,15 +13,10 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from subiquitycore.tests.parameterized import parameterized
from subiquitycore.tests import SubiTestCase
from subiquity.common.types import KeyboardSetting
from subiquity.models.keyboard import (
InconsistentMultiLayoutError,
KeyboardModel,
)
from subiquity.models.keyboard import InconsistentMultiLayoutError, KeyboardModel
from subiquitycore.tests import SubiTestCase
from subiquitycore.tests.parameterized import parameterized
class TestKeyboardModel(SubiTestCase):
@ -30,9 +25,9 @@ class TestKeyboardModel(SubiTestCase):
def testDefaultUS(self):
self.assertIsNone(self.model._setting)
self.assertEqual('us', self.model.setting.layout)
self.assertEqual("us", self.model.setting.layout)
@parameterized.expand((['zz'], ['en']))
@parameterized.expand((["zz"], ["en"]))
def testSetToInvalidLayout(self, layout):
initial = self.model.setting
val = KeyboardSetting(layout=layout)
@ -40,39 +35,41 @@ class TestKeyboardModel(SubiTestCase):
self.model.setting = val
self.assertEqual(initial, self.model.setting)
@parameterized.expand((['zz']))
@parameterized.expand((["zz"]))
def testSetToInvalidVariant(self, variant):
initial = self.model.setting
val = KeyboardSetting(layout='us', variant=variant)
val = KeyboardSetting(layout="us", variant=variant)
with self.assertRaises(ValueError):
self.model.setting = val
self.assertEqual(initial, self.model.setting)
def testMultiLayout(self):
val = KeyboardSetting(layout='us,ara', variant=',')
val = KeyboardSetting(layout="us,ara", variant=",")
self.model.setting = val
self.assertEqual(self.model.setting, val)
def testInconsistentMultiLayout(self):
initial = self.model.setting
val = KeyboardSetting(layout='us,ara', variant='')
val = KeyboardSetting(layout="us,ara", variant="")
with self.assertRaises(InconsistentMultiLayoutError):
self.model.setting = val
self.assertEqual(self.model.setting, initial)
def testInvalidMultiLayout(self):
initial = self.model.setting
val = KeyboardSetting(layout='us,ara', variant='zz,')
val = KeyboardSetting(layout="us,ara", variant="zz,")
with self.assertRaises(ValueError):
self.model.setting = val
self.assertEqual(self.model.setting, initial)
@parameterized.expand([
['ast_ES.UTF-8', 'es', 'ast'],
['de_DE.UTF-8', 'de', ''],
['fr_FR.UTF-8', 'fr', 'latin9'],
['oc', 'us', ''],
])
@parameterized.expand(
[
["ast_ES.UTF-8", "es", "ast"],
["de_DE.UTF-8", "de", ""],
["fr_FR.UTF-8", "fr", "latin9"],
["oc", "us", ""],
]
)
def testSettingForLang(self, lang, layout, variant):
val = self.model.setting_for_lang(lang)
self.assertEqual(layout, val.layout)
@ -81,20 +78,22 @@ class TestKeyboardModel(SubiTestCase):
def testAllLangsHaveKeyboardSuggestion(self):
# every language in the list needs a suggestion,
# even if only the default
with open('languagelist') as fp:
with open("languagelist") as fp:
for line in fp.readlines():
tokens = line.split(':')
tokens = line.split(":")
locale = tokens[1]
self.assertIn(locale, self.model.layout_for_lang.keys())
def testLoadSuggestions(self):
data = self.tmp_path('kbd.yaml')
with open(data, 'w') as fp:
fp.write('''
data = self.tmp_path("kbd.yaml")
with open(data, "w") as fp:
fp.write(
"""
aa_BB.UTF-8:
layout: aa
variant: cc
''')
"""
)
actual = self.model.load_layout_suggestions(data)
expected = {'aa_BB.UTF-8': KeyboardSetting(layout='aa', variant='cc')}
expected = {"aa_BB.UTF-8": KeyboardSetting(layout="aa", variant="cc")}
self.assertEqual(expected, actual)

View File

@ -41,16 +41,16 @@ class TestLocaleModel(unittest.IsolatedAsyncioTestCase):
self.model.selected_language = "fr_FR"
with mock.patch("subiquity.models.locale.arun_command") as arun_cmd:
# Currently, the default for fr_FR is fr_FR.ISO8859-1
with mock.patch("subiquity.models.locale.locale.normalize",
return_value="fr_FR.UTF-8"):
with mock.patch(
"subiquity.models.locale.locale.normalize", return_value="fr_FR.UTF-8"
):
await self.model.localectl_set_locale()
arun_cmd.assert_called_once_with(expected_cmd, check=True)
async def test_try_localectl_set_locale(self):
self.model.selected_language = "fr_FR.UTF-8"
exc = subprocess.CalledProcessError(returncode=1, cmd=["localedef"])
with mock.patch("subiquity.models.locale.arun_command",
side_effect=exc):
with mock.patch("subiquity.models.locale.arun_command", side_effect=exc):
await self.model.try_localectl_set_locale()
with mock.patch("subiquity.models.locale.arun_command"):
await self.model.try_localectl_set_locale()

View File

@ -18,41 +18,37 @@ import unittest
from unittest import mock
from subiquity.models.mirror import (
countrify_uri,
LEGACY_DEFAULT_PRIMARY_SECTION,
LegacyPrimaryEntry,
MirrorModel,
MirrorSelectionFallback,
LegacyPrimaryEntry,
PrimaryEntry,
countrify_uri,
)
class TestCountrifyUrl(unittest.TestCase):
def test_official_archive(self):
self.assertEqual(
countrify_uri(
"http://archive.ubuntu.com/ubuntu",
cc="fr"),
"http://fr.archive.ubuntu.com/ubuntu")
countrify_uri("http://archive.ubuntu.com/ubuntu", cc="fr"),
"http://fr.archive.ubuntu.com/ubuntu",
)
self.assertEqual(
countrify_uri(
"http://archive.ubuntu.com/ubuntu",
cc="us"),
"http://us.archive.ubuntu.com/ubuntu")
countrify_uri("http://archive.ubuntu.com/ubuntu", cc="us"),
"http://us.archive.ubuntu.com/ubuntu",
)
def test_ports_archive(self):
self.assertEqual(
countrify_uri(
"http://ports.ubuntu.com/ubuntu-ports",
cc="fr"),
"http://fr.ports.ubuntu.com/ubuntu-ports")
countrify_uri("http://ports.ubuntu.com/ubuntu-ports", cc="fr"),
"http://fr.ports.ubuntu.com/ubuntu-ports",
)
self.assertEqual(
countrify_uri(
"http://ports.ubuntu.com/ubuntu-ports",
cc="us"),
"http://us.ports.ubuntu.com/ubuntu-ports")
countrify_uri("http://ports.ubuntu.com/ubuntu-ports", cc="us"),
"http://us.ports.ubuntu.com/ubuntu-ports",
)
class TestPrimaryEntry(unittest.TestCase):
@ -78,21 +74,20 @@ class TestPrimaryEntry(unittest.TestCase):
model = MirrorModel()
entry = PrimaryEntry.from_config("country-mirror", parent=model)
self.assertEqual(entry, PrimaryEntry(parent=model,
country_mirror=True))
self.assertEqual(entry, PrimaryEntry(parent=model, country_mirror=True))
with self.assertRaises(ValueError):
entry = PrimaryEntry.from_config({}, parent=model)
entry = PrimaryEntry.from_config(
{"uri": "http://mirror"}, parent=model)
self.assertEqual(entry, PrimaryEntry(
uri="http://mirror", parent=model))
entry = PrimaryEntry.from_config({"uri": "http://mirror"}, parent=model)
self.assertEqual(entry, PrimaryEntry(uri="http://mirror", parent=model))
entry = PrimaryEntry.from_config(
{"uri": "http://mirror", "arches": ["amd64"]}, parent=model)
self.assertEqual(entry, PrimaryEntry(
uri="http://mirror", arches=["amd64"], parent=model))
{"uri": "http://mirror", "arches": ["amd64"]}, parent=model
)
self.assertEqual(
entry, PrimaryEntry(uri="http://mirror", arches=["amd64"], parent=model)
)
class TestLegacyPrimaryEntry(unittest.TestCase):
@ -111,8 +106,8 @@ class TestLegacyPrimaryEntry(unittest.TestCase):
def test_get_uri(self):
self.model.architecture = "amd64"
primary = LegacyPrimaryEntry(
[{"uri": "http://myurl", "arches": "amd64"}],
parent=self.model)
[{"uri": "http://myurl", "arches": "amd64"}], parent=self.model
)
self.assertEqual(primary.uri, "http://myurl")
def test_set_uri(self):
@ -130,8 +125,9 @@ class TestMirrorModel(unittest.TestCase):
self.model_legacy = MirrorModel()
self.model_legacy.legacy_primary = True
self.model_legacy.primary_candidates = [
LegacyPrimaryEntry(copy.deepcopy(
LEGACY_DEFAULT_PRIMARY_SECTION), parent=self.model_legacy),
LegacyPrimaryEntry(
copy.deepcopy(LEGACY_DEFAULT_PRIMARY_SECTION), parent=self.model_legacy
),
]
self.candidate_legacy = self.model_legacy.primary_candidates[0]
self.model_legacy.primary_staged = self.candidate_legacy
@ -152,7 +148,8 @@ class TestMirrorModel(unittest.TestCase):
[
"http://CC.archive.ubuntu.com/ubuntu",
"http://CC.ports.ubuntu.com/ubuntu-ports",
])
],
)
do_test(self.model)
do_test(self.model_legacy)
@ -167,7 +164,7 @@ class TestMirrorModel(unittest.TestCase):
def test_default_disable_components(self):
def do_test(model, candidate):
config = model.get_apt_config_staged()
self.assertEqual([], config['disable_components'])
self.assertEqual([], config["disable_components"])
# The candidate[0] is a country-mirror, skip it.
candidate = self.model.primary_candidates[1]
@ -179,15 +176,14 @@ class TestMirrorModel(unittest.TestCase):
# autoinstall loads to the config directly
model = MirrorModel()
data = {
'disable_components': ['non-free'],
'fallback': 'offline-install',
"disable_components": ["non-free"],
"fallback": "offline-install",
}
model.load_autoinstall_data(data)
self.assertFalse(model.legacy_primary)
model.primary_candidates[0].stage()
self.assertEqual(set(['non-free']), model.disabled_components)
self.assertEqual(model.primary_candidates,
model._default_primary_entries())
self.assertEqual(set(["non-free"]), model.disabled_components)
self.assertEqual(model.primary_candidates, model._default_primary_entries())
def test_from_autoinstall_modern(self):
data = {
@ -202,16 +198,19 @@ class TestMirrorModel(unittest.TestCase):
}
model = MirrorModel()
model.load_autoinstall_data(data)
self.assertEqual(model.primary_candidates, [
self.assertEqual(
model.primary_candidates,
[
PrimaryEntry(parent=model, country_mirror=True),
PrimaryEntry(uri="http://mirror", parent=model),
])
],
)
def test_disable_add(self):
def do_test(model, candidate):
expected = ['things', 'stuff']
expected = ["things", "stuff"]
model.disable_components(expected.copy(), add=True)
actual = model.get_apt_config_staged()['disable_components']
actual = model.get_apt_config_staged()["disable_components"]
actual.sort()
expected.sort()
self.assertEqual(expected, actual)
@ -223,11 +222,11 @@ class TestMirrorModel(unittest.TestCase):
do_test(self.model_legacy, self.candidate_legacy)
def test_disable_remove(self):
self.model.disabled_components = set(['a', 'b', 'things'])
to_remove = ['things', 'stuff']
expected = ['a', 'b']
self.model.disabled_components = set(["a", "b", "things"])
to_remove = ["things", "stuff"]
expected = ["a", "b"]
self.model.disable_components(to_remove, add=False)
actual = self.model.get_apt_config_staged()['disable_components']
actual = self.model.get_apt_config_staged()["disable_components"]
actual.sort()
expected.sort()
self.assertEqual(expected, actual)
@ -242,12 +241,15 @@ class TestMirrorModel(unittest.TestCase):
self.model.legacy_primary = False
self.model.fallback = MirrorSelectionFallback.OFFLINE_INSTALL
self.model.primary_candidates = [
PrimaryEntry(uri=None, arches=None,
country_mirror=True, parent=self.model),
PrimaryEntry(uri="http://mirror.local/ubuntu",
arches=None, parent=self.model),
PrimaryEntry(uri="http://amd64.mirror.local/ubuntu",
arches=["amd64"], parent=self.model),
PrimaryEntry(uri=None, arches=None, country_mirror=True, parent=self.model),
PrimaryEntry(
uri="http://mirror.local/ubuntu", arches=None, parent=self.model
),
PrimaryEntry(
uri="http://amd64.mirror.local/ubuntu",
arches=["amd64"],
parent=self.model,
),
]
cfg = self.model.make_autoinstall()
self.assertEqual(cfg["disable_components"], ["non-free"])
@ -259,8 +261,7 @@ class TestMirrorModel(unittest.TestCase):
self.model.disabled_components = set(["non-free"])
self.model.fallback = MirrorSelectionFallback.ABORT
self.model.legacy_primary = True
self.model.primary_candidates = \
[LegacyPrimaryEntry(primary, parent=self.model)]
self.model.primary_candidates = [LegacyPrimaryEntry(primary, parent=self.model)]
self.model.primary_candidates[0].elect()
cfg = self.model.make_autoinstall()
self.assertEqual(cfg["disable_components"], ["non-free"])
@ -269,22 +270,23 @@ class TestMirrorModel(unittest.TestCase):
def test_create_primary_candidate(self):
self.model.legacy_primary = False
candidate = self.model.create_primary_candidate(
"http://mymirror.valid")
candidate = self.model.create_primary_candidate("http://mymirror.valid")
self.assertEqual(candidate.uri, "http://mymirror.valid")
self.model.legacy_primary = True
candidate = self.model.create_primary_candidate(
"http://mymirror.valid")
candidate = self.model.create_primary_candidate("http://mymirror.valid")
self.assertEqual(candidate.uri, "http://mymirror.valid")
def test_wants_geoip(self):
country_mirror_candidates = mock.patch.object(
self.model, "country_mirror_candidates", return_value=iter([]))
self.model, "country_mirror_candidates", return_value=iter([])
)
with country_mirror_candidates:
self.assertFalse(self.model.wants_geoip())
country_mirror_candidates = mock.patch.object(
self.model, "country_mirror_candidates",
return_value=iter([PrimaryEntry(parent=self.model)]))
self.model,
"country_mirror_candidates",
return_value=iter([PrimaryEntry(parent=self.model)]),
)
with country_mirror_candidates:
self.assertTrue(self.model.wants_geoip())

View File

@ -17,9 +17,7 @@ import subprocess
import unittest
from unittest import mock
from subiquity.models.network import (
NetworkModel,
)
from subiquity.models.network import NetworkModel
class TestNetworkModel(unittest.IsolatedAsyncioTestCase):
@ -42,6 +40,5 @@ class TestNetworkModel(unittest.IsolatedAsyncioTestCase):
self.assertFalse(await self.model.is_nm_enabled())
with mock.patch("subiquity.models.network.arun_command") as arun:
arun.side_effect = subprocess.CalledProcessError(
1, [], None, "error")
arun.side_effect = subprocess.CalledProcessError(1, [], None, "error")
self.assertFalse(await self.model.is_nm_enabled())

View File

@ -18,28 +18,26 @@ import random
import shutil
import tempfile
import unittest
import yaml
from subiquity.models.source import SourceModel
from subiquitycore.tests.util import random_string
from subiquity.models.source import (
SourceModel,
)
def make_entry(**fields):
raw = {
'name': {
'en': random_string(),
"name": {
"en": random_string(),
},
'description': {
'en': random_string(),
"description": {
"en": random_string(),
},
'id': random_string(),
'type': random_string(),
'variant': random_string(),
'path': random_string(),
'size': random.randint(1000000, 2000000),
"id": random_string(),
"type": random_string(),
"variant": random_string(),
"path": random_string(),
"size": random.randint(1000000, 2000000),
}
for k, v in fields.items():
raw[k] = v
@ -47,7 +45,6 @@ def make_entry(**fields):
class TestSourceModel(unittest.TestCase):
def tdir(self):
tdir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, tdir)
@ -56,83 +53,78 @@ class TestSourceModel(unittest.TestCase):
def write_and_load_entries(self, model, entries, dir=None):
if dir is None:
dir = self.tdir()
cat_path = os.path.join(dir, 'catalog.yaml')
with open(cat_path, 'w') as fp:
cat_path = os.path.join(dir, "catalog.yaml")
with open(cat_path, "w") as fp:
yaml.dump(entries, fp)
with open(cat_path) as fp:
model.load_from_file(fp)
def test_initially_server(self):
model = SourceModel()
self.assertEqual(model.current.variant, 'server')
self.assertEqual(model.current.variant, "server")
def test_load_from_file_simple(self):
entry = make_entry(id='id1')
entry = make_entry(id="id1")
model = SourceModel()
self.write_and_load_entries(model, [entry])
self.assertEqual(model.current.id, 'id1')
self.assertEqual(model.current.id, "id1")
def test_load_from_file_ignores_extra_keys(self):
entry = make_entry(id='id1', foobarbaz=random_string())
entry = make_entry(id="id1", foobarbaz=random_string())
model = SourceModel()
self.write_and_load_entries(model, [entry])
self.assertEqual(model.current.id, 'id1')
self.assertEqual(model.current.id, "id1")
def test_default(self):
entries = [
make_entry(id='id1'),
make_entry(id='id2', default=True),
make_entry(id='id3'),
make_entry(id="id1"),
make_entry(id="id2", default=True),
make_entry(id="id3"),
]
model = SourceModel()
self.write_and_load_entries(model, entries)
self.assertEqual(model.current.id, 'id2')
self.assertEqual(model.current.id, "id2")
def test_get_source_absolute(self):
entry = make_entry(
type='scheme',
path='/foo/bar/baz',
type="scheme",
path="/foo/bar/baz",
)
model = SourceModel()
self.write_and_load_entries(model, [entry])
self.assertEqual(
model.get_source(), 'scheme:///foo/bar/baz')
self.assertEqual(model.get_source(), "scheme:///foo/bar/baz")
def test_get_source_relative(self):
dir = self.tdir()
entry = make_entry(
type='scheme',
path='foo/bar/baz',
type="scheme",
path="foo/bar/baz",
)
model = SourceModel()
self.write_and_load_entries(model, [entry], dir)
self.assertEqual(
model.get_source(),
f'scheme://{dir}/foo/bar/baz')
self.assertEqual(model.get_source(), f"scheme://{dir}/foo/bar/baz")
def test_get_source_initial(self):
model = SourceModel()
self.assertEqual(
model.get_source(),
'cp:///media/filesystem')
self.assertEqual(model.get_source(), "cp:///media/filesystem")
def test_render(self):
model = SourceModel()
self.assertEqual(model.render(), {})
def test_canary(self):
with open('examples/sources/install-canary.yaml') as fp:
with open("examples/sources/install-canary.yaml") as fp:
model = SourceModel()
model.load_from_file(fp)
self.assertEqual(2, len(model.sources))
minimal = model.get_matching_source('ubuntu-desktop-minimal')
minimal = model.get_matching_source("ubuntu-desktop-minimal")
self.assertIsNotNone(minimal.variations)
self.assertEqual(minimal.size, minimal.variations['default'].size)
self.assertEqual(minimal.path, minimal.variations['default'].path)
self.assertIsNone(minimal.variations['default'].snapd_system_label)
self.assertEqual(minimal.size, minimal.variations["default"].size)
self.assertEqual(minimal.path, minimal.variations["default"].path)
self.assertIsNone(minimal.variations["default"].snapd_system_label)
standard = model.get_matching_source('ubuntu-desktop')
standard = model.get_matching_source("ubuntu-desktop")
self.assertIsNotNone(standard.variations)
self.assertEqual(2, len(standard.variations))

View File

@ -15,18 +15,20 @@
import fnmatch
import json
import re
import unittest
from unittest import mock
import re
import yaml
import yaml
from cloudinit.config.schema import SchemaValidationError
try:
from cloudinit.config.schema import SchemaProblem
except ImportError:
def SchemaProblem(x, y): return (x, y) # TODO(drop on cloud-init 22.3 SRU)
from subiquitycore.pubsub import MessageHub
def SchemaProblem(x, y):
return (x, y) # TODO(drop on cloud-init 22.3 SRU)
from subiquity.common.types import IdentityData
from subiquity.models.subiquity import (
@ -35,13 +37,11 @@ from subiquity.models.subiquity import (
ModelNames,
SubiquityModel,
)
from subiquity.server.server import (
INSTALL_MODEL_NAMES,
POSTINSTALL_MODEL_NAMES,
)
from subiquity.server.server import INSTALL_MODEL_NAMES, POSTINSTALL_MODEL_NAMES
from subiquity.server.types import InstallerChannels
from subiquitycore.pubsub import MessageHub
getent_group_output = '''
getent_group_output = """
root:x:0:
daemon:x:1:
bin:x:2:
@ -57,56 +57,55 @@ man:x:12:
sudo:x:27:
ssh:x:118:
users:x:100:
'''
"""
class TestModelNames(unittest.TestCase):
def test_for_known_variant(self):
model_names = ModelNames({'a'}, var1={'b'}, var2={'c'})
self.assertEqual(model_names.for_variant('var1'), {'a', 'b'})
model_names = ModelNames({"a"}, var1={"b"}, var2={"c"})
self.assertEqual(model_names.for_variant("var1"), {"a", "b"})
def test_for_unknown_variant(self):
model_names = ModelNames({'a'}, var1={'b'}, var2={'c'})
self.assertEqual(model_names.for_variant('var3'), {'a'})
model_names = ModelNames({"a"}, var1={"b"}, var2={"c"})
self.assertEqual(model_names.for_variant("var3"), {"a"})
def test_all(self):
model_names = ModelNames({'a'}, var1={'b'}, var2={'c'})
self.assertEqual(model_names.all(), {'a', 'b', 'c'})
model_names = ModelNames({"a"}, var1={"b"}, var2={"c"})
self.assertEqual(model_names.all(), {"a", "b", "c"})
class TestSubiquityModel(unittest.IsolatedAsyncioTestCase):
maxDiff = None
def writtenFiles(self, config):
for k, v in config.get('write_files', {}).items():
for k, v in config.get("write_files", {}).items():
yield v
def assertConfigWritesFile(self, config, path):
self.assertIn(path, [s['path'] for s in self.writtenFiles(config)])
self.assertIn(path, [s["path"] for s in self.writtenFiles(config)])
def writtenFilesMatching(self, config, pattern):
files = list(self.writtenFiles(config))
matching = []
for spec in files:
if fnmatch.fnmatch(spec['path'], pattern):
if fnmatch.fnmatch(spec["path"], pattern):
matching.append(spec)
return matching
def writtenFilesMatchingContaining(self, config, pattern, content):
matching = []
for spec in self.writtenFilesMatching(config, pattern):
if content in spec['content']:
if content in spec["content"]:
matching.append(content)
return matching
def configVal(self, config, path):
cur = config
for component in path.split('.'):
for component in path.split("."):
if not isinstance(cur, dict):
self.fail(
"extracting {} reached non-dict {} too early".format(
path, cur))
"extracting {} reached non-dict {} too early".format(path, cur)
)
if component not in cur:
self.fail("no value found for {}".format(path))
cur = cur[component]
@ -117,11 +116,11 @@ class TestSubiquityModel(unittest.IsolatedAsyncioTestCase):
def assertConfigDoesNotHaveVal(self, config, path):
cur = config
for component in path.split('.'):
for component in path.split("."):
if not isinstance(cur, dict):
self.fail(
"extracting {} reached non-dict {} too early".format(
path, cur))
"extracting {} reached non-dict {} too early".format(path, cur)
)
if component not in cur:
return
cur = cur[component]
@ -129,157 +128,149 @@ class TestSubiquityModel(unittest.IsolatedAsyncioTestCase):
async def test_configure(self):
hub = MessageHub()
model = SubiquityModel(
'test', hub, ModelNames({'a', 'b'}), ModelNames(set()))
model.set_source_variant('var')
await hub.abroadcast((InstallerChannels.CONFIGURED, 'a'))
model = SubiquityModel("test", hub, ModelNames({"a", "b"}), ModelNames(set()))
model.set_source_variant("var")
await hub.abroadcast((InstallerChannels.CONFIGURED, "a"))
self.assertFalse(model._install_event.is_set())
await hub.abroadcast((InstallerChannels.CONFIGURED, 'b'))
await hub.abroadcast((InstallerChannels.CONFIGURED, "b"))
self.assertTrue(model._install_event.is_set())
def make_model(self):
return SubiquityModel(
'test', MessageHub(), INSTALL_MODEL_NAMES, POSTINSTALL_MODEL_NAMES)
"test", MessageHub(), INSTALL_MODEL_NAMES, POSTINSTALL_MODEL_NAMES
)
def test_proxy_set(self):
model = self.make_model()
proxy_val = 'http://my-proxy'
proxy_val = "http://my-proxy"
model.proxy.proxy = proxy_val
config = model.render()
self.assertConfigHasVal(config, 'proxy.http_proxy', proxy_val)
self.assertConfigHasVal(config, 'proxy.https_proxy', proxy_val)
self.assertConfigHasVal(config, "proxy.http_proxy", proxy_val)
self.assertConfigHasVal(config, "proxy.https_proxy", proxy_val)
confs = self.writtenFilesMatchingContaining(
config,
'etc/systemd/system/snapd.service.d/*.conf',
'HTTP_PROXY=' + proxy_val)
"etc/systemd/system/snapd.service.d/*.conf",
"HTTP_PROXY=" + proxy_val,
)
self.assertTrue(len(confs) > 0)
def test_proxy_notset(self):
model = self.make_model()
config = model.render()
self.assertConfigDoesNotHaveVal(config, 'proxy.http_proxy')
self.assertConfigDoesNotHaveVal(config, 'proxy.https_proxy')
self.assertConfigDoesNotHaveVal(config, 'apt.http_proxy')
self.assertConfigDoesNotHaveVal(config, 'apt.https_proxy')
self.assertConfigDoesNotHaveVal(config, "proxy.http_proxy")
self.assertConfigDoesNotHaveVal(config, "proxy.https_proxy")
self.assertConfigDoesNotHaveVal(config, "apt.http_proxy")
self.assertConfigDoesNotHaveVal(config, "apt.https_proxy")
confs = self.writtenFilesMatchingContaining(
config,
'etc/systemd/system/snapd.service.d/*.conf',
'HTTP_PROXY=')
config, "etc/systemd/system/snapd.service.d/*.conf", "HTTP_PROXY="
)
self.assertTrue(len(confs) == 0)
def test_keyboard(self):
model = self.make_model()
config = model.render()
self.assertConfigWritesFile(config, 'etc/default/keyboard')
self.assertConfigWritesFile(config, "etc/default/keyboard")
def test_writes_machine_id_media_info(self):
model_no_proxy = self.make_model()
model_proxy = self.make_model()
model_proxy.proxy.proxy = 'http://something'
model_proxy.proxy.proxy = "http://something"
for model in model_no_proxy, model_proxy:
config = model.render()
self.assertConfigWritesFile(config, 'etc/machine-id')
self.assertConfigWritesFile(config, 'var/log/installer/media-info')
self.assertConfigWritesFile(config, "etc/machine-id")
self.assertConfigWritesFile(config, "var/log/installer/media-info")
def test_storage_version(self):
model = self.make_model()
config = model.render()
self.assertConfigHasVal(config, 'storage.version', 1)
self.assertConfigHasVal(config, "storage.version", 1)
def test_write_netplan(self):
model = self.make_model()
config = model.render()
netplan_content = None
for fspec in config['write_files'].values():
if fspec['path'].startswith('etc/netplan'):
for fspec in config["write_files"].values():
if fspec["path"].startswith("etc/netplan"):
if netplan_content is not None:
self.fail("writing two files to netplan?")
netplan_content = fspec['content']
netplan_content = fspec["content"]
self.assertIsNot(netplan_content, None)
netplan = yaml.safe_load(netplan_content)
self.assertConfigHasVal(netplan, 'network.version', 2)
self.assertConfigHasVal(netplan, "network.version", 2)
def test_sources(self):
model = self.make_model()
config = model.render()
self.assertNotIn('sources', config)
self.assertNotIn("sources", config)
def test_mirror(self):
model = self.make_model()
mirror_val = 'http://my-mirror'
mirror_val = "http://my-mirror"
model.mirror.create_primary_candidate(mirror_val).elect()
config = model.render()
self.assertNotIn('apt', config)
self.assertNotIn("apt", config)
@mock.patch('subiquitycore.utils.run_command')
@mock.patch("subiquitycore.utils.run_command")
def test_cloud_init_user_list_merge(self, run_cmd):
main_user = IdentityData(
username='mainuser',
crypted_password='sample_value',
hostname='somehost')
secondary_user = {'name': 'user2'}
username="mainuser", crypted_password="sample_value", hostname="somehost"
)
secondary_user = {"name": "user2"}
with self.subTest('Main user + secondary user'):
with self.subTest("Main user + secondary user"):
model = self.make_model()
model.identity.add_user(main_user)
model.userdata = {'users': [secondary_user]}
model.userdata = {"users": [secondary_user]}
run_cmd.return_value.stdout = getent_group_output
cloud_init_config = model._cloud_init_config()
self.assertEqual(len(cloud_init_config['users']), 2)
self.assertEqual(cloud_init_config['users'][0]['name'], 'mainuser')
self.assertEqual(
cloud_init_config['users'][0]['groups'],
'adm,sudo,users'
)
self.assertEqual(cloud_init_config['users'][1]['name'], 'user2')
run_cmd.assert_called_with(['getent', 'group'], check=True)
self.assertEqual(len(cloud_init_config["users"]), 2)
self.assertEqual(cloud_init_config["users"][0]["name"], "mainuser")
self.assertEqual(cloud_init_config["users"][0]["groups"], "adm,sudo,users")
self.assertEqual(cloud_init_config["users"][1]["name"], "user2")
run_cmd.assert_called_with(["getent", "group"], check=True)
with self.subTest('Secondary user only'):
with self.subTest("Secondary user only"):
model = self.make_model()
model.userdata = {'users': [secondary_user]}
model.userdata = {"users": [secondary_user]}
cloud_init_config = model._cloud_init_config()
self.assertEqual(len(cloud_init_config['users']), 1)
self.assertEqual(cloud_init_config['users'][0]['name'], 'user2')
self.assertEqual(len(cloud_init_config["users"]), 1)
self.assertEqual(cloud_init_config["users"][0]["name"], "user2")
with self.subTest('Invalid user-data raises error'):
with self.subTest("Invalid user-data raises error"):
model = self.make_model()
model.userdata = {'bootcmd': "nope"}
model.userdata = {"bootcmd": "nope"}
with self.assertRaises(SchemaValidationError) as ctx:
model._cloud_init_config()
expected_error = (
"Cloud config schema errors: bootcmd: 'nope' is not of type"
" 'array'"
"Cloud config schema errors: bootcmd: 'nope' is not of type" " 'array'"
)
self.assertEqual(expected_error, str(ctx.exception))
@mock.patch('subiquity.models.subiquity.lsb_release')
@mock.patch('subiquitycore.file_util.datetime.datetime')
@mock.patch("subiquity.models.subiquity.lsb_release")
@mock.patch("subiquitycore.file_util.datetime.datetime")
def test_cloud_init_files_emits_datasource_config_and_clean_script(
self, datetime, lsb_release
):
datetime.utcnow.return_value = "2004-03-05 ..."
main_user = IdentityData(
username='mainuser',
crypted_password='sample_pass',
hostname='somehost')
username="mainuser", crypted_password="sample_pass", hostname="somehost"
)
model = self.make_model()
model.identity.add_user(main_user)
model.userdata = {}
expected_files = {
'etc/cloud/cloud.cfg.d/99-installer.cfg': re.compile(
'datasource:\n None:\n metadata:\n instance-id: .*\n userdata_raw: "#cloud-config\\\\ngrowpart:\\\\n mode: \\\'off\\\'\\\\npreserve_hostname: true\\\\n\\\\\n' # noqa
"etc/cloud/cloud.cfg.d/99-installer.cfg": re.compile(
"datasource:\n None:\n metadata:\n instance-id: .*\n userdata_raw: \"#cloud-config\\\\ngrowpart:\\\\n mode: \\'off\\'\\\\npreserve_hostname: true\\\\n\\\\\n" # noqa
),
'etc/hostname': 'somehost\n',
'etc/cloud/ds-identify.cfg': 'policy: enabled\n',
'etc/hosts': HOSTS_CONTENT.format(hostname='somehost'),
"etc/hostname": "somehost\n",
"etc/cloud/ds-identify.cfg": "policy: enabled\n",
"etc/hosts": HOSTS_CONTENT.format(hostname="somehost"),
}
# Avoid removing /etc/hosts and /etc/hostname in cloud-init clean
cfg_files = [
"/" + key for key in expected_files.keys() if "host" not in key
]
cfg_files = ["/" + key for key in expected_files.keys() if "host" not in key]
cfg_files.append(
# Obtained from NetworkModel.render when cloud-init features
# NETPLAN_CONFIG_ROOT_READ_ONLY is True
@ -287,15 +278,15 @@ class TestSubiquityModel(unittest.IsolatedAsyncioTestCase):
)
header = "# Autogenerated by Subiquity: 2004-03-05 ... UTC\n"
with self.subTest(
'Stable releases Jammy do not disable cloud-init.'
' NETPLAN_ROOT_READ_ONLY=True uses cloud-init networking'
"Stable releases Jammy do not disable cloud-init."
" NETPLAN_ROOT_READ_ONLY=True uses cloud-init networking"
):
lsb_release.return_value = {"release": "22.04"}
expected_files['etc/cloud/clean.d/99-installer'] = (
CLOUDINIT_CLEAN_FILE_TMPL.format(
expected_files[
"etc/cloud/clean.d/99-installer"
] = CLOUDINIT_CLEAN_FILE_TMPL.format(
header=header, cfg_files=json.dumps(sorted(cfg_files))
)
)
with unittest.mock.patch(
"subiquity.cloudinit.open",
mock.mock_open(
@ -304,57 +295,52 @@ class TestSubiquityModel(unittest.IsolatedAsyncioTestCase):
)
),
):
for (cpath, content, perms) in model._cloud_init_files():
for cpath, content, perms in model._cloud_init_files():
if isinstance(expected_files[cpath], re.Pattern):
self.assertIsNotNone(
expected_files[cpath].match(content)
)
self.assertIsNotNone(expected_files[cpath].match(content))
else:
self.assertEqual(expected_files[cpath], content)
with self.subTest(
'Kinetic++ disables cloud-init post install.'
' NETPLAN_ROOT_READ_ONLY=False avoids cloud-init networking'
"Kinetic++ disables cloud-init post install."
" NETPLAN_ROOT_READ_ONLY=False avoids cloud-init networking"
):
lsb_release.return_value = {"release": "22.10"}
cfg_files.append(
# Added by _cloud_init_files for 22.10 and later releases
'/etc/cloud/cloud-init.disabled',
"/etc/cloud/cloud-init.disabled",
)
# Obtained from NetworkModel.render
cfg_files.remove('/etc/cloud/cloud.cfg.d/90-installer-network.cfg')
cfg_files.append('/etc/netplan/00-installer-config.yaml')
cfg_files.remove("/etc/cloud/cloud.cfg.d/90-installer-network.cfg")
cfg_files.append("/etc/netplan/00-installer-config.yaml")
cfg_files.append(
'/etc/cloud/cloud.cfg.d/'
'subiquity-disable-cloudinit-networking.cfg'
"/etc/cloud/cloud.cfg.d/" "subiquity-disable-cloudinit-networking.cfg"
)
expected_files[
'etc/cloud/clean.d/99-installer'
"etc/cloud/clean.d/99-installer"
] = CLOUDINIT_CLEAN_FILE_TMPL.format(
header=header, cfg_files=json.dumps(sorted(cfg_files))
)
with unittest.mock.patch(
'subiquity.cloudinit.open',
"subiquity.cloudinit.open",
mock.mock_open(
read_data=json.dumps(
{'features': {'NETPLAN_CONFIG_ROOT_READ_ONLY': False}}
{"features": {"NETPLAN_CONFIG_ROOT_READ_ONLY": False}}
)
),
):
for (cpath, content, perms) in model._cloud_init_files():
for cpath, content, perms in model._cloud_init_files():
if isinstance(expected_files[cpath], re.Pattern):
self.assertIsNotNone(
expected_files[cpath].match(content)
)
self.assertIsNotNone(expected_files[cpath].match(content))
else:
self.assertEqual(expected_files[cpath], content)
def test_validatecloudconfig_schema(self):
model = self.make_model()
with self.subTest('Valid cloud-config does not error'):
with self.subTest("Valid cloud-config does not error"):
model.validate_cloudconfig_schema(
data={"ssh_import_id": ["chad.smith"]},
data_source="autoinstall.user-data"
data_source="autoinstall.user-data",
)
# Create our own subclass for focal as schema_deprecations
@ -367,22 +353,18 @@ class TestSubiquityModel(unittest.IsolatedAsyncioTestCase):
self.schema_deprecations = schema_deprecations
problem = SchemaProblem(
"bogus",
"'bogus' is deprecated, use 'notbogus' instead"
"bogus", "'bogus' is deprecated, use 'notbogus' instead"
)
with self.subTest('Deprecated cloud-config warns'):
with self.subTest("Deprecated cloud-config warns"):
with unittest.mock.patch(
"subiquity.models.subiquity.validate_cloudconfig_schema"
) as validate:
validate.side_effect = SchemaDeprecation(
schema_deprecations=(problem,)
)
validate.side_effect = SchemaDeprecation(schema_deprecations=(problem,))
with self.assertLogs(
"subiquity.models.subiquity", level="INFO"
) as logs:
model.validate_cloudconfig_schema(
data={"bogus": True},
data_source="autoinstall.user-data"
data={"bogus": True}, data_source="autoinstall.user-data"
)
expected = (
"WARNING:subiquity.models.subiquity:The cloud-init"
@ -391,22 +373,20 @@ class TestSubiquityModel(unittest.IsolatedAsyncioTestCase):
)
self.assertEqual(logs.output, [expected])
with self.subTest('Invalid cloud-config schema errors'):
with self.subTest("Invalid cloud-config schema errors"):
with self.assertRaises(SchemaValidationError) as ctx:
model.validate_cloudconfig_schema(
data={"bootcmd": "nope"}, data_source="system info"
)
expected_error = (
"Cloud config schema errors: bootcmd: 'nope' is not of"
" type 'array'"
"Cloud config schema errors: bootcmd: 'nope' is not of" " type 'array'"
)
self.assertEqual(expected_error, str(ctx.exception))
with self.subTest('Prefix autoinstall.user-data cloud-config errors'):
with self.subTest("Prefix autoinstall.user-data cloud-config errors"):
with self.assertRaises(SchemaValidationError) as ctx:
model.validate_cloudconfig_schema(
data={"bootcmd": "nope"},
data_source="autoinstall.user-data"
data={"bootcmd": "nope"}, data_source="autoinstall.user-data"
)
expected_error = (
"Cloud config schema errors:"

View File

@ -15,13 +15,13 @@
import logging
log = logging.getLogger('subiquity.models.timezone')
log = logging.getLogger("subiquity.models.timezone")
class TimeZoneModel(object):
"""Model representing timezone"""
timezone = ''
timezone = ""
def __init__(self):
# This is the raw request from the API / autoinstall.
@ -32,16 +32,16 @@ class TimeZoneModel(object):
self._request = None
# The actually timezone to set, possibly post-geoip lookup or
# possibly manually specified.
self.timezone = ''
self.timezone = ""
self.got_from_geoip = False
def set(self, value):
self._request = value
self.timezone = value if value != 'geoip' else ''
self.timezone = value if value != "geoip" else ""
@property
def detect_with_geoip(self):
return self._request == 'geoip'
return self._request == "geoip"
@property
def should_set_tz(self):
@ -54,10 +54,15 @@ class TimeZoneModel(object):
def make_cloudconfig(self):
if not self.should_set_tz:
return {}
return {'timezone': self.timezone}
return {"timezone": self.timezone}
def __repr__(self):
return ("<TimeZone: request {} detect {} should_set {} " +
"timezone {} gfg {}>").format(
self.request, self.detect_with_geoip, self.should_set_tz,
self.timezone, self.got_from_geoip)
return (
"<TimeZone: request {} detect {} should_set {} " + "timezone {} gfg {}>"
).format(
self.request,
self.detect_with_geoip,
self.should_set_tz,
self.timezone,
self.got_from_geoip,
)

View File

@ -26,6 +26,7 @@ class UbuntuProModel:
the provided token is correct ; nor to retrieve information about the
subscription.
"""
def __init__(self):
"""Initialize the model."""
self.token: str = ""

View File

@ -15,13 +15,13 @@
import logging
log = logging.getLogger('subiquity.models.updates')
log = logging.getLogger("subiquity.models.updates")
class UpdatesModel(object):
"""Model representing updates selection"""
updates = 'security'
updates = "security"
def __repr__(self):
return "<Updates: {}>".format(self.updates)

View File

@ -14,19 +14,17 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
from contextlib import contextmanager
import logging
import os
from contextlib import contextmanager
from socket import gethostname
from subprocess import CalledProcessError
from subiquitycore.utils import arun_command, run_command
from subiquity.server.curtin import run_curtin_command
from subiquity.common.types import (
AdConnectionInfo,
AdJoinResult,
)
log = logging.getLogger('subiquity.server.ad_joiner')
from subiquity.common.types import AdConnectionInfo, AdJoinResult
from subiquity.server.curtin import run_curtin_command
from subiquitycore.utils import arun_command, run_command
log = logging.getLogger("subiquity.server.ad_joiner")
@contextmanager
@ -37,7 +35,7 @@ def joining_context(hostname: str, root_dir: str):
hostname_current = gethostname()
binds = ("/proc", "/sys", "/dev", "/run")
try:
hostname_process = run_command(['hostname', hostname])
hostname_process = run_command(["hostname", hostname])
for bind in binds:
bound_dir = os.path.join(root_dir, bind[1:])
if bound_dir != bind:
@ -45,7 +43,7 @@ def joining_context(hostname: str, root_dir: str):
yield hostname_process
finally:
# Restoring the live session hostname.
hostname_process = run_command(['hostname', hostname_current])
hostname_process = run_command(["hostname", hostname_current])
if hostname_process.returncode:
log.info("Failed to restore live session hostname")
for bind in reversed(binds):
@ -54,15 +52,16 @@ def joining_context(hostname: str, root_dir: str):
run_command(["umount", "-f", bound_dir])
class AdJoinStrategy():
class AdJoinStrategy:
realm = "/usr/sbin/realm"
pam = "/usr/sbin/pam-auth-update"
def __init__(self, app):
self.app = app
async def do_join(self, info: AdConnectionInfo, hostname: str, context) \
-> AdJoinResult:
async def do_join(
self, info: AdConnectionInfo, hostname: str, context
) -> AdJoinResult:
"""This method changes the hostname and perform a real AD join, thus
should only run in a live session."""
root_dir = self.app.base_model.target
@ -73,11 +72,21 @@ class AdJoinStrategy():
log.info("Failed to set live session hostname for adcli")
return AdJoinResult.JOIN_ERROR
cp = await arun_command([self.realm, "join", "--install", root_dir,
"--user", info.admin_name,
"--computer-name", hostname,
"--unattended", info.domain_name],
input=info.password)
cp = await arun_command(
[
self.realm,
"join",
"--install",
root_dir,
"--user",
info.admin_name,
"--computer-name",
hostname,
"--unattended",
info.domain_name,
],
input=info.password,
)
if cp.returncode:
# Try again without the computer name. Lab tests shown more
@ -88,10 +97,19 @@ class AdJoinStrategy():
log.debug(cp.stderr)
log.debug(cp.stdout)
log.debug("Trying again without overriding the computer name:")
cp = await arun_command([self.realm, "join", "--install",
root_dir, "--user", info.admin_name,
"--unattended", info.domain_name],
input=info.password)
cp = await arun_command(
[
self.realm,
"join",
"--install",
root_dir,
"--user",
info.admin_name,
"--unattended",
info.domain_name,
],
input=info.password,
)
if cp.returncode:
log.debug("Joining operation failed:")
@ -102,11 +120,19 @@ class AdJoinStrategy():
# Enable pam_mkhomedir
try:
# The function raises if the process fail.
await run_curtin_command(self.app, context,
"in-target", "-t", root_dir,
"--", self.pam, "--package",
"--enable", "mkhomedir",
private_mounts=False)
await run_curtin_command(
self.app,
context,
"in-target",
"-t",
root_dir,
"--",
self.pam,
"--package",
"--enable",
"mkhomedir",
private_mounts=False,
)
return AdJoinResult.OK
except CalledProcessError:
@ -120,24 +146,25 @@ class AdJoinStrategy():
class StubStrategy(AdJoinStrategy):
async def do_join(self, info: AdConnectionInfo, hostname: str, context) \
-> AdJoinResult:
async def do_join(
self, info: AdConnectionInfo, hostname: str, context
) -> AdJoinResult:
"""Enables testing without real join. The result depends on the
domain name initial character, such that if it is:
- p or P: returns PAM_ERROR.
- j or J: returns JOIN_ERROR.
- returns OK otherwise."""
initial = info.domain_name[0]
if initial in ('j', 'J'):
if initial in ("j", "J"):
return AdJoinResult.JOIN_ERROR
if initial in ('p', 'P'):
if initial in ("p", "P"):
return AdJoinResult.PAM_ERROR
return AdJoinResult.OK
class AdJoiner():
class AdJoiner:
def __init__(self, app):
self._result = AdJoinResult.UNKNOWN
self._completion_event = asyncio.Event()
@ -146,8 +173,9 @@ class AdJoiner():
else:
self.strategy = AdJoinStrategy(app)
async def join_domain(self, info: AdConnectionInfo, hostname: str,
context) -> AdJoinResult:
async def join_domain(
self, info: AdConnectionInfo, hostname: str, context
) -> AdJoinResult:
if hostname:
self._result = await self.strategy.do_join(info, hostname, context)
else:

View File

@ -28,13 +28,8 @@ import tempfile
from typing import List, Optional
import apt_pkg
from curtin.config import merge_config
from curtin.commands.extract import AbstractSourceHandler
from subiquitycore.file_util import write_file, generate_config_yaml
from subiquitycore.lsb_release import lsb_release
from subiquitycore.utils import astart_command, orig_environ
from curtin.config import merge_config
from subiquity.server.curtin import run_curtin_command
from subiquity.server.mounter import (
@ -44,8 +39,11 @@ from subiquity.server.mounter import (
OverlayCleanupError,
OverlayMountpoint,
)
from subiquitycore.file_util import generate_config_yaml, write_file
from subiquitycore.lsb_release import lsb_release
from subiquitycore.utils import astart_command, orig_environ
log = logging.getLogger('subiquity.server.apt')
log = logging.getLogger("subiquity.server.apt")
class AptConfigCheckError(Exception):
@ -108,8 +106,7 @@ class AptConfigurer:
# system, or if it is not, just copy /var/lib/apt/lists from the
# 'configured_tree' overlay.
def __init__(self, app, mounter: Mounter,
source_handler: AbstractSourceHandler):
def __init__(self, app, mounter: Mounter, source_handler: AbstractSourceHandler):
self.app = app
self.mounter = mounter
self.source_handler: AbstractSourceHandler = source_handler
@ -133,22 +130,29 @@ class AptConfigurer:
self.app.base_model.debconf_selections,
]
for model in models:
merge_config(cfg, model.get_apt_config(
final=final, has_network=has_network))
return {'apt': cfg}
merge_config(
cfg, model.get_apt_config(final=final, has_network=has_network)
)
return {"apt": cfg}
async def apply_apt_config(self, context, final: bool):
self.configured_tree = await self.mounter.setup_overlay(
[self.source_path])
self.configured_tree = await self.mounter.setup_overlay([self.source_path])
config_location = os.path.join(
self.app.root, 'var/log/installer/subiquity-curtin-apt.conf')
self.app.root, "var/log/installer/subiquity-curtin-apt.conf"
)
generate_config_yaml(config_location, self.apt_config(final))
self.app.note_data_for_apport("CurtinAptConfig", config_location)
await run_curtin_command(
self.app, context, 'apt-config', '-t', self.configured_tree.p(),
config=config_location, private_mounts=True)
self.app,
context,
"apt-config",
"-t",
self.configured_tree.p(),
config=config_location,
private_mounts=True,
)
async def run_apt_config_check(self, output: io.StringIO) -> None:
"""Run apt-get update (with various options limiting the amount of
@ -173,15 +177,15 @@ class AptConfigurer:
# Need to ensure the "partial" directory exists.
partial_dir = apt_dirs["State::Lists"] / "partial"
partial_dir.mkdir(
parents=True, exist_ok=True)
partial_dir.mkdir(parents=True, exist_ok=True)
try:
shutil.chown(partial_dir, user="_apt")
except (PermissionError, LookupError) as exc:
log.warning("could to set owner of file %s: %r", partial_dir, exc)
apt_cmd = [
"apt-get", "update",
"apt-get",
"update",
"-oAPT::Update::Error-Mode=any",
# Workaround because the default sandbox user (i.e., _apt) does not
# have access to the overlay.
@ -201,8 +205,9 @@ class AptConfigurer:
env["APT_CONFIG"] = config_file.name
config_file.write(apt_config.dump())
config_file.flush()
proc = await astart_command(apt_cmd, stderr=subprocess.STDOUT,
clean_locale=False, env=env)
proc = await astart_command(
apt_cmd, stderr=subprocess.STDOUT, clean_locale=False, env=env
)
async def _reader():
while not proc.stdout.at_eof():
@ -223,42 +228,51 @@ class AptConfigurer:
async def configure_for_install(self, context):
assert self.configured_tree is not None
self.install_tree = await self.mounter.setup_overlay(
[self.configured_tree])
self.install_tree = await self.mounter.setup_overlay([self.configured_tree])
os.mkdir(self.install_tree.p('cdrom'))
await self.mounter.mount(
'/cdrom', self.install_tree.p('cdrom'), options='bind')
os.mkdir(self.install_tree.p("cdrom"))
await self.mounter.mount("/cdrom", self.install_tree.p("cdrom"), options="bind")
if self.app.base_model.network.has_network:
os.rename(
self.install_tree.p('etc/apt/sources.list'),
self.install_tree.p('etc/apt/sources.list.d/original.list'))
self.install_tree.p("etc/apt/sources.list"),
self.install_tree.p("etc/apt/sources.list.d/original.list"),
)
else:
proxy_path = self.install_tree.p(
'etc/apt/apt.conf.d/90curtin-aptproxy')
proxy_path = self.install_tree.p("etc/apt/apt.conf.d/90curtin-aptproxy")
if os.path.exists(proxy_path):
os.unlink(proxy_path)
codename = lsb_release(dry_run=self.app.opts.dry_run)['codename']
codename = lsb_release(dry_run=self.app.opts.dry_run)["codename"]
write_file(
self.install_tree.p('etc/apt/sources.list'),
f'deb [check-date=no] file:///cdrom {codename} main restricted\n')
self.install_tree.p("etc/apt/sources.list"),
f"deb [check-date=no] file:///cdrom {codename} main restricted\n",
)
await run_curtin_command(
self.app, context, "in-target", "-t", self.install_tree.p(),
"--", "apt-get", "update", private_mounts=True)
self.app,
context,
"in-target",
"-t",
self.install_tree.p(),
"--",
"apt-get",
"update",
private_mounts=True,
)
return self.install_tree.p()
@contextlib.asynccontextmanager
async def overlay(self):
overlay = await self.mounter.setup_overlay([
overlay = await self.mounter.setup_overlay(
[
self.install_tree.upperdir,
self.configured_tree.upperdir,
self.source_path,
])
]
)
try:
yield overlay
finally:
@ -269,8 +283,7 @@ class AptConfigurer:
# compares equal to the one we discarded earlier.
# But really, there should be better ways to handle this.
try:
await self.mounter.unmount(
Mountpoint(mountpoint=overlay.mountpoint))
await self.mounter.unmount(Mountpoint(mountpoint=overlay.mountpoint))
except subprocess.CalledProcessError as exc:
raise OverlayCleanupError from exc
@ -286,41 +299,53 @@ class AptConfigurer:
async def _restore_dir(dir):
shutil.rmtree(target_mnt.p(dir))
await self.app.command_runner.run([
'cp', '-aT', self.configured_tree.p(dir), target_mnt.p(dir),
])
await self.app.command_runner.run(
[
"cp",
"-aT",
self.configured_tree.p(dir),
target_mnt.p(dir),
]
)
def _restore_file(path: str) -> None:
shutil.copyfile(self.configured_tree.p(path), target_mnt.p(path))
# The file only exists if we are online
with contextlib.suppress(FileNotFoundError):
os.unlink(target_mnt.p('etc/apt/sources.list.d/original.list'))
_restore_file('etc/apt/sources.list')
os.unlink(target_mnt.p("etc/apt/sources.list.d/original.list"))
_restore_file("etc/apt/sources.list")
with contextlib.suppress(FileNotFoundError):
_restore_file('etc/apt/apt.conf.d/90curtin-aptproxy')
_restore_file("etc/apt/apt.conf.d/90curtin-aptproxy")
if self.app.base_model.network.has_network:
await run_curtin_command(
self.app, context, "in-target", "-t", target_mnt.p(),
"--", "apt-get", "update", private_mounts=True)
self.app,
context,
"in-target",
"-t",
target_mnt.p(),
"--",
"apt-get",
"update",
private_mounts=True,
)
else:
await _restore_dir('var/lib/apt/lists')
await _restore_dir("var/lib/apt/lists")
await self.cleanup()
try:
d = target_mnt.p('cdrom')
d = target_mnt.p("cdrom")
os.rmdir(d)
except OSError as ose:
log.warning(f'failed to rmdir {d}: {ose}')
log.warning(f"failed to rmdir {d}: {ose}")
async def setup_target(self, context, target: str):
# Call this after the rootfs has been extracted to the real target
# system but before any configuration is applied to it.
target_mnt = Mountpoint(mountpoint=target)
await self.mounter.mount(
'/cdrom', target_mnt.p('cdrom'), options='bind')
await self.mounter.mount("/cdrom", target_mnt.p("cdrom"), options="bind")
class DryRunAptConfigurer(AptConfigurer):
@ -354,7 +379,8 @@ class DryRunAptConfigurer(AptConfigurer):
return self.MirrorCheckStrategy(known["strategy"])
return self.MirrorCheckStrategy(
self.app.dr_cfg.apt_mirror_check_default_strategy)
self.app.dr_cfg.apt_mirror_check_default_strategy
)
async def apt_config_check_failure(self, output: io.StringIO) -> None:
"""Pretend that the execution of the apt-get update command results in
@ -363,7 +389,8 @@ class DryRunAptConfigurer(AptConfigurer):
release = lsb_release(dry_run=True)["codename"]
host = url.split("/")[2]
output.write(f"""\
output.write(
f"""\
Ign:1 {url} {release} InRelease
Ign:2 {url} {release}-updates InRelease
Ign:3 {url} {release}-backports InRelease
@ -389,7 +416,8 @@ E: Failed to fetch {url}/dists/{release}-security/InRelease\
Temporary failure resolving '{host}'
E: Some index files failed to download. They have been ignored,
or old ones used instead.
""")
"""
)
raise AptConfigCheckError
async def apt_config_check_success(self, output: io.StringIO) -> None:
@ -398,14 +426,16 @@ E: Some index files failed to download. They have been ignored,
url = self.app.base_model.mirror.primary_staged.uri
release = lsb_release(dry_run=True)["codename"]
output.write(f"""\
output.write(
f"""\
Get:1 {url} {release} InRelease [267 kB]
Get:2 {url} {release}-updates InRelease [109 kB]
Get:3 {url} {release}-backports InRelease [99.9 kB]
Get:4 {url} {release}-security InRelease [109 kB]
Fetched 585 kB in 1s (1057 kB/s)
Reading package lists...
""")
"""
)
async def run_apt_config_check(self, output: io.StringIO) -> None:
"""Dry-run implementation of the Apt config check.

View File

@ -16,17 +16,12 @@
import asyncio
import logging
from typing import Any, List
import time
from typing import Any, List
from subiquity.server.ubuntu_advantage import (
UAInterface,
APIUsageError,
)
from subiquity.server.ubuntu_advantage import APIUsageError, UAInterface
from subiquitycore.async_helpers import schedule_task
log = logging.getLogger("subiquity.server.contract_selection")
@ -36,6 +31,7 @@ class UPCSExpiredError(Exception):
class UnknownError(Exception):
"""Exception to be raised in case of unexpected error."""
def __init__(self, message: str = "", errors: List[Any] = []) -> None:
self.errors = errors
super().__init__(message)
@ -43,12 +39,14 @@ class UnknownError(Exception):
class ContractSelection:
"""Represents an already initiated contract selection."""
def __init__(
self,
client: UAInterface,
magic_token: str,
user_code: str,
validity_seconds: int) -> None:
validity_seconds: int,
) -> None:
"""Initialize the contract selection."""
self.client = client
self.magic_token = magic_token
@ -57,8 +55,7 @@ class ContractSelection:
self.task = asyncio.create_task(self._run_polling())
@classmethod
async def initiate(cls, client: UAInterface) \
-> "ContractSelection":
async def initiate(cls, client: UAInterface) -> "ContractSelection":
"""Initiate a contract selection and return a ContractSelection
request object."""
answer = await client.strategy.magic_initiate_v1()
@ -70,7 +67,8 @@ class ContractSelection:
client=client,
magic_token=answer["data"]["attributes"]["token"],
validity_seconds=answer["data"]["attributes"]["expires_in"],
user_code=answer["data"]["attributes"]["user_code"])
user_code=answer["data"]["attributes"]["user_code"],
)
async def _run_polling(self) -> str:
"""Runs the polling and eventually return a contract token."""
@ -79,8 +77,7 @@ class ContractSelection:
await asyncio.sleep(30 / self.client.strategy.scale_factor)
start_time = time.monotonic()
answer = await self.client.strategy.magic_wait_v1(
magic_token=self.magic_token)
answer = await self.client.strategy.magic_wait_v1(magic_token=self.magic_token)
call_duration = time.monotonic() - start_time
@ -89,13 +86,14 @@ class ContractSelection:
exception = UnknownError()
for error in answer["errors"]:
if error["code"] == "magic-attach-token-error" and \
call_duration >= 60 / self.client.strategy.scale_factor:
if (
error["code"] == "magic-attach-token-error"
and call_duration >= 60 / self.client.strategy.scale_factor
):
# Assume it's a timeout if it lasted more than 1 minute.
exception = UPCSExpiredError(error["title"])
else:
log.warning("magic_wait_v1: %s: %s",
error["code"], error["title"])
log.warning("magic_wait_v1: %s: %s", error["code"], error["title"])
raise exception
@ -108,7 +106,8 @@ class ContractSelection:
"""Release the resource on the server."""
try:
answer = await self.client.strategy.magic_revoke_v1(
magic_token=self.magic_token)
magic_token=self.magic_token
)
except APIUsageError as e:
log.warning("failed to revoke magic-token: %r", e)
@ -117,7 +116,8 @@ class ContractSelection:
log.debug("successfully revoked magic-token")
return
for error in answer["errors"]:
log.warning("magic_revoke_v1: %s: %s",
error["code"], error["title"])
log.warning(
"magic_revoke_v1: %s: %s", error["code"], error["title"]
)
schedule_task(delete_resource())

View File

@ -20,19 +20,15 @@ from typing import Any, Optional
import jsonschema
from subiquitycore.context import with_context
from subiquitycore.controller import (
BaseController,
)
from subiquity.common.api.server import bind
from subiquity.server.types import InstallerChannels
from subiquitycore.context import with_context
from subiquitycore.controller import BaseController
log = logging.getLogger("subiquity.server.controller")
class SubiquityController(BaseController):
autoinstall_key: Optional[str] = None
autoinstall_schema: Any = None
autoinstall_default: Any = None
@ -48,10 +44,9 @@ class SubiquityController(BaseController):
def __init__(self, app):
super().__init__(app)
self.context.set('controller', self)
self.context.set("controller", self)
if self.interactive_for_variants is not None:
self.app.hub.subscribe(
InstallerChannels.INSTALL_CONFIRMED, self._confirmed)
self.app.hub.subscribe(InstallerChannels.INSTALL_CONFIRMED, self._confirmed)
async def _confirmed(self):
variant = self.app.base_model.source.current.variant
@ -102,8 +97,7 @@ class SubiquityController(BaseController):
def interactive(self):
if not self.app.autoinstall_config:
return self._active
i_sections = self.app.autoinstall_config.get(
'interactive-sections', [])
i_sections = self.app.autoinstall_config.get("interactive-sections", [])
if "*" in i_sections:
return self._active
@ -111,21 +105,23 @@ class SubiquityController(BaseController):
if self.autoinstall_key in i_sections:
return self._active
if self.autoinstall_key_alias is not None \
and self.autoinstall_key_alias in i_sections:
if (
self.autoinstall_key_alias is not None
and self.autoinstall_key_alias in i_sections
):
return self._active
async def configured(self):
"""Let the world know that this controller's model is now configured.
"""
with open(self.app.state_path('states', self.name), 'w') as fp:
"""Let the world know that this controller's model is now configured."""
with open(self.app.state_path("states", self.name), "w") as fp:
json.dump(self.serialize(), fp)
if self.model_name is not None:
await self.app.hub.abroadcast(
(InstallerChannels.CONFIGURED, self.model_name))
(InstallerChannels.CONFIGURED, self.model_name)
)
def load_state(self):
state_path = self.app.state_path('states', self.name)
state_path = self.app.state_path("states", self.name)
if not os.path.exists(state_path):
return
with open(state_path) as fp:
@ -143,6 +139,5 @@ class SubiquityController(BaseController):
class NonInteractiveController(SubiquityController):
def interactive(self):
return False

Some files were not shown because too many files have changed in this diff Show More