Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0-or-later
0002 /* AFS vlserver list management.
0003  *
0004  * Copyright (C) 2018 Red Hat, Inc. All Rights Reserved.
0005  * Written by David Howells (dhowells@redhat.com)
0006  */
0007 
0008 #include <linux/kernel.h>
0009 #include <linux/slab.h>
0010 #include "internal.h"
0011 
0012 struct afs_vlserver *afs_alloc_vlserver(const char *name, size_t name_len,
0013                     unsigned short port)
0014 {
0015     struct afs_vlserver *vlserver;
0016 
0017     vlserver = kzalloc(struct_size(vlserver, name, name_len + 1),
0018                GFP_KERNEL);
0019     if (vlserver) {
0020         refcount_set(&vlserver->ref, 1);
0021         rwlock_init(&vlserver->lock);
0022         init_waitqueue_head(&vlserver->probe_wq);
0023         spin_lock_init(&vlserver->probe_lock);
0024         vlserver->rtt = UINT_MAX;
0025         vlserver->name_len = name_len;
0026         vlserver->port = port;
0027         memcpy(vlserver->name, name, name_len);
0028     }
0029     return vlserver;
0030 }
0031 
0032 static void afs_vlserver_rcu(struct rcu_head *rcu)
0033 {
0034     struct afs_vlserver *vlserver = container_of(rcu, struct afs_vlserver, rcu);
0035 
0036     afs_put_addrlist(rcu_access_pointer(vlserver->addresses));
0037     kfree_rcu(vlserver, rcu);
0038 }
0039 
0040 void afs_put_vlserver(struct afs_net *net, struct afs_vlserver *vlserver)
0041 {
0042     if (vlserver &&
0043         refcount_dec_and_test(&vlserver->ref))
0044         call_rcu(&vlserver->rcu, afs_vlserver_rcu);
0045 }
0046 
0047 struct afs_vlserver_list *afs_alloc_vlserver_list(unsigned int nr_servers)
0048 {
0049     struct afs_vlserver_list *vllist;
0050 
0051     vllist = kzalloc(struct_size(vllist, servers, nr_servers), GFP_KERNEL);
0052     if (vllist) {
0053         refcount_set(&vllist->ref, 1);
0054         rwlock_init(&vllist->lock);
0055     }
0056 
0057     return vllist;
0058 }
0059 
0060 void afs_put_vlserverlist(struct afs_net *net, struct afs_vlserver_list *vllist)
0061 {
0062     if (vllist) {
0063         if (refcount_dec_and_test(&vllist->ref)) {
0064             int i;
0065 
0066             for (i = 0; i < vllist->nr_servers; i++) {
0067                 afs_put_vlserver(net, vllist->servers[i].server);
0068             }
0069             kfree_rcu(vllist, rcu);
0070         }
0071     }
0072 }
0073 
0074 static u16 afs_extract_le16(const u8 **_b)
0075 {
0076     u16 val;
0077 
0078     val  = (u16)*(*_b)++ << 0;
0079     val |= (u16)*(*_b)++ << 8;
0080     return val;
0081 }
0082 
0083 /*
0084  * Build a VL server address list from a DNS queried server list.
0085  */
0086 static struct afs_addr_list *afs_extract_vl_addrs(const u8 **_b, const u8 *end,
0087                           u8 nr_addrs, u16 port)
0088 {
0089     struct afs_addr_list *alist;
0090     const u8 *b = *_b;
0091     int ret = -EINVAL;
0092 
0093     alist = afs_alloc_addrlist(nr_addrs, VL_SERVICE, port);
0094     if (!alist)
0095         return ERR_PTR(-ENOMEM);
0096     if (nr_addrs == 0)
0097         return alist;
0098 
0099     for (; nr_addrs > 0 && end - b >= nr_addrs; nr_addrs--) {
0100         struct dns_server_list_v1_address hdr;
0101         __be32 x[4];
0102 
0103         hdr.address_type = *b++;
0104 
0105         switch (hdr.address_type) {
0106         case DNS_ADDRESS_IS_IPV4:
0107             if (end - b < 4) {
0108                 _leave(" = -EINVAL [short inet]");
0109                 goto error;
0110             }
0111             memcpy(x, b, 4);
0112             afs_merge_fs_addr4(alist, x[0], port);
0113             b += 4;
0114             break;
0115 
0116         case DNS_ADDRESS_IS_IPV6:
0117             if (end - b < 16) {
0118                 _leave(" = -EINVAL [short inet6]");
0119                 goto error;
0120             }
0121             memcpy(x, b, 16);
0122             afs_merge_fs_addr6(alist, x, port);
0123             b += 16;
0124             break;
0125 
0126         default:
0127             _leave(" = -EADDRNOTAVAIL [unknown af %u]",
0128                    hdr.address_type);
0129             ret = -EADDRNOTAVAIL;
0130             goto error;
0131         }
0132     }
0133 
0134     /* Start with IPv6 if available. */
0135     if (alist->nr_ipv4 < alist->nr_addrs)
0136         alist->preferred = alist->nr_ipv4;
0137 
0138     *_b = b;
0139     return alist;
0140 
0141 error:
0142     *_b = b;
0143     afs_put_addrlist(alist);
0144     return ERR_PTR(ret);
0145 }
0146 
0147 /*
0148  * Build a VL server list from a DNS queried server list.
0149  */
0150 struct afs_vlserver_list *afs_extract_vlserver_list(struct afs_cell *cell,
0151                             const void *buffer,
0152                             size_t buffer_size)
0153 {
0154     const struct dns_server_list_v1_header *hdr = buffer;
0155     struct dns_server_list_v1_server bs;
0156     struct afs_vlserver_list *vllist, *previous;
0157     struct afs_addr_list *addrs;
0158     struct afs_vlserver *server;
0159     const u8 *b = buffer, *end = buffer + buffer_size;
0160     int ret = -ENOMEM, nr_servers, i, j;
0161 
0162     _enter("");
0163 
0164     /* Check that it's a server list, v1 */
0165     if (end - b < sizeof(*hdr) ||
0166         hdr->hdr.content != DNS_PAYLOAD_IS_SERVER_LIST ||
0167         hdr->hdr.version != 1) {
0168         pr_notice("kAFS: Got DNS record [%u,%u] len %zu\n",
0169               hdr->hdr.content, hdr->hdr.version, end - b);
0170         ret = -EDESTADDRREQ;
0171         goto dump;
0172     }
0173 
0174     nr_servers = hdr->nr_servers;
0175 
0176     vllist = afs_alloc_vlserver_list(nr_servers);
0177     if (!vllist)
0178         return ERR_PTR(-ENOMEM);
0179 
0180     vllist->source = (hdr->source < NR__dns_record_source) ?
0181         hdr->source : NR__dns_record_source;
0182     vllist->status = (hdr->status < NR__dns_lookup_status) ?
0183         hdr->status : NR__dns_lookup_status;
0184 
0185     read_lock(&cell->vl_servers_lock);
0186     previous = afs_get_vlserverlist(
0187         rcu_dereference_protected(cell->vl_servers,
0188                       lockdep_is_held(&cell->vl_servers_lock)));
0189     read_unlock(&cell->vl_servers_lock);
0190 
0191     b += sizeof(*hdr);
0192     while (end - b >= sizeof(bs)) {
0193         bs.name_len = afs_extract_le16(&b);
0194         bs.priority = afs_extract_le16(&b);
0195         bs.weight   = afs_extract_le16(&b);
0196         bs.port     = afs_extract_le16(&b);
0197         bs.source   = *b++;
0198         bs.status   = *b++;
0199         bs.protocol = *b++;
0200         bs.nr_addrs = *b++;
0201 
0202         _debug("extract %u %u %u %u %u %u %*.*s",
0203                bs.name_len, bs.priority, bs.weight,
0204                bs.port, bs.protocol, bs.nr_addrs,
0205                bs.name_len, bs.name_len, b);
0206 
0207         if (end - b < bs.name_len)
0208             break;
0209 
0210         ret = -EPROTONOSUPPORT;
0211         if (bs.protocol == DNS_SERVER_PROTOCOL_UNSPECIFIED) {
0212             bs.protocol = DNS_SERVER_PROTOCOL_UDP;
0213         } else if (bs.protocol != DNS_SERVER_PROTOCOL_UDP) {
0214             _leave(" = [proto %u]", bs.protocol);
0215             goto error;
0216         }
0217 
0218         if (bs.port == 0)
0219             bs.port = AFS_VL_PORT;
0220         if (bs.source > NR__dns_record_source)
0221             bs.source = NR__dns_record_source;
0222         if (bs.status > NR__dns_lookup_status)
0223             bs.status = NR__dns_lookup_status;
0224 
0225         /* See if we can update an old server record */
0226         server = NULL;
0227         for (i = 0; i < previous->nr_servers; i++) {
0228             struct afs_vlserver *p = previous->servers[i].server;
0229 
0230             if (p->name_len == bs.name_len &&
0231                 p->port == bs.port &&
0232                 strncasecmp(b, p->name, bs.name_len) == 0) {
0233                 server = afs_get_vlserver(p);
0234                 break;
0235             }
0236         }
0237 
0238         if (!server) {
0239             ret = -ENOMEM;
0240             server = afs_alloc_vlserver(b, bs.name_len, bs.port);
0241             if (!server)
0242                 goto error;
0243         }
0244 
0245         b += bs.name_len;
0246 
0247         /* Extract the addresses - note that we can't skip this as we
0248          * have to advance the payload pointer.
0249          */
0250         addrs = afs_extract_vl_addrs(&b, end, bs.nr_addrs, bs.port);
0251         if (IS_ERR(addrs)) {
0252             ret = PTR_ERR(addrs);
0253             goto error_2;
0254         }
0255 
0256         if (vllist->nr_servers >= nr_servers) {
0257             _debug("skip %u >= %u", vllist->nr_servers, nr_servers);
0258             afs_put_addrlist(addrs);
0259             afs_put_vlserver(cell->net, server);
0260             continue;
0261         }
0262 
0263         addrs->source = bs.source;
0264         addrs->status = bs.status;
0265 
0266         if (addrs->nr_addrs == 0) {
0267             afs_put_addrlist(addrs);
0268             if (!rcu_access_pointer(server->addresses)) {
0269                 afs_put_vlserver(cell->net, server);
0270                 continue;
0271             }
0272         } else {
0273             struct afs_addr_list *old = addrs;
0274 
0275             write_lock(&server->lock);
0276             old = rcu_replace_pointer(server->addresses, old,
0277                           lockdep_is_held(&server->lock));
0278             write_unlock(&server->lock);
0279             afs_put_addrlist(old);
0280         }
0281 
0282 
0283         /* TODO: Might want to check for duplicates */
0284 
0285         /* Insertion-sort by priority and weight */
0286         for (j = 0; j < vllist->nr_servers; j++) {
0287             if (bs.priority < vllist->servers[j].priority)
0288                 break; /* Lower preferable */
0289             if (bs.priority == vllist->servers[j].priority &&
0290                 bs.weight > vllist->servers[j].weight)
0291                 break; /* Higher preferable */
0292         }
0293 
0294         if (j < vllist->nr_servers) {
0295             memmove(vllist->servers + j + 1,
0296                 vllist->servers + j,
0297                 (vllist->nr_servers - j) * sizeof(struct afs_vlserver_entry));
0298         }
0299 
0300         clear_bit(AFS_VLSERVER_FL_PROBED, &server->flags);
0301 
0302         vllist->servers[j].priority = bs.priority;
0303         vllist->servers[j].weight = bs.weight;
0304         vllist->servers[j].server = server;
0305         vllist->nr_servers++;
0306     }
0307 
0308     if (b != end) {
0309         _debug("parse error %zd", b - end);
0310         goto error;
0311     }
0312 
0313     afs_put_vlserverlist(cell->net, previous);
0314     _leave(" = ok [%u]", vllist->nr_servers);
0315     return vllist;
0316 
0317 error_2:
0318     afs_put_vlserver(cell->net, server);
0319 error:
0320     afs_put_vlserverlist(cell->net, vllist);
0321     afs_put_vlserverlist(cell->net, previous);
0322 dump:
0323     if (ret != -ENOMEM) {
0324         printk(KERN_DEBUG "DNS: at %zu\n", (const void *)b - buffer);
0325         print_hex_dump_bytes("DNS: ", DUMP_PREFIX_NONE, buffer, buffer_size);
0326     }
0327     return ERR_PTR(ret);
0328 }