refactor/data_backend #12

Merged
jkuhl merged 36 commits from refactor/data_backend into develop 2025-12-04 15:47:45 +01:00
26 changed files with 3012 additions and 110 deletions
Showing only changes of commit 641c612a59 - Show all commits

Merge branch 'develop' into refactor/data_backend
Some checks failed
Mypy / mypy (push) Failing after 34s
Pytest / pytest (3.12) (push) Failing after 47s
Pytest / pytest (3.13) (push) Failing after 44s
Pytest / pytest (3.14) (push) Failing after 46s
Ruff / ruff (push) Failing after 33s

Justus Kuhlmann 2025-12-04 11:16:23 +01:00
Signed by: jkuhl
GPG key ID: 00ED992DD79B85A6

30
.github/workflows/mypy.yaml vendored Normal file
View file

@ -0,0 +1,30 @@
name: Mypy
on:
push:
pull_request:
workflow_dispatch:
jobs:
mypy:
runs-on: ubuntu-latest
env:
UV_CACHE_DIR: /tmp/.uv-cache
steps:
- name: Install git-annex
run: |
sudo apt-get update
sudo apt-get install -y git-annex
- name: Check out the repository
uses: https://github.com/RouxAntoine/checkout@v4.1.8
with:
show-progress: true
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
python-version: ${{ matrix.python-version }}
enable-cache: true
- name: Install corrlib
run: uv sync --locked --all-extras --dev --python "3.12"
- name: Run tests
run: uv run mypy corrlib

39
.github/workflows/pytest.yaml vendored Normal file
View file

@ -0,0 +1,39 @@
name: Pytest
on:
push:
pull_request:
workflow_dispatch:
schedule:
- cron: '0 4 1 * *'
jobs:
pytest:
strategy:
matrix:
python-version:
- "3.12"
- "3.13"
- "3.14"
runs-on: ubuntu-latest
env:
UV_CACHE_DIR: /tmp/.uv-cache
steps:
- name: Install git-annex
run: |
sudo apt-get update
sudo apt-get install -y git-annex
- name: Check out the repository
uses: https://github.com/RouxAntoine/checkout@v4.1.8
with:
show-progress: true
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
python-version: ${{ matrix.python-version }}
enable-cache: true
- name: Install corrlib
run: uv sync --locked --all-extras --dev --python ${{ matrix.python-version }}
- name: Run tests
run: uv run pytest --cov=corrlib tests

30
.github/workflows/ruff.yaml vendored Normal file
View file

@ -0,0 +1,30 @@
name: Ruff
on:
push:
pull_request:
workflow_dispatch:
jobs:
ruff:
runs-on: ubuntu-latest
env:
UV_CACHE_DIR: /tmp/.uv-cache
steps:
- name: Install git-annex
run: |
sudo apt-get update
sudo apt-get install -y git-annex
- name: Check out the repository
uses: https://github.com/RouxAntoine/checkout@v4.1.8
with:
show-progress: true
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
- name: Install corrlib
run: uv sync --locked --all-extras --dev --python "3.12"
- name: Run tests
run: uv run ruff check corrlib

4
.gitignore vendored
View file

@ -2,3 +2,7 @@ pyerrors_corrlib.egg-info
__pycache__
*.egg-info
test.ipynb
.vscode
.venv
.pytest_cache
.coverage

5
.gitmodules vendored
View file

@ -1,5 +0,0 @@
[submodule "projects/tmp"]
path = projects/tmp
url = git@kuhl-mann.de:lattice/charm_SF_data.git
datalad-id = 5f402163-77f2-470e-b6f1-64d7bf9f87d4
datalad-url = git@kuhl-mann.de:lattice/charm_SF_data.git

View file

@ -1,5 +1,5 @@
"""
The aim of this project is to extend pyerrors to be able to collect measurements from different projects and make them easily accessable to
The aim of this project is to extend pyerrors to be able to collect measurements from different projects and make them easily accessable to
the research group. The idea is to build a database, in which the researcher can easily search for measurements on a correlator basis,
which may be reusable.
As a standard to store the measurements, we will use the .json.gz format from pyerrors.
@ -15,10 +15,11 @@ For now, we are interested in collecting primary IObservables only, as these are
__app_name__ = "corrlib"
from .main import *
from .import input as input
from .initialization import *
from .meas_io import *
from .find import *
from .version import __version__
from .initialization import create as create
from .meas_io import load_record as load_record
from .meas_io import load_records as load_records
from .find import find_project as find_project
from .find import find_record as find_record
from .find import list_projects as list_projects
from .config import *

View file

@ -1,8 +1,9 @@
from corrlib import cli, __app_name__
def main():
def main() -> None:
cli.app(prog_name=__app_name__)
return
if __name__ == "__main__":

View file

@ -1,6 +1,6 @@
from typing import Optional
import typer
from corrlib import __app_name__, __version__
from corrlib import __app_name__
from .initialization import create
from .toml import import_tomls, update_project, reimport_project
from .find import find_record, list_projects
@ -8,6 +8,7 @@ from .tools import str2list
from .main import update_aliases
from .meas_io import drop_cache as mio_drop_cache
import os
from importlib.metadata import version
app = typer.Typer()
@ -15,7 +16,7 @@ app = typer.Typer()
def _version_callback(value: bool) -> None:
if value:
typer.echo(f"{__app_name__} v{__version__}")
print(__app_name__, version(__app_name__))
raise typer.Exit()

View file

@ -1,5 +1,4 @@
import sqlite3
import datalad.api as dl
import os
import json
import pandas as pd
@ -7,24 +6,25 @@ import numpy as np
from .input.implementations import codes
from .tools import k2m
from .tracker import get
from typing import Any, Optional
# this will implement the search functionality
def _project_lookup_by_alias(db, alias):
def _project_lookup_by_alias(db: str, alias: str) -> str:
# this will lookup the project name based on the alias
conn = sqlite3.connect(db)
c = conn.cursor()
c.execute(f"SELECT * FROM 'projects' WHERE alias = '{alias}'")
results = c.fetchall()
conn.close()
if len(results) > 1:
if len(results)>1:
print("Error: multiple projects found with alias " + alias)
elif len(results) == 0:
raise Exception("Error: no project found with alias " + alias)
return results[0][0]
return str(results[0][0])
def _project_lookup_by_id(db, uuid):
def _project_lookup_by_id(db: str, uuid: str) -> list[tuple[str, str]]:
conn = sqlite3.connect(db)
c = conn.cursor()
c.execute(f"SELECT * FROM 'projects' WHERE id = '{uuid}'")
@ -33,7 +33,8 @@ def _project_lookup_by_id(db, uuid):
return results
def _db_lookup(db, ensemble, correlator_name,code, project=None, parameters=None, created_before=None, created_after=None, updated_before=None, updated_after=None, revision=None):
def _db_lookup(db: str, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None,
created_before: Optional[str]=None, created_after: Optional[Any]=None, updated_before: Optional[Any]=None, updated_after: Optional[Any]=None) -> pd.DataFrame:
project_str = project
search_expr = f"SELECT * FROM 'backlogs' WHERE name = '{correlator_name}' AND ensemble = '{ensemble}'"
@ -57,7 +58,7 @@ def _db_lookup(db, ensemble, correlator_name,code, project=None, parameters=Non
return results
def sfcf_filter(results, **kwargs):
def sfcf_filter(results: pd.DataFrame, **kwargs: Any) -> pd.DataFrame:
drops = []
for ind in range(len(results)):
result = results.iloc[ind]
@ -140,24 +141,25 @@ def sfcf_filter(results, **kwargs):
return results.drop(drops)
def find_record(path, ensemble, correlator_name, code, project=None, parameters=None, created_before=None, created_after=None, updated_before=None, updated_after=None, revision=None, **kwargs):
def find_record(path: str, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None,
created_before: Optional[str]=None, created_after: Optional[str]=None, updated_before: Optional[str]=None, updated_after: Optional[str]=None, revision: Optional[str]=None, **kwargs: Any) -> pd.DataFrame:
db = path + '/backlogger.db'
if code not in codes:
raise ValueError("Code " + code + "unknown, take one of the following:" + ", ".join(codes))
get(path, "backlogger.db")
results = _db_lookup(db, ensemble, correlator_name,code, project, parameters=parameters, created_before=created_before, created_after=created_after, updated_before=updated_before, updated_after=updated_after, revision=revision)
results = _db_lookup(db, ensemble, correlator_name,code, project, parameters=parameters, created_before=created_before, created_after=created_after, updated_before=updated_before, updated_after=updated_after)
if code == "sfcf":
results = sfcf_filter(results, **kwargs)
print("Found " + str(len(results)) + " result" + ("s" if len(results)>1 else ""))
return results.reset_index()
def find_project(path, name):
def find_project(path: str, name: str) -> str:
get(path, "backlogger.db")
return _project_lookup_by_alias(os.path.join(path, "backlogger.db"), name)
def list_projects(path):
def list_projects(path: str) -> list[tuple[str, str]]:
db = path + '/backlogger.db'
get(path, "backlogger.db")
conn = sqlite3.connect(db)

View file

@ -5,7 +5,7 @@ import git
GITMODULES_FILE = '.gitmodules'
def move_submodule(repo_path, old_path, new_path):
def move_submodule(repo_path: str, old_path: str, new_path: str) -> None:
"""
Move a submodule to a new location.
@ -41,3 +41,4 @@ def move_submodule(repo_path, old_path, new_path):
repo.git.add('.gitmodules')
# save new state of the dataset
dl.save(repo_path, message=f"Move module from {old_path} to {new_path}", dataset=repo_path)
return

