diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index fbd51ec..791243f 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -8,21 +8,22 @@ on: jobs: mypy: runs-on: ubuntu-latest + env: + UV_CACHE_DIR: /tmp/.uv-cache steps: - name: Install git-annex run: | sudo apt-get update - sudo apt-get install -y git-annex + sudo apt-get install -y git-annex - name: Check out the repository uses: https://github.com/RouxAntoine/checkout@v4.1.8 with: show-progress: true - - name: Setup python - uses: https://github.com/actions/setup-python@v5 - with: - python-version: "3.12" - name: Install uv - uses: https://github.com/astral-sh/setup-uv@v5 + uses: astral-sh/setup-uv@v7 + with: + python-version: ${{ matrix.python-version }} + enable-cache: true - name: Install corrlib run: uv sync --locked --all-extras --dev --python "3.12" - name: Run tests diff --git a/.github/workflows/pytest.yaml b/.github/workflows/pytest.yaml index b1a4d94..1fcb8fe 100644 --- a/.github/workflows/pytest.yaml +++ b/.github/workflows/pytest.yaml @@ -17,11 +17,9 @@ jobs: - "3.14" runs-on: ubuntu-latest + env: + UV_CACHE_DIR: /tmp/.uv-cache steps: - - name: Setup git - run: | - git config --global user.email "tester@example.com" - git config --global user.name "Tester" - name: Install git-annex run: | sudo apt-get update @@ -30,12 +28,11 @@ jobs: uses: https://github.com/RouxAntoine/checkout@v4.1.8 with: show-progress: true - - name: Setup python - uses: https://github.com/actions/setup-python@v5 + - name: Install uv + uses: astral-sh/setup-uv@v7 with: python-version: ${{ matrix.python-version }} - - name: Install uv - uses: https://github.com/astral-sh/setup-uv@v5 + enable-cache: true - name: Install corrlib run: uv sync --locked --all-extras --dev --python ${{ matrix.python-version }} - name: Run tests diff --git a/.github/workflows/ruff.yaml b/.github/workflows/ruff.yaml index 1da1225..4de4b0b 100644 --- a/.github/workflows/ruff.yaml +++ b/.github/workflows/ruff.yaml @@ -9,6 +9,8 @@ jobs: ruff: runs-on: ubuntu-latest + env: + UV_CACHE_DIR: /tmp/.uv-cache steps: - name: Install git-annex run: | @@ -18,12 +20,10 @@ jobs: uses: https://github.com/RouxAntoine/checkout@v4.1.8 with: show-progress: true - - name: Setup python - uses: https://github.com/actions/setup-python@v5 - with: - python-version: "3.12" - name: Install uv - uses: https://github.com/astral-sh/setup-uv@v5 + uses: astral-sh/setup-uv@v7 + with: + enable-cache: true - name: Install corrlib run: uv sync --locked --all-extras --dev --python "3.12" - name: Run tests diff --git a/.gitignore b/.gitignore index f97ff98..22957fb 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,8 @@ pyerrors_corrlib.egg-info __pycache__ *.egg-info test.ipynb +test_ds .vscode .venv .pytest_cache -.coverage \ No newline at end of file +.coverage diff --git a/README.md b/README.md index 976ae57..0f6c9a3 100644 --- a/README.md +++ b/README.md @@ -5,3 +5,12 @@ This is done in a reproducible way using `datalad`. In principle, a dataset is created, that is automatically administered by the backlogger, in which data from differnt projects are held together. Everything is catalogued by a searchable SQL database, which holds the paths to the respective measurements. The original projects can be linked to the dataset and the data may be imported using wrapper functions around the read methonds of pyerrors. + +We work with the following nomenclature in this project: +- Measurement + A setis of Observables, including the appropriate metadata. +- Project + A series of measurements that was done by one person as part of their research. +- Record + An entry of a single Correlator in the database of the backlogger. +- \ No newline at end of file diff --git a/TODO.md b/TODO.md index 4153fc3..ba32ec9 100644 --- a/TODO.md +++ b/TODO.md @@ -1,14 +1,21 @@ # TODO ## Features -- implement import of non-datalad projects -- implement a way to use another backlog repo as a project - -- find a way to convey the mathematical structure of what EXACTLY is the form of the correlator in a specific project - - this could e.g. be done along the lines of mandatory documentation -- keep better track of the versions of the code, that was used for a specific measurement. - - maybe let this be an input in the project file? - - git repo and commit hash/version tag - +- [ ] implement import of non-datalad projects +- [ ] implement a way to use another backlog repo as a project +- [ ] make cache deadlock resistent (no read while writing) +- [ ] find a way to convey the mathematical structure of what EXACTLY is the form of the correlator in a specific project + - [ ] this could e.g. be done along the lines of mandatory documentation +- [ ] keep better track of the versions of the code, that was used for a specific measurement. + - [ ] maybe let this be an input in the project file? + - [ ] git repo and commit hash/version tag + - [ ] implement a code table? +- [ ] parallel processing of measurements +- [ ] extra SQL table for ensembles with UUID and aliases ## Bugfixes - [ ] revisit the reimport function for single files +- [ ] drop record needs to look if no records are left in a json file. + +## Rough Ideas +- [ ] multitable could provide a high speed implementation of an HDF5 based format +- [ ] implement also a way to include compiled binaries in the archives. diff --git a/corrlib/__init__.py b/corrlib/__init__.py index 4e1b364..448b4d5 100644 --- a/corrlib/__init__.py +++ b/corrlib/__init__.py @@ -22,3 +22,4 @@ from .meas_io import load_records as load_records from .find import find_project as find_project from .find import find_record as find_record from .find import list_projects as list_projects +from .tools import * diff --git a/corrlib/cache_io.py b/corrlib/cache_io.py new file mode 100644 index 0000000..63d2e68 --- /dev/null +++ b/corrlib/cache_io.py @@ -0,0 +1,58 @@ +from typing import Optional +import os +import shutil +from .tools import record2name_key +import datalad.api as dl +import sqlite3 +from tools import db_filename + + +def get_version_hash(path: str, record: str) -> str: + db = os.path.join(path, db_filename(path)) + dl.get(db, dataset=path) + conn = sqlite3.connect(db) + c = conn.cursor() + c.execute(f"SELECT current_version FROM 'backlogs' WHERE path = '{record}'") + return str(c.fetchall()[0][0]) + + +def drop_cache_files(path: str, fs: Optional[list[str]]=None) -> None: + cache_dir = os.path.join(path, ".cache") + if fs is None: + fs = os.listdir(cache_dir) + for f in fs: + shutil.rmtree(os.path.join(cache_dir, f)) + + +def cache_dir(path: str, file: str) -> str: + cache_path_list = [path] + cache_path_list.append(".cache") + cache_path_list.extend(file.split("/")[1:]) + cache_path = "/".join(cache_path_list) + return cache_path + + +def cache_path(path: str, file: str, sha_hash: str, key: str) -> str: + cache_path = os.path.join(cache_dir(path, file), key + "_" + sha_hash) + return cache_path + + +def is_old_version(path: str, record: str) -> bool: + version_hash = get_version_hash(path, record) + file, key = record2name_key(record) + meas_cache_path = os.path.join(cache_dir(path, file)) + ls = [] + is_old = True + for p, ds, fs in os.walk(meas_cache_path): + ls.extend(fs) + for filename in ls: + if key == filename.split("_")[0]: + if version_hash == filename.split("_")[1][:-2]: + is_old = False + return is_old + + +def is_in_cache(path: str, record: str) -> bool: + version_hash = get_version_hash(path, record) + file, key = record2name_key(record) + return os.path.exists(cache_path(path, file, version_hash, key) + ".p") diff --git a/corrlib/cli.py b/corrlib/cli.py index d24d8ef..c4e1e4b 100644 --- a/corrlib/cli.py +++ b/corrlib/cli.py @@ -1,20 +1,14 @@ from typing import Optional import typer from corrlib import __app_name__ - from .initialization import create from .toml import import_tomls, update_project, reimport_project from .find import find_record, list_projects from .tools import str2list from .main import update_aliases -from .meas_io import drop_cache as mio_drop_cache -from .meas_io import load_record as mio_load_record -from .integrity import full_integrity_check - +from .cache_io import drop_cache_files as cio_drop_cache_files import os -from pyerrors import Corr from importlib.metadata import version -from pathlib import Path app = typer.Typer() @@ -28,8 +22,8 @@ def _version_callback(value: bool) -> None: @app.command() def update( - path: Path = typer.Option( - Path('./corrlib'), + path: str = typer.Option( + str('./corrlib'), "--dataset", "-d", ), @@ -41,11 +35,10 @@ def update( update_project(path, uuid) return - @app.command() -def lister( - path: Path = typer.Option( - Path('./corrlib'), +def list( + path: str = typer.Option( + str('./corrlib'), "--dataset", "-d", ), @@ -56,8 +49,8 @@ def lister( """ if entities in ['ensembles', 'Ensembles','ENSEMBLES']: print("Ensembles:") - for item in os.listdir(path / "archive"): - if os.path.isdir(path / "archive" / item): + for item in os.listdir(path + "/archive"): + if os.path.isdir(os.path.join(path + "/archive", item)): print(item) elif entities == 'projects': results = list_projects(path) @@ -75,8 +68,8 @@ def lister( @app.command() def alias_add( - path: Path = typer.Option( - Path('./corrlib'), + path: str = typer.Option( + str('./corrlib'), "--dataset", "-d", ), @@ -93,67 +86,26 @@ def alias_add( @app.command() def find( - path: Path = typer.Option( - Path('./corrlib'), + path: str = typer.Option( + str('./corrlib'), "--dataset", "-d", ), ensemble: str = typer.Argument(), corr: str = typer.Argument(), code: str = typer.Argument(), - arg: str = typer.Option( - str('all'), - "--argument", - "-a", - ), ) -> None: """ Find a record in the backlog at hand. Through specifying it's ensemble and the measured correlator. """ results = find_record(path, ensemble, corr, code) - if results.empty: - return - if arg == 'all': - print(results) - else: - for r in results[arg].values: - print(r) - - -@app.command() -def stat( - path: Path = typer.Option( - Path('./corrlib'), - "--dataset", - "-d", - ), - record_id: str = typer.Argument(), - ) -> None: - """ - Show the statistics of a given record. - """ - record = mio_load_record(path, record_id) - if isinstance(record, (list, Corr)): - record = record[0] - statistics = record.idl - print(statistics) - return - - -@app.command() -def check(path: Path = typer.Option( - Path('./corrlib'), - "--dataset", - "-d", - ), - ) -> None: - full_integrity_check(path) + print(results) @app.command() def importer( - path: Path = typer.Option( - Path('./corrlib'), + path: str = typer.Option( + str('./corrlib'), "--dataset", "-d", ), @@ -175,8 +127,8 @@ def importer( @app.command() def reimporter( - path: Path = typer.Option( - Path('./corrlib'), + path: str = typer.Option( + str('./corrlib'), "--dataset", "-d", ), @@ -199,8 +151,8 @@ def reimporter( @app.command() def init( - path: Path = typer.Option( - Path('./corrlib'), + path: str = typer.Option( + str('./corrlib'), "--dataset", "-d", ), @@ -219,8 +171,8 @@ def init( @app.command() def drop_cache( - path: Path = typer.Option( - Path('./corrlib'), + path: str = typer.Option( + str('./corrlib'), "--dataset", "-d", ), @@ -228,7 +180,7 @@ def drop_cache( """ Drop the currect cache directory of the dataset. """ - mio_drop_cache(path) + cio_drop_cache_files(path) return diff --git a/corrlib/find.py b/corrlib/find.py index 7b07321..5d0a678 100644 --- a/corrlib/find.py +++ b/corrlib/find.py @@ -4,17 +4,12 @@ import json import pandas as pd import numpy as np from .input.implementations import codes -from .tools import k2m, get_db_file +from .tools import k2m, db_filename from .tracker import get -from .integrity import has_valid_times from typing import Any, Optional -from pathlib import Path -import datetime as dt -from collections.abc import Callable -import warnings -def _project_lookup_by_alias(db: Path, alias: str) -> str: +def _project_lookup_by_alias(db: str, alias: str) -> str: """ Lookup a projects UUID by its (human-readable) alias. @@ -32,7 +27,7 @@ def _project_lookup_by_alias(db: Path, alias: str) -> str: """ conn = sqlite3.connect(db) c = conn.cursor() - c.execute(f"SELECT * FROM 'projects' WHERE aliases = '{alias}'") + c.execute(f"SELECT * FROM 'projects' WHERE alias = '{alias}'") results = c.fetchall() conn.close() if len(results)>1: @@ -42,7 +37,7 @@ def _project_lookup_by_alias(db: Path, alias: str) -> str: return str(results[0][0]) -def _project_lookup_by_id(db: Path, uuid: str) -> list[tuple[str, ...]]: +def _project_lookup_by_id(db: str, uuid: str) -> list[tuple[str, str]]: """ Return the project information available in the database by UUID. @@ -66,56 +61,8 @@ def _project_lookup_by_id(db: Path, uuid: str) -> list[tuple[str, ...]]: return results -def _time_filter(results: pd.DataFrame, created_before: Optional[str]=None, created_after: Optional[Any]=None, updated_before: Optional[Any]=None, updated_after: Optional[Any]=None) -> pd.DataFrame: - """ - Filter the results from the database in terms of the creation and update times. - - Parameters - ---------- - results: pd.DataFrame - The dataframe holding the unfilteres results from the database. - created_before: str - Contraint on the creation date in datetime.datetime.isoformat. Note that this is exclusive. The creation date has to be truly before the date and time given. - created_after: str - Contraint on the creation date in datetime.datetime.isoformat. Note that this is exclusive. The creation date has to be truly after the date and time given. - updated_before: str - Contraint on the creation date in datetime.datetime.isoformat. Note that this is exclusive. The date of the last update has to be truly before the date and time given. - updated_after: str - Contraint on the creation date in datetime.datetime.isoformat. Note that this is exclusive. The date of the last update has to be truly after the date and time given. - """ - drops = [] - for ind in range(len(results)): - result = results.iloc[ind] - created_at = dt.datetime.fromisoformat(result['created_at']) - updated_at = dt.datetime.fromisoformat(result['updated_at']) - db_times_valid = has_valid_times(result) - if not db_times_valid: - raise ValueError('Time stamps not valid for result with path', result["path"]) - - if created_before is not None: - date_created_before = dt.datetime.fromisoformat(created_before) - if date_created_before < created_at: - drops.append(ind) - continue - if created_after is not None: - date_created_after = dt.datetime.fromisoformat(created_after) - if date_created_after > created_at: - drops.append(ind) - continue - if updated_before is not None: - date_updated_before = dt.datetime.fromisoformat(updated_before) - if date_updated_before < updated_at: - drops.append(ind) - continue - if updated_after is not None: - date_updated_after = dt.datetime.fromisoformat(updated_after) - if date_updated_after > updated_at: - drops.append(ind) - continue - return results.drop(drops) - - -def _db_lookup(db: Path, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None) -> pd.DataFrame: +def _db_lookup(db: str, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None, + created_before: Optional[str]=None, created_after: Optional[Any]=None, updated_before: Optional[Any]=None, updated_after: Optional[Any]=None) -> pd.DataFrame: """ Look up a correlator record in the database by the data given to the method. @@ -157,86 +104,22 @@ def _db_lookup(db: Path, ensemble: str, correlator_name: str, code: str, project search_expr += f" AND code = '{code}'" if parameters: search_expr += f" AND parameters = '{parameters}'" + if created_before: + search_expr += f" AND created_at < '{created_before}'" + if created_after: + search_expr += f" AND created_at > '{created_after}'" + if updated_before: + search_expr += f" AND updated_at < '{updated_before}'" + if updated_after: + search_expr += f" AND updated_at > '{updated_after}'" conn = sqlite3.connect(db) results = pd.read_sql(search_expr, conn) conn.close() return results -def _sfcf_drop(param: dict[str, Any], **kwargs: Any) -> bool: - if 'offset' in kwargs: - if kwargs.get('offset') != param['offset']: - return True - if 'quark_kappas' in kwargs: - kappas = kwargs['quark_kappas'] - if (not np.isclose(kappas[0], param['quarks'][0]['mass']) or not np.isclose(kappas[1], param['quarks'][1]['mass'])): - return True - if 'quark_masses' in kwargs: - masses = kwargs['quark_masses'] - if (not np.isclose(masses[0], k2m(param['quarks'][0]['mass'])) or not np.isclose(masses[1], k2m(param['quarks'][1]['mass']))): - return True - if 'qk1' in kwargs: - quark_kappa1 = kwargs['qk1'] - if not isinstance(quark_kappa1, list): - if (not np.isclose(quark_kappa1, param['quarks'][0]['mass'])): - return True - else: - if len(quark_kappa1) == 2: - if (quark_kappa1[0] > param['quarks'][0]['mass']) or (quark_kappa1[1] < param['quarks'][0]['mass']): - return True - else: - raise ValueError("quark_kappa1 has to have length 2") - if 'qk2' in kwargs: - quark_kappa2 = kwargs['qk2'] - if not isinstance(quark_kappa2, list): - if (not np.isclose(quark_kappa2, param['quarks'][1]['mass'])): - return True - else: - if len(quark_kappa2) == 2: - if (quark_kappa2[0] > param['quarks'][1]['mass']) or (quark_kappa2[1] < param['quarks'][1]['mass']): - return True - else: - raise ValueError("quark_kappa2 has to have length 2") - if 'qm1' in kwargs: - quark_mass1 = kwargs['qm1'] - if not isinstance(quark_mass1, list): - if (not np.isclose(quark_mass1, k2m(param['quarks'][0]['mass']))): - return True - else: - if len(quark_mass1) == 2: - if (quark_mass1[0] > k2m(param['quarks'][0]['mass'])) or (quark_mass1[1] < k2m(param['quarks'][0]['mass'])): - return True - else: - raise ValueError("quark_mass1 has to have length 2") - if 'qm2' in kwargs: - quark_mass2 = kwargs['qm2'] - if not isinstance(quark_mass2, list): - if (not np.isclose(quark_mass2, k2m(param['quarks'][1]['mass']))): - return True - else: - if len(quark_mass2) == 2: - if (quark_mass2[0] > k2m(param['quarks'][1]['mass'])) or (quark_mass2[1] < k2m(param['quarks'][1]['mass'])): - return True - else: - raise ValueError("quark_mass2 has to have length 2") - if 'quark_thetas' in kwargs: - quark_thetas = kwargs['quark_thetas'] - if (quark_thetas[0] != param['quarks'][0]['thetas'] and quark_thetas[1] != param['quarks'][1]['thetas']) or (quark_thetas[0] != param['quarks'][1]['thetas'] and quark_thetas[1] != param['quarks'][0]['thetas']): - return True - # careful, this is not save, when multiple contributions are present! - if 'wf1' in kwargs: - wf1 = kwargs['wf1'] - if not (np.isclose(wf1[0][0], param['wf1'][0][0], 1e-8) and np.isclose(wf1[0][1][0], param['wf1'][0][1][0], 1e-8) and np.isclose(wf1[0][1][1], param['wf1'][0][1][1], 1e-8)): - return True - if 'wf2' in kwargs: - wf2 = kwargs['wf2'] - if not (np.isclose(wf2[0][0], param['wf2'][0][0], 1e-8) and np.isclose(wf2[0][1][0], param['wf2'][0][1][0], 1e-8) and np.isclose(wf2[0][1][1], param['wf2'][0][1][1], 1e-8)): - return True - return False - - def sfcf_filter(results: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: - r""" + """ Filter method for the Database entries holding SFCF calculations. Parameters @@ -252,9 +135,9 @@ def sfcf_filter(results: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: qk2: float, optional Mass parameter $\kappa_2$ of the first quark. qm1: float, optional - Bare quark mass $m_1$ of the first quark. + Bare quak mass $m_1$ of the first quark. qm2: float, optional - Bare quark mass $m_2$ of the first quark. + Bare quak mass $m_1$ of the first quark. quarks_thetas: list[list[float]], optional wf1: optional wf2: optional @@ -264,85 +147,106 @@ def sfcf_filter(results: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: results: pd.DataFrame The filtered DataFrame, only holding the records that fit to the parameters given. """ - drops = [] for ind in range(len(results)): result = results.iloc[ind] param = json.loads(result['parameters']) - if _sfcf_drop(param, **kwargs): - drops.append(ind) + if 'offset' in kwargs: + if kwargs.get('offset') != param['offset']: + drops.append(ind) + continue + if 'quark_kappas' in kwargs: + kappas = kwargs['quark_kappas'] + if (not np.isclose(kappas[0], param['quarks'][0]['mass']) or not np.isclose(kappas[1], param['quarks'][1]['mass'])): + drops.append(ind) + continue + if 'quark_masses' in kwargs: + masses = kwargs['quark_masses'] + if (not np.isclose(masses[0], k2m(param['quarks'][0]['mass'])) or not np.isclose(masses[1], k2m(param['quarks'][1]['mass']))): + drops.append(ind) + continue + if 'qk1' in kwargs: + quark_kappa1 = kwargs['qk1'] + if not isinstance(quark_kappa1, list): + if (not np.isclose(quark_kappa1, param['quarks'][0]['mass'])): + drops.append(ind) + continue + else: + if len(quark_kappa1) == 2: + if (quark_kappa1[0] > param['quarks'][0]['mass']) or (quark_kappa1[1] < param['quarks'][0]['mass']): + drops.append(ind) + continue + if 'qk2' in kwargs: + quark_kappa2 = kwargs['qk2'] + if not isinstance(quark_kappa2, list): + if (not np.isclose(quark_kappa2, param['quarks'][1]['mass'])): + drops.append(ind) + continue + else: + if len(quark_kappa2) == 2: + if (quark_kappa2[0] > param['quarks'][1]['mass']) or (quark_kappa2[1] < param['quarks'][1]['mass']): + drops.append(ind) + continue + if 'qm1' in kwargs: + quark_mass1 = kwargs['qm1'] + if not isinstance(quark_mass1, list): + if (not np.isclose(quark_mass1, k2m(param['quarks'][0]['mass']))): + drops.append(ind) + continue + else: + if len(quark_mass1) == 2: + if (quark_mass1[0] > k2m(param['quarks'][0]['mass'])) or (quark_mass1[1] < k2m(param['quarks'][0]['mass'])): + drops.append(ind) + continue + if 'qm2' in kwargs: + quark_mass2 = kwargs['qm2'] + if not isinstance(quark_mass2, list): + if (not np.isclose(quark_mass2, k2m(param['quarks'][1]['mass']))): + drops.append(ind) + continue + else: + if len(quark_mass2) == 2: + if (quark_mass2[0] > k2m(param['quarks'][1]['mass'])) or (quark_mass2[1] < k2m(param['quarks'][1]['mass'])): + drops.append(ind) + continue + if 'quark_thetas' in kwargs: + quark_thetas = kwargs['quark_thetas'] + if (quark_thetas[0] != param['quarks'][0]['thetas'] and quark_thetas[1] != param['quarks'][1]['thetas']) or (quark_thetas[0] != param['quarks'][1]['thetas'] and quark_thetas[1] != param['quarks'][0]['thetas']): + drops.append(ind) + continue + # careful, this is not save, when multiple contributions are present! + if 'wf1' in kwargs: + wf1 = kwargs['wf1'] + if not (np.isclose(wf1[0][0], param['wf1'][0][0], 1e-8) and np.isclose(wf1[0][1][0], param['wf1'][0][1][0], 1e-8) and np.isclose(wf1[0][1][1], param['wf1'][0][1][1], 1e-8)): + drops.append(ind) + continue + if 'wf2' in kwargs: + wf2 = kwargs['wf2'] + if not (np.isclose(wf2[0][0], param['wf2'][0][0], 1e-8) and np.isclose(wf2[0][1][0], param['wf2'][0][1][0], 1e-8) and np.isclose(wf2[0][1][1], param['wf2'][0][1][1], 1e-8)): + drops.append(ind) + continue return results.drop(drops) -def openQCD_filter(results:pd.DataFrame, **kwargs: Any) -> pd.DataFrame: - """ - Filter for parameters of openQCD. - - Parameters - ---------- - results: pd.DataFrame - The unfiltered list of results from the database. - - Returns - ------- - results: pd.DataFrame - The filtered results. - - """ - warnings.warn("A filter for openQCD parameters is no implemented yet.", Warning) - - return results - - -def _code_filter(results: pd.DataFrame, code: str, **kwargs: Any) -> pd.DataFrame: - """ - Abstraction of the filters for the different codes that are available. - At the moment, only openQCD and SFCF are known. - The possible key words for the parameters can be seen in the descriptionso f the code-specific filters. - - Parameters - ---------- - results: pd.DataFrame - The unfiltered list of results from the database. - code: str - The name of the code that produced the record at hand. - kwargs: - The keyworkd args that are handed over to the code-specific filters. - - Returns - ------- - results: pd.DataFrame - The filtered results. - """ - if code == "sfcf": - return sfcf_filter(results, **kwargs) - elif code == "openQCD": - return openQCD_filter(results, **kwargs) - else: - raise ValueError(f"Code {code} is not known.") - - -def find_record(path: Path, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None, - created_before: Optional[str]=None, created_after: Optional[str]=None, updated_before: Optional[str]=None, updated_after: Optional[str]=None, - revision: Optional[str]=None, - customFilter: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, - **kwargs: Any) -> pd.DataFrame: - db_file = get_db_file(path) - db = path / db_file +def find_record(path: str, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None, + created_before: Optional[str]=None, created_after: Optional[str]=None, updated_before: Optional[str]=None, updated_after: Optional[str]=None, revision: Optional[str]=None, **kwargs: Any) -> pd.DataFrame: + db_file = db_filename(path) + db = os.path.join(path, db_file) if code not in codes: raise ValueError("Code " + code + "unknown, take one of the following:" + ", ".join(codes)) get(path, db_file) - results = _db_lookup(db, ensemble, correlator_name,code, project, parameters=parameters) - if any([arg is not None for arg in [created_before, created_after, updated_before, updated_after]]): - results = _time_filter(results, created_before, created_after, updated_before, updated_after) - results = _code_filter(results, code, **kwargs) - if customFilter is not None: - results = customFilter(results) + results = _db_lookup(db, ensemble, correlator_name,code, project, parameters=parameters, created_before=created_before, created_after=created_after, updated_before=updated_before, updated_after=updated_after) + if code == "sfcf": + results = sfcf_filter(results, **kwargs) + elif code == "openQCD": + pass + else: + raise Exception print("Found " + str(len(results)) + " result" + ("s" if len(results)>1 else "")) return results.reset_index() -def find_project(path: Path, name: str) -> str: +def find_project(path: str, name: str) -> str: """ Find a project by it's human readable name. @@ -358,12 +262,12 @@ def find_project(path: Path, name: str) -> str: uuid: str The uuid of the project in question. """ - db_file = get_db_file(path) + db_file = db_filename(path) get(path, db_file) - return _project_lookup_by_alias(path / db_file, name) + return _project_lookup_by_alias(os.path.join(path, db_file), name) -def list_projects(path: Path) -> list[tuple[str, str]]: +def list_projects(path: str) -> list[tuple[str, str]]: """ List all projects known to the library. @@ -377,7 +281,7 @@ def list_projects(path: Path) -> list[tuple[str, str]]: results: list[Any] The projects known to the library. """ - db_file = get_db_file(path) + db_file = db_filename(path) get(path, db_file) conn = sqlite3.connect(os.path.join(path, db_file)) c = conn.cursor() diff --git a/corrlib/git_tools.py b/corrlib/git_tools.py index d77f109..c6e7522 100644 --- a/corrlib/git_tools.py +++ b/corrlib/git_tools.py @@ -1,28 +1,27 @@ import os from .tracker import save import git -from pathlib import Path GITMODULES_FILE = '.gitmodules' -def move_submodule(repo_path: Path, old_path: Path, new_path: Path) -> None: +def move_submodule(repo_path: str, old_path: str, new_path: str) -> None: """ Move a submodule to a new location. Parameters ---------- - repo_path: Path + repo_path: str Path to the repository. - old_path: Path + old_path: str The old path of the module. - new_path: Path + new_path: str The new path of the module. """ - os.rename(repo_path / old_path, repo_path / new_path) + os.rename(os.path.join(repo_path, old_path), os.path.join(repo_path, new_path)) - gitmodules_file_path = repo_path / GITMODULES_FILE + gitmodules_file_path = os.path.join(repo_path, GITMODULES_FILE) # update paths in .gitmodules with open(gitmodules_file_path, 'r') as file: @@ -30,8 +29,8 @@ def move_submodule(repo_path: Path, old_path: Path, new_path: Path) -> None: updated_lines = [] for line in lines: - if str(old_path) in line: - line = line.replace(str(old_path), str(new_path)) + if old_path in line: + line = line.replace(old_path, new_path) updated_lines.append(line) with open(gitmodules_file_path, 'w') as file: @@ -41,6 +40,6 @@ def move_submodule(repo_path: Path, old_path: Path, new_path: Path) -> None: repo = git.Repo(repo_path) repo.git.add('.gitmodules') # save new state of the dataset - save(repo_path, message=f"Move module from {old_path} to {new_path}", files=[Path('.gitmodules'), repo_path]) + save(repo_path, message=f"Move module from {old_path} to {new_path}", files=['.gitmodules', repo_path]) return diff --git a/corrlib/initialization.py b/corrlib/initialization.py index c06a201..0b7be48 100644 --- a/corrlib/initialization.py +++ b/corrlib/initialization.py @@ -2,10 +2,9 @@ from configparser import ConfigParser import sqlite3 import os from .tracker import save, init -from pathlib import Path -def _create_db(db: Path) -> None: +def _create_db(db: str) -> None: """ Create the database file and the table. @@ -27,7 +26,8 @@ def _create_db(db: Path) -> None: parameters TEXT, parameter_file TEXT, created_at TEXT, - updated_at TEXT)''') + updated_at TEXT, + current_version TEXT)''') c.execute('''CREATE TABLE IF NOT EXISTS projects (id TEXT PRIMARY KEY, aliases TEXT, @@ -41,7 +41,7 @@ def _create_db(db: Path) -> None: return -def _create_config(path: Path, tracker: str, cached: bool) -> ConfigParser: +def _create_config(path: str, tracker: str, cached: bool) -> ConfigParser: """ Create the config file construction for backlogger. @@ -72,11 +72,12 @@ def _create_config(path: Path, tracker: str, cached: bool) -> ConfigParser: 'archive_path': 'archive', 'toml_imports_path': 'toml_imports', 'import_scripts_path': 'import_scripts', + 'cache_path': '.cache', } return config -def _write_config(path: Path, config: ConfigParser) -> None: +def _write_config(path: str, config: ConfigParser) -> None: """ Write the config file to disk. @@ -92,7 +93,7 @@ def _write_config(path: Path, config: ConfigParser) -> None: return -def create(path: Path, tracker: str = 'datalad', cached: bool = True) -> None: +def create(path: str, tracker: str = 'datalad', cached: bool = True) -> None: """ Create folder of backlogs. @@ -108,13 +109,13 @@ def create(path: Path, tracker: str = 'datalad', cached: bool = True) -> None: config = _create_config(path, tracker, cached) init(path, tracker) _write_config(path, config) - _create_db(path / config['paths']['db']) - os.chmod(path / config['paths']['db'], 0o666) - os.makedirs(path / config['paths']['projects_path']) - os.makedirs(path / config['paths']['archive_path']) - os.makedirs(path / config['paths']['toml_imports_path']) - os.makedirs(path / config['paths']['import_scripts_path'] / 'template.py') - with open(path / ".gitignore", "w") as fp: + _create_db(os.path.join(path, config['paths']['db'])) + os.chmod(os.path.join(path, config['paths']['db']), 0o666) + os.makedirs(os.path.join(path, config['paths']['projects_path'])) + os.makedirs(os.path.join(path, config['paths']['archive_path'])) + os.makedirs(os.path.join(path, config['paths']['toml_imports_path'])) + os.makedirs(os.path.join(path, config['paths']['import_scripts_path'], 'template.py')) + with open(os.path.join(path, ".gitignore"), "w") as fp: fp.write(".cache") fp.close() save(path, message="Initialized correlator library") diff --git a/corrlib/input/openQCD.py b/corrlib/input/openQCD.py index 879b555..71ebec6 100644 --- a/corrlib/input/openQCD.py +++ b/corrlib/input/openQCD.py @@ -3,13 +3,9 @@ import datalad.api as dl import os import fnmatch from typing import Any, Optional -from pathlib import Path -from ..pars.openQCD import ms1 -from ..pars.openQCD import qcd2 - -def load_ms1_infile(path: Path, project: str, file_in_project: str) -> dict[str, Any]: +def read_ms1_param(path: str, project: str, file_in_project: str) -> dict[str, Any]: """ Read the parameters for ms1 measurements from a parameter file in the project. @@ -73,7 +69,7 @@ def load_ms1_infile(path: Path, project: str, file_in_project: str) -> dict[str, return param -def load_ms3_infile(path: Path, project: str, file_in_project: str) -> dict[str, Any]: +def read_ms3_param(path: str, project: str, file_in_project: str) -> dict[str, Any]: """ Read the parameters for ms3 measurements from a parameter file in the project. @@ -107,7 +103,7 @@ def load_ms3_infile(path: Path, project: str, file_in_project: str) -> dict[str, return param -def read_rwms(path: Path, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, postfix: str="ms1", version: str='2.0', names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: +def read_rwms(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, postfix: str="ms1", version: str='2.0', names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: """ Read reweighting factor measurements from the project. @@ -164,7 +160,7 @@ def read_rwms(path: Path, project: str, dir_in_project: str, param: dict[str, An return rw_dict -def extract_t0(path: Path, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str="", names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: +def extract_t0(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str="", names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: """ Extract t0 measurements from the project. @@ -238,7 +234,7 @@ def extract_t0(path: Path, project: str, dir_in_project: str, param: dict[str, A return t0_dict -def extract_t1(path: Path, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str = "", names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: +def extract_t1(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str = "", names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: """ Extract t1 measurements from the project. @@ -307,51 +303,3 @@ def extract_t1(path: Path, project: str, dir_in_project: str, param: dict[str, A t1_dict[param["type"]] = {} t1_dict[param["type"]][pars] = t0 return t1_dict - - -def load_qcd2_pars(path: Path, project: str, file_in_project: str) -> dict[str, Any]: - """ - Thin wrapper around read_qcd2_par_file, getting the file before reading. - - Parameters - ---------- - path: Path - Path of the corrlib repository. - project: str - UUID of the project of the parameter-file. - file_in_project: str - The loaction of the file in the project directory. - - Returns - ------- - par_dict: dict - The dict with the parameters read from the .par-file. - """ - fname = path / "projects" / project / file_in_project - ds = os.path.join(path, "projects", project) - dl.get(fname, dataset=ds) - return qcd2.read_qcd2_par_file(fname) - - -def load_ms1_parfile(path: Path, project: str, file_in_project: str) -> dict[str, Any]: - """ - Thin wrapper around read_qcd2_ms1_par_file, getting the file before reading. - - Parameters - ---------- - path: Path - Path of the corrlib repository. - project: str - UUID of the project of the parameter-file. - file_in_project: str - The loaction of the file in the project directory. - - Returns - ------- - par_dict: dict - The dict with the parameters read from the .par-file. - """ - fname = path / "projects" / project / file_in_project - ds = os.path.join(path, "projects", project) - dl.get(fname, dataset=ds) - return ms1.read_qcd2_ms1_par_file(fname) diff --git a/corrlib/input/sfcf.py b/corrlib/input/sfcf.py index acd8261..6a75b72 100644 --- a/corrlib/input/sfcf.py +++ b/corrlib/input/sfcf.py @@ -3,8 +3,6 @@ import datalad.api as dl import json import os from typing import Any -from fnmatch import fnmatch -from pathlib import Path bi_corrs: list[str] = ["f_P", "fP", "f_p", @@ -81,7 +79,7 @@ for c in bib_corrs: corr_types[c] = 'bib' -def read_param(path: Path, project: str, file_in_project: str) -> dict[str, Any]: +def read_param(path: str, project: str, file_in_project: str) -> dict[str, Any]: """ Read the parameters from the sfcf file. @@ -97,7 +95,7 @@ def read_param(path: Path, project: str, file_in_project: str) -> dict[str, Any] """ - file = path / "projects" / project / file_in_project + file = path + "/projects/" + project + '/' + file_in_project dl.get(file, dataset=path) with open(file, 'r') as f: lines = f.readlines() @@ -258,7 +256,7 @@ def get_specs(key: str, parameters: dict[str, Any], sep: str = '/') -> str: return s -def read_data(path: Path, project: str, dir_in_project: str, prefix: str, param: dict[str, Any], version: str = '1.0c', cfg_seperator: str = 'n', sep: str = '/', **kwargs: Any) -> dict[str, Any]: +def read_data(path: str, project: str, dir_in_project: str, prefix: str, param: dict[str, Any], version: str = '1.0c', cfg_seperator: str = 'n', sep: str = '/', **kwargs: Any) -> dict[str, Any]: """ Extract the data from the sfcf file. @@ -300,10 +298,9 @@ def read_data(path: Path, project: str, dir_in_project: str, prefix: str, param: if not appended: compact = (version[-1] == "c") for i, item in enumerate(ls): - if fnmatch(item, prefix + "*"): - rep_path = directory + '/' + item - sub_ls = pe.input.sfcf._find_files(rep_path, prefix, compact, []) - files_to_get.extend([rep_path + "/" + filename for filename in sub_ls]) + rep_path = directory + '/' + item + sub_ls = pe.input.sfcf._find_files(rep_path, prefix, compact, []) + files_to_get.extend([rep_path + "/" + filename for filename in sub_ls]) print("Getting data, this might take a while...") @@ -321,10 +318,10 @@ def read_data(path: Path, project: str, dir_in_project: str, prefix: str, param: if not param['crr'] == []: if names is not None: data_crr = pe.input.sfcf.read_sfcf_multi(directory, prefix, param['crr'], param['mrr'], corr_type_list, range(len(param['wf_offsets'])), - range(len(param['wf_basis'])), range(len(param['wf_basis'])), version, cfg_seperator, keyed_out=True, silent=True, names=names) + range(len(param['wf_basis'])), range(len(param['wf_basis'])), version, cfg_seperator, keyed_out=True, names=names) else: data_crr = pe.input.sfcf.read_sfcf_multi(directory, prefix, param['crr'], param['mrr'], corr_type_list, range(len(param['wf_offsets'])), - range(len(param['wf_basis'])), range(len(param['wf_basis'])), version, cfg_seperator, keyed_out=True, silent=True) + range(len(param['wf_basis'])), range(len(param['wf_basis'])), version, cfg_seperator, keyed_out=True) for key in data_crr.keys(): data[key] = data_crr[key] diff --git a/corrlib/integrity.py b/corrlib/integrity.py deleted file mode 100644 index 5f80aa3..0000000 --- a/corrlib/integrity.py +++ /dev/null @@ -1,87 +0,0 @@ -import datetime as dt -from pathlib import Path -from .tools import get_db_file -import pandas as pd -import sqlite3 -from .tracker import get -import pyerrors.input.json as pj - -from typing import Any - - -def has_valid_times(result: pd.Series) -> bool: - # we expect created_at <= updated_at <= now - created_at = dt.datetime.fromisoformat(result['created_at']) - updated_at = dt.datetime.fromisoformat(result['updated_at']) - if created_at > updated_at: - return False - if updated_at > dt.datetime.now(): - return False - return True - -def are_keys_unique(db: Path, table: str, col: str) -> bool: - conn = sqlite3.connect(db) - c = conn.cursor() - c.execute(f"SELECT COUNT( DISTINCT CAST(path AS nvarchar(4000))), COUNT({col}) FROM {table};") - results = c.fetchall()[0] - conn.close() - return bool(results[0] == results[1]) - - -def check_db_integrity(path: Path) -> None: - db = get_db_file(path) - - if not are_keys_unique(path / db, 'backlogs', 'path'): - raise Exception("The paths the backlog table of the database links are not unique.") - - search_expr = "SELECT * FROM 'backlogs'" - conn = sqlite3.connect(path / db) - results = pd.read_sql(search_expr, conn) - - for _, result in results.iterrows(): - if not has_valid_times(result): - raise ValueError(f"Result with id {result[id]} has wrong time signatures.") - print("DB:\t✅") - return - - -def _check_db2paths(path: Path, meas_paths: list[str]) -> None: - needed_data: dict[str, list[str]] = {} - for mpath in meas_paths: - file = mpath.split("::")[0] - if file not in needed_data.keys(): - needed_data[file] = [] - key = mpath.split("::")[1] - needed_data[file].append(key) - - totf = len(needed_data.keys()) - for i, file in enumerate(needed_data.keys()): - print(f"Check against file {i}/{totf}: {file}") - get(path, Path(file)) - filedict: dict[str, Any] = pj.load_json_dict(str(path / file)) - if not set(filedict.keys()).issubset(needed_data[file]): - for key in filedict.keys(): - if key not in needed_data[file]: - raise ValueError(f"Found unintended key {key} in file {file}.") - if not set(needed_data[file]).issubset(filedict.keys()): - for key in needed_data[file]: - if key not in filedict.keys(): - raise ValueError(f"Did not find data for key {key} that should be in file {file}.") - print("Links:\t✅") - return - - -def check_db_file_links(path: Path) -> None: - db = get_db_file(path) - search_expr = "SELECT path FROM 'backlogs'" - conn = sqlite3.connect(path / db) - results = pd.read_sql(search_expr, conn)['path'].values - _check_db2paths(path, list(results)) - - -def full_integrity_check(path: Path) -> None: - check_db_integrity(path) - check_db_file_links(path) - print("Full:\t✅") - - diff --git a/corrlib/main.py b/corrlib/main.py index 831b69d..df0cd7a 100644 --- a/corrlib/main.py +++ b/corrlib/main.py @@ -5,13 +5,12 @@ import os from .git_tools import move_submodule import shutil from .find import _project_lookup_by_id -from .tools import list2str, str2list, get_db_file +from .tools import list2str, str2list, db_filename from .tracker import get, save, unlock, clone, drop from typing import Union, Optional -from pathlib import Path -def create_project(path: Path, uuid: str, owner: Union[str, None]=None, tags: Union[list[str], None]=None, aliases: Union[list[str], None]=None, code: Union[str, None]=None) -> None: +def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Union[list[str], None]=None, aliases: Union[list[str], None]=None, code: Union[str, None]=None) -> None: """ Create a new project entry in the database. @@ -26,7 +25,7 @@ def create_project(path: Path, uuid: str, owner: Union[str, None]=None, tags: Un code: str (optional) The code that was used to create the measurements. """ - db_file = get_db_file(path) + db_file = db_filename(path) db = os.path.join(path, db_file) get(path, db_file) conn = sqlite3.connect(db) @@ -49,7 +48,7 @@ def create_project(path: Path, uuid: str, owner: Union[str, None]=None, tags: Un return -def update_project_data(path: Path, uuid: str, prop: str, value: Union[str, None] = None) -> None: +def update_project_data(path: str, uuid: str, prop: str, value: Union[str, None] = None) -> None: """ Update/Edit a project entry in the database. Thin wrapper around sql3 call. @@ -65,7 +64,7 @@ def update_project_data(path: Path, uuid: str, prop: str, value: Union[str, None value: str or None Value to se `prop` to. """ - db_file = get_db_file(path) + db_file = db_filename(path) get(path, db_file) conn = sqlite3.connect(os.path.join(path, db_file)) c = conn.cursor() @@ -75,9 +74,9 @@ def update_project_data(path: Path, uuid: str, prop: str, value: Union[str, None return -def update_aliases(path: Path, uuid: str, aliases: list[str]) -> None: - db_file = get_db_file(path) - db = path / db_file +def update_aliases(path: str, uuid: str, aliases: list[str]) -> None: + db_file = db_filename(path) + db = os.path.join(path, db_file) get(path, db_file) known_data = _project_lookup_by_id(db, uuid)[0] known_aliases = known_data[1] @@ -103,7 +102,7 @@ def update_aliases(path: Path, uuid: str, aliases: list[str]) -> None: return -def import_project(path: Path, url: str, owner: Union[str, None]=None, tags: Optional[list[str]]=None, aliases: Optional[list[str]]=None, code: Optional[str]=None, isDataset: bool=True) -> str: +def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Optional[list[str]]=None, aliases: Optional[list[str]]=None, code: Optional[str]=None, isDataset: bool=True) -> str: """ Import a datalad dataset into the backlogger. @@ -135,14 +134,14 @@ def import_project(path: Path, url: str, owner: Union[str, None]=None, tags: Opt uuid = str(conf.get("datalad.dataset.id")) if not uuid: raise ValueError("The dataset does not have a uuid!") - if not os.path.exists(path / "projects" / uuid): - db_file = get_db_file(path) + if not os.path.exists(path + "/projects/" + uuid): + db_file = db_filename(path) get(path, db_file) unlock(path, db_file) create_project(path, uuid, owner, tags, aliases, code) - move_submodule(path, Path('projects/tmp'), Path('projects') / uuid) - os.mkdir(path / 'import_scripts' / uuid) - save(path, message="Import project from " + url, files=[Path(f'projects/{uuid}'), db_file]) + move_submodule(path, 'projects/tmp', 'projects/' + uuid) + os.mkdir(path + '/import_scripts/' + uuid) + save(path, message="Import project from " + url, files=['projects/' + uuid, db_file]) else: dl.drop(tmp_path, reckless='kill') shutil.rmtree(tmp_path) @@ -157,7 +156,7 @@ def import_project(path: Path, url: str, owner: Union[str, None]=None, tags: Opt return uuid -def drop_project_data(path: Path, uuid: str, path_in_project: str = "") -> None: +def drop_project_data(path: str, uuid: str, path_in_project: str = "") -> None: """ Drop (parts of) a project to free up diskspace @@ -170,5 +169,6 @@ def drop_project_data(path: Path, uuid: str, path_in_project: str = "") -> None: path_pn_project: str, optional If set, only the given path within the project is dropped. """ - drop(path / "projects" / uuid / path_in_project) + drop(path + "/projects/" + uuid + "/" + path_in_project) return + diff --git a/corrlib/meas_io.py b/corrlib/meas_io.py index cbd9386..3344efb 100644 --- a/corrlib/meas_io.py +++ b/corrlib/meas_io.py @@ -3,21 +3,17 @@ import os import sqlite3 from .input import sfcf,openQCD import json -from typing import Union -from pyerrors import Obs, Corr, dump_object, load_object +from typing import Union, Any +from pyerrors import Obs, Corr, load_object, dump_object from hashlib import sha256 -from .tools import get_db_file, cache_enabled +from .tools import record2name_key, name_key2record, make_version_hash +from .cache_io import is_in_cache, cache_path, cache_dir, get_version_hash +from .tools import db_filename, cache_enabled from .tracker import get, save, unlock import shutil -from typing import Any -from pathlib import Path -from .integrity import _check_db2paths -CACHE_DIR = ".cache" - - -def write_measurement(path: Path, ensemble: str, measurement: dict[str, dict[str, dict[str, Any]]], uuid: str, code: str, parameter_file: Union[str, None]) -> None: +def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, dict[str, Any]]], uuid: str, code: str, parameter_file: str) -> None: """ Write a measurement to the backlog. If the file for the measurement already exists, update the measurement. @@ -37,35 +33,26 @@ def write_measurement(path: Path, ensemble: str, measurement: dict[str, dict[str parameter_file: str The parameter file used for the measurement. """ - db_file = get_db_file(path) - db = path / db_file - - files_to_save = [] - + db_file = db_filename(path) + db = os.path.join(path, db_file) get(path, db_file) unlock(path, db_file) - files_to_save.append(db_file) - conn = sqlite3.connect(db) c = conn.cursor() + files = [] for corr in measurement.keys(): - file_in_archive = Path('.') / 'archive' / ensemble / corr / str(uuid + '.json.gz') - file = path / file_in_archive + file_in_archive = os.path.join('.', 'archive', ensemble, corr, uuid + '.json.gz') + file = os.path.join(path, file_in_archive) + files.append(file) known_meas = {} - if not os.path.exists(path / 'archive' / ensemble / corr): - os.makedirs(path / 'archive' / ensemble / corr) - files_to_save.append(file_in_archive) + if not os.path.exists(os.path.join(path, '.', 'archive', ensemble, corr)): + os.makedirs(os.path.join(path, '.', 'archive', ensemble, corr)) else: if os.path.exists(file): - if file not in files_to_save: - unlock(path, file_in_archive) - files_to_save.append(file_in_archive) - known_meas = pj.load_json_dict(str(file), verbose=False) + unlock(path, file_in_archive) + known_meas = pj.load_json_dict(file) if code == "sfcf": - if parameter_file is not None: - parameters = sfcf.read_param(path, uuid, parameter_file) - else: - raise Exception("Need parameter file for this code!") + parameters = sfcf.read_param(path, uuid, parameter_file) pars = {} subkeys = list(measurement[corr].keys()) for subkey in subkeys: @@ -74,25 +61,7 @@ def write_measurement(path: Path, ensemble: str, measurement: dict[str, dict[str elif code == "openQCD": ms_type = list(measurement.keys())[0] if ms_type == 'ms1': - if parameter_file is not None: - if parameter_file.endswith(".ms1.in"): - parameters = openQCD.load_ms1_infile(path, uuid, parameter_file) - elif parameter_file.endswith(".ms1.par"): - parameters = openQCD.load_ms1_parfile(path, uuid, parameter_file) - else: - # Temporary solution - parameters = {} - parameters["rand"] = {} - parameters["rw_fcts"] = [{}] - for nrw in range(1): - if "nsrc" not in parameters["rw_fcts"][nrw]: - parameters["rw_fcts"][nrw]["nsrc"] = 1 - if "mu" not in parameters["rw_fcts"][nrw]: - parameters["rw_fcts"][nrw]["mu"] = "None" - if "np" not in parameters["rw_fcts"][nrw]: - parameters["rw_fcts"][nrw]["np"] = "None" - if "irp" not in parameters["rw_fcts"][nrw]: - parameters["rw_fcts"][nrw]["irp"] = "None" + parameters = openQCD.read_ms1_param(path, uuid, parameter_file) pars = {} subkeys = [] for i in range(len(parameters["rw_fcts"])): @@ -104,7 +73,7 @@ def write_measurement(path: Path, ensemble: str, measurement: dict[str, dict[str pars[subkey] = json.dumps(parameters["rw_fcts"][i]) elif ms_type in ['t0', 't1']: if parameter_file is not None: - parameters = openQCD.load_ms3_infile(path, uuid, parameter_file) + parameters = openQCD.read_ms3_param(path, uuid, parameter_file) else: parameters = {} for rwp in ["integrator", "eps", "ntot", "dnms"]: @@ -117,25 +86,26 @@ def write_measurement(path: Path, ensemble: str, measurement: dict[str, dict[str subkey = "/".join(par_list) subkeys = [subkey] pars[subkey] = json.dumps(parameters) + + meas_paths = [] for subkey in subkeys: - parHash = sha256(str(pars[subkey]).encode('UTF-8')).hexdigest() - meas_path = str(file_in_archive) + "::" + parHash - - known_meas[parHash] = measurement[corr][subkey] - - if c.execute("SELECT * FROM backlogs WHERE path = ?", (meas_path,)).fetchone() is not None: - c.execute("UPDATE backlogs SET updated_at = datetime('now') WHERE path = ?", (meas_path, )) - else: - c.execute("INSERT INTO backlogs (name, ensemble, code, path, project, parameters, parameter_file, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now'))", + par_hash = sha256(str(pars[subkey]).encode('UTF-8')).hexdigest() + meas_path = name_key2record(file_in_archive, par_hash) + meas_paths.append(meas_path) + known_meas[par_hash] = measurement[corr][subkey] + data_hash = make_version_hash(path, meas_path) + if c.execute("SELECT * FROM backlogs WHERE path = ?", (meas_path,)).fetchone() is None: + c.execute("INSERT INTO backlogs (name, ensemble, code, path, project, parameters, parameter_file, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'))", (corr, ensemble, code, meas_path, uuid, pars[subkey], parameter_file)) - conn.commit() - pj.dump_dict_to_json(known_meas, str(file)) + c.execute("UPDATE backlogs SET current_version = ?, updated_at = datetime('now') WHERE path = ?", (data_hash, meas_path)) + pj.dump_dict_to_json(known_meas, file) + files.append(os.path.join(path, db_file)) conn.close() - save(path, message="Add measurements to database", files=files_to_save) + save(path, message="Add measurements to database", files=files) return -def load_record(path: Path, meas_path: str) -> Union[Corr, Obs]: +def load_record(path: str, meas_path: str) -> Union[Corr, Obs]: """ Load a list of records by their paths. @@ -154,7 +124,7 @@ def load_record(path: Path, meas_path: str) -> Union[Corr, Obs]: return load_records(path, [meas_path])[0] -def load_records(path: Path, meas_paths: list[str], preloaded: dict[str, Any] = {}, dry_run: bool = False) -> list[Union[Corr, Obs]]: +def load_records(path: str, record_paths: list[str], preloaded: dict[str, Any] = {}) -> list[Union[Corr, Obs]]: """ Load a list of records by their paths. @@ -164,85 +134,42 @@ def load_records(path: Path, meas_paths: list[str], preloaded: dict[str, Any] = Path of the correlator library. meas_paths: list[str] A list of the paths to the correlator in the backlog system. - preloaded: dict[str, Any] - The data that is already preloaded. Of interest if data has alread been loaded in the same script. - dry_run: bool - Do not load datda, just check whether we can reach the data we are interested in. + perloaded: dict[str, Any] + The data that is already prelaoded. Of interest if data has alread been loaded in the same script. Returns ------- - returned_data: list + retruned_data: list The loaded records. """ - if dry_run: - _check_db2paths(path, meas_paths) - return [] needed_data: dict[str, list[str]] = {} - for mpath in meas_paths: - file = mpath.split("::")[0] + for rpath in record_paths: + file, key = record2name_key(rpath) if file not in needed_data.keys(): needed_data[file] = [] - key = mpath.split("::")[1] needed_data[file].append(key) returned_data: list[Any] = [] for file in needed_data.keys(): for key in list(needed_data[file]): - if os.path.exists(str(cache_path(path, file, key)) + ".p"): - returned_data.append(load_object(str(cache_path(path, file, key)) + ".p")) + record = name_key2record(file, key) + current_version = get_version_hash(path, record) + if is_in_cache(path, record): + returned_data.append(load_object(cache_path(path, file, current_version, key) + ".p")) else: if file not in preloaded: - preloaded[file] = preload(path, Path(file)) + preloaded[file] = preload(path, file) returned_data.append(preloaded[file][key]) if cache_enabled(path): - if not os.path.exists(cache_dir(path, file)): - os.makedirs(cache_dir(path, file)) - dump_object(preloaded[file][key], str(cache_path(path, file, key))) + if not is_in_cache(path, record): + file, key = record2name_key(record) + if not os.path.exists(cache_dir(path, file)): + os.makedirs(cache_dir(path, file)) + current_version = get_version_hash(path, record) + dump_object(preloaded[file][key], cache_path(path, file, current_version, key)) return returned_data -def cache_dir(path: Path, file: str) -> Path: - """ - Returns the directory corresponding to the cache for the given file. - - Parameters - ---------- - path: str - The path of the library. - file: str - The file in the library that we want to access the cached data of. - Returns - ------- - cache_path: str - The path holding the cached data for the given file. - """ - cache_path_list = file.split("/")[1:] - cache_path = Path(path) / CACHE_DIR - for directory in cache_path_list: - cache_path /= directory - return cache_path - - -def cache_path(path: Path, file: str, key: str) -> Path: - """ - Parameters - ---------- - path: str - The path of the library. - file: str - The file in the library that we want to access the cached data of. - key: str - The key within the archive file. - - Returns - ------- - cache_path: str - The path at which the measurement of the given file and key is cached. - """ - cache_path = cache_dir(path, file) / key - return cache_path - - -def preload(path: Path, file: Path) -> dict[str, Any]: +def preload(path: str, file: str) -> dict[str, Any]: """ Read the contents of a file into a json dictionary with the pyerrors.json.load_json_dict method. @@ -259,12 +186,12 @@ def preload(path: Path, file: Path) -> dict[str, Any]: The data read from the file. """ get(path, file) - filedict: dict[str, Any] = pj.load_json_dict(str(path / file)) + filedict: dict[str, Any] = pj.load_json_dict(os.path.join(path, file)) print("> read file") return filedict -def drop_record(path: Path, meas_path: str) -> None: +def drop_record(path: str, meas_path: str) -> None: """ Drop a record by it's path. @@ -276,9 +203,9 @@ def drop_record(path: Path, meas_path: str) -> None: The measurement path as noted in the database. """ file_in_archive = meas_path.split("::")[0] - file = path / file_in_archive - db_file = get_db_file(path) - db = path / db_file + file = os.path.join(path, file_in_archive) + db_file = db_filename(path) + db = os.path.join(path, db_file) get(path, db_file) sub_key = meas_path.split("::")[1] unlock(path, db_file) @@ -290,18 +217,18 @@ def drop_record(path: Path, meas_path: str) -> None: raise ValueError("This measurement does not exist as an entry!") conn.commit() - known_meas = pj.load_json_dict(str(file)) + known_meas = pj.load_json_dict(file) if sub_key in known_meas: del known_meas[sub_key] - unlock(path, Path(file_in_archive)) - pj.dump_dict_to_json(known_meas, str(file)) + unlock(path, file_in_archive) + pj.dump_dict_to_json(known_meas, file) save(path, message="Drop measurements to database", files=[db, file]) return else: raise ValueError("This measurement does not exist as a file!") -def drop_cache(path: Path) -> None: +def drop_cache(path: str) -> None: """ Drop the cache directory of the library. @@ -310,7 +237,7 @@ def drop_cache(path: Path) -> None: path: str The path of the library. """ - cache_dir = path / ".cache" + cache_dir = os.path.join(path, ".cache") for f in os.listdir(cache_dir): - shutil.rmtree(cache_dir / f) + shutil.rmtree(os.path.join(cache_dir, f)) return diff --git a/corrlib/pars/openQCD/__init__.py b/corrlib/pars/openQCD/__init__.py deleted file mode 100644 index edbac71..0000000 --- a/corrlib/pars/openQCD/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ - -from . import ms1 as ms1 -from . import qcd2 as qcd2 diff --git a/corrlib/pars/openQCD/flags.py b/corrlib/pars/openQCD/flags.py deleted file mode 100644 index 95be919..0000000 --- a/corrlib/pars/openQCD/flags.py +++ /dev/null @@ -1,59 +0,0 @@ -""" -Reconstruct the outputs of flags. -""" - -import struct -from typing import Any, BinaryIO - -# lat_parms.c -def lat_parms_write_lat_parms(fp: BinaryIO) -> dict[str, Any]: - """ - NOTE: This is a duplcation from qcd2. - Unpack the lattice parameters written by write_lat_parms. - """ - lat_pars = {} - t = fp.read(16) - lat_pars["N"] = list(struct.unpack('iiii', t)) # lattice extends - t = fp.read(8) - nk, isw = struct.unpack('ii', t) # number of kappas and isw parameter - lat_pars["nk"] = nk - lat_pars["isw"] = isw - t = fp.read(8) - lat_pars["beta"] = struct.unpack('d', t)[0] # beta - t = fp.read(8) - lat_pars["c0"] = struct.unpack('d', t)[0] - t = fp.read(8) - lat_pars["c1"] = struct.unpack('d', t)[0] - t = fp.read(8) - lat_pars["csw"] = struct.unpack('d', t)[0] # csw factor - kappas = [] - m0s = [] - # read kappas - for ik in range(nk): - t = fp.read(8) - kappas.append(struct.unpack('d', t)[0]) - t = fp.read(8) - m0s.append(struct.unpack('d', t)[0]) - lat_pars["kappas"] = kappas - lat_pars["m0s"] = m0s - return lat_pars - - -def lat_parms_write_bc_parms(fp: BinaryIO) -> dict[str, Any]: - """ - NOTE: This is a duplcation from qcd2. - Unpack the boundary parameters written by write_bc_parms. - """ - bc_pars: dict[str, Any] = {} - t = fp.read(4) - bc_pars["type"] = struct.unpack('i', t)[0] # type of hte boundaries - t = fp.read(104) - bc_parms = struct.unpack('d'*13, t) - bc_pars["cG"] = list(bc_parms[:2]) # boundary gauge field improvement - bc_pars["cF"] = list(bc_parms[2:4]) # boundary fermion field improvement - phi: list[list[float]] = [[], []] - phi[0] = list(bc_parms[4:7]) - phi[1] = list(bc_parms[7:10]) - bc_pars["phi"] = phi - bc_pars["theta"] = list(bc_parms[10:]) - return bc_pars diff --git a/corrlib/pars/openQCD/ms1.py b/corrlib/pars/openQCD/ms1.py deleted file mode 100644 index 4c2aed5..0000000 --- a/corrlib/pars/openQCD/ms1.py +++ /dev/null @@ -1,30 +0,0 @@ -from . import flags - -from typing import Any -from pathlib import Path - - -def read_qcd2_ms1_par_file(fname: Path) -> dict[str, dict[str, Any]]: - """ - The subroutines written here have names according to the openQCD programs and functions that write out the data. - Parameters - ---------- - fname: Path - Location of the parameter file. - - Returns - ------- - par_dict: dict - Dictionary holding the parameters specified in the given file. - """ - - with open(fname, "rb") as fp: - lat_par_dict = flags.lat_parms_write_lat_parms(fp) - bc_par_dict = flags.lat_parms_write_bc_parms(fp) - fp.close() - par_dict = {} - par_dict["lat"] = lat_par_dict - par_dict["bc"] = bc_par_dict - return par_dict - - diff --git a/corrlib/pars/openQCD/qcd2.py b/corrlib/pars/openQCD/qcd2.py deleted file mode 100644 index e73c156..0000000 --- a/corrlib/pars/openQCD/qcd2.py +++ /dev/null @@ -1,29 +0,0 @@ -from . import flags - -from pathlib import Path -from typing import Any - - -def read_qcd2_par_file(fname: Path) -> dict[str, dict[str, Any]]: - """ - The subroutines written here have names according to the openQCD programs and functions that write out the data. - - Parameters - ---------- - fname: Path - Location of the parameter file. - - Returns - ------- - par_dict: dict - Dictionary holding the parameters specified in the given file. - """ - - with open(fname, "rb") as fp: - lat_par_dict = flags.lat_parms_write_lat_parms(fp) - bc_par_dict = flags.lat_parms_write_bc_parms(fp) - fp.close() - par_dict = {} - par_dict["lat"] = lat_par_dict - par_dict["bc"] = bc_par_dict - return par_dict diff --git a/corrlib/toml.py b/corrlib/toml.py index 0d4dfc8..629a499 100644 --- a/corrlib/toml.py +++ b/corrlib/toml.py @@ -19,7 +19,6 @@ from .meas_io import write_measurement import os from .input.implementations import codes as known_codes from typing import Any -from pathlib import Path def replace_string(string: str, name: str, val: str) -> str: @@ -127,7 +126,7 @@ def check_measurement_data(measurements: dict[str, dict[str, str]], code: str) - return -def import_tomls(path: Path, files: list[str], copy_files: bool=True) -> None: +def import_tomls(path: str, files: list[str], copy_files: bool=True) -> None: """ Import multiple toml files. @@ -145,7 +144,7 @@ def import_tomls(path: Path, files: list[str], copy_files: bool=True) -> None: return -def import_toml(path: Path, file: str, copy_file: bool=True) -> None: +def import_toml(path: str, file: str, copy_file: bool=True) -> None: """ Import a project decribed by a .toml file. @@ -172,16 +171,14 @@ def import_toml(path: Path, file: str, copy_file: bool=True) -> None: aliases = project.get('aliases', []) uuid = project.get('uuid', None) if uuid is not None: - if not os.path.exists(path / "projects" / uuid): + if not os.path.exists(path + "/projects/" + uuid): uuid = import_project(path, project['url'], aliases=aliases) else: update_aliases(path, uuid, aliases) else: uuid = import_project(path, project['url'], aliases=aliases) - imeas = 1 - nmeas = len(measurements.keys()) for mname, md in measurements.items(): - print(f"Import measurement {imeas}/{nmeas}: {mname}") + print("Import measurement: " + mname) ensemble = md['ensemble'] if project['code'] == 'sfcf': param = sfcf.read_param(path, uuid, md['param_file']) @@ -192,34 +189,15 @@ def import_toml(path: Path, file: str, copy_file: bool=True) -> None: measurement = sfcf.read_data(path, uuid, md['path'], md['prefix'], param, version=md['version'], cfg_seperator=md['cfg_seperator'], sep='/') + print(mname + " imported.") elif project['code'] == 'openQCD': if md['measurement'] == 'ms1': - if 'param_file' in md.keys(): - parameter_file = md['param_file'] - if parameter_file.endswith(".ms1.in"): - param = openQCD.load_ms1_infile(path, uuid, parameter_file) - elif parameter_file.endswith(".ms1.par"): - param = openQCD.load_ms1_parfile(path, uuid, parameter_file) - else: - # Temporary solution - parameters: dict[str, Any] = {} - parameters["rand"] = {} - parameters["rw_fcts"] = [{}] - for nrw in range(1): - if "nsrc" not in parameters["rw_fcts"][nrw]: - parameters["rw_fcts"][nrw]["nsrc"] = 1 - if "mu" not in parameters["rw_fcts"][nrw]: - parameters["rw_fcts"][nrw]["mu"] = "None" - if "np" not in parameters["rw_fcts"][nrw]: - parameters["rw_fcts"][nrw]["np"] = "None" - if "irp" not in parameters["rw_fcts"][nrw]: - parameters["rw_fcts"][nrw]["irp"] = "None" - param = parameters + param = openQCD.read_ms1_param(path, uuid, md['param_file']) param['type'] = 'ms1' measurement = openQCD.read_rwms(path, uuid, md['path'], param, md["prefix"], version=md["version"], names=md['names'], files=md['files']) elif md['measurement'] == 't0': if 'param_file' in md: - param = openQCD.load_ms3_infile(path, uuid, md['param_file']) + param = openQCD.read_ms3_param(path, uuid, md['param_file']) else: param = {} for rwp in ["integrator", "eps", "ntot", "dnms"]: @@ -229,26 +207,25 @@ def import_toml(path: Path, file: str, copy_file: bool=True) -> None: fit_range=int(md.get('fit_range', 5)), postfix=str(md.get('postfix', '')), names=md.get('names', []), files=md.get('files', [])) elif md['measurement'] == 't1': if 'param_file' in md: - param = openQCD.load_ms3_infile(path, uuid, md['param_file']) + param = openQCD.read_ms3_param(path, uuid, md['param_file']) param['type'] = 't1' measurement = openQCD.extract_t1(path, uuid, md['path'], param, str(md["prefix"]), int(md["dtr_read"]), int(md["xmin"]), int(md["spatial_extent"]), fit_range=int(md.get('fit_range', 5)), postfix=str(md.get('postfix', '')), names=md.get('names', []), files=md.get('files', [])) - write_measurement(path, ensemble, measurement, uuid, project['code'], (md['param_file'] if 'param_file' in md else None)) - imeas += 1 - print(mname + " imported.") - if not os.path.exists(path / "toml_imports" / uuid): - os.makedirs(path / "toml_imports" / uuid) + write_measurement(path, ensemble, measurement, uuid, project['code'], (md['param_file'] if 'param_file' in md else '')) + + if not os.path.exists(os.path.join(path, "toml_imports", uuid)): + os.makedirs(os.path.join(path, "toml_imports", uuid)) if copy_file: - import_file = path / "toml_imports" / uuid / file.split("/")[-1] + import_file = os.path.join(path, "toml_imports", uuid, file.split("/")[-1]) shutil.copy(file, import_file) - save(path, files=[import_file], message=f"Import using {import_file}") - print(f"File copied to {import_file}") + save(path, files=[import_file], message="Import using " + import_file) + print("File copied to " + import_file) print("Imported project.") return -def reimport_project(path: Path, uuid: str) -> None: +def reimport_project(path: str, uuid: str) -> None: """ Reimport an existing project using the files that are already available for this project. @@ -259,14 +236,14 @@ def reimport_project(path: Path, uuid: str) -> None: uuid: str uuid of the project that is to be reimported. """ - config_path = path / "import_scripts" / uuid + config_path = "/".join([path, "import_scripts", uuid]) for p, filenames, dirnames in os.walk(config_path): for fname in filenames: import_toml(path, os.path.join(config_path, fname), copy_file=False) return -def update_project(path: Path, uuid: str) -> None: +def update_project(path: str, uuid: str) -> None: """ Update all entries associated with a given project. diff --git a/corrlib/tools.py b/corrlib/tools.py index 93f0678..e46ce0a 100644 --- a/corrlib/tools.py +++ b/corrlib/tools.py @@ -1,7 +1,7 @@ import os +import hashlib from configparser import ConfigParser -from typing import Any -from pathlib import Path +from typing import Any, Union CONFIG_FILENAME = ".corrlib" cached: bool = True @@ -23,6 +23,7 @@ def str2list(string: str) -> list[str]: """ return string.split(",") + def list2str(mylist: list[str]) -> str: """ Convert a list to a comma-separated string. @@ -40,6 +41,7 @@ def list2str(mylist: list[str]) -> str: s = ",".join(mylist) return s + def m2k(m: float) -> float: """ Convert to bare quark mas $m$ to inverse mass parameter $kappa$. @@ -74,7 +76,48 @@ def k2m(k: float) -> float: return (1/(2*k))-4 -def set_config(path: Path, section: str, option: str, value: Any) -> None: +def record2name_key(record_path: str) -> tuple[str, str]: + """ + Convert a record to a pair of name and key. + + Parameters + ---------- + record: str + + Returns + ------- + name: str + key: str + """ + file = record_path.split("::")[0] + key = record_path.split("::")[1] + return file, key + + +def name_key2record(name: str, key: str) -> str: + """ + Convert a name and a key to a record name. + + Parameters + ---------- + name: str + key: str + + Returns + ------- + record: str + """ + return name + "::" + key + + +def make_version_hash(path: str, record: str) -> str: + file, key = record2name_key(record) + with open(os.path.join(path, file), 'rb') as fp: + file_hash = hashlib.file_digest(fp, 'sha1').hexdigest() + return file_hash + + +def set_config(path: str, section: str, option: str, value: Any) -> None: """ Set configuration parameters for the library. @@ -89,7 +132,7 @@ def set_config(path: Path, section: str, option: str, value: Any) -> None: value: Any The value we set the option to. """ - config_path = os.path.join(path, CONFIG_FILENAME) + config_path = os.path.join(path, '.corrlib') config = ConfigParser() if os.path.exists(config_path): config.read(config_path) @@ -101,7 +144,7 @@ def set_config(path: Path, section: str, option: str, value: Any) -> None: return -def get_db_file(path: Path) -> Path: +def db_filename(path: str) -> str: """ Get the database file associated with the library at the given path. @@ -119,13 +162,11 @@ def get_db_file(path: Path) -> Path: config = ConfigParser() if os.path.exists(config_path): config.read(config_path) - else: - raise FileNotFoundError("Configuration file not found.") - db_file = Path(config.get('paths', 'db', fallback='backlogger.db')) + db_file = config.get('paths', 'db', fallback='backlogger.db') return db_file -def cache_enabled(path: Path) -> bool: +def cache_enabled(path: str) -> bool: """ Check, whether the library is cached. Fallback is true. @@ -144,10 +185,31 @@ def cache_enabled(path: Path) -> bool: config = ConfigParser() if os.path.exists(config_path): config.read(config_path) - else: - raise FileNotFoundError("Configuration file not found.") cached_str = config.get('core', 'cached', fallback='True') - if cached_str not in ['True', 'False']: - raise ValueError(f"String {cached_str} is not a valid option, only True and False are allowed!") cached_bool = cached_str == ('True') return cached_bool + + +def cache_dir_name(path: str) -> Union[str, None]: + """ + Get the database file associated with the library at the given path. + + Parameters + ---------- + path: str + The path of the library. + + Returns + ------- + db_file: str + The file holding the database. + """ + config_path = os.path.join(path, CONFIG_FILENAME) + config = ConfigParser() + if os.path.exists(config_path): + config.read(config_path) + if cache_enabled(path): + cache = config.get('paths', 'cache', fallback='.cache') + else: + cache = None + return cache diff --git a/corrlib/tracker.py b/corrlib/tracker.py index a6e9bf4..63aabf2 100644 --- a/corrlib/tracker.py +++ b/corrlib/tracker.py @@ -3,11 +3,10 @@ from configparser import ConfigParser import datalad.api as dl from typing import Optional import shutil -from .tools import get_db_file -from pathlib import Path +from .tools import db_filename -def get_tracker(path: Path) -> str: +def get_tracker(path: str) -> str: """ Get the tracker used in the dataset located at path. @@ -31,7 +30,7 @@ def get_tracker(path: Path) -> str: return tracker -def get(path: Path, file: Path) -> None: +def get(path: str, file: str) -> None: """ Wrapper function to get a file from the dataset located at path with the specified tracker. @@ -44,7 +43,7 @@ def get(path: Path, file: Path) -> None: """ tracker = get_tracker(path) if tracker == 'datalad': - if file == get_db_file(path): + if file == db_filename(path): print("Downloading database...") else: print("Downloading data...") @@ -57,7 +56,7 @@ def get(path: Path, file: Path) -> None: return -def save(path: Path, message: str, files: Optional[list[Path]]=None) -> None: +def save(path: str, message: str, files: Optional[list[str]]=None) -> None: """ Wrapper function to save a file to the dataset located at path with the specified tracker. @@ -73,7 +72,7 @@ def save(path: Path, message: str, files: Optional[list[Path]]=None) -> None: tracker = get_tracker(path) if tracker == 'datalad': if files is not None: - files = [path / f for f in files] + files = [os.path.join(path, f) for f in files] dl.save(files, message=message, dataset=path) elif tracker == 'None': Warning("Tracker 'None' does not implement save.") @@ -82,7 +81,7 @@ def save(path: Path, message: str, files: Optional[list[Path]]=None) -> None: raise ValueError(f"Tracker {tracker} is not supported.") -def init(path: Path, tracker: str='datalad') -> None: +def init(path: str, tracker: str='datalad') -> None: """ Initialize a dataset at the specified path with the specified tracker. @@ -102,7 +101,7 @@ def init(path: Path, tracker: str='datalad') -> None: return -def unlock(path: Path, file: Path) -> None: +def unlock(path: str, file: str) -> None: """ Wrapper function to unlock a file in the dataset located at path with the specified tracker. @@ -115,7 +114,7 @@ def unlock(path: Path, file: Path) -> None: """ tracker = get_tracker(path) if tracker == 'datalad': - dl.unlock(os.path.join(path, file), dataset=path) + dl.unlock(file, dataset=path) elif tracker == 'None': Warning("Tracker 'None' does not implement unlock.") pass @@ -124,7 +123,7 @@ def unlock(path: Path, file: Path) -> None: return -def clone(path: Path, source: str, target: str) -> None: +def clone(path: str, source: str, target: str) -> None: """ Wrapper function to clone a dataset from source to target with the specified tracker. Parameters @@ -148,7 +147,7 @@ def clone(path: Path, source: str, target: str) -> None: return -def drop(path: Path, reckless: Optional[str]=None) -> None: +def drop(path: str, reckless: Optional[str]=None) -> None: """ Wrapper function to drop data from a dataset located at path with the specified tracker. diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..6b8794e --- /dev/null +++ b/setup.py @@ -0,0 +1,18 @@ +from setuptools import setup +from distutils.util import convert_path + + +version = {} +with open(convert_path('corrlib/version.py')) as ver_file: + exec(ver_file.read(), version) + +setup(name='pycorrlib', + version=version['__version__'], + author='Justus Kuhlmann', + author_email='j_kuhl19@uni-muenster.de', + install_requires=['pyerrors>=2.11.1', 'datalad>=1.1.0', 'typer>=0.12.5', 'gitpython>=3.1.45'], + entry_points = { + 'console_scripts': ['pcl=corrlib.cli:app'], + }, + packages=['corrlib', 'corrlib.input'] + ) diff --git a/tests/cli_test.py b/tests/cli_test.py index cba0a10..f1678c6 100644 --- a/tests/cli_test.py +++ b/tests/cli_test.py @@ -86,7 +86,7 @@ def test_list(tmp_path: Path) -> None: dataset_path = tmp_path / "test_dataset" result = runner.invoke(app, ["init", "--dataset", str(dataset_path)]) assert result.exit_code == 0 - result = runner.invoke(app, ["lister", "--dataset", str(dataset_path), "ensembles"]) + result = runner.invoke(app, ["list", "--dataset", str(dataset_path), "ensembles"]) assert result.exit_code == 0 - result = runner.invoke(app, ["lister", "--dataset", str(dataset_path), "projects"]) + result = runner.invoke(app, ["list", "--dataset", str(dataset_path), "projects"]) assert result.exit_code == 0 diff --git a/tests/find_test.py b/tests/find_test.py deleted file mode 100644 index cc455f9..0000000 --- a/tests/find_test.py +++ /dev/null @@ -1,432 +0,0 @@ -import corrlib.find as find -import sqlite3 -from pathlib import Path -import corrlib.initialization as cinit -import pytest -import pandas as pd -import datalad.api as dl -import datetime as dt - - -def make_sql(path: Path) -> Path: - db = path / "test.db" - cinit._create_db(db) - return db - - -def test_find_lookup_by_one_alias(tmp_path: Path) -> None: - db = make_sql(tmp_path) - conn = sqlite3.connect(db) - c = conn.cursor() - uuid = "test_uuid" - alias_str = "fun_project" - tag_str = "tt" - owner = "tester" - code = "test_code" - c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", - (uuid, alias_str, tag_str, owner, code)) - conn.commit() - assert uuid == find._project_lookup_by_alias(db, "fun_project") - uuid = "test_uuid2" - alias_str = "fun_project" - c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", - (uuid, alias_str, tag_str, owner, code)) - conn.commit() - with pytest.raises(Exception): - assert uuid == find._project_lookup_by_alias(db, "fun_project") - conn.close() - -def test_find_lookup_by_id(tmp_path: Path) -> None: - db = make_sql(tmp_path) - conn = sqlite3.connect(db) - c = conn.cursor() - uuid = "test_uuid" - alias_str = "fun_project" - tag_str = "tt" - owner = "tester" - code = "test_code" - c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", - (uuid, alias_str, tag_str, owner, code)) - conn.commit() - conn.close() - result = find._project_lookup_by_id(db, uuid)[0] - assert uuid == result[0] - assert alias_str == result[1] - assert tag_str == result[2] - assert owner == result[3] - assert code == result[4] - - -def test_time_filter() -> None: - record_A = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf0", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", - '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] # only created - record_B = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf1", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", - '2025-03-26 12:55:18.229966', '2025-04-26 12:55:18.229966'] # created and updated - record_C = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf2", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", - '2026-03-26 12:55:18.229966', '2026-04-14 12:55:18.229966'] # created and updated later - record_D = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf3", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", - '2026-03-26 12:55:18.229966', '2026-03-27 12:55:18.229966'] - record_E = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf4", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", - '2024-03-26 12:55:18.229966', '2024-03-26 12:55:18.229966'] # only created, earlier - record_F = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf5", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", - '2026-03-26 12:55:18.229966', '2024-03-26 12:55:18.229966'] # this is invalid... - record_G = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf2", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", - '2026-03-26 12:55:18.229966', str(dt.datetime.now() + dt.timedelta(days=2, hours=3, minutes=5, seconds=30))] # created and updated later - - data = [record_A, record_B, record_C, record_D, record_E] - cols = ["name", - "ensemble", - "code", - "path", - "project", - "parameters", - "parameter_file", - "created_at", - "updated_at"] - df = pd.DataFrame(data,columns=cols) - - results = find._time_filter(df, created_before='2023-03-26 12:55:18.229966') - assert results.empty - results = find._time_filter(df, created_before='2027-03-26 12:55:18.229966') - assert len(results) == 5 - results = find._time_filter(df, created_before='2026-03-25 12:55:18.229966') - assert len(results) == 3 - results = find._time_filter(df, created_before='2026-03-26 12:55:18.229965') - assert len(results) == 3 - results = find._time_filter(df, created_before='2025-03-04 12:55:18.229965') - assert len(results) == 1 - - results = find._time_filter(df, created_after='2023-03-26 12:55:18.229966') - assert len(results) == 5 - results = find._time_filter(df, created_after='2027-03-26 12:55:18.229966') - assert results.empty - results = find._time_filter(df, created_after='2026-03-25 12:55:18.229966') - assert len(results) == 2 - results = find._time_filter(df, created_after='2026-03-26 12:55:18.229965') - assert len(results) == 2 - results = find._time_filter(df, created_after='2025-03-04 12:55:18.229965') - assert len(results) == 4 - - results = find._time_filter(df, updated_before='2023-03-26 12:55:18.229966') - assert results.empty - results = find._time_filter(df, updated_before='2027-03-26 12:55:18.229966') - assert len(results) == 5 - results = find._time_filter(df, updated_before='2026-03-25 12:55:18.229966') - assert len(results) == 3 - results = find._time_filter(df, updated_before='2026-03-26 12:55:18.229965') - assert len(results) == 3 - results = find._time_filter(df, updated_before='2025-03-04 12:55:18.229965') - assert len(results) == 1 - - results = find._time_filter(df, updated_after='2023-03-26 12:55:18.229966') - assert len(results) == 5 - results = find._time_filter(df, updated_after='2027-03-26 12:55:18.229966') - assert results.empty - results = find._time_filter(df, updated_after='2026-03-25 12:55:18.229966') - assert len(results) == 2 - results = find._time_filter(df, updated_after='2026-03-26 12:55:18.229965') - assert len(results) == 2 - results = find._time_filter(df, updated_after='2025-03-04 12:55:18.229965') - assert len(results) == 4 - - data = [record_A, record_B, record_C, record_D, record_F] - cols = ["name", - "ensemble", - "code", - "path", - "project", - "parameters", - "parameter_file", - "created_at", - "updated_at"] - df = pd.DataFrame(data,columns=cols) - - with pytest.raises(ValueError): - results = find._time_filter(df, created_before='2023-03-26 12:55:18.229966') - - data = [record_A, record_B, record_C, record_D, record_G] - cols = ["name", - "ensemble", - "code", - "path", - "project", - "parameters", - "parameter_file", - "created_at", - "updated_at"] - df = pd.DataFrame(data,columns=cols) - - with pytest.raises(ValueError): - results = find._time_filter(df, created_before='2023-03-26 12:55:18.229966') - - -def test_db_lookup(tmp_path: Path) -> None: - db = make_sql(tmp_path) - conn = sqlite3.connect(db) - c = conn.cursor() - - corr = "f_A" - ensemble = "SF_A" - code = "openQCD" - meas_path = "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf" - uuid = "Project_A" - pars = "{par_A: 3.0, par_B: 5.0}" - parameter_file = "projects/Project_A/myinput.in" - c.execute("INSERT INTO backlogs (name, ensemble, code, path, project, parameters, parameter_file, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now'))", - (corr, ensemble, code, meas_path, uuid, pars, parameter_file)) - conn.commit() - - results = find._db_lookup(db, ensemble, corr, code) - assert len(results) == 1 - results = find._db_lookup(db, "SF_B", corr, code) - assert results.empty - results = find._db_lookup(db, ensemble, "g_A", code) - assert results.empty - results = find._db_lookup(db, ensemble, corr, "sfcf") - assert results.empty - results = find._db_lookup(db, ensemble, corr, code, project = "Project_A") - assert len(results) == 1 - results = find._db_lookup(db, ensemble, corr, code, project = "Project_B") - assert results.empty - results = find._db_lookup(db, ensemble, corr, code, parameters = pars) - assert len(results) == 1 - results = find._db_lookup(db, ensemble, corr, code, parameters = '{"par_A": 3.0, "par_B": 4.0}') - assert results.empty - - corr = "g_A" - ensemble = "SF_A" - code = "openQCD" - meas_path = "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf" - uuid = "Project_A" - pars = '{"par_A": 3.0, "par_B": 4.0}' - parameter_file = "projects/Project_A/myinput.in" - c.execute("INSERT INTO backlogs (name, ensemble, code, path, project, parameters, parameter_file, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now'))", - (corr, ensemble, code, meas_path, uuid, pars, parameter_file)) - conn.commit() - - corr = "f_A" - results = find._db_lookup(db, ensemble, corr, code) - assert len(results) == 1 - results = find._db_lookup(db, "SF_B", corr, code) - assert results.empty - results = find._db_lookup(db, ensemble, "g_A", code) - assert len(results) == 1 - results = find._db_lookup(db, ensemble, corr, "sfcf") - assert results.empty - results = find._db_lookup(db, ensemble, corr, code, project = "Project_A") - assert len(results) == 1 - results = find._db_lookup(db, ensemble, "g_A", code, project = "Project_A") - assert len(results) == 1 - results = find._db_lookup(db, ensemble, corr, code, project = "Project_B") - assert results.empty - results = find._db_lookup(db, ensemble, "g_A", code, project = "Project_B") - assert results.empty - results = find._db_lookup(db, ensemble, corr, code, parameters = pars) - assert results.empty - results = find._db_lookup(db, ensemble, "g_A", code, parameters = '{"par_A": 3.0, "par_B": 4.0}') - assert len(results) == 1 - - conn.close() - - -def test_sfcf_drop() -> None: - parameters0 = { - 'offset': [0,0,0], - 'quarks': [{'mass': 1, 'thetas': [0,0,0]}, {'mass': 2, 'thetas': [0,0,1]}], # m0s = -3.5, -3.75 - 'wf1': [[1, [0, 0]], [0.5, [1, 0]], [.75, [.5, .5]]], - 'wf2': [[1, [2, 1]], [2, [0.5, -0.5]], [.5, [.75, .72]]], - } - - assert not find._sfcf_drop(parameters0, offset=[0,0,0]) - assert find._sfcf_drop(parameters0, offset=[1,0,0]) - - assert not find._sfcf_drop(parameters0, quark_kappas = [1, 2]) - assert find._sfcf_drop(parameters0, quark_kappas = [-3.1, -3.72]) - - assert not find._sfcf_drop(parameters0, quark_masses = [-3.5, -3.75]) - assert find._sfcf_drop(parameters0, quark_masses = [-3.1, -3.72]) - - assert not find._sfcf_drop(parameters0, qk1 = 1) - assert not find._sfcf_drop(parameters0, qk2 = 2) - assert find._sfcf_drop(parameters0, qk1 = 2) - assert find._sfcf_drop(parameters0, qk2 = 1) - - assert not find._sfcf_drop(parameters0, qk1 = [0.5,1.5]) - assert not find._sfcf_drop(parameters0, qk2 = [1.5,2.5]) - assert find._sfcf_drop(parameters0, qk1 = 2) - assert find._sfcf_drop(parameters0, qk2 = 1) - with pytest.raises(ValueError): - assert not find._sfcf_drop(parameters0, qk1 = [0.5,1,5]) - with pytest.raises(ValueError): - assert not find._sfcf_drop(parameters0, qk2 = [1,5,2.5]) - - assert find._sfcf_drop(parameters0, qm1 = 1.2) - assert find._sfcf_drop(parameters0, qm2 = 2.2) - assert not find._sfcf_drop(parameters0, qm1 = -3.5) - assert not find._sfcf_drop(parameters0, qm2 = -3.75) - - assert find._sfcf_drop(parameters0, qm2 = 1.2) - assert find._sfcf_drop(parameters0, qm1 = 2.2) - with pytest.raises(ValueError): - assert not find._sfcf_drop(parameters0, qm1 = [0.5,1,5]) - with pytest.raises(ValueError): - assert not find._sfcf_drop(parameters0, qm2 = [1,5,2.5]) - - -def test_openQCD_filter() -> None: - record_0 = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", - '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] - record_1 = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", - '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] - record_2 = ["f_P", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", - '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] - record_3 = ["f_P", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", - '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] - data = [ - record_0, - record_1, - record_2, - record_3, - ] - cols = ["name", - "ensemble", - "code", - "path", - "project", - "parameters", - "parameter_file", - "created_at", - "updated_at"] - df = pd.DataFrame(data,columns=cols) - - with pytest.warns(Warning): - find.openQCD_filter(df, a = "asdf") - - -def test_code_filter() -> None: - record_0 = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", - '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] - record_1 = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", - '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] - record_2 = ["f_P", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", - '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] - record_3 = ["f_P", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", - '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] - record_4 = ["f_A", "ensA", "openQCD", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", - '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] - record_5 = ["f_A", "ensA", "openQCD", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", - '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] - record_6 = ["f_P", "ensA", "openQCD", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", - '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] - record_7 = ["f_P", "ensA", "openQCD", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", - '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] - record_8 = ["f_P", "ensA", "openQCD", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", - '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] - data = [ - record_0, - record_1, - record_2, - record_3, - ] - cols = ["name", - "ensemble", - "code", - "path", - "project", - "parameters", - "parameter_file", - "created_at", - "updated_at"] - df = pd.DataFrame(data,columns=cols) - - res = find._code_filter(df, "sfcf") - assert len(res) == 4 - - data = [ - record_4, - record_5, - record_6, - record_7, - record_8, - ] - cols = ["name", - "ensemble", - "code", - "path", - "project", - "parameters", - "parameter_file", - "created_at", - "updated_at"] - df = pd.DataFrame(data,columns=cols) - - res = find._code_filter(df, "openQCD") - assert len(res) == 5 - with pytest.raises(ValueError): - res = find._code_filter(df, "asdf") - - -def test_find_record() -> None: - assert True - - -def test_find_project(tmp_path: Path) -> None: - cinit.create(tmp_path) - db = tmp_path / "backlogger.db" - dl.unlock(str(db), dataset=str(tmp_path)) - conn = sqlite3.connect(db) - c = conn.cursor() - uuid = "test_uuid" - alias_str = "fun_project" - tag_str = "tt" - owner = "tester" - code = "test_code" - c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", - (uuid, alias_str, tag_str, owner, code)) - conn.commit() - - assert uuid == find.find_project(tmp_path, "fun_project") - - uuid = "test_uuid2" - alias_str = "fun_project" - c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", - (uuid, alias_str, tag_str, owner, code)) - conn.commit() - - with pytest.raises(Exception): - assert uuid == find._project_lookup_by_alias(tmp_path, "fun_project") - conn.close() - - -def test_list_projects(tmp_path: Path) -> None: - cinit.create(tmp_path) - db = tmp_path / "backlogger.db" - dl.unlock(str(db), dataset=str(tmp_path)) - conn = sqlite3.connect(db) - c = conn.cursor() - uuid = "test_uuid" - alias_str = "fun_project" - tag_str = "tt" - owner = "tester" - code = "test_code" - - c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", - (uuid, alias_str, tag_str, owner, code)) - uuid = "test_uuid2" - alias_str = "fun_project2" - c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", - (uuid, alias_str, tag_str, owner, code)) - uuid = "test_uuid3" - alias_str = "fun_project3" - c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", - (uuid, alias_str, tag_str, owner, code)) - uuid = "test_uuid4" - alias_str = "fun_project4" - c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", - (uuid, alias_str, tag_str, owner, code)) - conn.commit() - conn.close() - results = find.list_projects(tmp_path) - assert len(results) == 4 - for i in range(4): - assert len(results[i]) == 2 diff --git a/tests/sfcf_in_test.py b/tests/sfcf_in_test.py index 7ebc94a..5e4ff83 100644 --- a/tests/sfcf_in_test.py +++ b/tests/sfcf_in_test.py @@ -26,4 +26,4 @@ def test_get_specs() -> None: key = "f_P/q1 q2/1/0/0" specs = json.loads(input.get_specs(key, parameters)) assert specs['quarks'] == ['a', 'b'] - assert specs['wf1'][0] == [1, [0, 0]] + assert specs['wf1'][0] == [1, [0, 0]] \ No newline at end of file diff --git a/tests/initialization_test.py b/tests/test_initialization.py similarity index 94% rename from tests/initialization_test.py rename to tests/test_initialization.py index d78fb15..9284c82 100644 --- a/tests/initialization_test.py +++ b/tests/test_initialization.py @@ -5,21 +5,21 @@ from pathlib import Path def test_init_folders(tmp_path: Path) -> None: dataset_path = tmp_path / "test_dataset" - init.create(dataset_path) + init.create(str(dataset_path)) assert os.path.exists(str(dataset_path)) assert os.path.exists(str(dataset_path / "backlogger.db")) def test_init_folders_no_tracker(tmp_path: Path) -> None: dataset_path = tmp_path / "test_dataset" - init.create(dataset_path, tracker="None") + init.create(str(dataset_path), tracker="None") assert os.path.exists(str(dataset_path)) assert os.path.exists(str(dataset_path / "backlogger.db")) def test_init_config(tmp_path: Path) -> None: dataset_path = tmp_path / "test_dataset" - init.create(dataset_path, tracker="None") + init.create(str(dataset_path), tracker="None") config_path = dataset_path / ".corrlib" assert os.path.exists(str(config_path)) from configparser import ConfigParser @@ -37,7 +37,7 @@ def test_init_config(tmp_path: Path) -> None: def test_init_db(tmp_path: Path) -> None: dataset_path = tmp_path / "test_dataset" - init.create(dataset_path) + init.create(str(dataset_path)) assert os.path.exists(str(dataset_path / "backlogger.db")) conn = sql.connect(str(dataset_path / "backlogger.db")) cursor = conn.cursor() diff --git a/tests/tools_test.py b/tests/tools_test.py index 541674f..88dbffa 100644 --- a/tests/tools_test.py +++ b/tests/tools_test.py @@ -1,7 +1,6 @@ + + from corrlib import tools as tl -from configparser import ConfigParser -from pathlib import Path -import pytest def test_m2k() -> None: @@ -30,55 +29,3 @@ def test_str2list() -> None: def test_list2str() -> None: assert tl.list2str(["a", "b", "c"]) == "a,b,c" assert tl.list2str(["1", "2", "3"]) == "1,2,3" - - -def test_set_config(tmp_path: Path) -> None: - section = "core" - option = "test_option" - value = "test_value" - # config is not yet available - tl.set_config(tmp_path, section, option, value) - config_path = tmp_path / '.corrlib' - config = ConfigParser() - config.read(config_path) - assert config.get('core', 'test_option', fallback="not the value") == "test_value" - # now, a config file is already present - section = "core" - option = "test_option2" - value = "test_value2" - tl.set_config(tmp_path, section, option, value) - config_path = tmp_path / '.corrlib' - config = ConfigParser() - config.read(config_path) - assert config.get('core', 'test_option2', fallback="not the value") == "test_value2" - # update option 2 - section = "core" - option = "test_option2" - value = "test_value3" - tl.set_config(tmp_path, section, option, value) - config_path = tmp_path / '.corrlib' - config = ConfigParser() - config.read(config_path) - assert config.get('core', 'test_option2', fallback="not the value") == "test_value3" - - -def test_get_db_file(tmp_path: Path) -> None: - section = "paths" - option = "db" - value = "test_value" - # config is not yet available - tl.set_config(tmp_path, section, option, value) - assert tl.get_db_file(tmp_path) == Path("test_value") - - -def test_cache_enabled(tmp_path: Path) -> None: - section = "core" - option = "cached" - # config is not yet available - tl.set_config(tmp_path, section, option, "True") - assert tl.cache_enabled(tmp_path) - tl.set_config(tmp_path, section, option, "False") - assert not tl.cache_enabled(tmp_path) - tl.set_config(tmp_path, section, option, "lalala") - with pytest.raises(ValueError): - tl.cache_enabled(tmp_path)