refactor how controllers are stored a bit

This commit is contained in:
Michael Hudson-Doyle 2019-11-18 13:55:43 +13:00
parent 9e0c0bf106
commit 64378cdb95
8 changed files with 110 additions and 101 deletions

View File

@ -236,7 +236,7 @@ class IdentityController(BaseController):
login_details_path = '/run/console-conf/login-details.txt' login_details_path = '/run/console-conf/login-details.txt'
self.model.add_user(result) self.model.add_user(result)
ips = [] ips = []
net_model = self.app.controller_instances['Network'].model net_model = self.app.base_model.network
for dev in net_model.get_all_netdevs(): for dev in net_model.get_all_netdevs():
ips.extend(dev.actual_global_ip_addresses) ips.extend(dev.actual_global_ip_addresses)
with open(login_details_path, 'w') as fp: with open(login_details_path, 'w') as fp:
@ -252,7 +252,7 @@ class IdentityController(BaseController):
title = "Configuration Complete" title = "Configuration Complete"
self.ui.set_header(title) self.ui.set_header(title)
net_model = self.app.controller_instances['Network'].model net_model = self.app.base_model.network
ifaces = net_model.get_all_netdevs() ifaces = net_model.get_all_netdevs()
login_view = LoginView(self.opts, self.model, self, ifaces) login_view = LoginView(self.opts, self.model, self, ifaces)

View File

@ -13,6 +13,7 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from .error import ErrorController
from .filesystem import FilesystemController from .filesystem import FilesystemController
from .identity import IdentityController from .identity import IdentityController
from .installprogress import InstallProgressController from .installprogress import InstallProgressController
@ -26,6 +27,7 @@ from .ssh import SSHController
from .welcome import WelcomeController from .welcome import WelcomeController
from .zdev import ZdevController from .zdev import ZdevController
__all__ = [ __all__ = [
'ErrorController',
'FilesystemController', 'FilesystemController',
'IdentityController', 'IdentityController',
'InstallProgressController', 'InstallProgressController',

View File

@ -32,6 +32,7 @@ import requests
import urwid import urwid
from subiquitycore.controller import BaseController from subiquitycore.controller import BaseController
from subiquitycore.core import Skip
from subiquity.async_helpers import ( from subiquity.async_helpers import (
run_in_thread, run_in_thread,
@ -357,7 +358,7 @@ class ErrorController(BaseController):
return r return r
def start_ui(self): def start_ui(self):
pass raise Skip
def cancel(self): def cancel(self):
pass pass

View File

@ -24,7 +24,6 @@ import apport.hookutils
from subiquitycore.core import Application from subiquitycore.core import Application
from subiquity.controllers.error import ( from subiquity.controllers.error import (
ErrorController,
ErrorReportKind, ErrorReportKind,
) )
from subiquity.models.subiquity import SubiquityModel from subiquity.models.subiquity import SubiquityModel
@ -83,6 +82,7 @@ class Subiquity(Application):
"SSH", "SSH",
"SnapList", "SnapList",
"InstallProgress", "InstallProgress",
"Error", # does not have a UI
] ]
def __init__(self, opts, block_log_dir): def __init__(self, opts, block_log_dir):
@ -119,7 +119,7 @@ class Subiquity(Application):
def select_initial_screen(self, index): def select_initial_screen(self, index):
super().select_initial_screen(index) super().select_initial_screen(index)
for report in self.error_controller.reports: for report in self.controllers.Error.reports:
if report.kind == ErrorReportKind.UI and not report.seen: if report.kind == ErrorReportKind.UI and not report.seen:
log.debug("showing new error %r", report.base) log.debug("showing new error %r", report.base)
self.show_error_report(report) self.show_error_report(report)
@ -163,14 +163,6 @@ class Subiquity(Application):
self.run_command_in_foreground( self.run_command_in_foreground(
"bash", before_hook=_before, cwd='/') "bash", before_hook=_before, cwd='/')
def load_controllers(self):
super().load_controllers()
self.error_controller = ErrorController(self)
def start_controllers(self):
super().start_controllers()
self.error_controller.start()
def note_file_for_apport(self, key, path): def note_file_for_apport(self, key, path):
self._apport_files.append((key, path)) self._apport_files.append((key, path))
@ -181,7 +173,7 @@ class Subiquity(Application):
log.debug("generating crash report") log.debug("generating crash report")
try: try:
report = self.error_controller.create_report(kind) report = self.controllers.Error.create_report(kind)
except Exception: except Exception:
log.exception("creating crash report failed") log.exception("creating crash report failed")
return return

View File

@ -280,7 +280,7 @@ class ErrorReportListStretchy(Stretchy):
Text(""), Text(""),
])] ])]
self.report_to_row = {} self.report_to_row = {}
for report in self.app.error_controller.reports: for report in self.app.controllers.Error.reports:
connect_signal(report, "changed", self._report_changed, report) connect_signal(report, "changed", self._report_changed, report)
r = self.report_to_row[report] = self.row_for_report(report) r = self.report_to_row[report] = self.row_for_report(report)
rows.append(r) rows.append(r)

View File

@ -195,7 +195,7 @@ class HelpMenu(WidgetWrap):
local = Text( local = Text(
('info_minor header', " " + _("Help on this screen") + " ")) ('info_minor header', " " + _("Help on this screen") + " "))
if self.parent.app.error_controller.reports: if self.parent.app.controllers.Error.reports:
view_errors = menu_item( view_errors = menu_item(
_("View error reports").format(local_title), _("View error reports").format(local_title),
on_press=self._show_errors) on_press=self._show_errors)

View File

@ -26,11 +26,8 @@ class BaseController(ABC):
signals = [] signals = []
@classmethod
def _controller_name(cls):
return cls.__name__[:-len("Controller")]
def __init__(self, app): def __init__(self, app):
self.name = type(self).__name__[:-len("Controller")]
self.ui = app.ui self.ui = app.ui
self.signal = app.signal self.signal = app.signal
self.opts = app.opts self.opts = app.opts
@ -47,7 +44,7 @@ class BaseController(ABC):
self.loop = app.loop self.loop = app.loop
self.run_in_bg = app.run_in_bg self.run_in_bg = app.run_in_bg
self.app = app self.app = app
self.answers = app.answers.get(self._controller_name(), {}) self.answers = app.answers.get(self.name, {})
def register_signals(self): def register_signals(self):
"""Defines signals associated with controller from model.""" """Defines signals associated with controller from model."""
@ -72,10 +69,7 @@ class BaseController(ABC):
@property @property
def showing(self): def showing(self):
cur_controller = self.app.cur_controller return self.app.controllers.cur is self
while isinstance(cur_controller, RepeatedController):
cur_controller = cur_controller.orig
return cur_controller is self
@abstractmethod @abstractmethod
def start_ui(self): def start_ui(self):
@ -144,6 +138,7 @@ class BaseController(ABC):
class RepeatedController(BaseController): class RepeatedController(BaseController):
def __init__(self, orig, index): def __init__(self, orig, index):
self.name = "{}-{}".format(orig.name, index)
self.orig = orig self.orig = orig
self.index = index self.index = index

View File

