0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032 #include <net/tls.h>
0033 #include <crypto/aead.h>
0034 #include <crypto/scatterwalk.h>
0035 #include <net/ip6_checksum.h>
0036
0037 #include "tls.h"
0038
0039 static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk)
0040 {
0041 struct scatterlist *src = walk->sg;
0042 int diff = walk->offset - src->offset;
0043
0044 sg_set_page(sg, sg_page(src),
0045 src->length - diff, walk->offset);
0046
0047 scatterwalk_crypto_chain(sg, sg_next(src), 2);
0048 }
0049
0050 static int tls_enc_record(struct aead_request *aead_req,
0051 struct crypto_aead *aead, char *aad,
0052 char *iv, __be64 rcd_sn,
0053 struct scatter_walk *in,
0054 struct scatter_walk *out, int *in_len,
0055 struct tls_prot_info *prot)
0056 {
0057 unsigned char buf[TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE];
0058 struct scatterlist sg_in[3];
0059 struct scatterlist sg_out[3];
0060 u16 len;
0061 int rc;
0062
0063 len = min_t(int, *in_len, ARRAY_SIZE(buf));
0064
0065 scatterwalk_copychunks(buf, in, len, 0);
0066 scatterwalk_copychunks(buf, out, len, 1);
0067
0068 *in_len -= len;
0069 if (!*in_len)
0070 return 0;
0071
0072 scatterwalk_pagedone(in, 0, 1);
0073 scatterwalk_pagedone(out, 1, 1);
0074
0075 len = buf[4] | (buf[3] << 8);
0076 len -= TLS_CIPHER_AES_GCM_128_IV_SIZE;
0077
0078 tls_make_aad(aad, len - TLS_CIPHER_AES_GCM_128_TAG_SIZE,
0079 (char *)&rcd_sn, buf[0], prot);
0080
0081 memcpy(iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, buf + TLS_HEADER_SIZE,
0082 TLS_CIPHER_AES_GCM_128_IV_SIZE);
0083
0084 sg_init_table(sg_in, ARRAY_SIZE(sg_in));
0085 sg_init_table(sg_out, ARRAY_SIZE(sg_out));
0086 sg_set_buf(sg_in, aad, TLS_AAD_SPACE_SIZE);
0087 sg_set_buf(sg_out, aad, TLS_AAD_SPACE_SIZE);
0088 chain_to_walk(sg_in + 1, in);
0089 chain_to_walk(sg_out + 1, out);
0090
0091 *in_len -= len;
0092 if (*in_len < 0) {
0093 *in_len += TLS_CIPHER_AES_GCM_128_TAG_SIZE;
0094
0095
0096
0097
0098
0099
0100
0101 if (*in_len < 0)
0102 len += *in_len;
0103
0104 *in_len = 0;
0105 }
0106
0107 if (*in_len) {
0108 scatterwalk_copychunks(NULL, in, len, 2);
0109 scatterwalk_pagedone(in, 0, 1);
0110 scatterwalk_copychunks(NULL, out, len, 2);
0111 scatterwalk_pagedone(out, 1, 1);
0112 }
0113
0114 len -= TLS_CIPHER_AES_GCM_128_TAG_SIZE;
0115 aead_request_set_crypt(aead_req, sg_in, sg_out, len, iv);
0116
0117 rc = crypto_aead_encrypt(aead_req);
0118
0119 return rc;
0120 }
0121
0122 static void tls_init_aead_request(struct aead_request *aead_req,
0123 struct crypto_aead *aead)
0124 {
0125 aead_request_set_tfm(aead_req, aead);
0126 aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
0127 }
0128
0129 static struct aead_request *tls_alloc_aead_request(struct crypto_aead *aead,
0130 gfp_t flags)
0131 {
0132 unsigned int req_size = sizeof(struct aead_request) +
0133 crypto_aead_reqsize(aead);
0134 struct aead_request *aead_req;
0135
0136 aead_req = kzalloc(req_size, flags);
0137 if (aead_req)
0138 tls_init_aead_request(aead_req, aead);
0139 return aead_req;
0140 }
0141
0142 static int tls_enc_records(struct aead_request *aead_req,
0143 struct crypto_aead *aead, struct scatterlist *sg_in,
0144 struct scatterlist *sg_out, char *aad, char *iv,
0145 u64 rcd_sn, int len, struct tls_prot_info *prot)
0146 {
0147 struct scatter_walk out, in;
0148 int rc;
0149
0150 scatterwalk_start(&in, sg_in);
0151 scatterwalk_start(&out, sg_out);
0152
0153 do {
0154 rc = tls_enc_record(aead_req, aead, aad, iv,
0155 cpu_to_be64(rcd_sn), &in, &out, &len, prot);
0156 rcd_sn++;
0157
0158 } while (rc == 0 && len);
0159
0160 scatterwalk_done(&in, 0, 0);
0161 scatterwalk_done(&out, 1, 0);
0162
0163 return rc;
0164 }
0165
0166
0167
0168
0169 static void update_chksum(struct sk_buff *skb, int headln)
0170 {
0171 struct tcphdr *th = tcp_hdr(skb);
0172 int datalen = skb->len - headln;
0173 const struct ipv6hdr *ipv6h;
0174 const struct iphdr *iph;
0175
0176
0177
0178
0179 if (likely(skb->ip_summed == CHECKSUM_PARTIAL))
0180 return;
0181
0182 skb->ip_summed = CHECKSUM_PARTIAL;
0183 skb->csum_start = skb_transport_header(skb) - skb->head;
0184 skb->csum_offset = offsetof(struct tcphdr, check);
0185
0186 if (skb->sk->sk_family == AF_INET6) {
0187 ipv6h = ipv6_hdr(skb);
0188 th->check = ~csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr,
0189 datalen, IPPROTO_TCP, 0);
0190 } else {
0191 iph = ip_hdr(skb);
0192 th->check = ~csum_tcpudp_magic(iph->saddr, iph->daddr, datalen,
0193 IPPROTO_TCP, 0);
0194 }
0195 }
0196
0197 static void complete_skb(struct sk_buff *nskb, struct sk_buff *skb, int headln)
0198 {
0199 struct sock *sk = skb->sk;
0200 int delta;
0201
0202 skb_copy_header(nskb, skb);
0203
0204 skb_put(nskb, skb->len);
0205 memcpy(nskb->data, skb->data, headln);
0206
0207 nskb->destructor = skb->destructor;
0208 nskb->sk = sk;
0209 skb->destructor = NULL;
0210 skb->sk = NULL;
0211
0212 update_chksum(nskb, headln);
0213
0214
0215 if (nskb->destructor == sock_efree)
0216 return;
0217
0218 delta = nskb->truesize - skb->truesize;
0219 if (likely(delta < 0))
0220 WARN_ON_ONCE(refcount_sub_and_test(-delta, &sk->sk_wmem_alloc));
0221 else if (delta)
0222 refcount_add(delta, &sk->sk_wmem_alloc);
0223 }
0224
0225
0226
0227
0228
0229
0230 static int fill_sg_in(struct scatterlist *sg_in,
0231 struct sk_buff *skb,
0232 struct tls_offload_context_tx *ctx,
0233 u64 *rcd_sn,
0234 s32 *sync_size,
0235 int *resync_sgs)
0236 {
0237 int tcp_payload_offset = skb_tcp_all_headers(skb);
0238 int payload_len = skb->len - tcp_payload_offset;
0239 u32 tcp_seq = ntohl(tcp_hdr(skb)->seq);
0240 struct tls_record_info *record;
0241 unsigned long flags;
0242 int remaining;
0243 int i;
0244
0245 spin_lock_irqsave(&ctx->lock, flags);
0246 record = tls_get_record(ctx, tcp_seq, rcd_sn);
0247 if (!record) {
0248 spin_unlock_irqrestore(&ctx->lock, flags);
0249 return -EINVAL;
0250 }
0251
0252 *sync_size = tcp_seq - tls_record_start_seq(record);
0253 if (*sync_size < 0) {
0254 int is_start_marker = tls_record_is_start_marker(record);
0255
0256 spin_unlock_irqrestore(&ctx->lock, flags);
0257
0258
0259
0260
0261
0262
0263
0264
0265
0266
0267 if (!is_start_marker)
0268 *sync_size = 0;
0269 return -EINVAL;
0270 }
0271
0272 remaining = *sync_size;
0273 for (i = 0; remaining > 0; i++) {
0274 skb_frag_t *frag = &record->frags[i];
0275
0276 __skb_frag_ref(frag);
0277 sg_set_page(sg_in + i, skb_frag_page(frag),
0278 skb_frag_size(frag), skb_frag_off(frag));
0279
0280 remaining -= skb_frag_size(frag);
0281
0282 if (remaining < 0)
0283 sg_in[i].length += remaining;
0284 }
0285 *resync_sgs = i;
0286
0287 spin_unlock_irqrestore(&ctx->lock, flags);
0288 if (skb_to_sgvec(skb, &sg_in[i], tcp_payload_offset, payload_len) < 0)
0289 return -EINVAL;
0290
0291 return 0;
0292 }
0293
0294 static void fill_sg_out(struct scatterlist sg_out[3], void *buf,
0295 struct tls_context *tls_ctx,
0296 struct sk_buff *nskb,
0297 int tcp_payload_offset,
0298 int payload_len,
0299 int sync_size,
0300 void *dummy_buf)
0301 {
0302 sg_set_buf(&sg_out[0], dummy_buf, sync_size);
0303 sg_set_buf(&sg_out[1], nskb->data + tcp_payload_offset, payload_len);
0304
0305 dummy_buf += sync_size;
0306 sg_set_buf(&sg_out[2], dummy_buf, TLS_CIPHER_AES_GCM_128_TAG_SIZE);
0307 }
0308
0309 static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx,
0310 struct scatterlist sg_out[3],
0311 struct scatterlist *sg_in,
0312 struct sk_buff *skb,
0313 s32 sync_size, u64 rcd_sn)
0314 {
0315 struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
0316 int tcp_payload_offset = skb_tcp_all_headers(skb);
0317 int payload_len = skb->len - tcp_payload_offset;
0318 void *buf, *iv, *aad, *dummy_buf;
0319 struct aead_request *aead_req;
0320 struct sk_buff *nskb = NULL;
0321 int buf_len;
0322
0323 aead_req = tls_alloc_aead_request(ctx->aead_send, GFP_ATOMIC);
0324 if (!aead_req)
0325 return NULL;
0326
0327 buf_len = TLS_CIPHER_AES_GCM_128_SALT_SIZE +
0328 TLS_CIPHER_AES_GCM_128_IV_SIZE +
0329 TLS_AAD_SPACE_SIZE +
0330 sync_size +
0331 TLS_CIPHER_AES_GCM_128_TAG_SIZE;
0332 buf = kmalloc(buf_len, GFP_ATOMIC);
0333 if (!buf)
0334 goto free_req;
0335
0336 iv = buf;
0337 memcpy(iv, tls_ctx->crypto_send.aes_gcm_128.salt,
0338 TLS_CIPHER_AES_GCM_128_SALT_SIZE);
0339 aad = buf + TLS_CIPHER_AES_GCM_128_SALT_SIZE +
0340 TLS_CIPHER_AES_GCM_128_IV_SIZE;
0341 dummy_buf = aad + TLS_AAD_SPACE_SIZE;
0342
0343 nskb = alloc_skb(skb_headroom(skb) + skb->len, GFP_ATOMIC);
0344 if (!nskb)
0345 goto free_buf;
0346
0347 skb_reserve(nskb, skb_headroom(skb));
0348
0349 fill_sg_out(sg_out, buf, tls_ctx, nskb, tcp_payload_offset,
0350 payload_len, sync_size, dummy_buf);
0351
0352 if (tls_enc_records(aead_req, ctx->aead_send, sg_in, sg_out, aad, iv,
0353 rcd_sn, sync_size + payload_len,
0354 &tls_ctx->prot_info) < 0)
0355 goto free_nskb;
0356
0357 complete_skb(nskb, skb, tcp_payload_offset);
0358
0359
0360
0361
0362 nskb->prev = nskb;
0363
0364 free_buf:
0365 kfree(buf);
0366 free_req:
0367 kfree(aead_req);
0368 return nskb;
0369 free_nskb:
0370 kfree_skb(nskb);
0371 nskb = NULL;
0372 goto free_buf;
0373 }
0374
0375 static struct sk_buff *tls_sw_fallback(struct sock *sk, struct sk_buff *skb)
0376 {
0377 int tcp_payload_offset = skb_tcp_all_headers(skb);
0378 struct tls_context *tls_ctx = tls_get_ctx(sk);
0379 struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
0380 int payload_len = skb->len - tcp_payload_offset;
0381 struct scatterlist *sg_in, sg_out[3];
0382 struct sk_buff *nskb = NULL;
0383 int sg_in_max_elements;
0384 int resync_sgs = 0;
0385 s32 sync_size = 0;
0386 u64 rcd_sn;
0387
0388
0389
0390
0391
0392 sg_in_max_elements = 2 * MAX_SKB_FRAGS + 1;
0393
0394 if (!payload_len)
0395 return skb;
0396
0397 sg_in = kmalloc_array(sg_in_max_elements, sizeof(*sg_in), GFP_ATOMIC);
0398 if (!sg_in)
0399 goto free_orig;
0400
0401 sg_init_table(sg_in, sg_in_max_elements);
0402 sg_init_table(sg_out, ARRAY_SIZE(sg_out));
0403
0404 if (fill_sg_in(sg_in, skb, ctx, &rcd_sn, &sync_size, &resync_sgs)) {
0405
0406 if (sync_size < 0 && payload_len <= -sync_size)
0407 nskb = skb_get(skb);
0408 goto put_sg;
0409 }
0410
0411 nskb = tls_enc_skb(tls_ctx, sg_out, sg_in, skb, sync_size, rcd_sn);
0412
0413 put_sg:
0414 while (resync_sgs)
0415 put_page(sg_page(&sg_in[--resync_sgs]));
0416 kfree(sg_in);
0417 free_orig:
0418 if (nskb)
0419 consume_skb(skb);
0420 else
0421 kfree_skb(skb);
0422 return nskb;
0423 }
0424
0425 struct sk_buff *tls_validate_xmit_skb(struct sock *sk,
0426 struct net_device *dev,
0427 struct sk_buff *skb)
0428 {
0429 if (dev == rcu_dereference_bh(tls_get_ctx(sk)->netdev) ||
0430 netif_is_bond_master(dev))
0431 return skb;
0432
0433 return tls_sw_fallback(sk, skb);
0434 }
0435 EXPORT_SYMBOL_GPL(tls_validate_xmit_skb);
0436
0437 struct sk_buff *tls_validate_xmit_skb_sw(struct sock *sk,
0438 struct net_device *dev,
0439 struct sk_buff *skb)
0440 {
0441 return tls_sw_fallback(sk, skb);
0442 }
0443
0444 struct sk_buff *tls_encrypt_skb(struct sk_buff *skb)
0445 {
0446 return tls_sw_fallback(skb->sk, skb);
0447 }
0448 EXPORT_SYMBOL_GPL(tls_encrypt_skb);
0449
0450 int tls_sw_fallback_init(struct sock *sk,
0451 struct tls_offload_context_tx *offload_ctx,
0452 struct tls_crypto_info *crypto_info)
0453 {
0454 const u8 *key;
0455 int rc;
0456
0457 offload_ctx->aead_send =
0458 crypto_alloc_aead("gcm(aes)", 0, CRYPTO_ALG_ASYNC);
0459 if (IS_ERR(offload_ctx->aead_send)) {
0460 rc = PTR_ERR(offload_ctx->aead_send);
0461 pr_err_ratelimited("crypto_alloc_aead failed rc=%d\n", rc);
0462 offload_ctx->aead_send = NULL;
0463 goto err_out;
0464 }
0465
0466 key = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->key;
0467
0468 rc = crypto_aead_setkey(offload_ctx->aead_send, key,
0469 TLS_CIPHER_AES_GCM_128_KEY_SIZE);
0470 if (rc)
0471 goto free_aead;
0472
0473 rc = crypto_aead_setauthsize(offload_ctx->aead_send,
0474 TLS_CIPHER_AES_GCM_128_TAG_SIZE);
0475 if (rc)
0476 goto free_aead;
0477
0478 return 0;
0479 free_aead:
0480 crypto_free_aead(offload_ctx->aead_send);
0481 err_out:
0482 return rc;
0483 }