0001
0002
0003
0004
0005
0006
0007
0008 #include <linux/kernel.h>
0009 #include <linux/slab.h>
0010 #include "internal.h"
0011
0012 struct afs_vlserver *afs_alloc_vlserver(const char *name, size_t name_len,
0013 unsigned short port)
0014 {
0015 struct afs_vlserver *vlserver;
0016
0017 vlserver = kzalloc(struct_size(vlserver, name, name_len + 1),
0018 GFP_KERNEL);
0019 if (vlserver) {
0020 refcount_set(&vlserver->ref, 1);
0021 rwlock_init(&vlserver->lock);
0022 init_waitqueue_head(&vlserver->probe_wq);
0023 spin_lock_init(&vlserver->probe_lock);
0024 vlserver->rtt = UINT_MAX;
0025 vlserver->name_len = name_len;
0026 vlserver->port = port;
0027 memcpy(vlserver->name, name, name_len);
0028 }
0029 return vlserver;
0030 }
0031
0032 static void afs_vlserver_rcu(struct rcu_head *rcu)
0033 {
0034 struct afs_vlserver *vlserver = container_of(rcu, struct afs_vlserver, rcu);
0035
0036 afs_put_addrlist(rcu_access_pointer(vlserver->addresses));
0037 kfree_rcu(vlserver, rcu);
0038 }
0039
0040 void afs_put_vlserver(struct afs_net *net, struct afs_vlserver *vlserver)
0041 {
0042 if (vlserver &&
0043 refcount_dec_and_test(&vlserver->ref))
0044 call_rcu(&vlserver->rcu, afs_vlserver_rcu);
0045 }
0046
0047 struct afs_vlserver_list *afs_alloc_vlserver_list(unsigned int nr_servers)
0048 {
0049 struct afs_vlserver_list *vllist;
0050
0051 vllist = kzalloc(struct_size(vllist, servers, nr_servers), GFP_KERNEL);
0052 if (vllist) {
0053 refcount_set(&vllist->ref, 1);
0054 rwlock_init(&vllist->lock);
0055 }
0056
0057 return vllist;
0058 }
0059
0060 void afs_put_vlserverlist(struct afs_net *net, struct afs_vlserver_list *vllist)
0061 {
0062 if (vllist) {
0063 if (refcount_dec_and_test(&vllist->ref)) {
0064 int i;
0065
0066 for (i = 0; i < vllist->nr_servers; i++) {
0067 afs_put_vlserver(net, vllist->servers[i].server);
0068 }
0069 kfree_rcu(vllist, rcu);
0070 }
0071 }
0072 }
0073
0074 static u16 afs_extract_le16(const u8 **_b)
0075 {
0076 u16 val;
0077
0078 val = (u16)*(*_b)++ << 0;
0079 val |= (u16)*(*_b)++ << 8;
0080 return val;
0081 }
0082
0083
0084
0085
0086 static struct afs_addr_list *afs_extract_vl_addrs(const u8 **_b, const u8 *end,
0087 u8 nr_addrs, u16 port)
0088 {
0089 struct afs_addr_list *alist;
0090 const u8 *b = *_b;
0091 int ret = -EINVAL;
0092
0093 alist = afs_alloc_addrlist(nr_addrs, VL_SERVICE, port);
0094 if (!alist)
0095 return ERR_PTR(-ENOMEM);
0096 if (nr_addrs == 0)
0097 return alist;
0098
0099 for (; nr_addrs > 0 && end - b >= nr_addrs; nr_addrs--) {
0100 struct dns_server_list_v1_address hdr;
0101 __be32 x[4];
0102
0103 hdr.address_type = *b++;
0104
0105 switch (hdr.address_type) {
0106 case DNS_ADDRESS_IS_IPV4:
0107 if (end - b < 4) {
0108 _leave(" = -EINVAL [short inet]");
0109 goto error;
0110 }
0111 memcpy(x, b, 4);
0112 afs_merge_fs_addr4(alist, x[0], port);
0113 b += 4;
0114 break;
0115
0116 case DNS_ADDRESS_IS_IPV6:
0117 if (end - b < 16) {
0118 _leave(" = -EINVAL [short inet6]");
0119 goto error;
0120 }
0121 memcpy(x, b, 16);
0122 afs_merge_fs_addr6(alist, x, port);
0123 b += 16;
0124 break;
0125
0126 default:
0127 _leave(" = -EADDRNOTAVAIL [unknown af %u]",
0128 hdr.address_type);
0129 ret = -EADDRNOTAVAIL;
0130 goto error;
0131 }
0132 }
0133
0134
0135 if (alist->nr_ipv4 < alist->nr_addrs)
0136 alist->preferred = alist->nr_ipv4;
0137
0138 *_b = b;
0139 return alist;
0140
0141 error:
0142 *_b = b;
0143 afs_put_addrlist(alist);
0144 return ERR_PTR(ret);
0145 }
0146
0147
0148
0149
0150 struct afs_vlserver_list *afs_extract_vlserver_list(struct afs_cell *cell,
0151 const void *buffer,
0152 size_t buffer_size)
0153 {
0154 const struct dns_server_list_v1_header *hdr = buffer;
0155 struct dns_server_list_v1_server bs;
0156 struct afs_vlserver_list *vllist, *previous;
0157 struct afs_addr_list *addrs;
0158 struct afs_vlserver *server;
0159 const u8 *b = buffer, *end = buffer + buffer_size;
0160 int ret = -ENOMEM, nr_servers, i, j;
0161
0162 _enter("");
0163
0164
0165 if (end - b < sizeof(*hdr) ||
0166 hdr->hdr.content != DNS_PAYLOAD_IS_SERVER_LIST ||
0167 hdr->hdr.version != 1) {
0168 pr_notice("kAFS: Got DNS record [%u,%u] len %zu\n",
0169 hdr->hdr.content, hdr->hdr.version, end - b);
0170 ret = -EDESTADDRREQ;
0171 goto dump;
0172 }
0173
0174 nr_servers = hdr->nr_servers;
0175
0176 vllist = afs_alloc_vlserver_list(nr_servers);
0177 if (!vllist)
0178 return ERR_PTR(-ENOMEM);
0179
0180 vllist->source = (hdr->source < NR__dns_record_source) ?
0181 hdr->source : NR__dns_record_source;
0182 vllist->status = (hdr->status < NR__dns_lookup_status) ?
0183 hdr->status : NR__dns_lookup_status;
0184
0185 read_lock(&cell->vl_servers_lock);
0186 previous = afs_get_vlserverlist(
0187 rcu_dereference_protected(cell->vl_servers,
0188 lockdep_is_held(&cell->vl_servers_lock)));
0189 read_unlock(&cell->vl_servers_lock);
0190
0191 b += sizeof(*hdr);
0192 while (end - b >= sizeof(bs)) {
0193 bs.name_len = afs_extract_le16(&b);
0194 bs.priority = afs_extract_le16(&b);
0195 bs.weight = afs_extract_le16(&b);
0196 bs.port = afs_extract_le16(&b);
0197 bs.source = *b++;
0198 bs.status = *b++;
0199 bs.protocol = *b++;
0200 bs.nr_addrs = *b++;
0201
0202 _debug("extract %u %u %u %u %u %u %*.*s",
0203 bs.name_len, bs.priority, bs.weight,
0204 bs.port, bs.protocol, bs.nr_addrs,
0205 bs.name_len, bs.name_len, b);
0206
0207 if (end - b < bs.name_len)
0208 break;
0209
0210 ret = -EPROTONOSUPPORT;
0211 if (bs.protocol == DNS_SERVER_PROTOCOL_UNSPECIFIED) {
0212 bs.protocol = DNS_SERVER_PROTOCOL_UDP;
0213 } else if (bs.protocol != DNS_SERVER_PROTOCOL_UDP) {
0214 _leave(" = [proto %u]", bs.protocol);
0215 goto error;
0216 }
0217
0218 if (bs.port == 0)
0219 bs.port = AFS_VL_PORT;
0220 if (bs.source > NR__dns_record_source)
0221 bs.source = NR__dns_record_source;
0222 if (bs.status > NR__dns_lookup_status)
0223 bs.status = NR__dns_lookup_status;
0224
0225
0226 server = NULL;
0227 for (i = 0; i < previous->nr_servers; i++) {
0228 struct afs_vlserver *p = previous->servers[i].server;
0229
0230 if (p->name_len == bs.name_len &&
0231 p->port == bs.port &&
0232 strncasecmp(b, p->name, bs.name_len) == 0) {
0233 server = afs_get_vlserver(p);
0234 break;
0235 }
0236 }
0237
0238 if (!server) {
0239 ret = -ENOMEM;
0240 server = afs_alloc_vlserver(b, bs.name_len, bs.port);
0241 if (!server)
0242 goto error;
0243 }
0244
0245 b += bs.name_len;
0246
0247
0248
0249
0250 addrs = afs_extract_vl_addrs(&b, end, bs.nr_addrs, bs.port);
0251 if (IS_ERR(addrs)) {
0252 ret = PTR_ERR(addrs);
0253 goto error_2;
0254 }
0255
0256 if (vllist->nr_servers >= nr_servers) {
0257 _debug("skip %u >= %u", vllist->nr_servers, nr_servers);
0258 afs_put_addrlist(addrs);
0259 afs_put_vlserver(cell->net, server);
0260 continue;
0261 }
0262
0263 addrs->source = bs.source;
0264 addrs->status = bs.status;
0265
0266 if (addrs->nr_addrs == 0) {
0267 afs_put_addrlist(addrs);
0268 if (!rcu_access_pointer(server->addresses)) {
0269 afs_put_vlserver(cell->net, server);
0270 continue;
0271 }
0272 } else {
0273 struct afs_addr_list *old = addrs;
0274
0275 write_lock(&server->lock);
0276 old = rcu_replace_pointer(server->addresses, old,
0277 lockdep_is_held(&server->lock));
0278 write_unlock(&server->lock);
0279 afs_put_addrlist(old);
0280 }
0281
0282
0283
0284
0285
0286 for (j = 0; j < vllist->nr_servers; j++) {
0287 if (bs.priority < vllist->servers[j].priority)
0288 break;
0289 if (bs.priority == vllist->servers[j].priority &&
0290 bs.weight > vllist->servers[j].weight)
0291 break;
0292 }
0293
0294 if (j < vllist->nr_servers) {
0295 memmove(vllist->servers + j + 1,
0296 vllist->servers + j,
0297 (vllist->nr_servers - j) * sizeof(struct afs_vlserver_entry));
0298 }
0299
0300 clear_bit(AFS_VLSERVER_FL_PROBED, &server->flags);
0301
0302 vllist->servers[j].priority = bs.priority;
0303 vllist->servers[j].weight = bs.weight;
0304 vllist->servers[j].server = server;
0305 vllist->nr_servers++;
0306 }
0307
0308 if (b != end) {
0309 _debug("parse error %zd", b - end);
0310 goto error;
0311 }
0312
0313 afs_put_vlserverlist(cell->net, previous);
0314 _leave(" = ok [%u]", vllist->nr_servers);
0315 return vllist;
0316
0317 error_2:
0318 afs_put_vlserver(cell->net, server);
0319 error:
0320 afs_put_vlserverlist(cell->net, vllist);
0321 afs_put_vlserverlist(cell->net, previous);
0322 dump:
0323 if (ret != -ENOMEM) {
0324 printk(KERN_DEBUG "DNS: at %zu\n", (const void *)b - buffer);
0325 print_hex_dump_bytes("DNS: ", DUMP_PREFIX_NONE, buffer, buffer_size);
0326 }
0327 return ERR_PTR(ret);
0328 }