staload M = "libc/SATS/math.sats"
staload "libc/SATS/pthread.sats"
staload "libats/SATS/parworkshop.sats"
staload _ = "libats/DATS/parworkshop.dats"
macdef tt = int8_of_int (1)
macdef ff = int8_of_int (0)
typedef two = int8
%{^
ATSinline()
ats_int_type
ceil2 (ats_int_type k, ats_int_type i) {
int r = k % i ;
if (r == 0) {
return k ;
} else {
return (k + i - r) ;
} // end of [if]
} // end of [ceil2]
%}
extern
fun ceil2 {k,i:nat | i > 0} (k: int k, i: int i): intGte k = "ceil2"
extern fun sieve_once
{m,limit:nat | limit <= m} {i,j:nat} {l:addr}
(pf: !array_v (two, m, l) | A: ptr l, limit: int limit, i: int i, j: int j)
: void = "sieve_once_safe"
implement sieve_once (pf | A, limit, i, j) = begin
if (j < limit) then begin
(if A[j] <> ff then A[j] := ff; sieve_once (pf | A, limit, i, j+i))
end end
extern fun sieve_many
{m,m1,m2:nat | m1 <= m; m2 <= m} {l:addr}
(pf: !array_v (two, m, l) | A: ptr l, m1: int m1, m2: int m2): void
implement sieve_many (pf | p_A, m1, m2) = () where {
var i: intGte 2 = 2
val () = while (i < m1) let
val () = if p_A->[i] = tt then sieve_once (pf | p_A, m2, i, i+i)
in
i := i + 1
end }
extern fun sieve_many_seg
{m,m1,m2,m3:nat | m1 <= m; m2 <= m; m3 <= m} {l:addr}
(pf: !array_v (two, m, l) | A: ptr l, m1: int m1, m2: int m2, m3: int m3): void
= "sieve_many_seg_safe"
extern fun sieve_many_seg_unsafe
{m1,m2,m3:nat} {l:addr} (A: ptr l, m1: int m1, m2: int m2, m3: int m3): void
= "sieve_many_seg_safe"
implement sieve_many_seg (pf | p_A, m1, m2, m3) = () where {
var i: intGte 2 = 2
val () = while (i < m1) let
val () = if p_A->[i] = tt then let
val j = ceil2 (m2, i) in sieve_once (pf | p_A, m3, i, j)
end in
i := i + 1
end }
fn sqrt_int {m:nat} (m: int m): Nat = let
val m_sqrt = int_of_double ($M.sqrt (double_of m + 0.5))
val m_sqrt = int1_of_int m_sqrt
val () = assert (m_sqrt >= 0) in
m_sqrt
end
viewtypedef work = () -<lincloptr1> void
viewtypedef WSptr (l:addr) = WORKSHOPptr (work, l)
fun fwork {lws:agz}
(ws: !WSptr lws, wk: &work >> work?): int = let
val wk = wk
val pfun = __cast (wk) where {
extern castfn __cast
(wk: !work >> opt (work, i >= 2)): #[i:nat] uintptr i
} in
if pfun >= (uintptr1_of_uint1)2U then let
prval () = opt_unsome {work} (wk)
val () = wk ()
val () = cloptr_free (wk)
in
1 end else let
val u = uint1_of_uintptr1 (pfun)
val i = int_of_uint (u)
prval () = opt_unnone {work} (wk)
prval () = cleanup_top {work} (wk)
in
~i end end
fn nsieve_mt {lws:agz}
{m:int | m >= 2} (ws: !WSptr lws, m: int m): void = let
val () = assert_prerrf_bool1 (
m >= 2, "nsieve_mt(%i): argument is illegal; it must be positive.\n", @(m)
) val msz = size1_of_int1 (m)
val [la:addr] (pf_gc, pf | p_A) =
array_ptr_alloc_tsz {two} (msz, sizeof<two>)
var x0 = tt
val () = array_ptr_initialize_elt_tsz {two} (!p_A, msz, x0, sizeof<two>)
val m1 = sqrt_int (m)
val [m1:int] m1 = (if m1 < m then m1 + 1 else m): natLte m
val () = sieve_many (pf | p_A, m1, m1) #define INC %(1 << 16)
val () = split (pf | ws, p_A, m1, m, m1) where {
fun split {m1,m:nat | m1 <= m} {j:nat} (
pf: !array_v (two, m, la)
| ws: !WSptr lws
, p_A: ptr la
, m1: int m1
, m: int m
, j: int j
) : void = let
val j1 = j + INC
in
if j1 < m then let
val () = workshop_insert_work (ws, f) where {
val f = lam (): void =<lincloptr1> sieve_many_seg_unsafe (p_A, m1, j, j1)
}
in
split (pf | ws, p_A, m1, m, j1)
end else let
val () = workshop_insert_work (ws, f) where {
val f = lam (): void =<lincloptr1> sieve_many_seg_unsafe (p_A, m1, j, m)
}
in
end end }
val () = workshop_wait_blocked_all (ws)
val nworker = workshop_get_nworker (ws)
var i: Nat = 0
val () = while (i < nworker) let
val _quit = $extval (work, "(void*)0")
val () = workshop_insert_work (ws, _quit) in i := i + 1
end val () = workshop_wait_quit_all (ws)
val count = loop (pf | 2, 0) where {
fun loop {i:nat}
(pf: !array_v (two, m, la) | i: int i, c: int):<cloref1> int =
if i < m then begin
if p_A->[i] = tt then loop (pf | i+1, c+1) else loop (pf | i+1, c)
end else begin
c end
} val () = array_ptr_free {two} (pf_gc, pf | p_A)
in
printf ("The number of primes < %8i is %8i\n", @(m, count))
end
#define QSZ 1024 #define NWORKER 1
implement main (argc, argv) = let
var nworker: int = NWORKER
val () = assert_errmsg_bool1 (
argc >= 2, ": command format: nsieve2_mt <integer> <ncore>"
)
val i = int1_of argv.[1]
val () = assert_errmsg_bool1 (
i >= 2, "The input integer needs to be at least 2.\n"
) val () = if argc >= 3 then (nworker := int_of argv.[2])
val nworker = int1_of_int (nworker)
val () = assert_errmsg (nworker > 0, #LOCATION)
val ws = workshop_make<work> (QSZ, fwork)
val _err = workshop_add_nworker (ws, nworker)
val () = assert_errmsg (_err = 0, #LOCATION)
val () = nsieve_mt (ws, i)
val () = workshop_free_vt_exn (ws)
in
end