diff options
author | Kaz Kylheku <kaz@kylheku.com> | 2017-06-18 08:34:08 -0700 |
---|---|---|
committer | Kaz Kylheku <kaz@kylheku.com> | 2017-06-18 10:14:34 -0700 |
commit | f9964fe5c922535d5284ad22d62fddbdca315e97 (patch) | |
tree | 48aa53103215b6434c90eeec1606f16f04503f94 | |
parent | 1e35fa11db662e6237e76f800a9294809bcb1660 (diff) | |
download | txr-f9964fe5c922535d5284ad22d62fddbdca315e97.tar.gz txr-f9964fe5c922535d5284ad22d62fddbdca315e97.tar.bz2 txr-f9964fe5c922535d5284ad22d62fddbdca315e97.zip |
Handle returns of MPI functions that return MP_TOOBIG.
* arith.c (do_mp_error): New function.
(num_from_buffer, plus, minus, mul, floordiv, expt, exptmod,
logtrunc, sign_extend, ash, bit): Handle errors from select
MPI functions: those that have the mp_ign attribute.
* ffi.c (unum_carray, num_carray): Likewise.
* rand.c (random): Likewise.
-rw-r--r-- | arith.c | 202 | ||||
-rw-r--r-- | arith.h | 1 | ||||
-rw-r--r-- | ffi.c | 12 | ||||
-rw-r--r-- | rand.c | 20 |
4 files changed, 157 insertions, 78 deletions
@@ -95,7 +95,9 @@ val bignum_from_uintptr(uint_ptr_t u) val num_from_buffer(mem_t *buf, int bytes) { val n = make_bignum(); - mp_read_unsigned_bin(mp(n), buf, bytes); + mp_err mpe = mp_read_unsigned_bin(mp(n), buf, bytes); + if (mpe != MP_OKAY) + do_mp_error(lit("buffer to number conversion"), mpe); return normalize(n); } @@ -370,8 +372,16 @@ static int highest_significant_bit(int_ptr_t n) return highest_bit(n ^ INT_PTR_MAX); } +void do_mp_error(val self, mp_err code) +{ + val errstr = string_utf8(mp_strerror(code)); + uw_throwf(numeric_error_s, lit("~a: ~a"), self, errstr, nao); +} + val plus(val anum, val bnum) { + val self = lit("+"); + tail: switch (TAG_PAIR(tag(anum), tag(bnum))) { case TAG_PAIR(TAG_NUM, TAG_NUM): @@ -389,6 +399,7 @@ tail: case BGNUM: { val n; + mp_err mpe; if (anum == zero) return bnum; n = make_bignum(); @@ -396,16 +407,18 @@ tail: cnum a = c_num(anum); cnum ap = ABS(a); if (a > 0) - mp_add_d(mp(bnum), ap, mp(n)); + mpe = mp_add_d(mp(bnum), ap, mp(n)); else - mp_sub_d(mp(bnum), ap, mp(n)); + mpe = mp_sub_d(mp(bnum), ap, mp(n)); } else { mp_int tmp; mp_init(&tmp); mp_set_intptr(&tmp, c_num(anum)); - mp_add(mp(bnum), &tmp, mp(n)); + mpe = mp_add(mp(bnum), &tmp, mp(n)); mp_clear(&tmp); } + if (mpe != MP_OKAY) + do_mp_error(self, mpe); return normalize(n); } case FLNUM: @@ -421,6 +434,7 @@ tail: case BGNUM: { val n; + mp_err mpe; n = make_bignum(); if (bnum == zero) return anum; @@ -428,16 +442,18 @@ tail: cnum b = c_num(bnum); cnum bp = ABS(b); if (b > 0) - mp_add_d(mp(anum), bp, mp(n)); + mpe = mp_add_d(mp(anum), bp, mp(n)); else - mp_sub_d(mp(anum), bp, mp(n)); + mpe = mp_sub_d(mp(anum), bp, mp(n)); } else { mp_int tmp; mp_init(&tmp); mp_set_intptr(&tmp, c_num(bnum)); - mp_add(mp(anum), &tmp, mp(n)); + mpe = mp_add(mp(anum), &tmp, mp(n)); mp_clear(&tmp); } + if (mpe != MP_OKAY) + do_mp_error(self, mpe); return normalize(n); } case FLNUM: @@ -453,8 +469,11 @@ tail: case TYPE_PAIR(BGNUM, BGNUM): { val n; + mp_err mpe; n = make_bignum(); - mp_add(mp(anum), mp(bnum), mp(n)); + mpe = mp_add(mp(anum), mp(bnum), mp(n)); + if (mpe != MP_OKAY) + do_mp_error(self, mpe); return normalize(n); } case TYPE_PAIR(FLNUM, FLNUM): @@ -513,6 +532,8 @@ char_range: val minus(val anum, val bnum) { + val self = lit("-"); + tail: switch (TAG_PAIR(tag(anum), tag(bnum))) { case TAG_PAIR(TAG_NUM, TAG_NUM): @@ -531,6 +552,7 @@ tail: case BGNUM: { val n; + mp_err mpe; n = make_bignum(); if (anum == zero) { mp_neg(mp(bnum), mp(n)); @@ -540,17 +562,20 @@ tail: cnum a = c_num(anum); cnum ap = ABS(a); if (ap > 0) - mp_sub_d(mp(bnum), ap, mp(n)); + mpe = mp_sub_d(mp(bnum), ap, mp(n)); else - mp_add_d(mp(bnum), ap, mp(n)); - mp_neg(mp(n), mp(n)); + mpe = mp_add_d(mp(bnum), ap, mp(n)); + if (mpe == MP_OKAY) + mp_neg(mp(n), mp(n)); } else { mp_int tmp; mp_init(&tmp); mp_set_intptr(&tmp, c_num(anum)); - mp_sub(mp(bnum), &tmp, mp(n)); + mpe = mp_sub(mp(bnum), &tmp, mp(n)); mp_clear(&tmp); } + if (mpe != MP_OKAY) + do_mp_error(self, mpe); return normalize(n); } case FLNUM: @@ -566,6 +591,7 @@ tail: case BGNUM: { val n; + mp_err mpe; if (bnum == zero) return anum; n = make_bignum(); @@ -573,16 +599,18 @@ tail: cnum b = c_num(bnum); cnum bp = ABS(b); if (b > 0) - mp_sub_d(mp(anum), bp, mp(n)); + mpe = mp_sub_d(mp(anum), bp, mp(n)); else - mp_add_d(mp(anum), bp, mp(n)); + mpe = mp_add_d(mp(anum), bp, mp(n)); } else { mp_int tmp; mp_init(&tmp); mp_set_intptr(&tmp, c_num(bnum)); - mp_sub(mp(anum), &tmp, mp(n)); + mpe = mp_sub(mp(anum), &tmp, mp(n)); mp_clear(&tmp); } + if (mpe != MP_OKAY) + do_mp_error(self, mpe); return normalize(n); } case FLNUM: @@ -598,8 +626,11 @@ tail: case TYPE_PAIR(BGNUM, BGNUM): { val n; + mp_err mpe; n = make_bignum(); - mp_sub(mp(anum), mp(bnum), mp(n)); + mpe = mp_sub(mp(anum), mp(bnum), mp(n)); + if (mpe != MP_OKAY) + do_mp_error(self, mpe); return normalize(n); } case TYPE_PAIR(FLNUM, FLNUM): @@ -689,6 +720,8 @@ val abso(val anum) val mul(val anum, val bnum) { + val self = lit("*"); + tail: switch (TAG_PAIR(tag(anum), tag(bnum))) { case TAG_PAIR(TAG_NUM, TAG_NUM): @@ -725,22 +758,25 @@ tail: case BGNUM: { val n; + mp_err mpe; if (anum == one) return bnum; n = make_bignum(); if (sizeof (int_ptr_t) <= sizeof (mp_digit)) { cnum a = c_num(anum); cnum ap = ABS(a); - mp_mul_d(mp(bnum), ap, mp(n)); - if (ap < 0) + mpe = mp_mul_d(mp(bnum), ap, mp(n)); + if (ap < 0 && mpe == MP_OKAY) mp_neg(mp(n), mp(n)); } else { mp_int tmp; mp_init(&tmp); mp_set_intptr(&tmp, c_num(anum)); - mp_mul(mp(bnum), &tmp, mp(n)); + mpe = mp_mul(mp(bnum), &tmp, mp(n)); mp_clear(&tmp); } + if (mpe != MP_OKAY) + do_mp_error(self, mpe); return n; } case FLNUM: @@ -756,22 +792,25 @@ tail: case BGNUM: { val n; + mp_err mpe; if (bnum == one) return anum; n = make_bignum(); if (sizeof (int_ptr_t) <= sizeof (mp_digit)) { cnum b = c_num(bnum); cnum bp = ABS(b); - mp_mul_d(mp(anum), bp, mp(n)); - if (b < 0) + mpe = mp_mul_d(mp(anum), bp, mp(n)); + if (b < 0 && mpe == MP_OKAY) mp_neg(mp(n), mp(n)); } else { mp_int tmp; mp_init(&tmp); mp_set_intptr(&tmp, c_num(bnum)); - mp_mul(mp(anum), &tmp, mp(n)); + mpe = mp_mul(mp(anum), &tmp, mp(n)); mp_clear(&tmp); } + if (mpe != MP_OKAY) + do_mp_error(self, mpe); return n; } case FLNUM: @@ -787,8 +826,11 @@ tail: case TYPE_PAIR(BGNUM, BGNUM): { val n; + mp_err mpe; n = make_bignum(); - mp_mul(mp(anum), mp(bnum), mp(n)); + mpe = mp_mul(mp(anum), mp(bnum), mp(n)); + if (mpe != MP_OKAY) + do_mp_error(self, mpe); return n; } case TYPE_PAIR(FLNUM, FLNUM): @@ -1103,6 +1145,8 @@ divzero: val floordiv(val anum, val bnum) { + val self = lit("floor"); + if (missingp(bnum)) return floorf(anum); tail: @@ -1165,31 +1209,36 @@ tail: cnum b = c_num(bnum); cnum bp = ABS(b); mp_digit rem; + mp_err mpe = MP_OKAY; if (mp_div_d(mp(anum), bp, mp(n), &rem) != MP_OKAY) goto divzero; if (b < 0) mp_neg(mp(n), mp(n)); if (rem && ((ISNEG(mp(anum)) && b > 0) || (!ISNEG(mp(anum)) && b < 0))) - mp_sub_d(mp(n), 1, mp(n)); + mpe = mp_sub_d(mp(n), 1, mp(n)); + if (mpe != MP_OKAY) + do_mp_error(self, mpe); } else { - int err; + mp_err mpe; cnum b = c_num(bnum); mp_int tmp, rem; mp_init(&tmp); mp_init(&rem); mp_set_intptr(&tmp, b); - err = mp_div(mp(anum), &tmp, mp(n), 0); + mpe = mp_div(mp(anum), &tmp, mp(n), 0); mp_clear(&tmp); - if (err != MP_OKAY) { + if (mpe != MP_OKAY) { mp_clear(&rem); goto divzero; } if (mp_cmp_z(&rem) != MP_EQ && ((ISNEG(mp(anum)) && b > 0) || (!ISNEG(mp(anum)) && b < 0))) - mp_sub_d(mp(n), 1, mp(n)); + mpe = mp_sub_d(mp(n), 1, mp(n)); mp_clear(&rem); + if (mpe != MP_OKAY) + do_mp_error(self, mpe); } return normalize(n); } @@ -1213,6 +1262,7 @@ tail: { val n = make_bignum(); mp_int rem; + mp_err mpe = MP_OKAY; mp_init(&rem); if (mp_div(mp(anum), mp(bnum), mp(n), &rem) != MP_OKAY) { mp_clear(&rem); @@ -1221,8 +1271,10 @@ tail: if (mp_cmp_z(&rem) != MP_EQ && ((ISNEG(mp(anum)) && !ISNEG(mp(bnum))) || (!ISNEG(mp(anum)) && ISNEG(mp(bnum))))) - mp_sub_d(mp(n), 1, mp(n)); + mpe = mp_sub_d(mp(n), 1, mp(n)); mp_clear(&rem); + if (mpe != MP_OKAY) + do_mp_error(self, mpe); return normalize(n); } case TYPE_PAIR(FLNUM, FLNUM): @@ -1745,6 +1797,8 @@ tail: val expt(val anum, val bnum) { + val self = lit("expt"); + tail: switch (TYPE_PAIR(type(anum), type(bnum))) { case TYPE_PAIR(NUM, NUM): @@ -1753,6 +1807,7 @@ tail: cnum b = c_num(bnum); mp_int tmpa; val n; + mp_err mpe = MP_OKAY; if (b < 0) goto negexp; if (bnum == zero) @@ -1763,15 +1818,17 @@ tail: mp_init(&tmpa); mp_set_intptr(&tmpa, a); if (sizeof (int_ptr_t) <= sizeof (mp_digit)) { - mp_expt_d(&tmpa, b, mp(n)); + mpe = mp_expt_d(&tmpa, b, mp(n)); } else { mp_int tmpb; mp_init(&tmpb); mp_set_intptr(&tmpb, b); - mp_expt(&tmpa, &tmpb, mp(n)); + mpe = mp_expt(&tmpa, &tmpb, mp(n)); mp_clear(&tmpb); } mp_clear(&tmpa); + if (mpe != MP_OKAY) + do_mp_error(self, mpe); return normalize(n); } case TYPE_PAIR(NUM, BGNUM): @@ -1779,19 +1836,23 @@ tail: cnum a = c_num(anum); mp_int tmpa; val n; + mp_err mpe = MP_OKAY; if (mp_cmp_z(mp(bnum)) == MP_LT) goto negexp; n = make_bignum(); mp_init(&tmpa); mp_set_intptr(&tmpa, a); - mp_expt(&tmpa, mp(bnum), mp(n)); + mpe = mp_expt(&tmpa, mp(bnum), mp(n)); mp_clear(&tmpa); + if (mpe != MP_OKAY) + do_mp_error(self, mpe); return normalize(n); } case TYPE_PAIR(BGNUM, NUM): { cnum b = c_num(bnum); val n; + mp_err mpe = MP_OKAY; if (b < 0) goto negexp; if (bnum == zero) @@ -1800,23 +1861,28 @@ tail: return anum; n = make_bignum(); if (sizeof (int_ptr_t) <= sizeof (mp_digit)) { - mp_expt_d(mp(anum), b, mp(n)); + mpe = mp_expt_d(mp(anum), b, mp(n)); } else { mp_int tmpb; mp_init(&tmpb); mp_set_intptr(&tmpb, b); - mp_expt(mp(anum), &tmpb, mp(n)); + mpe = mp_expt(mp(anum), &tmpb, mp(n)); mp_clear(&tmpb); } + if (mpe != MP_OKAY) + do_mp_error(self, mpe); return normalize(n); } case TYPE_PAIR(BGNUM, BGNUM): { val n; + mp_err mpe = MP_OKAY; if (mp_cmp_z(mp(bnum)) == MP_LT) goto negexp; n = make_bignum(); - mp_expt(mp(anum), mp(bnum), mp(n)); + mpe = mp_expt(mp(anum), mp(bnum), mp(n)); + if (mpe != MP_OKAY) + do_mp_error(self, mpe); normalize(n); return n; } @@ -1842,6 +1908,8 @@ negexp: val exptmod(val base, val exp, val mod) { + val self = lit("exptmod"); + mp_err mpe = MP_OKAY; val n; if (!integerp(base) || !integerp(exp) || !integerp(mod)) @@ -1858,16 +1926,15 @@ val exptmod(val base, val exp, val mod) n = make_bignum(); - if (mp_exptmod(mp(base), mp(exp), mp(mod), mp(n)) != MP_OKAY) + if ((mpe = mp_exptmod(mp(base), mp(exp), mp(mod), mp(n))) != MP_OKAY) goto bad; return normalize(n); inval: - uw_throwf(error_s, lit("exptmod: non-integral operands ~s ~s ~s"), - base, exp, mod, nao); + uw_throwf(error_s, lit("~a: non-integral operands ~s ~s ~s"), + self, base, exp, mod, nao); bad: - uw_throwf(error_s, lit("exptmod: bad operands ~s ~s ~s"), - base, exp, mod, nao); + do_mp_error(self, mpe); } static int_ptr_t isqrt_fixnum(int_ptr_t a) @@ -2344,6 +2411,7 @@ bad: val logtrunc(val a, val bits) { + val self = lit("logtrunc"); cnum an, bn; val b; const cnum num_mask = (NUM_MAX << 1) | 1; @@ -2358,6 +2426,7 @@ val logtrunc(val a, val bits) goto bad4; switch (type(a)) { + mp_err mpe; case NUM: an = c_num(a); if (bn <= num_bits) { @@ -2368,24 +2437,21 @@ val logtrunc(val a, val bits) /* fallthrough */ case BGNUM: b = make_ubignum(); - if (mp_trunc(mp(a), mp(b), bn) != MP_OKAY) - goto bad; + if ((mpe = mp_trunc(mp(a), mp(b), bn)) != MP_OKAY) + do_mp_error(self, mpe); return normalize(b); default: goto bad3; } -bad: - uw_throwf(error_s, lit("logtrunc: operation failed on ~s"), a, nao); - bad2: - uw_throwf(error_s, lit("logtrunc: bits value ~s is not a fixnum"), bits, nao); + uw_throwf(error_s, lit("~a: bits value ~s is not a fixnum"), self, bits, nao); bad3: - uw_throwf(error_s, lit("logtrunc: non-integral operand ~s"), a, nao); + uw_throwf(error_s, lit("~a: non-integral operand ~s"), self, a, nao); -bad4: - uw_throwf(error_s, lit("logtrunc: negative bits value ~s"), bits, nao); +bad4:; + uw_throwf(error_s, lit("~a: negative bits value ~s"), self, bits, nao); } val sign_extend(val n, val nbits) @@ -2404,8 +2470,10 @@ val sign_extend(val n, val nbits) case BGNUM: { val out = make_ubignum(); + mp_err mpe; mp_2comp(mp(ntrunc), mp(out), mp(ntrunc)->used); - mp_trunc(mp(out), mp(out), c_num(nbits)); + if ((mpe = mp_trunc(mp(out), mp(out), c_num(nbits))) != MP_OKAY) + do_mp_error(lit("sign-extend"), mpe); mp_neg(mp(out), mp(out)); return normalize(out); } @@ -2418,10 +2486,12 @@ val sign_extend(val n, val nbits) val ash(val a, val bits) { + val self = lit("ash"); cnum an, bn; val b; int hb; const int num_bits = CHAR_BIT * sizeof (cnum) - TAG_SHIFT; + mp_err mpe = MP_OKAY; if (!fixnump(bits)) goto bad2; @@ -2449,8 +2519,8 @@ val ash(val a, val bits) if (bn < INT_MIN || bn > INT_MAX) goto bad4; b = make_bignum(); - if (mp_shift(mp(a), mp(b), bn) != MP_OKAY) - goto bad; + if ((mpe = mp_shift(mp(a), mp(b), bn)) != MP_OKAY) + break; return normalize(b); default: goto bad3; @@ -2465,31 +2535,31 @@ val ash(val a, val bits) return num_fast(an >> num_bits); case BGNUM: b = make_bignum(); - if (mp_shift(mp(a), mp(b), bn) != MP_OKAY) - goto bad; + if ((mpe = mp_shift(mp(a), mp(b), bn)) != MP_OKAY) + break; return normalize(b); default: goto bad3; } - } -bad: - uw_throwf(error_s, lit("ash: operation failed on ~s"), a, nao); + do_mp_error(self, mpe); bad2: - uw_throwf(error_s, lit("ash: bits value ~s is not a fixnum"), bits, nao); + uw_throwf(error_s, lit("~a: bits value ~s is not a fixnum"), self, bits, nao); bad3: - uw_throwf(error_s, lit("ash: non-integral operand ~s"), a, nao); + uw_throwf(error_s, lit("~a: non-integral operand ~s"), self, a, nao); bad4: - uw_throwf(error_s, lit("ash: bit value too large ~s"), bits, nao); + uw_throwf(error_s, lit("~a: bit value too large ~s"), self, bits, nao); } val bit(val a, val bit) { + val self = lit("bit"); cnum bn; + mp_err mpe = MP_OKAY; if (!fixnump(bit)) goto bad; @@ -2509,9 +2579,9 @@ val bit(val a, val bit) } case BGNUM: { - mp_err res = mp_bit(mp(a), bn); + mpe = mp_bit(mp(a), bn); - switch (res) { + switch (mpe) { case MP_YES: return t; case MP_NO: @@ -2525,16 +2595,16 @@ val bit(val a, val bit) } bad: - uw_throwf(error_s, lit("bit: bit position ~s is not a fixnum"), bit, nao); + uw_throwf(error_s, lit("~a: bit position ~s is not a fixnum"), self, bit, nao); bad2: - uw_throwf(error_s, lit("bit: bit position ~s is negative"), bit, nao); + uw_throwf(error_s, lit("~a: bit position ~s is negative"), self, bit, nao); bad3: - uw_throwf(error_s, lit("bit: non-integral operand ~s"), a, nao); + uw_throwf(error_s, lit("~a: non-integral operand ~s"), self, a, nao); bad4: - uw_throwf(error_s, lit("bit: operation failed on ~s, bit ~s"), a, bit, nao); + do_mp_error(self, mpe); } val maskv(struct args *bits) @@ -45,5 +45,6 @@ val tofloatz(val obj); val tointz(val obj, val base); val width(val num); val bits(val obj); +noreturn void do_mp_error(val self, mp_err code); void arith_init(void); void arith_free_all(void); @@ -4887,9 +4887,9 @@ val unum_carray(val carray) struct txr_ffi_type *etft = scry->eltft; ucnum size = (ucnum) etft->size * (ucnum) scry->nelem; val ubn = make_bignum(); - if ((ucnum) (int) size != size) - uw_throwf(error_s, lit("~a: bignum size overflow"), self, nao); - mp_read_unsigned_bin(mp(ubn), scry->data, size); + mp_err mpe = mp_read_unsigned_bin(mp(ubn), scry->data, size); + if (mpe != MP_OKAY) + do_mp_error(self, mpe); return normalize(ubn); } @@ -4901,9 +4901,9 @@ val num_carray(val carray) ucnum size = (ucnum) etft->size * (ucnum) scry->nelem; ucnum bits = size * 8; val ubn = make_bignum(); - if ((ucnum) (int) size != size || bits / 8 != size) - uw_throwf(error_s, lit("~a: bignum size overflow"), self, nao); - mp_read_unsigned_bin(mp(ubn), scry->data, size); + mp_err mpe = mp_read_unsigned_bin(mp(ubn), scry->data, size); + if (mpe != MP_OKAY) + do_mp_error(self, mpe); return sign_extend(normalize(ubn), unum(bits)); } @@ -217,6 +217,7 @@ val random_fixnum(val state) val random(val state, val modulus) { + val self = lit("random"); struct rand_state *r = coerce(struct rand_state *, cobj_handle(state, random_state_s)); @@ -233,21 +234,28 @@ val random(val state, val modulus) ucnum i; for (i = 0; i < rands_needed; i++) { rand32_t rnd = rand32(r); + mp_err mpe = MP_OKAY; #if MP_DIGIT_SIZE >= 4 if (i > 0) - mp_mul_2d(om, 32, om); + mpe = mp_mul_2d(om, 32, om); else rnd &= msb_rand_mask; - mp_add_d(om, rnd, om); + if (mpe == MP_OKAY) + mpe = mp_add_d(om, rnd, om); #else if (i > 0) - mp_mul_2d(om, 16, om); + mpe = mp_mul_2d(om, 16, om); else rnd &= msb_rand_mask; - mp_add_d(om, rnd & 0xFFFF, om); - mp_mul_2d(om, 16, om); - mp_add_d(om, rnd >> 16, om); + if (mpe == MP_OKAY) + mpe = mp_add_d(om, rnd & 0xFFFF, om); + if (mpe == MP_OKAY) + mpe = mp_mul_2d(om, 16, om); + if (mpe == MP_OKAY) + mp_add_d(om, rnd >> 16, om); #endif + if (mpe != MP_OKAY) + do_mp_error(self, mpe); } if (mp_cmp(om, m) != MP_LT) { mp_zero(om); |