mirror of
https://github.com/fjosw/pyerrors.git
synced 2026-05-13 16:46:52 +02:00
[Fix] Migrate to odrpack because of scipy.odr deprecation (#279)
* [Fix] Migrate to odrpack because of scipy.odr deprecation in recent release * [Fix] Fix behaviour for rank deficient fits. Add test * [Fix] Relax test_merge_obs tolerance to machine epsilon and update ODR docstring to reference odrpack * [ci] Re-add -Werror to pytest workflow * [Fix] Handle platform-dependent rank-deficient warning in ODR tests * [Fix] Improve rank-deficient detection and bump odrpack to >=0.5 Fix incorrect ODRPACK95 info code parsing: rank deficiency is encoded in the tens digit (info // 10 % 10), not the hundreds digit. Add irank and inv_condnum to the warning message for diagnostics.
This commit is contained in:
parent
b180dff020
commit
b28c2f0b6f
6 changed files with 119 additions and 44 deletions
|
|
@ -1,9 +1,10 @@
|
|||
import warnings
|
||||
import numpy as np
|
||||
import autograd.numpy as anp
|
||||
import matplotlib.pyplot as plt
|
||||
import math
|
||||
import scipy.optimize
|
||||
from scipy.odr import ODR, Model, RealData
|
||||
from odrpack import odr_fit
|
||||
from scipy.linalg import cholesky
|
||||
from scipy.stats import norm
|
||||
import iminuit
|
||||
|
|
@ -397,11 +398,21 @@ def test_total_least_squares():
|
|||
y = a[0] * anp.exp(-a[1] * x)
|
||||
return y
|
||||
|
||||
data = RealData([o.value for o in ox], [o.value for o in oy], sx=[o.dvalue for o in ox], sy=[o.dvalue for o in oy])
|
||||
model = Model(func)
|
||||
odr = ODR(data, model, [0, 0], partol=np.finfo(np.float64).eps)
|
||||
odr.set_job(fit_type=0, deriv=1)
|
||||
output = odr.run()
|
||||
# odrpack expects f(x, beta), but pyerrors convention is f(beta, x)
|
||||
def wrapped_func(x, beta):
|
||||
return func(beta, x)
|
||||
|
||||
output = odr_fit(
|
||||
wrapped_func,
|
||||
np.array([o.value for o in ox]),
|
||||
np.array([o.value for o in oy]),
|
||||
beta0=np.array([0.0, 0.0]),
|
||||
weight_x=1.0 / np.array([o.dvalue for o in ox]) ** 2,
|
||||
weight_y=1.0 / np.array([o.dvalue for o in oy]) ** 2,
|
||||
partol=np.finfo(np.float64).eps,
|
||||
task='explicit-ODR',
|
||||
diff_scheme='central'
|
||||
)
|
||||
|
||||
out = pe.total_least_squares(ox, oy, func, expected_chisquare=True)
|
||||
beta = out.fit_parameters
|
||||
|
|
@ -458,6 +469,21 @@ def test_total_least_squares():
|
|||
assert((outc.fit_parameters[1] - betac[1]).is_zero())
|
||||
|
||||
|
||||
def test_total_least_squares_vanishing_chisquare():
|
||||
"""Test that a saturated fit (n_obs == n_parms) works without exception."""
|
||||
def func(a, x):
|
||||
return a[0] + a[1] * x
|
||||
|
||||
x = [pe.pseudo_Obs(1.0, 0.1, 'x0'), pe.pseudo_Obs(2.0, 0.1, 'x1')]
|
||||
y = [pe.pseudo_Obs(1.0, 0.1, 'y0'), pe.pseudo_Obs(2.0, 0.1, 'y1')]
|
||||
|
||||
# Rank-deficient warning may or may not fire depending on platform/solver numerics.
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", message="ODR fit is rank deficient", category=RuntimeWarning)
|
||||
out = pe.total_least_squares(x, y, func, silent=True)
|
||||
assert len(out.fit_parameters) == 2
|
||||
|
||||
|
||||
def test_odr_derivatives():
|
||||
x = []
|
||||
y = []
|
||||
|
|
@ -502,7 +528,10 @@ def test_r_value_persistence():
|
|||
assert np.isclose(fitp[1].value, fitp[1].r_values['a'])
|
||||
assert np.isclose(fitp[1].value, fitp[1].r_values['b'])
|
||||
|
||||
fitp = pe.fits.total_least_squares(y, y, f)
|
||||
# Rank-deficient warning may or may not fire depending on platform/solver numerics.
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", message="ODR fit is rank deficient", category=RuntimeWarning)
|
||||
fitp = pe.fits.total_least_squares(y, y, f)
|
||||
|
||||
assert np.isclose(fitp[0].value, fitp[0].r_values['a'])
|
||||
assert np.isclose(fitp[0].value, fitp[0].r_values['b'])
|
||||
|
|
@ -1431,11 +1460,11 @@ def fit_general(x, y, func, silent=False, **kwargs):
|
|||
global print_output, beta0
|
||||
print_output = 1
|
||||
if 'initial_guess' in kwargs:
|
||||
beta0 = kwargs.get('initial_guess')
|
||||
beta0 = np.asarray(kwargs.get('initial_guess'), dtype=np.float64)
|
||||
if len(beta0) != n_parms:
|
||||
raise Exception('Initial guess does not have the correct length.')
|
||||
else:
|
||||
beta0 = np.arange(n_parms)
|
||||
beta0 = np.arange(n_parms, dtype=np.float64)
|
||||
|
||||
if len(x) != len(y):
|
||||
raise Exception('x and y have to have the same length')
|
||||
|
|
@ -1463,23 +1492,45 @@ def fit_general(x, y, func, silent=False, **kwargs):
|
|||
|
||||
xerr = kwargs.get('xerr')
|
||||
|
||||
# odrpack expects f(x, beta), but pyerrors convention is f(beta, x)
|
||||
def wrapped_func(x, beta):
|
||||
return func(beta, x)
|
||||
|
||||
if length == len(obs):
|
||||
assert 'x_constants' in kwargs
|
||||
data = RealData(kwargs.get('x_constants'), obs, sy=yerr)
|
||||
fit_type = 2
|
||||
x_data = np.asarray(kwargs.get('x_constants'))
|
||||
y_data = np.asarray(obs)
|
||||
# Ordinary least squares (no x errors)
|
||||
output = odr_fit(
|
||||
wrapped_func,
|
||||
x_data,
|
||||
y_data,
|
||||
beta0=beta0,
|
||||
weight_y=1.0 / np.asarray(yerr) ** 2,
|
||||
partol=np.finfo(np.float64).eps,
|
||||
task='explicit-OLS',
|
||||
diff_scheme='central'
|
||||
)
|
||||
elif length == len(obs) // 2:
|
||||
data = RealData(obs[:length], obs[length:], sx=xerr, sy=yerr)
|
||||
fit_type = 0
|
||||
x_data = np.asarray(obs[:length])
|
||||
y_data = np.asarray(obs[length:])
|
||||
# ODR with x errors
|
||||
output = odr_fit(
|
||||
wrapped_func,
|
||||
x_data,
|
||||
y_data,
|
||||
beta0=beta0,
|
||||
weight_x=1.0 / np.asarray(xerr) ** 2,
|
||||
weight_y=1.0 / np.asarray(yerr) ** 2,
|
||||
partol=np.finfo(np.float64).eps,
|
||||
task='explicit-ODR',
|
||||
diff_scheme='central'
|
||||
)
|
||||
else:
|
||||
raise Exception('x and y do not fit together.')
|
||||
|
||||
model = Model(func)
|
||||
|
||||
odr = ODR(data, model, beta0, partol=np.finfo(np.float64).eps)
|
||||
odr.set_job(fit_type=fit_type, deriv=1)
|
||||
output = odr.run()
|
||||
if print_output and not silent:
|
||||
print(*output.stopreason)
|
||||
print(output.stopreason)
|
||||
print('chisquare/d.o.f.:', output.res_var)
|
||||
print_output = 0
|
||||
beta0 = output.beta
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue