Back to home page

LXR

 
 

    


0001 /*
0002  * algif_skcipher: User-space interface for skcipher algorithms
0003  *
0004  * This file provides the user-space API for symmetric key ciphers.
0005  *
0006  * Copyright (c) 2010 Herbert Xu <herbert@gondor.apana.org.au>
0007  *
0008  * This program is free software; you can redistribute it and/or modify it
0009  * under the terms of the GNU General Public License as published by the Free
0010  * Software Foundation; either version 2 of the License, or (at your option)
0011  * any later version.
0012  *
0013  */
0014 
0015 #include <crypto/scatterwalk.h>
0016 #include <crypto/skcipher.h>
0017 #include <crypto/if_alg.h>
0018 #include <linux/init.h>
0019 #include <linux/list.h>
0020 #include <linux/kernel.h>
0021 #include <linux/mm.h>
0022 #include <linux/module.h>
0023 #include <linux/net.h>
0024 #include <net/sock.h>
0025 
0026 struct skcipher_sg_list {
0027     struct list_head list;
0028 
0029     int cur;
0030 
0031     struct scatterlist sg[0];
0032 };
0033 
0034 struct skcipher_tfm {
0035     struct crypto_skcipher *skcipher;
0036     bool has_key;
0037 };
0038 
0039 struct skcipher_ctx {
0040     struct list_head tsgl;
0041     struct af_alg_sgl rsgl;
0042 
0043     void *iv;
0044 
0045     struct af_alg_completion completion;
0046 
0047     atomic_t inflight;
0048     size_t used;
0049 
0050     unsigned int len;
0051     bool more;
0052     bool merge;
0053     bool enc;
0054 
0055     struct skcipher_request req;
0056 };
0057 
0058 struct skcipher_async_rsgl {
0059     struct af_alg_sgl sgl;
0060     struct list_head list;
0061 };
0062 
0063 struct skcipher_async_req {
0064     struct kiocb *iocb;
0065     struct skcipher_async_rsgl first_sgl;
0066     struct list_head list;
0067     struct scatterlist *tsg;
0068     atomic_t *inflight;
0069     struct skcipher_request req;
0070 };
0071 
0072 #define MAX_SGL_ENTS ((4096 - sizeof(struct skcipher_sg_list)) / \
0073               sizeof(struct scatterlist) - 1)
0074 
0075 static void skcipher_free_async_sgls(struct skcipher_async_req *sreq)
0076 {
0077     struct skcipher_async_rsgl *rsgl, *tmp;
0078     struct scatterlist *sgl;
0079     struct scatterlist *sg;
0080     int i, n;
0081 
0082     list_for_each_entry_safe(rsgl, tmp, &sreq->list, list) {
0083         af_alg_free_sg(&rsgl->sgl);
0084         if (rsgl != &sreq->first_sgl)
0085             kfree(rsgl);
0086     }
0087     sgl = sreq->tsg;
0088     n = sg_nents(sgl);
0089     for_each_sg(sgl, sg, n, i)
0090         put_page(sg_page(sg));
0091 
0092     kfree(sreq->tsg);
0093 }
0094 
0095 static void skcipher_async_cb(struct crypto_async_request *req, int err)
0096 {
0097     struct skcipher_async_req *sreq = req->data;
0098     struct kiocb *iocb = sreq->iocb;
0099 
0100     atomic_dec(sreq->inflight);
0101     skcipher_free_async_sgls(sreq);
0102     kzfree(sreq);
0103     iocb->ki_complete(iocb, err, err);
0104 }
0105 
0106 static inline int skcipher_sndbuf(struct sock *sk)
0107 {
0108     struct alg_sock *ask = alg_sk(sk);
0109     struct skcipher_ctx *ctx = ask->private;
0110 
0111     return max_t(int, max_t(int, sk->sk_sndbuf & PAGE_MASK, PAGE_SIZE) -
0112               ctx->used, 0);
0113 }
0114 
0115 static inline bool skcipher_writable(struct sock *sk)
0116 {
0117     return PAGE_SIZE <= skcipher_sndbuf(sk);
0118 }
0119 
0120 static int skcipher_alloc_sgl(struct sock *sk)
0121 {
0122     struct alg_sock *ask = alg_sk(sk);
0123     struct skcipher_ctx *ctx = ask->private;
0124     struct skcipher_sg_list *sgl;
0125     struct scatterlist *sg = NULL;
0126 
0127     sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
0128     if (!list_empty(&ctx->tsgl))
0129         sg = sgl->sg;
0130 
0131     if (!sg || sgl->cur >= MAX_SGL_ENTS) {
0132         sgl = sock_kmalloc(sk, sizeof(*sgl) +
0133                        sizeof(sgl->sg[0]) * (MAX_SGL_ENTS + 1),
0134                    GFP_KERNEL);
0135         if (!sgl)
0136             return -ENOMEM;
0137 
0138         sg_init_table(sgl->sg, MAX_SGL_ENTS + 1);
0139         sgl->cur = 0;
0140 
0141         if (sg)
0142             sg_chain(sg, MAX_SGL_ENTS + 1, sgl->sg);
0143 
0144         list_add_tail(&sgl->list, &ctx->tsgl);
0145     }
0146 
0147     return 0;
0148 }
0149 
0150 static void skcipher_pull_sgl(struct sock *sk, size_t used, int put)
0151 {
0152     struct alg_sock *ask = alg_sk(sk);
0153     struct skcipher_ctx *ctx = ask->private;
0154     struct skcipher_sg_list *sgl;
0155     struct scatterlist *sg;
0156     int i;
0157 
0158     while (!list_empty(&ctx->tsgl)) {
0159         sgl = list_first_entry(&ctx->tsgl, struct skcipher_sg_list,
0160                        list);
0161         sg = sgl->sg;
0162 
0163         for (i = 0; i < sgl->cur; i++) {
0164             size_t plen = min_t(size_t, used, sg[i].length);
0165 
0166             if (!sg_page(sg + i))
0167                 continue;
0168 
0169             sg[i].length -= plen;
0170             sg[i].offset += plen;
0171 
0172             used -= plen;
0173             ctx->used -= plen;
0174 
0175             if (sg[i].length)
0176                 return;
0177             if (put)
0178                 put_page(sg_page(sg + i));
0179             sg_assign_page(sg + i, NULL);
0180         }
0181 
0182         list_del(&sgl->list);
0183         sock_kfree_s(sk, sgl,
0184                  sizeof(*sgl) + sizeof(sgl->sg[0]) *
0185                         (MAX_SGL_ENTS + 1));
0186     }
0187 
0188     if (!ctx->used)
0189         ctx->merge = 0;
0190 }
0191 
0192 static void skcipher_free_sgl(struct sock *sk)
0193 {
0194     struct alg_sock *ask = alg_sk(sk);
0195     struct skcipher_ctx *ctx = ask->private;
0196 
0197     skcipher_pull_sgl(sk, ctx->used, 1);
0198 }
0199 
0200 static int skcipher_wait_for_wmem(struct sock *sk, unsigned flags)
0201 {
0202     DEFINE_WAIT_FUNC(wait, woken_wake_function);
0203     int err = -ERESTARTSYS;
0204     long timeout;
0205 
0206     if (flags & MSG_DONTWAIT)
0207         return -EAGAIN;
0208 
0209     sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk);
0210 
0211     add_wait_queue(sk_sleep(sk), &wait);
0212     for (;;) {
0213         if (signal_pending(current))
0214             break;
0215         timeout = MAX_SCHEDULE_TIMEOUT;
0216         if (sk_wait_event(sk, &timeout, skcipher_writable(sk), &wait)) {
0217             err = 0;
0218             break;
0219         }
0220     }
0221     remove_wait_queue(sk_sleep(sk), &wait);
0222 
0223     return err;
0224 }
0225 
0226 static void skcipher_wmem_wakeup(struct sock *sk)
0227 {
0228     struct socket_wq *wq;
0229 
0230     if (!skcipher_writable(sk))
0231         return;
0232 
0233     rcu_read_lock();
0234     wq = rcu_dereference(sk->sk_wq);
0235     if (skwq_has_sleeper(wq))
0236         wake_up_interruptible_sync_poll(&wq->wait, POLLIN |
0237                                POLLRDNORM |
0238                                POLLRDBAND);
0239     sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN);
0240     rcu_read_unlock();
0241 }
0242 
0243 static int skcipher_wait_for_data(struct sock *sk, unsigned flags)
0244 {
0245     DEFINE_WAIT_FUNC(wait, woken_wake_function);
0246     struct alg_sock *ask = alg_sk(sk);
0247     struct skcipher_ctx *ctx = ask->private;
0248     long timeout;
0249     int err = -ERESTARTSYS;
0250 
0251     if (flags & MSG_DONTWAIT) {
0252         return -EAGAIN;
0253     }
0254 
0255     sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
0256 
0257     add_wait_queue(sk_sleep(sk), &wait);
0258     for (;;) {
0259         if (signal_pending(current))
0260             break;
0261         timeout = MAX_SCHEDULE_TIMEOUT;
0262         if (sk_wait_event(sk, &timeout, ctx->used, &wait)) {
0263             err = 0;
0264             break;
0265         }
0266     }
0267     remove_wait_queue(sk_sleep(sk), &wait);
0268 
0269     sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
0270 
0271     return err;
0272 }
0273 
0274 static void skcipher_data_wakeup(struct sock *sk)
0275 {
0276     struct alg_sock *ask = alg_sk(sk);
0277     struct skcipher_ctx *ctx = ask->private;
0278     struct socket_wq *wq;
0279 
0280     if (!ctx->used)
0281         return;
0282 
0283     rcu_read_lock();
0284     wq = rcu_dereference(sk->sk_wq);
0285     if (skwq_has_sleeper(wq))
0286         wake_up_interruptible_sync_poll(&wq->wait, POLLOUT |
0287                                POLLRDNORM |
0288                                POLLRDBAND);
0289     sk_wake_async(sk, SOCK_WAKE_SPACE, POLL_OUT);
0290     rcu_read_unlock();
0291 }
0292 
0293 static int skcipher_sendmsg(struct socket *sock, struct msghdr *msg,
0294                 size_t size)
0295 {
0296     struct sock *sk = sock->sk;
0297     struct alg_sock *ask = alg_sk(sk);
0298     struct sock *psk = ask->parent;
0299     struct alg_sock *pask = alg_sk(psk);
0300     struct skcipher_ctx *ctx = ask->private;
0301     struct skcipher_tfm *skc = pask->private;
0302     struct crypto_skcipher *tfm = skc->skcipher;
0303     unsigned ivsize = crypto_skcipher_ivsize(tfm);
0304     struct skcipher_sg_list *sgl;
0305     struct af_alg_control con = {};
0306     long copied = 0;
0307     bool enc = 0;
0308     bool init = 0;
0309     int err;
0310     int i;
0311 
0312     if (msg->msg_controllen) {
0313         err = af_alg_cmsg_send(msg, &con);
0314         if (err)
0315             return err;
0316 
0317         init = 1;
0318         switch (con.op) {
0319         case ALG_OP_ENCRYPT:
0320             enc = 1;
0321             break;
0322         case ALG_OP_DECRYPT:
0323             enc = 0;
0324             break;
0325         default:
0326             return -EINVAL;
0327         }
0328 
0329         if (con.iv && con.iv->ivlen != ivsize)
0330             return -EINVAL;
0331     }
0332 
0333     err = -EINVAL;
0334 
0335     lock_sock(sk);
0336     if (!ctx->more && ctx->used)
0337         goto unlock;
0338 
0339     if (init) {
0340         ctx->enc = enc;
0341         if (con.iv)
0342             memcpy(ctx->iv, con.iv->iv, ivsize);
0343     }
0344 
0345     while (size) {
0346         struct scatterlist *sg;
0347         unsigned long len = size;
0348         size_t plen;
0349 
0350         if (ctx->merge) {
0351             sgl = list_entry(ctx->tsgl.prev,
0352                      struct skcipher_sg_list, list);
0353             sg = sgl->sg + sgl->cur - 1;
0354             len = min_t(unsigned long, len,
0355                     PAGE_SIZE - sg->offset - sg->length);
0356 
0357             err = memcpy_from_msg(page_address(sg_page(sg)) +
0358                           sg->offset + sg->length,
0359                           msg, len);
0360             if (err)
0361                 goto unlock;
0362 
0363             sg->length += len;
0364             ctx->merge = (sg->offset + sg->length) &
0365                      (PAGE_SIZE - 1);
0366 
0367             ctx->used += len;
0368             copied += len;
0369             size -= len;
0370             continue;
0371         }
0372 
0373         if (!skcipher_writable(sk)) {
0374             err = skcipher_wait_for_wmem(sk, msg->msg_flags);
0375             if (err)
0376                 goto unlock;
0377         }
0378 
0379         len = min_t(unsigned long, len, skcipher_sndbuf(sk));
0380 
0381         err = skcipher_alloc_sgl(sk);
0382         if (err)
0383             goto unlock;
0384 
0385         sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
0386         sg = sgl->sg;
0387         if (sgl->cur)
0388             sg_unmark_end(sg + sgl->cur - 1);
0389         do {
0390             i = sgl->cur;
0391             plen = min_t(size_t, len, PAGE_SIZE);
0392 
0393             sg_assign_page(sg + i, alloc_page(GFP_KERNEL));
0394             err = -ENOMEM;
0395             if (!sg_page(sg + i))
0396                 goto unlock;
0397 
0398             err = memcpy_from_msg(page_address(sg_page(sg + i)),
0399                           msg, plen);
0400             if (err) {
0401                 __free_page(sg_page(sg + i));
0402                 sg_assign_page(sg + i, NULL);
0403                 goto unlock;
0404             }
0405 
0406             sg[i].length = plen;
0407             len -= plen;
0408             ctx->used += plen;
0409             copied += plen;
0410             size -= plen;
0411             sgl->cur++;
0412         } while (len && sgl->cur < MAX_SGL_ENTS);
0413 
0414         if (!size)
0415             sg_mark_end(sg + sgl->cur - 1);
0416 
0417         ctx->merge = plen & (PAGE_SIZE - 1);
0418     }
0419 
0420     err = 0;
0421 
0422     ctx->more = msg->msg_flags & MSG_MORE;
0423 
0424 unlock:
0425     skcipher_data_wakeup(sk);
0426     release_sock(sk);
0427 
0428     return copied ?: err;
0429 }
0430 
0431 static ssize_t skcipher_sendpage(struct socket *sock, struct page *page,
0432                  int offset, size_t size, int flags)
0433 {
0434     struct sock *sk = sock->sk;
0435     struct alg_sock *ask = alg_sk(sk);
0436     struct skcipher_ctx *ctx = ask->private;
0437     struct skcipher_sg_list *sgl;
0438     int err = -EINVAL;
0439 
0440     if (flags & MSG_SENDPAGE_NOTLAST)
0441         flags |= MSG_MORE;
0442 
0443     lock_sock(sk);
0444     if (!ctx->more && ctx->used)
0445         goto unlock;
0446 
0447     if (!size)
0448         goto done;
0449 
0450     if (!skcipher_writable(sk)) {
0451         err = skcipher_wait_for_wmem(sk, flags);
0452         if (err)
0453             goto unlock;
0454     }
0455 
0456     err = skcipher_alloc_sgl(sk);
0457     if (err)
0458         goto unlock;
0459 
0460     ctx->merge = 0;
0461     sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
0462 
0463     if (sgl->cur)
0464         sg_unmark_end(sgl->sg + sgl->cur - 1);
0465 
0466     sg_mark_end(sgl->sg + sgl->cur);
0467     get_page(page);
0468     sg_set_page(sgl->sg + sgl->cur, page, size, offset);
0469     sgl->cur++;
0470     ctx->used += size;
0471 
0472 done:
0473     ctx->more = flags & MSG_MORE;
0474 
0475 unlock:
0476     skcipher_data_wakeup(sk);
0477     release_sock(sk);
0478 
0479     return err ?: size;
0480 }
0481 
0482 static int skcipher_all_sg_nents(struct skcipher_ctx *ctx)
0483 {
0484     struct skcipher_sg_list *sgl;
0485     struct scatterlist *sg;
0486     int nents = 0;
0487 
0488     list_for_each_entry(sgl, &ctx->tsgl, list) {
0489         sg = sgl->sg;
0490 
0491         while (!sg->length)
0492             sg++;
0493 
0494         nents += sg_nents(sg);
0495     }
0496     return nents;
0497 }
0498 
0499 static int skcipher_recvmsg_async(struct socket *sock, struct msghdr *msg,
0500                   int flags)
0501 {
0502     struct sock *sk = sock->sk;
0503     struct alg_sock *ask = alg_sk(sk);
0504     struct sock *psk = ask->parent;
0505     struct alg_sock *pask = alg_sk(psk);
0506     struct skcipher_ctx *ctx = ask->private;
0507     struct skcipher_tfm *skc = pask->private;
0508     struct crypto_skcipher *tfm = skc->skcipher;
0509     struct skcipher_sg_list *sgl;
0510     struct scatterlist *sg;
0511     struct skcipher_async_req *sreq;
0512     struct skcipher_request *req;
0513     struct skcipher_async_rsgl *last_rsgl = NULL;
0514     unsigned int txbufs = 0, len = 0, tx_nents;
0515     unsigned int reqsize = crypto_skcipher_reqsize(tfm);
0516     unsigned int ivsize = crypto_skcipher_ivsize(tfm);
0517     int err = -ENOMEM;
0518     bool mark = false;
0519     char *iv;
0520 
0521     sreq = kzalloc(sizeof(*sreq) + reqsize + ivsize, GFP_KERNEL);
0522     if (unlikely(!sreq))
0523         goto out;
0524 
0525     req = &sreq->req;
0526     iv = (char *)(req + 1) + reqsize;
0527     sreq->iocb = msg->msg_iocb;
0528     INIT_LIST_HEAD(&sreq->list);
0529     sreq->inflight = &ctx->inflight;
0530 
0531     lock_sock(sk);
0532     tx_nents = skcipher_all_sg_nents(ctx);
0533     sreq->tsg = kcalloc(tx_nents, sizeof(*sg), GFP_KERNEL);
0534     if (unlikely(!sreq->tsg))
0535         goto unlock;
0536     sg_init_table(sreq->tsg, tx_nents);
0537     memcpy(iv, ctx->iv, ivsize);
0538     skcipher_request_set_tfm(req, tfm);
0539     skcipher_request_set_callback(req, CRYPTO_TFM_REQ_MAY_SLEEP,
0540                       skcipher_async_cb, sreq);
0541 
0542     while (iov_iter_count(&msg->msg_iter)) {
0543         struct skcipher_async_rsgl *rsgl;
0544         int used;
0545 
0546         if (!ctx->used) {
0547             err = skcipher_wait_for_data(sk, flags);
0548             if (err)
0549                 goto free;
0550         }
0551         sgl = list_first_entry(&ctx->tsgl,
0552                        struct skcipher_sg_list, list);
0553         sg = sgl->sg;
0554 
0555         while (!sg->length)
0556             sg++;
0557 
0558         used = min_t(unsigned long, ctx->used,
0559                  iov_iter_count(&msg->msg_iter));
0560         used = min_t(unsigned long, used, sg->length);
0561 
0562         if (txbufs == tx_nents) {
0563             struct scatterlist *tmp;
0564             int x;
0565             /* Ran out of tx slots in async request
0566              * need to expand */
0567             tmp = kcalloc(tx_nents * 2, sizeof(*tmp),
0568                       GFP_KERNEL);
0569             if (!tmp) {
0570                 err = -ENOMEM;
0571                 goto free;
0572             }
0573 
0574             sg_init_table(tmp, tx_nents * 2);
0575             for (x = 0; x < tx_nents; x++)
0576                 sg_set_page(&tmp[x], sg_page(&sreq->tsg[x]),
0577                         sreq->tsg[x].length,
0578                         sreq->tsg[x].offset);
0579             kfree(sreq->tsg);
0580             sreq->tsg = tmp;
0581             tx_nents *= 2;
0582             mark = true;
0583         }
0584         /* Need to take over the tx sgl from ctx
0585          * to the asynch req - these sgls will be freed later */
0586         sg_set_page(sreq->tsg + txbufs++, sg_page(sg), sg->length,
0587                 sg->offset);
0588 
0589         if (list_empty(&sreq->list)) {
0590             rsgl = &sreq->first_sgl;
0591             list_add_tail(&rsgl->list, &sreq->list);
0592         } else {
0593             rsgl = kmalloc(sizeof(*rsgl), GFP_KERNEL);
0594             if (!rsgl) {
0595                 err = -ENOMEM;
0596                 goto free;
0597             }
0598             list_add_tail(&rsgl->list, &sreq->list);
0599         }
0600 
0601         used = af_alg_make_sg(&rsgl->sgl, &msg->msg_iter, used);
0602         err = used;
0603         if (used < 0)
0604             goto free;
0605         if (last_rsgl)
0606             af_alg_link_sg(&last_rsgl->sgl, &rsgl->sgl);
0607 
0608         last_rsgl = rsgl;
0609         len += used;
0610         skcipher_pull_sgl(sk, used, 0);
0611         iov_iter_advance(&msg->msg_iter, used);
0612     }
0613 
0614     if (mark)
0615         sg_mark_end(sreq->tsg + txbufs - 1);
0616 
0617     skcipher_request_set_crypt(req, sreq->tsg, sreq->first_sgl.sgl.sg,
0618                    len, iv);
0619     err = ctx->enc ? crypto_skcipher_encrypt(req) :
0620              crypto_skcipher_decrypt(req);
0621     if (err == -EINPROGRESS) {
0622         atomic_inc(&ctx->inflight);
0623         err = -EIOCBQUEUED;
0624         sreq = NULL;
0625         goto unlock;
0626     }
0627 free:
0628     skcipher_free_async_sgls(sreq);
0629 unlock:
0630     skcipher_wmem_wakeup(sk);
0631     release_sock(sk);
0632     kzfree(sreq);
0633 out:
0634     return err;
0635 }
0636 
0637 static int skcipher_recvmsg_sync(struct socket *sock, struct msghdr *msg,
0638                  int flags)
0639 {
0640     struct sock *sk = sock->sk;
0641     struct alg_sock *ask = alg_sk(sk);
0642     struct sock *psk = ask->parent;
0643     struct alg_sock *pask = alg_sk(psk);
0644     struct skcipher_ctx *ctx = ask->private;
0645     struct skcipher_tfm *skc = pask->private;
0646     struct crypto_skcipher *tfm = skc->skcipher;
0647     unsigned bs = crypto_skcipher_blocksize(tfm);
0648     struct skcipher_sg_list *sgl;
0649     struct scatterlist *sg;
0650     int err = -EAGAIN;
0651     int used;
0652     long copied = 0;
0653 
0654     lock_sock(sk);
0655     while (msg_data_left(msg)) {
0656         if (!ctx->used) {
0657             err = skcipher_wait_for_data(sk, flags);
0658             if (err)
0659                 goto unlock;
0660         }
0661 
0662         used = min_t(unsigned long, ctx->used, msg_data_left(msg));
0663 
0664         used = af_alg_make_sg(&ctx->rsgl, &msg->msg_iter, used);
0665         err = used;
0666         if (err < 0)
0667             goto unlock;
0668 
0669         if (ctx->more || used < ctx->used)
0670             used -= used % bs;
0671 
0672         err = -EINVAL;
0673         if (!used)
0674             goto free;
0675 
0676         sgl = list_first_entry(&ctx->tsgl,
0677                        struct skcipher_sg_list, list);
0678         sg = sgl->sg;
0679 
0680         while (!sg->length)
0681             sg++;
0682 
0683         skcipher_request_set_crypt(&ctx->req, sg, ctx->rsgl.sg, used,
0684                        ctx->iv);
0685 
0686         err = af_alg_wait_for_completion(
0687                 ctx->enc ?
0688                     crypto_skcipher_encrypt(&ctx->req) :
0689                     crypto_skcipher_decrypt(&ctx->req),
0690                 &ctx->completion);
0691 
0692 free:
0693         af_alg_free_sg(&ctx->rsgl);
0694 
0695         if (err)
0696             goto unlock;
0697 
0698         copied += used;
0699         skcipher_pull_sgl(sk, used, 1);
0700         iov_iter_advance(&msg->msg_iter, used);
0701     }
0702 
0703     err = 0;
0704 
0705 unlock:
0706     skcipher_wmem_wakeup(sk);
0707     release_sock(sk);
0708 
0709     return copied ?: err;
0710 }
0711 
0712 static int skcipher_recvmsg(struct socket *sock, struct msghdr *msg,
0713                 size_t ignored, int flags)
0714 {
0715     return (msg->msg_iocb && !is_sync_kiocb(msg->msg_iocb)) ?
0716         skcipher_recvmsg_async(sock, msg, flags) :
0717         skcipher_recvmsg_sync(sock, msg, flags);
0718 }
0719 
0720 static unsigned int skcipher_poll(struct file *file, struct socket *sock,
0721                   poll_table *wait)
0722 {
0723     struct sock *sk = sock->sk;
0724     struct alg_sock *ask = alg_sk(sk);
0725     struct skcipher_ctx *ctx = ask->private;
0726     unsigned int mask;
0727 
0728     sock_poll_wait(file, sk_sleep(sk), wait);
0729     mask = 0;
0730 
0731     if (ctx->used)
0732         mask |= POLLIN | POLLRDNORM;
0733 
0734     if (skcipher_writable(sk))
0735         mask |= POLLOUT | POLLWRNORM | POLLWRBAND;
0736 
0737     return mask;
0738 }
0739 
0740 static struct proto_ops algif_skcipher_ops = {
0741     .family     =   PF_ALG,
0742 
0743     .connect    =   sock_no_connect,
0744     .socketpair =   sock_no_socketpair,
0745     .getname    =   sock_no_getname,
0746     .ioctl      =   sock_no_ioctl,
0747     .listen     =   sock_no_listen,
0748     .shutdown   =   sock_no_shutdown,
0749     .getsockopt =   sock_no_getsockopt,
0750     .mmap       =   sock_no_mmap,
0751     .bind       =   sock_no_bind,
0752     .accept     =   sock_no_accept,
0753     .setsockopt =   sock_no_setsockopt,
0754 
0755     .release    =   af_alg_release,
0756     .sendmsg    =   skcipher_sendmsg,
0757     .sendpage   =   skcipher_sendpage,
0758     .recvmsg    =   skcipher_recvmsg,
0759     .poll       =   skcipher_poll,
0760 };
0761 
0762 static int skcipher_check_key(struct socket *sock)
0763 {
0764     int err = 0;
0765     struct sock *psk;
0766     struct alg_sock *pask;
0767     struct skcipher_tfm *tfm;
0768     struct sock *sk = sock->sk;
0769     struct alg_sock *ask = alg_sk(sk);
0770 
0771     lock_sock(sk);
0772     if (ask->refcnt)
0773         goto unlock_child;
0774 
0775     psk = ask->parent;
0776     pask = alg_sk(ask->parent);
0777     tfm = pask->private;
0778 
0779     err = -ENOKEY;
0780     lock_sock_nested(psk, SINGLE_DEPTH_NESTING);
0781     if (!tfm->has_key)
0782         goto unlock;
0783 
0784     if (!pask->refcnt++)
0785         sock_hold(psk);
0786 
0787     ask->refcnt = 1;
0788     sock_put(psk);
0789 
0790     err = 0;
0791 
0792 unlock:
0793     release_sock(psk);
0794 unlock_child:
0795     release_sock(sk);
0796 
0797     return err;
0798 }
0799 
0800 static int skcipher_sendmsg_nokey(struct socket *sock, struct msghdr *msg,
0801                   size_t size)
0802 {
0803     int err;
0804 
0805     err = skcipher_check_key(sock);
0806     if (err)
0807         return err;
0808 
0809     return skcipher_sendmsg(sock, msg, size);
0810 }
0811 
0812 static ssize_t skcipher_sendpage_nokey(struct socket *sock, struct page *page,
0813                        int offset, size_t size, int flags)
0814 {
0815     int err;
0816 
0817     err = skcipher_check_key(sock);
0818     if (err)
0819         return err;
0820 
0821     return skcipher_sendpage(sock, page, offset, size, flags);
0822 }
0823 
0824 static int skcipher_recvmsg_nokey(struct socket *sock, struct msghdr *msg,
0825                   size_t ignored, int flags)
0826 {
0827     int err;
0828 
0829     err = skcipher_check_key(sock);
0830     if (err)
0831         return err;
0832 
0833     return skcipher_recvmsg(sock, msg, ignored, flags);
0834 }
0835 
0836 static struct proto_ops algif_skcipher_ops_nokey = {
0837     .family     =   PF_ALG,
0838 
0839     .connect    =   sock_no_connect,
0840     .socketpair =   sock_no_socketpair,
0841     .getname    =   sock_no_getname,
0842     .ioctl      =   sock_no_ioctl,
0843     .listen     =   sock_no_listen,
0844     .shutdown   =   sock_no_shutdown,
0845     .getsockopt =   sock_no_getsockopt,
0846     .mmap       =   sock_no_mmap,
0847     .bind       =   sock_no_bind,
0848     .accept     =   sock_no_accept,
0849     .setsockopt =   sock_no_setsockopt,
0850 
0851     .release    =   af_alg_release,
0852     .sendmsg    =   skcipher_sendmsg_nokey,
0853     .sendpage   =   skcipher_sendpage_nokey,
0854     .recvmsg    =   skcipher_recvmsg_nokey,
0855     .poll       =   skcipher_poll,
0856 };
0857 
0858 static void *skcipher_bind(const char *name, u32 type, u32 mask)
0859 {
0860     struct skcipher_tfm *tfm;
0861     struct crypto_skcipher *skcipher;
0862 
0863     tfm = kzalloc(sizeof(*tfm), GFP_KERNEL);
0864     if (!tfm)
0865         return ERR_PTR(-ENOMEM);
0866 
0867     skcipher = crypto_alloc_skcipher(name, type, mask);
0868     if (IS_ERR(skcipher)) {
0869         kfree(tfm);
0870         return ERR_CAST(skcipher);
0871     }
0872 
0873     tfm->skcipher = skcipher;
0874 
0875     return tfm;
0876 }
0877 
0878 static void skcipher_release(void *private)
0879 {
0880     struct skcipher_tfm *tfm = private;
0881 
0882     crypto_free_skcipher(tfm->skcipher);
0883     kfree(tfm);
0884 }
0885 
0886 static int skcipher_setkey(void *private, const u8 *key, unsigned int keylen)
0887 {
0888     struct skcipher_tfm *tfm = private;
0889     int err;
0890 
0891     err = crypto_skcipher_setkey(tfm->skcipher, key, keylen);
0892     tfm->has_key = !err;
0893 
0894     return err;
0895 }
0896 
0897 static void skcipher_wait(struct sock *sk)
0898 {
0899     struct alg_sock *ask = alg_sk(sk);
0900     struct skcipher_ctx *ctx = ask->private;
0901     int ctr = 0;
0902 
0903     while (atomic_read(&ctx->inflight) && ctr++ < 100)
0904         msleep(100);
0905 }
0906 
0907 static void skcipher_sock_destruct(struct sock *sk)
0908 {
0909     struct alg_sock *ask = alg_sk(sk);
0910     struct skcipher_ctx *ctx = ask->private;
0911     struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(&ctx->req);
0912 
0913     if (atomic_read(&ctx->inflight))
0914         skcipher_wait(sk);
0915 
0916     skcipher_free_sgl(sk);
0917     sock_kzfree_s(sk, ctx->iv, crypto_skcipher_ivsize(tfm));
0918     sock_kfree_s(sk, ctx, ctx->len);
0919     af_alg_release_parent(sk);
0920 }
0921 
0922 static int skcipher_accept_parent_nokey(void *private, struct sock *sk)
0923 {
0924     struct skcipher_ctx *ctx;
0925     struct alg_sock *ask = alg_sk(sk);
0926     struct skcipher_tfm *tfm = private;
0927     struct crypto_skcipher *skcipher = tfm->skcipher;
0928     unsigned int len = sizeof(*ctx) + crypto_skcipher_reqsize(skcipher);
0929 
0930     ctx = sock_kmalloc(sk, len, GFP_KERNEL);
0931     if (!ctx)
0932         return -ENOMEM;
0933 
0934     ctx->iv = sock_kmalloc(sk, crypto_skcipher_ivsize(skcipher),
0935                    GFP_KERNEL);
0936     if (!ctx->iv) {
0937         sock_kfree_s(sk, ctx, len);
0938         return -ENOMEM;
0939     }
0940 
0941     memset(ctx->iv, 0, crypto_skcipher_ivsize(skcipher));
0942 
0943     INIT_LIST_HEAD(&ctx->tsgl);
0944     ctx->len = len;
0945     ctx->used = 0;
0946     ctx->more = 0;
0947     ctx->merge = 0;
0948     ctx->enc = 0;
0949     atomic_set(&ctx->inflight, 0);
0950     af_alg_init_completion(&ctx->completion);
0951 
0952     ask->private = ctx;
0953 
0954     skcipher_request_set_tfm(&ctx->req, skcipher);
0955     skcipher_request_set_callback(&ctx->req, CRYPTO_TFM_REQ_MAY_SLEEP |
0956                          CRYPTO_TFM_REQ_MAY_BACKLOG,
0957                       af_alg_complete, &ctx->completion);
0958 
0959     sk->sk_destruct = skcipher_sock_destruct;
0960 
0961     return 0;
0962 }
0963 
0964 static int skcipher_accept_parent(void *private, struct sock *sk)
0965 {
0966     struct skcipher_tfm *tfm = private;
0967 
0968     if (!tfm->has_key && crypto_skcipher_has_setkey(tfm->skcipher))
0969         return -ENOKEY;
0970 
0971     return skcipher_accept_parent_nokey(private, sk);
0972 }
0973 
0974 static const struct af_alg_type algif_type_skcipher = {
0975     .bind       =   skcipher_bind,
0976     .release    =   skcipher_release,
0977     .setkey     =   skcipher_setkey,
0978     .accept     =   skcipher_accept_parent,
0979     .accept_nokey   =   skcipher_accept_parent_nokey,
0980     .ops        =   &algif_skcipher_ops,
0981     .ops_nokey  =   &algif_skcipher_ops_nokey,
0982     .name       =   "skcipher",
0983     .owner      =   THIS_MODULE
0984 };
0985 
0986 static int __init algif_skcipher_init(void)
0987 {
0988     return af_alg_register_type(&algif_type_skcipher);
0989 }
0990 
0991 static void __exit algif_skcipher_exit(void)
0992 {
0993     int err = af_alg_unregister_type(&algif_type_skcipher);
0994     BUG_ON(err);
0995 }
0996 
0997 module_init(algif_skcipher_init);
0998 module_exit(algif_skcipher_exit);
0999 MODULE_LICENSE("GPL");