0001
0002
0003
0004
0005
0006 #include <linux/slab.h>
0007 #include <linux/sched.h>
0008 #include <linux/rculist.h>
0009 #include <rdma/rdma_vt.h>
0010 #include <rdma/rdmavt_qp.h>
0011
0012 #include "mcast.h"
0013
0014
0015
0016
0017
0018
0019
0020 void rvt_driver_mcast_init(struct rvt_dev_info *rdi)
0021 {
0022
0023
0024
0025
0026 spin_lock_init(&rdi->n_mcast_grps_lock);
0027 }
0028
0029
0030
0031
0032
0033 static struct rvt_mcast_qp *rvt_mcast_qp_alloc(struct rvt_qp *qp)
0034 {
0035 struct rvt_mcast_qp *mqp;
0036
0037 mqp = kmalloc(sizeof(*mqp), GFP_KERNEL);
0038 if (!mqp)
0039 goto bail;
0040
0041 mqp->qp = qp;
0042 rvt_get_qp(qp);
0043
0044 bail:
0045 return mqp;
0046 }
0047
0048 static void rvt_mcast_qp_free(struct rvt_mcast_qp *mqp)
0049 {
0050 struct rvt_qp *qp = mqp->qp;
0051
0052
0053 rvt_put_qp(qp);
0054
0055 kfree(mqp);
0056 }
0057
0058
0059
0060
0061
0062
0063
0064
0065 static struct rvt_mcast *rvt_mcast_alloc(union ib_gid *mgid, u16 lid)
0066 {
0067 struct rvt_mcast *mcast;
0068
0069 mcast = kzalloc(sizeof(*mcast), GFP_KERNEL);
0070 if (!mcast)
0071 goto bail;
0072
0073 mcast->mcast_addr.mgid = *mgid;
0074 mcast->mcast_addr.lid = lid;
0075
0076 INIT_LIST_HEAD(&mcast->qp_list);
0077 init_waitqueue_head(&mcast->wait);
0078 atomic_set(&mcast->refcount, 0);
0079
0080 bail:
0081 return mcast;
0082 }
0083
0084 static void rvt_mcast_free(struct rvt_mcast *mcast)
0085 {
0086 struct rvt_mcast_qp *p, *tmp;
0087
0088 list_for_each_entry_safe(p, tmp, &mcast->qp_list, list)
0089 rvt_mcast_qp_free(p);
0090
0091 kfree(mcast);
0092 }
0093
0094
0095
0096
0097
0098
0099
0100
0101
0102
0103
0104
0105
0106 struct rvt_mcast *rvt_mcast_find(struct rvt_ibport *ibp, union ib_gid *mgid,
0107 u16 lid)
0108 {
0109 struct rb_node *n;
0110 unsigned long flags;
0111 struct rvt_mcast *found = NULL;
0112
0113 spin_lock_irqsave(&ibp->lock, flags);
0114 n = ibp->mcast_tree.rb_node;
0115 while (n) {
0116 int ret;
0117 struct rvt_mcast *mcast;
0118
0119 mcast = rb_entry(n, struct rvt_mcast, rb_node);
0120
0121 ret = memcmp(mgid->raw, mcast->mcast_addr.mgid.raw,
0122 sizeof(*mgid));
0123 if (ret < 0) {
0124 n = n->rb_left;
0125 } else if (ret > 0) {
0126 n = n->rb_right;
0127 } else {
0128
0129 if (mcast->mcast_addr.lid == lid) {
0130 atomic_inc(&mcast->refcount);
0131 found = mcast;
0132 }
0133 break;
0134 }
0135 }
0136 spin_unlock_irqrestore(&ibp->lock, flags);
0137 return found;
0138 }
0139 EXPORT_SYMBOL(rvt_mcast_find);
0140
0141
0142
0143
0144
0145
0146
0147
0148
0149
0150
0151 static int rvt_mcast_add(struct rvt_dev_info *rdi, struct rvt_ibport *ibp,
0152 struct rvt_mcast *mcast, struct rvt_mcast_qp *mqp)
0153 {
0154 struct rb_node **n = &ibp->mcast_tree.rb_node;
0155 struct rb_node *pn = NULL;
0156 int ret;
0157
0158 spin_lock_irq(&ibp->lock);
0159
0160 while (*n) {
0161 struct rvt_mcast *tmcast;
0162 struct rvt_mcast_qp *p;
0163
0164 pn = *n;
0165 tmcast = rb_entry(pn, struct rvt_mcast, rb_node);
0166
0167 ret = memcmp(mcast->mcast_addr.mgid.raw,
0168 tmcast->mcast_addr.mgid.raw,
0169 sizeof(mcast->mcast_addr.mgid));
0170 if (ret < 0) {
0171 n = &pn->rb_left;
0172 continue;
0173 }
0174 if (ret > 0) {
0175 n = &pn->rb_right;
0176 continue;
0177 }
0178
0179 if (tmcast->mcast_addr.lid != mcast->mcast_addr.lid) {
0180 ret = EINVAL;
0181 goto bail;
0182 }
0183
0184
0185 list_for_each_entry_rcu(p, &tmcast->qp_list, list) {
0186 if (p->qp == mqp->qp) {
0187 ret = ESRCH;
0188 goto bail;
0189 }
0190 }
0191 if (tmcast->n_attached ==
0192 rdi->dparms.props.max_mcast_qp_attach) {
0193 ret = ENOMEM;
0194 goto bail;
0195 }
0196
0197 tmcast->n_attached++;
0198
0199 list_add_tail_rcu(&mqp->list, &tmcast->qp_list);
0200 ret = EEXIST;
0201 goto bail;
0202 }
0203
0204 spin_lock(&rdi->n_mcast_grps_lock);
0205 if (rdi->n_mcast_grps_allocated == rdi->dparms.props.max_mcast_grp) {
0206 spin_unlock(&rdi->n_mcast_grps_lock);
0207 ret = ENOMEM;
0208 goto bail;
0209 }
0210
0211 rdi->n_mcast_grps_allocated++;
0212 spin_unlock(&rdi->n_mcast_grps_lock);
0213
0214 mcast->n_attached++;
0215
0216 list_add_tail_rcu(&mqp->list, &mcast->qp_list);
0217
0218 atomic_inc(&mcast->refcount);
0219 rb_link_node(&mcast->rb_node, pn, n);
0220 rb_insert_color(&mcast->rb_node, &ibp->mcast_tree);
0221
0222 ret = 0;
0223
0224 bail:
0225 spin_unlock_irq(&ibp->lock);
0226
0227 return ret;
0228 }
0229
0230
0231
0232
0233
0234
0235
0236
0237
0238 int rvt_attach_mcast(struct ib_qp *ibqp, union ib_gid *gid, u16 lid)
0239 {
0240 struct rvt_qp *qp = ibqp_to_rvtqp(ibqp);
0241 struct rvt_dev_info *rdi = ib_to_rvt(ibqp->device);
0242 struct rvt_ibport *ibp = rdi->ports[qp->port_num - 1];
0243 struct rvt_mcast *mcast;
0244 struct rvt_mcast_qp *mqp;
0245 int ret = -ENOMEM;
0246
0247 if (ibqp->qp_num <= 1 || qp->state == IB_QPS_RESET)
0248 return -EINVAL;
0249
0250
0251
0252
0253
0254 mcast = rvt_mcast_alloc(gid, lid);
0255 if (!mcast)
0256 return -ENOMEM;
0257
0258 mqp = rvt_mcast_qp_alloc(qp);
0259 if (!mqp)
0260 goto bail_mcast;
0261
0262 switch (rvt_mcast_add(rdi, ibp, mcast, mqp)) {
0263 case ESRCH:
0264
0265 ret = 0;
0266 goto bail_mqp;
0267 case EEXIST:
0268 ret = 0;
0269 goto bail_mcast;
0270 case ENOMEM:
0271
0272 ret = -ENOMEM;
0273 goto bail_mqp;
0274 case EINVAL:
0275
0276 ret = -EINVAL;
0277 goto bail_mqp;
0278 default:
0279 break;
0280 }
0281
0282 return 0;
0283
0284 bail_mqp:
0285 rvt_mcast_qp_free(mqp);
0286
0287 bail_mcast:
0288 rvt_mcast_free(mcast);
0289
0290 return ret;
0291 }
0292
0293
0294
0295
0296
0297
0298
0299
0300
0301 int rvt_detach_mcast(struct ib_qp *ibqp, union ib_gid *gid, u16 lid)
0302 {
0303 struct rvt_qp *qp = ibqp_to_rvtqp(ibqp);
0304 struct rvt_dev_info *rdi = ib_to_rvt(ibqp->device);
0305 struct rvt_ibport *ibp = rdi->ports[qp->port_num - 1];
0306 struct rvt_mcast *mcast = NULL;
0307 struct rvt_mcast_qp *p, *tmp, *delp = NULL;
0308 struct rb_node *n;
0309 int last = 0;
0310 int ret = 0;
0311
0312 if (ibqp->qp_num <= 1)
0313 return -EINVAL;
0314
0315 spin_lock_irq(&ibp->lock);
0316
0317
0318 n = ibp->mcast_tree.rb_node;
0319 while (1) {
0320 if (!n) {
0321 spin_unlock_irq(&ibp->lock);
0322 return -EINVAL;
0323 }
0324
0325 mcast = rb_entry(n, struct rvt_mcast, rb_node);
0326 ret = memcmp(gid->raw, mcast->mcast_addr.mgid.raw,
0327 sizeof(*gid));
0328 if (ret < 0) {
0329 n = n->rb_left;
0330 } else if (ret > 0) {
0331 n = n->rb_right;
0332 } else {
0333
0334 if (mcast->mcast_addr.lid != lid) {
0335 spin_unlock_irq(&ibp->lock);
0336 return -EINVAL;
0337 }
0338 break;
0339 }
0340 }
0341
0342
0343 list_for_each_entry_safe(p, tmp, &mcast->qp_list, list) {
0344 if (p->qp != qp)
0345 continue;
0346
0347
0348
0349
0350 list_del_rcu(&p->list);
0351 mcast->n_attached--;
0352 delp = p;
0353
0354
0355 if (list_empty(&mcast->qp_list)) {
0356 rb_erase(&mcast->rb_node, &ibp->mcast_tree);
0357 last = 1;
0358 }
0359 break;
0360 }
0361
0362 spin_unlock_irq(&ibp->lock);
0363
0364 if (!delp)
0365 return -EINVAL;
0366
0367
0368
0369
0370
0371 wait_event(mcast->wait, atomic_read(&mcast->refcount) <= 1);
0372 rvt_mcast_qp_free(delp);
0373
0374 if (last) {
0375 atomic_dec(&mcast->refcount);
0376 wait_event(mcast->wait, !atomic_read(&mcast->refcount));
0377 rvt_mcast_free(mcast);
0378 spin_lock_irq(&rdi->n_mcast_grps_lock);
0379 rdi->n_mcast_grps_allocated--;
0380 spin_unlock_irq(&rdi->n_mcast_grps_lock);
0381 }
0382
0383 return 0;
0384 }
0385
0386
0387
0388
0389
0390
0391
0392 int rvt_mcast_tree_empty(struct rvt_dev_info *rdi)
0393 {
0394 int i;
0395 int in_use = 0;
0396
0397 for (i = 0; i < rdi->dparms.nports; i++)
0398 if (rdi->ports[i]->mcast_tree.rb_node)
0399 in_use++;
0400 return in_use;
0401 }