added rng arguments to HMC function

This commit is contained in:
Nicolas Lang 2023-11-27 13:19:26 +01:00
parent 5f72bc34ea
commit 830163e90f
2 changed files with 7 additions and 7 deletions

View file

@ -58,13 +58,13 @@ function hamiltonian(mom, U, lp, gp, ymws)
return K+V
end
function HMC!(U, int::IntrScheme, lp::SpaceParm, gp::GaugeParm, ymws::YMworkspace{T}; noacc=false) where T
function HMC!(U, int::IntrScheme, lp::SpaceParm, gp::GaugeParm, ymws::YMworkspace{T}; noacc=false, rng=Random.default_rng(), curng=CUDA.default_rng()) where T
@timeit "HMC trayectory" begin
ymws.U1 .= U
randomize!(ymws.mom, lp, ymws)
randomize!(ymws.mom, lp, ymws; curng)
hini = hamiltonian(ymws.mom, U, lp, gp, ymws)
MD!(ymws.mom, U, int, lp, gp, ymws)
@ -78,7 +78,7 @@ function HMC!(U, int::IntrScheme, lp::SpaceParm, gp::GaugeParm, ymws::YMworkspac
end
if (pacc < 1.0)
r = rand()
r = rand(rng)
if (pacc < r)
U .= ymws.U1
acc = false
@ -90,7 +90,7 @@ function HMC!(U, int::IntrScheme, lp::SpaceParm, gp::GaugeParm, ymws::YMworkspac
end
return dh, acc
end
HMC!(U, eps, ns, lp::SpaceParm, gp::GaugeParm, ymws::YMworkspace{T}; noacc=false) where T = HMC!(U, omf4(T, eps, ns), lp, gp, ymws; noacc=noacc)
HMC!(U, eps, ns, lp::SpaceParm, gp::GaugeParm, ymws::YMworkspace{T}; noacc=false, rng=Random.default_rng(), curng=CUDA.default_rng()) where T = HMC!(U, omf4(T, eps, ns), lp, gp, ymws; noacc=noacc, rng, curng)
function MD!(mom, U, int::IntrScheme{NI, T}, lp::SpaceParm, gp::GaugeParm, ymws::YMworkspace{T}) where {NI, T <: AbstractFloat}