diff options
Diffstat (limited to 'arith.c')
-rw-r--r-- | arith.c | 445 |
1 files changed, 445 insertions, 0 deletions
@@ -27,6 +27,7 @@ #include <stdio.h> #include <stdlib.h> +#include <stddef.h> #include <string.h> #include <wctype.h> #include <stdarg.h> @@ -49,6 +50,8 @@ #include "txr.h" #include "arith.h" +#define max(a, b) ((a) > (b) ? (a) : (b)) + #define TAG_PAIR(A, B) ((A) << TAG_SHIFT | (B)) #define NOOP(A, B) #define CNUM_BIT ((int) sizeof (cnum) * CHAR_BIT) @@ -3416,6 +3419,417 @@ static val flo_set_round_mode(val mode) #endif +val num(cnum n) +{ + return (n >= NUM_MIN && n <= NUM_MAX) ? num_fast(n) : bignum(n); +} + +cnum c_num(val n) +{ + switch (type(n)) { + case CHR: case NUM: + return coerce(cnum, n) >> TAG_SHIFT; + case BGNUM: + if (mp_in_intptr_range(mp(n))) { + int_ptr_t out; + mp_get_intptr(mp(n), &out); + return out; + } + uw_throwf(error_s, lit("~s is out of allowed range [~s, ~s]"), + n, num(INT_PTR_MIN), num(INT_PTR_MAX), nao); + default: + type_mismatch(lit("~s is not an integer"), n, nao); + } +} + +cnum c_fixnum(val num, val self) +{ + switch (type(num)) { + case CHR: case NUM: + return coerce(cnum, num) >> TAG_SHIFT; + default: + type_mismatch(lit("~a: ~s is not fixnum integer or character"), + self, num, nao); + } +} + +#if HAVE_FPCLASSIFY +INLINE int bad_float(double d) +{ + switch (fpclassify(d)) { + case FP_ZERO: + case FP_NORMAL: + case FP_SUBNORMAL: + return 0; + default: + return 1; + } +} +#else +#define bad_float(d) (0) +#endif + +val flo(double n) +{ + if (bad_float(n)) { + uw_throw(numeric_error_s, lit("out-of-range floating-point result")); + } else { + val obj = make_obj(); + obj->fl.type = FLNUM; + obj->fl.n = n; + return obj; + } +} + +double c_flo(val num, val self) +{ + type_check(self, num, FLNUM); + return num->fl.n; +} + +val fixnump(val num) +{ + return (is_num(num)) ? t : nil; +} + +val bignump(val num) +{ + return (type(num) == BGNUM) ? t : nil; +} + +val integerp(val num) +{ + switch (tag(num)) { + case TAG_NUM: + return t; + case TAG_PTR: + if (num == nil) + return nil; + if (num->t.type == BGNUM) + return t; + /* fallthrough */ + default: + return nil; + } +} + +val floatp(val num) +{ + return (type(num) == FLNUM) ? t : nil; +} + +val numberp(val num) +{ + switch (tag(num)) { + case TAG_NUM: + return t; + case TAG_PTR: + if (num == nil) + return nil; + if (num->t.type == BGNUM || num->t.type == FLNUM) + return t; + /* fallthrough */ + default: + return nil; + } +} + +val nary_op(val self, val (*bfun)(val, val), + val (*ufun)(val self, val), + struct args *args, val emptyval) +{ + val acc, next; + cnum index = 0; + + if (!args_more(args, index)) + return emptyval; + + acc = args_get(args, &index); + + if (!args_more(args, index)) + return ufun(self, acc); + + do { + next = args_get(args, &index); + acc = bfun(acc, next); + } while (args_more(args, index)); + + return acc; +} + +static val nary_op_keyfun(val self, val (*bfun)(val, val), + val (*ufun)(val self, val), + struct args *args, val emptyval, + val keyfun) +{ + val acc, next; + cnum index = 0; + + if (!args_more(args, index)) + return emptyval; + + acc = funcall1(keyfun, args_get(args, &index)); + + if (!args_more(args, index)) + return ufun(self, acc); + + do { + next = funcall1(keyfun, args_get(args, &index)); + acc = bfun(acc, next); + } while (args_more(args, index)); + + return acc; +} + + +val nary_simple_op(val self, val (*bfun)(val, val), + struct args *args, val firstval) +{ + val acc = firstval, next; + cnum index = 0; + + while (args_more(args, index)) { + next = args_get(args, &index); + acc = bfun(acc, next); + } + + return acc; +} + +static val unary_num(val self, val arg) +{ + if (!numberp(arg)) + uw_throwf(error_s, lit("~a: ~s isn't a number"), self, arg, nao); + return arg; +} + +static val unary_arith(val self, val arg) +{ + switch (type(arg)) { + case NUM: + case CHR: + case BGNUM: + case FLNUM: + return arg; + default: + uw_throwf(error_s, lit("~a: invalid argument ~s"), self, arg, nao); + } +} + +static val unary_int(val self, val arg) +{ + if (!integerp(arg)) + uw_throwf(error_s, lit("~a: ~s isn't an integer"), self, arg, nao); + return arg; +} + +val plusv(struct args *nlist) +{ + return nary_op(lit("+"), plus, unary_arith, nlist, zero); +} + +val minusv(val minuend, struct args *nlist) +{ + val acc = minuend, next; + cnum index = 0; + + if (!args_more(nlist, index)) + return neg(acc); + + do { + next = args_get(nlist, &index); + acc = minus(acc, next); + } while (args_more(nlist, index)); + + return acc; +} + +val mulv(struct args *nlist) +{ + return nary_op(lit("*"), mul, unary_num, nlist, one); +} + +val divv(val dividend, struct args *nlist) +{ + val acc = dividend, next; + cnum index = 0; + + if (!args_more(nlist, index)) + return divi(one, acc); + + do { + next = args_get(nlist, &index); + acc = divi(acc, next); + } while (args_more(nlist, index)); + + return acc; +} + +val logandv(struct args *nlist) +{ + return nary_op(lit("logand"), logand, unary_int, nlist, negone); +} + +val logiorv(struct args *nlist) +{ + return nary_op(lit("logior"), logior, unary_int, nlist, zero); +} + +val gtv(val first, struct args *rest) +{ + cnum index = 0; + + while (args_more(rest, index)) { + val elem = args_get(rest, &index); + if (!gt(first, elem)) + return nil; + first = elem; + } + + if (index == 0) + (void) unary_arith(lit(">"), first); + + return t; +} + +val ltv(val first, struct args *rest) +{ + cnum index = 0; + + while (args_more(rest, index)) { + val elem = args_get(rest, &index); + if (!lt(first, elem)) + return nil; + first = elem; + } + + if (index == 0) + (void) unary_arith(lit("<"), first); + + return t; +} + +val gev(val first, struct args *rest) +{ + cnum index = 0; + + while (args_more(rest, index)) { + val elem = args_get(rest, &index); + if (!ge(first, elem)) + return nil; + first = elem; + } + + if (index == 0) + (void) unary_arith(lit(">="), first); + + return t; +} + +val lev(val first, struct args *rest) +{ + cnum index = 0; + + while (args_more(rest, index)) { + val elem = args_get(rest, &index); + if (!le(first, elem)) + return nil; + first = elem; + } + + if (index == 0) + (void) unary_arith(lit("<="), first); + + return t; +} + +val numeqv(val first, struct args *rest) +{ + cnum index = 0; + + while (args_more(rest, index)) { + val elem = args_get(rest, &index); + if (!numeq(first, elem)) + return nil; + first = elem; + } + + if (index == 0) + (void) unary_arith(lit("="), first); + + return t; +} + +val numneqv(struct args *args) +{ + val i, j; + val list = args_get_list(args); + + if (list && !cdr(list)) { + (void) unary_arith(lit("/="), car(list)); + return t; + } + + for (i = list; i; i = cdr(i)) + for (j = cdr(i); j; j = cdr(j)) + if (numeq(car(i), car(j))) + return nil; + + return t; +} + +static val sumv(struct args *nlist, val keyfun) +{ + return nary_op_keyfun(lit("+"), plus, unary_arith, nlist, zero, keyfun); +} + +val sum(val seq, val keyfun) +{ + args_decl_list(args, ARGS_MIN, tolist(seq)); + return if3(missingp(keyfun), plusv(args), sumv(args, keyfun)); +} + +static val prodv(struct args *nlist, val keyfun) +{ + return nary_op_keyfun(lit("*"), mul, unary_num, nlist, one, keyfun); +} + +val prod(val seq, val keyfun) +{ + args_decl_list(args, ARGS_MIN, tolist(seq)); + return if3(missingp(keyfun), mulv(args), prodv(args, keyfun)); +} + +static val rexpt(val right, val left) +{ + return expt(left, right); +} + +val exptv(struct args *nlist) +{ + cnum nargs = args_count(nlist); + args_decl(rnlist, max(ARGS_MIN, nargs)); + args_copy_reverse(rnlist, nlist, nargs); + return nary_op(lit("expt"), rexpt, unary_num, rnlist, one); +} + +static val abso_self(val self, val arg) +{ + (void) self; + return abso(arg); +} + +val gcdv(struct args *nlist) +{ + return nary_op(lit("gcd"), gcd, abso_self, nlist, zero); +} + +val lcmv(struct args *nlist) +{ + return nary_op(lit("lcm"), lcm, abso_self, nlist, zero); +} + + void arith_init(void) { log2_init(); @@ -3457,6 +3871,24 @@ void arith_init(void) reg_fun(intern(lit("abs"), user_package), func_n1(abso)); reg_fun(intern(lit("trunc"), user_package), func_n2o(trunc, 1)); reg_fun(intern(lit("mod"), user_package), func_n2(mod)); + reg_fun(intern(lit("zerop"), user_package), func_n1(zerop)); + reg_fun(intern(lit("nzerop"), user_package), func_n1(nzerop)); + reg_fun(intern(lit("plusp"), user_package), func_n1(plusp)); + reg_fun(intern(lit("minusp"), user_package), func_n1(minusp)); + reg_fun(intern(lit("evenp"), user_package), func_n1(evenp)); + reg_fun(intern(lit("oddp"), user_package), func_n1(oddp)); + reg_fun(intern(lit("succ"), user_package), func_n1(succ)); + reg_fun(intern(lit("ssucc"), user_package), func_n1(ssucc)); + reg_fun(intern(lit("sssucc"), user_package), func_n1(sssucc)); + reg_fun(intern(lit("pred"), user_package), func_n1(pred)); + reg_fun(intern(lit("ppred"), user_package), func_n1(ppred)); + reg_fun(intern(lit("pppred"), user_package), func_n1(pppred)); + reg_fun(intern(lit(">"), user_package), func_n1v(gtv)); + reg_fun(intern(lit("<"), user_package), func_n1v(ltv)); + reg_fun(intern(lit(">="), user_package), func_n1v(gev)); + reg_fun(intern(lit("<="), user_package), func_n1v(lev)); + reg_fun(intern(lit("="), user_package), func_n1v(numeqv)); + reg_fun(intern(lit("/="), user_package), func_n0v(numneqv)); reg_fun(intern(lit("wrap"), user_package), func_n3(wrap)); reg_fun(intern(lit("wrap*"), user_package), func_n3(wrap_star)); reg_fun(intern(lit("/"), user_package), func_n1v(divv)); @@ -3485,6 +3917,19 @@ void arith_init(void) reg_fun(intern(lit("log2"), user_package), func_n1(logtwo)); reg_fun(intern(lit("exp"), user_package), func_n1(expo)); reg_fun(intern(lit("sqrt"), user_package), func_n1(sqroot)); + reg_fun(intern(lit("logand"), user_package), func_n0v(logandv)); + reg_fun(intern(lit("logior"), user_package), func_n0v(logiorv)); + reg_fun(intern(lit("logxor"), user_package), + func_n2(if3(opt_compat && opt_compat <= 202, logxor_old, logxor))); + reg_fun(intern(lit("logtest"), user_package), func_n2(logtest)); + reg_fun(intern(lit("lognot"), user_package), func_n2o(lognot, 1)); + reg_fun(intern(lit("logtrunc"), user_package), func_n2(logtrunc)); + reg_fun(intern(lit("sign-extend"), user_package), func_n2(sign_extend)); + reg_fun(intern(lit("ash"), user_package), func_n2(ash)); + reg_fun(intern(lit("bit"), user_package), func_n2(bit)); + reg_fun(intern(lit("mask"), user_package), func_n0v(maskv)); + reg_fun(intern(lit("width"), user_package), func_n1(width)); + reg_fun(intern(lit("logcount"), user_package), func_n1(logcount)); reg_fun(intern(lit("cum-norm-dist"), user_package), func_n1(cum_norm_dist)); reg_fun(intern(lit("inv-cum-norm"), user_package), func_n1(inv_cum_norm)); reg_fun(intern(lit("n-choose-k"), user_package), func_n2(n_choose_k)); |