0001
0002
0003
0004
0005
0006 #include "queueing.h"
0007 #include "device.h"
0008 #include "peer.h"
0009 #include "timers.h"
0010 #include "messages.h"
0011 #include "cookie.h"
0012 #include "socket.h"
0013
0014 #include <linux/ip.h>
0015 #include <linux/ipv6.h>
0016 #include <linux/udp.h>
0017 #include <net/ip_tunnels.h>
0018
0019
0020 static void update_rx_stats(struct wg_peer *peer, size_t len)
0021 {
0022 dev_sw_netstats_rx_add(peer->device->dev, len);
0023 peer->rx_bytes += len;
0024 }
0025
0026 #define SKB_TYPE_LE32(skb) (((struct message_header *)(skb)->data)->type)
0027
0028 static size_t validate_header_len(struct sk_buff *skb)
0029 {
0030 if (unlikely(skb->len < sizeof(struct message_header)))
0031 return 0;
0032 if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_DATA) &&
0033 skb->len >= MESSAGE_MINIMUM_LENGTH)
0034 return sizeof(struct message_data);
0035 if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION) &&
0036 skb->len == sizeof(struct message_handshake_initiation))
0037 return sizeof(struct message_handshake_initiation);
0038 if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE) &&
0039 skb->len == sizeof(struct message_handshake_response))
0040 return sizeof(struct message_handshake_response);
0041 if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE) &&
0042 skb->len == sizeof(struct message_handshake_cookie))
0043 return sizeof(struct message_handshake_cookie);
0044 return 0;
0045 }
0046
0047 static int prepare_skb_header(struct sk_buff *skb, struct wg_device *wg)
0048 {
0049 size_t data_offset, data_len, header_len;
0050 struct udphdr *udp;
0051
0052 if (unlikely(!wg_check_packet_protocol(skb) ||
0053 skb_transport_header(skb) < skb->head ||
0054 (skb_transport_header(skb) + sizeof(struct udphdr)) >
0055 skb_tail_pointer(skb)))
0056 return -EINVAL;
0057 udp = udp_hdr(skb);
0058 data_offset = (u8 *)udp - skb->data;
0059 if (unlikely(data_offset > U16_MAX ||
0060 data_offset + sizeof(struct udphdr) > skb->len))
0061
0062
0063
0064 return -EINVAL;
0065 data_len = ntohs(udp->len);
0066 if (unlikely(data_len < sizeof(struct udphdr) ||
0067 data_len > skb->len - data_offset))
0068
0069
0070
0071 return -EINVAL;
0072 data_len -= sizeof(struct udphdr);
0073 data_offset = (u8 *)udp + sizeof(struct udphdr) - skb->data;
0074 if (unlikely(!pskb_may_pull(skb,
0075 data_offset + sizeof(struct message_header)) ||
0076 pskb_trim(skb, data_len + data_offset) < 0))
0077 return -EINVAL;
0078 skb_pull(skb, data_offset);
0079 if (unlikely(skb->len != data_len))
0080
0081 return -EINVAL;
0082 header_len = validate_header_len(skb);
0083 if (unlikely(!header_len))
0084 return -EINVAL;
0085 __skb_push(skb, data_offset);
0086 if (unlikely(!pskb_may_pull(skb, data_offset + header_len)))
0087 return -EINVAL;
0088 __skb_pull(skb, data_offset);
0089 return 0;
0090 }
0091
0092 static void wg_receive_handshake_packet(struct wg_device *wg,
0093 struct sk_buff *skb)
0094 {
0095 enum cookie_mac_state mac_state;
0096 struct wg_peer *peer = NULL;
0097
0098
0099
0100 static u64 last_under_load;
0101 bool packet_needs_cookie;
0102 bool under_load;
0103
0104 if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE)) {
0105 net_dbg_skb_ratelimited("%s: Receiving cookie response from %pISpfsc\n",
0106 wg->dev->name, skb);
0107 wg_cookie_message_consume(
0108 (struct message_handshake_cookie *)skb->data, wg);
0109 return;
0110 }
0111
0112 under_load = atomic_read(&wg->handshake_queue_len) >=
0113 MAX_QUEUED_INCOMING_HANDSHAKES / 8;
0114 if (under_load) {
0115 last_under_load = ktime_get_coarse_boottime_ns();
0116 } else if (last_under_load) {
0117 under_load = !wg_birthdate_has_expired(last_under_load, 1);
0118 if (!under_load)
0119 last_under_load = 0;
0120 }
0121 mac_state = wg_cookie_validate_packet(&wg->cookie_checker, skb,
0122 under_load);
0123 if ((under_load && mac_state == VALID_MAC_WITH_COOKIE) ||
0124 (!under_load && mac_state == VALID_MAC_BUT_NO_COOKIE)) {
0125 packet_needs_cookie = false;
0126 } else if (under_load && mac_state == VALID_MAC_BUT_NO_COOKIE) {
0127 packet_needs_cookie = true;
0128 } else {
0129 net_dbg_skb_ratelimited("%s: Invalid MAC of handshake, dropping packet from %pISpfsc\n",
0130 wg->dev->name, skb);
0131 return;
0132 }
0133
0134 switch (SKB_TYPE_LE32(skb)) {
0135 case cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION): {
0136 struct message_handshake_initiation *message =
0137 (struct message_handshake_initiation *)skb->data;
0138
0139 if (packet_needs_cookie) {
0140 wg_packet_send_handshake_cookie(wg, skb,
0141 message->sender_index);
0142 return;
0143 }
0144 peer = wg_noise_handshake_consume_initiation(message, wg);
0145 if (unlikely(!peer)) {
0146 net_dbg_skb_ratelimited("%s: Invalid handshake initiation from %pISpfsc\n",
0147 wg->dev->name, skb);
0148 return;
0149 }
0150 wg_socket_set_peer_endpoint_from_skb(peer, skb);
0151 net_dbg_ratelimited("%s: Receiving handshake initiation from peer %llu (%pISpfsc)\n",
0152 wg->dev->name, peer->internal_id,
0153 &peer->endpoint.addr);
0154 wg_packet_send_handshake_response(peer);
0155 break;
0156 }
0157 case cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE): {
0158 struct message_handshake_response *message =
0159 (struct message_handshake_response *)skb->data;
0160
0161 if (packet_needs_cookie) {
0162 wg_packet_send_handshake_cookie(wg, skb,
0163 message->sender_index);
0164 return;
0165 }
0166 peer = wg_noise_handshake_consume_response(message, wg);
0167 if (unlikely(!peer)) {
0168 net_dbg_skb_ratelimited("%s: Invalid handshake response from %pISpfsc\n",
0169 wg->dev->name, skb);
0170 return;
0171 }
0172 wg_socket_set_peer_endpoint_from_skb(peer, skb);
0173 net_dbg_ratelimited("%s: Receiving handshake response from peer %llu (%pISpfsc)\n",
0174 wg->dev->name, peer->internal_id,
0175 &peer->endpoint.addr);
0176 if (wg_noise_handshake_begin_session(&peer->handshake,
0177 &peer->keypairs)) {
0178 wg_timers_session_derived(peer);
0179 wg_timers_handshake_complete(peer);
0180
0181
0182
0183
0184
0185
0186 wg_packet_send_keepalive(peer);
0187 }
0188 break;
0189 }
0190 }
0191
0192 if (unlikely(!peer)) {
0193 WARN(1, "Somehow a wrong type of packet wound up in the handshake queue!\n");
0194 return;
0195 }
0196
0197 local_bh_disable();
0198 update_rx_stats(peer, skb->len);
0199 local_bh_enable();
0200
0201 wg_timers_any_authenticated_packet_received(peer);
0202 wg_timers_any_authenticated_packet_traversal(peer);
0203 wg_peer_put(peer);
0204 }
0205
0206 void wg_packet_handshake_receive_worker(struct work_struct *work)
0207 {
0208 struct crypt_queue *queue = container_of(work, struct multicore_worker, work)->ptr;
0209 struct wg_device *wg = container_of(queue, struct wg_device, handshake_queue);
0210 struct sk_buff *skb;
0211
0212 while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) {
0213 wg_receive_handshake_packet(wg, skb);
0214 dev_kfree_skb(skb);
0215 atomic_dec(&wg->handshake_queue_len);
0216 cond_resched();
0217 }
0218 }
0219
0220 static void keep_key_fresh(struct wg_peer *peer)
0221 {
0222 struct noise_keypair *keypair;
0223 bool send;
0224
0225 if (peer->sent_lastminute_handshake)
0226 return;
0227
0228 rcu_read_lock_bh();
0229 keypair = rcu_dereference_bh(peer->keypairs.current_keypair);
0230 send = keypair && READ_ONCE(keypair->sending.is_valid) &&
0231 keypair->i_am_the_initiator &&
0232 wg_birthdate_has_expired(keypair->sending.birthdate,
0233 REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT);
0234 rcu_read_unlock_bh();
0235
0236 if (unlikely(send)) {
0237 peer->sent_lastminute_handshake = true;
0238 wg_packet_send_queued_handshake_initiation(peer, false);
0239 }
0240 }
0241
0242 static bool decrypt_packet(struct sk_buff *skb, struct noise_keypair *keypair)
0243 {
0244 struct scatterlist sg[MAX_SKB_FRAGS + 8];
0245 struct sk_buff *trailer;
0246 unsigned int offset;
0247 int num_frags;
0248
0249 if (unlikely(!keypair))
0250 return false;
0251
0252 if (unlikely(!READ_ONCE(keypair->receiving.is_valid) ||
0253 wg_birthdate_has_expired(keypair->receiving.birthdate, REJECT_AFTER_TIME) ||
0254 keypair->receiving_counter.counter >= REJECT_AFTER_MESSAGES)) {
0255 WRITE_ONCE(keypair->receiving.is_valid, false);
0256 return false;
0257 }
0258
0259 PACKET_CB(skb)->nonce =
0260 le64_to_cpu(((struct message_data *)skb->data)->counter);
0261
0262
0263
0264
0265
0266 offset = skb->data - skb_network_header(skb);
0267 skb_push(skb, offset);
0268 num_frags = skb_cow_data(skb, 0, &trailer);
0269 offset += sizeof(struct message_data);
0270 skb_pull(skb, offset);
0271 if (unlikely(num_frags < 0 || num_frags > ARRAY_SIZE(sg)))
0272 return false;
0273
0274 sg_init_table(sg, num_frags);
0275 if (skb_to_sgvec(skb, sg, 0, skb->len) <= 0)
0276 return false;
0277
0278 if (!chacha20poly1305_decrypt_sg_inplace(sg, skb->len, NULL, 0,
0279 PACKET_CB(skb)->nonce,
0280 keypair->receiving.key))
0281 return false;
0282
0283
0284
0285
0286 skb_push(skb, offset);
0287 if (pskb_trim(skb, skb->len - noise_encrypted_len(0)))
0288 return false;
0289 skb_pull(skb, offset);
0290
0291 return true;
0292 }
0293
0294
0295 static bool counter_validate(struct noise_replay_counter *counter, u64 their_counter)
0296 {
0297 unsigned long index, index_current, top, i;
0298 bool ret = false;
0299
0300 spin_lock_bh(&counter->lock);
0301
0302 if (unlikely(counter->counter >= REJECT_AFTER_MESSAGES + 1 ||
0303 their_counter >= REJECT_AFTER_MESSAGES))
0304 goto out;
0305
0306 ++their_counter;
0307
0308 if (unlikely((COUNTER_WINDOW_SIZE + their_counter) <
0309 counter->counter))
0310 goto out;
0311
0312 index = their_counter >> ilog2(BITS_PER_LONG);
0313
0314 if (likely(their_counter > counter->counter)) {
0315 index_current = counter->counter >> ilog2(BITS_PER_LONG);
0316 top = min_t(unsigned long, index - index_current,
0317 COUNTER_BITS_TOTAL / BITS_PER_LONG);
0318 for (i = 1; i <= top; ++i)
0319 counter->backtrack[(i + index_current) &
0320 ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0;
0321 counter->counter = their_counter;
0322 }
0323
0324 index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1;
0325 ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1),
0326 &counter->backtrack[index]);
0327
0328 out:
0329 spin_unlock_bh(&counter->lock);
0330 return ret;
0331 }
0332
0333 #include "selftest/counter.c"
0334
0335 static void wg_packet_consume_data_done(struct wg_peer *peer,
0336 struct sk_buff *skb,
0337 struct endpoint *endpoint)
0338 {
0339 struct net_device *dev = peer->device->dev;
0340 unsigned int len, len_before_trim;
0341 struct wg_peer *routed_peer;
0342
0343 wg_socket_set_peer_endpoint(peer, endpoint);
0344
0345 if (unlikely(wg_noise_received_with_keypair(&peer->keypairs,
0346 PACKET_CB(skb)->keypair))) {
0347 wg_timers_handshake_complete(peer);
0348 wg_packet_send_staged_packets(peer);
0349 }
0350
0351 keep_key_fresh(peer);
0352
0353 wg_timers_any_authenticated_packet_received(peer);
0354 wg_timers_any_authenticated_packet_traversal(peer);
0355
0356
0357 if (unlikely(!skb->len)) {
0358 update_rx_stats(peer, message_data_len(0));
0359 net_dbg_ratelimited("%s: Receiving keepalive packet from peer %llu (%pISpfsc)\n",
0360 dev->name, peer->internal_id,
0361 &peer->endpoint.addr);
0362 goto packet_processed;
0363 }
0364
0365 wg_timers_data_received(peer);
0366
0367 if (unlikely(skb_network_header(skb) < skb->head))
0368 goto dishonest_packet_size;
0369 if (unlikely(!(pskb_network_may_pull(skb, sizeof(struct iphdr)) &&
0370 (ip_hdr(skb)->version == 4 ||
0371 (ip_hdr(skb)->version == 6 &&
0372 pskb_network_may_pull(skb, sizeof(struct ipv6hdr)))))))
0373 goto dishonest_packet_type;
0374
0375 skb->dev = dev;
0376
0377
0378
0379
0380
0381
0382 skb->ip_summed = CHECKSUM_UNNECESSARY;
0383 skb->csum_level = ~0;
0384 skb->protocol = ip_tunnel_parse_protocol(skb);
0385 if (skb->protocol == htons(ETH_P_IP)) {
0386 len = ntohs(ip_hdr(skb)->tot_len);
0387 if (unlikely(len < sizeof(struct iphdr)))
0388 goto dishonest_packet_size;
0389 INET_ECN_decapsulate(skb, PACKET_CB(skb)->ds, ip_hdr(skb)->tos);
0390 } else if (skb->protocol == htons(ETH_P_IPV6)) {
0391 len = ntohs(ipv6_hdr(skb)->payload_len) +
0392 sizeof(struct ipv6hdr);
0393 INET_ECN_decapsulate(skb, PACKET_CB(skb)->ds, ipv6_get_dsfield(ipv6_hdr(skb)));
0394 } else {
0395 goto dishonest_packet_type;
0396 }
0397
0398 if (unlikely(len > skb->len))
0399 goto dishonest_packet_size;
0400 len_before_trim = skb->len;
0401 if (unlikely(pskb_trim(skb, len)))
0402 goto packet_processed;
0403
0404 routed_peer = wg_allowedips_lookup_src(&peer->device->peer_allowedips,
0405 skb);
0406 wg_peer_put(routed_peer);
0407
0408 if (unlikely(routed_peer != peer))
0409 goto dishonest_packet_peer;
0410
0411 napi_gro_receive(&peer->napi, skb);
0412 update_rx_stats(peer, message_data_len(len_before_trim));
0413 return;
0414
0415 dishonest_packet_peer:
0416 net_dbg_skb_ratelimited("%s: Packet has unallowed src IP (%pISc) from peer %llu (%pISpfsc)\n",
0417 dev->name, skb, peer->internal_id,
0418 &peer->endpoint.addr);
0419 ++dev->stats.rx_errors;
0420 ++dev->stats.rx_frame_errors;
0421 goto packet_processed;
0422 dishonest_packet_type:
0423 net_dbg_ratelimited("%s: Packet is neither ipv4 nor ipv6 from peer %llu (%pISpfsc)\n",
0424 dev->name, peer->internal_id, &peer->endpoint.addr);
0425 ++dev->stats.rx_errors;
0426 ++dev->stats.rx_frame_errors;
0427 goto packet_processed;
0428 dishonest_packet_size:
0429 net_dbg_ratelimited("%s: Packet has incorrect size from peer %llu (%pISpfsc)\n",
0430 dev->name, peer->internal_id, &peer->endpoint.addr);
0431 ++dev->stats.rx_errors;
0432 ++dev->stats.rx_length_errors;
0433 goto packet_processed;
0434 packet_processed:
0435 dev_kfree_skb(skb);
0436 }
0437
0438 int wg_packet_rx_poll(struct napi_struct *napi, int budget)
0439 {
0440 struct wg_peer *peer = container_of(napi, struct wg_peer, napi);
0441 struct noise_keypair *keypair;
0442 struct endpoint endpoint;
0443 enum packet_state state;
0444 struct sk_buff *skb;
0445 int work_done = 0;
0446 bool free;
0447
0448 if (unlikely(budget <= 0))
0449 return 0;
0450
0451 while ((skb = wg_prev_queue_peek(&peer->rx_queue)) != NULL &&
0452 (state = atomic_read_acquire(&PACKET_CB(skb)->state)) !=
0453 PACKET_STATE_UNCRYPTED) {
0454 wg_prev_queue_drop_peeked(&peer->rx_queue);
0455 keypair = PACKET_CB(skb)->keypair;
0456 free = true;
0457
0458 if (unlikely(state != PACKET_STATE_CRYPTED))
0459 goto next;
0460
0461 if (unlikely(!counter_validate(&keypair->receiving_counter,
0462 PACKET_CB(skb)->nonce))) {
0463 net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n",
0464 peer->device->dev->name,
0465 PACKET_CB(skb)->nonce,
0466 keypair->receiving_counter.counter);
0467 goto next;
0468 }
0469
0470 if (unlikely(wg_socket_endpoint_from_skb(&endpoint, skb)))
0471 goto next;
0472
0473 wg_reset_packet(skb, false);
0474 wg_packet_consume_data_done(peer, skb, &endpoint);
0475 free = false;
0476
0477 next:
0478 wg_noise_keypair_put(keypair, false);
0479 wg_peer_put(peer);
0480 if (unlikely(free))
0481 dev_kfree_skb(skb);
0482
0483 if (++work_done >= budget)
0484 break;
0485 }
0486
0487 if (work_done < budget)
0488 napi_complete_done(napi, work_done);
0489
0490 return work_done;
0491 }
0492
0493 void wg_packet_decrypt_worker(struct work_struct *work)
0494 {
0495 struct crypt_queue *queue = container_of(work, struct multicore_worker,
0496 work)->ptr;
0497 struct sk_buff *skb;
0498
0499 while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) {
0500 enum packet_state state =
0501 likely(decrypt_packet(skb, PACKET_CB(skb)->keypair)) ?
0502 PACKET_STATE_CRYPTED : PACKET_STATE_DEAD;
0503 wg_queue_enqueue_per_peer_rx(skb, state);
0504 if (need_resched())
0505 cond_resched();
0506 }
0507 }
0508
0509 static void wg_packet_consume_data(struct wg_device *wg, struct sk_buff *skb)
0510 {
0511 __le32 idx = ((struct message_data *)skb->data)->key_idx;
0512 struct wg_peer *peer = NULL;
0513 int ret;
0514
0515 rcu_read_lock_bh();
0516 PACKET_CB(skb)->keypair =
0517 (struct noise_keypair *)wg_index_hashtable_lookup(
0518 wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx,
0519 &peer);
0520 if (unlikely(!wg_noise_keypair_get(PACKET_CB(skb)->keypair)))
0521 goto err_keypair;
0522
0523 if (unlikely(READ_ONCE(peer->is_dead)))
0524 goto err;
0525
0526 ret = wg_queue_enqueue_per_device_and_peer(&wg->decrypt_queue, &peer->rx_queue, skb,
0527 wg->packet_crypt_wq, &wg->decrypt_queue.last_cpu);
0528 if (unlikely(ret == -EPIPE))
0529 wg_queue_enqueue_per_peer_rx(skb, PACKET_STATE_DEAD);
0530 if (likely(!ret || ret == -EPIPE)) {
0531 rcu_read_unlock_bh();
0532 return;
0533 }
0534 err:
0535 wg_noise_keypair_put(PACKET_CB(skb)->keypair, false);
0536 err_keypair:
0537 rcu_read_unlock_bh();
0538 wg_peer_put(peer);
0539 dev_kfree_skb(skb);
0540 }
0541
0542 void wg_packet_receive(struct wg_device *wg, struct sk_buff *skb)
0543 {
0544 if (unlikely(prepare_skb_header(skb, wg) < 0))
0545 goto err;
0546 switch (SKB_TYPE_LE32(skb)) {
0547 case cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION):
0548 case cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE):
0549 case cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE): {
0550 int cpu, ret = -EBUSY;
0551
0552 if (unlikely(!rng_is_initialized()))
0553 goto drop;
0554 if (atomic_read(&wg->handshake_queue_len) > MAX_QUEUED_INCOMING_HANDSHAKES / 2) {
0555 if (spin_trylock_bh(&wg->handshake_queue.ring.producer_lock)) {
0556 ret = __ptr_ring_produce(&wg->handshake_queue.ring, skb);
0557 spin_unlock_bh(&wg->handshake_queue.ring.producer_lock);
0558 }
0559 } else
0560 ret = ptr_ring_produce_bh(&wg->handshake_queue.ring, skb);
0561 if (ret) {
0562 drop:
0563 net_dbg_skb_ratelimited("%s: Dropping handshake packet from %pISpfsc\n",
0564 wg->dev->name, skb);
0565 goto err;
0566 }
0567 atomic_inc(&wg->handshake_queue_len);
0568 cpu = wg_cpumask_next_online(&wg->handshake_queue.last_cpu);
0569
0570 queue_work_on(cpu, wg->handshake_receive_wq,
0571 &per_cpu_ptr(wg->handshake_queue.worker, cpu)->work);
0572 break;
0573 }
0574 case cpu_to_le32(MESSAGE_DATA):
0575 PACKET_CB(skb)->ds = ip_tunnel_get_dsfield(ip_hdr(skb), skb);
0576 wg_packet_consume_data(wg, skb);
0577 break;
0578 default:
0579 WARN(1, "Non-exhaustive parsing of packet header lead to unknown packet type!\n");
0580 goto err;
0581 }
0582 return;
0583
0584 err:
0585 dev_kfree_skb(skb);
0586 }