Merge pull request #1559 from ogayot/ssh-import-mock

Stop calling ssh-import-id as part of the CI - use mock responses instead
This commit is contained in:
Olivier Gayot 2023-02-14 11:23:56 +01:00 committed by GitHub
commit 7848ccee46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 260 additions and 116 deletions

View File

@ -23,7 +23,9 @@ Identity:
hostname: ubuntu-server
# ubuntu
password: '$6$wdAcoXrU039hKYPd$508Qvbe7ObUnxoj15DRCkzC3qO7edjH0VV7BPNRDYK4QR8ofJaEEF2heacn0QgD.f8pO8SNp83XNdWG6tocBM1'
ssh-import-id: lp:mwhudson
# We have a specific dry-run behaviour for heracles, making the call to
# ssh-import-id succeed unconditionally.
ssh-import-id: lp:heracles
UbuntuPro:
token: ""
SnapList:

View File

@ -14,11 +14,8 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import logging
import os
import subprocess
from typing import List
from subiquitycore.utils import arun_command
from subiquity.common.apidef import API
from subiquity.common.types import (
SSHData,
@ -27,18 +24,16 @@ from subiquity.common.types import (
SSHIdentity,
)
from subiquity.server.controller import SubiquityController
from subiquity.server.ssh import (
SSHFetchError,
SSHKeyFetcher,
DryRunSSHKeyFetcher,
)
from subiquity.server.types import InstallerChannels
log = logging.getLogger('subiquity.server.controllers.ssh')
class SSHFetchError(Exception):
def __init__(self, status: SSHFetchIdStatus, reason: str) -> None:
self.reason = reason
self.status = status
super().__init__()
class SSHController(SubiquityController):
endpoint = API.ssh
@ -61,6 +56,10 @@ class SSHController(SubiquityController):
self.app.hub.subscribe(
InstallerChannels.INSTALL_CONFIRMED, self._confirmed)
self._active = True
if app.opts.dry_run:
self.fetcher = DryRunSSHKeyFetcher(app)
else:
self.fetcher = SSHKeyFetcher(app)
async def _confirmed(self):
if self.app.base_model.source.current.variant == 'desktop':
@ -98,43 +97,13 @@ class SSHController(SubiquityController):
self.model.pwauth = data.allow_pw
await self.configured()
async def fetch_keys_for_id(self, user_id: str) -> List[str]:
cmd = ('ssh-import-id', '--output', '-', '--', user_id)
env = None
if self.app.base_model.proxy.proxy:
env = os.environ.copy()
env["https_proxy"] = self.app.base_model.proxy.proxy
try:
cp = await arun_command(cmd, check=True, env=env)
except subprocess.CalledProcessError as exc:
log.exception("ssh-import-id failed. stderr: %s", exc.stderr)
raise SSHFetchError(status=SSHFetchIdStatus.IMPORT_ERROR,
reason=exc.stderr)
keys_material: str = cp.stdout.replace('\r', '').strip()
return [mat for mat in keys_material.splitlines() if mat]
async def gen_fingerprint_for_key(self, key: str) -> str:
""" For a given key, generate the fingerprint. """
# ssh-keygen supports multiple keys at once, but it is simpler to
# associate each key with its resulting fingerprint if we call
# ssh-keygen multiple times.
cmd = ('ssh-keygen', '-l', '-f', '-')
try:
cp = await arun_command(cmd, check=True, input=key)
except subprocess.CalledProcessError as exc:
log.exception("ssh-import-id failed. stderr: %s", exc.stderr)
raise SSHFetchError(status=SSHFetchIdStatus.FINGERPRINT_ERROR,
reason=exc.stderr)
return cp.stdout.strip()
async def fetch_id_GET(self, user_id: str) -> SSHFetchIdResponse:
identities: List[SSHIdentity] = []
try:
for key_material in await self.fetch_keys_for_id(user_id):
fingerprint = await self.gen_fingerprint_for_key(key_material)
for key_material in await self.fetcher.fetch_keys_for_id(user_id):
fingerprint = await self.fetcher.gen_fingerprint_for_key(
key_material)
fingerprint = fingerprint.replace(
f'# ssh-import-id {user_id}', ''

View File

@ -15,7 +15,6 @@
import unittest
from unittest import mock
from subprocess import CompletedProcess, CalledProcessError
from subiquity.common.types import SSHFetchIdStatus, SSHIdentity
from subiquity.server.controllers.ssh import (
@ -27,83 +26,19 @@ from subiquitycore.tests.mocks import make_app
class TestSSHController(unittest.IsolatedAsyncioTestCase):
arun_command_sym = "subiquity.server.controllers.ssh.arun_command"
def setUp(self):
self.controller = SSHController(make_app())
async def test_fetch_keys_for_id_one_key_ok(self):
with mock.patch(self.arun_command_sym) as mock_arun:
mock_arun.return_value = CompletedProcess([], 0)
mock_arun.return_value.stdout = """
ssh-rsa AAAAC3NzaC1lZ test@gh/335797 # ssh-import-id gh:test
"""
keys = await self.controller.fetch_keys_for_id(user_id="gh:test")
self.assertEqual(keys, [
"ssh-rsa AAAAC3NzaC1lZ test@gh/335797 # ssh-import-id gh:test",
])
async def test_fetch_keys_for_id_two_key_ok(self):
with mock.patch(self.arun_command_sym) as mock_arun:
mock_arun.return_value = CompletedProcess([], 0)
mock_arun.return_value.stdout = """\
ssh-rsa AAAAC3NzaC1lZ test@host # ssh-import-id lp:test
ssh-ed25519 AAAAAC3N test@host # ssh-import-id lp:test
"""
keys = await self.controller.fetch_keys_for_id(user_id="lp:test")
self.assertEqual(keys, [
"ssh-rsa AAAAC3NzaC1lZ test@host # ssh-import-id lp:test",
"ssh-ed25519 AAAAAC3N test@host # ssh-import-id lp:test",
])
async def test_fetch_keys_for_id_error(self):
with mock.patch(self.arun_command_sym) as mock_arun:
stderr = """\
2022-11-22 14:00:12,336 INFO [0] SSH keys [Authorized]
2022-11-22 14:00:12,337 ERROR No matching keys found for [lp:test2]
"""
mock_arun.side_effect = CalledProcessError(1, [], None, stderr)
with self.assertRaises(SSHFetchError) as cm:
await self.controller.fetch_keys_for_id(user_id="lp:test2")
self.assertEqual(cm.exception.reason, stderr)
self.assertEqual(cm.exception.status,
SSHFetchIdStatus.IMPORT_ERROR)
async def test_gen_fingerprint_for_key_ok(self):
with mock.patch(self.arun_command_sym) as mock_arun:
mock_arun.return_value = CompletedProcess([], 0)
mock_arun.return_value.stdout = """\
256 SHA256:rIR9UVRKspl7+KF75s test@host # ssh-import-id lp:test (ED25519)
"""
fp = await self.controller.gen_fingerprint_for_key(
"ssh-ed25519 AAAAAC3N test@host # ssh-import-id lp:test")
self.assertEqual(fp, """\
256 SHA256:rIR9UVRKspl7+KF75s test@host # ssh-import-id lp:test (ED25519)\
""")
async def test_gen_fingerprint_for_key_error(self):
stderr = "(stdin) is not a public key file.\n"
with mock.patch(self.arun_command_sym) as mock_arun:
mock_arun.side_effect = CalledProcessError(1, [], None, stderr)
with self.assertRaises(SSHFetchError) as cm:
await self.controller.gen_fingerprint_for_key(
"ssh-nsa AAAAAC3N test@host # ssh-import-id lp:test")
self.assertEqual(cm.exception.reason, stderr)
self.assertEqual(cm.exception.status,
SSHFetchIdStatus.FINGERPRINT_ERROR)
async def test_fetch_id_GET_ok(self):
key = "ssh-rsa AAAAA[..] user@host # ssh-import-id lp:user"
mock_fetch_keys = mock.patch.object(
self.controller, "fetch_keys_for_id", return_value=[key])
self.controller.fetcher, "fetch_keys_for_id",
return_value=[key])
fp = "256 SHA256:rIR9[..] user@host # ssh-import-id lp:user (ED25519)"
mock_gen_fingerprint = mock.patch.object(
self.controller, "gen_fingerprint_for_key", return_value=fp)
self.controller.fetcher, "gen_fingerprint_for_key",
return_value=fp)
with mock_fetch_keys, mock_gen_fingerprint:
response = await self.controller.fetch_id_GET(user_id="lp:user")
@ -122,7 +57,7 @@ ssh-ed25519 AAAAAC3N test@host # ssh-import-id lp:test
stderr = "ERROR No matching keys found for [lp=test2]\n"
mock_fetch_keys = mock.patch.object(
self.controller, "fetch_keys_for_id",
self.controller.fetcher, "fetch_keys_for_id",
side_effect=SSHFetchError(
status=SSHFetchIdStatus.IMPORT_ERROR,
reason=stderr))
@ -137,12 +72,13 @@ ssh-ed25519 AAAAAC3N test@host # ssh-import-id lp:test
async def test_fetch_id_GET_fingerprint_error(self):
key = "ssh-rsa AAAAA[..] user@host # ssh-import-id lp:user"
mock_fetch_keys = mock.patch.object(
self.controller, "fetch_keys_for_id", return_value=[key])
self.controller.fetcher, "fetch_keys_for_id",
return_value=[key])
stderr = "(stdin) is not a public key file\n"
mock_gen_fingerprint = mock.patch.object(
self.controller, "gen_fingerprint_for_key",
self.controller.fetcher, "gen_fingerprint_for_key",
side_effect=SSHFetchError(
status=SSHFetchIdStatus.FINGERPRINT_ERROR,
reason=stderr))

