summaryrefslogtreecommitdiffstats
path: root/rand.c
diff options
context:
space:
mode:
authorKaz Kylheku <kaz@kylheku.com>2016-01-18 06:24:09 -0800
committerKaz Kylheku <kaz@kylheku.com>2016-01-18 06:24:09 -0800
commitcc8f11bf43842e38f0a515b8070f4a7afe9a716d (patch)
treec5110ac4e87d541fca5405e78ae29c585242af62 /rand.c
parenta1c7cc2f9faf4463722e78f24b9433dc9cf0bbf7 (diff)
downloadtxr-cc8f11bf43842e38f0a515b8070f4a7afe9a716d.tar.gz
txr-cc8f11bf43842e38f0a515b8070f4a7afe9a716d.tar.bz2
txr-cc8f11bf43842e38f0a515b8070f4a7afe9a716d.zip
Don't allow non-positive modulus in rand and random.
* rand.c (random): In fixnum case, allow only m >= 1. The code is restructured so that this check is done before we do some arithmetic with derived values, where the behavior can become undefined. * txr.1: Document the restriction on modulus range for rand and random.
Diffstat (limited to 'rand.c')
-rw-r--r--rand.c42
1 files changed, 22 insertions, 20 deletions
diff --git a/rand.c b/rand.c
index 480c22df..a2215124 100644
--- a/rand.c
+++ b/rand.c
@@ -223,34 +223,36 @@ val random(val state, val modulus)
} else if (fixnump(modulus)) {
cnum m = c_num(modulus);
int bits = highest_bit(m);
+ if (m == 1) {
+ return zero;
+ } else if (m > 1) {
#if SIZEOF_PTR >= 8
- int rands_needed = (bits + 32 - 1) / 32;
+ int rands_needed = (bits + 32 - 1) / 32;
#endif
- int msb_rand_bits = bits % 32;
- rand32_t msb_rand_mask = convert(rand32_t, -1) >> (32 - msb_rand_bits);
- if (m <= 0)
- goto invalid;
- for (;;) {
- cnum out = 0;
+ int msb_rand_bits = bits % 32;
+ rand32_t msb_rand_mask = convert(rand32_t, -1) >> (32 - msb_rand_bits);
+ for (;;) {
+ cnum out = 0;
#if SIZEOF_PTR >= 8
- int i;
+ int i;
- for (i = 0; i < rands_needed; i++) {
- rand32_t rnd = rand32(r);
- out <<= 32;
- if (i == 0)
- rnd &= msb_rand_mask;
- out |= rnd;
- }
+ for (i = 0; i < rands_needed; i++) {
+ rand32_t rnd = rand32(r);
+ out <<= 32;
+ if (i == 0)
+ rnd &= msb_rand_mask;
+ out |= rnd;
+ }
#else
- out = rand32(r) & msb_rand_mask;
+ out = rand32(r) & msb_rand_mask;
#endif
- if (out >= m)
- continue;
- return num(out);
+ if (out >= m)
+ continue;
+ return num(out);
+ }
}
}
-invalid:
+
uw_throwf(numeric_error_s, lit("random: invalid modulus ~s"),
modulus, nao);
}