diff --git a/.gitignore b/.gitignore index 22957fb..f97ff98 100644 --- a/.gitignore +++ b/.gitignore @@ -2,8 +2,7 @@ pyerrors_corrlib.egg-info __pycache__ *.egg-info test.ipynb -test_ds .vscode .venv .pytest_cache -.coverage +.coverage \ No newline at end of file diff --git a/README.md b/README.md index 0f6c9a3..976ae57 100644 --- a/README.md +++ b/README.md @@ -5,12 +5,3 @@ This is done in a reproducible way using `datalad`. In principle, a dataset is created, that is automatically administered by the backlogger, in which data from differnt projects are held together. Everything is catalogued by a searchable SQL database, which holds the paths to the respective measurements. The original projects can be linked to the dataset and the data may be imported using wrapper functions around the read methonds of pyerrors. - -We work with the following nomenclature in this project: -- Measurement - A setis of Observables, including the appropriate metadata. -- Project - A series of measurements that was done by one person as part of their research. -- Record - An entry of a single Correlator in the database of the backlogger. -- \ No newline at end of file diff --git a/TODO.md b/TODO.md index ba32ec9..4153fc3 100644 --- a/TODO.md +++ b/TODO.md @@ -1,21 +1,14 @@ # TODO ## Features -- [ ] implement import of non-datalad projects -- [ ] implement a way to use another backlog repo as a project -- [ ] make cache deadlock resistent (no read while writing) -- [ ] find a way to convey the mathematical structure of what EXACTLY is the form of the correlator in a specific project - - [ ] this could e.g. be done along the lines of mandatory documentation -- [ ] keep better track of the versions of the code, that was used for a specific measurement. - - [ ] maybe let this be an input in the project file? - - [ ] git repo and commit hash/version tag - - [ ] implement a code table? -- [ ] parallel processing of measurements -- [ ] extra SQL table for ensembles with UUID and aliases +- implement import of non-datalad projects +- implement a way to use another backlog repo as a project + +- find a way to convey the mathematical structure of what EXACTLY is the form of the correlator in a specific project + - this could e.g. be done along the lines of mandatory documentation +- keep better track of the versions of the code, that was used for a specific measurement. + - maybe let this be an input in the project file? + - git repo and commit hash/version tag + ## Bugfixes - [ ] revisit the reimport function for single files -- [ ] drop record needs to look if no records are left in a json file. - -## Rough Ideas -- [ ] multitable could provide a high speed implementation of an HDF5 based format -- [ ] implement also a way to include compiled binaries in the archives. diff --git a/corrlib/__init__.py b/corrlib/__init__.py index 448b4d5..4e1b364 100644 --- a/corrlib/__init__.py +++ b/corrlib/__init__.py @@ -22,4 +22,3 @@ from .meas_io import load_records as load_records from .find import find_project as find_project from .find import find_record as find_record from .find import list_projects as list_projects -from .tools import * diff --git a/corrlib/cache_io.py b/corrlib/cache_io.py deleted file mode 100644 index 63d2e68..0000000 --- a/corrlib/cache_io.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Optional -import os -import shutil -from .tools import record2name_key -import datalad.api as dl -import sqlite3 -from tools import db_filename - - -def get_version_hash(path: str, record: str) -> str: - db = os.path.join(path, db_filename(path)) - dl.get(db, dataset=path) - conn = sqlite3.connect(db) - c = conn.cursor() - c.execute(f"SELECT current_version FROM 'backlogs' WHERE path = '{record}'") - return str(c.fetchall()[0][0]) - - -def drop_cache_files(path: str, fs: Optional[list[str]]=None) -> None: - cache_dir = os.path.join(path, ".cache") - if fs is None: - fs = os.listdir(cache_dir) - for f in fs: - shutil.rmtree(os.path.join(cache_dir, f)) - - -def cache_dir(path: str, file: str) -> str: - cache_path_list = [path] - cache_path_list.append(".cache") - cache_path_list.extend(file.split("/")[1:]) - cache_path = "/".join(cache_path_list) - return cache_path - - -def cache_path(path: str, file: str, sha_hash: str, key: str) -> str: - cache_path = os.path.join(cache_dir(path, file), key + "_" + sha_hash) - return cache_path - - -def is_old_version(path: str, record: str) -> bool: - version_hash = get_version_hash(path, record) - file, key = record2name_key(record) - meas_cache_path = os.path.join(cache_dir(path, file)) - ls = [] - is_old = True - for p, ds, fs in os.walk(meas_cache_path): - ls.extend(fs) - for filename in ls: - if key == filename.split("_")[0]: - if version_hash == filename.split("_")[1][:-2]: - is_old = False - return is_old - - -def is_in_cache(path: str, record: str) -> bool: - version_hash = get_version_hash(path, record) - file, key = record2name_key(record) - return os.path.exists(cache_path(path, file, version_hash, key) + ".p") diff --git a/corrlib/cli.py b/corrlib/cli.py index c4e1e4b..f205026 100644 --- a/corrlib/cli.py +++ b/corrlib/cli.py @@ -6,7 +6,8 @@ from .toml import import_tomls, update_project, reimport_project from .find import find_record, list_projects from .tools import str2list from .main import update_aliases -from .cache_io import drop_cache_files as cio_drop_cache_files +from .meas_io import drop_cache as mio_drop_cache +from .meas_io import load_record as mio_load_record import os from importlib.metadata import version @@ -35,6 +36,7 @@ def update( update_project(path, uuid) return + @app.command() def list( path: str = typer.Option( @@ -94,12 +96,39 @@ def find( ensemble: str = typer.Argument(), corr: str = typer.Argument(), code: str = typer.Argument(), + arg: str = typer.Option( + str('all'), + "--argument", + "-a", + ), ) -> None: """ Find a record in the backlog at hand. Through specifying it's ensemble and the measured correlator. """ results = find_record(path, ensemble, corr, code) - print(results) + if arg == 'all': + print(results) + else: + for r in results[arg].values: + print(r) + + +@app.command() +def stat( + path: str = typer.Option( + str('./corrlib'), + "--dataset", + "-d", + ), + record: str = typer.Argument(), + ) -> None: + """ + Show the statistics of a given record. + """ + record = mio_load_record(path, record)[0] + statistics = record.idl + print(statistics) + return @app.command() @@ -180,7 +209,7 @@ def drop_cache( """ Drop the currect cache directory of the dataset. """ - cio_drop_cache_files(path) + mio_drop_cache(path) return diff --git a/corrlib/find.py b/corrlib/find.py index 5d0a678..21063ec 100644 --- a/corrlib/find.py +++ b/corrlib/find.py @@ -4,7 +4,7 @@ import json import pandas as pd import numpy as np from .input.implementations import codes -from .tools import k2m, db_filename +from .tools import k2m, get_db_file from .tracker import get from typing import Any, Optional @@ -230,7 +230,7 @@ def sfcf_filter(results: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: def find_record(path: str, ensemble: str, correlator_name: str, code: str, project: Optional[str]=None, parameters: Optional[str]=None, created_before: Optional[str]=None, created_after: Optional[str]=None, updated_before: Optional[str]=None, updated_after: Optional[str]=None, revision: Optional[str]=None, **kwargs: Any) -> pd.DataFrame: - db_file = db_filename(path) + db_file = get_db_file(path) db = os.path.join(path, db_file) if code not in codes: raise ValueError("Code " + code + "unknown, take one of the following:" + ", ".join(codes)) @@ -262,7 +262,7 @@ def find_project(path: str, name: str) -> str: uuid: str The uuid of the project in question. """ - db_file = db_filename(path) + db_file = get_db_file(path) get(path, db_file) return _project_lookup_by_alias(os.path.join(path, db_file), name) @@ -281,7 +281,7 @@ def list_projects(path: str) -> list[tuple[str, str]]: results: list[Any] The projects known to the library. """ - db_file = db_filename(path) + db_file = get_db_file(path) get(path, db_file) conn = sqlite3.connect(os.path.join(path, db_file)) c = conn.cursor() diff --git a/corrlib/initialization.py b/corrlib/initialization.py index 0b7be48..bb71db6 100644 --- a/corrlib/initialization.py +++ b/corrlib/initialization.py @@ -26,8 +26,7 @@ def _create_db(db: str) -> None: parameters TEXT, parameter_file TEXT, created_at TEXT, - updated_at TEXT, - current_version TEXT)''') + updated_at TEXT)''') c.execute('''CREATE TABLE IF NOT EXISTS projects (id TEXT PRIMARY KEY, aliases TEXT, @@ -72,7 +71,6 @@ def _create_config(path: str, tracker: str, cached: bool) -> ConfigParser: 'archive_path': 'archive', 'toml_imports_path': 'toml_imports', 'import_scripts_path': 'import_scripts', - 'cache_path': '.cache', } return config diff --git a/corrlib/input/sfcf.py b/corrlib/input/sfcf.py index 6a75b72..621f736 100644 --- a/corrlib/input/sfcf.py +++ b/corrlib/input/sfcf.py @@ -3,6 +3,7 @@ import datalad.api as dl import json import os from typing import Any +from fnmatch import fnmatch bi_corrs: list[str] = ["f_P", "fP", "f_p", @@ -298,9 +299,10 @@ def read_data(path: str, project: str, dir_in_project: str, prefix: str, param: if not appended: compact = (version[-1] == "c") for i, item in enumerate(ls): - rep_path = directory + '/' + item - sub_ls = pe.input.sfcf._find_files(rep_path, prefix, compact, []) - files_to_get.extend([rep_path + "/" + filename for filename in sub_ls]) + if fnmatch(item, prefix + "*"): + rep_path = directory + '/' + item + sub_ls = pe.input.sfcf._find_files(rep_path, prefix, compact, []) + files_to_get.extend([rep_path + "/" + filename for filename in sub_ls]) print("Getting data, this might take a while...") diff --git a/corrlib/main.py b/corrlib/main.py index df0cd7a..88b99b3 100644 --- a/corrlib/main.py +++ b/corrlib/main.py @@ -5,7 +5,7 @@ import os from .git_tools import move_submodule import shutil from .find import _project_lookup_by_id -from .tools import list2str, str2list, db_filename +from .tools import list2str, str2list, get_db_file from .tracker import get, save, unlock, clone, drop from typing import Union, Optional @@ -25,7 +25,7 @@ def create_project(path: str, uuid: str, owner: Union[str, None]=None, tags: Uni code: str (optional) The code that was used to create the measurements. """ - db_file = db_filename(path) + db_file = get_db_file(path) db = os.path.join(path, db_file) get(path, db_file) conn = sqlite3.connect(db) @@ -64,7 +64,7 @@ def update_project_data(path: str, uuid: str, prop: str, value: Union[str, None] value: str or None Value to se `prop` to. """ - db_file = db_filename(path) + db_file = get_db_file(path) get(path, db_file) conn = sqlite3.connect(os.path.join(path, db_file)) c = conn.cursor() @@ -75,7 +75,7 @@ def update_project_data(path: str, uuid: str, prop: str, value: Union[str, None] def update_aliases(path: str, uuid: str, aliases: list[str]) -> None: - db_file = db_filename(path) + db_file = get_db_file(path) db = os.path.join(path, db_file) get(path, db_file) known_data = _project_lookup_by_id(db, uuid)[0] @@ -135,7 +135,7 @@ def import_project(path: str, url: str, owner: Union[str, None]=None, tags: Opti if not uuid: raise ValueError("The dataset does not have a uuid!") if not os.path.exists(path + "/projects/" + uuid): - db_file = db_filename(path) + db_file = get_db_file(path) get(path, db_file) unlock(path, db_file) create_project(path, uuid, owner, tags, aliases, code) diff --git a/corrlib/meas_io.py b/corrlib/meas_io.py index 3344efb..65a0569 100644 --- a/corrlib/meas_io.py +++ b/corrlib/meas_io.py @@ -3,14 +3,13 @@ import os import sqlite3 from .input import sfcf,openQCD import json -from typing import Union, Any -from pyerrors import Obs, Corr, load_object, dump_object +from typing import Union +from pyerrors import Obs, Corr, dump_object, load_object from hashlib import sha256 -from .tools import record2name_key, name_key2record, make_version_hash -from .cache_io import is_in_cache, cache_path, cache_dir, get_version_hash -from .tools import db_filename, cache_enabled +from .tools import get_db_file, cache_enabled from .tracker import get, save, unlock import shutil +from typing import Any def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, dict[str, Any]]], uuid: str, code: str, parameter_file: str) -> None: @@ -33,7 +32,7 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, parameter_file: str The parameter file used for the measurement. """ - db_file = db_filename(path) + db_file = get_db_file(path) db = os.path.join(path, db_file) get(path, db_file) unlock(path, db_file) @@ -86,18 +85,18 @@ def write_measurement(path: str, ensemble: str, measurement: dict[str, dict[str, subkey = "/".join(par_list) subkeys = [subkey] pars[subkey] = json.dumps(parameters) - - meas_paths = [] for subkey in subkeys: - par_hash = sha256(str(pars[subkey]).encode('UTF-8')).hexdigest() - meas_path = name_key2record(file_in_archive, par_hash) - meas_paths.append(meas_path) - known_meas[par_hash] = measurement[corr][subkey] - data_hash = make_version_hash(path, meas_path) - if c.execute("SELECT * FROM backlogs WHERE path = ?", (meas_path,)).fetchone() is None: - c.execute("INSERT INTO backlogs (name, ensemble, code, path, project, parameters, parameter_file, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'))", + parHash = sha256(str(pars[subkey]).encode('UTF-8')).hexdigest() + meas_path = file_in_archive + "::" + parHash + + known_meas[parHash] = measurement[corr][subkey] + + if c.execute("SELECT * FROM backlogs WHERE path = ?", (meas_path,)).fetchone() is not None: + c.execute("UPDATE backlogs SET updated_at = datetime('now') WHERE path = ?", (meas_path, )) + else: + c.execute("INSERT INTO backlogs (name, ensemble, code, path, project, parameters, parameter_file, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, datetime('now'), datetime('now'))", (corr, ensemble, code, meas_path, uuid, pars[subkey], parameter_file)) - c.execute("UPDATE backlogs SET current_version = ?, updated_at = datetime('now') WHERE path = ?", (data_hash, meas_path)) + conn.commit() pj.dump_dict_to_json(known_meas, file) files.append(os.path.join(path, db_file)) conn.close() @@ -124,7 +123,7 @@ def load_record(path: str, meas_path: str) -> Union[Corr, Obs]: return load_records(path, [meas_path])[0] -def load_records(path: str, record_paths: list[str], preloaded: dict[str, Any] = {}) -> list[Union[Corr, Obs]]: +def load_records(path: str, meas_paths: list[str], preloaded: dict[str, Any] = {}) -> list[Union[Corr, Obs]]: """ Load a list of records by their paths. @@ -143,32 +142,70 @@ def load_records(path: str, record_paths: list[str], preloaded: dict[str, Any] = The loaded records. """ needed_data: dict[str, list[str]] = {} - for rpath in record_paths: - file, key = record2name_key(rpath) + for mpath in meas_paths: + file = mpath.split("::")[0] if file not in needed_data.keys(): needed_data[file] = [] + key = mpath.split("::")[1] needed_data[file].append(key) returned_data: list[Any] = [] for file in needed_data.keys(): for key in list(needed_data[file]): - record = name_key2record(file, key) - current_version = get_version_hash(path, record) - if is_in_cache(path, record): - returned_data.append(load_object(cache_path(path, file, current_version, key) + ".p")) + if os.path.exists(cache_path(path, file, key) + ".p"): + returned_data.append(load_object(cache_path(path, file, key) + ".p")) else: if file not in preloaded: preloaded[file] = preload(path, file) returned_data.append(preloaded[file][key]) if cache_enabled(path): - if not is_in_cache(path, record): - file, key = record2name_key(record) - if not os.path.exists(cache_dir(path, file)): - os.makedirs(cache_dir(path, file)) - current_version = get_version_hash(path, record) - dump_object(preloaded[file][key], cache_path(path, file, current_version, key)) + if not os.path.exists(cache_dir(path, file)): + os.makedirs(cache_dir(path, file)) + dump_object(preloaded[file][key], cache_path(path, file, key)) return returned_data +def cache_dir(path: str, file: str) -> str: + """ + Returns the directory corresponding to the cache for the given file. + + Parameters + ---------- + path: str + The path of the library. + file: str + The file in the library that we want to access the cached data of. + Returns + ------- + cache_path: str + The path holding the cached data for the given file. + """ + cache_path_list = [path] + cache_path_list.append(".cache") + cache_path_list.extend(file.split("/")[1:]) + cache_path = "/".join(cache_path_list) + return cache_path + + +def cache_path(path: str, file: str, key: str) -> str: + """ + Parameters + ---------- + path: str + The path of the library. + file: str + The file in the library that we want to access the cached data of. + key: str + The key within the archive file. + + Returns + ------- + cache_path: str + The path at which the measurement of the given file and key is cached. + """ + cache_path = os.path.join(cache_dir(path, file), key) + return cache_path + + def preload(path: str, file: str) -> dict[str, Any]: """ Read the contents of a file into a json dictionary with the pyerrors.json.load_json_dict method. @@ -204,7 +241,7 @@ def drop_record(path: str, meas_path: str) -> None: """ file_in_archive = meas_path.split("::")[0] file = os.path.join(path, file_in_archive) - db_file = db_filename(path) + db_file = get_db_file(path) db = os.path.join(path, db_file) get(path, db_file) sub_key = meas_path.split("::")[1] diff --git a/corrlib/tools.py b/corrlib/tools.py index e46ce0a..118b094 100644 --- a/corrlib/tools.py +++ b/corrlib/tools.py @@ -1,7 +1,6 @@ import os -import hashlib from configparser import ConfigParser -from typing import Any, Union +from typing import Any CONFIG_FILENAME = ".corrlib" cached: bool = True @@ -23,7 +22,6 @@ def str2list(string: str) -> list[str]: """ return string.split(",") - def list2str(mylist: list[str]) -> str: """ Convert a list to a comma-separated string. @@ -41,7 +39,6 @@ def list2str(mylist: list[str]) -> str: s = ",".join(mylist) return s - def m2k(m: float) -> float: """ Convert to bare quark mas $m$ to inverse mass parameter $kappa$. @@ -76,47 +73,6 @@ def k2m(k: float) -> float: return (1/(2*k))-4 -def record2name_key(record_path: str) -> tuple[str, str]: - """ - Convert a record to a pair of name and key. - - Parameters - ---------- - record: str - - Returns - ------- - name: str - key: str - """ - file = record_path.split("::")[0] - key = record_path.split("::")[1] - return file, key - - -def name_key2record(name: str, key: str) -> str: - """ - Convert a name and a key to a record name. - - Parameters - ---------- - name: str - key: str - - Returns - ------- - record: str - """ - return name + "::" + key - - -def make_version_hash(path: str, record: str) -> str: - file, key = record2name_key(record) - with open(os.path.join(path, file), 'rb') as fp: - file_hash = hashlib.file_digest(fp, 'sha1').hexdigest() - return file_hash - - def set_config(path: str, section: str, option: str, value: Any) -> None: """ Set configuration parameters for the library. @@ -144,7 +100,7 @@ def set_config(path: str, section: str, option: str, value: Any) -> None: return -def db_filename(path: str) -> str: +def get_db_file(path: str) -> str: """ Get the database file associated with the library at the given path. @@ -188,28 +144,3 @@ def cache_enabled(path: str) -> bool: cached_str = config.get('core', 'cached', fallback='True') cached_bool = cached_str == ('True') return cached_bool - - -def cache_dir_name(path: str) -> Union[str, None]: - """ - Get the database file associated with the library at the given path. - - Parameters - ---------- - path: str - The path of the library. - - Returns - ------- - db_file: str - The file holding the database. - """ - config_path = os.path.join(path, CONFIG_FILENAME) - config = ConfigParser() - if os.path.exists(config_path): - config.read(config_path) - if cache_enabled(path): - cache = config.get('paths', 'cache', fallback='.cache') - else: - cache = None - return cache diff --git a/corrlib/tracker.py b/corrlib/tracker.py index 63aabf2..5cc281c 100644 --- a/corrlib/tracker.py +++ b/corrlib/tracker.py @@ -3,7 +3,7 @@ from configparser import ConfigParser import datalad.api as dl from typing import Optional import shutil -from .tools import db_filename +from .tools import get_db_file def get_tracker(path: str) -> str: @@ -43,7 +43,7 @@ def get(path: str, file: str) -> None: """ tracker = get_tracker(path) if tracker == 'datalad': - if file == db_filename(path): + if file == get_db_file(path): print("Downloading database...") else: print("Downloading data...") diff --git a/setup.py b/setup.py deleted file mode 100644 index 6b8794e..0000000 --- a/setup.py +++ /dev/null @@ -1,18 +0,0 @@ -from setuptools import setup -from distutils.util import convert_path - - -version = {} -with open(convert_path('corrlib/version.py')) as ver_file: - exec(ver_file.read(), version) - -setup(name='pycorrlib', - version=version['__version__'], - author='Justus Kuhlmann', - author_email='j_kuhl19@uni-muenster.de', - install_requires=['pyerrors>=2.11.1', 'datalad>=1.1.0', 'typer>=0.12.5', 'gitpython>=3.1.45'], - entry_points = { - 'console_scripts': ['pcl=corrlib.cli:app'], - }, - packages=['corrlib', 'corrlib.input'] - ) diff --git a/tests/cli_test.py b/tests/cli_test.py index f1678c6..a6b0bd7 100644 --- a/tests/cli_test.py +++ b/tests/cli_test.py @@ -2,19 +2,18 @@ from typer.testing import CliRunner from corrlib.cli import app import os import sqlite3 as sql -from pathlib import Path runner = CliRunner() -def test_version() -> None: +def test_version(): result = runner.invoke(app, ["--version"]) assert result.exit_code == 0 assert "corrlib" in result.output -def test_init_folders(tmp_path: Path) -> None: +def test_init_folders(tmp_path): dataset_path = tmp_path / "test_dataset" result = runner.invoke(app, ["init", "--dataset", str(dataset_path)]) assert result.exit_code == 0 @@ -22,7 +21,7 @@ def test_init_folders(tmp_path: Path) -> None: assert os.path.exists(str(dataset_path / "backlogger.db")) -def test_init_db(tmp_path: Path) -> None: +def test_init_db(tmp_path): dataset_path = tmp_path / "test_dataset" result = runner.invoke(app, ["init", "--dataset", str(dataset_path)]) assert result.exit_code == 0 @@ -38,7 +37,7 @@ def test_init_db(tmp_path: Path) -> None: table_names = [table[0] for table in tables] for expected_table in expected_tables: assert expected_table in table_names - + cursor.execute("SELECT * FROM projects;") projects = cursor.fetchall() assert len(projects) == 0 @@ -61,7 +60,7 @@ def test_init_db(tmp_path: Path) -> None: project_column_names = [col[1] for col in project_columns] for expected_col in expected_project_columns: assert expected_col in project_column_names - + cursor.execute("PRAGMA table_info('backlogs');") backlog_columns = cursor.fetchall() expected_backlog_columns = [ @@ -82,7 +81,7 @@ def test_init_db(tmp_path: Path) -> None: assert expected_col in backlog_column_names -def test_list(tmp_path: Path) -> None: +def test_list(tmp_path): dataset_path = tmp_path / "test_dataset" result = runner.invoke(app, ["init", "--dataset", str(dataset_path)]) assert result.exit_code == 0 diff --git a/tests/import_project_test.py b/tests/import_project_test.py index 685d2cf..2dea06f 100644 --- a/tests/import_project_test.py +++ b/tests/import_project_test.py @@ -1,7 +1,7 @@ import corrlib.toml as t -def test_toml_check_measurement_data() -> None: +def test_toml_check_measurement_data(): measurements = { "a": { diff --git a/tests/sfcf_in_test.py b/tests/sfcf_in_test.py index 5e4ff83..72921e7 100644 --- a/tests/sfcf_in_test.py +++ b/tests/sfcf_in_test.py @@ -1,7 +1,7 @@ import corrlib.input.sfcf as input import json -def test_get_specs() -> None: +def test_get_specs(): parameters = { 'crr': [ 'f_P', 'f_A' diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 9284c82..1ea0ece 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -1,23 +1,22 @@ import corrlib.initialization as init import os import sqlite3 as sql -from pathlib import Path -def test_init_folders(tmp_path: Path) -> None: +def test_init_folders(tmp_path): dataset_path = tmp_path / "test_dataset" init.create(str(dataset_path)) assert os.path.exists(str(dataset_path)) assert os.path.exists(str(dataset_path / "backlogger.db")) -def test_init_folders_no_tracker(tmp_path: Path) -> None: +def test_init_folders_no_tracker(tmp_path): dataset_path = tmp_path / "test_dataset" init.create(str(dataset_path), tracker="None") assert os.path.exists(str(dataset_path)) assert os.path.exists(str(dataset_path / "backlogger.db")) -def test_init_config(tmp_path: Path) -> None: +def test_init_config(tmp_path): dataset_path = tmp_path / "test_dataset" init.create(str(dataset_path), tracker="None") config_path = dataset_path / ".corrlib" @@ -35,7 +34,7 @@ def test_init_config(tmp_path: Path) -> None: assert config.get("paths", "import_scripts_path") == "import_scripts" -def test_init_db(tmp_path: Path) -> None: +def test_init_db(tmp_path): dataset_path = tmp_path / "test_dataset" init.create(str(dataset_path)) assert os.path.exists(str(dataset_path / "backlogger.db")) diff --git a/tests/tools_test.py b/tests/tools_test.py index 88dbffa..ee76f1c 100644 --- a/tests/tools_test.py +++ b/tests/tools_test.py @@ -3,29 +3,29 @@ from corrlib import tools as tl -def test_m2k() -> None: +def test_m2k(): for m in [0.1, 0.5, 1.0]: expected_k = 1 / (2 * m + 8) assert tl.m2k(m) == expected_k -def test_k2m() -> None: +def test_k2m(): for m in [0.1, 0.5, 1.0]: assert tl.k2m(m) == (1/(2*m))-4 -def test_k2m_m2k() -> None: +def test_k2m_m2k(): for m in [0.1, 0.5, 1.0]: k = tl.m2k(m) m_converted = tl.k2m(k) assert abs(m - m_converted) < 1e-9 -def test_str2list() -> None: +def test_str2list(): assert tl.str2list("a,b,c") == ["a", "b", "c"] assert tl.str2list("1,2,3") == ["1", "2", "3"] -def test_list2str() -> None: +def test_list2str(): assert tl.list2str(["a", "b", "c"]) == "a,b,c" assert tl.list2str(["1", "2", "3"]) == "1,2,3"