Merge pull request #1179 from ogayot/FR-2025

Server-side implementation of third-party drivers
This commit is contained in:
Olivier Gayot 2022-03-01 10:21:59 +01:00 committed by GitHub
commit 2e6fffe02b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 567 additions and 30 deletions

View File

@ -395,6 +395,14 @@
"additionalProperties": false "additionalProperties": false
} }
}, },
"drivers": {
"type": "object",
"properties": {
"install": {
"type": "boolean"
}
}
},
"timezone": { "timezone": {
"type": "string" "type": "string"
}, },

View File

@ -39,6 +39,8 @@ from subiquity.common.types import (
ModifyPartitionV2, ModifyPartitionV2,
RefreshStatus, RefreshStatus,
ShutdownMode, ShutdownMode,
DriversResponse,
DriversPayload,
SnapInfo, SnapInfo,
SnapListResponse, SnapListResponse,
SnapSelection, SnapSelection,
@ -293,6 +295,10 @@ class API:
def POST(data: Payload[ModifyPartitionV2]) \ def POST(data: Payload[ModifyPartitionV2]) \
-> StorageResponseV2: ... -> StorageResponseV2: ...
class drivers:
def GET(wait: bool = False) -> DriversResponse: ...
def POST(data: Payload[DriversPayload]) -> None: ...
class snaplist: class snaplist:
def GET(wait: bool = False) -> SnapListResponse: ... def GET(wait: bool = False) -> SnapListResponse: ...
def POST(data: Payload[List[SnapSelection]]): ... def POST(data: Payload[List[SnapSelection]]): ...

View File

@ -375,6 +375,24 @@ class SnapInfo:
channels: List[ChannelSnapInfo] = attr.Factory(list) channels: List[ChannelSnapInfo] = attr.Factory(list)
@attr.s(auto_attribs=True)
class DriversResponse:
""" Response to GET request to drivers.
:install: tells whether third-party drivers will be installed (if any is
available).
:drivers: tells what third-party drivers will be installed should we decide
to do it. It will bet set to None until we figure out what drivers are
available.
"""
install: bool
drivers: Optional[List[str]]
@attr.s(auto_attribs=True)
class DriversPayload:
install: bool
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class SnapSelection: class SnapSelection:
name: str name: str

View File

@ -0,0 +1,22 @@
# Copyright 2021 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging
log = logging.getLogger('subiquity.models.drivers')
class DriversModel:
do_install = False

View File

@ -30,6 +30,7 @@ from subiquitycore.file_util import write_file
from subiquity.common.resources import get_users_and_groups from subiquity.common.resources import get_users_and_groups
from subiquity.server.types import InstallerChannels from subiquity.server.types import InstallerChannels
from .drivers import DriversModel
from .filesystem import FilesystemModel from .filesystem import FilesystemModel
from .identity import IdentityModel from .identity import IdentityModel
from .kernel import KernelModel from .kernel import KernelModel
@ -132,6 +133,7 @@ class SubiquityModel:
self.chroot_prefix = [] self.chroot_prefix = []
self.debconf_selections = DebconfSelectionsModel() self.debconf_selections = DebconfSelectionsModel()
self.drivers = DriversModel()
self.filesystem = FilesystemModel() self.filesystem = FilesystemModel()
self.identity = IdentityModel() self.identity = IdentityModel()
self.kernel = KernelModel() self.kernel = KernelModel()

View File

@ -13,6 +13,7 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import contextlib
import functools import functools
import logging import logging
import os import os
@ -153,8 +154,10 @@ class AptConfigurer:
self._mounts.append(m) self._mounts.append(m)
return m return m
async def unmount(self, mountpoint: str): async def unmount(self, mountpoint: Mountpoint, remove=True):
await self.app.command_runner.run(['umount', mountpoint]) if remove:
self._mounts.remove(mountpoint)
await self.app.command_runner.run(['umount', mountpoint.mountpoint])
async def setup_overlay(self, lowers: List[Lower]) -> OverlayMountpoint: async def setup_overlay(self, lowers: List[Lower]) -> OverlayMountpoint:
tdir = self.tdir() tdir = self.tdir()
@ -224,9 +227,27 @@ class AptConfigurer:
return self.install_tree.p() return self.install_tree.p()
@contextlib.asynccontextmanager
async def overlay(self):
overlay = await self.setup_overlay([
self.install_tree.upperdir,
self.configured_tree.upperdir,
self.source
])
try:
yield overlay
finally:
# TODO self.unmount expects a Mountpoint object. Unfortunately, the
# one we created in setup_overlay was discarded and replaced by an
# OverlayMountPoint object instead. Here we re-create a new
# Mountpoint object and (thanks to attr.s) make sure that it
# compares equal to the one we discarded earlier.
# But really, there should be better ways to handle this.
await self.unmount(Mountpoint(mountpoint=overlay.mountpoint))
async def cleanup(self): async def cleanup(self):
for m in reversed(self._mounts): for m in reversed(self._mounts):
await self.unmount(m.mountpoint) await self.unmount(m, remove=False)
for d in self._tdirs: for d in self._tdirs:
shutil.rmtree(d) shutil.rmtree(d)
@ -239,7 +260,9 @@ class AptConfigurer:
'cp', '-aT', self.configured_tree.p(dir), target_mnt.p(dir), 'cp', '-aT', self.configured_tree.p(dir), target_mnt.p(dir),
]) ])
await self.unmount(target_mnt.p('cdrom')) await self.unmount(
Mountpoint(mountpoint=target_mnt.p('cdrom')),
remove=False)
os.rmdir(target_mnt.p('cdrom')) os.rmdir(target_mnt.p('cdrom'))
await _restore_dir('etc/apt') await _restore_dir('etc/apt')
@ -253,14 +276,13 @@ class AptConfigurer:
await self.cleanup() await self.cleanup()
if self.app.base_model.network.has_network:
await run_curtin_command(
self.app, context, "in-target", "-t", target_mnt.p(),
"--", "apt-get", "update")
class DryRunAptConfigurer(AptConfigurer): class DryRunAptConfigurer(AptConfigurer):
async def unmount(self, mountpoint: Mountpoint, remove=True):
if remove:
self._mounts.remove(mountpoint)
async def setup_overlay(self, lowers: List[Lower]) -> OverlayMountpoint: async def setup_overlay(self, lowers: List[Lower]) -> OverlayMountpoint:
# XXX This implementation expects that: # XXX This implementation expects that:
# - on first invocation, the lowers list contains a single string # - on first invocation, the lowers list contains a single string
@ -285,8 +307,12 @@ class DryRunAptConfigurer(AptConfigurer):
mountpoint=target, mountpoint=target,
upperdir=None) upperdir=None)
@contextlib.asynccontextmanager
async def overlay(self):
yield await self.setup_overlay(self.install_tree.mountpoint)
async def deconfigure(self, context, target): async def deconfigure(self, context, target):
return await self.cleanup()
def get_apt_configurer(app, source: str): def get_apt_configurer(app, source: str):

View File

@ -15,6 +15,7 @@
from .cmdlist import EarlyController, LateController, ErrorController from .cmdlist import EarlyController, LateController, ErrorController
from .debconf import DebconfController from .debconf import DebconfController
from .drivers import DriversController
from .filesystem import FilesystemController from .filesystem import FilesystemController
from .identity import IdentityController from .identity import IdentityController
from .install import InstallController from .install import InstallController
@ -39,6 +40,7 @@ from .zdev import ZdevController
__all__ = [ __all__ = [
'DebconfController', 'DebconfController',
'DriversController',
'EarlyController', 'EarlyController',
'ErrorController', 'ErrorController',
'FilesystemController', 'FilesystemController',

View File

@ -0,0 +1,103 @@
# Copyright 2022 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging
from typing import List, Optional
from subiquitycore.context import with_context
from subiquity.common.apidef import API
from subiquity.common.types import DriversPayload, DriversResponse
from subiquity.server.controller import SubiquityController
from subiquity.server.types import InstallerChannels
from subiquity.server.ubuntu_drivers import (
CommandNotFoundError,
UbuntuDriversInterface,
get_ubuntu_drivers_interface,
)
log = logging.getLogger('subiquity.server.controllers.drivers')
class DriversController(SubiquityController):
endpoint = API.drivers
autoinstall_key = model_name = "drivers"
autoinstall_schema = {
'type': 'object',
'properties': {
'install': {
'type': 'boolean',
},
},
}
autoinstall_default = {"install": False}
drivers: Optional[List[str]] = None
def __init__(self, app) -> None:
super().__init__(app)
self.ubuntu_drivers: UbuntuDriversInterface = \
get_ubuntu_drivers_interface(app)
def make_autoinstall(self):
return {
"install": self.model.do_install,
}
def load_autoinstall_data(self, data):
if data is not None and "install" in data:
self.model.do_install = data["install"]
def start(self):
self._wait_apt = asyncio.Event()
self.app.hub.subscribe(
InstallerChannels.APT_CONFIGURED,
self._wait_apt.set)
self._drivers_task = asyncio.create_task(self._list_drivers())
@with_context()
async def _list_drivers(self, context):
with context.child("wait_apt"):
await self._wait_apt.wait()
apt = self.app.controllers.Mirror.apt_configurer
async with apt.overlay() as d:
try:
# Make sure ubuntu-drivers is available.
self.ubuntu_drivers.ensure_cmd_exists(d.mountpoint)
except CommandNotFoundError:
self.drivers = []
else:
self.drivers = await self.ubuntu_drivers.list_drivers(
root_dir=d.mountpoint,
context=context)
log.debug("Available drivers to install: %s", self.drivers)
if not self.drivers:
await self.configured()
else:
# TODO Remove this once we have the GUI controller.
await self.POST(install=True)
async def GET(self, wait: bool = False) -> DriversResponse:
if wait:
await self._drivers_task
return DriversResponse(install=self.model.do_install,
drivers=self.drivers)
async def POST(self, data: DriversPayload) -> None:
self.model.do_install = data.install
await self.configured()

View File

@ -45,6 +45,10 @@ from subiquity.server.curtin import (
run_curtin_command, run_curtin_command,
start_curtin_command, start_curtin_command,
) )
from subiquity.server.types import (
InstallerChannels,
)
log = logging.getLogger("subiquity.server.controllers.install") log = logging.getLogger("subiquity.server.controllers.install")
@ -153,6 +157,8 @@ class InstallController(SubiquityController):
for_install_path = await self.configure_apt(context=context) for_install_path = await self.configure_apt(context=context)
await self.app.hub.abroadcast(InstallerChannels.APT_CONFIGURED)
if os.path.exists(self.model.target): if os.path.exists(self.model.target):
await self.unmount_target( await self.unmount_target(
context=context, target=self.model.target) context=context, target=self.model.target)
@ -190,6 +196,13 @@ class InstallController(SubiquityController):
packages = await self.get_target_packages(context=context) packages = await self.get_target_packages(context=context)
for package in packages: for package in packages:
await self.install_package(context=context, package=package) await self.install_package(context=context, package=package)
if self.model.drivers.do_install:
with context.child(
"ubuntu-drivers-install",
"installing third-party drivers") as child:
ubuntu_drivers = self.app.controllers.Drivers.ubuntu_drivers
await ubuntu_drivers.install_drivers(root_dir=self.tpath(),
context=child)
if self.model.network.has_network: if self.model.network.has_network:
self.app.update_state(ApplicationState.UU_RUNNING) self.app.update_state(ApplicationState.UU_RUNNING)

View File

@ -100,17 +100,17 @@ class _CurtinCommand:
cmd.extend(args) cmd.extend(args)
return cmd return cmd
async def start(self, context): async def start(self, context, **opts):
self._fd = journald_listen( self._fd = journald_listen(
asyncio.get_event_loop(), [self._event_syslog_id], self._event) asyncio.get_event_loop(), [self._event_syslog_id], self._event)
# Yield to the event loop before starting curtin to avoid missing the # Yield to the event loop before starting curtin to avoid missing the
# first couple of events. # first couple of events.
await asyncio.sleep(0) await asyncio.sleep(0)
self._event_contexts[''] = context self._event_contexts[''] = context
self.proc = await self.runner.start(self._cmd) self.proc = await self.runner.start(self._cmd, **opts)
async def wait(self): async def wait(self):
await self.runner.wait(self.proc) result = await self.runner.wait(self.proc)
waited = 0.0 waited = 0.0
while len(self._event_contexts) > 1 and waited < 5.0: while len(self._event_contexts) > 1 and waited < 5.0:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
@ -118,10 +118,11 @@ class _CurtinCommand:
log.debug("waited %s seconds for events to drain", waited) log.debug("waited %s seconds for events to drain", waited)
self._event_contexts.pop('', None) self._event_contexts.pop('', None)
asyncio.get_event_loop().remove_reader(self._fd) asyncio.get_event_loop().remove_reader(self._fd)
return result
async def run(self, context): async def run(self, context):
await self.start(context) await self.start(context)
await self.wait() return await self.wait()
class _DryRunCurtinCommand(_CurtinCommand): class _DryRunCurtinCommand(_CurtinCommand):
@ -146,7 +147,8 @@ class _FailingDryRunCurtinCommand(_DryRunCurtinCommand):
event_file = 'examples/curtin-events-fail.json' event_file = 'examples/curtin-events-fail.json'
async def start_curtin_command(app, context, command, *args, config=None): async def start_curtin_command(
app, context, command, *args, config=None, **opts):
if app.opts.dry_run: if app.opts.dry_run:
if 'install-fail' in app.debug_flags: if 'install-fail' in app.debug_flags:
cls = _FailingDryRunCurtinCommand cls = _FailingDryRunCurtinCommand
@ -156,11 +158,12 @@ async def start_curtin_command(app, context, command, *args, config=None):
cls = _CurtinCommand cls = _CurtinCommand
curtin_cmd = cls(app.opts, app.command_runner, command, *args, curtin_cmd = cls(app.opts, app.command_runner, command, *args,
config=config) config=config)
await curtin_cmd.start(context) await curtin_cmd.start(context, **opts)
return curtin_cmd return curtin_cmd
async def run_curtin_command(app, context, command, *args, config=None): async def run_curtin_command(
app, context, command, *args, config=None, **opts):
cmd = await start_curtin_command( cmd = await start_curtin_command(
app, context, command, *args, config=config) app, context, command, *args, config=config, **opts)
await cmd.wait() return await cmd.wait()

View File

@ -24,22 +24,27 @@ class LoggedCommandRunner:
def __init__(self, ident): def __init__(self, ident):
self.ident = ident self.ident = ident
async def start(self, cmd): async def start(self, cmd, *, capture=False):
proc = await astart_command([ if not capture:
'systemd-cat', '--level-prefix=false', '--identifier='+self.ident, cmd = [
] + cmd) 'systemd-cat',
'--level-prefix=false',
'--identifier='+self.ident,
] + cmd
proc = await astart_command(cmd)
proc.args = cmd proc.args = cmd
return proc return proc
async def wait(self, proc): async def wait(self, proc):
await proc.communicate() stdout, stderr = await proc.communicate()
if proc.returncode != 0: if proc.returncode != 0:
raise subprocess.CalledProcessError(proc.returncode, proc.args) raise subprocess.CalledProcessError(proc.returncode, proc.args)
else: else:
return subprocess.CompletedProcess(proc.args, proc.returncode) return subprocess.CompletedProcess(
proc.args, proc.returncode, stdout=stdout, stderr=stderr)
async def run(self, cmd): async def run(self, cmd, **opts):
proc = await self.start(cmd) proc = await self.start(cmd, **opts)
return await self.wait(proc) return await self.wait(proc)
@ -49,7 +54,7 @@ class DryRunCommandRunner(LoggedCommandRunner):
super().__init__(ident) super().__init__(ident)
self.delay = delay self.delay = delay
async def start(self, cmd): async def start(self, cmd, *, capture=False):
if 'scripts/replay-curtin-log.py' in cmd: if 'scripts/replay-curtin-log.py' in cmd:
delay = 0 delay = 0
else: else:
@ -58,7 +63,7 @@ class DryRunCommandRunner(LoggedCommandRunner):
delay = 3*self.delay delay = 3*self.delay
else: else:
delay = self.delay delay = self.delay
proc = await super().start(cmd) proc = await super().start(cmd, capture=capture)
await asyncio.sleep(delay) await asyncio.sleep(delay)
return proc return proc

View File

@ -200,6 +200,7 @@ INSTALL_MODEL_NAMES = ModelNames({
}) })
POSTINSTALL_MODEL_NAMES = ModelNames({ POSTINSTALL_MODEL_NAMES = ModelNames({
"drivers",
"identity", "identity",
"locale", "locale",
"network", "network",
@ -252,6 +253,7 @@ class SubiquityServer(Application):
"Identity", "Identity",
"SSH", "SSH",
"SnapList", "SnapList",
"Drivers",
"TimeZone", "TimeZone",
"Install", "Install",
"Updates", "Updates",

View File

@ -13,10 +13,11 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from unittest.mock import Mock from unittest.mock import Mock, patch, AsyncMock
from subiquitycore.tests import SubiTestCase from subiquitycore.tests import SubiTestCase
from subiquitycore.tests.mocks import make_app from subiquitycore.tests.mocks import make_app
from subiquitycore.tests.util import run_coro
from subiquity.server.apt import ( from subiquity.server.apt import (
AptConfigurer, AptConfigurer,
Mountpoint, Mountpoint,
@ -50,6 +51,34 @@ class TestAptConfigurer(SubiTestCase):
expected['apt']['https_proxy'] = proxy expected['apt']['https_proxy'] = proxy
self.assertEqual(expected, self.configurer.apt_config()) self.assertEqual(expected, self.configurer.apt_config())
def test_mount_unmount(self):
# Make sure we can unmount something that we mounted before.
with patch.object(self.app, "command_runner",
create=True, new_callable=AsyncMock):
m = run_coro(self.configurer.mount("/dev/cdrom", "/target"))
run_coro(self.configurer.unmount(m))
def test_overlay(self):
self.configurer.install_tree = OverlayMountpoint(
upperdir="upperdir-install-tree",
lowers=["lowers1-install-tree"],
mountpoint="mountpoint-install-tree",
)
self.configurer.configured_tree = OverlayMountpoint(
upperdir="upperdir-install-tree",
lowers=["lowers1-install-tree"],
mountpoint="mountpoint-install-tree",
)
self.source = "source"
async def coro():
async with self.configurer.overlay():
pass
with patch.object(self.app, "command_runner",
create=True, new_callable=AsyncMock):
run_coro(coro())
class TestLowerDirFor(SubiTestCase): class TestLowerDirFor(SubiTestCase):
def test_lowerdir_for_str(self): def test_lowerdir_for_str(self):

View File

@ -0,0 +1,135 @@
# Copyright 2022 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from subprocess import CalledProcessError
import unittest
from unittest.mock import patch, AsyncMock, Mock
from subiquitycore.tests.mocks import make_app
from subiquitycore.tests.util import run_coro
from subiquity.server.ubuntu_drivers import (
UbuntuDriversInterface,
UbuntuDriversClientInterface,
CommandNotFoundError,
)
class TestUbuntuDriversInterface(unittest.TestCase):
def setUp(self):
self.app = make_app()
@patch.multiple(UbuntuDriversInterface, __abstractmethods__=set())
def test_init(self):
ubuntu_drivers = UbuntuDriversInterface(self.app, gpgpu=False)
self.assertEqual(ubuntu_drivers.app, self.app)
self.assertEqual(ubuntu_drivers.list_drivers_cmd, [
"ubuntu-drivers", "list",
"--recommended",
])
self.assertEqual(ubuntu_drivers.install_drivers_cmd, [
"ubuntu-drivers", "install"])
ubuntu_drivers = UbuntuDriversInterface(self.app, gpgpu=True)
self.assertEqual(ubuntu_drivers.list_drivers_cmd, [
"ubuntu-drivers", "list",
"--recommended", "--gpgpu",
])
self.assertEqual(ubuntu_drivers.install_drivers_cmd, [
"ubuntu-drivers", "install", "--gpgpu"])
@patch.multiple(UbuntuDriversInterface, __abstractmethods__=set())
@patch("subiquity.server.ubuntu_drivers.run_curtin_command")
def test_install_drivers(self, mock_run_curtin_command):
ubuntu_drivers = UbuntuDriversInterface(self.app, gpgpu=False)
run_coro(ubuntu_drivers.install_drivers(
root_dir="/target",
context="installing third-party drivers"))
mock_run_curtin_command.assert_called_once_with(
self.app, "installing third-party drivers",
"in-target", "-t", "/target",
"--",
"ubuntu-drivers", "install"
)
@patch.multiple(UbuntuDriversInterface, __abstractmethods__=set())
def test_drivers_from_output(self):
ubuntu_drivers = UbuntuDriversInterface(self.app, gpgpu=False)
output = """\
nvidia-driver-470 linux-modules-nvidia-470-generic-hwe-20.04
"""
self.assertEqual(
ubuntu_drivers._drivers_from_output(output=output),
["nvidia-driver-470"])
# Make sure empty lines are discarded
output = """
nvidia-driver-470 linux-modules-nvidia-470-generic-hwe-20.04
nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04
"""
self.assertEqual(
ubuntu_drivers._drivers_from_output(output=output),
["nvidia-driver-470", "nvidia-driver-510"])
class TestUbuntuDriversClientInterface(unittest.TestCase):
def setUp(self):
self.app = make_app()
self.ubuntu_drivers = UbuntuDriversClientInterface(
self.app, gpgpu=False)
def test_ensure_cmd_exists(self):
with patch.object(
self.app, "command_runner",
create=True, new_callable=AsyncMock) as mock_runner:
# On success
run_coro(self.ubuntu_drivers.ensure_cmd_exists("/target"))
mock_runner.run.assert_called_once_with(
[
"chroot", "/target",
"sh", "-c", "command -v ubuntu-drivers",
])
# On process failure
mock_runner.run.side_effect = CalledProcessError(
returncode=1,
cmd=["sh", "-c", "command -v ubuntu-drivers"])
with self.assertRaises(CommandNotFoundError):
run_coro(self.ubuntu_drivers.ensure_cmd_exists("/target"))
@patch("subiquity.server.ubuntu_drivers.run_curtin_command")
def test_list_drivers(self, mock_run_curtin_command):
# Make sure this gets decoded as utf-8.
mock_run_curtin_command.return_value = Mock(stdout=b"""\
nvidia-driver-510 linux-modules-nvidia-510-generic-hwe-20.04
""")
drivers = run_coro(self.ubuntu_drivers.list_drivers(
root_dir="/target",
context="listing third-party drivers"))
mock_run_curtin_command.assert_called_once_with(
self.app, "listing third-party drivers",
"in-target", "-t", "/target",
"--",
"ubuntu-drivers", "list", "--recommended",
capture=True)
self.assertEqual(drivers, ["nvidia-driver-510"])

View File

@ -21,3 +21,4 @@ class InstallerChannels(CoreChannels):
SNAPD_NETWORK_CHANGE = 'snapd-network-change' SNAPD_NETWORK_CHANGE = 'snapd-network-change'
GEOIP = 'geoip' GEOIP = 'geoip'
CONFIGURED = 'configured' CONFIGURED = 'configured'
APT_CONFIGURED = 'apt-configured'

View File

@ -0,0 +1,162 @@
# Copyright 2022 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
""" Module that defines helpers to use the ubuntu-drivers command. """
from abc import ABC, abstractmethod
import logging
import subprocess
from typing import List, Type
from subiquity.server.curtin import run_curtin_command
from subiquitycore.utils import arun_command
log = logging.getLogger("subiquity.server.ubuntu_drivers")
class CommandNotFoundError(Exception):
""" Exception to be raised when the ubuntu-drivers command is not
available.
"""
class UbuntuDriversInterface(ABC):
def __init__(self, app, gpgpu: bool) -> None:
self.app = app
self.list_drivers_cmd = [
"ubuntu-drivers", "list",
"--recommended",
]
self.install_drivers_cmd = [
"ubuntu-drivers", "install",
]
if gpgpu:
self.list_drivers_cmd.append("--gpgpu")
self.install_drivers_cmd.append("--gpgpu")
@abstractmethod
async def ensure_cmd_exists(self, root_dir: str) -> None:
pass
@abstractmethod
async def list_drivers(self, root_dir: str, context) -> List[str]:
pass
async def install_drivers(self, root_dir: str, context) -> None:
await run_curtin_command(
self.app, context,
"in-target", "-t", root_dir, "--", *self.install_drivers_cmd)
def _drivers_from_output(self, output: str) -> List[str]:
""" Parse the output of ubuntu-drivers list --recommended and return a
list of drivers. """
drivers: List[str] = []
# Drivers are listed one per line, but each driver is followed by a
# linux-modules-* package (which we are not interested in showing).
# e.g.,:
# $ ubuntu-drivers list --recommended
# nvidia-driver-470 linux-modules-nvidia-470-generic-hwe-20.04
for line in [x.strip() for x in output.split("\n")]:
if not line:
continue
drivers.append(line.split(" ", maxsplit=1)[0])
return drivers
class UbuntuDriversClientInterface(UbuntuDriversInterface):
""" UbuntuDrivers interface that uses the ubuntu-drivers command from the
specified root directory. """
async def ensure_cmd_exists(self, root_dir: str) -> None:
# TODO This does not tell us if the "--recommended" option is
# available.
try:
await self.app.command_runner.run(
['chroot', root_dir,
'sh', '-c',
"command -v ubuntu-drivers"])
except subprocess.CalledProcessError:
raise CommandNotFoundError(
f"Command ubuntu-drivers is not available in {root_dir}")
async def list_drivers(self, root_dir: str, context) -> List[str]:
result = await run_curtin_command(
self.app, context,
"in-target", "-t", root_dir, "--", *self.list_drivers_cmd,
capture=True)
# Currently we have no way to specify universal_newlines=True or
# encoding="utf-8" to run_curtin_command so we need to decode the
# output.
return self._drivers_from_output(result.stdout.decode("utf-8"))
class UbuntuDriversHasDriversInterface(UbuntuDriversInterface):
""" A dry-run implementation of ubuntu-drivers that returns a hard-coded
list of drivers. """
gpgpu_drivers: List[str] = ["nvidia-driver-470-server"]
not_gpgpu_drivers: List[str] = ["nvidia-driver-510"]
def __init__(self, app, gpgpu: bool) -> None:
super().__init__(app, gpgpu)
self.drivers = self.gpgpu_drivers if gpgpu else self.not_gpgpu_drivers
async def ensure_cmd_exists(self, root_dir: str) -> None:
pass
async def list_drivers(self, root_dir: str, context) -> List[str]:
return self.drivers
class UbuntuDriversNoDriversInterface(UbuntuDriversHasDriversInterface):
""" A dry-run implementation of ubuntu-drivers that returns a hard-coded
empty list of drivers. """
gpgpu_drivers: List[str] = []
not_gpgpu_drivers: List[str] = []
class UbuntuDriversRunDriversInterface(UbuntuDriversInterface):
""" A dry-run implementation of ubuntu-drivers that actually runs the
ubuntu-drivers command but locally. """
async def ensure_cmd_exists(self, root_dir: str) -> None:
# TODO This does not tell us if the "--recommended" option is
# available.
try:
await arun_command(["command", "-v", "ubuntu-drivers"])
except subprocess.CalledProcessError:
raise CommandNotFoundError(
"Command ubuntu-drivers is not available in this system")
async def list_drivers(self, root_dir: str, context) -> List[str]:
# We run the command locally - ignoring the root_dir.
result = await arun_command(self.list_drivers_cmd)
return self._drivers_from_output(result.stdout)
def get_ubuntu_drivers_interface(app) -> UbuntuDriversInterface:
is_server = app.base_model.source.current.variant == "server"
cls: Type[UbuntuDriversInterface] = UbuntuDriversClientInterface
if app.opts.dry_run:
if 'has-drivers' in app.debug_flags:
cls = UbuntuDriversHasDriversInterface
elif 'run-drivers' in app.debug_flags:
cls = UbuntuDriversRunDriversInterface
else:
cls = UbuntuDriversNoDriversInterface
return cls(app, gpgpu=is_server)