Merge pull request #1257 from ogayot/mypy-fixes

Some more code cleanup
This commit is contained in:
Michael Hudson-Doyle 2022-04-13 12:09:04 +12:00 committed by GitHub
commit b07405a744
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 42 additions and 41 deletions

View File

@ -21,7 +21,7 @@ import os
import signal import signal
import sys import sys
import traceback import traceback
from typing import Optional from typing import Dict, List, Optional
import aiohttp import aiohttp
@ -118,7 +118,7 @@ class SubiquityClient(TuiApplication):
"Progress", "Progress",
] ]
variant_to_controllers = {} variant_to_controllers: Dict[str, List[str]] = {}
def __init__(self, opts): def __init__(self, opts):
if is_linux_tty(): if is_linux_tty():

View File

@ -15,7 +15,7 @@
import abc import abc
import functools import functools
from typing import Optional from typing import Any, Optional
import attr import attr
@ -125,7 +125,7 @@ class SetAttrPlan(MakeBootDevicePlan):
device: object device: object
attr: str attr: str
val: str val: Any
def apply(self, manipulator): def apply(self, manipulator):
setattr(self.device, self.attr, self.val) setattr(self.device, self.attr, self.val)

View File

@ -729,7 +729,7 @@ class Partition(_Formattable):
@fsobj("raid") @fsobj("raid")
class Raid(_Device): class Raid(_Device):
name = attr.ib() name = attr.ib()
raidlevel = attr.ib(converter=lambda x: raidlevels_by_value[x].value) raidlevel: str = attr.ib(converter=lambda x: raidlevels_by_value[x].value)
devices = attributes.reflist( devices = attributes.reflist(
backlink="_constructed_device", default=attr.Factory(set)) backlink="_constructed_device", default=attr.Factory(set))

View File

@ -15,12 +15,12 @@
import logging import logging
from subiquitycore.models.network import NetworkModel from subiquitycore.models.network import NetworkModel as CoreNetworkModel
log = logging.getLogger('subiquity.models.network') log = logging.getLogger('subiquity.models.network')
class NetworkModel(NetworkModel): class NetworkModel(CoreNetworkModel):
def __init__(self): def __init__(self):
super().__init__("subiquity") super().__init__("subiquity")

View File

@ -134,8 +134,8 @@ class AptConfigurer:
self.configured_tree: Optional[OverlayMountpoint] = None self.configured_tree: Optional[OverlayMountpoint] = None
self.install_tree: Optional[OverlayMountpoint] = None self.install_tree: Optional[OverlayMountpoint] = None
self.install_mount = None self.install_mount = None
self._mounts = [] self._mounts: List[Mountpoint] = []
self._tdirs = [] self._tdirs: List[str] = []
def tdir(self): def tdir(self):
d = tempfile.mkdtemp() d = tempfile.mkdtemp()

View File

@ -291,7 +291,7 @@ class NetworkController(BaseNetworkController, SubiquityController):
await self.configured() await self.configured()
async def global_addresses_GET(self) -> List[str]: async def global_addresses_GET(self) -> List[str]:
ips = [] ips: List[str] = []
for dev in self.model.get_all_netdevs(): for dev in self.model.get_all_netdevs():
ips.extend(map(str, dev.actual_global_ip_addresses)) ips.extend(map(str, dev.actual_global_ip_addresses))
return ips return ips

View File

@ -25,14 +25,12 @@ from curtin.reporter.events import (
report_start_event, report_start_event,
status, status,
) )
from curtin.reporter.handlers import ( from curtin.reporter.handlers import LogHandler as CurtinLogHandler
LogHandler,
)
from subiquity.server.controller import NonInteractiveController from subiquity.server.controller import NonInteractiveController
class LogHandler(LogHandler): class LogHandler(CurtinLogHandler):
def publish_event(self, event): def publish_event(self, event):
level = getattr(logging, event.level) level = getattr(logging, event.level)
logger = logging.getLogger('') logger = logging.getLogger('')

View File

