make SingleInstanceTask.start_sync set self.task synchronously

This commit is contained in:
Michael Hudson-Doyle 2019-12-16 12:13:37 +13:00
parent b639d3a4dd
commit de18cc977f
2 changed files with 28 additions and 14 deletions

View File

@ -13,6 +13,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import enum
import logging
import os
@ -57,22 +58,23 @@ class RefreshController(BaseController):
def start(self):
self.configure_task = schedule_task(self.configure_snapd())
self.check_task_starter = SingleInstanceTask(
self.check_task = SingleInstanceTask(
self.check_for_update, propagate_errors=False)
self.check_task_starter.start_sync()
self.check_task.start_sync()
@property
def check_state(self):
task = self.check_task_starter.task
if task is None:
return CheckState.UNKNOWN
task = self.check_task.task
if not task.done():
return CheckState.UNKNOWN
if task.cancelled():
return CheckState.UNKNOWN
if task.exception():
return CheckState.UNAVAILABLE
return task.result()
async def configure_snapd(self):
log.debug("configure_snapd")
try:
r = await self.app.snapd.get(
'v2/snaps/{snap_name}'.format(snap_name=self.snap_name))
@ -130,10 +132,10 @@ class RefreshController(BaseController):
def snapd_network_changed(self):
if self.check_state == CheckState.UNKNOWN:
self.check_task_starter.start_sync()
self.check_task.start_sync()
async def check_for_update(self):
await self.configure_task
await asyncio.shield(self.configure_task)
# If we restarted into this version, don't check for a new version.
if self.app.updated:
return CheckState.UNAVAILABLE

View File

@ -14,6 +14,10 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging
log = logging.getLogger("subiquitycore.async_helpers")
def _done(fut):
@ -47,16 +51,24 @@ class SingleInstanceTask:
self.propagate_errors = propagate_errors
self.task = None
async def start(self, *args, **kw):
if self.task is not None:
self.task.cancel()
async def _start(self, old):
if old is not None:
old.cancel()
try:
await self.task
await old
except BaseException:
pass
self.task = schedule_task(
self.func(*args, **kw), self.propagate_errors)
schedule_task(self.task, self.propagate_errors)
async def start(self, *args, **kw):
await self.start_sync(*args, **kw)
return self.task
def start_sync(self, *args, **kw):
return schedule_task(self.start(*args, **kw))
old = self.task
coro = self.func(*args, **kw)
if asyncio.iscoroutine(coro):
self.task = asyncio.Task(coro)
else:
self.task = coro
return schedule_task(self._start(old))