diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 791243f..fbd51ec 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -8,22 +8,21 @@ on: jobs: mypy: runs-on: ubuntu-latest - env: - UV_CACHE_DIR: /tmp/.uv-cache steps: - name: Install git-annex run: | sudo apt-get update - sudo apt-get install -y git-annex + sudo apt-get install -y git-annex - name: Check out the repository uses: https://github.com/RouxAntoine/checkout@v4.1.8 with: show-progress: true - - name: Install uv - uses: astral-sh/setup-uv@v7 + - name: Setup python + uses: https://github.com/actions/setup-python@v5 with: - python-version: ${{ matrix.python-version }} - enable-cache: true + python-version: "3.12" + - name: Install uv + uses: https://github.com/astral-sh/setup-uv@v5 - name: Install corrlib run: uv sync --locked --all-extras --dev --python "3.12" - name: Run tests diff --git a/.github/workflows/pytest.yaml b/.github/workflows/pytest.yaml index 1fcb8fe..b1a4d94 100644 --- a/.github/workflows/pytest.yaml +++ b/.github/workflows/pytest.yaml @@ -17,9 +17,11 @@ jobs: - "3.14" runs-on: ubuntu-latest - env: - UV_CACHE_DIR: /tmp/.uv-cache steps: + - name: Setup git + run: | + git config --global user.email "tester@example.com" + git config --global user.name "Tester" - name: Install git-annex run: | sudo apt-get update @@ -28,11 +30,12 @@ jobs: uses: https://github.com/RouxAntoine/checkout@v4.1.8 with: show-progress: true - - name: Install uv - uses: astral-sh/setup-uv@v7 + - name: Setup python + uses: https://github.com/actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - enable-cache: true + - name: Install uv + uses: https://github.com/astral-sh/setup-uv@v5 - name: Install corrlib run: uv sync --locked --all-extras --dev --python ${{ matrix.python-version }} - name: Run tests diff --git a/.github/workflows/ruff.yaml b/.github/workflows/ruff.yaml index 4de4b0b..1da1225 100644 --- a/.github/workflows/ruff.yaml +++ b/.github/workflows/ruff.yaml @@ -9,8 +9,6 @@ jobs: ruff: runs-on: ubuntu-latest - env: - UV_CACHE_DIR: /tmp/.uv-cache steps: - name: Install git-annex run: | @@ -20,10 +18,12 @@ jobs: uses: https://github.com/RouxAntoine/checkout@v4.1.8 with: show-progress: true - - name: Install uv - uses: astral-sh/setup-uv@v7 + - name: Setup python + uses: https://github.com/actions/setup-python@v5 with: - enable-cache: true + python-version: "3.12" + - name: Install uv + uses: https://github.com/astral-sh/setup-uv@v5 - name: Install corrlib run: uv sync --locked --all-extras --dev --python "3.12" - name: Run tests diff --git a/corrlib/cli.py b/corrlib/cli.py index 414fcc4..6c1c3c5 100644 --- a/corrlib/cli.py +++ b/corrlib/cli.py @@ -7,8 +7,11 @@ from .find import find_record, list_projects from .tools import str2list from .main import update_aliases from .meas_io import drop_cache as mio_drop_cache +from .meas_io import load_record as mio_load_record import os +from pyerrors import Corr from importlib.metadata import version +from pathlib import Path app = typer.Typer() @@ -22,8 +25,8 @@ def _version_callback(value: bool) -> None: @app.command() def update( - path: str = typer.Option( - str('./corrlib'), + path: Path = typer.Option( + Path('./corrlib'), "--dataset", "-d", ), @@ -35,10 +38,11 @@ def update( update_project(path, uuid) return + @app.command() -def list( - path: str = typer.Option( - str('./corrlib'), +def lister( + path: Path = typer.Option( + Path('./corrlib'), "--dataset", "-d", ), @@ -49,8 +53,8 @@ def list( """ if entities in ['ensembles', 'Ensembles','ENSEMBLES']: print("Ensembles:") - for item in os.listdir(path + "/archive"): - if os.path.isdir(os.path.join(path + "/archive", item)): + for item in os.listdir(path / "archive"): + if os.path.isdir(path / "archive" / item): print(item) elif entities == 'projects': results = list_projects(path) @@ -68,8 +72,8 @@ def list( @app.command() def alias_add( - path: str = typer.Option( - str('./corrlib'), + path: Path = typer.Option( + Path('./corrlib'), "--dataset", "-d", ), @@ -86,26 +90,57 @@ def alias_add( @app.command() def find( - path: str = typer.Option( - str('./corrlib'), + path: Path = typer.Option( + Path('./corrlib'), "--dataset", "-d", ), ensemble: str = typer.Argument(), corr: str = typer.Argument(), code: str = typer.Argument(), + arg: str = typer.Option( + str('all'), + "--argument", + "-a", + ), ) -> None: """ Find a record in the backlog at hand. Through specifying it's ensemble and the measured correlator. """ results = find_record(path, ensemble, corr, code) - print(results) + if results.empty: + return + if arg == 'all': + print(results) + else: + for r in results[arg].values: + print(r) + + +@app.command() +def stat( + path: Path = typer.Option( + Path('./corrlib'), + "--dataset", + "-d", + ), + record_id: str = typer.Argument(), + ) -> None: + """ + Show the statistics of a given record. + """ + record = mio_load_record(path, record_id) + if isinstance(record, (list, Corr)): + record = record[0] + statistics = record.idl + print(statistics) + return @app.command() def importer( - path: str = typer.Option( - str('./corrlib'), + path: Path = typer.Option( + Path('./corrlib'), "--dataset", "-d", ), @@ -127,8 +162,8 @@ def importer( @app.command() def reimporter( - path: str = typer.Option( - str('./corrlib'), + path: Path = typer.Option( + Path('./corrlib'), "--dataset", "-d", ), @@ -151,8 +186,8 @@ def reimporter( @app.command() def init( - path: str = typer.Option( - str('./corrlib'), + path: Path = typer.Option( + Path('./corrlib'), "--dataset", "-d", ), @@ -171,8 +206,8 @@ def init( @app.command() def drop_cache( - path: str = typer.Option( - str('./corrlib'), + path: Path = typer.Option( + Path('./corrlib'), "--dataset", "-d", ), diff --git a/corrlib/find.py b/corrlib/find.py index 21063ec..7b07321 100644 --- a/corrlib/find.py +++ b/corrlib/find.py @@ -6,10 +6,15 @@ import numpy as np from .input.implementations import codes from .tools import k2m, get_db_file from .tracker import get +from .integrity import has_valid_times from typing import Any, Optional +from pathlib import Path +import datetime as dt +from collections.abc import Callable +import warnings -def _project_lookup_by_alias(db: str, alias: str) -> str: +def _project_lookup_by_alias(db: Path, alias: str) -> str: """ Lookup a projects UUID by its (human-readable) alias. @@ -27,7 +32,7 @@ def _project_lookup_by_alias(db: str, alias: str) -> str: """ conn = sqlite3.connect(db) c = conn.cursor() - c.execute(f"SELECT * FROM 'projects' WHERE alias = '{alias}'") + c.execute(f"SELECT * FROM 'projects' WHERE aliases = '{alias}'") results = c.fetchall() conn.close() if len(results)>1: @@ -37,7 +42,7 @@ def _project_lookup_by_alias(db: str, alias: str) -> str: return str(results[0][0]) -def _project_lookup_by_id(db: str, uuid: str) -> list[tuple[str, str]]: +def _project_lookup_by_id(db: Path, uuid: str) -> list[tuple[str, ...]]: """ Return the project information available in the database by UUID. @@ -61,8 +66,56 @@ def _project_lookup_by_id(db: str, uuid: str) -> list[tuple[str, str]]: return results -def _db_lookup(db: str, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None, - created_before: Optional[str]=None, created_after: Optional[Any]=None, updated_before: Optional[Any]=None, updated_after: Optional[Any]=None) -> pd.DataFrame: +def _time_filter(results: pd.DataFrame, created_before: Optional[str]=None, created_after: Optional[Any]=None, updated_before: Optional[Any]=None, updated_after: Optional[Any]=None) -> pd.DataFrame: + """ + Filter the results from the database in terms of the creation and update times. + + Parameters + ---------- + results: pd.DataFrame + The dataframe holding the unfilteres results from the database. + created_before: str + Contraint on the creation date in datetime.datetime.isoformat. Note that this is exclusive. The creation date has to be truly before the date and time given. + created_after: str + Contraint on the creation date in datetime.datetime.isoformat. Note that this is exclusive. The creation date has to be truly after the date and time given. + updated_before: str + Contraint on the creation date in datetime.datetime.isoformat. Note that this is exclusive. The date of the last update has to be truly before the date and time given. + updated_after: str + Contraint on the creation date in datetime.datetime.isoformat. Note that this is exclusive. The date of the last update has to be truly after the date and time given. + """ + drops = [] + for ind in range(len(results)): + result = results.iloc[ind] + created_at = dt.datetime.fromisoformat(result['created_at']) + updated_at = dt.datetime.fromisoformat(result['updated_at']) + db_times_valid = has_valid_times(result) + if not db_times_valid: + raise ValueError('Time stamps not valid for result with path', result["path"]) + + if created_before is not None: + date_created_before = dt.datetime.fromisoformat(created_before) + if date_created_before < created_at: + drops.append(ind) + continue + if created_after is not None: + date_created_after = dt.datetime.fromisoformat(created_after) + if date_created_after > created_at: + drops.append(ind) + continue + if updated_before is not None: + date_updated_before = dt.datetime.fromisoformat(updated_before) + if date_updated_before < updated_at: + drops.append(ind) + continue + if updated_after is not None: + date_updated_after = dt.datetime.fromisoformat(updated_after) + if date_updated_after > updated_at: + drops.append(ind) + continue + return results.drop(drops) + + +def _db_lookup(db: Path, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None) -> pd.DataFrame: """ Look up a correlator record in the database by the data given to the method. @@ -104,22 +157,86 @@ def _db_lookup(db: str, ensemble: str, correlator_name: str, code: str, project: search_expr += f" AND code = '{code}'" if parameters: search_expr += f" AND parameters = '{parameters}'" - if created_before: - search_expr += f" AND created_at < '{created_before}'" - if created_after: - search_expr += f" AND created_at > '{created_after}'" - if updated_before: - search_expr += f" AND updated_at < '{updated_before}'" - if updated_after: - search_expr += f" AND updated_at > '{updated_after}'" conn = sqlite3.connect(db) results = pd.read_sql(search_expr, conn) conn.close() return results +def _sfcf_drop(param: dict[str, Any], **kwargs: Any) -> bool: + if 'offset' in kwargs: + if kwargs.get('offset') != param['offset']: + return True + if 'quark_kappas' in kwargs: + kappas = kwargs['quark_kappas'] + if (not np.isclose(kappas[0], param['quarks'][0]['mass']) or not np.isclose(kappas[1], param['quarks'][1]['mass'])): + return True + if 'quark_masses' in kwargs: + masses = kwargs['quark_masses'] + if (not np.isclose(masses[0], k2m(param['quarks'][0]['mass'])) or not np.isclose(masses[1], k2m(param['quarks'][1]['mass']))): + return True + if 'qk1' in kwargs: + quark_kappa1 = kwargs['qk1'] + if not isinstance(quark_kappa1, list): + if (not np.isclose(quark_kappa1, param['quarks'][0]['mass'])): + return True + else: + if len(quark_kappa1) == 2: + if (quark_kappa1[0] > param['quarks'][0]['mass']) or (quark_kappa1[1] < param['quarks'][0]['mass']): + return True + else: + raise ValueError("quark_kappa1 has to have length 2") + if 'qk2' in kwargs: + quark_kappa2 = kwargs['qk2'] + if not isinstance(quark_kappa2, list): + if (not np.isclose(quark_kappa2, param['quarks'][1]['mass'])): + return True + else: + if len(quark_kappa2) == 2: + if (quark_kappa2[0] > param['quarks'][1]['mass']) or (quark_kappa2[1] < param['quarks'][1]['mass']): + return True + else: + raise ValueError("quark_kappa2 has to have length 2") + if 'qm1' in kwargs: + quark_mass1 = kwargs['qm1'] + if not isinstance(quark_mass1, list): + if (not np.isclose(quark_mass1, k2m(param['quarks'][0]['mass']))): + return True + else: + if len(quark_mass1) == 2: + if (quark_mass1[0] > k2m(param['quarks'][0]['mass'])) or (quark_mass1[1] < k2m(param['quarks'][0]['mass'])): + return True + else: + raise ValueError("quark_mass1 has to have length 2") + if 'qm2' in kwargs: + quark_mass2 = kwargs['qm2'] + if not isinstance(quark_mass2, list): + if (not np.isclose(quark_mass2, k2m(param['quarks'][1]['mass']))): + return True + else: + if len(quark_mass2) == 2: + if (quark_mass2[0] > k2m(param['quarks'][1]['mass'])) or (quark_mass2[1] < k2m(param['quarks'][1]['mass'])): + return True + else: + raise ValueError("quark_mass2 has to have length 2") + if 'quark_thetas' in kwargs: + quark_thetas = kwargs['quark_thetas'] + if (quark_thetas[0] != param['quarks'][0]['thetas'] and quark_thetas[1] != param['quarks'][1]['thetas']) or (quark_thetas[0] != param['quarks'][1]['thetas'] and quark_thetas[1] != param['quarks'][0]['thetas']): + return True + # careful, this is not save, when multiple contributions are present! + if 'wf1' in kwargs: + wf1 = kwargs['wf1'] + if not (np.isclose(wf1[0][0], param['wf1'][0][0], 1e-8) and np.isclose(wf1[0][1][0], param['wf1'][0][1][0], 1e-8) and np.isclose(wf1[0][1][1], param['wf1'][0][1][1], 1e-8)): + return True + if 'wf2' in kwargs: + wf2 = kwargs['wf2'] + if not (np.isclose(wf2[0][0], param['wf2'][0][0], 1e-8) and np.isclose(wf2[0][1][0], param['wf2'][0][1][0], 1e-8) and np.isclose(wf2[0][1][1], param['wf2'][0][1][1], 1e-8)): + return True + return False + + def sfcf_filter(results: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: - """ + r""" Filter method for the Database entries holding SFCF calculations. Parameters @@ -135,9 +252,9 @@ def sfcf_filter(results: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: qk2: float, optional Mass parameter $\kappa_2$ of the first quark. qm1: float, optional - Bare quak mass $m_1$ of the first quark. + Bare quark mass $m_1$ of the first quark. qm2: float, optional - Bare quak mass $m_1$ of the first quark. + Bare quark mass $m_2$ of the first quark. quarks_thetas: list[list[float]], optional wf1: optional wf2: optional @@ -147,106 +264,85 @@ def sfcf_filter(results: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: results: pd.DataFrame The filtered DataFrame, only holding the records that fit to the parameters given. """ + drops = [] for ind in range(len(results)): result = results.iloc[ind] param = json.loads(result['parameters']) - if 'offset' in kwargs: - if kwargs.get('offset') != param['offset']: - drops.append(ind) - continue - if 'quark_kappas' in kwargs: - kappas = kwargs['quark_kappas'] - if (not np.isclose(kappas[0], param['quarks'][0]['mass']) or not np.isclose(kappas[1], param['quarks'][1]['mass'])): - drops.append(ind) - continue - if 'quark_masses' in kwargs: - masses = kwargs['quark_masses'] - if (not np.isclose(masses[0], k2m(param['quarks'][0]['mass'])) or not np.isclose(masses[1], k2m(param['quarks'][1]['mass']))): - drops.append(ind) - continue - if 'qk1' in kwargs: - quark_kappa1 = kwargs['qk1'] - if not isinstance(quark_kappa1, list): - if (not np.isclose(quark_kappa1, param['quarks'][0]['mass'])): - drops.append(ind) - continue - else: - if len(quark_kappa1) == 2: - if (quark_kappa1[0] > param['quarks'][0]['mass']) or (quark_kappa1[1] < param['quarks'][0]['mass']): - drops.append(ind) - continue - if 'qk2' in kwargs: - quark_kappa2 = kwargs['qk2'] - if not isinstance(quark_kappa2, list): - if (not np.isclose(quark_kappa2, param['quarks'][1]['mass'])): - drops.append(ind) - continue - else: - if len(quark_kappa2) == 2: - if (quark_kappa2[0] > param['quarks'][1]['mass']) or (quark_kappa2[1] < param['quarks'][1]['mass']): - drops.append(ind) - continue - if 'qm1' in kwargs: - quark_mass1 = kwargs['qm1'] - if not isinstance(quark_mass1, list): - if (not np.isclose(quark_mass1, k2m(param['quarks'][0]['mass']))): - drops.append(ind) - continue - else: - if len(quark_mass1) == 2: - if (quark_mass1[0] > k2m(param['quarks'][0]['mass'])) or (quark_mass1[1] < k2m(param['quarks'][0]['mass'])): - drops.append(ind) - continue - if 'qm2' in kwargs: - quark_mass2 = kwargs['qm2'] - if not isinstance(quark_mass2, list): - if (not np.isclose(quark_mass2, k2m(param['quarks'][1]['mass']))): - drops.append(ind) - continue - else: - if len(quark_mass2) == 2: - if (quark_mass2[0] > k2m(param['quarks'][1]['mass'])) or (quark_mass2[1] < k2m(param['quarks'][1]['mass'])): - drops.append(ind) - continue - if 'quark_thetas' in kwargs: - quark_thetas = kwargs['quark_thetas'] - if (quark_thetas[0] != param['quarks'][0]['thetas'] and quark_thetas[1] != param['quarks'][1]['thetas']) or (quark_thetas[0] != param['quarks'][1]['thetas'] and quark_thetas[1] != param['quarks'][0]['thetas']): - drops.append(ind) - continue - # careful, this is not save, when multiple contributions are present! - if 'wf1' in kwargs: - wf1 = kwargs['wf1'] - if not (np.isclose(wf1[0][0], param['wf1'][0][0], 1e-8) and np.isclose(wf1[0][1][0], param['wf1'][0][1][0], 1e-8) and np.isclose(wf1[0][1][1], param['wf1'][0][1][1], 1e-8)): - drops.append(ind) - continue - if 'wf2' in kwargs: - wf2 = kwargs['wf2'] - if not (np.isclose(wf2[0][0], param['wf2'][0][0], 1e-8) and np.isclose(wf2[0][1][0], param['wf2'][0][1][0], 1e-8) and np.isclose(wf2[0][1][1], param['wf2'][0][1][1], 1e-8)): - drops.append(ind) - continue + if _sfcf_drop(param, **kwargs): + drops.append(ind) return results.drop(drops) -def find_record(path: str, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None, - created_before: Optional[str]=None, created_after: Optional[str]=None, updated_before: Optional[str]=None, updated_after: Optional[str]=None, revision: Optional[str]=None, **kwargs: Any) -> pd.DataFrame: +def openQCD_filter(results:pd.DataFrame, **kwargs: Any) -> pd.DataFrame: + """ + Filter for parameters of openQCD. + + Parameters + ---------- + results: pd.DataFrame + The unfiltered list of results from the database. + + Returns + ------- + results: pd.DataFrame + The filtered results. + + """ + warnings.warn("A filter for openQCD parameters is no implemented yet.", Warning) + + return results + + +def _code_filter(results: pd.DataFrame, code: str, **kwargs: Any) -> pd.DataFrame: + """ + Abstraction of the filters for the different codes that are available. + At the moment, only openQCD and SFCF are known. + The possible key words for the parameters can be seen in the descriptionso f the code-specific filters. + + Parameters + ---------- + results: pd.DataFrame + The unfiltered list of results from the database. + code: str + The name of the code that produced the record at hand. + kwargs: + The keyworkd args that are handed over to the code-specific filters. + + Returns + ------- + results: pd.DataFrame + The filtered results. + """ + if code == "sfcf": + return sfcf_filter(results, **kwargs) + elif code == "openQCD": + return openQCD_filter(results, **kwargs) + else: + raise ValueError(f"Code {code} is not known.") + + +def find_record(path: Path, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None, + created_before: Optional[str]=None, created_after: Optional[str]=None, updated_before: Optional[str]=None, updated_after: Optional[str]=None, + revision: Optional[str]=None, + customFilter: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, + **kwargs: Any) -> pd.DataFrame: db_file = get_db_file(path) - db = os.path.join(path, db_file) + db = path / db_file if code not in codes: raise ValueError("Code " + code + "unknown, take one of the following:" + ", ".join(codes)) get(path, db_file) - results = _db_lookup(db, ensemble, correlator_name,code, project, parameters=parameters, created_before=created_before, created_after=created_after, updated_before=updated_before, updated_after=updated_after) - if code == "sfcf": - results = sfcf_filter(results, **kwargs) - elif code == "openQCD": - pass - else: - raise Exception + results = _db_lookup(db, ensemble, correlator_name,code, project, parameters=parameters) + if any([arg is not None for arg in [created_before, created_after, updated_before, updated_after]]): + results = _time_filter(results, created_before, created_after, updated_before, updated_after) + results = _code_filter(results, code, **kwargs) + if customFilter is not None: + results = customFilter(results) print("Found " + str(len(results)) + " result" + ("s" if len(results)>1 else "")) return results.reset_index() -def find_project(path: str, name: str) -> str: +def find_project(path: Path, name: str) -> str: """ Find a project by it's human readable name. @@ -264,10 +360,10 @@ def find_project(path: str, name: str) -> str: """ db_file = get_db_file(path) get(path, db_file) - return _project_lookup_by_alias(os.path.join(path, db_file), name) + return _project_lookup_by_alias(path / db_file, name) -def list_projects(path: str) -> list[tuple[str, str]]: +def list_projects(path: Path) -> list[tuple[str, str]]: """ List all projects known to the library. diff --git a/corrlib/git_tools.py b/corrlib/git_tools.py index c6e7522..d77f109 100644 --- a/corrlib/git_tools.py +++ b/corrlib/git_tools.py @@ -1,27 +1,28 @@ import os from .tracker import save import git +from pathlib import Path GITMODULES_FILE = '.gitmodules' -def move_submodule(repo_path: str, old_path: str, new_path: str) -> None: +def move_submodule(repo_path: Path, old_path: Path, new_path: Path) -> None: """ Move a submodule to a new location. Parameters ---------- - repo_path: str + repo_path: Path Path to the repository. - old_path: str + old_path: Path The old path of the module. - new_path: str + new_path: Path The new path of the module. """ - os.rename(os.path.join(repo_path, old_path), os.path.join(repo_path, new_path)) + os.rename(repo_path / old_path, repo_path / new_path) - gitmodules_file_path = os.path.join(repo_path, GITMODULES_FILE) + gitmodules_file_path = repo_path / GITMODULES_FILE # update paths in .gitmodules with open(gitmodules_file_path, 'r') as file: @@ -29,8 +30,8 @@ def move_submodule(repo_path: str, old_path: str, new_path: str) -> None: updated_lines = [] for line in lines: - if old_path in line: - line = line.replace(old_path, new_path) + if str(old_path) in line: + line = line.replace(str(old_path), str(new_path)) updated_lines.append(line) with open(gitmodules_file_path, 'w') as file: @@ -40,6 +41,6 @@ def move_submodule(repo_path: str, old_path: str, new_path: str) -> None: repo = git.Repo(repo_path) repo.git.add('.gitmodules') # save new state of the dataset - save(repo_path, message=f"Move module from {old_path} to {new_path}", files=['.gitmodules', repo_path]) + save(repo_path, message=f"Move module from {old_path} to {new_path}", files=[Path('.gitmodules'), repo_path]) return diff --git a/corrlib/initialization.py b/corrlib/initialization.py index bb71db6..c06a201 100644 --- a/corrlib/initialization.py +++ b/corrlib/initialization.py @@ -2,9 +2,10 @@ from configparser import ConfigParser import sqlite3 import os from .tracker import save, init +from pathlib import Path -def _create_db(db: str) -> None: +def _create_db(db: Path) -> None: """ Create the database file and the table. @@ -40,7 +41,7 @@ def _create_db(db: str) -> None: return -def _create_config(path: str, tracker: str, cached: bool) -> ConfigParser: +def _create_config(path: Path, tracker: str, cached: bool) -> ConfigParser: """ Create the config file construction for backlogger. @@ -75,7 +76,7 @@ def _create_config(path: str, tracker: str, cached: bool) -> ConfigParser: return config -def _write_config(path: str, config: ConfigParser) -> None: +def _write_config(path: Path, config: ConfigParser) -> None: """ Write the config file to disk. @@ -91,7 +92,7 @@ def _write_config(path: str, config: ConfigParser) -> None: return -def create(path: str, tracker: str = 'datalad', cached: bool = True) -> None: +def create(path: Path, tracker: str = 'datalad', cached: bool = True) -> None: """ Create folder of backlogs. @@ -107,13 +108,13 @@ def create(path: str, tracker: str = 'datalad', cached: bool = True) -> None: config = _create_config(path, tracker, cached) init(path, tracker) _write_config(path, config) - _create_db(os.path.join(path, config['paths']['db'])) - os.chmod(os.path.join(path, config['paths']['db']), 0o666) - os.makedirs(os.path.join(path, config['paths']['projects_path'])) - os.makedirs(os.path.join(path, config['paths']['archive_path'])) - os.makedirs(os.path.join(path, config['paths']['toml_imports_path'])) - os.makedirs(os.path.join(path, config['paths']['import_scripts_path'], 'template.py')) - with open(os.path.join(path, ".gitignore"), "w") as fp: + _create_db(path / config['paths']['db']) + os.chmod(path / config['paths']['db'], 0o666) + os.makedirs(path / config['paths']['projects_path']) + os.makedirs(path / config['paths']['archive_path']) + os.makedirs(path / config['paths']['toml_imports_path']) + os.makedirs(path / config['paths']['import_scripts_path'] / 'template.py') + with open(path / ".gitignore", "w") as fp: fp.write(".cache") fp.close() save(path, message="Initialized correlator library") diff --git a/corrlib/input/openQCD.py b/corrlib/input/openQCD.py index 71ebec6..879b555 100644 --- a/corrlib/input/openQCD.py +++ b/corrlib/input/openQCD.py @@ -3,9 +3,13 @@ import datalad.api as dl import os import fnmatch from typing import Any, Optional +from pathlib import Path +from ..pars.openQCD import ms1 +from ..pars.openQCD import qcd2 -def read_ms1_param(path: str, project: str, file_in_project: str) -> dict[str, Any]: + +def load_ms1_infile(path: Path, project: str, file_in_project: str) -> dict[str, Any]: """ Read the parameters for ms1 measurements from a parameter file in the project. @@ -69,7 +73,7 @@ def read_ms1_param(path: str, project: str, file_in_project: str) -> dict[str, A return param -def read_ms3_param(path: str, project: str, file_in_project: str) -> dict[str, Any]: +def load_ms3_infile(path: Path, project: str, file_in_project: str) -> dict[str, Any]: """ Read the parameters for ms3 measurements from a parameter file in the project. @@ -103,7 +107,7 @@ def read_ms3_param(path: str, project: str, file_in_project: str) -> dict[str, A return param -def read_rwms(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, postfix: str="ms1", version: str='2.0', names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: +def read_rwms(path: Path, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, postfix: str="ms1", version: str='2.0', names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: """ Read reweighting factor measurements from the project. @@ -160,7 +164,7 @@ def read_rwms(path: str, project: str, dir_in_project: str, param: dict[str, Any return rw_dict -def extract_t0(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str="", names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: +def extract_t0(path: Path, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str="", names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: """ Extract t0 measurements from the project. @@ -234,7 +238,7 @@ def extract_t0(path: str, project: str, dir_in_project: str, param: dict[str, An return t0_dict -def extract_t1(path: str, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str = "", names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: +def extract_t1(path: Path, project: str, dir_in_project: str, param: dict[str, Any], prefix: str, dtr_read: int, xmin: int, spatial_extent: int, fit_range: int = 5, postfix: str = "", names: Optional[list[str]]=None, files: Optional[list[str]]=None) -> dict[str, Any]: """ Extract t1 measurements from the project. @@ -303,3 +307,51 @@ def extract_t1(path: str, project: str, dir_in_project: str, param: dict[str, An t1_dict[param["type"]] = {} t1_dict[param["type"]][pars] = t0 return t1_dict + + +def load_qcd2_pars(path: Path, project: str, file_in_project: str) -> dict[str, Any]: + """ + Thin wrapper around read_qcd2_par_file, getting the file before reading. + + Parameters + ---------- + path: Path + Path of the corrlib repository. + project: str + UUID of the project of the parameter-file. + file_in_project: str + The loaction of the file in the project directory. + + Returns + ------- + par_dict: dict + The dict with the parameters read from the .par-file. + """ + fname = path / "projects" / project / file_in_project + ds = os.path.join(path, "projects", project) + dl.get(fname, dataset=ds) + return qcd2.read_qcd2_par_file(fname) + + +def load_ms1_parfile(path: Path, project: str, file_in_project: str) -> dict[str, Any]: + """ + Thin wrapper around read_qcd2_ms1_par_file, getting the file before reading. + + Parameters + ---------- + path: Path + Path of the corrlib repository. + project: str + UUID of the project of the parameter-file. + file_in_project: str + The loaction of the file in the project directory. + + Returns + ------- + par_dict: dict + The dict with the parameters read from the .par-file. + """ + fname = path / "projects" / project / file_in_project + ds = os.path.join(path, "projects", project) + dl.get(fname, dataset=ds) + return ms1.read_qcd2_ms1_par_file(fname) diff --git a/corrlib/input/sfcf.py b/corrlib/input/sfcf.py index 6a75b72..acd8261 100644 --- a/corrlib/input/sfcf.py +++ b/corrlib/input/sfcf.py @@ -3,6 +3,8 @@ import datalad.api as dl import json import os from typing import Any +from fnmatch import fnmatch +from pathlib import Path bi_corrs: list[str] = ["f_P", "fP", "f_p", @@ -79,7 +81,7 @@ for c in bib_corrs: corr_types[c] = 'bib' -def read_param(path: str, project: str, file_in_project: str) -> dict[str, Any]: +def read_param(path: Path, project: str, file_in_project: str) -> dict[str, Any]: """ Read the parameters from the sfcf file. @@ -95,7 +97,7 @@ def read_param(path: str, project: str, file_in_project: str) -> dict[str, Any]: """ - file = path + "/projects/" + project + '/' + file_in_project + file = path / "projects" / project / file_in_project dl.get(file, dataset=path) with open(file, 'r') as f: lines = f.readlines() @@ -256,7 +258,7 @@ def get_specs(key: str, parameters: dict[str, Any], sep: str = '/') -> str: return s -def read_data(path: str, project: str, dir_in_project: str, prefix: str, param: dict[str, Any], version: str = '1.0c', cfg_seperator: str = 'n', sep: str = '/', **kwargs: Any) -> dict[str, Any]: +def read_data(path: Path, project: str, dir_in_project: str, prefix: str, param: dict[str, Any], version: str = '1.0c', cfg_seperator: str = 'n', sep: str = '/', **kwargs: Any) -> dict[str, Any]: """ Extract the data from the sfcf file. @@ -298,9 +300,10 @@ def read_data(path: str, project: str, dir_in_project: str, prefix: str, param: if not appended: compact = (version[-1] == "c") for i, item in enumerate(ls): - rep_path = directory + '/' + item - sub_ls = pe.input.sfcf._find_files(rep_path, prefix, compact, []) - files_to_get.extend([rep_path + "/" + filename for filename in sub_ls]) + if fnmatch(item, prefix + "*"): + rep_path = directory + '/' + item + sub_ls = pe.input.sfcf._find_files(rep_path, prefix, compact, []) + files_to_get.extend([rep_path + "/" + filename for filename in sub_ls]) print("Getting data, this might take a while...") @@ -318,10 +321,10 @@ def read_data(path: str, project: str, dir_in_project: str, prefix: str, param: if not param['crr'] == []: if names is not None: data_crr = pe.input.sfcf.read_sfcf_multi(directory, prefix, param['crr'], param['mrr'], corr_type_list, range(len(param['wf_offsets'])), - range(len(param['wf_basis'])), range(len(param['wf_basis'])), version, cfg_seperator, keyed_out=True, names=names) + range(len(param['wf_basis'])), range(len(param['wf_basis'])), version, cfg_seperator, keyed_out=True, silent=True, names=names) else: data_crr = pe.input.sfcf.read_sfcf_multi(directory, prefix, param['crr'], param['mrr'], corr_type_list, range(len(param['wf_offsets'])), - range(len(param['wf_basis'])), range(len(param['wf_basis'])), version, cfg_seperator, keyed_out=True) + range(len(param['wf_basis'])), range(len(param['wf_basis'])), version, cfg_seperator, keyed_out=True, silent=True) for key in data_crr.keys(): data[key] = data_crr[key] diff --git a/corrlib/integrity.py b/corrlib/integrity.py new file mode 100644 index 0000000..d865944 --- /dev/null +++ b/corrlib/integrity.py @@ -0,0 +1,45 @@ +import datetime as dt +from pathlib import Path +from .tools import get_db_file +import pandas as pd +import sqlite3 + + +def has_valid_times(result: pd.Series) -> bool: + # we expect created_at <= updated_at <= now + created_at = dt.datetime.fromisoformat(result['created_at']) + updated_at = dt.datetime.fromisoformat(result['updated_at']) + if created_at > updated_at: + return False + if updated_at > dt.datetime.now(): + return False + return True + +def are_keys_unique(db: Path, table: str, col: str) -> bool: + conn = sqlite3.connect(db) + c = conn.cursor() + c.execute(f"SELECT COUNT( DISTINCT CAST(path AS nvarchar(4000))), COUNT({col}) FROM {table};") + results = c.fetchall()[0] + conn.close() + return bool(results[0] == results[1]) + + +def check_db_integrity(path: Path) -> None: + db = get_db_file(path) + + if not are_keys_unique(db, 'backlogs', 'path'): + raise Exception("The paths the backlog table of the database links are not unique.") + + search_expr = "SELECT * FROM 'backlogs'" + conn = sqlite3.connect(db) + results = pd.read_sql(search_expr, conn) + + for _, result in results.iterrows(): + if not has_valid_times(result): + raise ValueError(f"Result with id {result[id]} has wrong time signatures.") + + + +def full_integrity_check(path: Path) -> None: + check_db_integrity(path) + diff --git a/corrlib/main.py b/corrlib/main.py index 88b99b3..831b69d 100644 --- a/corrlib/main.py +++ b/corrlib/main.py @@ -8,9 +8,10 @@ from .find import _project_lookup_by_id from .tools import list2str, str2list, get_db_file from .tracker import get, save, unlock, clone, drop from typing import Union, Optional +from pathlib import Path -def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Union[list[str], None]=None, aliases: Union[list[str], None]=None, code: Union[str, None]=None) -> None: +def create_project(path: Path, uuid: str, owner: Union[str, None]=None, tags: Union[list[str], None]=None, aliases: Union[list[str], None]=None, code: Union[str, None]=None) -> None: """ Create a new project entry in the database. @@ -48,7 +49,7 @@ def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Uni return -def update_project_data(path: str, uuid: str, prop: str, value: Union[str, None] = None) -> None: +def update_project_data(path: Path, uuid: str, prop: str, value: Union[str, None] = None) -> None: """ Update/Edit a project entry in the database. Thin wrapper around sql3 call. @@ -74,9 +75,9 @@ def update_project_data(path: str, uuid: str, prop: str, value: Union[str, None] return -def update_aliases(path: str, uuid: str, aliases: list[str]) -> None: +def update_aliases(path: Path, uuid: str, aliases: list[str]) -> None: db_file = get_db_file(path) - db = os.path.join(path, db_file) + db = path / db_file get(path, db_file) known_data = _project_lookup_by_id(db, uuid)[0] known_aliases = known_data[1] @@ -102,7 +103,7 @@ def update_aliases(path: str, uuid: str, aliases: list[str]) -> None: return -def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Optional[list[str]]=None, aliases: Optional[list[str]]=None, code: Optional[str]=None, isDataset: bool=True) -> str: +def import_project(path: Path, url: str, owner: Union[str, None]=None, tags: Optional[list[str]]=None, aliases: Optional[list[str]]=None, code: Optional[str]=None, isDataset: bool=True) -> str: """ Import a datalad dataset into the backlogger. @@ -134,14 +135,14 @@ def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Opti uuid = str(conf.get("datalad.dataset.id")) if not uuid: raise ValueError("The dataset does not have a uuid!") - if not os.path.exists(path + "/projects/" + uuid): + if not os.path.exists(path / "projects" / uuid): db_file = get_db_file(path) get(path, db_file) unlock(path, db_file) create_project(path, uuid, owner, tags, aliases, code) - move_submodule(path, 'projects/tmp', 'projects/' + uuid) - os.mkdir(path + '/import_scripts/' + uuid) - save(path, message="Import project from " + url, files=['projects/' + uuid, db_file]) + move_submodule(path, Path('projects/tmp'), Path('projects') / uuid) + os.mkdir(path / 'import_scripts' / uuid) + save(path, message="Import project from " + url, files=[Path(f'projects/{uuid}'), db_file]) else: dl.drop(tmp_path, reckless='kill') shutil.rmtree(tmp_path) @@ -156,7 +157,7 @@ def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Opti return uuid -def drop_project_data(path: str, uuid: str, path_in_project: str = "") -> None: +def drop_project_data(path: Path, uuid: str, path_in_project: str = "") -> None: """ Drop (parts of) a project to free up diskspace @@ -169,6 +170,5 @@ def drop_project_data(path: str, uuid: str, path_in_project: str = "") -> None: path_pn_project: str, optional If set, only the given path within the project is dropped. """ - drop(path + "/projects/" + uuid + "/" + path_in_project) + drop(path / "projects" / uuid / path_in_project) return - diff --git a/corrlib/meas_io.py b/corrlib/meas_io.py index 65a0569..f4e8a83 100644 --- a/corrlib/meas_io.py +++ b/corrlib/meas_io.py @@ -10,9 +10,13 @@ from .tools import get_db_file, cache_enabled from .tracker import get, save, unlock import shutil from typing import Any +from pathlib import Path -def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, dict[str, Any]]], uuid: str, code: str, parameter_file: str) -> None: +CACHE_DIR = ".cache" + + +def write_measurement(path: Path, ensemble: str, measurement: dict[str, dict[str, dict[str, Any]]], uuid: str, code: str, parameter_file: Union[str, None]) -> None: """ Write a measurement to the backlog. If the file for the measurement already exists, update the measurement. @@ -33,25 +37,34 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, The parameter file used for the measurement. """ db_file = get_db_file(path) - db = os.path.join(path, db_file) + db = path / db_file + + files_to_save = [] + get(path, db_file) unlock(path, db_file) + files_to_save.append(db_file) + conn = sqlite3.connect(db) c = conn.cursor() - files = [] for corr in measurement.keys(): - file_in_archive = os.path.join('.', 'archive', ensemble, corr, uuid + '.json.gz') - file = os.path.join(path, file_in_archive) - files.append(file) + file_in_archive = Path('.') / 'archive' / ensemble / corr / str(uuid + '.json.gz') + file = path / file_in_archive known_meas = {} - if not os.path.exists(os.path.join(path, '.', 'archive', ensemble, corr)): - os.makedirs(os.path.join(path, '.', 'archive', ensemble, corr)) + if not os.path.exists(path / 'archive' / ensemble / corr): + os.makedirs(path / 'archive' / ensemble / corr) + files_to_save.append(file_in_archive) else: if os.path.exists(file): - unlock(path, file_in_archive) - known_meas = pj.load_json_dict(file) + if file not in files_to_save: + unlock(path, file_in_archive) + files_to_save.append(file_in_archive) + known_meas = pj.load_json_dict(str(file), verbose=False) if code == "sfcf": - parameters = sfcf.read_param(path, uuid, parameter_file) + if parameter_file is not None: + parameters = sfcf.read_param(path, uuid, parameter_file) + else: + raise Exception("Need parameter file for this code!") pars = {} subkeys = list(measurement[corr].keys()) for subkey in subkeys: @@ -60,7 +73,25 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, elif code == "openQCD": ms_type = list(measurement.keys())[0] if ms_type == 'ms1': - parameters = openQCD.read_ms1_param(path, uuid, parameter_file) + if parameter_file is not None: + if parameter_file.endswith(".ms1.in"): + parameters = openQCD.load_ms1_infile(path, uuid, parameter_file) + elif parameter_file.endswith(".ms1.par"): + parameters = openQCD.load_ms1_parfile(path, uuid, parameter_file) + else: + # Temporary solution + parameters = {} + parameters["rand"] = {} + parameters["rw_fcts"] = [{}] + for nrw in range(1): + if "nsrc" not in parameters["rw_fcts"][nrw]: + parameters["rw_fcts"][nrw]["nsrc"] = 1 + if "mu" not in parameters["rw_fcts"][nrw]: + parameters["rw_fcts"][nrw]["mu"] = "None" + if "np" not in parameters["rw_fcts"][nrw]: + parameters["rw_fcts"][nrw]["np"] = "None" + if "irp" not in parameters["rw_fcts"][nrw]: + parameters["rw_fcts"][nrw]["irp"] = "None" pars = {} subkeys = [] for i in range(len(parameters["rw_fcts"])): @@ -72,7 +103,7 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, pars[subkey] = json.dumps(parameters["rw_fcts"][i]) elif ms_type in ['t0', 't1']: if parameter_file is not None: - parameters = openQCD.read_ms3_param(path, uuid, parameter_file) + parameters = openQCD.load_ms3_infile(path, uuid, parameter_file) else: parameters = {} for rwp in ["integrator", "eps", "ntot", "dnms"]: @@ -87,7 +118,7 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, pars[subkey] = json.dumps(parameters) for subkey in subkeys: parHash = sha256(str(pars[subkey]).encode('UTF-8')).hexdigest() - meas_path = file_in_archive + "::" + parHash + meas_path = str(file_in_archive) + "::" + parHash known_meas[parHash] = measurement[corr][subkey] @@ -97,14 +128,13 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, c.execute("INSERT INTO backlogs (name, ensemble, code, path, project, parameters, parameter_file, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now'))", (corr, ensemble, code, meas_path, uuid, pars[subkey], parameter_file)) conn.commit() - pj.dump_dict_to_json(known_meas, file) - files.append(os.path.join(path, db_file)) + pj.dump_dict_to_json(known_meas, str(file)) conn.close() - save(path, message="Add measurements to database", files=files) + save(path, message="Add measurements to database", files=files_to_save) return -def load_record(path: str, meas_path: str) -> Union[Corr, Obs]: +def load_record(path: Path, meas_path: str) -> Union[Corr, Obs]: """ Load a list of records by their paths. @@ -123,7 +153,7 @@ def load_record(path: str, meas_path: str) -> Union[Corr, Obs]: return load_records(path, [meas_path])[0] -def load_records(path: str, meas_paths: list[str], preloaded: dict[str, Any] = {}) -> list[Union[Corr, Obs]]: +def load_records(path: Path, meas_paths: list[str], preloaded: dict[str, Any] = {}) -> list[Union[Corr, Obs]]: """ Load a list of records by their paths. @@ -151,20 +181,20 @@ def load_records(path: str, meas_paths: list[str], preloaded: dict[str, Any] = { returned_data: list[Any] = [] for file in needed_data.keys(): for key in list(needed_data[file]): - if os.path.exists(cache_path(path, file, key) + ".p"): - returned_data.append(load_object(cache_path(path, file, key) + ".p")) + if os.path.exists(str(cache_path(path, file, key)) + ".p"): + returned_data.append(load_object(str(cache_path(path, file, key)) + ".p")) else: if file not in preloaded: - preloaded[file] = preload(path, file) + preloaded[file] = preload(path, Path(file)) returned_data.append(preloaded[file][key]) if cache_enabled(path): if not os.path.exists(cache_dir(path, file)): os.makedirs(cache_dir(path, file)) - dump_object(preloaded[file][key], cache_path(path, file, key)) + dump_object(preloaded[file][key], str(cache_path(path, file, key))) return returned_data -def cache_dir(path: str, file: str) -> str: +def cache_dir(path: Path, file: str) -> Path: """ Returns the directory corresponding to the cache for the given file. @@ -179,14 +209,14 @@ def cache_dir(path: str, file: str) -> str: cache_path: str The path holding the cached data for the given file. """ - cache_path_list = [path] - cache_path_list.append(".cache") - cache_path_list.extend(file.split("/")[1:]) - cache_path = "/".join(cache_path_list) + cache_path_list = file.split("/")[1:] + cache_path = Path(path) / CACHE_DIR + for directory in cache_path_list: + cache_path /= directory return cache_path -def cache_path(path: str, file: str, key: str) -> str: +def cache_path(path: Path, file: str, key: str) -> Path: """ Parameters ---------- @@ -202,11 +232,11 @@ def cache_path(path: str, file: str, key: str) -> str: cache_path: str The path at which the measurement of the given file and key is cached. """ - cache_path = os.path.join(cache_dir(path, file), key) + cache_path = cache_dir(path, file) / key return cache_path -def preload(path: str, file: str) -> dict[str, Any]: +def preload(path: Path, file: Path) -> dict[str, Any]: """ Read the contents of a file into a json dictionary with the pyerrors.json.load_json_dict method. @@ -223,12 +253,12 @@ def preload(path: str, file: str) -> dict[str, Any]: The data read from the file. """ get(path, file) - filedict: dict[str, Any] = pj.load_json_dict(os.path.join(path, file)) + filedict: dict[str, Any] = pj.load_json_dict(str(path / file)) print("> read file") return filedict -def drop_record(path: str, meas_path: str) -> None: +def drop_record(path: Path, meas_path: str) -> None: """ Drop a record by it's path. @@ -240,9 +270,9 @@ def drop_record(path: str, meas_path: str) -> None: The measurement path as noted in the database. """ file_in_archive = meas_path.split("::")[0] - file = os.path.join(path, file_in_archive) + file = path / file_in_archive db_file = get_db_file(path) - db = os.path.join(path, db_file) + db = path / db_file get(path, db_file) sub_key = meas_path.split("::")[1] unlock(path, db_file) @@ -254,18 +284,18 @@ def drop_record(path: str, meas_path: str) -> None: raise ValueError("This measurement does not exist as an entry!") conn.commit() - known_meas = pj.load_json_dict(file) + known_meas = pj.load_json_dict(str(file)) if sub_key in known_meas: del known_meas[sub_key] - unlock(path, file_in_archive) - pj.dump_dict_to_json(known_meas, file) + unlock(path, Path(file_in_archive)) + pj.dump_dict_to_json(known_meas, str(file)) save(path, message="Drop measurements to database", files=[db, file]) return else: raise ValueError("This measurement does not exist as a file!") -def drop_cache(path: str) -> None: +def drop_cache(path: Path) -> None: """ Drop the cache directory of the library. @@ -274,7 +304,7 @@ def drop_cache(path: str) -> None: path: str The path of the library. """ - cache_dir = os.path.join(path, ".cache") + cache_dir = path / ".cache" for f in os.listdir(cache_dir): - shutil.rmtree(os.path.join(cache_dir, f)) + shutil.rmtree(cache_dir / f) return diff --git a/corrlib/pars/openQCD/__init__.py b/corrlib/pars/openQCD/__init__.py new file mode 100644 index 0000000..edbac71 --- /dev/null +++ b/corrlib/pars/openQCD/__init__.py @@ -0,0 +1,3 @@ + +from . import ms1 as ms1 +from . import qcd2 as qcd2 diff --git a/corrlib/pars/openQCD/flags.py b/corrlib/pars/openQCD/flags.py new file mode 100644 index 0000000..95be919 --- /dev/null +++ b/corrlib/pars/openQCD/flags.py @@ -0,0 +1,59 @@ +""" +Reconstruct the outputs of flags. +""" + +import struct +from typing import Any, BinaryIO + +# lat_parms.c +def lat_parms_write_lat_parms(fp: BinaryIO) -> dict[str, Any]: + """ + NOTE: This is a duplcation from qcd2. + Unpack the lattice parameters written by write_lat_parms. + """ + lat_pars = {} + t = fp.read(16) + lat_pars["N"] = list(struct.unpack('iiii', t)) # lattice extends + t = fp.read(8) + nk, isw = struct.unpack('ii', t) # number of kappas and isw parameter + lat_pars["nk"] = nk + lat_pars["isw"] = isw + t = fp.read(8) + lat_pars["beta"] = struct.unpack('d', t)[0] # beta + t = fp.read(8) + lat_pars["c0"] = struct.unpack('d', t)[0] + t = fp.read(8) + lat_pars["c1"] = struct.unpack('d', t)[0] + t = fp.read(8) + lat_pars["csw"] = struct.unpack('d', t)[0] # csw factor + kappas = [] + m0s = [] + # read kappas + for ik in range(nk): + t = fp.read(8) + kappas.append(struct.unpack('d', t)[0]) + t = fp.read(8) + m0s.append(struct.unpack('d', t)[0]) + lat_pars["kappas"] = kappas + lat_pars["m0s"] = m0s + return lat_pars + + +def lat_parms_write_bc_parms(fp: BinaryIO) -> dict[str, Any]: + """ + NOTE: This is a duplcation from qcd2. + Unpack the boundary parameters written by write_bc_parms. + """ + bc_pars: dict[str, Any] = {} + t = fp.read(4) + bc_pars["type"] = struct.unpack('i', t)[0] # type of hte boundaries + t = fp.read(104) + bc_parms = struct.unpack('d'*13, t) + bc_pars["cG"] = list(bc_parms[:2]) # boundary gauge field improvement + bc_pars["cF"] = list(bc_parms[2:4]) # boundary fermion field improvement + phi: list[list[float]] = [[], []] + phi[0] = list(bc_parms[4:7]) + phi[1] = list(bc_parms[7:10]) + bc_pars["phi"] = phi + bc_pars["theta"] = list(bc_parms[10:]) + return bc_pars diff --git a/corrlib/pars/openQCD/ms1.py b/corrlib/pars/openQCD/ms1.py new file mode 100644 index 0000000..4c2aed5 --- /dev/null +++ b/corrlib/pars/openQCD/ms1.py @@ -0,0 +1,30 @@ +from . import flags + +from typing import Any +from pathlib import Path + + +def read_qcd2_ms1_par_file(fname: Path) -> dict[str, dict[str, Any]]: + """ + The subroutines written here have names according to the openQCD programs and functions that write out the data. + Parameters + ---------- + fname: Path + Location of the parameter file. + + Returns + ------- + par_dict: dict + Dictionary holding the parameters specified in the given file. + """ + + with open(fname, "rb") as fp: + lat_par_dict = flags.lat_parms_write_lat_parms(fp) + bc_par_dict = flags.lat_parms_write_bc_parms(fp) + fp.close() + par_dict = {} + par_dict["lat"] = lat_par_dict + par_dict["bc"] = bc_par_dict + return par_dict + + diff --git a/corrlib/pars/openQCD/qcd2.py b/corrlib/pars/openQCD/qcd2.py new file mode 100644 index 0000000..e73c156 --- /dev/null +++ b/corrlib/pars/openQCD/qcd2.py @@ -0,0 +1,29 @@ +from . import flags + +from pathlib import Path +from typing import Any + + +def read_qcd2_par_file(fname: Path) -> dict[str, dict[str, Any]]: + """ + The subroutines written here have names according to the openQCD programs and functions that write out the data. + + Parameters + ---------- + fname: Path + Location of the parameter file. + + Returns + ------- + par_dict: dict + Dictionary holding the parameters specified in the given file. + """ + + with open(fname, "rb") as fp: + lat_par_dict = flags.lat_parms_write_lat_parms(fp) + bc_par_dict = flags.lat_parms_write_bc_parms(fp) + fp.close() + par_dict = {} + par_dict["lat"] = lat_par_dict + par_dict["bc"] = bc_par_dict + return par_dict diff --git a/corrlib/toml.py b/corrlib/toml.py index 629a499..0d4dfc8 100644 --- a/corrlib/toml.py +++ b/corrlib/toml.py @@ -19,6 +19,7 @@ from .meas_io import write_measurement import os from .input.implementations import codes as known_codes from typing import Any +from pathlib import Path def replace_string(string: str, name: str, val: str) -> str: @@ -126,7 +127,7 @@ def check_measurement_data(measurements: dict[str, dict[str, str]], code: str) - return -def import_tomls(path: str, files: list[str], copy_files: bool=True) -> None: +def import_tomls(path: Path, files: list[str], copy_files: bool=True) -> None: """ Import multiple toml files. @@ -144,7 +145,7 @@ def import_tomls(path: str, files: list[str], copy_files: bool=True) -> None: return -def import_toml(path: str, file: str, copy_file: bool=True) -> None: +def import_toml(path: Path, file: str, copy_file: bool=True) -> None: """ Import a project decribed by a .toml file. @@ -171,14 +172,16 @@ def import_toml(path: str, file: str, copy_file: bool=True) -> None: aliases = project.get('aliases', []) uuid = project.get('uuid', None) if uuid is not None: - if not os.path.exists(path + "/projects/" + uuid): + if not os.path.exists(path / "projects" / uuid): uuid = import_project(path, project['url'], aliases=aliases) else: update_aliases(path, uuid, aliases) else: uuid = import_project(path, project['url'], aliases=aliases) + imeas = 1 + nmeas = len(measurements.keys()) for mname, md in measurements.items(): - print("Import measurement: " + mname) + print(f"Import measurement {imeas}/{nmeas}: {mname}") ensemble = md['ensemble'] if project['code'] == 'sfcf': param = sfcf.read_param(path, uuid, md['param_file']) @@ -189,15 +192,34 @@ def import_toml(path: str, file: str, copy_file: bool=True) -> None: measurement = sfcf.read_data(path, uuid, md['path'], md['prefix'], param, version=md['version'], cfg_seperator=md['cfg_seperator'], sep='/') - print(mname + " imported.") elif project['code'] == 'openQCD': if md['measurement'] == 'ms1': - param = openQCD.read_ms1_param(path, uuid, md['param_file']) + if 'param_file' in md.keys(): + parameter_file = md['param_file'] + if parameter_file.endswith(".ms1.in"): + param = openQCD.load_ms1_infile(path, uuid, parameter_file) + elif parameter_file.endswith(".ms1.par"): + param = openQCD.load_ms1_parfile(path, uuid, parameter_file) + else: + # Temporary solution + parameters: dict[str, Any] = {} + parameters["rand"] = {} + parameters["rw_fcts"] = [{}] + for nrw in range(1): + if "nsrc" not in parameters["rw_fcts"][nrw]: + parameters["rw_fcts"][nrw]["nsrc"] = 1 + if "mu" not in parameters["rw_fcts"][nrw]: + parameters["rw_fcts"][nrw]["mu"] = "None" + if "np" not in parameters["rw_fcts"][nrw]: + parameters["rw_fcts"][nrw]["np"] = "None" + if "irp" not in parameters["rw_fcts"][nrw]: + parameters["rw_fcts"][nrw]["irp"] = "None" + param = parameters param['type'] = 'ms1' measurement = openQCD.read_rwms(path, uuid, md['path'], param, md["prefix"], version=md["version"], names=md['names'], files=md['files']) elif md['measurement'] == 't0': if 'param_file' in md: - param = openQCD.read_ms3_param(path, uuid, md['param_file']) + param = openQCD.load_ms3_infile(path, uuid, md['param_file']) else: param = {} for rwp in ["integrator", "eps", "ntot", "dnms"]: @@ -207,25 +229,26 @@ def import_toml(path: str, file: str, copy_file: bool=True) -> None: fit_range=int(md.get('fit_range', 5)), postfix=str(md.get('postfix', '')), names=md.get('names', []), files=md.get('files', [])) elif md['measurement'] == 't1': if 'param_file' in md: - param = openQCD.read_ms3_param(path, uuid, md['param_file']) + param = openQCD.load_ms3_infile(path, uuid, md['param_file']) param['type'] = 't1' measurement = openQCD.extract_t1(path, uuid, md['path'], param, str(md["prefix"]), int(md["dtr_read"]), int(md["xmin"]), int(md["spatial_extent"]), fit_range=int(md.get('fit_range', 5)), postfix=str(md.get('postfix', '')), names=md.get('names', []), files=md.get('files', [])) + write_measurement(path, ensemble, measurement, uuid, project['code'], (md['param_file'] if 'param_file' in md else None)) + imeas += 1 + print(mname + " imported.") - write_measurement(path, ensemble, measurement, uuid, project['code'], (md['param_file'] if 'param_file' in md else '')) - - if not os.path.exists(os.path.join(path, "toml_imports", uuid)): - os.makedirs(os.path.join(path, "toml_imports", uuid)) + if not os.path.exists(path / "toml_imports" / uuid): + os.makedirs(path / "toml_imports" / uuid) if copy_file: - import_file = os.path.join(path, "toml_imports", uuid, file.split("/")[-1]) + import_file = path / "toml_imports" / uuid / file.split("/")[-1] shutil.copy(file, import_file) - save(path, files=[import_file], message="Import using " + import_file) - print("File copied to " + import_file) + save(path, files=[import_file], message=f"Import using {import_file}") + print(f"File copied to {import_file}") print("Imported project.") return -def reimport_project(path: str, uuid: str) -> None: +def reimport_project(path: Path, uuid: str) -> None: """ Reimport an existing project using the files that are already available for this project. @@ -236,14 +259,14 @@ def reimport_project(path: str, uuid: str) -> None: uuid: str uuid of the project that is to be reimported. """ - config_path = "/".join([path, "import_scripts", uuid]) + config_path = path / "import_scripts" / uuid for p, filenames, dirnames in os.walk(config_path): for fname in filenames: import_toml(path, os.path.join(config_path, fname), copy_file=False) return -def update_project(path: str, uuid: str) -> None: +def update_project(path: Path, uuid: str) -> None: """ Update all entries associated with a given project. diff --git a/corrlib/tools.py b/corrlib/tools.py index 118b094..93f0678 100644 --- a/corrlib/tools.py +++ b/corrlib/tools.py @@ -1,6 +1,7 @@ import os from configparser import ConfigParser from typing import Any +from pathlib import Path CONFIG_FILENAME = ".corrlib" cached: bool = True @@ -73,7 +74,7 @@ def k2m(k: float) -> float: return (1/(2*k))-4 -def set_config(path: str, section: str, option: str, value: Any) -> None: +def set_config(path: Path, section: str, option: str, value: Any) -> None: """ Set configuration parameters for the library. @@ -88,7 +89,7 @@ def set_config(path: str, section: str, option: str, value: Any) -> None: value: Any The value we set the option to. """ - config_path = os.path.join(path, '.corrlib') + config_path = os.path.join(path, CONFIG_FILENAME) config = ConfigParser() if os.path.exists(config_path): config.read(config_path) @@ -100,7 +101,7 @@ def set_config(path: str, section: str, option: str, value: Any) -> None: return -def get_db_file(path: str) -> str: +def get_db_file(path: Path) -> Path: """ Get the database file associated with the library at the given path. @@ -118,11 +119,13 @@ def get_db_file(path: str) -> str: config = ConfigParser() if os.path.exists(config_path): config.read(config_path) - db_file = config.get('paths', 'db', fallback='backlogger.db') + else: + raise FileNotFoundError("Configuration file not found.") + db_file = Path(config.get('paths', 'db', fallback='backlogger.db')) return db_file -def cache_enabled(path: str) -> bool: +def cache_enabled(path: Path) -> bool: """ Check, whether the library is cached. Fallback is true. @@ -141,6 +144,10 @@ def cache_enabled(path: str) -> bool: config = ConfigParser() if os.path.exists(config_path): config.read(config_path) + else: + raise FileNotFoundError("Configuration file not found.") cached_str = config.get('core', 'cached', fallback='True') + if cached_str not in ['True', 'False']: + raise ValueError(f"String {cached_str} is not a valid option, only True and False are allowed!") cached_bool = cached_str == ('True') return cached_bool diff --git a/corrlib/tracker.py b/corrlib/tracker.py index 5cc281c..a6e9bf4 100644 --- a/corrlib/tracker.py +++ b/corrlib/tracker.py @@ -4,9 +4,10 @@ import datalad.api as dl from typing import Optional import shutil from .tools import get_db_file +from pathlib import Path -def get_tracker(path: str) -> str: +def get_tracker(path: Path) -> str: """ Get the tracker used in the dataset located at path. @@ -30,7 +31,7 @@ def get_tracker(path: str) -> str: return tracker -def get(path: str, file: str) -> None: +def get(path: Path, file: Path) -> None: """ Wrapper function to get a file from the dataset located at path with the specified tracker. @@ -56,7 +57,7 @@ def get(path: str, file: str) -> None: return -def save(path: str, message: str, files: Optional[list[str]]=None) -> None: +def save(path: Path, message: str, files: Optional[list[Path]]=None) -> None: """ Wrapper function to save a file to the dataset located at path with the specified tracker. @@ -72,7 +73,7 @@ def save(path: str, message: str, files: Optional[list[str]]=None) -> None: tracker = get_tracker(path) if tracker == 'datalad': if files is not None: - files = [os.path.join(path, f) for f in files] + files = [path / f for f in files] dl.save(files, message=message, dataset=path) elif tracker == 'None': Warning("Tracker 'None' does not implement save.") @@ -81,7 +82,7 @@ def save(path: str, message: str, files: Optional[list[str]]=None) -> None: raise ValueError(f"Tracker {tracker} is not supported.") -def init(path: str, tracker: str='datalad') -> None: +def init(path: Path, tracker: str='datalad') -> None: """ Initialize a dataset at the specified path with the specified tracker. @@ -101,7 +102,7 @@ def init(path: str, tracker: str='datalad') -> None: return -def unlock(path: str, file: str) -> None: +def unlock(path: Path, file: Path) -> None: """ Wrapper function to unlock a file in the dataset located at path with the specified tracker. @@ -114,7 +115,7 @@ def unlock(path: str, file: str) -> None: """ tracker = get_tracker(path) if tracker == 'datalad': - dl.unlock(file, dataset=path) + dl.unlock(os.path.join(path, file), dataset=path) elif tracker == 'None': Warning("Tracker 'None' does not implement unlock.") pass @@ -123,7 +124,7 @@ def unlock(path: str, file: str) -> None: return -def clone(path: str, source: str, target: str) -> None: +def clone(path: Path, source: str, target: str) -> None: """ Wrapper function to clone a dataset from source to target with the specified tracker. Parameters @@ -147,7 +148,7 @@ def clone(path: str, source: str, target: str) -> None: return -def drop(path: str, reckless: Optional[str]=None) -> None: +def drop(path: Path, reckless: Optional[str]=None) -> None: """ Wrapper function to drop data from a dataset located at path with the specified tracker. diff --git a/tests/cli_test.py b/tests/cli_test.py index a6b0bd7..cba0a10 100644 --- a/tests/cli_test.py +++ b/tests/cli_test.py @@ -2,18 +2,19 @@ from typer.testing import CliRunner from corrlib.cli import app import os import sqlite3 as sql +from pathlib import Path runner = CliRunner() -def test_version(): +def test_version() -> None: result = runner.invoke(app, ["--version"]) assert result.exit_code == 0 assert "corrlib" in result.output -def test_init_folders(tmp_path): +def test_init_folders(tmp_path: Path) -> None: dataset_path = tmp_path / "test_dataset" result = runner.invoke(app, ["init", "--dataset", str(dataset_path)]) assert result.exit_code == 0 @@ -21,7 +22,7 @@ def test_init_folders(tmp_path): assert os.path.exists(str(dataset_path / "backlogger.db")) -def test_init_db(tmp_path): +def test_init_db(tmp_path: Path) -> None: dataset_path = tmp_path / "test_dataset" result = runner.invoke(app, ["init", "--dataset", str(dataset_path)]) assert result.exit_code == 0 @@ -37,7 +38,7 @@ def test_init_db(tmp_path): table_names = [table[0] for table in tables] for expected_table in expected_tables: assert expected_table in table_names - + cursor.execute("SELECT * FROM projects;") projects = cursor.fetchall() assert len(projects) == 0 @@ -60,7 +61,7 @@ def test_init_db(tmp_path): project_column_names = [col[1] for col in project_columns] for expected_col in expected_project_columns: assert expected_col in project_column_names - + cursor.execute("PRAGMA table_info('backlogs');") backlog_columns = cursor.fetchall() expected_backlog_columns = [ @@ -81,11 +82,11 @@ def test_init_db(tmp_path): assert expected_col in backlog_column_names -def test_list(tmp_path): +def test_list(tmp_path: Path) -> None: dataset_path = tmp_path / "test_dataset" result = runner.invoke(app, ["init", "--dataset", str(dataset_path)]) assert result.exit_code == 0 - result = runner.invoke(app, ["list", "--dataset", str(dataset_path), "ensembles"]) + result = runner.invoke(app, ["lister", "--dataset", str(dataset_path), "ensembles"]) assert result.exit_code == 0 - result = runner.invoke(app, ["list", "--dataset", str(dataset_path), "projects"]) + result = runner.invoke(app, ["lister", "--dataset", str(dataset_path), "projects"]) assert result.exit_code == 0 diff --git a/tests/find_test.py b/tests/find_test.py new file mode 100644 index 0000000..cc455f9 --- /dev/null +++ b/tests/find_test.py @@ -0,0 +1,432 @@ +import corrlib.find as find +import sqlite3 +from pathlib import Path +import corrlib.initialization as cinit +import pytest +import pandas as pd +import datalad.api as dl +import datetime as dt + + +def make_sql(path: Path) -> Path: + db = path / "test.db" + cinit._create_db(db) + return db + + +def test_find_lookup_by_one_alias(tmp_path: Path) -> None: + db = make_sql(tmp_path) + conn = sqlite3.connect(db) + c = conn.cursor() + uuid = "test_uuid" + alias_str = "fun_project" + tag_str = "tt" + owner = "tester" + code = "test_code" + c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", + (uuid, alias_str, tag_str, owner, code)) + conn.commit() + assert uuid == find._project_lookup_by_alias(db, "fun_project") + uuid = "test_uuid2" + alias_str = "fun_project" + c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", + (uuid, alias_str, tag_str, owner, code)) + conn.commit() + with pytest.raises(Exception): + assert uuid == find._project_lookup_by_alias(db, "fun_project") + conn.close() + +def test_find_lookup_by_id(tmp_path: Path) -> None: + db = make_sql(tmp_path) + conn = sqlite3.connect(db) + c = conn.cursor() + uuid = "test_uuid" + alias_str = "fun_project" + tag_str = "tt" + owner = "tester" + code = "test_code" + c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", + (uuid, alias_str, tag_str, owner, code)) + conn.commit() + conn.close() + result = find._project_lookup_by_id(db, uuid)[0] + assert uuid == result[0] + assert alias_str == result[1] + assert tag_str == result[2] + assert owner == result[3] + assert code == result[4] + + +def test_time_filter() -> None: + record_A = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf0", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", + '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] # only created + record_B = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf1", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", + '2025-03-26 12:55:18.229966', '2025-04-26 12:55:18.229966'] # created and updated + record_C = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf2", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", + '2026-03-26 12:55:18.229966', '2026-04-14 12:55:18.229966'] # created and updated later + record_D = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf3", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", + '2026-03-26 12:55:18.229966', '2026-03-27 12:55:18.229966'] + record_E = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf4", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", + '2024-03-26 12:55:18.229966', '2024-03-26 12:55:18.229966'] # only created, earlier + record_F = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf5", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", + '2026-03-26 12:55:18.229966', '2024-03-26 12:55:18.229966'] # this is invalid... + record_G = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf2", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", + '2026-03-26 12:55:18.229966', str(dt.datetime.now() + dt.timedelta(days=2, hours=3, minutes=5, seconds=30))] # created and updated later + + data = [record_A, record_B, record_C, record_D, record_E] + cols = ["name", + "ensemble", + "code", + "path", + "project", + "parameters", + "parameter_file", + "created_at", + "updated_at"] + df = pd.DataFrame(data,columns=cols) + + results = find._time_filter(df, created_before='2023-03-26 12:55:18.229966') + assert results.empty + results = find._time_filter(df, created_before='2027-03-26 12:55:18.229966') + assert len(results) == 5 + results = find._time_filter(df, created_before='2026-03-25 12:55:18.229966') + assert len(results) == 3 + results = find._time_filter(df, created_before='2026-03-26 12:55:18.229965') + assert len(results) == 3 + results = find._time_filter(df, created_before='2025-03-04 12:55:18.229965') + assert len(results) == 1 + + results = find._time_filter(df, created_after='2023-03-26 12:55:18.229966') + assert len(results) == 5 + results = find._time_filter(df, created_after='2027-03-26 12:55:18.229966') + assert results.empty + results = find._time_filter(df, created_after='2026-03-25 12:55:18.229966') + assert len(results) == 2 + results = find._time_filter(df, created_after='2026-03-26 12:55:18.229965') + assert len(results) == 2 + results = find._time_filter(df, created_after='2025-03-04 12:55:18.229965') + assert len(results) == 4 + + results = find._time_filter(df, updated_before='2023-03-26 12:55:18.229966') + assert results.empty + results = find._time_filter(df, updated_before='2027-03-26 12:55:18.229966') + assert len(results) == 5 + results = find._time_filter(df, updated_before='2026-03-25 12:55:18.229966') + assert len(results) == 3 + results = find._time_filter(df, updated_before='2026-03-26 12:55:18.229965') + assert len(results) == 3 + results = find._time_filter(df, updated_before='2025-03-04 12:55:18.229965') + assert len(results) == 1 + + results = find._time_filter(df, updated_after='2023-03-26 12:55:18.229966') + assert len(results) == 5 + results = find._time_filter(df, updated_after='2027-03-26 12:55:18.229966') + assert results.empty + results = find._time_filter(df, updated_after='2026-03-25 12:55:18.229966') + assert len(results) == 2 + results = find._time_filter(df, updated_after='2026-03-26 12:55:18.229965') + assert len(results) == 2 + results = find._time_filter(df, updated_after='2025-03-04 12:55:18.229965') + assert len(results) == 4 + + data = [record_A, record_B, record_C, record_D, record_F] + cols = ["name", + "ensemble", + "code", + "path", + "project", + "parameters", + "parameter_file", + "created_at", + "updated_at"] + df = pd.DataFrame(data,columns=cols) + + with pytest.raises(ValueError): + results = find._time_filter(df, created_before='2023-03-26 12:55:18.229966') + + data = [record_A, record_B, record_C, record_D, record_G] + cols = ["name", + "ensemble", + "code", + "path", + "project", + "parameters", + "parameter_file", + "created_at", + "updated_at"] + df = pd.DataFrame(data,columns=cols) + + with pytest.raises(ValueError): + results = find._time_filter(df, created_before='2023-03-26 12:55:18.229966') + + +def test_db_lookup(tmp_path: Path) -> None: + db = make_sql(tmp_path) + conn = sqlite3.connect(db) + c = conn.cursor() + + corr = "f_A" + ensemble = "SF_A" + code = "openQCD" + meas_path = "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf" + uuid = "Project_A" + pars = "{par_A: 3.0, par_B: 5.0}" + parameter_file = "projects/Project_A/myinput.in" + c.execute("INSERT INTO backlogs (name, ensemble, code, path, project, parameters, parameter_file, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now'))", + (corr, ensemble, code, meas_path, uuid, pars, parameter_file)) + conn.commit() + + results = find._db_lookup(db, ensemble, corr, code) + assert len(results) == 1 + results = find._db_lookup(db, "SF_B", corr, code) + assert results.empty + results = find._db_lookup(db, ensemble, "g_A", code) + assert results.empty + results = find._db_lookup(db, ensemble, corr, "sfcf") + assert results.empty + results = find._db_lookup(db, ensemble, corr, code, project = "Project_A") + assert len(results) == 1 + results = find._db_lookup(db, ensemble, corr, code, project = "Project_B") + assert results.empty + results = find._db_lookup(db, ensemble, corr, code, parameters = pars) + assert len(results) == 1 + results = find._db_lookup(db, ensemble, corr, code, parameters = '{"par_A": 3.0, "par_B": 4.0}') + assert results.empty + + corr = "g_A" + ensemble = "SF_A" + code = "openQCD" + meas_path = "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf" + uuid = "Project_A" + pars = '{"par_A": 3.0, "par_B": 4.0}' + parameter_file = "projects/Project_A/myinput.in" + c.execute("INSERT INTO backlogs (name, ensemble, code, path, project, parameters, parameter_file, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now'))", + (corr, ensemble, code, meas_path, uuid, pars, parameter_file)) + conn.commit() + + corr = "f_A" + results = find._db_lookup(db, ensemble, corr, code) + assert len(results) == 1 + results = find._db_lookup(db, "SF_B", corr, code) + assert results.empty + results = find._db_lookup(db, ensemble, "g_A", code) + assert len(results) == 1 + results = find._db_lookup(db, ensemble, corr, "sfcf") + assert results.empty + results = find._db_lookup(db, ensemble, corr, code, project = "Project_A") + assert len(results) == 1 + results = find._db_lookup(db, ensemble, "g_A", code, project = "Project_A") + assert len(results) == 1 + results = find._db_lookup(db, ensemble, corr, code, project = "Project_B") + assert results.empty + results = find._db_lookup(db, ensemble, "g_A", code, project = "Project_B") + assert results.empty + results = find._db_lookup(db, ensemble, corr, code, parameters = pars) + assert results.empty + results = find._db_lookup(db, ensemble, "g_A", code, parameters = '{"par_A": 3.0, "par_B": 4.0}') + assert len(results) == 1 + + conn.close() + + +def test_sfcf_drop() -> None: + parameters0 = { + 'offset': [0,0,0], + 'quarks': [{'mass': 1, 'thetas': [0,0,0]}, {'mass': 2, 'thetas': [0,0,1]}], # m0s = -3.5, -3.75 + 'wf1': [[1, [0, 0]], [0.5, [1, 0]], [.75, [.5, .5]]], + 'wf2': [[1, [2, 1]], [2, [0.5, -0.5]], [.5, [.75, .72]]], + } + + assert not find._sfcf_drop(parameters0, offset=[0,0,0]) + assert find._sfcf_drop(parameters0, offset=[1,0,0]) + + assert not find._sfcf_drop(parameters0, quark_kappas = [1, 2]) + assert find._sfcf_drop(parameters0, quark_kappas = [-3.1, -3.72]) + + assert not find._sfcf_drop(parameters0, quark_masses = [-3.5, -3.75]) + assert find._sfcf_drop(parameters0, quark_masses = [-3.1, -3.72]) + + assert not find._sfcf_drop(parameters0, qk1 = 1) + assert not find._sfcf_drop(parameters0, qk2 = 2) + assert find._sfcf_drop(parameters0, qk1 = 2) + assert find._sfcf_drop(parameters0, qk2 = 1) + + assert not find._sfcf_drop(parameters0, qk1 = [0.5,1.5]) + assert not find._sfcf_drop(parameters0, qk2 = [1.5,2.5]) + assert find._sfcf_drop(parameters0, qk1 = 2) + assert find._sfcf_drop(parameters0, qk2 = 1) + with pytest.raises(ValueError): + assert not find._sfcf_drop(parameters0, qk1 = [0.5,1,5]) + with pytest.raises(ValueError): + assert not find._sfcf_drop(parameters0, qk2 = [1,5,2.5]) + + assert find._sfcf_drop(parameters0, qm1 = 1.2) + assert find._sfcf_drop(parameters0, qm2 = 2.2) + assert not find._sfcf_drop(parameters0, qm1 = -3.5) + assert not find._sfcf_drop(parameters0, qm2 = -3.75) + + assert find._sfcf_drop(parameters0, qm2 = 1.2) + assert find._sfcf_drop(parameters0, qm1 = 2.2) + with pytest.raises(ValueError): + assert not find._sfcf_drop(parameters0, qm1 = [0.5,1,5]) + with pytest.raises(ValueError): + assert not find._sfcf_drop(parameters0, qm2 = [1,5,2.5]) + + +def test_openQCD_filter() -> None: + record_0 = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", + '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] + record_1 = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", + '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] + record_2 = ["f_P", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", + '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] + record_3 = ["f_P", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", + '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] + data = [ + record_0, + record_1, + record_2, + record_3, + ] + cols = ["name", + "ensemble", + "code", + "path", + "project", + "parameters", + "parameter_file", + "created_at", + "updated_at"] + df = pd.DataFrame(data,columns=cols) + + with pytest.warns(Warning): + find.openQCD_filter(df, a = "asdf") + + +def test_code_filter() -> None: + record_0 = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", + '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] + record_1 = ["f_A", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", + '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] + record_2 = ["f_P", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", + '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] + record_3 = ["f_P", "ensA", "sfcf", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", + '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] + record_4 = ["f_A", "ensA", "openQCD", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", + '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] + record_5 = ["f_A", "ensA", "openQCD", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", + '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] + record_6 = ["f_P", "ensA", "openQCD", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", + '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] + record_7 = ["f_P", "ensA", "openQCD", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", + '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] + record_8 = ["f_P", "ensA", "openQCD", "archive/SF_A/f_A/Project_A.json.gz::asdfasdfasdf", "SF_A", '{"par_A": 5.0, "par_B": 5.0}', "projects/SF_A/input.in", + '2025-03-26 12:55:18.229966', '2025-03-26 12:55:18.229966'] + data = [ + record_0, + record_1, + record_2, + record_3, + ] + cols = ["name", + "ensemble", + "code", + "path", + "project", + "parameters", + "parameter_file", + "created_at", + "updated_at"] + df = pd.DataFrame(data,columns=cols) + + res = find._code_filter(df, "sfcf") + assert len(res) == 4 + + data = [ + record_4, + record_5, + record_6, + record_7, + record_8, + ] + cols = ["name", + "ensemble", + "code", + "path", + "project", + "parameters", + "parameter_file", + "created_at", + "updated_at"] + df = pd.DataFrame(data,columns=cols) + + res = find._code_filter(df, "openQCD") + assert len(res) == 5 + with pytest.raises(ValueError): + res = find._code_filter(df, "asdf") + + +def test_find_record() -> None: + assert True + + +def test_find_project(tmp_path: Path) -> None: + cinit.create(tmp_path) + db = tmp_path / "backlogger.db" + dl.unlock(str(db), dataset=str(tmp_path)) + conn = sqlite3.connect(db) + c = conn.cursor() + uuid = "test_uuid" + alias_str = "fun_project" + tag_str = "tt" + owner = "tester" + code = "test_code" + c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", + (uuid, alias_str, tag_str, owner, code)) + conn.commit() + + assert uuid == find.find_project(tmp_path, "fun_project") + + uuid = "test_uuid2" + alias_str = "fun_project" + c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", + (uuid, alias_str, tag_str, owner, code)) + conn.commit() + + with pytest.raises(Exception): + assert uuid == find._project_lookup_by_alias(tmp_path, "fun_project") + conn.close() + + +def test_list_projects(tmp_path: Path) -> None: + cinit.create(tmp_path) + db = tmp_path / "backlogger.db" + dl.unlock(str(db), dataset=str(tmp_path)) + conn = sqlite3.connect(db) + c = conn.cursor() + uuid = "test_uuid" + alias_str = "fun_project" + tag_str = "tt" + owner = "tester" + code = "test_code" + + c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", + (uuid, alias_str, tag_str, owner, code)) + uuid = "test_uuid2" + alias_str = "fun_project2" + c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", + (uuid, alias_str, tag_str, owner, code)) + uuid = "test_uuid3" + alias_str = "fun_project3" + c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", + (uuid, alias_str, tag_str, owner, code)) + uuid = "test_uuid4" + alias_str = "fun_project4" + c.execute("INSERT INTO projects (id, aliases, customTags, owner, code, created_at, updated_at) VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'))", + (uuid, alias_str, tag_str, owner, code)) + conn.commit() + conn.close() + results = find.list_projects(tmp_path) + assert len(results) == 4 + for i in range(4): + assert len(results[i]) == 2 diff --git a/tests/import_project_test.py b/tests/import_project_test.py index 2dea06f..685d2cf 100644 --- a/tests/import_project_test.py +++ b/tests/import_project_test.py @@ -1,7 +1,7 @@ import corrlib.toml as t -def test_toml_check_measurement_data(): +def test_toml_check_measurement_data() -> None: measurements = { "a": { diff --git a/tests/test_initialization.py b/tests/initialization_test.py similarity index 88% rename from tests/test_initialization.py rename to tests/initialization_test.py index 1ea0ece..d78fb15 100644 --- a/tests/test_initialization.py +++ b/tests/initialization_test.py @@ -1,24 +1,25 @@ import corrlib.initialization as init import os import sqlite3 as sql +from pathlib import Path -def test_init_folders(tmp_path): +def test_init_folders(tmp_path: Path) -> None: dataset_path = tmp_path / "test_dataset" - init.create(str(dataset_path)) + init.create(dataset_path) assert os.path.exists(str(dataset_path)) assert os.path.exists(str(dataset_path / "backlogger.db")) -def test_init_folders_no_tracker(tmp_path): +def test_init_folders_no_tracker(tmp_path: Path) -> None: dataset_path = tmp_path / "test_dataset" - init.create(str(dataset_path), tracker="None") + init.create(dataset_path, tracker="None") assert os.path.exists(str(dataset_path)) assert os.path.exists(str(dataset_path / "backlogger.db")) -def test_init_config(tmp_path): +def test_init_config(tmp_path: Path) -> None: dataset_path = tmp_path / "test_dataset" - init.create(str(dataset_path), tracker="None") + init.create(dataset_path, tracker="None") config_path = dataset_path / ".corrlib" assert os.path.exists(str(config_path)) from configparser import ConfigParser @@ -34,9 +35,9 @@ def test_init_config(tmp_path): assert config.get("paths", "import_scripts_path") == "import_scripts" -def test_init_db(tmp_path): +def test_init_db(tmp_path: Path) -> None: dataset_path = tmp_path / "test_dataset" - init.create(str(dataset_path)) + init.create(dataset_path) assert os.path.exists(str(dataset_path / "backlogger.db")) conn = sql.connect(str(dataset_path / "backlogger.db")) cursor = conn.cursor() diff --git a/tests/sfcf_in_test.py b/tests/sfcf_in_test.py index 72921e7..7ebc94a 100644 --- a/tests/sfcf_in_test.py +++ b/tests/sfcf_in_test.py @@ -1,7 +1,7 @@ import corrlib.input.sfcf as input import json -def test_get_specs(): +def test_get_specs() -> None: parameters = { 'crr': [ 'f_P', 'f_A' @@ -26,4 +26,4 @@ def test_get_specs(): key = "f_P/q1 q2/1/0/0" specs = json.loads(input.get_specs(key, parameters)) assert specs['quarks'] == ['a', 'b'] - assert specs['wf1'][0] == [1, [0, 0]] \ No newline at end of file + assert specs['wf1'][0] == [1, [0, 0]] diff --git a/tests/tools_test.py b/tests/tools_test.py index ee76f1c..541674f 100644 --- a/tests/tools_test.py +++ b/tests/tools_test.py @@ -1,31 +1,84 @@ - - from corrlib import tools as tl +from configparser import ConfigParser +from pathlib import Path +import pytest -def test_m2k(): +def test_m2k() -> None: for m in [0.1, 0.5, 1.0]: expected_k = 1 / (2 * m + 8) assert tl.m2k(m) == expected_k -def test_k2m(): +def test_k2m() -> None: for m in [0.1, 0.5, 1.0]: assert tl.k2m(m) == (1/(2*m))-4 -def test_k2m_m2k(): +def test_k2m_m2k() -> None: for m in [0.1, 0.5, 1.0]: k = tl.m2k(m) m_converted = tl.k2m(k) assert abs(m - m_converted) < 1e-9 -def test_str2list(): +def test_str2list() -> None: assert tl.str2list("a,b,c") == ["a", "b", "c"] assert tl.str2list("1,2,3") == ["1", "2", "3"] -def test_list2str(): +def test_list2str() -> None: assert tl.list2str(["a", "b", "c"]) == "a,b,c" assert tl.list2str(["1", "2", "3"]) == "1,2,3" + + +def test_set_config(tmp_path: Path) -> None: + section = "core" + option = "test_option" + value = "test_value" + # config is not yet available + tl.set_config(tmp_path, section, option, value) + config_path = tmp_path / '.corrlib' + config = ConfigParser() + config.read(config_path) + assert config.get('core', 'test_option', fallback="not the value") == "test_value" + # now, a config file is already present + section = "core" + option = "test_option2" + value = "test_value2" + tl.set_config(tmp_path, section, option, value) + config_path = tmp_path / '.corrlib' + config = ConfigParser() + config.read(config_path) + assert config.get('core', 'test_option2', fallback="not the value") == "test_value2" + # update option 2 + section = "core" + option = "test_option2" + value = "test_value3" + tl.set_config(tmp_path, section, option, value) + config_path = tmp_path / '.corrlib' + config = ConfigParser() + config.read(config_path) + assert config.get('core', 'test_option2', fallback="not the value") == "test_value3" + + +def test_get_db_file(tmp_path: Path) -> None: + section = "paths" + option = "db" + value = "test_value" + # config is not yet available + tl.set_config(tmp_path, section, option, value) + assert tl.get_db_file(tmp_path) == Path("test_value") + + +def test_cache_enabled(tmp_path: Path) -> None: + section = "core" + option = "cached" + # config is not yet available + tl.set_config(tmp_path, section, option, "True") + assert tl.cache_enabled(tmp_path) + tl.set_config(tmp_path, section, option, "False") + assert not tl.cache_enabled(tmp_path) + tl.set_config(tmp_path, section, option, "lalala") + with pytest.raises(ValueError): + tl.cache_enabled(tmp_path)