From 5789c0cef6953c754aa5e8907534a27fff07dad9 Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Thu, 2 Dec 2021 16:54:51 +0000 Subject: [PATCH] feat: new_cov_names and new_sample_names added to derived_array --- pyerrors/obs.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pyerrors/obs.py b/pyerrors/obs.py index 77f55cff..3edd9292 100644 --- a/pyerrors/obs.py +++ b/pyerrors/obs.py @@ -1048,6 +1048,7 @@ def derived_observable(func, data, array_mode=False, **kwargs): raveled_data = data.ravel() # Workaround for matrix operations containing non Obs data + # TODO: Find more elegant solution here. for i_data in raveled_data: if isinstance(i_data, Obs): first_name = i_data.names[0] @@ -1070,11 +1071,13 @@ def derived_observable(func, data, array_mode=False, **kwargs): n_obs = len(raveled_data) new_names = sorted(set([y for x in [o.names for o in raveled_data] for y in x])) + new_cov_names = sorted(set([y for x in [o.cov_names for o in raveled_data] for y in x])) + new_sample_names = sorted(set(new_names) - set(new_cov_names)) - is_merged = {name: (len(list(filter(lambda o: o.is_merged.get(name, False) is True, raveled_data))) > 0) for name in new_names} + is_merged = {name: (len(list(filter(lambda o: o.is_merged.get(name, False) is True, raveled_data))) > 0) for name in new_sample_names} reweighted = len(list(filter(lambda o: o.reweighted is True, raveled_data))) > 0 new_idl_d = {} - for name in new_names: + for name in new_sample_names: idl = [] for i_data in raveled_data: tmp = i_data.idl.get(name) @@ -1096,7 +1099,7 @@ def derived_observable(func, data, array_mode=False, **kwargs): multi = 1 new_r_values = {} - for name in new_names: + for name in new_sample_names: tmp_values = np.zeros(n_obs) for i, item in enumerate(raveled_data): tmp = item.r_values.get(name) @@ -1140,7 +1143,7 @@ def derived_observable(func, data, array_mode=False, **kwargs): if array_mode is True: d_extracted = {} - for name in new_names: + for name in new_sample_names: d_extracted[name] = [] for i_dat, dat in enumerate(data): ens_length = len(new_idl_d[name]) @@ -1150,7 +1153,7 @@ def derived_observable(func, data, array_mode=False, **kwargs): new_deltas = {} new_grad = {} if array_mode is True: - for name in new_names: + for name in new_sample_names: ens_length = d_extracted[name][0].shape[-1] new_deltas[name] = np.zeros(ens_length) for i_dat, dat in enumerate(d_extracted[name]):