staload M = "libc/SATS/math.sats"
staload "libc/SATS/pthread.sats"
staload "libats/SATS/parworkshop.sats"
staload _ = "libats/DATS/parworkshop.dats"
macdef tt = int8_of_int (1)
macdef ff = int8_of_int (0)
typedef two = int8
extern fun sieve_once
{m,limit:nat | limit <= m} {i,j:nat} {l:addr} (
pf: !array_v (two, m, l) | A: ptr l, limit: int limit, i: int i, j: int j
) : void = "sieve_once_safe"
implement
sieve_once
(pf | A, limit, i, j) = begin
if (j < limit) then begin
(if A[j] <> ff then A[j] := ff; sieve_once (pf | A, limit, i, j+i))
end end
extern fun sieve_once_unsafe
(A: Ptr, limit: int, i: int, j: int): void = "sieve_once_safe"
extern fun sieve_many
{m,m1,m2:nat | m1 <= m; m2 <= m} {i:nat} {l:addr}
(pf: !array_v (two, m, l) | A: ptr l, m1: int m1, m2: int m2, i: int i): void
= "sieve_many_safe"
implement
sieve_many
(pf | p_A, m1, m2, i) = begin
if i < m1 then let
val () = if p_A->[i] = tt then sieve_once (pf | p_A, m2, i, i+i)
in
sieve_many (pf | p_A, m1, m2, i+1)
end end
fn sqrt_int {m:nat} (m: int m): Nat = let
val m_sqrt = int_of_double ($M.sqrt (double_of m + 0.5))
val m_sqrt = int1_of_int m_sqrt
val () = assert (m_sqrt >= 0) in
m_sqrt
end
viewtypedef work = () -<lincloptr1> void
viewtypedef WSptr (l:addr) = WORKSHOPptr (work, l)
fun fwork {lws:agz}
(ws: !WSptr lws, wk: &work >> work?): int = let
val wk = wk
val pfun = __cast (wk) where {
extern castfn __cast
(wk: !work >> opt (work, i >= 2)): #[i:nat] uintptr i
} in
if pfun >= (uintptr1_of_uint1)2U then let
prval () = opt_unsome {work} (wk)
val () = wk ()
val () = cloptr_free (wk)
in
1 end else let
val u = uint1_of_uintptr1 (pfun)
val i = int_of_uint (u)
prval () = opt_unnone {work} (wk)
prval () = cleanup_top {work} (wk)
in
~i end end
fn nsieve_mt {lws:agz}
(ws: !WSptr lws, m: int): void = let
val [m:int] m = int1_of_int m
val () = assert_prerrf_bool1 (
m >= 2, "nsieve_mt(%i): argument is illegal; it must be positive.\n", @(m)
) val msz = size1_of_int1 (m)
val [la:addr] (pf_gc, pf | p_A) =
array_ptr_alloc_tsz {two} (msz, sizeof<two>)
var x0 = tt
val () = array_ptr_initialize_elt_tsz {two} (!p_A, msz, x0, sizeof<two>)
val m1 = sqrt_int (m)
val [m1:int] m1 = (if m1 < m then m1 + 1 else m): natLte m
val () = sieve_many (pf | p_A, m1, m1, 2) val () = split (pf | ws, p_A, m, m1, 2) where {
fun split {m,m1:nat | m1 <= m} {i:nat} (
pf: !array_v (two, m, la)
| ws: !WSptr lws
, p_A: ptr la
, m: int m
, m1: int m1
, i: int i
) : void = begin
if i < m1 then let
val () = if p_A->[i] = tt then let
val () = workshop_insert_work (ws, f) where {
val f = lam (): void =<lincloptr1> sieve_once_unsafe (p_A, m, i, i+i)
} in
end in
split (pf | ws, p_A, m, m1, i+1)
end end }
val () = workshop_wait_blocked_all (ws)
val nworker = workshop_get_nworker (ws)
var i: Nat = 0
val () = while (i < nworker) let
val _quit = $extval (work, "(void*)0")
val () = workshop_insert_work (ws, _quit) in i := i + 1
end val () = workshop_wait_quit_all (ws)
val count = loop (pf | 2, 0) where {
fun loop {i:nat}
(pf: !array_v (two, m, la) | i: int i, c: int):<cloref1> int =
if i < m then begin
if p_A->[i] = tt then loop (pf | i+1, c+1) else loop (pf | i+1, c)
end else begin
c end
} val () = array_ptr_free {two} (pf_gc, pf | p_A)
in
printf ("The number of primes < %8i is %8i\n", @(m, count))
end
#define QSZ 1024 #define NWORKER 1
implement main (argc, argv) = let
var nworker: int = NWORKER
val () = assert_errmsg_bool1 (
argc >= 2, ": command format: nsieve_mt <integer> <ncore>"
)
val i = int1_of argv.[1]
val () = assert_errmsg_bool1 (
i >= 2, "The input integer needs to be at least 2.\n"
) val () = if argc >= 3 then (nworker := int_of argv.[2])
val nworker = int1_of_int (nworker)
val () = assert_errmsg (nworker > 0, #LOCATION)
val ws = workshop_make<work> (QSZ, fwork)
val _err = workshop_add_nworker (ws, nworker)
val () = assert_errmsg (_err = 0, #LOCATION)
val () = nsieve_mt (ws, i)
val () = workshop_free_vt_exn (ws)
in
end