From 140268c1c93190d1e8754ad276355e95249e833b Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Wed, 8 Dec 2021 15:17:32 +0000 Subject: [PATCH] refactor: two loops over new_sample_names merged. --- pyerrors/obs.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/pyerrors/obs.py b/pyerrors/obs.py index 681a398b..60441456 100644 --- a/pyerrors/obs.py +++ b/pyerrors/obs.py @@ -1075,16 +1075,6 @@ def derived_observable(func, data, array_mode=False, **kwargs): is_merged = {name: (len(list(filter(lambda o: o.is_merged.get(name, False) is True, raveled_data))) > 0) for name in new_sample_names} reweighted = len(list(filter(lambda o: o.reweighted is True, raveled_data))) > 0 - new_idl_d = {} - for name in new_sample_names: - idl = [] - for i_data in raveled_data: - tmp_idl = i_data.idl.get(name) - if tmp_idl is not None: - idl.append(tmp_idl) - new_idl_d[name] = _merge_idx(idl) - if not is_merged[name]: - is_merged[name] = (1 != len(set([len(idx) for idx in [*idl, new_idl_d[name]]]))) if data.ndim == 1: values = np.array([o.value for o in data]) @@ -1098,13 +1088,21 @@ def derived_observable(func, data, array_mode=False, **kwargs): multi = 1 new_r_values = {} + new_idl_d = {} for name in new_sample_names: + idl = [] tmp_values = np.zeros(n_obs) for i, item in enumerate(raveled_data): tmp_values[i] = item.r_values.get(name, item.value) + tmp_idl = item.idl.get(name) + if tmp_idl is not None: + idl.append(tmp_idl) if multi > 0: tmp_values = np.array(tmp_values).reshape(data.shape) new_r_values[name] = func(tmp_values, **kwargs) + new_idl_d[name] = _merge_idx(idl) + if not is_merged[name]: + is_merged[name] = (1 != len(set([len(idx) for idx in [*idl, new_idl_d[name]]]))) if 'man_grad' in kwargs: deriv = np.asarray(kwargs.get('man_grad'))