Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Copyright (c) 2006 Intel Corporation.  All rights reserved.
0003  *
0004  * This software is available to you under a choice of one of two
0005  * licenses.  You may choose to be licensed under the terms of the GNU
0006  * General Public License (GPL) Version 2, available from the file
0007  * COPYING in the main directory of this source tree, or the
0008  * OpenIB.org BSD license below:
0009  *
0010  *     Redistribution and use in source and binary forms, with or
0011  *     without modification, are permitted provided that the following
0012  *     conditions are met:
0013  *
0014  *      - Redistributions of source code must retain the above
0015  *        copyright notice, this list of conditions and the following
0016  *        disclaimer.
0017  *
0018  *      - Redistributions in binary form must reproduce the above
0019  *        copyright notice, this list of conditions and the following
0020  *        disclaimer in the documentation and/or other materials
0021  *        provided with the distribution.
0022  *
0023  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
0024  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
0025  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
0026  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
0027  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
0028  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
0029  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
0030  * SOFTWARE.
0031  */
0032 
0033 #include <linux/completion.h>
0034 #include <linux/dma-mapping.h>
0035 #include <linux/err.h>
0036 #include <linux/interrupt.h>
0037 #include <linux/export.h>
0038 #include <linux/slab.h>
0039 #include <linux/bitops.h>
0040 #include <linux/random.h>
0041 
0042 #include <rdma/ib_cache.h>
0043 #include "sa.h"
0044 
0045 static int mcast_add_one(struct ib_device *device);
0046 static void mcast_remove_one(struct ib_device *device, void *client_data);
0047 
0048 static struct ib_client mcast_client = {
0049     .name   = "ib_multicast",
0050     .add    = mcast_add_one,
0051     .remove = mcast_remove_one
0052 };
0053 
0054 static struct ib_sa_client  sa_client;
0055 static struct workqueue_struct  *mcast_wq;
0056 static union ib_gid mgid0;
0057 
0058 struct mcast_device;
0059 
0060 struct mcast_port {
0061     struct mcast_device *dev;
0062     spinlock_t      lock;
0063     struct rb_root      table;
0064     refcount_t      refcount;
0065     struct completion   comp;
0066     u32         port_num;
0067 };
0068 
0069 struct mcast_device {
0070     struct ib_device    *device;
0071     struct ib_event_handler event_handler;
0072     int         start_port;
0073     int         end_port;
0074     struct mcast_port   port[];
0075 };
0076 
0077 enum mcast_state {
0078     MCAST_JOINING,
0079     MCAST_MEMBER,
0080     MCAST_ERROR,
0081 };
0082 
0083 enum mcast_group_state {
0084     MCAST_IDLE,
0085     MCAST_BUSY,
0086     MCAST_GROUP_ERROR,
0087     MCAST_PKEY_EVENT
0088 };
0089 
0090 enum {
0091     MCAST_INVALID_PKEY_INDEX = 0xFFFF
0092 };
0093 
0094 struct mcast_member;
0095 
0096 struct mcast_group {
0097     struct ib_sa_mcmember_rec rec;
0098     struct rb_node      node;
0099     struct mcast_port   *port;
0100     spinlock_t      lock;
0101     struct work_struct  work;
0102     struct list_head    pending_list;
0103     struct list_head    active_list;
0104     struct mcast_member *last_join;
0105     int         members[NUM_JOIN_MEMBERSHIP_TYPES];
0106     atomic_t        refcount;
0107     enum mcast_group_state  state;
0108     struct ib_sa_query  *query;
0109     u16         pkey_index;
0110     u8          leave_state;
0111     int         retries;
0112 };
0113 
0114 struct mcast_member {
0115     struct ib_sa_multicast  multicast;
0116     struct ib_sa_client *client;
0117     struct mcast_group  *group;
0118     struct list_head    list;
0119     enum mcast_state    state;
0120     refcount_t      refcount;
0121     struct completion   comp;
0122 };
0123 
0124 static void join_handler(int status, struct ib_sa_mcmember_rec *rec,
0125              void *context);
0126 static void leave_handler(int status, struct ib_sa_mcmember_rec *rec,
0127               void *context);
0128 
0129 static struct mcast_group *mcast_find(struct mcast_port *port,
0130                       union ib_gid *mgid)
0131 {
0132     struct rb_node *node = port->table.rb_node;
0133     struct mcast_group *group;
0134     int ret;
0135 
0136     while (node) {
0137         group = rb_entry(node, struct mcast_group, node);
0138         ret = memcmp(mgid->raw, group->rec.mgid.raw, sizeof *mgid);
0139         if (!ret)
0140             return group;
0141 
0142         if (ret < 0)
0143             node = node->rb_left;
0144         else
0145             node = node->rb_right;
0146     }
0147     return NULL;
0148 }
0149 
0150 static struct mcast_group *mcast_insert(struct mcast_port *port,
0151                     struct mcast_group *group,
0152                     int allow_duplicates)
0153 {
0154     struct rb_node **link = &port->table.rb_node;
0155     struct rb_node *parent = NULL;
0156     struct mcast_group *cur_group;
0157     int ret;
0158 
0159     while (*link) {
0160         parent = *link;
0161         cur_group = rb_entry(parent, struct mcast_group, node);
0162 
0163         ret = memcmp(group->rec.mgid.raw, cur_group->rec.mgid.raw,
0164                  sizeof group->rec.mgid);
0165         if (ret < 0)
0166             link = &(*link)->rb_left;
0167         else if (ret > 0)
0168             link = &(*link)->rb_right;
0169         else if (allow_duplicates)
0170             link = &(*link)->rb_left;
0171         else
0172             return cur_group;
0173     }
0174     rb_link_node(&group->node, parent, link);
0175     rb_insert_color(&group->node, &port->table);
0176     return NULL;
0177 }
0178 
0179 static void deref_port(struct mcast_port *port)
0180 {
0181     if (refcount_dec_and_test(&port->refcount))
0182         complete(&port->comp);
0183 }
0184 
0185 static void release_group(struct mcast_group *group)
0186 {
0187     struct mcast_port *port = group->port;
0188     unsigned long flags;
0189 
0190     spin_lock_irqsave(&port->lock, flags);
0191     if (atomic_dec_and_test(&group->refcount)) {
0192         rb_erase(&group->node, &port->table);
0193         spin_unlock_irqrestore(&port->lock, flags);
0194         kfree(group);
0195         deref_port(port);
0196     } else
0197         spin_unlock_irqrestore(&port->lock, flags);
0198 }
0199 
0200 static void deref_member(struct mcast_member *member)
0201 {
0202     if (refcount_dec_and_test(&member->refcount))
0203         complete(&member->comp);
0204 }
0205 
0206 static void queue_join(struct mcast_member *member)
0207 {
0208     struct mcast_group *group = member->group;
0209     unsigned long flags;
0210 
0211     spin_lock_irqsave(&group->lock, flags);
0212     list_add_tail(&member->list, &group->pending_list);
0213     if (group->state == MCAST_IDLE) {
0214         group->state = MCAST_BUSY;
0215         atomic_inc(&group->refcount);
0216         queue_work(mcast_wq, &group->work);
0217     }
0218     spin_unlock_irqrestore(&group->lock, flags);
0219 }
0220 
0221 /*
0222  * A multicast group has four types of members: full member, non member,
0223  * sendonly non member and sendonly full member.
0224  * We need to keep track of the number of members of each
0225  * type based on their join state.  Adjust the number of members the belong to
0226  * the specified join states.
0227  */
0228 static void adjust_membership(struct mcast_group *group, u8 join_state, int inc)
0229 {
0230     int i;
0231 
0232     for (i = 0; i < NUM_JOIN_MEMBERSHIP_TYPES; i++, join_state >>= 1)
0233         if (join_state & 0x1)
0234             group->members[i] += inc;
0235 }
0236 
0237 /*
0238  * If a multicast group has zero members left for a particular join state, but
0239  * the group is still a member with the SA, we need to leave that join state.
0240  * Determine which join states we still belong to, but that do not have any
0241  * active members.
0242  */
0243 static u8 get_leave_state(struct mcast_group *group)
0244 {
0245     u8 leave_state = 0;
0246     int i;
0247 
0248     for (i = 0; i < NUM_JOIN_MEMBERSHIP_TYPES; i++)
0249         if (!group->members[i])
0250             leave_state |= (0x1 << i);
0251 
0252     return leave_state & group->rec.join_state;
0253 }
0254 
0255 static int check_selector(ib_sa_comp_mask comp_mask,
0256               ib_sa_comp_mask selector_mask,
0257               ib_sa_comp_mask value_mask,
0258               u8 selector, u8 src_value, u8 dst_value)
0259 {
0260     int err;
0261 
0262     if (!(comp_mask & selector_mask) || !(comp_mask & value_mask))
0263         return 0;
0264 
0265     switch (selector) {
0266     case IB_SA_GT:
0267         err = (src_value <= dst_value);
0268         break;
0269     case IB_SA_LT:
0270         err = (src_value >= dst_value);
0271         break;
0272     case IB_SA_EQ:
0273         err = (src_value != dst_value);
0274         break;
0275     default:
0276         err = 0;
0277         break;
0278     }
0279 
0280     return err;
0281 }
0282 
0283 static int cmp_rec(struct ib_sa_mcmember_rec *src,
0284            struct ib_sa_mcmember_rec *dst, ib_sa_comp_mask comp_mask)
0285 {
0286     /* MGID must already match */
0287 
0288     if (comp_mask & IB_SA_MCMEMBER_REC_PORT_GID &&
0289         memcmp(&src->port_gid, &dst->port_gid, sizeof src->port_gid))
0290         return -EINVAL;
0291     if (comp_mask & IB_SA_MCMEMBER_REC_QKEY && src->qkey != dst->qkey)
0292         return -EINVAL;
0293     if (comp_mask & IB_SA_MCMEMBER_REC_MLID && src->mlid != dst->mlid)
0294         return -EINVAL;
0295     if (check_selector(comp_mask, IB_SA_MCMEMBER_REC_MTU_SELECTOR,
0296                IB_SA_MCMEMBER_REC_MTU, dst->mtu_selector,
0297                src->mtu, dst->mtu))
0298         return -EINVAL;
0299     if (comp_mask & IB_SA_MCMEMBER_REC_TRAFFIC_CLASS &&
0300         src->traffic_class != dst->traffic_class)
0301         return -EINVAL;
0302     if (comp_mask & IB_SA_MCMEMBER_REC_PKEY && src->pkey != dst->pkey)
0303         return -EINVAL;
0304     if (check_selector(comp_mask, IB_SA_MCMEMBER_REC_RATE_SELECTOR,
0305                IB_SA_MCMEMBER_REC_RATE, dst->rate_selector,
0306                src->rate, dst->rate))
0307         return -EINVAL;
0308     if (check_selector(comp_mask,
0309                IB_SA_MCMEMBER_REC_PACKET_LIFE_TIME_SELECTOR,
0310                IB_SA_MCMEMBER_REC_PACKET_LIFE_TIME,
0311                dst->packet_life_time_selector,
0312                src->packet_life_time, dst->packet_life_time))
0313         return -EINVAL;
0314     if (comp_mask & IB_SA_MCMEMBER_REC_SL && src->sl != dst->sl)
0315         return -EINVAL;
0316     if (comp_mask & IB_SA_MCMEMBER_REC_FLOW_LABEL &&
0317         src->flow_label != dst->flow_label)
0318         return -EINVAL;
0319     if (comp_mask & IB_SA_MCMEMBER_REC_HOP_LIMIT &&
0320         src->hop_limit != dst->hop_limit)
0321         return -EINVAL;
0322     if (comp_mask & IB_SA_MCMEMBER_REC_SCOPE && src->scope != dst->scope)
0323         return -EINVAL;
0324 
0325     /* join_state checked separately, proxy_join ignored */
0326 
0327     return 0;
0328 }
0329 
0330 static int send_join(struct mcast_group *group, struct mcast_member *member)
0331 {
0332     struct mcast_port *port = group->port;
0333     int ret;
0334 
0335     group->last_join = member;
0336     ret = ib_sa_mcmember_rec_query(&sa_client, port->dev->device,
0337                        port->port_num, IB_MGMT_METHOD_SET,
0338                        &member->multicast.rec,
0339                        member->multicast.comp_mask,
0340                        3000, GFP_KERNEL, join_handler, group,
0341                        &group->query);
0342     return (ret > 0) ? 0 : ret;
0343 }
0344 
0345 static int send_leave(struct mcast_group *group, u8 leave_state)
0346 {
0347     struct mcast_port *port = group->port;
0348     struct ib_sa_mcmember_rec rec;
0349     int ret;
0350 
0351     rec = group->rec;
0352     rec.join_state = leave_state;
0353     group->leave_state = leave_state;
0354 
0355     ret = ib_sa_mcmember_rec_query(&sa_client, port->dev->device,
0356                        port->port_num, IB_SA_METHOD_DELETE, &rec,
0357                        IB_SA_MCMEMBER_REC_MGID     |
0358                        IB_SA_MCMEMBER_REC_PORT_GID |
0359                        IB_SA_MCMEMBER_REC_JOIN_STATE,
0360                        3000, GFP_KERNEL, leave_handler,
0361                        group, &group->query);
0362     return (ret > 0) ? 0 : ret;
0363 }
0364 
0365 static void join_group(struct mcast_group *group, struct mcast_member *member,
0366                u8 join_state)
0367 {
0368     member->state = MCAST_MEMBER;
0369     adjust_membership(group, join_state, 1);
0370     group->rec.join_state |= join_state;
0371     member->multicast.rec = group->rec;
0372     member->multicast.rec.join_state = join_state;
0373     list_move(&member->list, &group->active_list);
0374 }
0375 
0376 static int fail_join(struct mcast_group *group, struct mcast_member *member,
0377              int status)
0378 {
0379     spin_lock_irq(&group->lock);
0380     list_del_init(&member->list);
0381     spin_unlock_irq(&group->lock);
0382     return member->multicast.callback(status, &member->multicast);
0383 }
0384 
0385 static void process_group_error(struct mcast_group *group)
0386 {
0387     struct mcast_member *member;
0388     int ret = 0;
0389     u16 pkey_index;
0390 
0391     if (group->state == MCAST_PKEY_EVENT)
0392         ret = ib_find_pkey(group->port->dev->device,
0393                    group->port->port_num,
0394                    be16_to_cpu(group->rec.pkey), &pkey_index);
0395 
0396     spin_lock_irq(&group->lock);
0397     if (group->state == MCAST_PKEY_EVENT && !ret &&
0398         group->pkey_index == pkey_index)
0399         goto out;
0400 
0401     while (!list_empty(&group->active_list)) {
0402         member = list_entry(group->active_list.next,
0403                     struct mcast_member, list);
0404         refcount_inc(&member->refcount);
0405         list_del_init(&member->list);
0406         adjust_membership(group, member->multicast.rec.join_state, -1);
0407         member->state = MCAST_ERROR;
0408         spin_unlock_irq(&group->lock);
0409 
0410         ret = member->multicast.callback(-ENETRESET,
0411                          &member->multicast);
0412         deref_member(member);
0413         if (ret)
0414             ib_sa_free_multicast(&member->multicast);
0415         spin_lock_irq(&group->lock);
0416     }
0417 
0418     group->rec.join_state = 0;
0419 out:
0420     group->state = MCAST_BUSY;
0421     spin_unlock_irq(&group->lock);
0422 }
0423 
0424 static void mcast_work_handler(struct work_struct *work)
0425 {
0426     struct mcast_group *group;
0427     struct mcast_member *member;
0428     struct ib_sa_multicast *multicast;
0429     int status, ret;
0430     u8 join_state;
0431 
0432     group = container_of(work, typeof(*group), work);
0433 retest:
0434     spin_lock_irq(&group->lock);
0435     while (!list_empty(&group->pending_list) ||
0436            (group->state != MCAST_BUSY)) {
0437 
0438         if (group->state != MCAST_BUSY) {
0439             spin_unlock_irq(&group->lock);
0440             process_group_error(group);
0441             goto retest;
0442         }
0443 
0444         member = list_entry(group->pending_list.next,
0445                     struct mcast_member, list);
0446         multicast = &member->multicast;
0447         join_state = multicast->rec.join_state;
0448         refcount_inc(&member->refcount);
0449 
0450         if (join_state == (group->rec.join_state & join_state)) {
0451             status = cmp_rec(&group->rec, &multicast->rec,
0452                      multicast->comp_mask);
0453             if (!status)
0454                 join_group(group, member, join_state);
0455             else
0456                 list_del_init(&member->list);
0457             spin_unlock_irq(&group->lock);
0458             ret = multicast->callback(status, multicast);
0459         } else {
0460             spin_unlock_irq(&group->lock);
0461             status = send_join(group, member);
0462             if (!status) {
0463                 deref_member(member);
0464                 return;
0465             }
0466             ret = fail_join(group, member, status);
0467         }
0468 
0469         deref_member(member);
0470         if (ret)
0471             ib_sa_free_multicast(&member->multicast);
0472         spin_lock_irq(&group->lock);
0473     }
0474 
0475     join_state = get_leave_state(group);
0476     if (join_state) {
0477         group->rec.join_state &= ~join_state;
0478         spin_unlock_irq(&group->lock);
0479         if (send_leave(group, join_state))
0480             goto retest;
0481     } else {
0482         group->state = MCAST_IDLE;
0483         spin_unlock_irq(&group->lock);
0484         release_group(group);
0485     }
0486 }
0487 
0488 /*
0489  * Fail a join request if it is still active - at the head of the pending queue.
0490  */
0491 static void process_join_error(struct mcast_group *group, int status)
0492 {
0493     struct mcast_member *member;
0494     int ret;
0495 
0496     spin_lock_irq(&group->lock);
0497     member = list_entry(group->pending_list.next,
0498                 struct mcast_member, list);
0499     if (group->last_join == member) {
0500         refcount_inc(&member->refcount);
0501         list_del_init(&member->list);
0502         spin_unlock_irq(&group->lock);
0503         ret = member->multicast.callback(status, &member->multicast);
0504         deref_member(member);
0505         if (ret)
0506             ib_sa_free_multicast(&member->multicast);
0507     } else
0508         spin_unlock_irq(&group->lock);
0509 }
0510 
0511 static void join_handler(int status, struct ib_sa_mcmember_rec *rec,
0512              void *context)
0513 {
0514     struct mcast_group *group = context;
0515     u16 pkey_index = MCAST_INVALID_PKEY_INDEX;
0516 
0517     if (status)
0518         process_join_error(group, status);
0519     else {
0520         int mgids_changed, is_mgid0;
0521 
0522         if (ib_find_pkey(group->port->dev->device,
0523                  group->port->port_num, be16_to_cpu(rec->pkey),
0524                  &pkey_index))
0525             pkey_index = MCAST_INVALID_PKEY_INDEX;
0526 
0527         spin_lock_irq(&group->port->lock);
0528         if (group->state == MCAST_BUSY &&
0529             group->pkey_index == MCAST_INVALID_PKEY_INDEX)
0530             group->pkey_index = pkey_index;
0531         mgids_changed = memcmp(&rec->mgid, &group->rec.mgid,
0532                        sizeof(group->rec.mgid));
0533         group->rec = *rec;
0534         if (mgids_changed) {
0535             rb_erase(&group->node, &group->port->table);
0536             is_mgid0 = !memcmp(&mgid0, &group->rec.mgid,
0537                        sizeof(mgid0));
0538             mcast_insert(group->port, group, is_mgid0);
0539         }
0540         spin_unlock_irq(&group->port->lock);
0541     }
0542     mcast_work_handler(&group->work);
0543 }
0544 
0545 static void leave_handler(int status, struct ib_sa_mcmember_rec *rec,
0546               void *context)
0547 {
0548     struct mcast_group *group = context;
0549 
0550     if (status && group->retries > 0 &&
0551         !send_leave(group, group->leave_state))
0552         group->retries--;
0553     else
0554         mcast_work_handler(&group->work);
0555 }
0556 
0557 static struct mcast_group *acquire_group(struct mcast_port *port,
0558                      union ib_gid *mgid, gfp_t gfp_mask)
0559 {
0560     struct mcast_group *group, *cur_group;
0561     unsigned long flags;
0562     int is_mgid0;
0563 
0564     is_mgid0 = !memcmp(&mgid0, mgid, sizeof mgid0);
0565     if (!is_mgid0) {
0566         spin_lock_irqsave(&port->lock, flags);
0567         group = mcast_find(port, mgid);
0568         if (group)
0569             goto found;
0570         spin_unlock_irqrestore(&port->lock, flags);
0571     }
0572 
0573     group = kzalloc(sizeof *group, gfp_mask);
0574     if (!group)
0575         return NULL;
0576 
0577     group->retries = 3;
0578     group->port = port;
0579     group->rec.mgid = *mgid;
0580     group->pkey_index = MCAST_INVALID_PKEY_INDEX;
0581     INIT_LIST_HEAD(&group->pending_list);
0582     INIT_LIST_HEAD(&group->active_list);
0583     INIT_WORK(&group->work, mcast_work_handler);
0584     spin_lock_init(&group->lock);
0585 
0586     spin_lock_irqsave(&port->lock, flags);
0587     cur_group = mcast_insert(port, group, is_mgid0);
0588     if (cur_group) {
0589         kfree(group);
0590         group = cur_group;
0591     } else
0592         refcount_inc(&port->refcount);
0593 found:
0594     atomic_inc(&group->refcount);
0595     spin_unlock_irqrestore(&port->lock, flags);
0596     return group;
0597 }
0598 
0599 /*
0600  * We serialize all join requests to a single group to make our lives much
0601  * easier.  Otherwise, two users could try to join the same group
0602  * simultaneously, with different configurations, one could leave while the
0603  * join is in progress, etc., which makes locking around error recovery
0604  * difficult.
0605  */
0606 struct ib_sa_multicast *
0607 ib_sa_join_multicast(struct ib_sa_client *client,
0608              struct ib_device *device, u32 port_num,
0609              struct ib_sa_mcmember_rec *rec,
0610              ib_sa_comp_mask comp_mask, gfp_t gfp_mask,
0611              int (*callback)(int status,
0612                      struct ib_sa_multicast *multicast),
0613              void *context)
0614 {
0615     struct mcast_device *dev;
0616     struct mcast_member *member;
0617     struct ib_sa_multicast *multicast;
0618     int ret;
0619 
0620     dev = ib_get_client_data(device, &mcast_client);
0621     if (!dev)
0622         return ERR_PTR(-ENODEV);
0623 
0624     member = kmalloc(sizeof *member, gfp_mask);
0625     if (!member)
0626         return ERR_PTR(-ENOMEM);
0627 
0628     ib_sa_client_get(client);
0629     member->client = client;
0630     member->multicast.rec = *rec;
0631     member->multicast.comp_mask = comp_mask;
0632     member->multicast.callback = callback;
0633     member->multicast.context = context;
0634     init_completion(&member->comp);
0635     refcount_set(&member->refcount, 1);
0636     member->state = MCAST_JOINING;
0637 
0638     member->group = acquire_group(&dev->port[port_num - dev->start_port],
0639                       &rec->mgid, gfp_mask);
0640     if (!member->group) {
0641         ret = -ENOMEM;
0642         goto err;
0643     }
0644 
0645     /*
0646      * The user will get the multicast structure in their callback.  They
0647      * could then free the multicast structure before we can return from
0648      * this routine.  So we save the pointer to return before queuing
0649      * any callback.
0650      */
0651     multicast = &member->multicast;
0652     queue_join(member);
0653     return multicast;
0654 
0655 err:
0656     ib_sa_client_put(client);
0657     kfree(member);
0658     return ERR_PTR(ret);
0659 }
0660 EXPORT_SYMBOL(ib_sa_join_multicast);
0661 
0662 void ib_sa_free_multicast(struct ib_sa_multicast *multicast)
0663 {
0664     struct mcast_member *member;
0665     struct mcast_group *group;
0666 
0667     member = container_of(multicast, struct mcast_member, multicast);
0668     group = member->group;
0669 
0670     spin_lock_irq(&group->lock);
0671     if (member->state == MCAST_MEMBER)
0672         adjust_membership(group, multicast->rec.join_state, -1);
0673 
0674     list_del_init(&member->list);
0675 
0676     if (group->state == MCAST_IDLE) {
0677         group->state = MCAST_BUSY;
0678         spin_unlock_irq(&group->lock);
0679         /* Continue to hold reference on group until callback */
0680         queue_work(mcast_wq, &group->work);
0681     } else {
0682         spin_unlock_irq(&group->lock);
0683         release_group(group);
0684     }
0685 
0686     deref_member(member);
0687     wait_for_completion(&member->comp);
0688     ib_sa_client_put(member->client);
0689     kfree(member);
0690 }
0691 EXPORT_SYMBOL(ib_sa_free_multicast);
0692 
0693 int ib_sa_get_mcmember_rec(struct ib_device *device, u32 port_num,
0694                union ib_gid *mgid, struct ib_sa_mcmember_rec *rec)
0695 {
0696     struct mcast_device *dev;
0697     struct mcast_port *port;
0698     struct mcast_group *group;
0699     unsigned long flags;
0700     int ret = 0;
0701 
0702     dev = ib_get_client_data(device, &mcast_client);
0703     if (!dev)
0704         return -ENODEV;
0705 
0706     port = &dev->port[port_num - dev->start_port];
0707     spin_lock_irqsave(&port->lock, flags);
0708     group = mcast_find(port, mgid);
0709     if (group)
0710         *rec = group->rec;
0711     else
0712         ret = -EADDRNOTAVAIL;
0713     spin_unlock_irqrestore(&port->lock, flags);
0714 
0715     return ret;
0716 }
0717 EXPORT_SYMBOL(ib_sa_get_mcmember_rec);
0718 
0719 /**
0720  * ib_init_ah_from_mcmember - Initialize AH attribute from multicast
0721  * member record and gid of the device.
0722  * @device: RDMA device
0723  * @port_num:   Port of the rdma device to consider
0724  * @rec:    Multicast member record to use
0725  * @ndev:   Optional netdevice, applicable only for RoCE
0726  * @gid_type:   GID type to consider
0727  * @ah_attr:    AH attribute to fillup on successful completion
0728  *
0729  * ib_init_ah_from_mcmember() initializes AH attribute based on multicast
0730  * member record and other device properties. On success the caller is
0731  * responsible to call rdma_destroy_ah_attr on the ah_attr. Returns 0 on
0732  * success or appropriate error code.
0733  *
0734  */
0735 int ib_init_ah_from_mcmember(struct ib_device *device, u32 port_num,
0736                  struct ib_sa_mcmember_rec *rec,
0737                  struct net_device *ndev,
0738                  enum ib_gid_type gid_type,
0739                  struct rdma_ah_attr *ah_attr)
0740 {
0741     const struct ib_gid_attr *sgid_attr;
0742 
0743     /* GID table is not based on the netdevice for IB link layer,
0744      * so ignore ndev during search.
0745      */
0746     if (rdma_protocol_ib(device, port_num))
0747         ndev = NULL;
0748     else if (!rdma_protocol_roce(device, port_num))
0749         return -EINVAL;
0750 
0751     sgid_attr = rdma_find_gid_by_port(device, &rec->port_gid,
0752                       gid_type, port_num, ndev);
0753     if (IS_ERR(sgid_attr))
0754         return PTR_ERR(sgid_attr);
0755 
0756     memset(ah_attr, 0, sizeof(*ah_attr));
0757     ah_attr->type = rdma_ah_find_type(device, port_num);
0758 
0759     rdma_ah_set_dlid(ah_attr, be16_to_cpu(rec->mlid));
0760     rdma_ah_set_sl(ah_attr, rec->sl);
0761     rdma_ah_set_port_num(ah_attr, port_num);
0762     rdma_ah_set_static_rate(ah_attr, rec->rate);
0763     rdma_move_grh_sgid_attr(ah_attr, &rec->mgid,
0764                 be32_to_cpu(rec->flow_label),
0765                 rec->hop_limit, rec->traffic_class,
0766                 sgid_attr);
0767     return 0;
0768 }
0769 EXPORT_SYMBOL(ib_init_ah_from_mcmember);
0770 
0771 static void mcast_groups_event(struct mcast_port *port,
0772                    enum mcast_group_state state)
0773 {
0774     struct mcast_group *group;
0775     struct rb_node *node;
0776     unsigned long flags;
0777 
0778     spin_lock_irqsave(&port->lock, flags);
0779     for (node = rb_first(&port->table); node; node = rb_next(node)) {
0780         group = rb_entry(node, struct mcast_group, node);
0781         spin_lock(&group->lock);
0782         if (group->state == MCAST_IDLE) {
0783             atomic_inc(&group->refcount);
0784             queue_work(mcast_wq, &group->work);
0785         }
0786         if (group->state != MCAST_GROUP_ERROR)
0787             group->state = state;
0788         spin_unlock(&group->lock);
0789     }
0790     spin_unlock_irqrestore(&port->lock, flags);
0791 }
0792 
0793 static void mcast_event_handler(struct ib_event_handler *handler,
0794                 struct ib_event *event)
0795 {
0796     struct mcast_device *dev;
0797     int index;
0798 
0799     dev = container_of(handler, struct mcast_device, event_handler);
0800     if (!rdma_cap_ib_mcast(dev->device, event->element.port_num))
0801         return;
0802 
0803     index = event->element.port_num - dev->start_port;
0804 
0805     switch (event->event) {
0806     case IB_EVENT_PORT_ERR:
0807     case IB_EVENT_LID_CHANGE:
0808     case IB_EVENT_CLIENT_REREGISTER:
0809         mcast_groups_event(&dev->port[index], MCAST_GROUP_ERROR);
0810         break;
0811     case IB_EVENT_PKEY_CHANGE:
0812         mcast_groups_event(&dev->port[index], MCAST_PKEY_EVENT);
0813         break;
0814     default:
0815         break;
0816     }
0817 }
0818 
0819 static int mcast_add_one(struct ib_device *device)
0820 {
0821     struct mcast_device *dev;
0822     struct mcast_port *port;
0823     int i;
0824     int count = 0;
0825 
0826     dev = kmalloc(struct_size(dev, port, device->phys_port_cnt),
0827               GFP_KERNEL);
0828     if (!dev)
0829         return -ENOMEM;
0830 
0831     dev->start_port = rdma_start_port(device);
0832     dev->end_port = rdma_end_port(device);
0833 
0834     for (i = 0; i <= dev->end_port - dev->start_port; i++) {
0835         if (!rdma_cap_ib_mcast(device, dev->start_port + i))
0836             continue;
0837         port = &dev->port[i];
0838         port->dev = dev;
0839         port->port_num = dev->start_port + i;
0840         spin_lock_init(&port->lock);
0841         port->table = RB_ROOT;
0842         init_completion(&port->comp);
0843         refcount_set(&port->refcount, 1);
0844         ++count;
0845     }
0846 
0847     if (!count) {
0848         kfree(dev);
0849         return -EOPNOTSUPP;
0850     }
0851 
0852     dev->device = device;
0853     ib_set_client_data(device, &mcast_client, dev);
0854 
0855     INIT_IB_EVENT_HANDLER(&dev->event_handler, device, mcast_event_handler);
0856     ib_register_event_handler(&dev->event_handler);
0857     return 0;
0858 }
0859 
0860 static void mcast_remove_one(struct ib_device *device, void *client_data)
0861 {
0862     struct mcast_device *dev = client_data;
0863     struct mcast_port *port;
0864     int i;
0865 
0866     ib_unregister_event_handler(&dev->event_handler);
0867     flush_workqueue(mcast_wq);
0868 
0869     for (i = 0; i <= dev->end_port - dev->start_port; i++) {
0870         if (rdma_cap_ib_mcast(device, dev->start_port + i)) {
0871             port = &dev->port[i];
0872             deref_port(port);
0873             wait_for_completion(&port->comp);
0874         }
0875     }
0876 
0877     kfree(dev);
0878 }
0879 
0880 int mcast_init(void)
0881 {
0882     int ret;
0883 
0884     mcast_wq = alloc_ordered_workqueue("ib_mcast", WQ_MEM_RECLAIM);
0885     if (!mcast_wq)
0886         return -ENOMEM;
0887 
0888     ib_sa_register_client(&sa_client);
0889 
0890     ret = ib_register_client(&mcast_client);
0891     if (ret)
0892         goto err;
0893     return 0;
0894 
0895 err:
0896     ib_sa_unregister_client(&sa_client);
0897     destroy_workqueue(mcast_wq);
0898     return ret;
0899 }
0900 
0901 void mcast_cleanup(void)
0902 {
0903     ib_unregister_client(&mcast_client);
0904     ib_sa_unregister_client(&sa_client);
0905     destroy_workqueue(mcast_wq);
0906 }