First try at replacing autograd by jax

This commit is contained in:
Fabian Joswig 2021-10-18 12:53:17 +01:00
parent 8d7a5daafa
commit 8fc5d96363
9 changed files with 76 additions and 55 deletions

View file

@ -1,9 +1,12 @@
import autograd.numpy as np
import numpy as np
import jax.numpy as jnp
import math
import scipy.optimize
from scipy.odr import ODR, Model, RealData
import pyerrors as pe
import pytest
from jax.config import config
config.update("jax_enable_x64", True)
np.random.seed(0)
@ -24,7 +27,7 @@ def test_standard_fit():
popt, pcov = scipy.optimize.curve_fit(f, x, y, sigma=[o.dvalue for o in oy], absolute_sigma=True)
def func(a, x):
y = a[0] * np.exp(-a[1] * x)
y = a[0] * jnp.exp(-a[1] * x)
return y
beta = pe.fits.standard_fit(x, oy, func)
@ -61,7 +64,7 @@ def test_odr_fit():
return a * np.exp(-b * x)
def func(a, x):
y = a[0] * np.exp(-a[1] * x)
y = a[0] * jnp.exp(-a[1] * x)
return y
data = RealData([o.value for o in ox], [o.value for o in oy], sx=[o.dvalue for o in ox], sy=[o.dvalue for o in oy])