View File

@ -38,6 +38,12 @@ class KnownMirror(TypedDict, total=False):
strategy: str
class SSHImport(TypedDict, total=True):
""" Dictionary type hints for a SSH key import. """
strategy: str
username: str
@attr.s(auto_attribs=True)
class DRConfig:
""" Configuration for dry-run-only executions.
@ -65,6 +71,12 @@ class DRConfig:
{"pattern": r"/fail(ed)?/?$", "strategy": "failure"},
]
ssh_import_default_strategy: str = "run-on-host"
ssh_imports: List[SSHImport] = [
{"username": "heracles", "strategy": "success"},
{"username": "sisyphus", "strategy": "failure"},
]
@classmethod
def load(cls, stream):
data = yaml.safe_load(stream)

110
subiquity/server/ssh.py Normal file
View File

@ -0,0 +1,110 @@
# Copyright 2023 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import enum
import logging
import os
import subprocess
from typing import List
from subiquitycore.utils import arun_command
from subiquity.common.types import (
SSHFetchIdStatus,
)
log = logging.getLogger("subiquity.server.ssh")
class SSHFetchError(Exception):
def __init__(self, status: SSHFetchIdStatus, reason: str) -> None:
self.reason = reason
self.status = status
super().__init__()
class SSHKeyFetcher:
def __init__(self, app):
self.app = app
async def fetch_keys_for_id(self, user_id: str) -> List[str]:
cmd = ('ssh-import-id', '--output', '-', '--', user_id)
env = None
if self.app.base_model.proxy.proxy:
env = os.environ.copy()
env["https_proxy"] = self.app.base_model.proxy.proxy
try:
cp = await arun_command(cmd, check=True, env=env)
except subprocess.CalledProcessError as exc:
log.exception("ssh-import-id failed. stderr: %s", exc.stderr)
raise SSHFetchError(status=SSHFetchIdStatus.IMPORT_ERROR,
reason=exc.stderr)
keys_material: str = cp.stdout.replace('\r', '').strip()
return [mat for mat in keys_material.splitlines() if mat]
async def gen_fingerprint_for_key(self, key: str) -> str:
""" For a given key, generate the fingerprint. """
# ssh-keygen supports multiple keys at once, but it is simpler to
# associate each key with its resulting fingerprint if we call
# ssh-keygen multiple times.
cmd = ('ssh-keygen', '-l', '-f', '-')
try:
cp = await arun_command(cmd, check=True, input=key)
except subprocess.CalledProcessError as exc:
log.exception("ssh-import-id failed. stderr: %s", exc.stderr)
raise SSHFetchError(status=SSHFetchIdStatus.FINGERPRINT_ERROR,
reason=exc.stderr)
return cp.stdout.strip()
class DryRunSSHKeyFetcher(SSHKeyFetcher):
class SSHImportStrategy(enum.Enum):
SUCCESS = "success"
FAILURE = "failure"
RUN_ON_HOST = "run-on-host"
async def fetch_keys_fake_success(self, user_id: str) -> List[str]:
unused, username = user_id.split(":", maxsplit=1)
return [f"""\
ssh-ed25519\
AAAAC3NzaC1lZDI1NTE5AAAAIMM/qhS3hS3+IjpJBYXZWCqPKPH9Zag8QYbS548iEjoZ\
{username}@earth # ssh-import-id {user_id}"""]
async def fetch_keys_fake_failure(self, user_id: str) -> List[str]:
unused, username = user_id.split(":", maxsplit=1)
raise SSHFetchError(status=SSHFetchIdStatus.IMPORT_ERROR,
reason=f"ERROR Username {username} not found.")
async def fetch_keys_for_id(self, user_id: str) -> List[str]:
service, username = user_id.split(":", maxsplit=1)
strategy = self.SSHImportStrategy(
self.app.dr_cfg.ssh_import_default_strategy)
for entry in self.app.dr_cfg.ssh_imports:
if entry["username"] != username:
continue
strategy = self.SSHImportStrategy(entry["strategy"])
break
strategies_mapping = {
self.SSHImportStrategy.RUN_ON_HOST: super().fetch_keys_for_id,
self.SSHImportStrategy.SUCCESS: self.fetch_keys_fake_success,
self.SSHImportStrategy.FAILURE: self.fetch_keys_fake_failure,
}
coroutine = strategies_mapping[strategy](user_id)
return await coroutine

View File

@ -0,0 +1,115 @@
# Copyright 2023 Canonical, Ltd.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from subprocess import CompletedProcess, CalledProcessError
import unittest
from unittest import mock
from subiquity.common.types import SSHFetchIdStatus
from subiquity.server.ssh import (
DryRunSSHKeyFetcher,
SSHFetchError,
SSHKeyFetcher,
)
from subiquitycore.tests.mocks import make_app
class TestSSHKeyFetcher(unittest.IsolatedAsyncioTestCase):
arun_command_sym = "subiquity.server.ssh.arun_command"
def setUp(self):
self.fetcher = SSHKeyFetcher(make_app())
async def test_fetch_keys_for_id_one_key_ok(self):
with mock.patch(self.arun_command_sym) as mock_arun:
mock_arun.return_value = CompletedProcess([], 0)
mock_arun.return_value.stdout = """
ssh-rsa AAAAC3NzaC1lZ test@gh/335797 # ssh-import-id gh:test
"""
keys = await self.fetcher.fetch_keys_for_id(user_id="gh:test")
self.assertEqual(keys, [
"ssh-rsa AAAAC3NzaC1lZ test@gh/335797 # ssh-import-id gh:test",
])
async def test_fetch_keys_for_id_two_key_ok(self):
with mock.patch(self.arun_command_sym) as mock_arun:
mock_arun.return_value = CompletedProcess([], 0)
mock_arun.return_value.stdout = """\
ssh-rsa AAAAC3NzaC1lZ test@host # ssh-import-id lp:test
ssh-ed25519 AAAAAC3N test@host # ssh-import-id lp:test
"""
keys = await self.fetcher.fetch_keys_for_id(user_id="lp:test")
self.assertEqual(keys, [
"ssh-rsa AAAAC3NzaC1lZ test@host # ssh-import-id lp:test",
"ssh-ed25519 AAAAAC3N test@host # ssh-import-id lp:test",
])
async def test_fetch_keys_for_id_error(self):
with mock.patch(self.arun_command_sym) as mock_arun:
stderr = """\
2022-11-22 14:00:12,336 INFO [0] SSH keys [Authorized]
2022-11-22 14:00:12,337 ERROR No matching keys found for [lp:test2]
"""
mock_arun.side_effect = CalledProcessError(1, [], None, stderr)
with self.assertRaises(SSHFetchError) as cm:
await self.fetcher.fetch_keys_for_id(user_id="lp:test2")
self.assertEqual(cm.exception.reason, stderr)
self.assertEqual(cm.exception.status,
SSHFetchIdStatus.IMPORT_ERROR)
async def test_gen_fingerprint_for_key_ok(self):
with mock.patch(self.arun_command_sym) as mock_arun:
mock_arun.return_value = CompletedProcess([], 0)
mock_arun.return_value.stdout = """\
256 SHA256:rIR9UVRKspl7+KF75s test@host # ssh-import-id lp:test (ED25519)
"""
fp = await self.fetcher.gen_fingerprint_for_key(
"ssh-ed25519 AAAAAC3N test@host # ssh-import-id lp:test")
self.assertEqual(fp, """\
256 SHA256:rIR9UVRKspl7+KF75s test@host # ssh-import-id lp:test (ED25519)\
""")
async def test_gen_fingerprint_for_key_error(self):
stderr = "(stdin) is not a public key file.\n"
with mock.patch(self.arun_command_sym) as mock_arun:
mock_arun.side_effect = CalledProcessError(1, [], None, stderr)
with self.assertRaises(SSHFetchError) as cm:
await self.fetcher.gen_fingerprint_for_key(
"ssh-nsa AAAAAC3N test@host # ssh-import-id lp:test")
self.assertEqual(cm.exception.reason, stderr)
self.assertEqual(cm.exception.status,
SSHFetchIdStatus.FINGERPRINT_ERROR)
class TestDryRunSSHKeyFetcher(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.fetcher = DryRunSSHKeyFetcher(make_app())
async def test_fetch_keys_fake_success(self):
result = await self.fetcher.fetch_keys_fake_success("lp:test")
expected = ["""\
ssh-ed25519\
AAAAC3NzaC1lZDI1NTE5AAAAIMM/qhS3hS3+IjpJBYXZWCqPKPH9Zag8QYbS548iEjoZ\
test@earth # ssh-import-id lp:test"""]
self.assertEqual(result, expected)
async def test_fetch_keys_fake_failure(self):
with self.assertRaises(SSHFetchError):
await self.fetcher.fetch_keys_fake_failure("lp:test")