summaryrefslogtreecommitdiffstats
path: root/mpi-patches/faster-square-root
blob: 30f3da912ee70864a32ba9bd6d20bffa41133ce8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
Index: mpi-1.8.6/mpi.c
===================================================================
--- mpi-1.8.6.orig/mpi.c	2012-03-04 11:45:43.071884757 -0800
+++ mpi-1.8.6/mpi.c	2012-03-04 11:45:43.556157007 -0800
@@ -158,6 +158,9 @@
 mp_err   s_mp_grow(mp_int *mp, mp_size min);   /* increase allocated size */
 mp_err   s_mp_pad(mp_int *mp, mp_size min);    /* left pad with zeroes    */
 
+int      s_highest_bit_mp(mp_int *a);
+mp_err   s_mp_set_bit(mp_int *a, int bit);
+
 void     s_mp_clamp(mp_int *mp);               /* clip leading zeroes     */
 
 void     s_mp_exch(mp_int *a, mp_int *b);      /* swap a and b in place   */
@@ -1535,77 +1538,55 @@
 
 /* {{{ mp_sqrt(a, b) */
 
-/*
-  mp_sqrt(a, b)
-
-  Compute the integer square root of a, and store the result in b.
-  Uses an integer-arithmetic version of Newton's iterative linear
-  approximation technique to determine this value; the result has the
-  following two properties:
-
-     b^2 <= a
-     (b+1)^2 >= a
-
-  It is a range error to pass a negative value.
- */
 mp_err mp_sqrt(mp_int *a, mp_int *b)
 {
-  mp_int   x, t;
-  mp_err   res;
+  int mask_shift;
+  mp_int root, guess, *proot = &root, *pguess = &guess;
+  mp_int guess_sqr;
+  mp_err err = MP_MEM;
 
   ARGCHK(a != NULL && b != NULL, MP_BADARG);
 
-  /* Cannot take square root of a negative value */
-  if(SIGN(a) == MP_NEG)
+  if (mp_cmp_z(b) == MP_LT)
     return MP_RANGE;
 
-  /* Special cases for zero and one, trivial     */
-  if(mp_cmp_d(a, 0) == MP_EQ || mp_cmp_d(a, 1) == MP_EQ) 
-    return mp_copy(a, b);
-    
-  /* Initialize the temporaries we'll use below  */
-  if((res = mp_init_size(&t, USED(a))) != MP_OKAY)
-    return res;
-
-  /* Compute an initial guess for the iteration as a itself */
-  if((res = mp_init_copy(&x, a)) != MP_OKAY)
-    goto X;
-
-  for(;;) {
-    /* t = (x * x) - a */
-    mp_copy(&x, &t);      /* can't fail, t is big enough for original x */
-    if((res = mp_sqr(&t, &t)) != MP_OKAY ||
-       (res = mp_sub(&t, a, &t)) != MP_OKAY)
-      goto CLEANUP;
-
-    /* t = t / 2x       */
-    s_mp_mul_2(&x);
-    if((res = mp_div(&t, &x, &t, NULL)) != MP_OKAY)
-      goto CLEANUP;
-    s_mp_div_2(&x);
-
-    /* Terminate the loop, if the quotient is zero */
-    if(mp_cmp_z(&t) == MP_EQ)
-      break;
-
-    /* x = x - t       */
-    if((res = mp_sub(&x, &t, &x)) != MP_OKAY)
-      goto CLEANUP;
-
+  if ((err = mp_init(&root)))
+    goto out;
+  if ((err = mp_init(&guess)))
+    goto cleanup_root;
+  if ((err = mp_init(&guess_sqr)))
+    goto cleanup_guess;
+
+  for (mask_shift = s_highest_bit_mp(a) / 2; mask_shift > 0; mask_shift--) {
+    mp_int *temp;
+
+    if ((err = mp_copy(proot, pguess)))
+	goto cleanup;
+    if ((err = s_mp_set_bit(pguess, mask_shift)))
+	goto cleanup;
+    if ((err = mp_copy(pguess, &guess_sqr)))
+	goto cleanup;
+    if ((err = s_mp_sqr(&guess_sqr)))
+	goto cleanup;
+
+    if (s_mp_cmp(&guess_sqr, a) <= 0) {
+	temp = proot;
+	proot = pguess;
+	pguess = temp;
+    }
   }
 
-  /* Copy result to output parameter */
-  mp_sub_d(&x, 1, &x);
-  s_mp_exch(&x, b);
-
- CLEANUP:
-  mp_clear(&x);
- X:
-  mp_clear(&t); 
-
-  return res;
-
-} /* end mp_sqrt() */
+  err = mp_copy(proot, b);
+
+cleanup:
+  mp_clear(&guess_sqr);
+cleanup_guess:
+  mp_clear(&guess);
+cleanup_root:
+  mp_clear(&root);
+out:
+  return err;
+}
 
 /* }}} */
 
@@ -2554,21 +2535,9 @@
 
 int    mp_count_bits(mp_int *mp)
 {
-  int      len;
-  mp_digit d;
-
   ARGCHK(mp != NULL, MP_BADARG);
 
-  len = DIGIT_BIT * (USED(mp) - 1);
-  d = DIGIT(mp, USED(mp) - 1);
-
-  while(d != 0) {
-    ++len;
-    d >>= 1;
-  }
-
-  return len;
-  
+  return s_highest_bit_mp(mp);
 } /* end mp_count_bits() */
 
 /* }}} */
@@ -3132,6 +3101,27 @@
   abort();
 }
 
+int s_highest_bit_mp(mp_int *a)
+{
+  int nd1 = USED(a) - 1;
+  return s_highest_bit(DIGIT(a, nd1)) + nd1 * MP_DIGIT_BIT;
+}
+
+mp_err s_mp_set_bit(mp_int *a, int bit)
+{
+  int nd = (bit + MP_DIGIT_BIT) / MP_DIGIT_BIT;
+  int nbit = bit - (nd - 1) * MP_DIGIT_BIT;
+  mp_err res;
+
+  if (nd == 0)
+    return MP_OKAY;
+
+  if ((res = s_mp_pad(a, nd)) != MP_OKAY)
+    return res;
+
+  DIGIT(a, nd - 1) |= ((mp_digit) 1 << nbit);
+  return MP_OKAY;
+}
 
 /* {{{ s_mp_exch(a, b) */