diff --git a/corrlib/__init__.py b/corrlib/__init__.py index 91b07f4..41d8691 100644 --- a/corrlib/__init__.py +++ b/corrlib/__init__.py @@ -19,5 +19,6 @@ from .main import * from .import input as input from .initialization import * from .meas_io import * +from .cache_io import * from .find import * from .version import __version__ diff --git a/corrlib/cache_io.py b/corrlib/cache_io.py index 4d1d632..c890164 100644 --- a/corrlib/cache_io.py +++ b/corrlib/cache_io.py @@ -1,6 +1,19 @@ from typing import Union, Optional import os import shutil +from .tools import record2name_key +from pyerrors import dump_object +import datalad.api as dl +import sqlite3 + + +def get_version_hash(path, record): + db = os.path.join(path, "backlogger.db") + dl.get(db, dataset=path) + conn = sqlite3.connect(db) + c = conn.cursor() + c.execute(f"SELECT current_version FROM 'backlogs' WHERE path = '{record}'") + return c.fetchall()[0][0] def drop_cache_files(path: str, fs: Optional[list[str]]=None): @@ -19,15 +32,27 @@ def cache_dir(path, file): return cache_path -def cache_path(path, file, hash, key): - cache_path = os.path.join(cache_dir(path, file), hash, key) +def cache_path(path, file, sha_hash, key): + cache_path = os.path.join(cache_dir(path, file), key + "_" + sha_hash) return cache_path -def is_in_cache(path, record, hash): - - if os.file.exists(cache_path(path, file, hash, key)): - return True - else: - return False - - \ No newline at end of file + +def is_old_version(path, record): + version_hash = get_version_hash(path, record) + file, key = record2name_key(record) + meas_cache_path = os.path.join(cache_dir(path, file)) + ls = [] + for p, ds, fs in os.walk(meas_cache_path): + ls.extend(fs) + for filename in ls: + if key == filename.split("_")[0]: + if not version_hash == filename.split("_")[1][:-2]: + return True + else: + return False + + +def is_in_cache(path, record): + version_hash = get_version_hash(path, record) + file, key = record2name_key(record) + return os.path.exists(cache_path(path, file, version_hash, key) + ".p") diff --git a/corrlib/cli.py b/corrlib/cli.py index b808c13..44ede1b 100644 --- a/corrlib/cli.py +++ b/corrlib/cli.py @@ -6,7 +6,7 @@ from .toml import import_tomls, update_project, reimport_project from .find import find_record, list_projects from .tools import str2list from .main import update_aliases -from .meas_io import drop_cache as mio_drop_cache +from .cache_io import drop_cache_files as cio_drop_cache_files import os @@ -171,7 +171,7 @@ def drop_cache( """ Drop the currect cache directory of the dataset. """ - mio_drop_cache(path) + cio_drop_cache_files(path) return diff --git a/corrlib/initialization.py b/corrlib/initialization.py index f6ef5aa..e5c0ede 100644 --- a/corrlib/initialization.py +++ b/corrlib/initialization.py @@ -21,7 +21,8 @@ def _create_db(db): parameters TEXT, parameter_file TEXT, created_at TEXT, - updated_at TEXT)''') + updated_at TEXT, + current_version TEXT)''') c.execute('''CREATE TABLE IF NOT EXISTS projects (id TEXT PRIMARY KEY, aliases TEXT, diff --git a/corrlib/meas_io.py b/corrlib/meas_io.py index ad9a6e8..ff7cdc8 100644 --- a/corrlib/meas_io.py +++ b/corrlib/meas_io.py @@ -5,11 +5,10 @@ import sqlite3 from .input import sfcf,openQCD import json from typing import Union, Optional -from pyerrors import Obs, Corr, dump_object, load_object -from hashlib import sha256, sha1 -from .tools import cached, record2name_key -import shutil -from .caching import cache_path, cache_dir +from pyerrors import Obs, Corr, load_object, dump_object +from hashlib import sha256 +from .tools import cached, record2name_key, make_version_hash +from .cache_io import is_in_cache, cache_path, cache_dir, get_version_hash def write_measurement(path, ensemble, measurement, uuid, code, parameter_file: Optional[str]=None): @@ -79,11 +78,13 @@ def write_measurement(path, ensemble, measurement, uuid, code, parameter_file: O subkey = "/".join(par_list) subkeys = [subkey] pars[subkey] = json.dumps(parameters) - for subkey in subkeys: - parHash = sha256(str(pars[subkey]).encode('UTF-8')).hexdigest() - meas_path = file_in_archive + "::" + parHash - known_meas[parHash] = measurement[corr][subkey] + meas_paths = [] + for subkey in subkeys: + par_hash = sha256(str(pars[subkey]).encode('UTF-8')).hexdigest() + meas_path = file_in_archive + "::" + par_hash + meas_paths.append(meas_path) + known_meas[par_hash] = measurement[corr][subkey] if c.execute("SELECT * FROM backlogs WHERE path = ?", (meas_path,)).fetchone() is not None: c.execute("UPDATE backlogs SET updated_at = datetime('now') WHERE path = ?", (meas_path, )) @@ -92,7 +93,12 @@ def write_measurement(path, ensemble, measurement, uuid, code, parameter_file: O (corr, ensemble, code, meas_path, uuid, pars[subkey], parameter_file)) conn.commit() pj.dump_dict_to_json(known_meas, file) - files.append(path + '/backlogger.db') + for meas_path in meas_paths: + version_hash = make_version_hash(path, meas_path) + print(version_hash) + c.execute("UPDATE backlogs SET current_version = ? WHERE project = ? AND code = ? and name = ?", (version_hash, uuid, code, corr)) + conn.commit() + files.append(db) conn.close() dl.save(files, message="Add measurements to database", dataset=path) @@ -140,16 +146,21 @@ def load_records(path: str, record_paths: list[str], preloaded = {}) -> list[Uni returned_data: list = [] for file in needed_data.keys(): for key in list(needed_data[file]): - if os.path.exists(cache_path(path, file, key) + ".p"): - returned_data.append(load_object(cache_path(path, file, key) + ".p")) + record = file + "::" + key + current_version = get_version_hash(path, record) + if is_in_cache(path, record): + returned_data.append(load_object(cache_path(path, file, current_version, key) + ".p")) else: if file not in preloaded: preloaded[file] = preload(path, file) returned_data.append(preloaded[file][key]) if cached: - if not os.path.exists(cache_dir(path, file)): - os.makedirs(cache_dir(path, file)) - dump_object(preloaded[file][key], cache_path(path, file, key)) + if not is_in_cache(path, record): + file, key = record2name_key(record) + if not os.path.exists(cache_dir(path, file)): + os.makedirs(cache_dir(path, file)) + current_version = get_version_hash(path, record) + dump_object(preloaded[file][key], cache_path(path, file, current_version, key)) return returned_data