From 52705d8fcdbf320c02ce69a52a1a0bd1d1d24db0 Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Wed, 8 Dec 2021 15:26:27 +0000 Subject: [PATCH] refactor: minor simplifications in derived_observable --- pyerrors/obs.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/pyerrors/obs.py b/pyerrors/obs.py index 60441456..a2deda6d 100644 --- a/pyerrors/obs.py +++ b/pyerrors/obs.py @@ -1057,7 +1057,7 @@ def derived_observable(func, data, array_mode=False, **kwargs): if not all(isinstance(x, Obs) for x in raveled_data): for i in range(len(raveled_data)): if isinstance(raveled_data[i], (int, float)): - raveled_data[i] = cov_Obs(raveled_data[i], 0.0, "###dummy_entry###") + raveled_data[i] = cov_Obs(raveled_data[i], 0.0, "###dummy_covobs###") allcov = {} for o in raveled_data: @@ -1083,9 +1083,7 @@ def derived_observable(func, data, array_mode=False, **kwargs): new_values = func(values, **kwargs) - multi = 0 - if isinstance(new_values, np.ndarray): - multi = 1 + multi = int(isinstance(new_values, np.ndarray)) new_r_values = {} new_idl_d = {} @@ -1137,13 +1135,11 @@ def derived_observable(func, data, array_mode=False, **kwargs): if array_mode is True: - new_covobs_lengths = dict(set([y for x in [[(n, o.covobs[n].N) for n in o.cov_names] for o in raveled_data] for y in x])) - class _Zero_grad(): def __init__(self, N): - # self.grad = np.zeros(N) self.grad = np.zeros((N, 1)) + new_covobs_lengths = dict(set([y for x in [[(n, o.covobs[n].N) for n in o.cov_names] for o in raveled_data] for y in x])) d_extracted = {} g_extracted = {} for name in new_sample_names: