diff --git a/subiquity/common/apidef.py b/subiquity/common/apidef.py index e75da3c4..f0f0d0ab 100644 --- a/subiquity/common/apidef.py +++ b/subiquity/common/apidef.py @@ -28,6 +28,7 @@ from subiquity.common.types import ( ADConnectionInfo, AdAdminNameValidation, AdDomainNameValidation, + AdJoinResult, AdPasswordValidation, AddPartitionV2, AnyStep, @@ -434,6 +435,9 @@ class API: class check_password: def POST(password: Payload[str]) -> AdPasswordValidation: ... + class join_result: + def GET(wait: bool = True) -> AdJoinResult: ... + class LinkAction(enum.Enum): NEW = enum.auto() diff --git a/subiquity/common/types.py b/subiquity/common/types.py index 732408b1..e33cfc7f 100644 --- a/subiquity/common/types.py +++ b/subiquity/common/types.py @@ -801,3 +801,9 @@ class AdDomainNameValidation(enum.Enum): class AdPasswordValidation(enum.Enum): OK = 'OK' EMPTY = 'Empty' + + +class AdJoinResult(enum.Enum): + OK = 'OK' + JOIN_ERROR = 'Failed to join' + UNKNOWN = "Didn't attempt to join yet" diff --git a/subiquity/server/ad_joiner.py b/subiquity/server/ad_joiner.py new file mode 100644 index 00000000..2801c0f2 --- /dev/null +++ b/subiquity/server/ad_joiner.py @@ -0,0 +1,41 @@ +# Copyright 2023 Canonical, Ltd. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import asyncio +from subiquity.common.types import ( + ADConnectionInfo, + AdJoinResult, +) + + +class AdJoiner(): + def __init__(self): + self._result = AdJoinResult.UNKNOWN + self.join_task = None + + async def join_domain(self, info: ADConnectionInfo) -> AdJoinResult: + self.join_task = asyncio.create_task(self.async_join(info)) + self._result = await self.join_task + return self._result + + async def async_join(self, info: ADConnectionInfo) -> AdJoinResult: + # TODO: Join. + return await asyncio.sleep(3, result=AdJoinResult.JOIN_ERROR) + + async def join_result(self): + if self.join_task is None: + return AdJoinResult.UNKNOWN + + return await self.join_task diff --git a/subiquity/server/controllers/ad.py b/subiquity/server/controllers/ad.py index 346f6210..89f2e8d6 100644 --- a/subiquity/server/controllers/ad.py +++ b/subiquity/server/controllers/ad.py @@ -25,8 +25,10 @@ from subiquity.common.types import ( ADConnectionInfo, AdAdminNameValidation, AdDomainNameValidation, + AdJoinResult, AdPasswordValidation ) +from subiquity.server.ad_joiner import AdJoiner from subiquity.server.controller import SubiquityController log = logging.getLogger('subiquity.server.controllers.ad') @@ -95,6 +97,8 @@ class ADController(SubiquityController): def __init__(self, app): super().__init__(app) + self.ad_joiner = None + self.join_result = AdJoinResult.UNKNOWN if self.app.opts.dry_run: self.ping_strgy = StubDcPingStrategy() else: @@ -141,6 +145,23 @@ class ADController(SubiquityController): to configure AD are present in the live system.""" return self.ping_strgy.has_support() + async def join_result_GET(self, wait: bool = True) -> AdJoinResult: + # Enables testing the API without the need for the install controller + if self.app.opts.dry_run and self.ad_joiner is None: + await self.join_domain() + + if wait and self.ad_joiner: + self.join_result = await self.ad_joiner.join_result() + + return self.join_result + + async def join_domain(self) -> None: + """To be called from the install controller if joining was requested""" + if self.ad_joiner is None: + self.ad_joiner = AdJoiner() + + await self.ad_joiner.join_domain(self.model.conn_info) + # Helper out-of-class functions grouped. class AdValidators: diff --git a/subiquity/server/controllers/install.py b/subiquity/server/controllers/install.py index 8f99322f..55fbabb1 100644 --- a/subiquity/server/controllers/install.py +++ b/subiquity/server/controllers/install.py @@ -406,6 +406,8 @@ class InstallController(SubiquityController): policy = self.model.updates.updates await self.run_unattended_upgrades(context=context, policy=policy) await self.restore_apt_config(context=context) + if self.model.ad.do_join: + await self.app.controllers.Ad.join_domain() @with_context(description="configuring cloud-init") async def configure_cloud_init(self, context): diff --git a/subiquity/tests/api/test_api.py b/subiquity/tests/api/test_api.py index 77c2d4a0..b501e7a1 100644 --- a/subiquity/tests/api/test_api.py +++ b/subiquity/tests/api/test_api.py @@ -1696,3 +1696,6 @@ class TestActiveDirectory(TestAPI): result = await instance.post(endpoint + '/check_admin_name', data='$Ubuntu') self.assertEqual('OK', result) + # Attempts to join with the info supplied above. + join_result = await instance.get(endpoint + '/join_result') + self.assertEqual('JOIN_ERROR', join_result)