diff options
-rw-r--r-- | ChangeLog | 24 | ||||
-rw-r--r-- | arith.c | 123 | ||||
-rw-r--r-- | arith.txr | 115 | ||||
-rw-r--r-- | lib.c | 31 | ||||
-rw-r--r-- | mpi-patches/add-mp-set-intptr | 22 | ||||
-rw-r--r-- | mpi-patches/fix-mult-bug | 13 | ||||
-rw-r--r-- | mpi-patches/mpi-set-double-intptr | 10 |
7 files changed, 271 insertions, 67 deletions
@@ -1,5 +1,29 @@ 2011-12-10 Kaz Kylheku <kaz@kylheku.com> + Bignum support in mult function. + + * arith.c: Regenerated. + + * arith.txr (CNUM_BIT): New constant. + (bignum, bignum_dbl_ipt): New static functions. + (@{add-fname}): Use bignum function. + (mul): New functions, rewrite of mul from lib.c. + + * lib.c (mul): Function removed. + + * mpi-patches/add-mp-set-intptr (mp_set_intptr): Revised patch. + Local variable v should be int_ptr_t not unsigned long. + Also, the mp_set interface doesn't set the sign; it's an unsigned + interface. We must do that ourselves. + + * mpi-patches/fix-mult-bug: The main multiplication function is + also broken in the same way, requiring the cast. + + * mpi-patches/mpi-set-double-intptr: Fixed use of wrong type for + local variable v. + +2011-12-10 Kaz Kylheku <kaz@kylheku.com> + * mpi-patches/mpi-set-mpi-word: Bugfix and refresh. * mpi-patches/mpi-set-double-intptr: New file. @@ -37,6 +37,7 @@ #include <dirent.h> #include <setjmp.h> #include <wchar.h> +#include <limits.h> #include "config.h" #include "lib.h" #include "unwind.h" @@ -45,6 +46,7 @@ #define TAG_PAIR(A, B) ((A) << TAG_SHIFT | (B)) #define NOOP(A, B) +#define CNUM_BIT ((int) sizeof (cnum) * CHAR_BIT) static mp_int NUM_MAX_MP; @@ -56,6 +58,20 @@ val make_bignum(void) return n; } +static val bignum(cnum cn) +{ + val n = make_bignum(); + mp_set_intptr(mp(n), cn); + return n; +} + +static val bignum_dbl_ipt(double_intptr_t di) +{ + val n = make_bignum(); + mp_set_double_intptr(mp(n), di); + return n; +} + static val normalize(val bignum) { switch (mp_cmp_mag(mp(bignum), &NUM_MAX_MP)) { @@ -83,12 +99,8 @@ val plus(val anum, val bnum) cnum b = c_num(bnum); cnum sum = a + b; - if (sum < NUM_MIN || sum > NUM_MAX) { - val n = make_bignum(); - mp_set_intptr(mp(n), sum); - return n; - } - + if (sum < NUM_MIN || sum > NUM_MAX) + return bignum(sum); return num(sum); } case TAG_PAIR(TAG_NUM, TAG_PTR): @@ -148,12 +160,8 @@ val minus(val anum, val bnum) cnum b = c_num(bnum); cnum sum = a - b; - if (sum < NUM_MIN || sum > NUM_MAX) { - val n = make_bignum(); - mp_set_intptr(mp(n), sum); - return n; - } - + if (sum < NUM_MIN || sum > NUM_MAX) + return bignum(sum); return num(sum); } case TAG_PAIR(TAG_NUM, TAG_PTR): @@ -213,6 +221,97 @@ val neg(val anum) } } +val mul(val anum, val bnum) +{ + int tag_a = tag(anum); + int tag_b = tag(bnum); + + switch (TAG_PAIR(tag_a, tag_b)) { + case TAG_PAIR(TAG_NUM, TAG_NUM): + { + cnum a = c_num(anum); + cnum b = c_num(bnum); +#if HAVE_DOUBLE_INTPTR_T + double_intptr_t product = a * (double_intptr_t) b; + if (product < NUM_MIN || product > NUM_MAX) + return bignum_dbl_ipt(product); + return num(product); +#else + cnum ap = (a < 0) ? -a : a; + cnum bp = (b < 0) ? -b : b; + int bit = CNUM_BIT - 3, amaxbit = 0, bmaxbit = 0; + cnum mask = (cnum) 1 << (CNUM_BIT - 4); + for (; mask && (ap || bp); mask >>= 1, bit--) { + if ((ap & mask)) { + amaxbit = bit; + ap = 0; + } + if ((bp & mask)) { + bmaxbit = bit; + bp = 0; + } + } + if (amaxbit + bmaxbit < CNUM_BIT - 1) { + cnum product = a * b; + if (product >= NUM_MIN && product <= NUM_MAX) + return num(a * b); + return bignum(a * b); + } else { + val n = make_bignum(); + mp_int tmpb; + mp_init(&tmpb); + mp_set_intptr(&tmpb, b); + mp_set_intptr(mp(n), a); + mp_mul(mp(n), &tmpb, mp(n)); + mp_clear(&tmpb); + return n; + } +#endif + } + case TAG_PAIR(TAG_NUM, TAG_PTR): + { + val n; + type_check(bnum, BGNUM); + n = make_bignum(); + if (sizeof (int_ptr_t) <= sizeof (mp_digit)) { + mp_mul_d(mp(bnum), c_num(anum), mp(n)); + } else { + mp_int tmp; + mp_init(&tmp); + mp_set_intptr(&tmp, c_num(anum)); + mp_mul(mp(bnum), &tmp, mp(n)); + } + return n; + } + case TAG_PAIR(TAG_PTR, TAG_NUM): + { + val n; + type_check(bnum, BGNUM); + n = make_bignum(); + if (sizeof (int_ptr_t) <= sizeof (mp_digit)) { + mp_mul_d(mp(anum), c_num(bnum), mp(n)); + } else { + mp_int tmp; + mp_init(&tmp); + mp_set_intptr(&tmp, c_num(bnum)); + mp_mul(mp(anum), &tmp, mp(n)); + } + return n; + } + case TAG_PAIR(TAG_PTR, TAG_PTR): + { + val n; + type_check(anum, BGNUM); + type_check(bnum, BGNUM); + n = make_bignum(); + mp_mul(mp(anum), mp(bnum), mp(n)); + return n; + } + } + uw_throwf(error_s, lit("mul: invalid operands ~s ~s"), anum, bnum, nao); + abort(); +} + void arith_init(void) { mp_init(&NUM_MAX_MP); @@ -42,6 +42,7 @@ #include <dirent.h> #include <setjmp.h> #include <wchar.h> +#include <limits.h> #include "config.h" #include "lib.h" #include "unwind.h" @@ -50,6 +51,7 @@ #define TAG_PAIR(A, B) ((A) << TAG_SHIFT | (B)) #define NOOP(A, B) +#define CNUM_BIT ((int) sizeof (cnum) * CHAR_BIT) static mp_int NUM_MAX_MP; @@ -61,6 +63,20 @@ val make_bignum(void) return n; } +static val bignum(cnum cn) +{ + val n = make_bignum(); + mp_set_intptr(mp(n), cn); + return n; +} + +static val bignum_dbl_ipt(double_intptr_t di) +{ + val n = make_bignum(); + mp_set_double_intptr(mp(n), di); + return n; +} + static val normalize(val bignum) { switch (mp_cmp_mag(mp(bignum), &NUM_MAX_MP)) { @@ -89,12 +105,8 @@ val @{add-fname}(val anum, val bnum) cnum b = c_num(bnum); cnum sum = a @{add-c-op} b; - if (sum < NUM_MIN || sum > NUM_MAX) { - val n = make_bignum(); - mp_set_intptr(mp(n), sum); - return n; - } - + if (sum < NUM_MIN || sum > NUM_MAX) + return bignum(sum); return num(sum); } case TAG_PAIR(TAG_NUM, TAG_PTR): @@ -155,6 +167,97 @@ val neg(val anum) } } +val mul(val anum, val bnum) +{ + int tag_a = tag(anum); + int tag_b = tag(bnum); + + switch (TAG_PAIR(tag_a, tag_b)) { + case TAG_PAIR(TAG_NUM, TAG_NUM): + { + cnum a = c_num(anum); + cnum b = c_num(bnum); +#if HAVE_DOUBLE_INTPTR_T + double_intptr_t product = a * (double_intptr_t) b; + if (product < NUM_MIN || product > NUM_MAX) + return bignum_dbl_ipt(product); + return num(product); +#else + cnum ap = (a < 0) ? -a : a; + cnum bp = (b < 0) ? -b : b; + int bit = CNUM_BIT - 3, amaxbit = 0, bmaxbit = 0; + cnum mask = (cnum) 1 << (CNUM_BIT - 4); + for (; mask && (ap || bp); mask >>= 1, bit--) { + if ((ap & mask)) { + amaxbit = bit; + ap = 0; + } + if ((bp & mask)) { + bmaxbit = bit; + bp = 0; + } + } + if (amaxbit + bmaxbit < CNUM_BIT - 1) { + cnum product = a * b; + if (product >= NUM_MIN && product <= NUM_MAX) + return num(a * b); + return bignum(a * b); + } else { + val n = make_bignum(); + mp_int tmpb; + mp_init(&tmpb); + mp_set_intptr(&tmpb, b); + mp_set_intptr(mp(n), a); + mp_mul(mp(n), &tmpb, mp(n)); + mp_clear(&tmpb); + return n; + } +#endif + } + case TAG_PAIR(TAG_NUM, TAG_PTR): + { + val n; + type_check(bnum, BGNUM); + n = make_bignum(); + if (sizeof (int_ptr_t) <= sizeof (mp_digit)) { + mp_mul_d(mp(bnum), c_num(anum), mp(n)); + } else { + mp_int tmp; + mp_init(&tmp); + mp_set_intptr(&tmp, c_num(anum)); + mp_mul(mp(bnum), &tmp, mp(n)); + } + return n; + } + case TAG_PAIR(TAG_PTR, TAG_NUM): + { + val n; + type_check(bnum, BGNUM); + n = make_bignum(); + if (sizeof (int_ptr_t) <= sizeof (mp_digit)) { + mp_mul_d(mp(anum), c_num(bnum), mp(n)); + } else { + mp_int tmp; + mp_init(&tmp); + mp_set_intptr(&tmp, c_num(bnum)); + mp_mul(mp(anum), &tmp, mp(n)); + } + return n; + } + case TAG_PAIR(TAG_PTR, TAG_PTR): + { + val n; + type_check(anum, BGNUM); + type_check(bnum, BGNUM); + n = make_bignum(); + mp_mul(mp(anum), mp(bnum), mp(n)); + return n; + } + } + uw_throwf(error_s, lit("mul: invalid operands ~s ~s"), anum, bnum, nao); + abort(); +} + void arith_init(void) { mp_init(&NUM_MAX_MP); @@ -848,37 +848,6 @@ val minusv(val minuend, val nlist) return neg(minuend); } -val mul(val anum, val bnum) -{ - cnum a = c_num(anum); - cnum b = c_num(bnum); - -#ifdef HAVE_LONGLONG_T - if (sizeof (longlong_t) >= 2 * sizeof (cnum)) { - longlong_t product = a * b; - numeric_assert (product >= NUM_MIN && product <= NUM_MAX); - return num(product); - } else -#endif - { - if (a > 0){ - if (b > 0) { - numeric_assert (a <= (NUM_MAX / b)); - } else { - numeric_assert (b >= (NUM_MIN / a)); - } - } else { - if (b > 0) { - numeric_assert (a >= (NUM_MIN / b)); - } else { - numeric_assert ((a == 0) || (b >= (NUM_MIN / a))); - } - } - - return num(a * b); - } -} - val mulv(val nlist) { if (!nlist) diff --git a/mpi-patches/add-mp-set-intptr b/mpi-patches/add-mp-set-intptr index a5d50a33..87e4ebb4 100644 --- a/mpi-patches/add-mp-set-intptr +++ b/mpi-patches/add-mp-set-intptr @@ -1,16 +1,17 @@ Index: mpi-1.8.6/mpi.c =================================================================== ---- mpi-1.8.6.orig/mpi.c 2011-12-09 13:52:26.000000000 -0800 -+++ mpi-1.8.6/mpi.c 2011-12-09 13:56:19.000000000 -0800 +--- mpi-1.8.6.orig/mpi.c 2011-12-10 18:20:55.000000000 -0800 ++++ mpi-1.8.6/mpi.c 2011-12-10 19:40:53.000000000 -0800 @@ -528,6 +528,59 @@ /* }}} */ +mp_err mp_set_intptr(mp_int *mp, int_ptr_t z) +{ ++ int_ptr_t v = z > 0 ? z : -z; ++ + if (sizeof z > sizeof (mp_digit)) { + int ix, shift; -+ unsigned long v = z > 0 ? z : -z; + const int nd = (sizeof v + sizeof (mp_digit) - 1) / sizeof (mp_digit); + + ARGCHK(mp != NULL, MP_BADARG); @@ -28,14 +29,13 @@ Index: mpi-1.8.6/mpi.c + { + DIGIT(mp, ix) = (v >> shift) & MP_DIGIT_MAX; + } -+ -+ if(z < 0) -+ SIGN(mp) = MP_NEG; -+ -+ return MP_OKAY; ++ } else { ++ mp_set(mp, v); + } + -+ mp_set(mp, z); ++ if(z < 0) ++ SIGN(mp) = MP_NEG; ++ + return MP_OKAY; +} + @@ -64,8 +64,8 @@ Index: mpi-1.8.6/mpi.c Index: mpi-1.8.6/mpi.h =================================================================== ---- mpi-1.8.6.orig/mpi.h 2011-12-09 13:49:20.000000000 -0800 -+++ mpi-1.8.6/mpi.h 2011-12-09 13:56:19.000000000 -0800 +--- mpi-1.8.6.orig/mpi.h 2011-12-10 18:19:39.000000000 -0800 ++++ mpi-1.8.6/mpi.h 2011-12-10 19:39:58.000000000 -0800 @@ -94,6 +94,8 @@ void mp_zero(mp_int *mp); void mp_set(mp_int *mp, mp_digit d); diff --git a/mpi-patches/fix-mult-bug b/mpi-patches/fix-mult-bug index e86d0363..bb8b0f0d 100644 --- a/mpi-patches/fix-mult-bug +++ b/mpi-patches/fix-mult-bug @@ -1,7 +1,7 @@ Index: mpi-1.8.6/mpi.c =================================================================== ---- mpi-1.8.6.orig/mpi.c 2011-12-10 12:05:39.000000000 -0800 -+++ mpi-1.8.6/mpi.c 2011-12-10 12:05:43.000000000 -0800 +--- mpi-1.8.6.orig/mpi.c 2011-12-10 19:41:00.000000000 -0800 ++++ mpi-1.8.6/mpi.c 2011-12-10 19:43:09.000000000 -0800 @@ -3263,7 +3263,7 @@ } @@ -11,3 +11,12 @@ Index: mpi-1.8.6/mpi.c dp[ix] = ACCUM(w); k = CARRYOUT(w); } +@@ -3480,7 +3480,7 @@ + pa = DIGITS(a); + for(jx = 0; jx < ua; ++jx, ++pa) { + pt = pbt + ix + jx; +- w = *pb * *pa + k + *pt; ++ w = *pb * (mp_word) *pa + k + *pt; + *pt = ACCUM(w); + k = CARRYOUT(w); + } diff --git a/mpi-patches/mpi-set-double-intptr b/mpi-patches/mpi-set-double-intptr index 1c834966..fb5dc52c 100644 --- a/mpi-patches/mpi-set-double-intptr +++ b/mpi-patches/mpi-set-double-intptr @@ -1,7 +1,7 @@ Index: mpi-1.8.6/mpi.c =================================================================== ---- mpi-1.8.6.orig/mpi.c 2011-12-10 18:21:53.000000000 -0800 -+++ mpi-1.8.6/mpi.c 2011-12-10 18:24:07.000000000 -0800 +--- mpi-1.8.6.orig/mpi.c 2011-12-10 19:13:25.000000000 -0800 ++++ mpi-1.8.6/mpi.c 2011-12-10 19:16:43.000000000 -0800 @@ -573,6 +573,36 @@ return MP_OKAY; } @@ -10,7 +10,7 @@ Index: mpi-1.8.6/mpi.c +mp_err mp_set_double_intptr(mp_int *mp, double_intptr_t z) +{ + int ix, shift; -+ unsigned long v = z > 0 ? z : -z; ++ double_intptr_t v = z > 0 ? z : -z; + const int nd = (sizeof v + sizeof (mp_digit) - 1) / sizeof (mp_digit); + + ARGCHK(mp != NULL, MP_BADARG); @@ -41,8 +41,8 @@ Index: mpi-1.8.6/mpi.c USED(mp) = 2; Index: mpi-1.8.6/mpi.h =================================================================== ---- mpi-1.8.6.orig/mpi.h 2011-12-10 18:21:53.000000000 -0800 -+++ mpi-1.8.6/mpi.h 2011-12-10 18:22:56.000000000 -0800 +--- mpi-1.8.6.orig/mpi.h 2011-12-10 19:13:25.000000000 -0800 ++++ mpi-1.8.6/mpi.h 2011-12-10 19:14:04.000000000 -0800 @@ -100,6 +100,9 @@ mp_err mp_set_int(mp_int *mp, long z); mp_err mp_set_intptr(mp_int *mp, int_ptr_t z); |