First try at replacing autograd by jax

This commit is contained in:
Fabian Joswig 2021-10-18 12:53:17 +01:00
parent 8d7a5daafa
commit 8fc5d96363
9 changed files with 76 additions and 55 deletions

View file

@ -1,10 +1,13 @@
import autograd.numpy as np
import numpy as np
import jax.numpy as jnp
import os
import random
import string
import copy
import pyerrors as pe
import pytest
from jax.config import config
config.update("jax_enable_x64", True)
np.random.seed(0)
@ -29,21 +32,31 @@ def test_comparison():
def test_function_overloading():
a = pe.pseudo_Obs(17, 2.9, 'e1')
a = pe.pseudo_Obs(2, 2.9, 'e1')
b = pe.pseudo_Obs(4, 0.8, 'e1')
fs = [lambda x: x[0] + x[1], lambda x: x[1] + x[0], lambda x: x[0] - x[1], lambda x: x[1] - x[0],
lambda x: x[0] * x[1], lambda x: x[1] * x[0], lambda x: x[0] / x[1], lambda x: x[1] / x[0],
lambda x: np.exp(x[0]), lambda x: np.sin(x[0]), lambda x: np.cos(x[0]), lambda x: np.tan(x[0]),
lambda x: np.log(x[0]), lambda x: np.sqrt(np.abs(x[0])),
lambda x: np.sinh(x[0]), lambda x: np.cosh(x[0]), lambda x: np.tanh(x[0])]
lambda x: x[0] * x[1], lambda x: x[1] * x[0], lambda x: x[0] / x[1], lambda x: x[1] / x[0]]
for i, f in enumerate(fs):
t1 = f([a, b])
t2 = pe.derived_observable(f, [a, b])
c = t2 - t1
assert c.value == 0.0, str(i)
assert np.all(np.abs(c.deltas['e1']) < 1e-14), str(i)
assert c.is_zero()
f_np = [lambda x: np.exp(x[0]), lambda x: np.sin(x[0]), lambda x: np.cos(x[0]), lambda x: np.tan(x[0]),
lambda x: np.log(x[0]), lambda x: np.sqrt(np.abs(x[0])),
lambda x: np.sinh(x[0]), lambda x: np.cosh(x[0]), lambda x: np.tanh(x[0])]
f_jnp = [lambda x: jnp.exp(x[0]), lambda x: jnp.sin(x[0]), lambda x: jnp.cos(x[0]), lambda x: jnp.tan(x[0]),
lambda x: jnp.log(x[0]), lambda x: jnp.sqrt(jnp.abs(x[0])),
lambda x: jnp.sinh(x[0]), lambda x: jnp.cosh(x[0]), lambda x: jnp.tanh(x[0])]
for i, (f1, f2) in enumerate(zip(f_np, f_jnp)):
t1 = f1([a])
t2 = pe.derived_observable(f2, [a])
c = t2 - t1
assert c.is_zero()
def test_overloading_vectorization():
@ -121,7 +134,7 @@ def test_derived_observables():
test_obs = pe.pseudo_Obs(2, 0.1 * (1 + np.random.rand()), 't', int(1000 * (1 + np.random.rand())))
# Check if autograd and numgrad give the same result
d_Obs_ad = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * np.sin(x[0] * x[1]), [test_obs, test_obs])
d_Obs_ad = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * jnp.sin(x[0] * x[1]), [test_obs, test_obs])
d_Obs_ad.gamma_method()
d_Obs_fd = pe.derived_observable(lambda x, **kwargs: x[0] * x[1] * np.sin(x[0] * x[1]), [test_obs, test_obs], num_grad=True)
d_Obs_fd.gamma_method()