diff --git a/pyerrors/linalg.py b/pyerrors/linalg.py index 259f070a..bbb59367 100644 --- a/pyerrors/linalg.py +++ b/pyerrors/linalg.py @@ -174,35 +174,6 @@ def matmul(*operands): return derived_array(multi_dot, operands) -def _exp_to_jack(matrix): - base_matrix = np.empty_like(matrix) - for index, entry in np.ndenumerate(matrix): - base_matrix[index] = entry.export_jackknife() - return base_matrix - - -def _imp_from_jack(matrix, name, idl): - base_matrix = np.empty_like(matrix) - for index, entry in np.ndenumerate(matrix): - base_matrix[index] = import_jackknife(entry, name, [idl]) - return base_matrix - - -def _exp_to_jack_c(matrix): - base_matrix = np.empty_like(matrix) - for index, entry in np.ndenumerate(matrix): - base_matrix[index] = entry.real.export_jackknife() + 1j * entry.imag.export_jackknife() - return base_matrix - - -def _imp_from_jack_c(matrix, name, idl): - base_matrix = np.empty_like(matrix) - for index, entry in np.ndenumerate(matrix): - base_matrix[index] = CObs(import_jackknife(entry.real, name, [idl]), - import_jackknife(entry.imag, name, [idl])) - return base_matrix - - def jack_matmul(*operands): """Matrix multiply both operands making use of the jackknife approximation. @@ -215,6 +186,31 @@ def jack_matmul(*operands): For large matrices this is considerably faster compared to matmul. """ + def _exp_to_jack(matrix): + base_matrix = np.empty_like(matrix) + for index, entry in np.ndenumerate(matrix): + base_matrix[index] = entry.export_jackknife() + return base_matrix + + def _imp_from_jack(matrix, name, idl): + base_matrix = np.empty_like(matrix) + for index, entry in np.ndenumerate(matrix): + base_matrix[index] = import_jackknife(entry, name, [idl]) + return base_matrix + + def _exp_to_jack_c(matrix): + base_matrix = np.empty_like(matrix) + for index, entry in np.ndenumerate(matrix): + base_matrix[index] = entry.real.export_jackknife() + 1j * entry.imag.export_jackknife() + return base_matrix + + def _imp_from_jack_c(matrix, name, idl): + base_matrix = np.empty_like(matrix) + for index, entry in np.ndenumerate(matrix): + base_matrix[index] = CObs(import_jackknife(entry.real, name, [idl]), + import_jackknife(entry.imag, name, [idl])) + return base_matrix + if any(isinstance(o.flat[0], CObs) for o in operands): name = operands[0].flat[0].real.names[0] idl = operands[0].flat[0].real.idl[name] @@ -251,12 +247,40 @@ def einsum(subscripts, *operands): Obs valued. """ - if any(isinstance(o.flat[0], CObs) for o in operands): - name = operands[0].flat[0].real.names[0] - idl = operands[0].flat[0].real.idl[name] - else: - name = operands[0].flat[0].names[0] - idl = operands[0].flat[0].idl[name] + def _exp_to_jack(matrix): + base_matrix = [] + for index, entry in np.ndenumerate(matrix): + base_matrix.append(entry.export_jackknife()) + return np.asarray(base_matrix).reshape(matrix.shape + base_matrix[0].shape) + + def _exp_to_jack_c(matrix): + base_matrix = [] + for index, entry in np.ndenumerate(matrix): + base_matrix.append(entry.real.export_jackknife() + 1j * entry.imag.export_jackknife()) + return np.asarray(base_matrix).reshape(matrix.shape + base_matrix[0].shape) + + def _imp_from_jack(matrix, name, idl): + base_matrix = np.empty(shape=matrix.shape[:-1], dtype=object) + for index in np.ndindex(matrix.shape[:-1]): + base_matrix[index] = import_jackknife(matrix[index], name, [idl]) + return base_matrix + + def _imp_from_jack_c(matrix, name, idl): + base_matrix = np.empty(shape=matrix.shape[:-1], dtype=object) + for index in np.ndindex(matrix.shape[:-1]): + base_matrix[index] = CObs(import_jackknife(matrix[index].real, name, [idl]), + import_jackknife(matrix[index].imag, name, [idl])) + return base_matrix + + for op in operands: + if isinstance(op.flat[0], CObs): + name = op.flat[0].real.names[0] + idl = op.flat[0].real.idl[name] + break + elif isinstance(op.flat[0], Obs): + name = op.flat[0].names[0] + idl = op.flat[0].idl[name] + break conv_operands = [] for op in operands: @@ -267,15 +291,22 @@ def einsum(subscripts, *operands): else: conv_operands.append(op) - result = np.einsum(subscripts, *conv_operands) + tmp_subscripts = ','.join([o + '...' for o in subscripts.split(',')]) + extended_subscripts = '->'.join([o + '...' for o in tmp_subscripts.split('->')[:-1]] + [tmp_subscripts.split('->')[-1]]) + jack_einsum = np.einsum(extended_subscripts, *conv_operands) - if result.dtype == complex: - return _imp_from_jack_c(result, name, idl) - elif result.dtype == float: - return _imp_from_jack(result, name, idl) + if jack_einsum.dtype == complex: + result = _imp_from_jack_c(jack_einsum, name, idl) + elif jack_einsum.dtype == float: + result =_imp_from_jack(jack_einsum, name, idl) else: raise Exception("Result has unexpected datatype") + if result.shape == (): + return result.flat[0] + else: + return result + def inv(x): """Inverse of Obs or CObs valued matrices."""