diff --git a/subiquity/controllers/installpath.py b/subiquity/controllers/installpath.py index 5ce29151..45a3be4a 100644 --- a/subiquity/controllers/installpath.py +++ b/subiquity/controllers/installpath.py @@ -30,7 +30,7 @@ class InstallpathController(BaseController): self.answers = self.all_answers.get("Installpath", {}) self.release = None - def installpath(self): + def default(self): self.ui.set_body(InstallpathView(self.model, self)) if 'path' in self.answers: path = self.answers['path'] @@ -41,11 +41,16 @@ class InstallpathController(BaseController): self.model.update(self.answers) self.signal.emit_signal('next-screen') - default = installpath - def cancel(self): self.signal.emit_signal('prev-screen') + def serialize(self): + return {'path': self.model.path, 'results': self.model.results} + + def deserialize(self, data): + self.model.path = data['path'] + self.model.results = data['results'] + def choose_path(self, path): self.model.path = path getattr(self, 'install_' + path)() diff --git a/subiquity/controllers/mirror.py b/subiquity/controllers/mirror.py index b2a50600..94c4674c 100644 --- a/subiquity/controllers/mirror.py +++ b/subiquity/controllers/mirror.py @@ -36,6 +36,12 @@ class MirrorController(BaseController): def cancel(self): self.signal.emit_signal('prev-screen') + def serialize(self): + return self.model.mirror + + def deserialize(self, data): + self.model.mirror = data + def done(self, mirror): if mirror != self.model.mirror: self.model.mirror = mirror diff --git a/subiquity/controllers/proxy.py b/subiquity/controllers/proxy.py index 43edd979..c215ec83 100644 --- a/subiquity/controllers/proxy.py +++ b/subiquity/controllers/proxy.py @@ -38,6 +38,12 @@ class ProxyController(BaseController): def cancel(self): self.signal.emit_signal('prev-screen') + def serialize(self): + return self.model.proxy + + def deserialize(self, data): + self.model.proxy = data + def done(self, proxy): if proxy != self.model.proxy: self.model.proxy = proxy diff --git a/subiquity/controllers/welcome.py b/subiquity/controllers/welcome.py index 23a1a00a..05eeac07 100644 --- a/subiquity/controllers/welcome.py +++ b/subiquity/controllers/welcome.py @@ -41,3 +41,9 @@ class WelcomeController(BaseController): def cancel(self): # Can't go back from here! pass + + def serialize(self): + return self.model.selected_language + + def deserialize(self, data): + self.model.switch_language(data) diff --git a/subiquitycore/controller.py b/subiquitycore/controller.py index 595f1ff3..672ed630 100644 --- a/subiquitycore/controller.py +++ b/subiquitycore/controller.py @@ -74,6 +74,13 @@ class BaseController(ABC): def default(self): pass + def serialize(self): + return None + + def deserialize(self, data): + if data is not None: + raise Exception("missing deserialize method on {}".format(self)) + # Stuff for fine grained actions, used by filesystem and network # controller at time of writing this comment. diff --git a/subiquitycore/core.py b/subiquitycore/core.py index c9000233..e39c2cb2 100644 --- a/subiquitycore/core.py +++ b/subiquitycore/core.py @@ -15,6 +15,7 @@ from concurrent import futures import fcntl +import json import logging import os import struct @@ -254,6 +255,12 @@ class Application: opts.project = self.project + root = '/' + if opts.dry_run: + root = '.subiquity' + self.state_dir = os.path.join(root, 'run', self.project) + os.makedirs(os.path.join(self.state_dir, 'states'), exist_ok=True) + answers = {} if opts.answers is not None: answers = yaml.safe_load(open(opts.answers).read()) @@ -268,7 +275,9 @@ class Application: input_filter = DummyKeycodesFilter() scale = float(os.environ.get('SUBIQUITY_REPLAY_TIMESCALE', "1")) + updated = os.path.exists(os.path.join(self.state_dir, 'updating')) self.common = { + "updated": updated, "ui": ui, "opts": opts, "signal": Signal(), @@ -304,6 +313,16 @@ class Application: controller_class.register_signals() log.debug(self.common['signal']) + def save_state(self): + if self.controller_index < 0: + return + cur_controller_name = self.controllers[self.controller_index] + cur_controller = self.common['controllers'][cur_controller_name] + state_path = os.path.join( + self.state_dir, 'states', cur_controller_name) + with open(state_path, 'w') as fp: + json.dump(cur_controller.serialize(), fp) + def select_screen(self, index): self.controller_index = index self.common['ui'].progress_current = index @@ -311,8 +330,12 @@ class Application: log.debug("moving to screen %s", controller_name) controller = self.common['controllers'][controller_name] controller.default() + state_path = os.path.join(self.state_dir, 'last-screen') + with open(state_path, 'w') as fp: + fp.write(controller_name) def next_screen(self, *args): + self.save_state() while True: if self.controller_index == len(self.controllers) - 1: self.exit() @@ -324,6 +347,7 @@ class Application: return def prev_screen(self, *args): + self.save_state() while True: if self.controller_index == 0: self.exit() @@ -429,7 +453,6 @@ class Application: self.common['loop'].event_loop)) self.common['base_model'] = self.model_class(self.common) try: - self.common['loop'].set_alarm_in(0.05, self.next_screen) if self.common['opts'].scripts: self.run_scripts(self.common['opts'].scripts) controllers_mod = __import__('%s.controllers' % self.project, @@ -439,6 +462,36 @@ class Application: klass = getattr(controllers_mod, k+"Controller") self.common['controllers'][k] = klass(self.common) log.debug("*** %s", self.common['controllers']) + + initial_controller_index = 0 + + if self.common['updated']: + for k in self.controllers: + state_path = os.path.join(self.state_dir, 'states', k) + if not os.path.exists(state_path): + continue + with open(state_path) as fp: + self.common['controllers'][k].deserialize( + json.load(fp)) + + last_screen = None + state_path = os.path.join(self.state_dir, 'last-screen') + if os.path.exists(state_path): + with open(state_path) as fp: + last_screen = fp.read().strip() + + if last_screen in self.controllers: + initial_controller_index = self.controllers.index( + last_screen) + + def select_initial_screen(loop, index): + try: + self.select_screen(index) + except Skip: + self.next_screen() + + self.common['loop'].set_alarm_in( + 0.05, select_initial_screen, initial_controller_index) self._connect_base_signals() self.common['loop'].run() except Exception: