diff --git a/pyerrors/linalg.py b/pyerrors/linalg.py index e6ee9cef..259f070a 100644 --- a/pyerrors/linalg.py +++ b/pyerrors/linalg.py @@ -239,6 +239,44 @@ def jack_matmul(*operands): return _imp_from_jack(r, name, idl) +def einsum(subscripts, *operands): + """Wrapper for numpy.einsum + + Parameters + ---------- + subscripts : str + Subscripts for summation (see numpy documentation for details) + operands : numpy.ndarray + Arbitrary number of 2d-numpy arrays which can be real or complex + Obs valued. + """ + + if any(isinstance(o.flat[0], CObs) for o in operands): + name = operands[0].flat[0].real.names[0] + idl = operands[0].flat[0].real.idl[name] + else: + name = operands[0].flat[0].names[0] + idl = operands[0].flat[0].idl[name] + + conv_operands = [] + for op in operands: + if isinstance(op.flat[0], CObs): + conv_operands.append(_exp_to_jack_c(op)) + elif isinstance(op.flat[0], Obs): + conv_operands.append(_exp_to_jack(op)) + else: + conv_operands.append(op) + + result = np.einsum(subscripts, *conv_operands) + + if result.dtype == complex: + return _imp_from_jack_c(result, name, idl) + elif result.dtype == float: + return _imp_from_jack(result, name, idl) + else: + raise Exception("Result has unexpected datatype") + + def inv(x): """Inverse of Obs or CObs valued matrices.""" return _mat_mat_op(anp.linalg.inv, x)