diff --git a/subiquity/server/server.py b/subiquity/server/server.py index a4af100c..fb60cfe0 100644 --- a/subiquity/server/server.py +++ b/subiquity/server/server.py @@ -17,7 +17,6 @@ import asyncio import logging import os import shlex -import shutil import sys import time from typing import List, Optional @@ -37,7 +36,10 @@ import yaml from subiquitycore.async_helpers import run_in_thread from subiquitycore.context import with_context from subiquitycore.core import Application -from subiquitycore.file_util import write_file +from subiquitycore.file_util import ( + copy_file_if_exists, + write_file, + ) from subiquitycore.prober import Prober from subiquitycore.ssh import ( host_key_fingerprints, @@ -567,19 +569,9 @@ class SubiquityServer(Application): return None isopath = self.base_relative(iso_autoinstall_path) - self.copy_autoinstall(loc, isopath) + copy_file_if_exists(loc, isopath) return isopath - def copy_autoinstall(self, source, target): - if source is None or not os.path.exists(source): - return - dirname = os.path.dirname(target) - os.makedirs(dirname, exist_ok=True) - try: - shutil.copyfile(source, target) - except shutil.SameFileError: - pass - def _user_has_password(self, username): with open('/etc/shadow') as fp: for line in fp: @@ -657,9 +649,10 @@ class SubiquityServer(Application): open(stamp_file, 'w').close() await asyncio.sleep(1) self.load_autoinstall_config(only_early=False) - self.copy_autoinstall( - self.autoinstall, - self.base_relative(reload_autoinstall_path)) + if self.autoinstall is not None: + copy_file_if_exists( + self.autoinstall, + self.base_relative(reload_autoinstall_path)) if self.autoinstall_config: self.interactive = bool( self.autoinstall_config.get('interactive-sections')) diff --git a/subiquity/server/tests/test_server.py b/subiquity/server/tests/test_server.py index 19a1fd7d..e4f47180 100644 --- a/subiquity/server/tests/test_server.py +++ b/subiquity/server/tests/test_server.py @@ -27,10 +27,6 @@ from subiquity.server.server import ( class TestAutoinstallLoad(SubiTestCase): - def assertContents(self, path, expected): - with open(path, 'r') as fp: - self.assertEqual(expected, fp.read()) - def setUp(self): self.tempdir = self.tmp_dir() opts = Mock() @@ -66,7 +62,7 @@ class TestAutoinstallLoad(SubiTestCase): self.create(cloud_autoinstall_path, 'cloud') iso = self.create(iso_autoinstall_path, 'iso') self.assertEqual(iso, self.server.select_autoinstall()) - self.assertContents(iso, 'reload') + self.assert_contents(iso, 'reload') def test_arg_wins(self): arg = self.create(self.path('arg.autoinstall.yaml'), 'arg') @@ -74,29 +70,22 @@ class TestAutoinstallLoad(SubiTestCase): self.create(cloud_autoinstall_path, 'cloud') iso = self.create(iso_autoinstall_path, 'iso') self.assertEqual(iso, self.server.select_autoinstall()) - self.assertContents(iso, 'arg') + self.assert_contents(iso, 'arg') def test_cloud_wins(self): self.create(cloud_autoinstall_path, 'cloud') iso = self.create(iso_autoinstall_path, 'iso') self.assertEqual(iso, self.server.select_autoinstall()) - self.assertContents(iso, 'cloud') + self.assert_contents(iso, 'cloud') def test_iso_wins(self): iso = self.create(iso_autoinstall_path, 'iso') self.assertEqual(iso, self.server.select_autoinstall()) - self.assertContents(iso, 'iso') + self.assert_contents(iso, 'iso') def test_nobody_wins(self): self.assertIsNone(self.server.select_autoinstall()) - def test_copied_to_reload(self): - data = 'stuff things' - src = self.create(self.path('test.yaml'), data) - tgt = self.path(reload_autoinstall_path) - self.server.copy_autoinstall(src, tgt) - self.assertContents(tgt, data) - def test_bogus_autoinstall_argument(self): self.server.opts.autoinstall = self.path('nonexistant.yaml') with self.assertRaises(Exception): diff --git a/subiquitycore/file_util.py b/subiquitycore/file_util.py index 0af78d2b..0c5fe9e4 100644 --- a/subiquitycore/file_util.py +++ b/subiquitycore/file_util.py @@ -18,6 +18,7 @@ import datetime import grp import logging import os +import shutil import tempfile import yaml @@ -73,3 +74,16 @@ def generate_config_yaml(filename, content, **kwargs): now = datetime.datetime.utcnow() tf.write(f'# Autogenerated by Subiquity: {now} UTC\n') tf.write(yaml.dump(content)) + + +def copy_file_if_exists(source: str, target: str): + """If source exists, copy to destination. Ignore error that dest may be a + duplicate. Create destination parent dirs as needed.""" + if not os.path.exists(source): + return + dirname = os.path.dirname(target) + os.makedirs(dirname, exist_ok=True) + try: + shutil.copyfile(source, target) + except shutil.SameFileError: + pass diff --git a/subiquitycore/tests/__init__.py b/subiquitycore/tests/__init__.py index da500c3f..0d117de5 100644 --- a/subiquitycore/tests/__init__.py +++ b/subiquitycore/tests/__init__.py @@ -25,6 +25,10 @@ class SubiTestCase(TestCase): dir = self.tmp_dir() return os.path.normpath(os.path.abspath(os.path.join(dir, path))) + def assert_contents(self, path, expected_contents): + with open(path, 'r') as fp: + self.assertEqual(expected_contents, fp.read()) + def populate_dir(path, files): if not os.path.exists(path): diff --git a/subiquitycore/tests/test_file_util.py b/subiquitycore/tests/test_file_util.py new file mode 100644 index 00000000..ae0a5b43 --- /dev/null +++ b/subiquitycore/tests/test_file_util.py @@ -0,0 +1,31 @@ +# Copyright 2022 Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from subiquitycore.file_util import copy_file_if_exists +from subiquitycore.tests import SubiTestCase + + +class TestCopy(SubiTestCase): + def test_copied_to_non_exist_dir(self): + data = 'stuff things' + src = self.tmp_path('src') + tgt = self.tmp_path('create-me/target') + with open(src, 'w') as fp: + fp.write(data) + copy_file_if_exists(src, tgt) + self.assert_contents(tgt, data) + + def test_copied_non_exist_src(self): + copy_file_if_exists('/does/not/exist', '/ditto')