0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021 #include <linux/bitops.h>
0022 #include <linux/count_zeros.h>
0023 #include <linux/byteorder/generic.h>
0024 #include <linux/scatterlist.h>
0025 #include <linux/string.h>
0026 #include "mpi-internal.h"
0027
0028 #define MAX_EXTERN_SCAN_BYTES (16*1024*1024)
0029 #define MAX_EXTERN_MPI_BITS 16384
0030
0031
0032
0033
0034
0035
0036 MPI mpi_read_raw_data(const void *xbuffer, size_t nbytes)
0037 {
0038 const uint8_t *buffer = xbuffer;
0039 int i, j;
0040 unsigned nbits, nlimbs;
0041 mpi_limb_t a;
0042 MPI val = NULL;
0043
0044 while (nbytes > 0 && buffer[0] == 0) {
0045 buffer++;
0046 nbytes--;
0047 }
0048
0049 nbits = nbytes * 8;
0050 if (nbits > MAX_EXTERN_MPI_BITS) {
0051 pr_info("MPI: mpi too large (%u bits)\n", nbits);
0052 return NULL;
0053 }
0054 if (nbytes > 0)
0055 nbits -= count_leading_zeros(buffer[0]) - (BITS_PER_LONG - 8);
0056
0057 nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB);
0058 val = mpi_alloc(nlimbs);
0059 if (!val)
0060 return NULL;
0061 val->nbits = nbits;
0062 val->sign = 0;
0063 val->nlimbs = nlimbs;
0064
0065 if (nbytes > 0) {
0066 i = BYTES_PER_MPI_LIMB - nbytes % BYTES_PER_MPI_LIMB;
0067 i %= BYTES_PER_MPI_LIMB;
0068 for (j = nlimbs; j > 0; j--) {
0069 a = 0;
0070 for (; i < BYTES_PER_MPI_LIMB; i++) {
0071 a <<= 8;
0072 a |= *buffer++;
0073 }
0074 i = 0;
0075 val->d[j - 1] = a;
0076 }
0077 }
0078 return val;
0079 }
0080 EXPORT_SYMBOL_GPL(mpi_read_raw_data);
0081
0082 MPI mpi_read_from_buffer(const void *xbuffer, unsigned *ret_nread)
0083 {
0084 const uint8_t *buffer = xbuffer;
0085 unsigned int nbits, nbytes;
0086 MPI val;
0087
0088 if (*ret_nread < 2)
0089 return ERR_PTR(-EINVAL);
0090 nbits = buffer[0] << 8 | buffer[1];
0091
0092 if (nbits > MAX_EXTERN_MPI_BITS) {
0093 pr_info("MPI: mpi too large (%u bits)\n", nbits);
0094 return ERR_PTR(-EINVAL);
0095 }
0096
0097 nbytes = DIV_ROUND_UP(nbits, 8);
0098 if (nbytes + 2 > *ret_nread) {
0099 pr_info("MPI: mpi larger than buffer nbytes=%u ret_nread=%u\n",
0100 nbytes, *ret_nread);
0101 return ERR_PTR(-EINVAL);
0102 }
0103
0104 val = mpi_read_raw_data(buffer + 2, nbytes);
0105 if (!val)
0106 return ERR_PTR(-ENOMEM);
0107
0108 *ret_nread = nbytes + 2;
0109 return val;
0110 }
0111 EXPORT_SYMBOL_GPL(mpi_read_from_buffer);
0112
0113
0114
0115
0116 int mpi_fromstr(MPI val, const char *str)
0117 {
0118 int sign = 0;
0119 int prepend_zero = 0;
0120 int i, j, c, c1, c2;
0121 unsigned int nbits, nbytes, nlimbs;
0122 mpi_limb_t a;
0123
0124 if (*str == '-') {
0125 sign = 1;
0126 str++;
0127 }
0128
0129
0130 if (*str == '0' && str[1] == 'x')
0131 str += 2;
0132
0133 nbits = strlen(str);
0134 if (nbits > MAX_EXTERN_SCAN_BYTES) {
0135 mpi_clear(val);
0136 return -EINVAL;
0137 }
0138 nbits *= 4;
0139 if ((nbits % 8))
0140 prepend_zero = 1;
0141
0142 nbytes = (nbits+7) / 8;
0143 nlimbs = (nbytes+BYTES_PER_MPI_LIMB-1) / BYTES_PER_MPI_LIMB;
0144
0145 if (val->alloced < nlimbs)
0146 mpi_resize(val, nlimbs);
0147
0148 i = BYTES_PER_MPI_LIMB - (nbytes % BYTES_PER_MPI_LIMB);
0149 i %= BYTES_PER_MPI_LIMB;
0150 j = val->nlimbs = nlimbs;
0151 val->sign = sign;
0152 for (; j > 0; j--) {
0153 a = 0;
0154 for (; i < BYTES_PER_MPI_LIMB; i++) {
0155 if (prepend_zero) {
0156 c1 = '0';
0157 prepend_zero = 0;
0158 } else
0159 c1 = *str++;
0160
0161 if (!c1) {
0162 mpi_clear(val);
0163 return -EINVAL;
0164 }
0165 c2 = *str++;
0166 if (!c2) {
0167 mpi_clear(val);
0168 return -EINVAL;
0169 }
0170 if (c1 >= '0' && c1 <= '9')
0171 c = c1 - '0';
0172 else if (c1 >= 'a' && c1 <= 'f')
0173 c = c1 - 'a' + 10;
0174 else if (c1 >= 'A' && c1 <= 'F')
0175 c = c1 - 'A' + 10;
0176 else {
0177 mpi_clear(val);
0178 return -EINVAL;
0179 }
0180 c <<= 4;
0181 if (c2 >= '0' && c2 <= '9')
0182 c |= c2 - '0';
0183 else if (c2 >= 'a' && c2 <= 'f')
0184 c |= c2 - 'a' + 10;
0185 else if (c2 >= 'A' && c2 <= 'F')
0186 c |= c2 - 'A' + 10;
0187 else {
0188 mpi_clear(val);
0189 return -EINVAL;
0190 }
0191 a <<= 8;
0192 a |= c;
0193 }
0194 i = 0;
0195 val->d[j-1] = a;
0196 }
0197
0198 return 0;
0199 }
0200 EXPORT_SYMBOL_GPL(mpi_fromstr);
0201
0202 MPI mpi_scanval(const char *string)
0203 {
0204 MPI a;
0205
0206 a = mpi_alloc(0);
0207 if (!a)
0208 return NULL;
0209
0210 if (mpi_fromstr(a, string)) {
0211 mpi_free(a);
0212 return NULL;
0213 }
0214 mpi_normalize(a);
0215 return a;
0216 }
0217 EXPORT_SYMBOL_GPL(mpi_scanval);
0218
0219 static int count_lzeros(MPI a)
0220 {
0221 mpi_limb_t alimb;
0222 int i, lzeros = 0;
0223
0224 for (i = a->nlimbs - 1; i >= 0; i--) {
0225 alimb = a->d[i];
0226 if (alimb == 0) {
0227 lzeros += sizeof(mpi_limb_t);
0228 } else {
0229 lzeros += count_leading_zeros(alimb) / 8;
0230 break;
0231 }
0232 }
0233 return lzeros;
0234 }
0235
0236
0237
0238
0239
0240
0241
0242
0243
0244
0245
0246
0247
0248
0249
0250 int mpi_read_buffer(MPI a, uint8_t *buf, unsigned buf_len, unsigned *nbytes,
0251 int *sign)
0252 {
0253 uint8_t *p;
0254 #if BYTES_PER_MPI_LIMB == 4
0255 __be32 alimb;
0256 #elif BYTES_PER_MPI_LIMB == 8
0257 __be64 alimb;
0258 #else
0259 #error please implement for this limb size.
0260 #endif
0261 unsigned int n = mpi_get_size(a);
0262 int i, lzeros;
0263
0264 if (!buf || !nbytes)
0265 return -EINVAL;
0266
0267 if (sign)
0268 *sign = a->sign;
0269
0270 lzeros = count_lzeros(a);
0271
0272 if (buf_len < n - lzeros) {
0273 *nbytes = n - lzeros;
0274 return -EOVERFLOW;
0275 }
0276
0277 p = buf;
0278 *nbytes = n - lzeros;
0279
0280 for (i = a->nlimbs - 1 - lzeros / BYTES_PER_MPI_LIMB,
0281 lzeros %= BYTES_PER_MPI_LIMB;
0282 i >= 0; i--) {
0283 #if BYTES_PER_MPI_LIMB == 4
0284 alimb = cpu_to_be32(a->d[i]);
0285 #elif BYTES_PER_MPI_LIMB == 8
0286 alimb = cpu_to_be64(a->d[i]);
0287 #else
0288 #error please implement for this limb size.
0289 #endif
0290 memcpy(p, (u8 *)&alimb + lzeros, BYTES_PER_MPI_LIMB - lzeros);
0291 p += BYTES_PER_MPI_LIMB - lzeros;
0292 lzeros = 0;
0293 }
0294 return 0;
0295 }
0296 EXPORT_SYMBOL_GPL(mpi_read_buffer);
0297
0298
0299
0300
0301
0302
0303
0304
0305
0306
0307
0308
0309
0310 void *mpi_get_buffer(MPI a, unsigned *nbytes, int *sign)
0311 {
0312 uint8_t *buf;
0313 unsigned int n;
0314 int ret;
0315
0316 if (!nbytes)
0317 return NULL;
0318
0319 n = mpi_get_size(a);
0320
0321 if (!n)
0322 n++;
0323
0324 buf = kmalloc(n, GFP_KERNEL);
0325
0326 if (!buf)
0327 return NULL;
0328
0329 ret = mpi_read_buffer(a, buf, n, nbytes, sign);
0330
0331 if (ret) {
0332 kfree(buf);
0333 return NULL;
0334 }
0335 return buf;
0336 }
0337 EXPORT_SYMBOL_GPL(mpi_get_buffer);
0338
0339
0340
0341
0342
0343
0344
0345
0346
0347
0348
0349
0350
0351
0352
0353
0354 int mpi_write_to_sgl(MPI a, struct scatterlist *sgl, unsigned nbytes,
0355 int *sign)
0356 {
0357 u8 *p, *p2;
0358 #if BYTES_PER_MPI_LIMB == 4
0359 __be32 alimb;
0360 #elif BYTES_PER_MPI_LIMB == 8
0361 __be64 alimb;
0362 #else
0363 #error please implement for this limb size.
0364 #endif
0365 unsigned int n = mpi_get_size(a);
0366 struct sg_mapping_iter miter;
0367 int i, x, buf_len;
0368 int nents;
0369
0370 if (sign)
0371 *sign = a->sign;
0372
0373 if (nbytes < n)
0374 return -EOVERFLOW;
0375
0376 nents = sg_nents_for_len(sgl, nbytes);
0377 if (nents < 0)
0378 return -EINVAL;
0379
0380 sg_miter_start(&miter, sgl, nents, SG_MITER_ATOMIC | SG_MITER_TO_SG);
0381 sg_miter_next(&miter);
0382 buf_len = miter.length;
0383 p2 = miter.addr;
0384
0385 while (nbytes > n) {
0386 i = min_t(unsigned, nbytes - n, buf_len);
0387 memset(p2, 0, i);
0388 p2 += i;
0389 nbytes -= i;
0390
0391 buf_len -= i;
0392 if (!buf_len) {
0393 sg_miter_next(&miter);
0394 buf_len = miter.length;
0395 p2 = miter.addr;
0396 }
0397 }
0398
0399 for (i = a->nlimbs - 1; i >= 0; i--) {
0400 #if BYTES_PER_MPI_LIMB == 4
0401 alimb = a->d[i] ? cpu_to_be32(a->d[i]) : 0;
0402 #elif BYTES_PER_MPI_LIMB == 8
0403 alimb = a->d[i] ? cpu_to_be64(a->d[i]) : 0;
0404 #else
0405 #error please implement for this limb size.
0406 #endif
0407 p = (u8 *)&alimb;
0408
0409 for (x = 0; x < sizeof(alimb); x++) {
0410 *p2++ = *p++;
0411 if (!--buf_len) {
0412 sg_miter_next(&miter);
0413 buf_len = miter.length;
0414 p2 = miter.addr;
0415 }
0416 }
0417 }
0418
0419 sg_miter_stop(&miter);
0420 return 0;
0421 }
0422 EXPORT_SYMBOL_GPL(mpi_write_to_sgl);
0423
0424
0425
0426
0427
0428
0429
0430
0431
0432
0433
0434
0435
0436
0437 MPI mpi_read_raw_from_sgl(struct scatterlist *sgl, unsigned int nbytes)
0438 {
0439 struct sg_mapping_iter miter;
0440 unsigned int nbits, nlimbs;
0441 int x, j, z, lzeros, ents;
0442 unsigned int len;
0443 const u8 *buff;
0444 mpi_limb_t a;
0445 MPI val = NULL;
0446
0447 ents = sg_nents_for_len(sgl, nbytes);
0448 if (ents < 0)
0449 return NULL;
0450
0451 sg_miter_start(&miter, sgl, ents, SG_MITER_ATOMIC | SG_MITER_FROM_SG);
0452
0453 lzeros = 0;
0454 len = 0;
0455 while (nbytes > 0) {
0456 while (len && !*buff) {
0457 lzeros++;
0458 len--;
0459 buff++;
0460 }
0461
0462 if (len && *buff)
0463 break;
0464
0465 sg_miter_next(&miter);
0466 buff = miter.addr;
0467 len = miter.length;
0468
0469 nbytes -= lzeros;
0470 lzeros = 0;
0471 }
0472
0473 miter.consumed = lzeros;
0474
0475 nbytes -= lzeros;
0476 nbits = nbytes * 8;
0477 if (nbits > MAX_EXTERN_MPI_BITS) {
0478 sg_miter_stop(&miter);
0479 pr_info("MPI: mpi too large (%u bits)\n", nbits);
0480 return NULL;
0481 }
0482
0483 if (nbytes > 0)
0484 nbits -= count_leading_zeros(*buff) - (BITS_PER_LONG - 8);
0485
0486 sg_miter_stop(&miter);
0487
0488 nlimbs = DIV_ROUND_UP(nbytes, BYTES_PER_MPI_LIMB);
0489 val = mpi_alloc(nlimbs);
0490 if (!val)
0491 return NULL;
0492
0493 val->nbits = nbits;
0494 val->sign = 0;
0495 val->nlimbs = nlimbs;
0496
0497 if (nbytes == 0)
0498 return val;
0499
0500 j = nlimbs - 1;
0501 a = 0;
0502 z = BYTES_PER_MPI_LIMB - nbytes % BYTES_PER_MPI_LIMB;
0503 z %= BYTES_PER_MPI_LIMB;
0504
0505 while (sg_miter_next(&miter)) {
0506 buff = miter.addr;
0507 len = miter.length;
0508
0509 for (x = 0; x < len; x++) {
0510 a <<= 8;
0511 a |= *buff++;
0512 if (((z + x + 1) % BYTES_PER_MPI_LIMB) == 0) {
0513 val->d[j--] = a;
0514 a = 0;
0515 }
0516 }
0517 z += x;
0518 }
0519
0520 return val;
0521 }
0522 EXPORT_SYMBOL_GPL(mpi_read_raw_from_sgl);
0523
0524
0525 static void twocompl(unsigned char *p, unsigned int n)
0526 {
0527 int i;
0528
0529 for (i = n-1; i >= 0 && !p[i]; i--)
0530 ;
0531 if (i >= 0) {
0532 if ((p[i] & 0x01))
0533 p[i] = (((p[i] ^ 0xfe) | 0x01) & 0xff);
0534 else if ((p[i] & 0x02))
0535 p[i] = (((p[i] ^ 0xfc) | 0x02) & 0xfe);
0536 else if ((p[i] & 0x04))
0537 p[i] = (((p[i] ^ 0xf8) | 0x04) & 0xfc);
0538 else if ((p[i] & 0x08))
0539 p[i] = (((p[i] ^ 0xf0) | 0x08) & 0xf8);
0540 else if ((p[i] & 0x10))
0541 p[i] = (((p[i] ^ 0xe0) | 0x10) & 0xf0);
0542 else if ((p[i] & 0x20))
0543 p[i] = (((p[i] ^ 0xc0) | 0x20) & 0xe0);
0544 else if ((p[i] & 0x40))
0545 p[i] = (((p[i] ^ 0x80) | 0x40) & 0xc0);
0546 else
0547 p[i] = 0x80;
0548
0549 for (i--; i >= 0; i--)
0550 p[i] ^= 0xff;
0551 }
0552 }
0553
0554 int mpi_print(enum gcry_mpi_format format, unsigned char *buffer,
0555 size_t buflen, size_t *nwritten, MPI a)
0556 {
0557 unsigned int nbits = mpi_get_nbits(a);
0558 size_t len;
0559 size_t dummy_nwritten;
0560 int negative;
0561
0562 if (!nwritten)
0563 nwritten = &dummy_nwritten;
0564
0565
0566
0567
0568
0569
0570
0571 if (a->sign && mpi_cmp_ui(a, 0))
0572 negative = 1;
0573 else
0574 negative = 0;
0575
0576 len = buflen;
0577 *nwritten = 0;
0578 if (format == GCRYMPI_FMT_STD) {
0579 unsigned char *tmp;
0580 int extra = 0;
0581 unsigned int n;
0582
0583 tmp = mpi_get_buffer(a, &n, NULL);
0584 if (!tmp)
0585 return -EINVAL;
0586
0587 if (negative) {
0588 twocompl(tmp, n);
0589 if (!(*tmp & 0x80)) {
0590
0591 n++;
0592 extra = 2;
0593 }
0594 } else if (n && (*tmp & 0x80)) {
0595
0596
0597
0598
0599 n++;
0600 extra = 1;
0601 }
0602
0603 if (buffer && n > len) {
0604
0605 kfree(tmp);
0606 return -E2BIG;
0607 }
0608 if (buffer) {
0609 unsigned char *s = buffer;
0610
0611 if (extra == 1)
0612 *s++ = 0;
0613 else if (extra)
0614 *s++ = 0xff;
0615 memcpy(s, tmp, n-!!extra);
0616 }
0617 kfree(tmp);
0618 *nwritten = n;
0619 return 0;
0620 } else if (format == GCRYMPI_FMT_USG) {
0621 unsigned int n = (nbits + 7)/8;
0622
0623
0624
0625
0626
0627
0628 if (buffer && n > len)
0629 return -E2BIG;
0630 if (buffer) {
0631 unsigned char *tmp;
0632
0633 tmp = mpi_get_buffer(a, &n, NULL);
0634 if (!tmp)
0635 return -EINVAL;
0636 memcpy(buffer, tmp, n);
0637 kfree(tmp);
0638 }
0639 *nwritten = n;
0640 return 0;
0641 } else if (format == GCRYMPI_FMT_PGP) {
0642 unsigned int n = (nbits + 7)/8;
0643
0644
0645 if (negative)
0646 return -EINVAL;
0647
0648 if (buffer && n+2 > len)
0649 return -E2BIG;
0650
0651 if (buffer) {
0652 unsigned char *tmp;
0653 unsigned char *s = buffer;
0654
0655 s[0] = nbits >> 8;
0656 s[1] = nbits;
0657
0658 tmp = mpi_get_buffer(a, &n, NULL);
0659 if (!tmp)
0660 return -EINVAL;
0661 memcpy(s+2, tmp, n);
0662 kfree(tmp);
0663 }
0664 *nwritten = n+2;
0665 return 0;
0666 } else if (format == GCRYMPI_FMT_SSH) {
0667 unsigned char *tmp;
0668 int extra = 0;
0669 unsigned int n;
0670
0671 tmp = mpi_get_buffer(a, &n, NULL);
0672 if (!tmp)
0673 return -EINVAL;
0674
0675 if (negative) {
0676 twocompl(tmp, n);
0677 if (!(*tmp & 0x80)) {
0678
0679 n++;
0680 extra = 2;
0681 }
0682 } else if (n && (*tmp & 0x80)) {
0683 n++;
0684 extra = 1;
0685 }
0686
0687 if (buffer && n+4 > len) {
0688 kfree(tmp);
0689 return -E2BIG;
0690 }
0691
0692 if (buffer) {
0693 unsigned char *s = buffer;
0694
0695 *s++ = n >> 24;
0696 *s++ = n >> 16;
0697 *s++ = n >> 8;
0698 *s++ = n;
0699 if (extra == 1)
0700 *s++ = 0;
0701 else if (extra)
0702 *s++ = 0xff;
0703 memcpy(s, tmp, n-!!extra);
0704 }
0705 kfree(tmp);
0706 *nwritten = 4+n;
0707 return 0;
0708 } else if (format == GCRYMPI_FMT_HEX) {
0709 unsigned char *tmp;
0710 int i;
0711 int extra = 0;
0712 unsigned int n = 0;
0713
0714 tmp = mpi_get_buffer(a, &n, NULL);
0715 if (!tmp)
0716 return -EINVAL;
0717 if (!n || (*tmp & 0x80))
0718 extra = 2;
0719
0720 if (buffer && 2*n + extra + negative + 1 > len) {
0721 kfree(tmp);
0722 return -E2BIG;
0723 }
0724 if (buffer) {
0725 unsigned char *s = buffer;
0726
0727 if (negative)
0728 *s++ = '-';
0729 if (extra) {
0730 *s++ = '0';
0731 *s++ = '0';
0732 }
0733
0734 for (i = 0; i < n; i++) {
0735 unsigned int c = tmp[i];
0736
0737 *s++ = (c >> 4) < 10 ? '0'+(c>>4) : 'A'+(c>>4)-10;
0738 c &= 15;
0739 *s++ = c < 10 ? '0'+c : 'A'+c-10;
0740 }
0741 *s++ = 0;
0742 *nwritten = s - buffer;
0743 } else {
0744 *nwritten = 2*n + extra + negative + 1;
0745 }
0746 kfree(tmp);
0747 return 0;
0748 } else
0749 return -EINVAL;
0750 }
0751 EXPORT_SYMBOL_GPL(mpi_print);