From b906abe2b19cd39a7ed4156576975b848f134ea1 Mon Sep 17 00:00:00 2001 From: io Date: Fri, 17 Sep 2021 06:34:44 +0000 Subject: [PATCH] add basic migration support --- fetch_posts.py | 20 +++++++++++++++++++- schema.sql | 4 ++++ utils.py | 18 ++++++++++++++++++ 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/fetch_posts.py b/fetch_posts.py index ac6f04a..adfcee7 100755 --- a/fetch_posts.py +++ b/fetch_posts.py @@ -9,10 +9,10 @@ import pendulum import operator import aiosqlite import contextlib -from utils import shield from pleroma import Pleroma from bs4 import BeautifulSoup from functools import partial +from utils import shield, suppress from typing import Iterable, NewType from third_party.utils import extract_post_content @@ -26,6 +26,8 @@ UTC = pendulum.timezone('UTC') JSON_CONTENT_TYPE = 'application/json' ACTIVITYPUB_CONTENT_TYPE = 'application/activity+json' +MIGRATION_VERSION = 1 + class PostFetcher: def __init__(self, *, config): self.config = config @@ -47,10 +49,26 @@ class PostFetcher: ), ) self._db = await stack.enter_async_context(aiosqlite.connect(self.config['db_path'])) + await self._maybe_run_migrations() self._db.row_factory = aiosqlite.Row self._ctx_stack = stack return self + async def _maybe_run_migrations(self): + async with self._db.cursor() as cur, suppress(aiosqlite.OperationalError): + if await (await cur.execute('SELECT migration_version FROM migrations')).fetchone(): return + + await self._run_migrations() + + async def _run_migrations(self): + # TODO proper migrations, not just "has the schema ever been run" migrations + async with await (anyio.Path(__file__).parent/'schema.sql').open() as f: + schema = await f.read() + + async with self._db.cursor() as cur: + await cur.executescript(schema) + await cur.execute('INSERT INTO migrations (migration_version) VALUES (?)', (MIGRATION_VERSION,)) + async def __aexit__(self, *excinfo): return await self._ctx_stack.__aexit__(*excinfo) diff --git a/schema.sql b/schema.sql index b716719..663b81d 100644 --- a/schema.sql +++ b/schema.sql @@ -6,3 +6,7 @@ CREATE TABLE posts ( -- UTC Unix timestamp in seconds published_at REAL NOT NULL ); + +CREATE TABLE migrations ( + migration_version INTEGER NOT NULL +); diff --git a/utils.py b/utils.py index 085b536..ecc773d 100644 --- a/utils.py +++ b/utils.py @@ -1,7 +1,25 @@ # SPDX-License-Identifier: AGPL-3.0-only import anyio +import contextlib from functools import wraps +from datetime import datetime, timezone + +def as_corofunc(f): + @wraps(f) + async def wrapped(*args, **kwargs): + # can't decide if i want an `anyio.sleep(0)` here. + return f(*args, **kwargs) + return wrapped + +def as_async_cm(cls): + @wraps(cls, updated=()) # cls.__dict__ doesn't support .update() + class wrapped(cls, contextlib.AbstractAsyncContextManager): + __aenter__ = as_corofunc(cls.__enter__) + __aexit__ = as_corofunc(cls.__exit__) + return wrapped + +suppress = as_async_cm(contextlib.suppress) def shield(f): @wraps(f)