diff options
author | Kaz Kylheku <kaz@kylheku.com> | 2017-08-05 21:46:08 -0700 |
---|---|---|
committer | Kaz Kylheku <kaz@kylheku.com> | 2017-08-05 21:46:08 -0700 |
commit | 2e0acd6057bd65b0b872356c51c1b7b1a06c89b9 (patch) | |
tree | a4483a81092a80fda41c624165586371d32a4dce | |
parent | 2844bb73b485450660d70de2de489590d0995d9e (diff) | |
download | txr-2e0acd6057bd65b0b872356c51c1b7b1a06c89b9.tar.gz txr-2e0acd6057bd65b0b872356c51c1b7b1a06c89b9.tar.bz2 txr-2e0acd6057bd65b0b872356c51c1b7b1a06c89b9.zip |
bugfix: n-ary arith functions must check single arg.
We are allowing calls like (* "a") and (+ "a")
without diagnosing that the argument isn't of a valid
type. Note that (max "a") is fine beacause min and
max use the less function; they are not strictly numeric.
* lib.c (nary_op): Beef up function with additional argument
for type checking the unary case.
(unary_num, unary_arith, unary_int): New static functions.
(plusv, mulv, logandv, logiorv): Use new nary_op interface.
(gtv, ltv, gev, lev, numeqv, numneq): Check the
first number.
* lib.c (nary_op): Declaration updated.
-rw-r--r-- | lib.c | 65 | ||||
-rw-r--r-- | lib.h | 4 |
2 files changed, 60 insertions, 9 deletions
@@ -3117,7 +3117,9 @@ val numberp(val num) } } -val nary_op(val (*cfunc)(val, val), struct args *args, val emptyval) +val nary_op(val self, val (*bfun)(val, val), + val (*ufun)(val self, val), + struct args *args, val emptyval) { val fi, se, re; cnum index = 0; @@ -3128,21 +3130,48 @@ val nary_op(val (*cfunc)(val, val), struct args *args, val emptyval) fi = args_get(args, &index); if (!args_more(args, index)) - return fi; + return ufun(self, fi); se = args_get(args, &index); if (!args_more(args, index)) - return cfunc(fi, se); + return bfun(fi, se); re = args_get_rest(args, index); - return reduce_left(func_n2(cfunc), re, cfunc(fi, se), nil); + return reduce_left(func_n2(bfun), re, bfun(fi, se), nil); +} + +static val unary_num(val self, val arg) +{ + if (!numberp(arg)) + uw_throwf(error_s, lit("~a: ~s isn't a number"), self, arg, nao); + return arg; +} + +static val unary_arith(val self, val arg) +{ + switch (type(arg)) { + case NUM: + case CHR: + case BGNUM: + case FLNUM: + return arg; + default: + uw_throwf(error_s, lit("~a: invalid argument ~s"), self, arg, nao); + } +} + +static val unary_int(val self, val arg) +{ + if (!integerp(arg)) + uw_throwf(error_s, lit("~a: ~s isn't an integer"), self, arg, nao); + return arg; } val plusv(struct args *nlist) { - return nary_op(plus, nlist, zero); + return nary_op(lit("+"), plus, unary_arith, nlist, zero); } val minusv(val minuend, struct args *nlist) @@ -3164,7 +3193,7 @@ val minusv(val minuend, struct args *nlist) val mulv(struct args *nlist) { - return nary_op(mul, nlist, one); + return nary_op(lit("*"), mul, unary_num, nlist, one); } val divv(val dividend, struct args *nlist) @@ -3186,12 +3215,12 @@ val divv(val dividend, struct args *nlist) val logandv(struct args *nlist) { - return nary_op(logand, nlist, negone); + return nary_op(lit("logand"), logand, unary_int, nlist, negone); } val logiorv(struct args *nlist) { - return nary_op(logior, nlist, zero); + return nary_op(lit("logior"), logior, unary_int, nlist, zero); } val gtv(val first, struct args *rest) @@ -3205,6 +3234,9 @@ val gtv(val first, struct args *rest) first = elem; } + if (index == 0) + (void) unary_arith(lit(">"), first); + return t; } @@ -3219,6 +3251,9 @@ val ltv(val first, struct args *rest) first = elem; } + if (index == 0) + (void) unary_arith(lit("<"), first); + return t; } @@ -3233,6 +3268,9 @@ val gev(val first, struct args *rest) first = elem; } + if (index == 0) + (void) unary_arith(lit(">="), first); + return t; } @@ -3247,6 +3285,9 @@ val lev(val first, struct args *rest) first = elem; } + if (index == 0) + (void) unary_arith(lit("<="), first); + return t; } @@ -3261,6 +3302,9 @@ val numeqv(val first, struct args *rest) first = elem; } + if (index == 0) + (void) unary_arith(lit("="), first); + return t; } @@ -3269,6 +3313,11 @@ val numneqv(struct args *args) val i, j; val list = args_get_list(args); + if (list && !cdr(list)) { + (void) unary_arith(lit("/="), car(list)); + return t; + } + for (i = list; i; i = cdr(i)) for (j = cdr(i); j; j = cdr(j)) if (numeq(car(i), car(j))) @@ -654,7 +654,9 @@ val bignump(val num); val floatp(val num); val integerp(val num); val numberp(val num); -val nary_op(val (*cfunc)(val, val), struct args *args, val emptyval); +val nary_op(val self, val (*bfun)(val, val), + val (*ufun)(val self, val), + struct args *args, val emptyval); val plus(val anum, val bnum); val plusv(struct args *); val minus(val anum, val bnum); |