@ -20,13 +20,13 @@ import logging
import os import os
import subprocess import subprocess
import sys import sys
from typing import List from typing import Dict, List, Type
from curtin.commands.install import ( from curtin.commands.install import (
INSTALL_LOG, INSTALL_LOG,
) )
from subiquitycore.context import Status from subiquitycore.context import Context, Status
from subiquity.journald import ( from subiquity.journald import (
journald_listen, journald_listen,
@ -43,7 +43,7 @@ class _CurtinCommand:
config=None, private_mounts: bool): config=None, private_mounts: bool):
self.opts = opts self.opts = opts
self.runner = runner self.runner = runner
self._event_contexts = {} self._event_contexts: Dict[str, Context] = {}
_CurtinCommand._count += 1 _CurtinCommand._count += 1
self._event_syslog_id = 'curtin_event.%s.%s' % ( self._event_syslog_id = 'curtin_event.%s.%s' % (
os.getpid(), _CurtinCommand._count) os.getpid(), _CurtinCommand._count)
@ -155,7 +155,8 @@ class _FailingDryRunCurtinCommand(_DryRunCurtinCommand):
async def start_curtin_command(app, context, async def start_curtin_command(app, context,
command: str, *args: str, command: str, *args: str,
config=None, private_mounts: bool, config=None, private_mounts: bool,
**opts): **opts) -> _CurtinCommand:
cls: Type[_CurtinCommand]
if app.opts.dry_run: if app.opts.dry_run:
if 'install-fail' in app.debug_flags: if 'install-fail' in app.debug_flags:
cls = _FailingDryRunCurtinCommand cls = _FailingDryRunCurtinCommand

View File

@ -78,6 +78,8 @@ class LoggedCommandRunner:
async def wait(self, proc: asyncio.subprocess.Process) \ async def wait(self, proc: asyncio.subprocess.Process) \
-> subprocess.CompletedProcess: -> subprocess.CompletedProcess:
stdout, stderr = await proc.communicate() stdout, stderr = await proc.communicate()
# .communicate() forces returncode to be set to a value
assert(proc.returncode is not None)
if proc.returncode != 0: if proc.returncode != 0:
raise subprocess.CalledProcessError(proc.returncode, proc.args) raise subprocess.CalledProcessError(proc.returncode, proc.args)
else: else:

View File

@ -143,7 +143,7 @@ class MetaController:
return self.app.variant return self.app.variant
async def ssh_info_GET(self) -> Optional[LiveSessionSSHInfo]: async def ssh_info_GET(self) -> Optional[LiveSessionSSHInfo]:
ips = [] ips: List[str] = []
if self.app.base_model.network: if self.app.base_model.network:
for dev in self.app.base_model.network.get_all_netdevs(): for dev in self.app.base_model.network.get_all_netdevs():
ips.extend(map(str, dev.actual_global_ip_addresses)) ips.extend(map(str, dev.actual_global_ip_addresses))

View File

@ -57,7 +57,7 @@ class TestMockedUAInterfaceStrategy(unittest.TestCase):
class TestUAClientUAInterfaceStrategy(unittest.TestCase): class TestUAClientUAInterfaceStrategy(unittest.TestCase):
arun_command = "subiquity.server.ubuntu_advantage.utils.arun_command" arun_command_sym = "subiquity.server.ubuntu_advantage.utils.arun_command"
def test_init(self): def test_init(self):
# Default initializer. # Default initializer.
@ -84,7 +84,7 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
"--simulate-with-token", "123456789", "--simulate-with-token", "123456789",
) )
with patch(self.arun_command) as mock_arun: with patch(self.arun_command_sym) as mock_arun:
mock_arun.return_value = CompletedProcess([], 0) mock_arun.return_value = CompletedProcess([], 0)
mock_arun.return_value.stdout = "{}" mock_arun.return_value.stdout = "{}"
run_coro(strategy.query_info(token="123456789")) run_coro(strategy.query_info(token="123456789"))
@ -99,7 +99,7 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
"--simulate-with-token", "123456789", "--simulate-with-token", "123456789",
) )
with patch(self.arun_command) as mock_arun: with patch(self.arun_command_sym) as mock_arun:
mock_arun.side_effect = CalledProcessError(returncode=1, mock_arun.side_effect = CalledProcessError(returncode=1,
cmd=command) cmd=command)
mock_arun.return_value.stdout = "{}" mock_arun.return_value.stdout = "{}"
@ -116,7 +116,7 @@ class TestUAClientUAInterfaceStrategy(unittest.TestCase):
"--simulate-with-token", "123456789", "--simulate-with-token", "123456789",
) )
with patch(self.arun_command) as mock_arun: with patch(self.arun_command_sym) as mock_arun:
mock_arun.return_value = CompletedProcess([], 0) mock_arun.return_value = CompletedProcess([], 0)
mock_arun.return_value.stdout = "invalid-json" mock_arun.return_value.stdout = "invalid-json"
with self.assertRaises(CheckSubscriptionError): with self.assertRaises(CheckSubscriptionError):