View file

@ -5,7 +5,7 @@ import os
import tracker as tr
def _create_db(db):
def _create_db(db: str) -> None:
"""
Create the database file and the table.
@ -34,6 +34,7 @@ def _create_db(db):
updated_at TEXT)''')
conn.commit()
conn.close()
return
def _create_config(path):
@ -56,7 +57,28 @@ def _create_config(path):
config.write(configfile)
def create(path):
def _create_config(path: str) -> None:
"""
Create the config file for backlogger.
"""
config = ConfigParser()
config['core'] = {
'version': '1.0',
'db_path': os.path.join(path, 'backlogger.db'),
'projects_path': os.path.join(path, 'projects'),
'archive_path': os.path.join(path, 'archive'),
'toml_imports_path': os.path.join(path, 'toml_imports'),
'import_scripts_path': os.path.join(path, 'import_scripts'),
'tracker': 'datalad',
'cached': True,
}
with open(os.path.join(path, '.corrlib'), 'w') as configfile:
config.write(configfile)
return
def create(path: str) -> None:
"""
Create folder of backlogs.
@ -73,3 +95,4 @@ def create(path):
fp.write(".cache")
fp.close()
tr.save(path, message="Initialized correlator library", dataset=path)
return

View file

@ -2,6 +2,6 @@
Import functions for different codes.
"""
from . import sfcf
from . import openQCD
from . import implementations
from . import sfcf as sfcf
from . import openQCD as openQCD
from . import implementations as implementations

View file

@ -2,7 +2,7 @@ import pyerrors.input.openQCD as input
import datalad.api as dl
import os
import fnmatch
from typing import Any
from typing import Any, Optional
def read_ms1_param(path: str, project: str, file_in_project: str) -> dict[str, Any]:
@ -67,7 +67,7 @@ def read_ms3_param(path: str, project: str, file_in_project: str) -> dict[str, A
return param
def read_rwms(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, postfix: str="ms1", version: str='2.0', names: list[str]=None, files: list[str]=None) -> dict[str, Any]:
def read_rwms(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, postfix: str="ms1", version: str='2.0', names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]:
dataset = os.path.join(path, "projects", project)
directory = os.path.join(dataset, dir_in_project)
if files is None:
@ -94,7 +94,7 @@ def read_rwms(path: str, project: str, dir_in_project: str, param: dict[str, Any
return rw_dict
def extract_t0(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str=None, names: list[str]=None, files: list[str]=None) -> dict[str, Any]:
def extract_t0(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str="", names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]:
dataset = os.path.join(path, "projects", project)
directory = os.path.join(dataset, dir_in_project)
if files is None:
@ -132,7 +132,7 @@ def extract_t0(path: str, project: str, dir_in_project: str, param: dict[str, An
return t0_dict
def extract_t1(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str = None, names: list[str]=None, files: list[str]=None) -> dict[str, Any]:
def extract_t1(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str = "", names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]:
directory = os.path.join(path, "projects", project, dir_in_project)
if files is None:
files = []

View file

@ -5,7 +5,7 @@ import os
from typing import Any
bi_corrs: list = ["f_P", "fP", "f_p",
bi_corrs: list[str] = ["f_P", "fP", "f_p",
"g_P", "gP", "g_p",
"fA0", "f_A", "f_a",
"gA0", "g_A", "g_a",
@ -43,7 +43,7 @@ bi_corrs: list = ["f_P", "fP", "f_p",
"l3A2", "l3_A2", "g_av23",
]
bb_corrs: list = [
bb_corrs: list[str] = [
'F1',
'F_1',
'f_1',
@ -64,7 +64,7 @@ bb_corrs: list = [
'F_sPdP_d',
]
bib_corrs: list = [
bib_corrs: list[str] = [
'F_V0',
'K_V0',
]
@ -184,7 +184,7 @@ def read_param(path: str, project: str, file_in_project: str) -> dict[str, Any]:
return params
def _map_params(params: dict, spec_list: list) -> dict[str, Any]:
def _map_params(params: dict[str, Any], spec_list: list[str]) -> dict[str, Any]:
"""
Map the extracted parameters to the extracted data.
@ -194,7 +194,7 @@ def _map_params(params: dict, spec_list: list) -> dict[str, Any]:
The parameters extracted from the parameter (input) file. in the dict form given by read_param.
spec_list: list
The list of specifications that belongs to the dorrelator in question.
Return
------
new_specs: dict
@ -228,7 +228,7 @@ def _map_params(params: dict, spec_list: list) -> dict[str, Any]:
return new_specs
def get_specs(key, parameters, sep='/') -> str:
def get_specs(key: str, parameters: dict[str, Any], sep: str = '/') -> str:
key_parts = key.split(sep)
if corr_types[key_parts[0]] == 'bi':
param = _map_params(parameters, key_parts[1:-1])
@ -238,7 +238,7 @@ def get_specs(key, parameters, sep='/') -> str:
return s
def read_data(path, project, dir_in_project, prefix, param, version='1.0c', cfg_seperator='n', sep='/', **kwargs) -> dict:
def read_data(path: str, project: str, dir_in_project: str, prefix: str, param: dict[str, Any], version: str = '1.0c', cfg_seperator: str = 'n', sep: str = '/', **kwargs: Any) -> dict[str, Any]:
"""
Extract the data from the sfcf file.

View file

@ -7,10 +7,10 @@ import shutil
from .find import _project_lookup_by_id
from .tools import list2str, str2list
from .tracker import get_file
from typing import Union
from typing import Union, Optional
def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Union[str, None]=None, aliases: Union[str, None]=None, code: Union[str, None]=None):
def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Union[list[str], None]=None, aliases: Union[list[str], None]=None, code: Union[str, None]=None) -> None:
"""
Create a new project entry in the database.
@ -34,10 +34,10 @@ def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Uni
raise ValueError("Project already imported, use update_project() instead.")
dl.unlock(db, dataset=path)
alias_str = None
alias_str = ""
if aliases is not None:
alias_str = list2str(aliases)
tag_str = None
tag_str = ""
if tags is not None:
tag_str = list2str(tags)
c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", (uuid, alias_str, tag_str, owner, code))
@ -46,7 +46,7 @@ def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Uni
dl.save(db, message="Added entry for project " + uuid + " to database", dataset=path)
def update_project_data(path, uuid, prop, value = None):
def update_project_data(path: str, uuid: str, prop: str, value: Union[str, None] = None) -> None:
get_file(path, "backlogger.db")
conn = sqlite3.connect(os.path.join(path, "backlogger.db"))
c = conn.cursor()
@ -56,7 +56,7 @@ def update_project_data(path, uuid, prop, value = None):
return
def update_aliases(path: str, uuid: str, aliases: list[str]):
def update_aliases(path: str, uuid: str, aliases: list[str]) -> None:
db = os.path.join(path, "backlogger.db")
get_file(path, "backlogger.db")
known_data = _project_lookup_by_id(db, uuid)[0]
@ -83,7 +83,7 @@ def update_aliases(path: str, uuid: str, aliases: list[str]):
return
def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Union[str, None]=None, aliases: Union[str, None]=None, code: Union[str, None]=None, isDataset: bool=True):
def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Optional[list[str]]=None, aliases: Optional[list[str]]=None, code: Optional[str]=None, isDataset: bool=True) -> str:
"""
Parameters
----------
@ -118,7 +118,7 @@ def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Unio
dl.install(path=tmp_path, source=url, dataset=path)
tmp_ds = dl.Dataset(tmp_path)
conf = dlc.ConfigManager(tmp_ds)
uuid = conf.get("datalad.dataset.id")
uuid = str(conf.get("datalad.dataset.id"))
if not uuid:
raise ValueError("The dataset does not have a uuid!")
if not os.path.exists(path + "/projects/" + uuid):
@ -131,21 +131,22 @@ def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Unio
dl.save([db, path + '/projects/' + uuid], message="Import project from " + url, dataset=path)
else:
dl.drop(tmp_path, reckless='kill')
shutil.rmtree(tmp_path)
shutil.rmtree(tmp_path)
if aliases is not None:
if isinstance(aliases, str):
alias_list = [aliases]
else:
alias_list = aliases
update_aliases(path, uuid, alias_list)
# make this more concrete
return uuid
def drop_project_data(path: str, uuid: str, path_in_project: str = ""):
def drop_project_data(path: str, uuid: str, path_in_project: str = "") -> None:
"""
Drop (parts of) a prject to free up diskspace
"""
dl.drop(path + "/projects/" + uuid + "/" + path_in_project)
return

View file

@ -10,9 +10,10 @@ from hashlib import sha256
from .tools import cached
from .tracker import get
import shutil
from typing import Any
def write_measurement(path, ensemble, measurement, uuid, code, parameter_file=None):
def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, dict[str, Any]]], uuid: str, code: str, parameter_file: str) -> None:
"""
Write a measurement to the backlog.
If the file for the measurement already exists, update the measurement.
@ -59,7 +60,7 @@ def write_measurement(path, ensemble, measurement, uuid, code, parameter_file=No
pars = {}
subkeys = []
for i in range(len(parameters["rw_fcts"])):
par_list = []
par_list = []
for k in parameters["rw_fcts"][i].keys():
par_list.append(str(parameters["rw_fcts"][i][k]))
subkey = "/".join(par_list)
@ -80,12 +81,12 @@ def write_measurement(path, ensemble, measurement, uuid, code, parameter_file=No
subkey = "/".join(par_list)
subkeys = [subkey]
pars[subkey] = json.dumps(parameters)
for subkey in subkeys:
for subkey in subkeys:
parHash = sha256(str(pars[subkey]).encode('UTF-8')).hexdigest()
meas_path = file_in_archive + "::" + parHash
known_meas[parHash] = measurement[corr][subkey]
if c.execute("SELECT * FROM backlogs WHERE path = ?", (meas_path,)).fetchone() is not None:
c.execute("UPDATE backlogs SET updated_at = datetime('now') WHERE path = ?", (meas_path, ))
else:
@ -98,7 +99,7 @@ def write_measurement(path, ensemble, measurement, uuid, code, parameter_file=No
dl.save(files, message="Add measurements to database", dataset=path)
def load_record(path: str, meas_path: str):
def load_record(path: str, meas_path: str) -> Union[Corr, Obs]:
"""
Load a list of records by their paths.
@ -108,7 +109,7 @@ def load_record(path: str, meas_path: str):
Path of the correlator library.
meas_path: str
The path to the correlator in the backlog system.
Returns
-------
co : Corr or Obs
@ -117,7 +118,7 @@ def load_record(path: str, meas_path: str):
return load_records(path, [meas_path])[0]
def load_records(path: str, meas_paths: list[str], preloaded = {}) -> list[Union[Corr, Obs]]:
def load_records(path: str, meas_paths: list[str], preloaded: dict[str, Any] = {}) -> list[Union[Corr, Obs]]:
"""
Load a list of records by their paths.
@ -127,7 +128,7 @@ def load_records(path: str, meas_paths: list[str], preloaded = {}) -> list[Union
Path of the correlator library.
meas_paths: list[str]
A list of the paths to the correlator in the backlog system.
Returns
-------
List
@ -139,7 +140,7 @@ def load_records(path: str, meas_paths: list[str], preloaded = {}) -> list[Union
needed_data[file] = []
key = mpath.split("::")[1]
needed_data[file].append(key)
returned_data: list = []
returned_data: list[Any] = []
for file in needed_data.keys():
for key in list(needed_data[file]):
if os.path.exists(cache_path(path, file, key) + ".p"):
@ -155,7 +156,7 @@ def load_records(path: str, meas_paths: list[str], preloaded = {}) -> list[Union
return returned_data
def cache_dir(path, file):
def cache_dir(path: str, file: str) -> str:
cache_path_list = [path]
cache_path_list.append(".cache")
cache_path_list.extend(file.split("/")[1:])
@ -163,19 +164,19 @@ def cache_dir(path, file):
return cache_path
def cache_path(path, file, key):
def cache_path(path: str, file: str, key: str) -> str:
cache_path = os.path.join(cache_dir(path, file), key)
return cache_path
def preload(path: str, file: str):
def preload(path: str, file: str) -> dict[str, Any]:
get(path, file)
filedict = pj.load_json_dict(os.path.join(path, file))
filedict: dict[str, Any] = pj.load_json_dict(os.path.join(path, file))
print("> read file")
return filedict
def drop_record(path: str, meas_path: str):
def drop_record(path: str, meas_path: str) -> None:
file_in_archive = meas_path.split("::")[0]
file = os.path.join(path, file_in_archive)
db = os.path.join(path, 'backlogger.db')
@ -200,7 +201,9 @@ def drop_record(path: str, meas_path: str):
else:
raise ValueError("This measurement does not exist as a file!")
def drop_cache(path: str):
def drop_cache(path: str) -> None:
cache_dir = os.path.join(path, ".cache")
for f in os.listdir(cache_dir):
shutil.rmtree(os.path.join(cache_dir, f))
return

View file

@ -16,15 +16,16 @@ from .meas_io import write_measurement
import datalad.api as dl
import os
from .input.implementations import codes as known_codes
from typing import Any
def replace_string(string: str, name: str, val: str):
def replace_string(string: str, name: str, val: str) -> str:
if '{' + name + '}' in string:
n = string.replace('{' + name + '}', val)
return n
else:
return string
def replace_in_meas(measurements: dict, vars: dict[str, str]):
def replace_in_meas(measurements: dict[str, dict[str, Any]], vars: dict[str, str]) -> dict[str, dict[str, Any]]:
# replace global variables
for name, value in vars.items():
for m in measurements.keys():
@ -36,7 +37,7 @@ def replace_in_meas(measurements: dict, vars: dict[str, str]):
measurements[m][key][i] = replace_string(measurements[m][key][i], name, value)
return measurements
def fill_cons(measurements, constants):
def fill_cons(measurements: dict[str, dict[str, Any]], constants: dict[str, str]) -> dict[str, dict[str, Any]]:
for m in measurements.keys():
for name, val in constants.items():
if name not in measurements[m].keys():
@ -44,7 +45,7 @@ def fill_cons(measurements, constants):
return measurements
def check_project_data(d: dict) -> None:
def check_project_data(d: dict[str, dict[str, str]]) -> None:
if 'project' not in d.keys() or 'measurements' not in d.keys() or len(list(d.keys())) > 4:
raise ValueError('There should only be maximally be four keys on the top level, "project" and "measurements" are mandatory, "contants" is optional!')
project_data = d['project']
@ -57,7 +58,7 @@ def check_project_data(d: dict) -> None:
return
def check_measurement_data(measurements: dict, code: str) -> None:
def check_measurement_data(measurements: dict[str, dict[str, str]], code: str) -> None:
var_names: list[str] = []
if code == "sfcf":
var_names = ["path", "ensemble", "param_file", "version", "prefix", "cfg_seperator", "names"]
@ -91,14 +92,14 @@ def import_toml(path: str, file: str, copy_file: bool=True) -> None:
with open(file, 'rb') as fp:
toml_dict = toml.load(fp)
check_project_data(toml_dict)
project: dict = toml_dict['project']
project: dict[str, Any] = toml_dict['project']
if project['code'] not in known_codes:
raise ValueError('Code' + project['code'] + 'has no import implementation!')
measurements: dict = toml_dict['measurements']
measurements: dict[str, dict[str, Any]] = toml_dict['measurements']
measurements = fill_cons(measurements, toml_dict['constants'] if 'constants' in toml_dict else {})
measurements = replace_in_meas(measurements, toml_dict['replace'] if 'replace' in toml_dict else {})
check_measurement_data(measurements, project['code'])
aliases = project.get('aliases', None)
aliases = project.get('aliases', [])
uuid = project.get('uuid', None)
if uuid is not None:
if not os.path.exists(path + "/projects/" + uuid):
@ -133,16 +134,16 @@ def import_toml(path: str, file: str, copy_file: bool=True) -> None:
for rwp in ["integrator", "eps", "ntot", "dnms"]:
param[rwp] = "Unknown"
param['type'] = 't0'
measurement = openQCD.extract_t0(path, uuid, md['path'], param, md["prefix"], md["dtr_read"], md["xmin"], md["spatial_extent"],
fit_range=md.get('fit_range', 5), postfix=md.get('postfix', None), names=md.get('names', None), files=md.get('files', None))
measurement = openQCD.extract_t0(path, uuid, md['path'], param, str(md["prefix"]), int(md["dtr_read"]), int(md["xmin"]), int(md["spatial_extent"]),
fit_range=int(md.get('fit_range', 5)), postfix=str(md.get('postfix', '')), names=md.get('names', []), files=md.get('files', []))
elif md['measurement'] == 't1':
if 'param_file' in md:
param = openQCD.read_ms3_param(path, uuid, md['param_file'])
param['type'] = 't1'
measurement = openQCD.extract_t1(path, uuid, md['path'], param, md["prefix"], md["dtr_read"], md["xmin"], md["spatial_extent"],
fit_range=md.get('fit_range', 5), postfix=md.get('postfix', None), names=md.get('names', None), files=md.get('files', None))
measurement = openQCD.extract_t1(path, uuid, md['path'], param, str(md["prefix"]), int(md["dtr_read"]), int(md["xmin"]), int(md["spatial_extent"]),
fit_range=int(md.get('fit_range', 5)), postfix=str(md.get('postfix', '')), names=md.get('names', []), files=md.get('files', []))
write_measurement(path, ensemble, measurement, uuid, project['code'], (md['param_file'] if 'param_file' in md else None))
write_measurement(path, ensemble, measurement, uuid, project['code'], (md['param_file'] if 'param_file' in md else ''))
if not os.path.exists(os.path.join(path, "toml_imports", uuid)):
os.makedirs(os.path.join(path, "toml_imports", uuid))
@ -155,7 +156,7 @@ def import_toml(path: str, file: str, copy_file: bool=True) -> None:
return
def reimport_project(path, uuid):
def reimport_project(path: str, uuid: str) -> None:
"""
Reimport an existing project using the files that are already available for this project.
@ -173,6 +174,7 @@ def reimport_project(path, uuid):
return
def update_project(path, uuid):
def update_project(path: str, uuid: str) -> None:
dl.update(how='merge', follow='sibling', dataset=os.path.join(path, "projects", uuid))
# reimport_project(path, uuid)
return

View file

@ -2,20 +2,20 @@ import os
from configparser import ConfigParser
def str2list(string):
def str2list(string: str) -> list[str]:
return string.split(",")
def list2str(mylist):
def list2str(mylist: list[str]) -> str:
s = ",".join(mylist)
return s
cached = True
cached: bool = True
def m2k(m):
def m2k(m: float) -> float:
return 1/(2*m+8)
def k2m(k):
def k2m(k: float) -> float:
return (1/(2*k))-4

View file

@ -1 +1,34 @@
__version__ = "0.2.3"
# file generated by setuptools-scm
# don't change, don't track in version control
__all__ = [
"__version__",
"__version_tuple__",
"version",
"version_tuple",
"__commit_id__",
"commit_id",
]
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple
from typing import Union
VERSION_TUPLE = Tuple[Union[int, str], ...]
COMMIT_ID = Union[str, None]
else:
VERSION_TUPLE = object
COMMIT_ID = object
version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
commit_id: COMMIT_ID
__commit_id__: COMMIT_ID
__version__ = version = '0.2.4.dev14+g602324f84.d20251202'
__version_tuple__ = version_tuple = (0, 2, 4, 'dev14', 'g602324f84.d20251202')
__commit_id__ = commit_id = 'g602324f84'

View file

@ -1,6 +1,52 @@
[build-system]
requires = ["setuptools >= 63.0.0", "wheel"]
requires = ["setuptools >= 63.0.0", "wheel", "setuptools-scm"]
build-backend = "setuptools.build_meta"
[project]
requires-python = ">=3.10"
name = "corrlib"
dynamic = ["version"]
dependencies = [
"gitpython>=3.1.45",
'pyerrors>=2.11.1',
"datalad>=1.1.0",
'typer>=0.12.5',
]
description = "Python correlation library"
authors = [
{ name = 'Justus Kuhlmann', email = 'j_kuhl19@uni-muenster.de'}
]
[project.scripts]
pcl = "corrlib.cli:app"
[tool.setuptools.packages.find]
include = ["corrlib", "corrlib.*"]
[tool.setuptools_scm]
write_to = "corrlib/version.py"
[tool.ruff.lint]
ignore = ["F403"]
ignore = ["E501"]
extend-select = [
"YTT",
"E",
"W",
"F",
]
[tool.mypy]
strict = true
implicit_reexport = false
follow_untyped_imports = false
ignore_missing_imports = true
[dependency-groups]
dev = [
"mypy>=1.19.0",
"pandas-stubs>=2.3.3.251201",
"pytest>=9.0.1",
"pytest-cov>=7.0.0",
"pytest-pretty>=1.3.0",
"ruff>=0.14.7",
]

View file

@ -1,18 +0,0 @@
from setuptools import setup
from distutils.util import convert_path
version = {}
with open(convert_path('corrlib/version.py')) as ver_file:
exec(ver_file.read(), version)
setup(name='pycorrlib',
version=version['__version__'],
author='Justus Kuhlmann',
author_email='j_kuhl19@uni-muenster.de',
install_requires=['pyerrors>=2.11.1', 'datalad>=1.1.0', 'typer>=0.12.5'],
entry_points = {
'console_scripts': ['pcl=corrlib.cli:app'],
},
packages=['corrlib', 'corrlib.input']
)

91
tests/cli_test.py Normal file
View file

@ -0,0 +1,91 @@
from typer.testing import CliRunner
from corrlib.cli import app
import os
import sqlite3 as sql
runner = CliRunner()
def test_version():
result = runner.invoke(app, ["--version"])
assert result.exit_code == 0
assert "corrlib" in result.output
def test_init_folders(tmp_path):
dataset_path = tmp_path / "test_dataset"
result = runner.invoke(app, ["init", "--dataset", str(dataset_path)])
assert result.exit_code == 0
assert os.path.exists(str(dataset_path))
assert os.path.exists(str(dataset_path / "backlogger.db"))
def test_init_db(tmp_path):
dataset_path = tmp_path / "test_dataset"
result = runner.invoke(app, ["init", "--dataset", str(dataset_path)])
assert result.exit_code == 0
assert os.path.exists(str(dataset_path / "backlogger.db"))
conn = sql.connect(str(dataset_path / "backlogger.db"))
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
expected_tables = [
'projects',
'backlogs',
]
table_names = [table[0] for table in tables]
for expected_table in expected_tables:
assert expected_table in table_names
cursor.execute("SELECT * FROM projects;")
projects = cursor.fetchall()
assert len(projects) == 0
cursor.execute("SELECT * FROM backlogs;")
backlogs = cursor.fetchall()
assert len(backlogs) == 0
cursor.execute("PRAGMA table_info('projects');")
project_columns = cursor.fetchall()
expected_project_columns = [
"id",
"aliases",
"customTags",
"owner",
"code",
"created_at",
"updated_at"
]
project_column_names = [col[1] for col in project_columns]
for expected_col in expected_project_columns:
assert expected_col in project_column_names
cursor.execute("PRAGMA table_info('backlogs');")
backlog_columns = cursor.fetchall()
expected_backlog_columns = [
"id",
"name",
"ensemble",
"code",
"path",
"project",
"customTags",
"parameters",
"parameter_file",
"created_at",
"updated_at"
]
backlog_column_names = [col[1] for col in backlog_columns]
for expected_col in expected_backlog_columns:
assert expected_col in backlog_column_names
def test_list(tmp_path):
dataset_path = tmp_path / "test_dataset"
result = runner.invoke(app, ["init", "--dataset", str(dataset_path)])
assert result.exit_code == 0
result = runner.invoke(app, ["list", "--dataset", str(dataset_path), "ensembles"])
assert result.exit_code == 0
result = runner.invoke(app, ["list", "--dataset", str(dataset_path), "projects"])
assert result.exit_code == 0

View file

@ -14,4 +14,4 @@ def test_toml_check_measurement_data():
"names": ['list', 'of', 'names']
}
}
t.check_measurement_data(measurements)
t.check_measurement_data(measurements, "sfcf")

View file

@ -0,0 +1,68 @@
import corrlib.initialization as init
import os
import sqlite3 as sql
def test_init_folders(tmp_path):
dataset_path = tmp_path / "test_dataset"
init.create(str(dataset_path))
assert os.path.exists(str(dataset_path))
assert os.path.exists(str(dataset_path / "backlogger.db"))
def test_init_db(tmp_path):
dataset_path = tmp_path / "test_dataset"
init.create(str(dataset_path))
assert os.path.exists(str(dataset_path / "backlogger.db"))
conn = sql.connect(str(dataset_path / "backlogger.db"))
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
expected_tables = [
'projects',
'backlogs',
]
table_names = [table[0] for table in tables]
for expected_table in expected_tables:
assert expected_table in table_names
cursor.execute("SELECT * FROM projects;")
projects = cursor.fetchall()
assert len(projects) == 0
cursor.execute("SELECT * FROM backlogs;")
backlogs = cursor.fetchall()
assert len(backlogs) == 0
cursor.execute("PRAGMA table_info('projects');")
project_columns = cursor.fetchall()
expected_project_columns = [
"id",
"aliases",
"customTags",
"owner",
"code",
"created_at",
"updated_at"
]
project_column_names = [col[1] for col in project_columns]
for expected_col in expected_project_columns:
assert expected_col in project_column_names
cursor.execute("PRAGMA table_info('backlogs');")
backlog_columns = cursor.fetchall()
expected_backlog_columns = [
"id",
"name",
"ensemble",
"code",
"path",
"project",
"customTags",
"parameters",
"parameter_file",
"created_at",
"updated_at"
]
backlog_column_names = [col[1] for col in backlog_columns]
for expected_col in expected_backlog_columns:
assert expected_col in backlog_column_names

31
tests/tools_test.py Normal file
View file

@ -0,0 +1,31 @@
from corrlib import tools as tl
def test_m2k():
for m in [0.1, 0.5, 1.0]:
expected_k = 1 / (2 * m + 8)
assert tl.m2k(m) == expected_k
def test_k2m():
for m in [0.1, 0.5, 1.0]:
assert tl.k2m(m) == (1/(2*m))-4
def test_k2m_m2k():
for m in [0.1, 0.5, 1.0]:
k = tl.m2k(m)
m_converted = tl.k2m(k)
assert abs(m - m_converted) < 1e-9
def test_str2list():
assert tl.str2list("a,b,c") == ["a", "b", "c"]
assert tl.str2list("1,2,3") == ["1", "2", "3"]
def test_list2str():
assert tl.list2str(["a", "b", "c"]) == "a,b,c"
assert tl.list2str(["1", "2", "3"]) == "1,2,3"

2518
uv.lock generated Normal file

File diff suppressed because it is too large Load diff