From ac215696203afcc1ea43115156e0418eb1ab1063 Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Wed, 3 Nov 2021 14:01:52 +0000 Subject: [PATCH] Corr.correlate implemented --- pyerrors/correlators.py | 18 +++++++++++++++++- pyerrors/pyerrors.py | 7 +++++++ tests/pyerrors_test.py | 8 ++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/pyerrors/correlators.py b/pyerrors/correlators.py index b6c7706e..18643422 100644 --- a/pyerrors/correlators.py +++ b/pyerrors/correlators.py @@ -3,7 +3,7 @@ import numpy as np import autograd.numpy as anp import matplotlib.pyplot as plt import scipy.linalg -from .pyerrors import Obs, dump_object, reweight +from .pyerrors import Obs, dump_object, reweight, correlate from .fits import least_squares from .linalg import eigh, inv, cholesky from .roots import find_root @@ -237,6 +237,22 @@ class Corr: """Reverse the time ordering of the Corr""" return Corr(self.content[::-1]) + def correlate(self, partner): + """Correlate the correlator with another correlator or Obs""" + new_content = [] + for x0, t_slice in enumerate(self.content): + if t_slice is None: + new_content.append(None) + else: + if isinstance(partner, Corr): + new_content.append(np.array([correlate(o, partner.content[x0][0]) for o in t_slice])) + elif isinstance(partner, Obs): + new_content.append(np.array([correlate(o, partner) for o in t_slice])) + else: + raise Exception("Can only correlate with an Obs or a Corr.") + + return Corr(new_content) + def reweight(self, weight, **kwargs): """Reweight the correlator. diff --git a/pyerrors/pyerrors.py b/pyerrors/pyerrors.py index 481b6593..e7ecfe8b 100644 --- a/pyerrors/pyerrors.py +++ b/pyerrors/pyerrors.py @@ -1145,6 +1145,13 @@ def reweight(weight, obs, **kwargs): def correlate(obs_a, obs_b): """Correlate two observables. + Attributes: + ----------- + obs_a : Obs + First observable + obs_b : Obs + Second observable + Keep in mind to only correlate primary observables which have not been reweighted yet. The reweighting has to be applied after correlating the observables. Currently only works if ensembles are identical. This is not really necessary. diff --git a/tests/pyerrors_test.py b/tests/pyerrors_test.py index 7943df63..687144cc 100644 --- a/tests/pyerrors_test.py +++ b/tests/pyerrors_test.py @@ -290,6 +290,14 @@ def test_merge_obs(): assert diff == -(my_obs1.value + my_obs2.value) / 2 +def test_correlate(): + my_obs1 = pe.Obs([np.random.rand(100)], ['t']) + my_obs2 = pe.Obs([np.random.rand(100)], ['t']) + corr1 = pe.correlate(my_obs1, my_obs2) + corr2 = pe.correlate(my_obs2, my_obs1) + assert corr1 == corr2 + + def test_irregular_error_propagation(): obs_list = [pe.Obs([np.random.rand(100)], ['t']), pe.Obs([np.random.rand(50)], ['t'], idl=[range(1, 100, 2)]),