mirror of
https://github.com/fjosw/pyerrors.git
synced 2025-06-29 08:19:27 +02:00
First try at replacing autograd by jax
This commit is contained in:
parent
8d7a5daafa
commit
8fc5d96363
9 changed files with 76 additions and 55 deletions
|
@ -1,7 +1,10 @@
|
|||
import autograd.numpy as np
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import math
|
||||
import pyerrors as pe
|
||||
import pytest
|
||||
from jax.config import config
|
||||
config.update("jax_enable_x64", True)
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
|
@ -14,7 +17,7 @@ def test_matrix_inverse():
|
|||
|
||||
content.append(1.0) # Add 1.0 as a float
|
||||
matrix = np.diag(content)
|
||||
inverse_matrix = pe.linalg.mat_mat_op(np.linalg.inv, matrix)
|
||||
inverse_matrix = pe.linalg.mat_mat_op(jnp.linalg.inv, matrix)
|
||||
assert all([o.is_zero() for o in np.diag(matrix) * np.diag(inverse_matrix) - 1])
|
||||
|
||||
|
||||
|
@ -35,7 +38,7 @@ def test_complex_matrix_inverse():
|
|||
matrix[n, m] = entry.real.value + 1j * entry.imag.value
|
||||
|
||||
inverse_matrix = np.linalg.inv(matrix)
|
||||
inverse_obs_matrix = pe.linalg.mat_mat_op(np.linalg.inv, obs_matrix)
|
||||
inverse_obs_matrix = pe.linalg.mat_mat_op(jnp.linalg.inv, obs_matrix)
|
||||
for (n, m), entry in np.ndenumerate(inverse_matrix):
|
||||
assert np.isclose(inverse_matrix[n, m].real, inverse_obs_matrix[n, m].real.value)
|
||||
assert np.isclose(inverse_matrix[n, m].imag, inverse_obs_matrix[n, m].imag.value)
|
||||
|
@ -53,7 +56,7 @@ def test_matrix_functions():
|
|||
matrix = np.array(matrix) @ np.identity(dim)
|
||||
|
||||
# Check inverse of matrix
|
||||
inv = pe.linalg.mat_mat_op(np.linalg.inv, matrix)
|
||||
inv = pe.linalg.mat_mat_op(jnp.linalg.inv, matrix)
|
||||
check_inv = matrix @ inv
|
||||
|
||||
for (i, j), entry in np.ndenumerate(check_inv):
|
||||
|
@ -66,7 +69,7 @@ def test_matrix_functions():
|
|||
|
||||
# Check Cholesky decomposition
|
||||
sym = np.dot(matrix, matrix.T)
|
||||
cholesky = pe.linalg.mat_mat_op(np.linalg.cholesky, sym)
|
||||
cholesky = pe.linalg.mat_mat_op(jnp.linalg.cholesky, sym)
|
||||
check = cholesky @ cholesky.T
|
||||
|
||||
for (i, j), entry in np.ndenumerate(check):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue