0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033 #include <linux/kernel.h>
0034 #include <linux/slab.h>
0035 #include <linux/export.h>
0036 #include <linux/skbuff.h>
0037 #include <linux/list.h>
0038 #include <linux/errqueue.h>
0039
0040 #include "rds.h"
0041
0042 static unsigned int rds_exthdr_size[__RDS_EXTHDR_MAX] = {
0043 [RDS_EXTHDR_NONE] = 0,
0044 [RDS_EXTHDR_VERSION] = sizeof(struct rds_ext_header_version),
0045 [RDS_EXTHDR_RDMA] = sizeof(struct rds_ext_header_rdma),
0046 [RDS_EXTHDR_RDMA_DEST] = sizeof(struct rds_ext_header_rdma_dest),
0047 [RDS_EXTHDR_NPATHS] = sizeof(u16),
0048 [RDS_EXTHDR_GEN_NUM] = sizeof(u32),
0049 };
0050
0051 void rds_message_addref(struct rds_message *rm)
0052 {
0053 rdsdebug("addref rm %p ref %d\n", rm, refcount_read(&rm->m_refcount));
0054 refcount_inc(&rm->m_refcount);
0055 }
0056 EXPORT_SYMBOL_GPL(rds_message_addref);
0057
0058 static inline bool rds_zcookie_add(struct rds_msg_zcopy_info *info, u32 cookie)
0059 {
0060 struct rds_zcopy_cookies *ck = &info->zcookies;
0061 int ncookies = ck->num;
0062
0063 if (ncookies == RDS_MAX_ZCOOKIES)
0064 return false;
0065 ck->cookies[ncookies] = cookie;
0066 ck->num = ++ncookies;
0067 return true;
0068 }
0069
0070 static struct rds_msg_zcopy_info *rds_info_from_znotifier(struct rds_znotifier *znotif)
0071 {
0072 return container_of(znotif, struct rds_msg_zcopy_info, znotif);
0073 }
0074
0075 void rds_notify_msg_zcopy_purge(struct rds_msg_zcopy_queue *q)
0076 {
0077 unsigned long flags;
0078 LIST_HEAD(copy);
0079 struct rds_msg_zcopy_info *info, *tmp;
0080
0081 spin_lock_irqsave(&q->lock, flags);
0082 list_splice(&q->zcookie_head, ©);
0083 INIT_LIST_HEAD(&q->zcookie_head);
0084 spin_unlock_irqrestore(&q->lock, flags);
0085
0086 list_for_each_entry_safe(info, tmp, ©, rs_zcookie_next) {
0087 list_del(&info->rs_zcookie_next);
0088 kfree(info);
0089 }
0090 }
0091
0092 static void rds_rm_zerocopy_callback(struct rds_sock *rs,
0093 struct rds_znotifier *znotif)
0094 {
0095 struct rds_msg_zcopy_info *info;
0096 struct rds_msg_zcopy_queue *q;
0097 u32 cookie = znotif->z_cookie;
0098 struct rds_zcopy_cookies *ck;
0099 struct list_head *head;
0100 unsigned long flags;
0101
0102 mm_unaccount_pinned_pages(&znotif->z_mmp);
0103 q = &rs->rs_zcookie_queue;
0104 spin_lock_irqsave(&q->lock, flags);
0105 head = &q->zcookie_head;
0106 if (!list_empty(head)) {
0107 info = list_entry(head, struct rds_msg_zcopy_info,
0108 rs_zcookie_next);
0109 if (info && rds_zcookie_add(info, cookie)) {
0110 spin_unlock_irqrestore(&q->lock, flags);
0111 kfree(rds_info_from_znotifier(znotif));
0112
0113 return;
0114 }
0115 }
0116
0117 info = rds_info_from_znotifier(znotif);
0118 ck = &info->zcookies;
0119 memset(ck, 0, sizeof(*ck));
0120 WARN_ON(!rds_zcookie_add(info, cookie));
0121 list_add_tail(&q->zcookie_head, &info->rs_zcookie_next);
0122
0123 spin_unlock_irqrestore(&q->lock, flags);
0124
0125 }
0126
0127
0128
0129
0130 static void rds_message_purge(struct rds_message *rm)
0131 {
0132 unsigned long i, flags;
0133 bool zcopy = false;
0134
0135 if (unlikely(test_bit(RDS_MSG_PAGEVEC, &rm->m_flags)))
0136 return;
0137
0138 spin_lock_irqsave(&rm->m_rs_lock, flags);
0139 if (rm->m_rs) {
0140 struct rds_sock *rs = rm->m_rs;
0141
0142 if (rm->data.op_mmp_znotifier) {
0143 zcopy = true;
0144 rds_rm_zerocopy_callback(rs, rm->data.op_mmp_znotifier);
0145 rds_wake_sk_sleep(rs);
0146 rm->data.op_mmp_znotifier = NULL;
0147 }
0148 sock_put(rds_rs_to_sk(rs));
0149 rm->m_rs = NULL;
0150 }
0151 spin_unlock_irqrestore(&rm->m_rs_lock, flags);
0152
0153 for (i = 0; i < rm->data.op_nents; i++) {
0154
0155 if (!zcopy)
0156 __free_page(sg_page(&rm->data.op_sg[i]));
0157 else
0158 put_page(sg_page(&rm->data.op_sg[i]));
0159 }
0160 rm->data.op_nents = 0;
0161
0162 if (rm->rdma.op_active)
0163 rds_rdma_free_op(&rm->rdma);
0164 if (rm->rdma.op_rdma_mr)
0165 kref_put(&rm->rdma.op_rdma_mr->r_kref, __rds_put_mr_final);
0166
0167 if (rm->atomic.op_active)
0168 rds_atomic_free_op(&rm->atomic);
0169 if (rm->atomic.op_rdma_mr)
0170 kref_put(&rm->atomic.op_rdma_mr->r_kref, __rds_put_mr_final);
0171 }
0172
0173 void rds_message_put(struct rds_message *rm)
0174 {
0175 rdsdebug("put rm %p ref %d\n", rm, refcount_read(&rm->m_refcount));
0176 WARN(!refcount_read(&rm->m_refcount), "danger refcount zero on %p\n", rm);
0177 if (refcount_dec_and_test(&rm->m_refcount)) {
0178 BUG_ON(!list_empty(&rm->m_sock_item));
0179 BUG_ON(!list_empty(&rm->m_conn_item));
0180 rds_message_purge(rm);
0181
0182 kfree(rm);
0183 }
0184 }
0185 EXPORT_SYMBOL_GPL(rds_message_put);
0186
0187 void rds_message_populate_header(struct rds_header *hdr, __be16 sport,
0188 __be16 dport, u64 seq)
0189 {
0190 hdr->h_flags = 0;
0191 hdr->h_sport = sport;
0192 hdr->h_dport = dport;
0193 hdr->h_sequence = cpu_to_be64(seq);
0194 hdr->h_exthdr[0] = RDS_EXTHDR_NONE;
0195 }
0196 EXPORT_SYMBOL_GPL(rds_message_populate_header);
0197
0198 int rds_message_add_extension(struct rds_header *hdr, unsigned int type,
0199 const void *data, unsigned int len)
0200 {
0201 unsigned int ext_len = sizeof(u8) + len;
0202 unsigned char *dst;
0203
0204
0205 if (hdr->h_exthdr[0] != RDS_EXTHDR_NONE)
0206 return 0;
0207
0208 if (type >= __RDS_EXTHDR_MAX || len != rds_exthdr_size[type])
0209 return 0;
0210
0211 if (ext_len >= RDS_HEADER_EXT_SPACE)
0212 return 0;
0213 dst = hdr->h_exthdr;
0214
0215 *dst++ = type;
0216 memcpy(dst, data, len);
0217
0218 dst[len] = RDS_EXTHDR_NONE;
0219 return 1;
0220 }
0221 EXPORT_SYMBOL_GPL(rds_message_add_extension);
0222
0223
0224
0225
0226
0227
0228
0229
0230
0231
0232
0233
0234
0235
0236
0237 int rds_message_next_extension(struct rds_header *hdr,
0238 unsigned int *pos, void *buf, unsigned int *buflen)
0239 {
0240 unsigned int offset, ext_type, ext_len;
0241 u8 *src = hdr->h_exthdr;
0242
0243 offset = *pos;
0244 if (offset >= RDS_HEADER_EXT_SPACE)
0245 goto none;
0246
0247
0248
0249 ext_type = src[offset++];
0250
0251 if (ext_type == RDS_EXTHDR_NONE || ext_type >= __RDS_EXTHDR_MAX)
0252 goto none;
0253 ext_len = rds_exthdr_size[ext_type];
0254 if (offset + ext_len > RDS_HEADER_EXT_SPACE)
0255 goto none;
0256
0257 *pos = offset + ext_len;
0258 if (ext_len < *buflen)
0259 *buflen = ext_len;
0260 memcpy(buf, src + offset, *buflen);
0261 return ext_type;
0262
0263 none:
0264 *pos = RDS_HEADER_EXT_SPACE;
0265 *buflen = 0;
0266 return RDS_EXTHDR_NONE;
0267 }
0268
0269 int rds_message_add_rdma_dest_extension(struct rds_header *hdr, u32 r_key, u32 offset)
0270 {
0271 struct rds_ext_header_rdma_dest ext_hdr;
0272
0273 ext_hdr.h_rdma_rkey = cpu_to_be32(r_key);
0274 ext_hdr.h_rdma_offset = cpu_to_be32(offset);
0275 return rds_message_add_extension(hdr, RDS_EXTHDR_RDMA_DEST, &ext_hdr, sizeof(ext_hdr));
0276 }
0277 EXPORT_SYMBOL_GPL(rds_message_add_rdma_dest_extension);
0278
0279
0280
0281
0282
0283
0284 struct rds_message *rds_message_alloc(unsigned int extra_len, gfp_t gfp)
0285 {
0286 struct rds_message *rm;
0287
0288 if (extra_len > KMALLOC_MAX_SIZE - sizeof(struct rds_message))
0289 return NULL;
0290
0291 rm = kzalloc(sizeof(struct rds_message) + extra_len, gfp);
0292 if (!rm)
0293 goto out;
0294
0295 rm->m_used_sgs = 0;
0296 rm->m_total_sgs = extra_len / sizeof(struct scatterlist);
0297
0298 refcount_set(&rm->m_refcount, 1);
0299 INIT_LIST_HEAD(&rm->m_sock_item);
0300 INIT_LIST_HEAD(&rm->m_conn_item);
0301 spin_lock_init(&rm->m_rs_lock);
0302 init_waitqueue_head(&rm->m_flush_wait);
0303
0304 out:
0305 return rm;
0306 }
0307
0308
0309
0310
0311 struct scatterlist *rds_message_alloc_sgs(struct rds_message *rm, int nents)
0312 {
0313 struct scatterlist *sg_first = (struct scatterlist *) &rm[1];
0314 struct scatterlist *sg_ret;
0315
0316 if (nents <= 0) {
0317 pr_warn("rds: alloc sgs failed! nents <= 0\n");
0318 return ERR_PTR(-EINVAL);
0319 }
0320
0321 if (rm->m_used_sgs + nents > rm->m_total_sgs) {
0322 pr_warn("rds: alloc sgs failed! total %d used %d nents %d\n",
0323 rm->m_total_sgs, rm->m_used_sgs, nents);
0324 return ERR_PTR(-ENOMEM);
0325 }
0326
0327 sg_ret = &sg_first[rm->m_used_sgs];
0328 sg_init_table(sg_ret, nents);
0329 rm->m_used_sgs += nents;
0330
0331 return sg_ret;
0332 }
0333
0334 struct rds_message *rds_message_map_pages(unsigned long *page_addrs, unsigned int total_len)
0335 {
0336 struct rds_message *rm;
0337 unsigned int i;
0338 int num_sgs = DIV_ROUND_UP(total_len, PAGE_SIZE);
0339 int extra_bytes = num_sgs * sizeof(struct scatterlist);
0340
0341 rm = rds_message_alloc(extra_bytes, GFP_NOWAIT);
0342 if (!rm)
0343 return ERR_PTR(-ENOMEM);
0344
0345 set_bit(RDS_MSG_PAGEVEC, &rm->m_flags);
0346 rm->m_inc.i_hdr.h_len = cpu_to_be32(total_len);
0347 rm->data.op_nents = DIV_ROUND_UP(total_len, PAGE_SIZE);
0348 rm->data.op_sg = rds_message_alloc_sgs(rm, num_sgs);
0349 if (IS_ERR(rm->data.op_sg)) {
0350 void *err = ERR_CAST(rm->data.op_sg);
0351 rds_message_put(rm);
0352 return err;
0353 }
0354
0355 for (i = 0; i < rm->data.op_nents; ++i) {
0356 sg_set_page(&rm->data.op_sg[i],
0357 virt_to_page(page_addrs[i]),
0358 PAGE_SIZE, 0);
0359 }
0360
0361 return rm;
0362 }
0363
0364 static int rds_message_zcopy_from_user(struct rds_message *rm, struct iov_iter *from)
0365 {
0366 struct scatterlist *sg;
0367 int ret = 0;
0368 int length = iov_iter_count(from);
0369 int total_copied = 0;
0370 struct rds_msg_zcopy_info *info;
0371
0372 rm->m_inc.i_hdr.h_len = cpu_to_be32(iov_iter_count(from));
0373
0374
0375
0376
0377 sg = rm->data.op_sg;
0378
0379 info = kzalloc(sizeof(*info), GFP_KERNEL);
0380 if (!info)
0381 return -ENOMEM;
0382 INIT_LIST_HEAD(&info->rs_zcookie_next);
0383 rm->data.op_mmp_znotifier = &info->znotif;
0384 if (mm_account_pinned_pages(&rm->data.op_mmp_znotifier->z_mmp,
0385 length)) {
0386 ret = -ENOMEM;
0387 goto err;
0388 }
0389 while (iov_iter_count(from)) {
0390 struct page *pages;
0391 size_t start;
0392 ssize_t copied;
0393
0394 copied = iov_iter_get_pages2(from, &pages, PAGE_SIZE,
0395 1, &start);
0396 if (copied < 0) {
0397 struct mmpin *mmp;
0398 int i;
0399
0400 for (i = 0; i < rm->data.op_nents; i++)
0401 put_page(sg_page(&rm->data.op_sg[i]));
0402 mmp = &rm->data.op_mmp_znotifier->z_mmp;
0403 mm_unaccount_pinned_pages(mmp);
0404 ret = -EFAULT;
0405 goto err;
0406 }
0407 total_copied += copied;
0408 length -= copied;
0409 sg_set_page(sg, pages, copied, start);
0410 rm->data.op_nents++;
0411 sg++;
0412 }
0413 WARN_ON_ONCE(length != 0);
0414 return ret;
0415 err:
0416 kfree(info);
0417 rm->data.op_mmp_znotifier = NULL;
0418 return ret;
0419 }
0420
0421 int rds_message_copy_from_user(struct rds_message *rm, struct iov_iter *from,
0422 bool zcopy)
0423 {
0424 unsigned long to_copy, nbytes;
0425 unsigned long sg_off;
0426 struct scatterlist *sg;
0427 int ret = 0;
0428
0429 rm->m_inc.i_hdr.h_len = cpu_to_be32(iov_iter_count(from));
0430
0431
0432 sg = rm->data.op_sg;
0433 sg_off = 0;
0434
0435 if (zcopy)
0436 return rds_message_zcopy_from_user(rm, from);
0437
0438 while (iov_iter_count(from)) {
0439 if (!sg_page(sg)) {
0440 ret = rds_page_remainder_alloc(sg, iov_iter_count(from),
0441 GFP_HIGHUSER);
0442 if (ret)
0443 return ret;
0444 rm->data.op_nents++;
0445 sg_off = 0;
0446 }
0447
0448 to_copy = min_t(unsigned long, iov_iter_count(from),
0449 sg->length - sg_off);
0450
0451 rds_stats_add(s_copy_from_user, to_copy);
0452 nbytes = copy_page_from_iter(sg_page(sg), sg->offset + sg_off,
0453 to_copy, from);
0454 if (nbytes != to_copy)
0455 return -EFAULT;
0456
0457 sg_off += to_copy;
0458
0459 if (sg_off == sg->length)
0460 sg++;
0461 }
0462
0463 return ret;
0464 }
0465
0466 int rds_message_inc_copy_to_user(struct rds_incoming *inc, struct iov_iter *to)
0467 {
0468 struct rds_message *rm;
0469 struct scatterlist *sg;
0470 unsigned long to_copy;
0471 unsigned long vec_off;
0472 int copied;
0473 int ret;
0474 u32 len;
0475
0476 rm = container_of(inc, struct rds_message, m_inc);
0477 len = be32_to_cpu(rm->m_inc.i_hdr.h_len);
0478
0479 sg = rm->data.op_sg;
0480 vec_off = 0;
0481 copied = 0;
0482
0483 while (iov_iter_count(to) && copied < len) {
0484 to_copy = min_t(unsigned long, iov_iter_count(to),
0485 sg->length - vec_off);
0486 to_copy = min_t(unsigned long, to_copy, len - copied);
0487
0488 rds_stats_add(s_copy_to_user, to_copy);
0489 ret = copy_page_to_iter(sg_page(sg), sg->offset + vec_off,
0490 to_copy, to);
0491 if (ret != to_copy)
0492 return -EFAULT;
0493
0494 vec_off += to_copy;
0495 copied += to_copy;
0496
0497 if (vec_off == sg->length) {
0498 vec_off = 0;
0499 sg++;
0500 }
0501 }
0502
0503 return copied;
0504 }
0505
0506
0507
0508
0509
0510 void rds_message_wait(struct rds_message *rm)
0511 {
0512 wait_event_interruptible(rm->m_flush_wait,
0513 !test_bit(RDS_MSG_MAPPED, &rm->m_flags));
0514 }
0515
0516 void rds_message_unmapped(struct rds_message *rm)
0517 {
0518 clear_bit(RDS_MSG_MAPPED, &rm->m_flags);
0519 wake_up_interruptible(&rm->m_flush_wait);
0520 }
0521 EXPORT_SYMBOL_GPL(rds_message_unmapped);