View File

@ -1,11 +1,12 @@
# This file is part of subiquity. See LICENSE file for license information. # This file is part of subiquity. See LICENSE file for license information.
import shlex import shlex
from typing import Dict
LSB_RELEASE_FILE = "/etc/lsb-release" LSB_RELEASE_FILE = "/etc/lsb-release"
LSB_RELEASE_EXAMPLE = "examples/lsb-release-focal" LSB_RELEASE_EXAMPLE = "examples/lsb-release-focal"
def lsb_release(path=None, dry_run: bool = False): def lsb_release(path=None, dry_run: bool = False) -> Dict[str, str]:
"""return a dictionary of values from /etc/lsb-release. """return a dictionary of values from /etc/lsb-release.
keys are lower case with DISTRIB_ prefix removed.""" keys are lower case with DISTRIB_ prefix removed."""
if dry_run and path is not None: if dry_run and path is not None:
@ -14,7 +15,7 @@ def lsb_release(path=None, dry_run: bool = False):
if path is None: if path is None:
path = LSB_RELEASE_EXAMPLE if dry_run else LSB_RELEASE_FILE path = LSB_RELEASE_EXAMPLE if dry_run else LSB_RELEASE_FILE
ret = {} ret: Dict[str, str] = {}
try: try:
with open(path, "r") as fp: with open(path, "r") as fp:
content = fp.read() content = fp.read()

View File

@ -20,7 +20,7 @@ import logging
import yaml import yaml
from socket import AF_INET, AF_INET6 from socket import AF_INET, AF_INET6
import attr import attr
from typing import List, Optional from typing import Dict, List, Optional
from subiquitycore import netplan from subiquitycore import netplan
@ -208,6 +208,7 @@ class NetworkDev(object):
if self.name in dev2.config.get('interfaces', []): if self.name in dev2.config.get('interfaces', []):
bond_master = dev2.name bond_master = dev2.name
break break
bond: Optional[BondConfig] = None
if self.type == 'bond' and self.config is not None: if self.type == 'bond' and self.config is not None:
params = self.config['parameters'] params = self.config['parameters']
bond = BondConfig( bond = BondConfig(
@ -215,26 +216,22 @@ class NetworkDev(object):
mode=params['mode'], mode=params['mode'],
xmit_hash_policy=params.get('xmit-hash-policy'), xmit_hash_policy=params.get('xmit-hash-policy'),
lacp_rate=params.get('lacp-rate')) lacp_rate=params.get('lacp-rate'))
else: vlan: Optional[VLANConfig] = None
bond = None
if self.type == 'vlan' and self.config is not None: if self.type == 'vlan' and self.config is not None:
vlan = VLANConfig(id=self.config['id'], link=self.config['link']) vlan = VLANConfig(id=self.config['id'], link=self.config['link'])
else: wlan: Optional[WLANStatus] = None
vlan = None
if self.type == 'wlan': if self.type == 'wlan':
ssid, psk = self.configured_ssid ssid, psk = self.configured_ssid
wlan = WLANStatus( wlan = WLANStatus(
config=WLANConfig(ssid=ssid, psk=psk), config=WLANConfig(ssid=ssid, psk=psk),
scan_state=self.info.wlan['scan_state'], scan_state=self.info.wlan['scan_state'],
visible_ssids=self.info.wlan['visible_ssids']) visible_ssids=self.info.wlan['visible_ssids'])
else:
wlan = None
dhcp_addresses = self.dhcp_addresses() dhcp_addresses = self.dhcp_addresses()
configured_addresseses = {4: [], 6: []} configured_addresses: Dict[int, List[str]] = {4: [], 6: []}
if self.config is not None: if self.config is not None:
for addr in self.config.get('addresses', []): for addr in self.config.get('addresses', []):
configured_addresseses[addr_version(addr)].append(addr) configured_addresses[addr_version(addr)].append(addr)
ns = self.config.get('nameservers', {}) ns = self.config.get('nameservers', {})
else: else:
ns = {} ns = {}
@ -250,7 +247,7 @@ class NetworkDev(object):
else: else:
gateway = None gateway = None
static_configs[v] = StaticConfig( static_configs[v] = StaticConfig(
addresses=configured_addresseses[v], addresses=configured_addresses[v],
gateway=gateway, gateway=gateway,
nameservers=ns.get('nameservers', []), nameservers=ns.get('nameservers', []),
searchdomains=ns.get('search', [])) searchdomains=ns.get('search', []))

