mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-06-29 16:29:27 +02:00
First try at replacing autograd by jax
This commit is contained in:
parent
8d7a5daafa
commit
8fc5d96363
9 changed files with 76 additions and 55 deletions
|
@ -1,9 +1,12 @@
|
|||
import autograd.numpy as np
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import math
|
||||
import scipy.optimize
|
||||
from scipy.odr import ODR, Model, RealData
|
||||
import pyerrors as pe
|
||||
import pytest
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True)
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
@ -24,7 +27,7 @@ def test_standard_fit():
|
|||
popt, pcov = scipy.optimize.curve_fit(f, x, y, sigma=[o.dvalue for o in oy], absolute_sigma=True)
|
||||
|
||||
def func(a, x):
|
||||
y = a[0] * np.exp(-a[1] * x)
|
||||
y = a[0] * jnp.exp(-a[1] * x)
|
||||
return y
|
||||
|
||||
beta = pe.fits.standard_fit(x, oy, func)
|
||||
|
@ -61,7 +64,7 @@ def test_odr_fit():
|
|||
return a * np.exp(-b * x)
|
||||
|
||||
def func(a, x):
|
||||
y = a[0] * np.exp(-a[1] * x)
|
||||
y = a[0] * jnp.exp(-a[1] * x)
|
||||
return y
|
||||
|
||||
data = RealData([o.value for o in ox], [o.value for o in oy], sx=[o.dvalue for o in ox], sy=[o.dvalue for o in oy])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue