Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0
0002 /*
0003  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
0004  */
0005 
0006 #include "noise.h"
0007 #include "device.h"
0008 #include "peer.h"
0009 #include "messages.h"
0010 #include "queueing.h"
0011 #include "peerlookup.h"
0012 
0013 #include <linux/rcupdate.h>
0014 #include <linux/slab.h>
0015 #include <linux/bitmap.h>
0016 #include <linux/scatterlist.h>
0017 #include <linux/highmem.h>
0018 #include <crypto/algapi.h>
0019 
0020 /* This implements Noise_IKpsk2:
0021  *
0022  * <- s
0023  * ******
0024  * -> e, es, s, ss, {t}
0025  * <- e, ee, se, psk, {}
0026  */
0027 
0028 static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
0029 static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com";
0030 static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init;
0031 static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init;
0032 static atomic64_t keypair_counter = ATOMIC64_INIT(0);
0033 
0034 void __init wg_noise_init(void)
0035 {
0036     struct blake2s_state blake;
0037 
0038     blake2s(handshake_init_chaining_key, handshake_name, NULL,
0039         NOISE_HASH_LEN, sizeof(handshake_name), 0);
0040     blake2s_init(&blake, NOISE_HASH_LEN);
0041     blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN);
0042     blake2s_update(&blake, identifier_name, sizeof(identifier_name));
0043     blake2s_final(&blake, handshake_init_hash);
0044 }
0045 
0046 /* Must hold peer->handshake.static_identity->lock */
0047 void wg_noise_precompute_static_static(struct wg_peer *peer)
0048 {
0049     down_write(&peer->handshake.lock);
0050     if (!peer->handshake.static_identity->has_identity ||
0051         !curve25519(peer->handshake.precomputed_static_static,
0052             peer->handshake.static_identity->static_private,
0053             peer->handshake.remote_static))
0054         memset(peer->handshake.precomputed_static_static, 0,
0055                NOISE_PUBLIC_KEY_LEN);
0056     up_write(&peer->handshake.lock);
0057 }
0058 
0059 void wg_noise_handshake_init(struct noise_handshake *handshake,
0060                  struct noise_static_identity *static_identity,
0061                  const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
0062                  const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
0063                  struct wg_peer *peer)
0064 {
0065     memset(handshake, 0, sizeof(*handshake));
0066     init_rwsem(&handshake->lock);
0067     handshake->entry.type = INDEX_HASHTABLE_HANDSHAKE;
0068     handshake->entry.peer = peer;
0069     memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN);
0070     if (peer_preshared_key)
0071         memcpy(handshake->preshared_key, peer_preshared_key,
0072                NOISE_SYMMETRIC_KEY_LEN);
0073     handshake->static_identity = static_identity;
0074     handshake->state = HANDSHAKE_ZEROED;
0075     wg_noise_precompute_static_static(peer);
0076 }
0077 
0078 static void handshake_zero(struct noise_handshake *handshake)
0079 {
0080     memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN);
0081     memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN);
0082     memset(&handshake->hash, 0, NOISE_HASH_LEN);
0083     memset(&handshake->chaining_key, 0, NOISE_HASH_LEN);
0084     handshake->remote_index = 0;
0085     handshake->state = HANDSHAKE_ZEROED;
0086 }
0087 
0088 void wg_noise_handshake_clear(struct noise_handshake *handshake)
0089 {
0090     down_write(&handshake->lock);
0091     wg_index_hashtable_remove(
0092             handshake->entry.peer->device->index_hashtable,
0093             &handshake->entry);
0094     handshake_zero(handshake);
0095     up_write(&handshake->lock);
0096 }
0097 
0098 static struct noise_keypair *keypair_create(struct wg_peer *peer)
0099 {
0100     struct noise_keypair *keypair = kzalloc(sizeof(*keypair), GFP_KERNEL);
0101 
0102     if (unlikely(!keypair))
0103         return NULL;
0104     spin_lock_init(&keypair->receiving_counter.lock);
0105     keypair->internal_id = atomic64_inc_return(&keypair_counter);
0106     keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
0107     keypair->entry.peer = peer;
0108     kref_init(&keypair->refcount);
0109     return keypair;
0110 }
0111 
0112 static void keypair_free_rcu(struct rcu_head *rcu)
0113 {
0114     kfree_sensitive(container_of(rcu, struct noise_keypair, rcu));
0115 }
0116 
0117 static void keypair_free_kref(struct kref *kref)
0118 {
0119     struct noise_keypair *keypair =
0120         container_of(kref, struct noise_keypair, refcount);
0121 
0122     net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n",
0123                 keypair->entry.peer->device->dev->name,
0124                 keypair->internal_id,
0125                 keypair->entry.peer->internal_id);
0126     wg_index_hashtable_remove(keypair->entry.peer->device->index_hashtable,
0127                   &keypair->entry);
0128     call_rcu(&keypair->rcu, keypair_free_rcu);
0129 }
0130 
0131 void wg_noise_keypair_put(struct noise_keypair *keypair, bool unreference_now)
0132 {
0133     if (unlikely(!keypair))
0134         return;
0135     if (unlikely(unreference_now))
0136         wg_index_hashtable_remove(
0137             keypair->entry.peer->device->index_hashtable,
0138             &keypair->entry);
0139     kref_put(&keypair->refcount, keypair_free_kref);
0140 }
0141 
0142 struct noise_keypair *wg_noise_keypair_get(struct noise_keypair *keypair)
0143 {
0144     RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(),
0145         "Taking noise keypair reference without holding the RCU BH read lock");
0146     if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount)))
0147         return NULL;
0148     return keypair;
0149 }
0150 
0151 void wg_noise_keypairs_clear(struct noise_keypairs *keypairs)
0152 {
0153     struct noise_keypair *old;
0154 
0155     spin_lock_bh(&keypairs->keypair_update_lock);
0156 
0157     /* We zero the next_keypair before zeroing the others, so that
0158      * wg_noise_received_with_keypair returns early before subsequent ones
0159      * are zeroed.
0160      */
0161     old = rcu_dereference_protected(keypairs->next_keypair,
0162         lockdep_is_held(&keypairs->keypair_update_lock));
0163     RCU_INIT_POINTER(keypairs->next_keypair, NULL);
0164     wg_noise_keypair_put(old, true);
0165 
0166     old = rcu_dereference_protected(keypairs->previous_keypair,
0167         lockdep_is_held(&keypairs->keypair_update_lock));
0168     RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
0169     wg_noise_keypair_put(old, true);
0170 
0171     old = rcu_dereference_protected(keypairs->current_keypair,
0172         lockdep_is_held(&keypairs->keypair_update_lock));
0173     RCU_INIT_POINTER(keypairs->current_keypair, NULL);
0174     wg_noise_keypair_put(old, true);
0175 
0176     spin_unlock_bh(&keypairs->keypair_update_lock);
0177 }
0178 
0179 void wg_noise_expire_current_peer_keypairs(struct wg_peer *peer)
0180 {
0181     struct noise_keypair *keypair;
0182 
0183     wg_noise_handshake_clear(&peer->handshake);
0184     wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake);
0185 
0186     spin_lock_bh(&peer->keypairs.keypair_update_lock);
0187     keypair = rcu_dereference_protected(peer->keypairs.next_keypair,
0188             lockdep_is_held(&peer->keypairs.keypair_update_lock));
0189     if (keypair)
0190         keypair->sending.is_valid = false;
0191     keypair = rcu_dereference_protected(peer->keypairs.current_keypair,
0192             lockdep_is_held(&peer->keypairs.keypair_update_lock));
0193     if (keypair)
0194         keypair->sending.is_valid = false;
0195     spin_unlock_bh(&peer->keypairs.keypair_update_lock);
0196 }
0197 
0198 static void add_new_keypair(struct noise_keypairs *keypairs,
0199                 struct noise_keypair *new_keypair)
0200 {
0201     struct noise_keypair *previous_keypair, *next_keypair, *current_keypair;
0202 
0203     spin_lock_bh(&keypairs->keypair_update_lock);
0204     previous_keypair = rcu_dereference_protected(keypairs->previous_keypair,
0205         lockdep_is_held(&keypairs->keypair_update_lock));
0206     next_keypair = rcu_dereference_protected(keypairs->next_keypair,
0207         lockdep_is_held(&keypairs->keypair_update_lock));
0208     current_keypair = rcu_dereference_protected(keypairs->current_keypair,
0209         lockdep_is_held(&keypairs->keypair_update_lock));
0210     if (new_keypair->i_am_the_initiator) {
0211         /* If we're the initiator, it means we've sent a handshake, and
0212          * received a confirmation response, which means this new
0213          * keypair can now be used.
0214          */
0215         if (next_keypair) {
0216             /* If there already was a next keypair pending, we
0217              * demote it to be the previous keypair, and free the
0218              * existing current. Note that this means KCI can result
0219              * in this transition. It would perhaps be more sound to
0220              * always just get rid of the unused next keypair
0221              * instead of putting it in the previous slot, but this
0222              * might be a bit less robust. Something to think about
0223              * for the future.
0224              */
0225             RCU_INIT_POINTER(keypairs->next_keypair, NULL);
0226             rcu_assign_pointer(keypairs->previous_keypair,
0227                        next_keypair);
0228             wg_noise_keypair_put(current_keypair, true);
0229         } else /* If there wasn't an existing next keypair, we replace
0230             * the previous with the current one.
0231             */
0232             rcu_assign_pointer(keypairs->previous_keypair,
0233                        current_keypair);
0234         /* At this point we can get rid of the old previous keypair, and
0235          * set up the new keypair.
0236          */
0237         wg_noise_keypair_put(previous_keypair, true);
0238         rcu_assign_pointer(keypairs->current_keypair, new_keypair);
0239     } else {
0240         /* If we're the responder, it means we can't use the new keypair
0241          * until we receive confirmation via the first data packet, so
0242          * we get rid of the existing previous one, the possibly
0243          * existing next one, and slide in the new next one.
0244          */
0245         rcu_assign_pointer(keypairs->next_keypair, new_keypair);
0246         wg_noise_keypair_put(next_keypair, true);
0247         RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
0248         wg_noise_keypair_put(previous_keypair, true);
0249     }
0250     spin_unlock_bh(&keypairs->keypair_update_lock);
0251 }
0252 
0253 bool wg_noise_received_with_keypair(struct noise_keypairs *keypairs,
0254                     struct noise_keypair *received_keypair)
0255 {
0256     struct noise_keypair *old_keypair;
0257     bool key_is_new;
0258 
0259     /* We first check without taking the spinlock. */
0260     key_is_new = received_keypair ==
0261              rcu_access_pointer(keypairs->next_keypair);
0262     if (likely(!key_is_new))
0263         return false;
0264 
0265     spin_lock_bh(&keypairs->keypair_update_lock);
0266     /* After locking, we double check that things didn't change from
0267      * beneath us.
0268      */
0269     if (unlikely(received_keypair !=
0270             rcu_dereference_protected(keypairs->next_keypair,
0271                 lockdep_is_held(&keypairs->keypair_update_lock)))) {
0272         spin_unlock_bh(&keypairs->keypair_update_lock);
0273         return false;
0274     }
0275 
0276     /* When we've finally received the confirmation, we slide the next
0277      * into the current, the current into the previous, and get rid of
0278      * the old previous.
0279      */
0280     old_keypair = rcu_dereference_protected(keypairs->previous_keypair,
0281         lockdep_is_held(&keypairs->keypair_update_lock));
0282     rcu_assign_pointer(keypairs->previous_keypair,
0283         rcu_dereference_protected(keypairs->current_keypair,
0284             lockdep_is_held(&keypairs->keypair_update_lock)));
0285     wg_noise_keypair_put(old_keypair, true);
0286     rcu_assign_pointer(keypairs->current_keypair, received_keypair);
0287     RCU_INIT_POINTER(keypairs->next_keypair, NULL);
0288 
0289     spin_unlock_bh(&keypairs->keypair_update_lock);
0290     return true;
0291 }
0292 
0293 /* Must hold static_identity->lock */
0294 void wg_noise_set_static_identity_private_key(
0295     struct noise_static_identity *static_identity,
0296     const u8 private_key[NOISE_PUBLIC_KEY_LEN])
0297 {
0298     memcpy(static_identity->static_private, private_key,
0299            NOISE_PUBLIC_KEY_LEN);
0300     curve25519_clamp_secret(static_identity->static_private);
0301     static_identity->has_identity = curve25519_generate_public(
0302         static_identity->static_public, private_key);
0303 }
0304 
0305 static void hmac(u8 *out, const u8 *in, const u8 *key, const size_t inlen, const size_t keylen)
0306 {
0307     struct blake2s_state state;
0308     u8 x_key[BLAKE2S_BLOCK_SIZE] __aligned(__alignof__(u32)) = { 0 };
0309     u8 i_hash[BLAKE2S_HASH_SIZE] __aligned(__alignof__(u32));
0310     int i;
0311 
0312     if (keylen > BLAKE2S_BLOCK_SIZE) {
0313         blake2s_init(&state, BLAKE2S_HASH_SIZE);
0314         blake2s_update(&state, key, keylen);
0315         blake2s_final(&state, x_key);
0316     } else
0317         memcpy(x_key, key, keylen);
0318 
0319     for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
0320         x_key[i] ^= 0x36;
0321 
0322     blake2s_init(&state, BLAKE2S_HASH_SIZE);
0323     blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
0324     blake2s_update(&state, in, inlen);
0325     blake2s_final(&state, i_hash);
0326 
0327     for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
0328         x_key[i] ^= 0x5c ^ 0x36;
0329 
0330     blake2s_init(&state, BLAKE2S_HASH_SIZE);
0331     blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
0332     blake2s_update(&state, i_hash, BLAKE2S_HASH_SIZE);
0333     blake2s_final(&state, i_hash);
0334 
0335     memcpy(out, i_hash, BLAKE2S_HASH_SIZE);
0336     memzero_explicit(x_key, BLAKE2S_BLOCK_SIZE);
0337     memzero_explicit(i_hash, BLAKE2S_HASH_SIZE);
0338 }
0339 
0340 /* This is Hugo Krawczyk's HKDF:
0341  *  - https://eprint.iacr.org/2010/264.pdf
0342  *  - https://tools.ietf.org/html/rfc5869
0343  */
0344 static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
0345         size_t first_len, size_t second_len, size_t third_len,
0346         size_t data_len, const u8 chaining_key[NOISE_HASH_LEN])
0347 {
0348     u8 output[BLAKE2S_HASH_SIZE + 1];
0349     u8 secret[BLAKE2S_HASH_SIZE];
0350 
0351     WARN_ON(IS_ENABLED(DEBUG) &&
0352         (first_len > BLAKE2S_HASH_SIZE ||
0353          second_len > BLAKE2S_HASH_SIZE ||
0354          third_len > BLAKE2S_HASH_SIZE ||
0355          ((second_len || second_dst || third_len || third_dst) &&
0356           (!first_len || !first_dst)) ||
0357          ((third_len || third_dst) && (!second_len || !second_dst))));
0358 
0359     /* Extract entropy from data into secret */
0360     hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN);
0361 
0362     if (!first_dst || !first_len)
0363         goto out;
0364 
0365     /* Expand first key: key = secret, data = 0x1 */
0366     output[0] = 1;
0367     hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE);
0368     memcpy(first_dst, output, first_len);
0369 
0370     if (!second_dst || !second_len)
0371         goto out;
0372 
0373     /* Expand second key: key = secret, data = first-key || 0x2 */
0374     output[BLAKE2S_HASH_SIZE] = 2;
0375     hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE);
0376     memcpy(second_dst, output, second_len);
0377 
0378     if (!third_dst || !third_len)
0379         goto out;
0380 
0381     /* Expand third key: key = secret, data = second-key || 0x3 */
0382     output[BLAKE2S_HASH_SIZE] = 3;
0383     hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1, BLAKE2S_HASH_SIZE);
0384     memcpy(third_dst, output, third_len);
0385 
0386 out:
0387     /* Clear sensitive data from stack */
0388     memzero_explicit(secret, BLAKE2S_HASH_SIZE);
0389     memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
0390 }
0391 
0392 static void derive_keys(struct noise_symmetric_key *first_dst,
0393             struct noise_symmetric_key *second_dst,
0394             const u8 chaining_key[NOISE_HASH_LEN])
0395 {
0396     u64 birthdate = ktime_get_coarse_boottime_ns();
0397     kdf(first_dst->key, second_dst->key, NULL, NULL,
0398         NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
0399         chaining_key);
0400     first_dst->birthdate = second_dst->birthdate = birthdate;
0401     first_dst->is_valid = second_dst->is_valid = true;
0402 }
0403 
0404 static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
0405                 u8 key[NOISE_SYMMETRIC_KEY_LEN],
0406                 const u8 private[NOISE_PUBLIC_KEY_LEN],
0407                 const u8 public[NOISE_PUBLIC_KEY_LEN])
0408 {
0409     u8 dh_calculation[NOISE_PUBLIC_KEY_LEN];
0410 
0411     if (unlikely(!curve25519(dh_calculation, private, public)))
0412         return false;
0413     kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN,
0414         NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key);
0415     memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN);
0416     return true;
0417 }
0418 
0419 static bool __must_check mix_precomputed_dh(u8 chaining_key[NOISE_HASH_LEN],
0420                         u8 key[NOISE_SYMMETRIC_KEY_LEN],
0421                         const u8 precomputed[NOISE_PUBLIC_KEY_LEN])
0422 {
0423     static u8 zero_point[NOISE_PUBLIC_KEY_LEN];
0424     if (unlikely(!crypto_memneq(precomputed, zero_point, NOISE_PUBLIC_KEY_LEN)))
0425         return false;
0426     kdf(chaining_key, key, NULL, precomputed, NOISE_HASH_LEN,
0427         NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
0428         chaining_key);
0429     return true;
0430 }
0431 
0432 static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
0433 {
0434     struct blake2s_state blake;
0435 
0436     blake2s_init(&blake, NOISE_HASH_LEN);
0437     blake2s_update(&blake, hash, NOISE_HASH_LEN);
0438     blake2s_update(&blake, src, src_len);
0439     blake2s_final(&blake, hash);
0440 }
0441 
0442 static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN],
0443             u8 key[NOISE_SYMMETRIC_KEY_LEN],
0444             const u8 psk[NOISE_SYMMETRIC_KEY_LEN])
0445 {
0446     u8 temp_hash[NOISE_HASH_LEN];
0447 
0448     kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN,
0449         NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key);
0450     mix_hash(hash, temp_hash, NOISE_HASH_LEN);
0451     memzero_explicit(temp_hash, NOISE_HASH_LEN);
0452 }
0453 
0454 static void handshake_init(u8 chaining_key[NOISE_HASH_LEN],
0455                u8 hash[NOISE_HASH_LEN],
0456                const u8 remote_static[NOISE_PUBLIC_KEY_LEN])
0457 {
0458     memcpy(hash, handshake_init_hash, NOISE_HASH_LEN);
0459     memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN);
0460     mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN);
0461 }
0462 
0463 static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext,
0464                 size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
0465                 u8 hash[NOISE_HASH_LEN])
0466 {
0467     chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash,
0468                  NOISE_HASH_LEN,
0469                  0 /* Always zero for Noise_IK */, key);
0470     mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len));
0471 }
0472 
0473 static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext,
0474                 size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
0475                 u8 hash[NOISE_HASH_LEN])
0476 {
0477     if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len,
0478                       hash, NOISE_HASH_LEN,
0479                       0 /* Always zero for Noise_IK */, key))
0480         return false;
0481     mix_hash(hash, src_ciphertext, src_len);
0482     return true;
0483 }
0484 
0485 static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN],
0486                   const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN],
0487                   u8 chaining_key[NOISE_HASH_LEN],
0488                   u8 hash[NOISE_HASH_LEN])
0489 {
0490     if (ephemeral_dst != ephemeral_src)
0491         memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
0492     mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
0493     kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0,
0494         NOISE_PUBLIC_KEY_LEN, chaining_key);
0495 }
0496 
0497 static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN])
0498 {
0499     struct timespec64 now;
0500 
0501     ktime_get_real_ts64(&now);
0502 
0503     /* In order to prevent some sort of infoleak from precise timers, we
0504      * round down the nanoseconds part to the closest rounded-down power of
0505      * two to the maximum initiations per second allowed anyway by the
0506      * implementation.
0507      */
0508     now.tv_nsec = ALIGN_DOWN(now.tv_nsec,
0509         rounddown_pow_of_two(NSEC_PER_SEC / INITIATIONS_PER_SECOND));
0510 
0511     /* https://cr.yp.to/libtai/tai64.html */
0512     *(__be64 *)output = cpu_to_be64(0x400000000000000aULL + now.tv_sec);
0513     *(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec);
0514 }
0515 
0516 bool
0517 wg_noise_handshake_create_initiation(struct message_handshake_initiation *dst,
0518                      struct noise_handshake *handshake)
0519 {
0520     u8 timestamp[NOISE_TIMESTAMP_LEN];
0521     u8 key[NOISE_SYMMETRIC_KEY_LEN];
0522     bool ret = false;
0523 
0524     /* We need to wait for crng _before_ taking any locks, since
0525      * curve25519_generate_secret uses get_random_bytes_wait.
0526      */
0527     wait_for_random_bytes();
0528 
0529     down_read(&handshake->static_identity->lock);
0530     down_write(&handshake->lock);
0531 
0532     if (unlikely(!handshake->static_identity->has_identity))
0533         goto out;
0534 
0535     dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION);
0536 
0537     handshake_init(handshake->chaining_key, handshake->hash,
0538                handshake->remote_static);
0539 
0540     /* e */
0541     curve25519_generate_secret(handshake->ephemeral_private);
0542     if (!curve25519_generate_public(dst->unencrypted_ephemeral,
0543                     handshake->ephemeral_private))
0544         goto out;
0545     message_ephemeral(dst->unencrypted_ephemeral,
0546               dst->unencrypted_ephemeral, handshake->chaining_key,
0547               handshake->hash);
0548 
0549     /* es */
0550     if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private,
0551             handshake->remote_static))
0552         goto out;
0553 
0554     /* s */
0555     message_encrypt(dst->encrypted_static,
0556             handshake->static_identity->static_public,
0557             NOISE_PUBLIC_KEY_LEN, key, handshake->hash);
0558 
0559     /* ss */
0560     if (!mix_precomputed_dh(handshake->chaining_key, key,
0561                 handshake->precomputed_static_static))
0562         goto out;
0563 
0564     /* {t} */
0565     tai64n_now(timestamp);
0566     message_encrypt(dst->encrypted_timestamp, timestamp,
0567             NOISE_TIMESTAMP_LEN, key, handshake->hash);
0568 
0569     dst->sender_index = wg_index_hashtable_insert(
0570         handshake->entry.peer->device->index_hashtable,
0571         &handshake->entry);
0572 
0573     handshake->state = HANDSHAKE_CREATED_INITIATION;
0574     ret = true;
0575 
0576 out:
0577     up_write(&handshake->lock);
0578     up_read(&handshake->static_identity->lock);
0579     memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
0580     return ret;
0581 }
0582 
0583 struct wg_peer *
0584 wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
0585                       struct wg_device *wg)
0586 {
0587     struct wg_peer *peer = NULL, *ret_peer = NULL;
0588     struct noise_handshake *handshake;
0589     bool replay_attack, flood_attack;
0590     u8 key[NOISE_SYMMETRIC_KEY_LEN];
0591     u8 chaining_key[NOISE_HASH_LEN];
0592     u8 hash[NOISE_HASH_LEN];
0593     u8 s[NOISE_PUBLIC_KEY_LEN];
0594     u8 e[NOISE_PUBLIC_KEY_LEN];
0595     u8 t[NOISE_TIMESTAMP_LEN];
0596     u64 initiation_consumption;
0597 
0598     down_read(&wg->static_identity.lock);
0599     if (unlikely(!wg->static_identity.has_identity))
0600         goto out;
0601 
0602     handshake_init(chaining_key, hash, wg->static_identity.static_public);
0603 
0604     /* e */
0605     message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
0606 
0607     /* es */
0608     if (!mix_dh(chaining_key, key, wg->static_identity.static_private, e))
0609         goto out;
0610 
0611     /* s */
0612     if (!message_decrypt(s, src->encrypted_static,
0613                  sizeof(src->encrypted_static), key, hash))
0614         goto out;
0615 
0616     /* Lookup which peer we're actually talking to */
0617     peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable, s);
0618     if (!peer)
0619         goto out;
0620     handshake = &peer->handshake;
0621 
0622     /* ss */
0623     if (!mix_precomputed_dh(chaining_key, key,
0624                 handshake->precomputed_static_static))
0625         goto out;
0626 
0627     /* {t} */
0628     if (!message_decrypt(t, src->encrypted_timestamp,
0629                  sizeof(src->encrypted_timestamp), key, hash))
0630         goto out;
0631 
0632     down_read(&handshake->lock);
0633     replay_attack = memcmp(t, handshake->latest_timestamp,
0634                    NOISE_TIMESTAMP_LEN) <= 0;
0635     flood_attack = (s64)handshake->last_initiation_consumption +
0636                    NSEC_PER_SEC / INITIATIONS_PER_SECOND >
0637                (s64)ktime_get_coarse_boottime_ns();
0638     up_read(&handshake->lock);
0639     if (replay_attack || flood_attack)
0640         goto out;
0641 
0642     /* Success! Copy everything to peer */
0643     down_write(&handshake->lock);
0644     memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
0645     if (memcmp(t, handshake->latest_timestamp, NOISE_TIMESTAMP_LEN) > 0)
0646         memcpy(handshake->latest_timestamp, t, NOISE_TIMESTAMP_LEN);
0647     memcpy(handshake->hash, hash, NOISE_HASH_LEN);
0648     memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
0649     handshake->remote_index = src->sender_index;
0650     initiation_consumption = ktime_get_coarse_boottime_ns();
0651     if ((s64)(handshake->last_initiation_consumption - initiation_consumption) < 0)
0652         handshake->last_initiation_consumption = initiation_consumption;
0653     handshake->state = HANDSHAKE_CONSUMED_INITIATION;
0654     up_write(&handshake->lock);
0655     ret_peer = peer;
0656 
0657 out:
0658     memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
0659     memzero_explicit(hash, NOISE_HASH_LEN);
0660     memzero_explicit(chaining_key, NOISE_HASH_LEN);
0661     up_read(&wg->static_identity.lock);
0662     if (!ret_peer)
0663         wg_peer_put(peer);
0664     return ret_peer;
0665 }
0666 
0667 bool wg_noise_handshake_create_response(struct message_handshake_response *dst,
0668                     struct noise_handshake *handshake)
0669 {
0670     u8 key[NOISE_SYMMETRIC_KEY_LEN];
0671     bool ret = false;
0672 
0673     /* We need to wait for crng _before_ taking any locks, since
0674      * curve25519_generate_secret uses get_random_bytes_wait.
0675      */
0676     wait_for_random_bytes();
0677 
0678     down_read(&handshake->static_identity->lock);
0679     down_write(&handshake->lock);
0680 
0681     if (handshake->state != HANDSHAKE_CONSUMED_INITIATION)
0682         goto out;
0683 
0684     dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE);
0685     dst->receiver_index = handshake->remote_index;
0686 
0687     /* e */
0688     curve25519_generate_secret(handshake->ephemeral_private);
0689     if (!curve25519_generate_public(dst->unencrypted_ephemeral,
0690                     handshake->ephemeral_private))
0691         goto out;
0692     message_ephemeral(dst->unencrypted_ephemeral,
0693               dst->unencrypted_ephemeral, handshake->chaining_key,
0694               handshake->hash);
0695 
0696     /* ee */
0697     if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
0698             handshake->remote_ephemeral))
0699         goto out;
0700 
0701     /* se */
0702     if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
0703             handshake->remote_static))
0704         goto out;
0705 
0706     /* psk */
0707     mix_psk(handshake->chaining_key, handshake->hash, key,
0708         handshake->preshared_key);
0709 
0710     /* {} */
0711     message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash);
0712 
0713     dst->sender_index = wg_index_hashtable_insert(
0714         handshake->entry.peer->device->index_hashtable,
0715         &handshake->entry);
0716 
0717     handshake->state = HANDSHAKE_CREATED_RESPONSE;
0718     ret = true;
0719 
0720 out:
0721     up_write(&handshake->lock);
0722     up_read(&handshake->static_identity->lock);
0723     memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
0724     return ret;
0725 }
0726 
0727 struct wg_peer *
0728 wg_noise_handshake_consume_response(struct message_handshake_response *src,
0729                     struct wg_device *wg)
0730 {
0731     enum noise_handshake_state state = HANDSHAKE_ZEROED;
0732     struct wg_peer *peer = NULL, *ret_peer = NULL;
0733     struct noise_handshake *handshake;
0734     u8 key[NOISE_SYMMETRIC_KEY_LEN];
0735     u8 hash[NOISE_HASH_LEN];
0736     u8 chaining_key[NOISE_HASH_LEN];
0737     u8 e[NOISE_PUBLIC_KEY_LEN];
0738     u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
0739     u8 static_private[NOISE_PUBLIC_KEY_LEN];
0740     u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN];
0741 
0742     down_read(&wg->static_identity.lock);
0743 
0744     if (unlikely(!wg->static_identity.has_identity))
0745         goto out;
0746 
0747     handshake = (struct noise_handshake *)wg_index_hashtable_lookup(
0748         wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE,
0749         src->receiver_index, &peer);
0750     if (unlikely(!handshake))
0751         goto out;
0752 
0753     down_read(&handshake->lock);
0754     state = handshake->state;
0755     memcpy(hash, handshake->hash, NOISE_HASH_LEN);
0756     memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
0757     memcpy(ephemeral_private, handshake->ephemeral_private,
0758            NOISE_PUBLIC_KEY_LEN);
0759     memcpy(preshared_key, handshake->preshared_key,
0760            NOISE_SYMMETRIC_KEY_LEN);
0761     up_read(&handshake->lock);
0762 
0763     if (state != HANDSHAKE_CREATED_INITIATION)
0764         goto fail;
0765 
0766     /* e */
0767     message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
0768 
0769     /* ee */
0770     if (!mix_dh(chaining_key, NULL, ephemeral_private, e))
0771         goto fail;
0772 
0773     /* se */
0774     if (!mix_dh(chaining_key, NULL, wg->static_identity.static_private, e))
0775         goto fail;
0776 
0777     /* psk */
0778     mix_psk(chaining_key, hash, key, preshared_key);
0779 
0780     /* {} */
0781     if (!message_decrypt(NULL, src->encrypted_nothing,
0782                  sizeof(src->encrypted_nothing), key, hash))
0783         goto fail;
0784 
0785     /* Success! Copy everything to peer */
0786     down_write(&handshake->lock);
0787     /* It's important to check that the state is still the same, while we
0788      * have an exclusive lock.
0789      */
0790     if (handshake->state != state) {
0791         up_write(&handshake->lock);
0792         goto fail;
0793     }
0794     memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
0795     memcpy(handshake->hash, hash, NOISE_HASH_LEN);
0796     memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
0797     handshake->remote_index = src->sender_index;
0798     handshake->state = HANDSHAKE_CONSUMED_RESPONSE;
0799     up_write(&handshake->lock);
0800     ret_peer = peer;
0801     goto out;
0802 
0803 fail:
0804     wg_peer_put(peer);
0805 out:
0806     memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
0807     memzero_explicit(hash, NOISE_HASH_LEN);
0808     memzero_explicit(chaining_key, NOISE_HASH_LEN);
0809     memzero_explicit(ephemeral_private, NOISE_PUBLIC_KEY_LEN);
0810     memzero_explicit(static_private, NOISE_PUBLIC_KEY_LEN);
0811     memzero_explicit(preshared_key, NOISE_SYMMETRIC_KEY_LEN);
0812     up_read(&wg->static_identity.lock);
0813     return ret_peer;
0814 }
0815 
0816 bool wg_noise_handshake_begin_session(struct noise_handshake *handshake,
0817                       struct noise_keypairs *keypairs)
0818 {
0819     struct noise_keypair *new_keypair;
0820     bool ret = false;
0821 
0822     down_write(&handshake->lock);
0823     if (handshake->state != HANDSHAKE_CREATED_RESPONSE &&
0824         handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
0825         goto out;
0826 
0827     new_keypair = keypair_create(handshake->entry.peer);
0828     if (!new_keypair)
0829         goto out;
0830     new_keypair->i_am_the_initiator = handshake->state ==
0831                       HANDSHAKE_CONSUMED_RESPONSE;
0832     new_keypair->remote_index = handshake->remote_index;
0833 
0834     if (new_keypair->i_am_the_initiator)
0835         derive_keys(&new_keypair->sending, &new_keypair->receiving,
0836                 handshake->chaining_key);
0837     else
0838         derive_keys(&new_keypair->receiving, &new_keypair->sending,
0839                 handshake->chaining_key);
0840 
0841     handshake_zero(handshake);
0842     rcu_read_lock_bh();
0843     if (likely(!READ_ONCE(container_of(handshake, struct wg_peer,
0844                        handshake)->is_dead))) {
0845         add_new_keypair(keypairs, new_keypair);
0846         net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n",
0847                     handshake->entry.peer->device->dev->name,
0848                     new_keypair->internal_id,
0849                     handshake->entry.peer->internal_id);
0850         ret = wg_index_hashtable_replace(
0851             handshake->entry.peer->device->index_hashtable,
0852             &handshake->entry, &new_keypair->entry);
0853     } else {
0854         kfree_sensitive(new_keypair);
0855     }
0856     rcu_read_unlock_bh();
0857 
0858 out:
0859     up_write(&handshake->lock);
0860     return ret;
0861 }