View File

@ -19,7 +19,7 @@ import logging
import os import os
import random import random
import subprocess import subprocess
from typing import List from typing import List, Sequence
log = logging.getLogger("subiquitycore.utils") log = logging.getLogger("subiquitycore.utils")
@ -35,7 +35,7 @@ def _clean_env(env):
return env return env
def run_command(cmd: List[str], *, input=None, stdout=subprocess.PIPE, def run_command(cmd: Sequence[str], *, input=None, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, encoding='utf-8', errors='replace', stderr=subprocess.PIPE, encoding='utf-8', errors='replace',
env=None, **kw) -> subprocess.CompletedProcess: env=None, **kw) -> subprocess.CompletedProcess:
"""A wrapper around subprocess.run with logging and different defaults. """A wrapper around subprocess.run with logging and different defaults.
@ -63,7 +63,7 @@ def run_command(cmd: List[str], *, input=None, stdout=subprocess.PIPE,
return cp return cp
async def arun_command(cmd: List[str], *, async def arun_command(cmd: Sequence[str], *,
stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
encoding='utf-8', input=None, errors='replace', encoding='utf-8', input=None, errors='replace',
env=None, check=False, **kw) \ env=None, check=False, **kw) \
@ -84,6 +84,8 @@ async def arun_command(cmd: List[str], *,
if stderr is not None: if stderr is not None:
stderr = stderr.decode(encoding) stderr = stderr.decode(encoding)
log.debug("arun_command %s exited with code %s", cmd, proc.returncode) log.debug("arun_command %s exited with code %s", cmd, proc.returncode)
# .communicate() forces returncode to be set to a value
assert(proc.returncode is not None)
if check and proc.returncode != 0: if check and proc.returncode != 0:
raise subprocess.CalledProcessError(proc.returncode, cmd) raise subprocess.CalledProcessError(proc.returncode, cmd)
else: else:
@ -91,7 +93,7 @@ async def arun_command(cmd: List[str], *,
cmd, proc.returncode, stdout, stderr) cmd, proc.returncode, stdout, stderr)
async def astart_command(cmd: List[str], *, stdout=subprocess.PIPE, async def astart_command(cmd: Sequence[str], *, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stdin=subprocess.DEVNULL, stderr=subprocess.PIPE, stdin=subprocess.DEVNULL,
env=None, **kw) -> asyncio.subprocess.Process: env=None, **kw) -> asyncio.subprocess.Process:
log.debug("astart_command called: %s", cmd) log.debug("astart_command called: %s", cmd)
@ -100,12 +102,12 @@ async def astart_command(cmd: List[str], *, stdout=subprocess.PIPE,
env=_clean_env(env), **kw) env=_clean_env(env), **kw)
async def split_cmd_output(cmd: List[str], split_on: str) -> List[str]: async def split_cmd_output(cmd: Sequence[str], split_on: str) -> List[str]:
cp = await arun_command(cmd, check=True) cp = await arun_command(cmd, check=True)
return cp.stdout.split(split_on) return cp.stdout.split(split_on)
def start_command(cmd: List[str], *, def start_command(cmd: Sequence[str], *,
stdin=subprocess.DEVNULL, stdout=subprocess.PIPE, stdin=subprocess.DEVNULL, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, encoding='utf-8', errors='replace', stderr=subprocess.PIPE, encoding='utf-8', errors='replace',
env=None, **kw) -> subprocess.Popen: env=None, **kw) -> subprocess.Popen: