From d6de8e6387f9bd6aea96082b122c5f1ad2a848d5 Mon Sep 17 00:00:00 2001 From: Justus Kuhlmann Date: Tue, 21 Apr 2026 10:22:46 +0200 Subject: [PATCH] Introduce thin wrapper for SQL calls --- corrlib/find.py | 21 ++++++++------------- corrlib/sql.py | 17 +++++++++++++++++ tests/find_test.py | 13 ++++++++++--- 3 files changed, 35 insertions(+), 16 deletions(-) create mode 100644 corrlib/sql.py diff --git a/corrlib/find.py b/corrlib/find.py index 7b07321..1e6b4bf 100644 --- a/corrlib/find.py +++ b/corrlib/find.py @@ -7,6 +7,7 @@ from .input.implementations import codes from .tools import k2m, get_db_file from .tracker import get from .integrity import has_valid_times +from .sql import thin_sql_wrapper from typing import Any, Optional from pathlib import Path import datetime as dt @@ -14,7 +15,7 @@ from collections.abc import Callable import warnings -def _project_lookup_by_alias(db: Path, alias: str) -> str: +def _project_lookup_by_alias(path: Path, alias: str) -> str: """ Lookup a projects UUID by its (human-readable) alias. @@ -30,11 +31,8 @@ def _project_lookup_by_alias(db: Path, alias: str) -> str: uuid: str The UUID of the project with the given alias. """ - conn = sqlite3.connect(db) - c = conn.cursor() - c.execute(f"SELECT * FROM 'projects' WHERE aliases = '{alias}'") - results = c.fetchall() - conn.close() + stmt = f"SELECT * FROM 'projects' WHERE aliases = '{alias}'" + results = thin_sql_wrapper(path, stmt) if len(results)>1: print("Error: multiple projects found with alias " + alias) elif len(results) == 0: @@ -42,7 +40,7 @@ def _project_lookup_by_alias(db: Path, alias: str) -> str: return str(results[0][0]) -def _project_lookup_by_id(db: Path, uuid: str) -> list[tuple[str, ...]]: +def _project_lookup_by_id(path: Path, uuid: str) -> list[tuple[str, ...]]: """ Return the project information available in the database by UUID. @@ -58,11 +56,8 @@ def _project_lookup_by_id(db: Path, uuid: str) -> list[tuple[str, ...]]: results: list The row of the project in the database. """ - conn = sqlite3.connect(db) - c = conn.cursor() - c.execute(f"SELECT * FROM 'projects' WHERE id = '{uuid}'") - results = c.fetchall() - conn.close() + stmt = f"SELECT * FROM 'projects' WHERE id = '{uuid}'" + results = thin_sql_wrapper(path, stmt) return results @@ -360,7 +355,7 @@ def find_project(path: Path, name: str) -> str: """ db_file = get_db_file(path) get(path, db_file) - return _project_lookup_by_alias(path / db_file, name) + return _project_lookup_by_alias(path, name) def list_projects(path: Path) -> list[tuple[str, str]]: diff --git a/corrlib/sql.py b/corrlib/sql.py new file mode 100644 index 0000000..f45ce31 --- /dev/null +++ b/corrlib/sql.py @@ -0,0 +1,17 @@ +import sqlite3 +from .tools import get_db_file +from pathlib import Path +from typing import Any + + +def thin_sql_wrapper(path: Path, stmt: str) -> list[Any]: + db_file = get_db_file(path) + db = path / db_file + conn = sqlite3.connect(db) + c = conn.cursor() + + c.execute(stmt) + results = c.fetchall() + conn.commit() + conn.close() + return results diff --git a/tests/find_test.py b/tests/find_test.py index cc455f9..2144001 100644 --- a/tests/find_test.py +++ b/tests/find_test.py @@ -9,12 +9,17 @@ import datetime as dt def make_sql(path: Path) -> Path: - db = path / "test.db" + db = path / "backlogger.db" cinit._create_db(db) return db +def make_config(path: Path) -> None: + cinit._write_config(path, cinit._create_config(path, "datalad", False)) + + def test_find_lookup_by_one_alias(tmp_path: Path) -> None: + make_config(tmp_path) db = make_sql(tmp_path) conn = sqlite3.connect(db) c = conn.cursor() @@ -26,7 +31,7 @@ def test_find_lookup_by_one_alias(tmp_path: Path) -> None: c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", (uuid, alias_str, tag_str, owner, code)) conn.commit() - assert uuid == find._project_lookup_by_alias(db, "fun_project") + assert uuid == find._project_lookup_by_alias(tmp_path, "fun_project") uuid = "test_uuid2" alias_str = "fun_project" c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", @@ -36,7 +41,9 @@ def test_find_lookup_by_one_alias(tmp_path: Path) -> None: assert uuid == find._project_lookup_by_alias(db, "fun_project") conn.close() + def test_find_lookup_by_id(tmp_path: Path) -> None: + make_config(tmp_path) db = make_sql(tmp_path) conn = sqlite3.connect(db) c = conn.cursor() @@ -49,7 +56,7 @@ def test_find_lookup_by_id(tmp_path: Path) -> None: (uuid, alias_str, tag_str, owner, code)) conn.commit() conn.close() - result = find._project_lookup_by_id(db, uuid)[0] + result = find._project_lookup_by_id(tmp_path, uuid)[0] assert uuid == result[0] assert alias_str == result[1] assert tag_str == result[2] -- 2.43.0