diff --git a/pyerrors/correlators.py b/pyerrors/correlators.py index d14166a3..27282986 100644 --- a/pyerrors/correlators.py +++ b/pyerrors/correlators.py @@ -10,8 +10,9 @@ from .misc import dump_object, _assert_equal_properties from .fits import least_squares, Fit_result from .roots import find_root from . import linalg +from .input.json import dump_to_json from numpy import ndarray, ufunc -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union, Literal class Corr: @@ -45,7 +46,7 @@ class Corr: __slots__ = ["content", "N", "T", "tag", "prange"] - def __init__(self, data_input: list[Obs, CObs], padding: list[int]=[0, 0], prange: Optional[list[int]]=None): + def __init__(self, data_input: Union[list[Obs, CObs], list[ndarray[ndarray[Obs, CObs]]], ndarray[ndarray[Corr]]], padding: list[int]=[0, 0], prange: Optional[list[int]]=None): """ Initialize a Corr object. Parameters @@ -303,7 +304,7 @@ class Corr: transposed = [None if _check_for_none(self, G) else G.T for G in self.content] return 0.5 * (Corr(transposed) + self) - def GEVP(self, t0: int, ts: Optional[int]=None, sort: Optional[str]="Eigenvalue", vector_obs: bool=False, **kwargs) -> Union[list[list[Optional[ndarray]]], ndarray, list[Optional[ndarray]]]: + def GEVP(self, t0: int, ts: Optional[int]=None, sort: Optional[Literal["Eigenvalue", "Eigenvector"]]="Eigenvalue", vector_obs: bool=False, **kwargs) -> Union[list[list[Optional[ndarray]]], ndarray, list[Optional[ndarray]]]: r'''Solve the generalized eigenvalue problem on the correlator matrix and returns the corresponding eigenvectors. The eigenvectors are sorted according to the descending eigenvalues, the zeroth eigenvector(s) correspond to the @@ -409,7 +410,7 @@ class Corr: else: return reordered_vecs - def Eigenvalue(self, t0: int, ts: None=None, state: int=0, sort: str="Eigenvalue", **kwargs) -> "Corr": + def Eigenvalue(self, t0: int, ts: Optional[int]=None, state: int=0, sort: Optional[Literal["Eigenvalue", "Eigenvector"]]="Eigenvalue", **kwargs) -> "Corr": """Determines the eigenvalue of the GEVP by solving and projecting the correlator Parameters @@ -495,7 +496,7 @@ class Corr: new_content.append(self.content[t]) return Corr(new_content) - def correlate(self, partner: Union[Corr, float, Obs]) -> "Corr": + def correlate(self, partner: Union[Corr, Obs]) -> "Corr": """Correlate the correlator with another correlator or Obs Parameters @@ -577,14 +578,14 @@ class Corr: return (self + T_partner) / 2 - def deriv(self, variant: Optional[str]="symmetric") -> "Corr": + def deriv(self, variant: Literal["symmetric", "forward", "backward", "improved", "log"]="symmetric") -> "Corr": """Return the first derivative of the correlator with respect to x0. Parameters ---------- variant : str decides which definition of the finite differences derivative is used. - Available choice: symmetric, forward, backward, improved, log, default: symmetric + Available choices: symmetric, forward, backward, improved, log, default: symmetric """ if self.N != 1: raise ValueError("deriv only implemented for one-dimensional correlators.") @@ -638,7 +639,7 @@ class Corr: else: raise ValueError("Unknown variant.") - def second_deriv(self, variant: Optional[str]="symmetric") -> "Corr": + def second_deriv(self, variant: Literal["symmetric", "big_symmetric", "improved", "log"]="symmetric") -> "Corr": r"""Return the second derivative of the correlator with respect to x0. Parameters @@ -698,7 +699,7 @@ class Corr: else: raise ValueError("Unknown variant.") - def m_eff(self, variant: str='log', guess: float=1.0) -> "Corr": + def m_eff(self, variant: Literal["log", "cosh", "periodic", "sinh", "arccosh", "logsym"]='log', guess: float=1.0) -> "Corr": """Returns the effective mass of the correlator as correlator object Parameters @@ -813,7 +814,7 @@ class Corr: result = least_squares(xs, ys, function, silent=silent, **kwargs) return result - def plateau(self, plateau_range: Optional[list[int]]=None, method: str="fit", auto_gamma: bool=False) -> Obs: + def plateau(self, plateau_range: Optional[list[int]]=None, method: Literal['fit', 'avg']="fit", auto_gamma: bool=False) -> Obs: """ Extract a plateau value from a Corr object Parameters @@ -862,7 +863,7 @@ class Corr: self.prange = prange return - def show(self, x_range: Optional[list[int]]=None, comp: Optional[Corr]=None, y_range: Optional[list[int, float]]=None, logscale: bool=False, plateau: Optional[Obs, float, int]=None, fit_res: Optional[Fit_result]=None, fit_key: Optional[str]=None, ylabel: Optional[str]=None, save: Optional[str]=None, auto_gamma: bool=False, hide_sigma: Optional[int, float]=None, references: Optional[list[float]]=None, title: Optional[str]=None): + def show(self, x_range: Optional[list[int]]=None, comp: Optional[Corr]=None, y_range: Optional[list[int, float]]=None, logscale: bool=False, plateau: Union[Obs, float, int, None]=None, fit_res: Optional[Fit_result]=None, fit_key: Optional[str]=None, ylabel: Optional[str]=None, save: Optional[str]=None, auto_gamma: bool=False, hide_sigma: Union[int, float, None]=None, references: Optional[list[float]]=None, title: Optional[str]=None): """Plots the correlator using the tag of the correlator as label if available. Parameters @@ -1029,11 +1030,8 @@ class Corr: specifies a custom path for the file (default '.') """ if datatype == "json.gz": - from .input.json import dump_to_json - if 'path' in kwargs: - file_name = kwargs.get('path') + '/' + filename - else: - file_name = filename + path = kwargs.get("path", ".") + file_name = path + '/' + filename dump_to_json(self, file_name) elif datatype == "pickle": dump_object(self, filename, **kwargs) @@ -1078,7 +1076,7 @@ class Corr: __array_priority__ = 10000 - def __eq__(self, y: Any) -> ndarray: + def __eq__(self, y: Any) -> ndarray[bool, None]: if isinstance(y, Corr): comp = np.asarray(y.content, dtype=object) else: