diff --git a/pyerrors/linalg.py b/pyerrors/linalg.py index 259f070a..bb44870d 100644 --- a/pyerrors/linalg.py +++ b/pyerrors/linalg.py @@ -174,35 +174,6 @@ def matmul(*operands): return derived_array(multi_dot, operands) -def _exp_to_jack(matrix): - base_matrix = np.empty_like(matrix) - for index, entry in np.ndenumerate(matrix): - base_matrix[index] = entry.export_jackknife() - return base_matrix - - -def _imp_from_jack(matrix, name, idl): - base_matrix = np.empty_like(matrix) - for index, entry in np.ndenumerate(matrix): - base_matrix[index] = import_jackknife(entry, name, [idl]) - return base_matrix - - -def _exp_to_jack_c(matrix): - base_matrix = np.empty_like(matrix) - for index, entry in np.ndenumerate(matrix): - base_matrix[index] = entry.real.export_jackknife() + 1j * entry.imag.export_jackknife() - return base_matrix - - -def _imp_from_jack_c(matrix, name, idl): - base_matrix = np.empty_like(matrix) - for index, entry in np.ndenumerate(matrix): - base_matrix[index] = CObs(import_jackknife(entry.real, name, [idl]), - import_jackknife(entry.imag, name, [idl])) - return base_matrix - - def jack_matmul(*operands): """Matrix multiply both operands making use of the jackknife approximation. @@ -215,6 +186,31 @@ def jack_matmul(*operands): For large matrices this is considerably faster compared to matmul. """ + def _exp_to_jack(matrix): + base_matrix = np.empty_like(matrix) + for index, entry in np.ndenumerate(matrix): + base_matrix[index] = entry.export_jackknife() + return base_matrix + + def _imp_from_jack(matrix, name, idl): + base_matrix = np.empty_like(matrix) + for index, entry in np.ndenumerate(matrix): + base_matrix[index] = import_jackknife(entry, name, [idl]) + return base_matrix + + def _exp_to_jack_c(matrix): + base_matrix = np.empty_like(matrix) + for index, entry in np.ndenumerate(matrix): + base_matrix[index] = entry.real.export_jackknife() + 1j * entry.imag.export_jackknife() + return base_matrix + + def _imp_from_jack_c(matrix, name, idl): + base_matrix = np.empty_like(matrix) + for index, entry in np.ndenumerate(matrix): + base_matrix[index] = CObs(import_jackknife(entry.real, name, [idl]), + import_jackknife(entry.imag, name, [idl])) + return base_matrix + if any(isinstance(o.flat[0], CObs) for o in operands): name = operands[0].flat[0].real.names[0] idl = operands[0].flat[0].real.idl[name] @@ -251,12 +247,40 @@ def einsum(subscripts, *operands): Obs valued. """ - if any(isinstance(o.flat[0], CObs) for o in operands): - name = operands[0].flat[0].real.names[0] - idl = operands[0].flat[0].real.idl[name] - else: - name = operands[0].flat[0].names[0] - idl = operands[0].flat[0].idl[name] + def _exp_to_jack(matrix): + base_matrix = [] + for index, entry in np.ndenumerate(matrix): + base_matrix.append(entry.export_jackknife()) + return np.asarray(base_matrix).reshape(matrix.shape + base_matrix[0].shape) + + def _exp_to_jack_c(matrix): + base_matrix = [] + for index, entry in np.ndenumerate(matrix): + base_matrix.append(entry.real.export_jackknife() + 1j * entry.imag.export_jackknife()) + return np.asarray(base_matrix).reshape(matrix.shape + base_matrix[0].shape) + + def _imp_from_jack(matrix, name, idl): + base_matrix = np.empty(shape=matrix.shape[:-1], dtype=object) + for index in np.ndindex(matrix.shape[:-1]): + base_matrix[index] = import_jackknife(matrix[index], name, [idl]) + return base_matrix + + def _imp_from_jack_c(matrix, name, idl): + base_matrix = np.empty(shape=matrix.shape[:-1], dtype=object) + for index in np.ndindex(matrix.shape[:-1]): + base_matrix[index] = CObs(import_jackknife(matrix[index].real, name, [idl]), + import_jackknife(matrix[index].imag, name, [idl])) + return base_matrix + + for op in operands: + if isinstance(op.flat[0], CObs): + name = op.flat[0].real.names[0] + idl = op.flat[0].real.idl[name] + break + elif isinstance(op.flat[0], Obs): + name = op.flat[0].names[0] + idl = op.flat[0].idl[name] + break conv_operands = [] for op in operands: @@ -267,15 +291,22 @@ def einsum(subscripts, *operands): else: conv_operands.append(op) - result = np.einsum(subscripts, *conv_operands) + tmp_subscripts = ','.join([o + '...' for o in subscripts.split(',')]) + extended_subscripts = '->'.join([o + '...' for o in tmp_subscripts.split('->')[:-1]] + [tmp_subscripts.split('->')[-1]]) + jack_einsum = np.einsum(extended_subscripts, *conv_operands) - if result.dtype == complex: - return _imp_from_jack_c(result, name, idl) - elif result.dtype == float: - return _imp_from_jack(result, name, idl) + if jack_einsum.dtype == complex: + result = _imp_from_jack_c(jack_einsum, name, idl) + elif jack_einsum.dtype == float: + result = _imp_from_jack(jack_einsum, name, idl) else: raise Exception("Result has unexpected datatype") + if result.shape == (): + return result.flat[0] + else: + return result + def inv(x): """Inverse of Obs or CObs valued matrices.""" diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 46ee6c89..d34da29f 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -93,6 +93,59 @@ def test_jack_matmul(): assert trace4.real.dvalue < 0.001 assert trace4.imag.dvalue < 0.001 + +def test_einsum(): + + def _perform_real_check(arr): + [o.gamma_method() for o in arr] + assert np.all([o.is_zero_within_error(0.001) for o in arr]) + assert np.all([o.dvalue < 0.001 for o in arr]) + + def _perform_complex_check(arr): + [o.gamma_method() for o in arr] + assert np.all([o.real.is_zero_within_error(0.001) for o in arr]) + assert np.all([o.real.dvalue < 0.001 for o in arr]) + assert np.all([o.imag.is_zero_within_error(0.001) for o in arr]) + assert np.all([o.imag.dvalue < 0.001 for o in arr]) + + + tt = [get_real_matrix(4), get_real_matrix(3)] + q = np.tensordot(tt[0], tt[1], 0) + c1 = tt[1] @ q + c2 = pe.linalg.einsum('ij,abjd->abid', tt[1], q) + check1 = c1 - c2 + _perform_real_check(check1.ravel()) + check2 = np.trace(tt[0]) - pe.linalg.einsum('ii', tt[0]) + _perform_real_check([check2]) + check3 = np.trace(tt[1]) - pe.linalg.einsum('ii', tt[1]) + _perform_real_check([check3]) + + tt = [get_real_matrix(4), np.random.random((3, 3))] + q = np.tensordot(tt[0], tt[1], 0) + c1 = tt[1] @ q + c2 = pe.linalg.einsum('ij,abjd->abid', tt[1], q) + check1 = c1 - c2 + _perform_real_check(check1.ravel()) + + tt = [get_complex_matrix(4), get_complex_matrix(3)] + q = np.tensordot(tt[0], tt[1], 0) + c1 = tt[1] @ q + c2 = pe.linalg.einsum('ij,abjd->abid', tt[1], q) + check1 = c1 - c2 + _perform_complex_check(check1.ravel()) + check2 = np.trace(tt[0]) - pe.linalg.einsum('ii', tt[0]) + _perform_complex_check([check2]) + check3 = np.trace(tt[1]) - pe.linalg.einsum('ii', tt[1]) + _perform_complex_check([check3]) + + tt = [get_complex_matrix(4), np.random.random((3, 3))] + q = np.tensordot(tt[0], tt[1], 0) + c1 = tt[1] @ q + c2 = pe.linalg.einsum('ij,abjd->abid', tt[1], q) + check1 = c1 - c2 + _perform_complex_check(check1.ravel()) + + def test_multi_dot(): for dim in [4, 6]: my_list = []