@ -282,6 +282,59 @@ class AsyncioEventLoop(urwid.AsyncioEventLoop):
loop.default_exception_handler(context) loop.default_exception_handler(context)
class ControllerSet:
def __init__(self, app, names):
self.app = app
self.controller_names = names
self.index = -1
self.instances = []
def load(self):
controllers_mod = __import__(
'{}.controllers'.format(self.app.project), None, None, [''])
for name in self.controller_names:
log.debug("Importing controller: %s", name)
klass = getattr(controllers_mod, name+"Controller")
if hasattr(self, name):
c = 0
for instance in self.instances:
if isinstance(instance, klass):
c += 1
inst = RepeatedController(getattr(self, name), c)
name = inst.name
else:
inst = klass(self.app)
setattr(self, name, inst)
self.instances.append(inst)
@property
def cur(self):
if self.index >= 0:
inst = self.instances[self.index]
while isinstance(inst, RepeatedController):
inst = inst.orig
return inst
else:
return None
@property
def at_end(self):
return self.index == len(self.instances) - 1
@property
def at_start(self):
return self.index == 0
def advance(self):
self.index += 1
return self.cur
def back_up(self):
self.index -= 1
return self.cur
class Application: class Application:
# A concrete subclass must set project and controllers attributes, e.g.: # A concrete subclass must set project and controllers attributes, e.g.:
@ -295,7 +348,7 @@ class Application:
# "InstallProgress", # "InstallProgress",
# ] # ]
# The 'next-screen' and 'prev-screen' signals move through the list of # The 'next-screen' and 'prev-screen' signals move through the list of
# controllers in order, calling the default method on the controller # controllers in order, calling the start_ui method on the controller
# instance. # instance.
make_ui = SubiquityCoreUI make_ui = SubiquityCoreUI
@ -337,21 +390,7 @@ class Application:
self.prober = prober self.prober = prober
self.loop = None self.loop = None
self.pool = futures.ThreadPoolExecutor(10) self.pool = futures.ThreadPoolExecutor(10)
if opts.screens: self.controllers = ControllerSet(self, self.controllers)
self.controllers = [c for c in self.controllers
if c in opts.screens]
else:
self.controllers = self.controllers[:]
self.ui.progress_completion = len(self.controllers)
self.controller_instances = dict.fromkeys(self.controllers)
self.controller_index = -1
@property
def cur_controller(self):
if self.controller_index < 0:
return None
controller_name = self.controllers[self.controller_index]
return self.controller_instances[controller_name]
def run_in_bg(self, func, callback): def run_in_bg(self, func, callback):
"""Run func() in a thread and call callback on UI thread. """Run func() in a thread and call callback on UI thread.
@ -423,58 +462,58 @@ class Application:
self.signal.connect_signals(signals) self.signal.connect_signals(signals)
# Registers signals from each controller # Registers signals from each controller
for controller_class in self.controller_instances.values(): for controller in self.controllers.instances:
controller_class.register_signals() controller.register_signals()
log.debug("known signals: %s", self.signal.known_signals) log.debug("known signals: %s", self.signal.known_signals)
def save_state(self): def save_state(self):
cur_controller = self.cur_controller cur = self.controllers.cur
if cur_controller is None: if cur is None:
return return
state_path = os.path.join( state_path = os.path.join(
self.state_dir, 'states', cur_controller._controller_name()) self.state_dir, 'states', cur.name)
with open(state_path, 'w') as fp: with open(state_path, 'w') as fp:
json.dump(cur_controller.serialize(), fp) json.dump(cur.serialize(), fp)
def select_screen(self, index): def start_screen(self, new):
if self.cur_controller is not None: old = self.controllers.cur
self.cur_controller.end_ui() if old is not None:
self.controller_index = index old.end_ui()
self.ui.progress_current = index log.debug("moving to screen %s", new.name)
log.debug( if self.opts.screens and new.name not in self.opts.screens:
"moving to screen %s", self.cur_controller._controller_name()) raise Skip
self.cur_controller.start_ui() new.start_ui()
state_path = os.path.join(self.state_dir, 'last-screen') state_path = os.path.join(self.state_dir, 'last-screen')
with open(state_path, 'w') as fp: with open(state_path, 'w') as fp:
fp.write(self.cur_controller._controller_name()) fp.write(new.name)
def next_screen(self, *args): def next_screen(self, *args):
self.save_state() self.save_state()
while True: while True:
if self.controller_index == len(self.controllers) - 1: if self.controllers.at_end:
self.exit() self.exit()
controller = self.controllers.advance()
try: try:
self.select_screen(self.controller_index + 1) self.start_screen(controller)
except Skip: except Skip:
controller_name = self.controllers[self.controller_index] log.debug("skipping screen %s", controller.name)
log.debug("skipping screen %s", controller_name)
continue continue
else: else:
return break
def prev_screen(self, *args): def prev_screen(self, *args):
self.save_state() self.save_state()
while True: while True:
if self.controller_index == 0: if self.controllers.at_start:
self.exit() self.exit()
controller = self.controllers.back_up()
try: try:
self.select_screen(self.controller_index - 1) self.start_screen(controller)
except Skip: except Skip:
controller_name = self.controllers[self.controller_index] log.debug("skipping screen %s", controller.name)
log.debug("skipping screen %s", controller_name)
continue continue
else: else:
return break
# EventLoop ------------------------------------------------------------------- # EventLoop -------------------------------------------------------------------
@ -566,56 +605,36 @@ class Application:
elif key in ['ctrl t', 'f4']: elif key in ['ctrl t', 'f4']:
self.toggle_color() self.toggle_color()
def load_controllers(self):
log.debug("load_controllers")
controllers_mod = __import__(
'{}.controllers'.format(self.project), None, None, [''])
for i, k in enumerate(self.controllers):
if self.controller_instances[k] is None:
log.debug("Importing controller: {}".format(k))
klass = getattr(controllers_mod, k+"Controller")
self.controller_instances[k] = klass(self)
else:
count = 1
for k2 in self.controllers[:i]:
if k2 == k or k2.startswith(k + '-'):
count += 1
orig = self.controller_instances[k]
k += '-' + str(count)
self.controllers[i] = k
self.controller_instances[k] = RepeatedController(
orig, count)
log.debug("load_controllers done")
def start_controllers(self): def start_controllers(self):
log.debug("starting controllers") log.debug("starting controllers")
for k in self.controllers: for controller in self.controllers.instances:
self.controller_instances[k].start() controller.start()
log.debug("controllers started") log.debug("controllers started")
def load_serialized_state(self): def load_serialized_state(self):
for k in self.controllers: for controller in self.controllers.instances:
state_path = os.path.join(self.state_dir, 'states', k) state_path = os.path.join(
self.state_dir, 'states', controller.name)
if not os.path.exists(state_path): if not os.path.exists(state_path):
continue continue
with open(state_path) as fp: with open(state_path) as fp:
self.controller_instances[k].deserialize( controller.deserialize(json.load(fp))
json.load(fp))
last_screen = None last_screen = None
state_path = os.path.join(self.state_dir, 'last-screen') state_path = os.path.join(self.state_dir, 'last-screen')
if os.path.exists(state_path): if os.path.exists(state_path):
with open(state_path) as fp: with open(state_path) as fp:
last_screen = fp.read().strip() last_screen = fp.read().strip()
controller_index = 0
for i, controller in enumerate(self.controllers.instances):
if controller.name == last_screen:
controller_index = i
return controller_index
if last_screen in self.controllers: def select_initial_screen(self, controller_index):
return self.controllers.index(last_screen) self.controllers.index = controller_index
else:
return 0
def select_initial_screen(self, index):
try: try:
self.select_screen(index) self.start_screen(self.controllers.cur)
except Skip: except Skip:
self.next_screen() self.next_screen()
@ -642,7 +661,7 @@ class Application:
if self.opts.scripts: if self.opts.scripts:
self.run_scripts(self.opts.scripts) self.run_scripts(self.opts.scripts)
self.load_controllers() self.controllers.load()
initial_controller_index = 0 initial_controller_index = 0