Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0-only
0002 #include <linux/types.h>
0003 #include <linux/skbuff.h>
0004 #include <linux/socket.h>
0005 #include <linux/sysctl.h>
0006 #include <linux/net.h>
0007 #include <linux/module.h>
0008 #include <linux/if_arp.h>
0009 #include <linux/ipv6.h>
0010 #include <linux/mpls.h>
0011 #include <linux/netconf.h>
0012 #include <linux/nospec.h>
0013 #include <linux/vmalloc.h>
0014 #include <linux/percpu.h>
0015 #include <net/ip.h>
0016 #include <net/dst.h>
0017 #include <net/sock.h>
0018 #include <net/arp.h>
0019 #include <net/ip_fib.h>
0020 #include <net/netevent.h>
0021 #include <net/ip_tunnels.h>
0022 #include <net/netns/generic.h>
0023 #if IS_ENABLED(CONFIG_IPV6)
0024 #include <net/ipv6.h>
0025 #endif
0026 #include <net/ipv6_stubs.h>
0027 #include <net/rtnh.h>
0028 #include "internal.h"
0029 
0030 /* max memory we will use for mpls_route */
0031 #define MAX_MPLS_ROUTE_MEM  4096
0032 
0033 /* Maximum number of labels to look ahead at when selecting a path of
0034  * a multipath route
0035  */
0036 #define MAX_MP_SELECT_LABELS 4
0037 
0038 #define MPLS_NEIGH_TABLE_UNSPEC (NEIGH_LINK_TABLE + 1)
0039 
0040 static int label_limit = (1 << 20) - 1;
0041 static int ttl_max = 255;
0042 
0043 #if IS_ENABLED(CONFIG_NET_IP_TUNNEL)
0044 static size_t ipgre_mpls_encap_hlen(struct ip_tunnel_encap *e)
0045 {
0046     return sizeof(struct mpls_shim_hdr);
0047 }
0048 
0049 static const struct ip_tunnel_encap_ops mpls_iptun_ops = {
0050     .encap_hlen = ipgre_mpls_encap_hlen,
0051 };
0052 
0053 static int ipgre_tunnel_encap_add_mpls_ops(void)
0054 {
0055     return ip_tunnel_encap_add_ops(&mpls_iptun_ops, TUNNEL_ENCAP_MPLS);
0056 }
0057 
0058 static void ipgre_tunnel_encap_del_mpls_ops(void)
0059 {
0060     ip_tunnel_encap_del_ops(&mpls_iptun_ops, TUNNEL_ENCAP_MPLS);
0061 }
0062 #else
0063 static int ipgre_tunnel_encap_add_mpls_ops(void)
0064 {
0065     return 0;
0066 }
0067 
0068 static void ipgre_tunnel_encap_del_mpls_ops(void)
0069 {
0070 }
0071 #endif
0072 
0073 static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt,
0074                struct nlmsghdr *nlh, struct net *net, u32 portid,
0075                unsigned int nlm_flags);
0076 
0077 static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned index)
0078 {
0079     struct mpls_route *rt = NULL;
0080 
0081     if (index < net->mpls.platform_labels) {
0082         struct mpls_route __rcu **platform_label =
0083             rcu_dereference(net->mpls.platform_label);
0084         rt = rcu_dereference(platform_label[index]);
0085     }
0086     return rt;
0087 }
0088 
0089 bool mpls_output_possible(const struct net_device *dev)
0090 {
0091     return dev && (dev->flags & IFF_UP) && netif_carrier_ok(dev);
0092 }
0093 EXPORT_SYMBOL_GPL(mpls_output_possible);
0094 
0095 static u8 *__mpls_nh_via(struct mpls_route *rt, struct mpls_nh *nh)
0096 {
0097     return (u8 *)nh + rt->rt_via_offset;
0098 }
0099 
0100 static const u8 *mpls_nh_via(const struct mpls_route *rt,
0101                  const struct mpls_nh *nh)
0102 {
0103     return __mpls_nh_via((struct mpls_route *)rt, (struct mpls_nh *)nh);
0104 }
0105 
0106 static unsigned int mpls_nh_header_size(const struct mpls_nh *nh)
0107 {
0108     /* The size of the layer 2.5 labels to be added for this route */
0109     return nh->nh_labels * sizeof(struct mpls_shim_hdr);
0110 }
0111 
0112 unsigned int mpls_dev_mtu(const struct net_device *dev)
0113 {
0114     /* The amount of data the layer 2 frame can hold */
0115     return dev->mtu;
0116 }
0117 EXPORT_SYMBOL_GPL(mpls_dev_mtu);
0118 
0119 bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu)
0120 {
0121     if (skb->len <= mtu)
0122         return false;
0123 
0124     if (skb_is_gso(skb) && skb_gso_validate_network_len(skb, mtu))
0125         return false;
0126 
0127     return true;
0128 }
0129 EXPORT_SYMBOL_GPL(mpls_pkt_too_big);
0130 
0131 void mpls_stats_inc_outucastpkts(struct net_device *dev,
0132                  const struct sk_buff *skb)
0133 {
0134     struct mpls_dev *mdev;
0135 
0136     if (skb->protocol == htons(ETH_P_MPLS_UC)) {
0137         mdev = mpls_dev_get(dev);
0138         if (mdev)
0139             MPLS_INC_STATS_LEN(mdev, skb->len,
0140                        tx_packets,
0141                        tx_bytes);
0142     } else if (skb->protocol == htons(ETH_P_IP)) {
0143         IP_UPD_PO_STATS(dev_net(dev), IPSTATS_MIB_OUT, skb->len);
0144 #if IS_ENABLED(CONFIG_IPV6)
0145     } else if (skb->protocol == htons(ETH_P_IPV6)) {
0146         struct inet6_dev *in6dev = __in6_dev_get(dev);
0147 
0148         if (in6dev)
0149             IP6_UPD_PO_STATS(dev_net(dev), in6dev,
0150                      IPSTATS_MIB_OUT, skb->len);
0151 #endif
0152     }
0153 }
0154 EXPORT_SYMBOL_GPL(mpls_stats_inc_outucastpkts);
0155 
0156 static u32 mpls_multipath_hash(struct mpls_route *rt, struct sk_buff *skb)
0157 {
0158     struct mpls_entry_decoded dec;
0159     unsigned int mpls_hdr_len = 0;
0160     struct mpls_shim_hdr *hdr;
0161     bool eli_seen = false;
0162     int label_index;
0163     u32 hash = 0;
0164 
0165     for (label_index = 0; label_index < MAX_MP_SELECT_LABELS;
0166          label_index++) {
0167         mpls_hdr_len += sizeof(*hdr);
0168         if (!pskb_may_pull(skb, mpls_hdr_len))
0169             break;
0170 
0171         /* Read and decode the current label */
0172         hdr = mpls_hdr(skb) + label_index;
0173         dec = mpls_entry_decode(hdr);
0174 
0175         /* RFC6790 - reserved labels MUST NOT be used as keys
0176          * for the load-balancing function
0177          */
0178         if (likely(dec.label >= MPLS_LABEL_FIRST_UNRESERVED)) {
0179             hash = jhash_1word(dec.label, hash);
0180 
0181             /* The entropy label follows the entropy label
0182              * indicator, so this means that the entropy
0183              * label was just added to the hash - no need to
0184              * go any deeper either in the label stack or in the
0185              * payload
0186              */
0187             if (eli_seen)
0188                 break;
0189         } else if (dec.label == MPLS_LABEL_ENTROPY) {
0190             eli_seen = true;
0191         }
0192 
0193         if (!dec.bos)
0194             continue;
0195 
0196         /* found bottom label; does skb have room for a header? */
0197         if (pskb_may_pull(skb, mpls_hdr_len + sizeof(struct iphdr))) {
0198             const struct iphdr *v4hdr;
0199 
0200             v4hdr = (const struct iphdr *)(hdr + 1);
0201             if (v4hdr->version == 4) {
0202                 hash = jhash_3words(ntohl(v4hdr->saddr),
0203                             ntohl(v4hdr->daddr),
0204                             v4hdr->protocol, hash);
0205             } else if (v4hdr->version == 6 &&
0206                    pskb_may_pull(skb, mpls_hdr_len +
0207                          sizeof(struct ipv6hdr))) {
0208                 const struct ipv6hdr *v6hdr;
0209 
0210                 v6hdr = (const struct ipv6hdr *)(hdr + 1);
0211                 hash = __ipv6_addr_jhash(&v6hdr->saddr, hash);
0212                 hash = __ipv6_addr_jhash(&v6hdr->daddr, hash);
0213                 hash = jhash_1word(v6hdr->nexthdr, hash);
0214             }
0215         }
0216 
0217         break;
0218     }
0219 
0220     return hash;
0221 }
0222 
0223 static struct mpls_nh *mpls_get_nexthop(struct mpls_route *rt, u8 index)
0224 {
0225     return (struct mpls_nh *)((u8 *)rt->rt_nh + index * rt->rt_nh_size);
0226 }
0227 
0228 /* number of alive nexthops (rt->rt_nhn_alive) and the flags for
0229  * a next hop (nh->nh_flags) are modified by netdev event handlers.
0230  * Since those fields can change at any moment, use READ_ONCE to
0231  * access both.
0232  */
0233 static const struct mpls_nh *mpls_select_multipath(struct mpls_route *rt,
0234                            struct sk_buff *skb)
0235 {
0236     u32 hash = 0;
0237     int nh_index = 0;
0238     int n = 0;
0239     u8 alive;
0240 
0241     /* No need to look further into packet if there's only
0242      * one path
0243      */
0244     if (rt->rt_nhn == 1)
0245         return rt->rt_nh;
0246 
0247     alive = READ_ONCE(rt->rt_nhn_alive);
0248     if (alive == 0)
0249         return NULL;
0250 
0251     hash = mpls_multipath_hash(rt, skb);
0252     nh_index = hash % alive;
0253     if (alive == rt->rt_nhn)
0254         goto out;
0255     for_nexthops(rt) {
0256         unsigned int nh_flags = READ_ONCE(nh->nh_flags);
0257 
0258         if (nh_flags & (RTNH_F_DEAD | RTNH_F_LINKDOWN))
0259             continue;
0260         if (n == nh_index)
0261             return nh;
0262         n++;
0263     } endfor_nexthops(rt);
0264 
0265 out:
0266     return mpls_get_nexthop(rt, nh_index);
0267 }
0268 
0269 static bool mpls_egress(struct net *net, struct mpls_route *rt,
0270             struct sk_buff *skb, struct mpls_entry_decoded dec)
0271 {
0272     enum mpls_payload_type payload_type;
0273     bool success = false;
0274 
0275     /* The IPv4 code below accesses through the IPv4 header
0276      * checksum, which is 12 bytes into the packet.
0277      * The IPv6 code below accesses through the IPv6 hop limit
0278      * which is 8 bytes into the packet.
0279      *
0280      * For all supported cases there should always be at least 12
0281      * bytes of packet data present.  The IPv4 header is 20 bytes
0282      * without options and the IPv6 header is always 40 bytes
0283      * long.
0284      */
0285     if (!pskb_may_pull(skb, 12))
0286         return false;
0287 
0288     payload_type = rt->rt_payload_type;
0289     if (payload_type == MPT_UNSPEC)
0290         payload_type = ip_hdr(skb)->version;
0291 
0292     switch (payload_type) {
0293     case MPT_IPV4: {
0294         struct iphdr *hdr4 = ip_hdr(skb);
0295         u8 new_ttl;
0296         skb->protocol = htons(ETH_P_IP);
0297 
0298         /* If propagating TTL, take the decremented TTL from
0299          * the incoming MPLS header, otherwise decrement the
0300          * TTL, but only if not 0 to avoid underflow.
0301          */
0302         if (rt->rt_ttl_propagate == MPLS_TTL_PROP_ENABLED ||
0303             (rt->rt_ttl_propagate == MPLS_TTL_PROP_DEFAULT &&
0304              net->mpls.ip_ttl_propagate))
0305             new_ttl = dec.ttl;
0306         else
0307             new_ttl = hdr4->ttl ? hdr4->ttl - 1 : 0;
0308 
0309         csum_replace2(&hdr4->check,
0310                   htons(hdr4->ttl << 8),
0311                   htons(new_ttl << 8));
0312         hdr4->ttl = new_ttl;
0313         success = true;
0314         break;
0315     }
0316     case MPT_IPV6: {
0317         struct ipv6hdr *hdr6 = ipv6_hdr(skb);
0318         skb->protocol = htons(ETH_P_IPV6);
0319 
0320         /* If propagating TTL, take the decremented TTL from
0321          * the incoming MPLS header, otherwise decrement the
0322          * hop limit, but only if not 0 to avoid underflow.
0323          */
0324         if (rt->rt_ttl_propagate == MPLS_TTL_PROP_ENABLED ||
0325             (rt->rt_ttl_propagate == MPLS_TTL_PROP_DEFAULT &&
0326              net->mpls.ip_ttl_propagate))
0327             hdr6->hop_limit = dec.ttl;
0328         else if (hdr6->hop_limit)
0329             hdr6->hop_limit = hdr6->hop_limit - 1;
0330         success = true;
0331         break;
0332     }
0333     case MPT_UNSPEC:
0334         /* Should have decided which protocol it is by now */
0335         break;
0336     }
0337 
0338     return success;
0339 }
0340 
0341 static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
0342             struct packet_type *pt, struct net_device *orig_dev)
0343 {
0344     struct net *net = dev_net(dev);
0345     struct mpls_shim_hdr *hdr;
0346     const struct mpls_nh *nh;
0347     struct mpls_route *rt;
0348     struct mpls_entry_decoded dec;
0349     struct net_device *out_dev;
0350     struct mpls_dev *out_mdev;
0351     struct mpls_dev *mdev;
0352     unsigned int hh_len;
0353     unsigned int new_header_size;
0354     unsigned int mtu;
0355     int err;
0356 
0357     /* Careful this entire function runs inside of an rcu critical section */
0358 
0359     mdev = mpls_dev_get(dev);
0360     if (!mdev)
0361         goto drop;
0362 
0363     MPLS_INC_STATS_LEN(mdev, skb->len, rx_packets,
0364                rx_bytes);
0365 
0366     if (!mdev->input_enabled) {
0367         MPLS_INC_STATS(mdev, rx_dropped);
0368         goto drop;
0369     }
0370 
0371     if (skb->pkt_type != PACKET_HOST)
0372         goto err;
0373 
0374     if ((skb = skb_share_check(skb, GFP_ATOMIC)) == NULL)
0375         goto err;
0376 
0377     if (!pskb_may_pull(skb, sizeof(*hdr)))
0378         goto err;
0379 
0380     skb_dst_drop(skb);
0381 
0382     /* Read and decode the label */
0383     hdr = mpls_hdr(skb);
0384     dec = mpls_entry_decode(hdr);
0385 
0386     rt = mpls_route_input_rcu(net, dec.label);
0387     if (!rt) {
0388         MPLS_INC_STATS(mdev, rx_noroute);
0389         goto drop;
0390     }
0391 
0392     nh = mpls_select_multipath(rt, skb);
0393     if (!nh)
0394         goto err;
0395 
0396     /* Pop the label */
0397     skb_pull(skb, sizeof(*hdr));
0398     skb_reset_network_header(skb);
0399 
0400     skb_orphan(skb);
0401 
0402     if (skb_warn_if_lro(skb))
0403         goto err;
0404 
0405     skb_forward_csum(skb);
0406 
0407     /* Verify ttl is valid */
0408     if (dec.ttl <= 1)
0409         goto err;
0410 
0411     /* Find the output device */
0412     out_dev = nh->nh_dev;
0413     if (!mpls_output_possible(out_dev))
0414         goto tx_err;
0415 
0416     /* Verify the destination can hold the packet */
0417     new_header_size = mpls_nh_header_size(nh);
0418     mtu = mpls_dev_mtu(out_dev);
0419     if (mpls_pkt_too_big(skb, mtu - new_header_size))
0420         goto tx_err;
0421 
0422     hh_len = LL_RESERVED_SPACE(out_dev);
0423     if (!out_dev->header_ops)
0424         hh_len = 0;
0425 
0426     /* Ensure there is enough space for the headers in the skb */
0427     if (skb_cow(skb, hh_len + new_header_size))
0428         goto tx_err;
0429 
0430     skb->dev = out_dev;
0431     skb->protocol = htons(ETH_P_MPLS_UC);
0432 
0433     dec.ttl -= 1;
0434     if (unlikely(!new_header_size && dec.bos)) {
0435         /* Penultimate hop popping */
0436         if (!mpls_egress(dev_net(out_dev), rt, skb, dec))
0437             goto err;
0438     } else {
0439         bool bos;
0440         int i;
0441         skb_push(skb, new_header_size);
0442         skb_reset_network_header(skb);
0443         /* Push the new labels */
0444         hdr = mpls_hdr(skb);
0445         bos = dec.bos;
0446         for (i = nh->nh_labels - 1; i >= 0; i--) {
0447             hdr[i] = mpls_entry_encode(nh->nh_label[i],
0448                            dec.ttl, 0, bos);
0449             bos = false;
0450         }
0451     }
0452 
0453     mpls_stats_inc_outucastpkts(out_dev, skb);
0454 
0455     /* If via wasn't specified then send out using device address */
0456     if (nh->nh_via_table == MPLS_NEIGH_TABLE_UNSPEC)
0457         err = neigh_xmit(NEIGH_LINK_TABLE, out_dev,
0458                  out_dev->dev_addr, skb);
0459     else
0460         err = neigh_xmit(nh->nh_via_table, out_dev,
0461                  mpls_nh_via(rt, nh), skb);
0462     if (err)
0463         net_dbg_ratelimited("%s: packet transmission failed: %d\n",
0464                     __func__, err);
0465     return 0;
0466 
0467 tx_err:
0468     out_mdev = out_dev ? mpls_dev_get(out_dev) : NULL;
0469     if (out_mdev)
0470         MPLS_INC_STATS(out_mdev, tx_errors);
0471     goto drop;
0472 err:
0473     MPLS_INC_STATS(mdev, rx_errors);
0474 drop:
0475     kfree_skb(skb);
0476     return NET_RX_DROP;
0477 }
0478 
0479 static struct packet_type mpls_packet_type __read_mostly = {
0480     .type = cpu_to_be16(ETH_P_MPLS_UC),
0481     .func = mpls_forward,
0482 };
0483 
0484 static const struct nla_policy rtm_mpls_policy[RTA_MAX+1] = {
0485     [RTA_DST]       = { .type = NLA_U32 },
0486     [RTA_OIF]       = { .type = NLA_U32 },
0487     [RTA_TTL_PROPAGATE] = { .type = NLA_U8 },
0488 };
0489 
0490 struct mpls_route_config {
0491     u32         rc_protocol;
0492     u32         rc_ifindex;
0493     u8          rc_via_table;
0494     u8          rc_via_alen;
0495     u8          rc_via[MAX_VIA_ALEN];
0496     u32         rc_label;
0497     u8          rc_ttl_propagate;
0498     u8          rc_output_labels;
0499     u32         rc_output_label[MAX_NEW_LABELS];
0500     u32         rc_nlflags;
0501     enum mpls_payload_type  rc_payload_type;
0502     struct nl_info      rc_nlinfo;
0503     struct rtnexthop    *rc_mp;
0504     int         rc_mp_len;
0505 };
0506 
0507 /* all nexthops within a route have the same size based on max
0508  * number of labels and max via length for a hop
0509  */
0510 static struct mpls_route *mpls_rt_alloc(u8 num_nh, u8 max_alen, u8 max_labels)
0511 {
0512     u8 nh_size = MPLS_NH_SIZE(max_labels, max_alen);
0513     struct mpls_route *rt;
0514     size_t size;
0515 
0516     size = sizeof(*rt) + num_nh * nh_size;
0517     if (size > MAX_MPLS_ROUTE_MEM)
0518         return ERR_PTR(-EINVAL);
0519 
0520     rt = kzalloc(size, GFP_KERNEL);
0521     if (!rt)
0522         return ERR_PTR(-ENOMEM);
0523 
0524     rt->rt_nhn = num_nh;
0525     rt->rt_nhn_alive = num_nh;
0526     rt->rt_nh_size = nh_size;
0527     rt->rt_via_offset = MPLS_NH_VIA_OFF(max_labels);
0528 
0529     return rt;
0530 }
0531 
0532 static void mpls_rt_free(struct mpls_route *rt)
0533 {
0534     if (rt)
0535         kfree_rcu(rt, rt_rcu);
0536 }
0537 
0538 static void mpls_notify_route(struct net *net, unsigned index,
0539                   struct mpls_route *old, struct mpls_route *new,
0540                   const struct nl_info *info)
0541 {
0542     struct nlmsghdr *nlh = info ? info->nlh : NULL;
0543     unsigned portid = info ? info->portid : 0;
0544     int event = new ? RTM_NEWROUTE : RTM_DELROUTE;
0545     struct mpls_route *rt = new ? new : old;
0546     unsigned nlm_flags = (old && new) ? NLM_F_REPLACE : 0;
0547     /* Ignore reserved labels for now */
0548     if (rt && (index >= MPLS_LABEL_FIRST_UNRESERVED))
0549         rtmsg_lfib(event, index, rt, nlh, net, portid, nlm_flags);
0550 }
0551 
0552 static void mpls_route_update(struct net *net, unsigned index,
0553                   struct mpls_route *new,
0554                   const struct nl_info *info)
0555 {
0556     struct mpls_route __rcu **platform_label;
0557     struct mpls_route *rt;
0558 
0559     ASSERT_RTNL();
0560 
0561     platform_label = rtnl_dereference(net->mpls.platform_label);
0562     rt = rtnl_dereference(platform_label[index]);
0563     rcu_assign_pointer(platform_label[index], new);
0564 
0565     mpls_notify_route(net, index, rt, new, info);
0566 
0567     /* If we removed a route free it now */
0568     mpls_rt_free(rt);
0569 }
0570 
0571 static unsigned find_free_label(struct net *net)
0572 {
0573     struct mpls_route __rcu **platform_label;
0574     size_t platform_labels;
0575     unsigned index;
0576 
0577     platform_label = rtnl_dereference(net->mpls.platform_label);
0578     platform_labels = net->mpls.platform_labels;
0579     for (index = MPLS_LABEL_FIRST_UNRESERVED; index < platform_labels;
0580          index++) {
0581         if (!rtnl_dereference(platform_label[index]))
0582             return index;
0583     }
0584     return LABEL_NOT_SPECIFIED;
0585 }
0586 
0587 #if IS_ENABLED(CONFIG_INET)
0588 static struct net_device *inet_fib_lookup_dev(struct net *net,
0589                           const void *addr)
0590 {
0591     struct net_device *dev;
0592     struct rtable *rt;
0593     struct in_addr daddr;
0594 
0595     memcpy(&daddr, addr, sizeof(struct in_addr));
0596     rt = ip_route_output(net, daddr.s_addr, 0, 0, 0);
0597     if (IS_ERR(rt))
0598         return ERR_CAST(rt);
0599 
0600     dev = rt->dst.dev;
0601     dev_hold(dev);
0602 
0603     ip_rt_put(rt);
0604 
0605     return dev;
0606 }
0607 #else
0608 static struct net_device *inet_fib_lookup_dev(struct net *net,
0609                           const void *addr)
0610 {
0611     return ERR_PTR(-EAFNOSUPPORT);
0612 }
0613 #endif
0614 
0615 #if IS_ENABLED(CONFIG_IPV6)
0616 static struct net_device *inet6_fib_lookup_dev(struct net *net,
0617                            const void *addr)
0618 {
0619     struct net_device *dev;
0620     struct dst_entry *dst;
0621     struct flowi6 fl6;
0622 
0623     if (!ipv6_stub)
0624         return ERR_PTR(-EAFNOSUPPORT);
0625 
0626     memset(&fl6, 0, sizeof(fl6));
0627     memcpy(&fl6.daddr, addr, sizeof(struct in6_addr));
0628     dst = ipv6_stub->ipv6_dst_lookup_flow(net, NULL, &fl6, NULL);
0629     if (IS_ERR(dst))
0630         return ERR_CAST(dst);
0631 
0632     dev = dst->dev;
0633     dev_hold(dev);
0634     dst_release(dst);
0635 
0636     return dev;
0637 }
0638 #else
0639 static struct net_device *inet6_fib_lookup_dev(struct net *net,
0640                            const void *addr)
0641 {
0642     return ERR_PTR(-EAFNOSUPPORT);
0643 }
0644 #endif
0645 
0646 static struct net_device *find_outdev(struct net *net,
0647                       struct mpls_route *rt,
0648                       struct mpls_nh *nh, int oif)
0649 {
0650     struct net_device *dev = NULL;
0651 
0652     if (!oif) {
0653         switch (nh->nh_via_table) {
0654         case NEIGH_ARP_TABLE:
0655             dev = inet_fib_lookup_dev(net, mpls_nh_via(rt, nh));
0656             break;
0657         case NEIGH_ND_TABLE:
0658             dev = inet6_fib_lookup_dev(net, mpls_nh_via(rt, nh));
0659             break;
0660         case NEIGH_LINK_TABLE:
0661             break;
0662         }
0663     } else {
0664         dev = dev_get_by_index(net, oif);
0665     }
0666 
0667     if (!dev)
0668         return ERR_PTR(-ENODEV);
0669 
0670     if (IS_ERR(dev))
0671         return dev;
0672 
0673     /* The caller is holding rtnl anyways, so release the dev reference */
0674     dev_put(dev);
0675 
0676     return dev;
0677 }
0678 
0679 static int mpls_nh_assign_dev(struct net *net, struct mpls_route *rt,
0680                   struct mpls_nh *nh, int oif)
0681 {
0682     struct net_device *dev = NULL;
0683     int err = -ENODEV;
0684 
0685     dev = find_outdev(net, rt, nh, oif);
0686     if (IS_ERR(dev)) {
0687         err = PTR_ERR(dev);
0688         dev = NULL;
0689         goto errout;
0690     }
0691 
0692     /* Ensure this is a supported device */
0693     err = -EINVAL;
0694     if (!mpls_dev_get(dev))
0695         goto errout;
0696 
0697     if ((nh->nh_via_table == NEIGH_LINK_TABLE) &&
0698         (dev->addr_len != nh->nh_via_alen))
0699         goto errout;
0700 
0701     nh->nh_dev = dev;
0702 
0703     if (!(dev->flags & IFF_UP)) {
0704         nh->nh_flags |= RTNH_F_DEAD;
0705     } else {
0706         unsigned int flags;
0707 
0708         flags = dev_get_flags(dev);
0709         if (!(flags & (IFF_RUNNING | IFF_LOWER_UP)))
0710             nh->nh_flags |= RTNH_F_LINKDOWN;
0711     }
0712 
0713     return 0;
0714 
0715 errout:
0716     return err;
0717 }
0718 
0719 static int nla_get_via(const struct nlattr *nla, u8 *via_alen, u8 *via_table,
0720                u8 via_addr[], struct netlink_ext_ack *extack)
0721 {
0722     struct rtvia *via = nla_data(nla);
0723     int err = -EINVAL;
0724     int alen;
0725 
0726     if (nla_len(nla) < offsetof(struct rtvia, rtvia_addr)) {
0727         NL_SET_ERR_MSG_ATTR(extack, nla,
0728                     "Invalid attribute length for RTA_VIA");
0729         goto errout;
0730     }
0731     alen = nla_len(nla) -
0732             offsetof(struct rtvia, rtvia_addr);
0733     if (alen > MAX_VIA_ALEN) {
0734         NL_SET_ERR_MSG_ATTR(extack, nla,
0735                     "Invalid address length for RTA_VIA");
0736         goto errout;
0737     }
0738 
0739     /* Validate the address family */
0740     switch (via->rtvia_family) {
0741     case AF_PACKET:
0742         *via_table = NEIGH_LINK_TABLE;
0743         break;
0744     case AF_INET:
0745         *via_table = NEIGH_ARP_TABLE;
0746         if (alen != 4)
0747             goto errout;
0748         break;
0749     case AF_INET6:
0750         *via_table = NEIGH_ND_TABLE;
0751         if (alen != 16)
0752             goto errout;
0753         break;
0754     default:
0755         /* Unsupported address family */
0756         goto errout;
0757     }
0758 
0759     memcpy(via_addr, via->rtvia_addr, alen);
0760     *via_alen = alen;
0761     err = 0;
0762 
0763 errout:
0764     return err;
0765 }
0766 
0767 static int mpls_nh_build_from_cfg(struct mpls_route_config *cfg,
0768                   struct mpls_route *rt)
0769 {
0770     struct net *net = cfg->rc_nlinfo.nl_net;
0771     struct mpls_nh *nh = rt->rt_nh;
0772     int err;
0773     int i;
0774 
0775     if (!nh)
0776         return -ENOMEM;
0777 
0778     nh->nh_labels = cfg->rc_output_labels;
0779     for (i = 0; i < nh->nh_labels; i++)
0780         nh->nh_label[i] = cfg->rc_output_label[i];
0781 
0782     nh->nh_via_table = cfg->rc_via_table;
0783     memcpy(__mpls_nh_via(rt, nh), cfg->rc_via, cfg->rc_via_alen);
0784     nh->nh_via_alen = cfg->rc_via_alen;
0785 
0786     err = mpls_nh_assign_dev(net, rt, nh, cfg->rc_ifindex);
0787     if (err)
0788         goto errout;
0789 
0790     if (nh->nh_flags & (RTNH_F_DEAD | RTNH_F_LINKDOWN))
0791         rt->rt_nhn_alive--;
0792 
0793     return 0;
0794 
0795 errout:
0796     return err;
0797 }
0798 
0799 static int mpls_nh_build(struct net *net, struct mpls_route *rt,
0800              struct mpls_nh *nh, int oif, struct nlattr *via,
0801              struct nlattr *newdst, u8 max_labels,
0802              struct netlink_ext_ack *extack)
0803 {
0804     int err = -ENOMEM;
0805 
0806     if (!nh)
0807         goto errout;
0808 
0809     if (newdst) {
0810         err = nla_get_labels(newdst, max_labels, &nh->nh_labels,
0811                      nh->nh_label, extack);
0812         if (err)
0813             goto errout;
0814     }
0815 
0816     if (via) {
0817         err = nla_get_via(via, &nh->nh_via_alen, &nh->nh_via_table,
0818                   __mpls_nh_via(rt, nh), extack);
0819         if (err)
0820             goto errout;
0821     } else {
0822         nh->nh_via_table = MPLS_NEIGH_TABLE_UNSPEC;
0823     }
0824 
0825     err = mpls_nh_assign_dev(net, rt, nh, oif);
0826     if (err)
0827         goto errout;
0828 
0829     return 0;
0830 
0831 errout:
0832     return err;
0833 }
0834 
0835 static u8 mpls_count_nexthops(struct rtnexthop *rtnh, int len,
0836                   u8 cfg_via_alen, u8 *max_via_alen,
0837                   u8 *max_labels)
0838 {
0839     int remaining = len;
0840     u8 nhs = 0;
0841 
0842     *max_via_alen = 0;
0843     *max_labels = 0;
0844 
0845     while (rtnh_ok(rtnh, remaining)) {
0846         struct nlattr *nla, *attrs = rtnh_attrs(rtnh);
0847         int attrlen;
0848         u8 n_labels = 0;
0849 
0850         attrlen = rtnh_attrlen(rtnh);
0851         nla = nla_find(attrs, attrlen, RTA_VIA);
0852         if (nla && nla_len(nla) >=
0853             offsetof(struct rtvia, rtvia_addr)) {
0854             int via_alen = nla_len(nla) -
0855                 offsetof(struct rtvia, rtvia_addr);
0856 
0857             if (via_alen <= MAX_VIA_ALEN)
0858                 *max_via_alen = max_t(u16, *max_via_alen,
0859                               via_alen);
0860         }
0861 
0862         nla = nla_find(attrs, attrlen, RTA_NEWDST);
0863         if (nla &&
0864             nla_get_labels(nla, MAX_NEW_LABELS, &n_labels,
0865                    NULL, NULL) != 0)
0866             return 0;
0867 
0868         *max_labels = max_t(u8, *max_labels, n_labels);
0869 
0870         /* number of nexthops is tracked by a u8.
0871          * Check for overflow.
0872          */
0873         if (nhs == 255)
0874             return 0;
0875         nhs++;
0876 
0877         rtnh = rtnh_next(rtnh, &remaining);
0878     }
0879 
0880     /* leftover implies invalid nexthop configuration, discard it */
0881     return remaining > 0 ? 0 : nhs;
0882 }
0883 
0884 static int mpls_nh_build_multi(struct mpls_route_config *cfg,
0885                    struct mpls_route *rt, u8 max_labels,
0886                    struct netlink_ext_ack *extack)
0887 {
0888     struct rtnexthop *rtnh = cfg->rc_mp;
0889     struct nlattr *nla_via, *nla_newdst;
0890     int remaining = cfg->rc_mp_len;
0891     int err = 0;
0892     u8 nhs = 0;
0893 
0894     change_nexthops(rt) {
0895         int attrlen;
0896 
0897         nla_via = NULL;
0898         nla_newdst = NULL;
0899 
0900         err = -EINVAL;
0901         if (!rtnh_ok(rtnh, remaining))
0902             goto errout;
0903 
0904         /* neither weighted multipath nor any flags
0905          * are supported
0906          */
0907         if (rtnh->rtnh_hops || rtnh->rtnh_flags)
0908             goto errout;
0909 
0910         attrlen = rtnh_attrlen(rtnh);
0911         if (attrlen > 0) {
0912             struct nlattr *attrs = rtnh_attrs(rtnh);
0913 
0914             nla_via = nla_find(attrs, attrlen, RTA_VIA);
0915             nla_newdst = nla_find(attrs, attrlen, RTA_NEWDST);
0916         }
0917 
0918         err = mpls_nh_build(cfg->rc_nlinfo.nl_net, rt, nh,
0919                     rtnh->rtnh_ifindex, nla_via, nla_newdst,
0920                     max_labels, extack);
0921         if (err)
0922             goto errout;
0923 
0924         if (nh->nh_flags & (RTNH_F_DEAD | RTNH_F_LINKDOWN))
0925             rt->rt_nhn_alive--;
0926 
0927         rtnh = rtnh_next(rtnh, &remaining);
0928         nhs++;
0929     } endfor_nexthops(rt);
0930 
0931     rt->rt_nhn = nhs;
0932 
0933     return 0;
0934 
0935 errout:
0936     return err;
0937 }
0938 
0939 static bool mpls_label_ok(struct net *net, unsigned int *index,
0940               struct netlink_ext_ack *extack)
0941 {
0942     bool is_ok = true;
0943 
0944     /* Reserved labels may not be set */
0945     if (*index < MPLS_LABEL_FIRST_UNRESERVED) {
0946         NL_SET_ERR_MSG(extack,
0947                    "Invalid label - must be MPLS_LABEL_FIRST_UNRESERVED or higher");
0948         is_ok = false;
0949     }
0950 
0951     /* The full 20 bit range may not be supported. */
0952     if (is_ok && *index >= net->mpls.platform_labels) {
0953         NL_SET_ERR_MSG(extack,
0954                    "Label >= configured maximum in platform_labels");
0955         is_ok = false;
0956     }
0957 
0958     *index = array_index_nospec(*index, net->mpls.platform_labels);
0959     return is_ok;
0960 }
0961 
0962 static int mpls_route_add(struct mpls_route_config *cfg,
0963               struct netlink_ext_ack *extack)
0964 {
0965     struct mpls_route __rcu **platform_label;
0966     struct net *net = cfg->rc_nlinfo.nl_net;
0967     struct mpls_route *rt, *old;
0968     int err = -EINVAL;
0969     u8 max_via_alen;
0970     unsigned index;
0971     u8 max_labels;
0972     u8 nhs;
0973 
0974     index = cfg->rc_label;
0975 
0976     /* If a label was not specified during insert pick one */
0977     if ((index == LABEL_NOT_SPECIFIED) &&
0978         (cfg->rc_nlflags & NLM_F_CREATE)) {
0979         index = find_free_label(net);
0980     }
0981 
0982     if (!mpls_label_ok(net, &index, extack))
0983         goto errout;
0984 
0985     /* Append makes no sense with mpls */
0986     err = -EOPNOTSUPP;
0987     if (cfg->rc_nlflags & NLM_F_APPEND) {
0988         NL_SET_ERR_MSG(extack, "MPLS does not support route append");
0989         goto errout;
0990     }
0991 
0992     err = -EEXIST;
0993     platform_label = rtnl_dereference(net->mpls.platform_label);
0994     old = rtnl_dereference(platform_label[index]);
0995     if ((cfg->rc_nlflags & NLM_F_EXCL) && old)
0996         goto errout;
0997 
0998     err = -EEXIST;
0999     if (!(cfg->rc_nlflags & NLM_F_REPLACE) && old)
1000         goto errout;
1001 
1002     err = -ENOENT;
1003     if (!(cfg->rc_nlflags & NLM_F_CREATE) && !old)
1004         goto errout;
1005 
1006     err = -EINVAL;
1007     if (cfg->rc_mp) {
1008         nhs = mpls_count_nexthops(cfg->rc_mp, cfg->rc_mp_len,
1009                       cfg->rc_via_alen, &max_via_alen,
1010                       &max_labels);
1011     } else {
1012         max_via_alen = cfg->rc_via_alen;
1013         max_labels = cfg->rc_output_labels;
1014         nhs = 1;
1015     }
1016 
1017     if (nhs == 0) {
1018         NL_SET_ERR_MSG(extack, "Route does not contain a nexthop");
1019         goto errout;
1020     }
1021 
1022     rt = mpls_rt_alloc(nhs, max_via_alen, max_labels);
1023     if (IS_ERR(rt)) {
1024         err = PTR_ERR(rt);
1025         goto errout;
1026     }
1027 
1028     rt->rt_protocol = cfg->rc_protocol;
1029     rt->rt_payload_type = cfg->rc_payload_type;
1030     rt->rt_ttl_propagate = cfg->rc_ttl_propagate;
1031 
1032     if (cfg->rc_mp)
1033         err = mpls_nh_build_multi(cfg, rt, max_labels, extack);
1034     else
1035         err = mpls_nh_build_from_cfg(cfg, rt);
1036     if (err)
1037         goto freert;
1038 
1039     mpls_route_update(net, index, rt, &cfg->rc_nlinfo);
1040 
1041     return 0;
1042 
1043 freert:
1044     mpls_rt_free(rt);
1045 errout:
1046     return err;
1047 }
1048 
1049 static int mpls_route_del(struct mpls_route_config *cfg,
1050               struct netlink_ext_ack *extack)
1051 {
1052     struct net *net = cfg->rc_nlinfo.nl_net;
1053     unsigned index;
1054     int err = -EINVAL;
1055 
1056     index = cfg->rc_label;
1057 
1058     if (!mpls_label_ok(net, &index, extack))
1059         goto errout;
1060 
1061     mpls_route_update(net, index, NULL, &cfg->rc_nlinfo);
1062 
1063     err = 0;
1064 errout:
1065     return err;
1066 }
1067 
1068 static void mpls_get_stats(struct mpls_dev *mdev,
1069                struct mpls_link_stats *stats)
1070 {
1071     struct mpls_pcpu_stats *p;
1072     int i;
1073 
1074     memset(stats, 0, sizeof(*stats));
1075 
1076     for_each_possible_cpu(i) {
1077         struct mpls_link_stats local;
1078         unsigned int start;
1079 
1080         p = per_cpu_ptr(mdev->stats, i);
1081         do {
1082             start = u64_stats_fetch_begin_irq(&p->syncp);
1083             local = p->stats;
1084         } while (u64_stats_fetch_retry_irq(&p->syncp, start));
1085 
1086         stats->rx_packets   += local.rx_packets;
1087         stats->rx_bytes     += local.rx_bytes;
1088         stats->tx_packets   += local.tx_packets;
1089         stats->tx_bytes     += local.tx_bytes;
1090         stats->rx_errors    += local.rx_errors;
1091         stats->tx_errors    += local.tx_errors;
1092         stats->rx_dropped   += local.rx_dropped;
1093         stats->tx_dropped   += local.tx_dropped;
1094         stats->rx_noroute   += local.rx_noroute;
1095     }
1096 }
1097 
1098 static int mpls_fill_stats_af(struct sk_buff *skb,
1099                   const struct net_device *dev)
1100 {
1101     struct mpls_link_stats *stats;
1102     struct mpls_dev *mdev;
1103     struct nlattr *nla;
1104 
1105     mdev = mpls_dev_get(dev);
1106     if (!mdev)
1107         return -ENODATA;
1108 
1109     nla = nla_reserve_64bit(skb, MPLS_STATS_LINK,
1110                 sizeof(struct mpls_link_stats),
1111                 MPLS_STATS_UNSPEC);
1112     if (!nla)
1113         return -EMSGSIZE;
1114 
1115     stats = nla_data(nla);
1116     mpls_get_stats(mdev, stats);
1117 
1118     return 0;
1119 }
1120 
1121 static size_t mpls_get_stats_af_size(const struct net_device *dev)
1122 {
1123     struct mpls_dev *mdev;
1124 
1125     mdev = mpls_dev_get(dev);
1126     if (!mdev)
1127         return 0;
1128 
1129     return nla_total_size_64bit(sizeof(struct mpls_link_stats));
1130 }
1131 
1132 static int mpls_netconf_fill_devconf(struct sk_buff *skb, struct mpls_dev *mdev,
1133                      u32 portid, u32 seq, int event,
1134                      unsigned int flags, int type)
1135 {
1136     struct nlmsghdr  *nlh;
1137     struct netconfmsg *ncm;
1138     bool all = false;
1139 
1140     nlh = nlmsg_put(skb, portid, seq, event, sizeof(struct netconfmsg),
1141             flags);
1142     if (!nlh)
1143         return -EMSGSIZE;
1144 
1145     if (type == NETCONFA_ALL)
1146         all = true;
1147 
1148     ncm = nlmsg_data(nlh);
1149     ncm->ncm_family = AF_MPLS;
1150 
1151     if (nla_put_s32(skb, NETCONFA_IFINDEX, mdev->dev->ifindex) < 0)
1152         goto nla_put_failure;
1153 
1154     if ((all || type == NETCONFA_INPUT) &&
1155         nla_put_s32(skb, NETCONFA_INPUT,
1156             mdev->input_enabled) < 0)
1157         goto nla_put_failure;
1158 
1159     nlmsg_end(skb, nlh);
1160     return 0;
1161 
1162 nla_put_failure:
1163     nlmsg_cancel(skb, nlh);
1164     return -EMSGSIZE;
1165 }
1166 
1167 static int mpls_netconf_msgsize_devconf(int type)
1168 {
1169     int size = NLMSG_ALIGN(sizeof(struct netconfmsg))
1170             + nla_total_size(4); /* NETCONFA_IFINDEX */
1171     bool all = false;
1172 
1173     if (type == NETCONFA_ALL)
1174         all = true;
1175 
1176     if (all || type == NETCONFA_INPUT)
1177         size += nla_total_size(4);
1178 
1179     return size;
1180 }
1181 
1182 static void mpls_netconf_notify_devconf(struct net *net, int event,
1183                     int type, struct mpls_dev *mdev)
1184 {
1185     struct sk_buff *skb;
1186     int err = -ENOBUFS;
1187 
1188     skb = nlmsg_new(mpls_netconf_msgsize_devconf(type), GFP_KERNEL);
1189     if (!skb)
1190         goto errout;
1191 
1192     err = mpls_netconf_fill_devconf(skb, mdev, 0, 0, event, 0, type);
1193     if (err < 0) {
1194         /* -EMSGSIZE implies BUG in mpls_netconf_msgsize_devconf() */
1195         WARN_ON(err == -EMSGSIZE);
1196         kfree_skb(skb);
1197         goto errout;
1198     }
1199 
1200     rtnl_notify(skb, net, 0, RTNLGRP_MPLS_NETCONF, NULL, GFP_KERNEL);
1201     return;
1202 errout:
1203     if (err < 0)
1204         rtnl_set_sk_err(net, RTNLGRP_MPLS_NETCONF, err);
1205 }
1206 
1207 static const struct nla_policy devconf_mpls_policy[NETCONFA_MAX + 1] = {
1208     [NETCONFA_IFINDEX]  = { .len = sizeof(int) },
1209 };
1210 
1211 static int mpls_netconf_valid_get_req(struct sk_buff *skb,
1212                       const struct nlmsghdr *nlh,
1213                       struct nlattr **tb,
1214                       struct netlink_ext_ack *extack)
1215 {
1216     int i, err;
1217 
1218     if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(struct netconfmsg))) {
1219         NL_SET_ERR_MSG_MOD(extack,
1220                    "Invalid header for netconf get request");
1221         return -EINVAL;
1222     }
1223 
1224     if (!netlink_strict_get_check(skb))
1225         return nlmsg_parse_deprecated(nlh, sizeof(struct netconfmsg),
1226                           tb, NETCONFA_MAX,
1227                           devconf_mpls_policy, extack);
1228 
1229     err = nlmsg_parse_deprecated_strict(nlh, sizeof(struct netconfmsg),
1230                         tb, NETCONFA_MAX,
1231                         devconf_mpls_policy, extack);
1232     if (err)
1233         return err;
1234 
1235     for (i = 0; i <= NETCONFA_MAX; i++) {
1236         if (!tb[i])
1237             continue;
1238 
1239         switch (i) {
1240         case NETCONFA_IFINDEX:
1241             break;
1242         default:
1243             NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in netconf get request");
1244             return -EINVAL;
1245         }
1246     }
1247 
1248     return 0;
1249 }
1250 
1251 static int mpls_netconf_get_devconf(struct sk_buff *in_skb,
1252                     struct nlmsghdr *nlh,
1253                     struct netlink_ext_ack *extack)
1254 {
1255     struct net *net = sock_net(in_skb->sk);
1256     struct nlattr *tb[NETCONFA_MAX + 1];
1257     struct net_device *dev;
1258     struct mpls_dev *mdev;
1259     struct sk_buff *skb;
1260     int ifindex;
1261     int err;
1262 
1263     err = mpls_netconf_valid_get_req(in_skb, nlh, tb, extack);
1264     if (err < 0)
1265         goto errout;
1266 
1267     err = -EINVAL;
1268     if (!tb[NETCONFA_IFINDEX])
1269         goto errout;
1270 
1271     ifindex = nla_get_s32(tb[NETCONFA_IFINDEX]);
1272     dev = __dev_get_by_index(net, ifindex);
1273     if (!dev)
1274         goto errout;
1275 
1276     mdev = mpls_dev_get(dev);
1277     if (!mdev)
1278         goto errout;
1279 
1280     err = -ENOBUFS;
1281     skb = nlmsg_new(mpls_netconf_msgsize_devconf(NETCONFA_ALL), GFP_KERNEL);
1282     if (!skb)
1283         goto errout;
1284 
1285     err = mpls_netconf_fill_devconf(skb, mdev,
1286                     NETLINK_CB(in_skb).portid,
1287                     nlh->nlmsg_seq, RTM_NEWNETCONF, 0,
1288                     NETCONFA_ALL);
1289     if (err < 0) {
1290         /* -EMSGSIZE implies BUG in mpls_netconf_msgsize_devconf() */
1291         WARN_ON(err == -EMSGSIZE);
1292         kfree_skb(skb);
1293         goto errout;
1294     }
1295     err = rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid);
1296 errout:
1297     return err;
1298 }
1299 
1300 static int mpls_netconf_dump_devconf(struct sk_buff *skb,
1301                      struct netlink_callback *cb)
1302 {
1303     const struct nlmsghdr *nlh = cb->nlh;
1304     struct net *net = sock_net(skb->sk);
1305     struct hlist_head *head;
1306     struct net_device *dev;
1307     struct mpls_dev *mdev;
1308     int idx, s_idx;
1309     int h, s_h;
1310 
1311     if (cb->strict_check) {
1312         struct netlink_ext_ack *extack = cb->extack;
1313         struct netconfmsg *ncm;
1314 
1315         if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ncm))) {
1316             NL_SET_ERR_MSG_MOD(extack, "Invalid header for netconf dump request");
1317             return -EINVAL;
1318         }
1319 
1320         if (nlmsg_attrlen(nlh, sizeof(*ncm))) {
1321             NL_SET_ERR_MSG_MOD(extack, "Invalid data after header in netconf dump request");
1322             return -EINVAL;
1323         }
1324     }
1325 
1326     s_h = cb->args[0];
1327     s_idx = idx = cb->args[1];
1328 
1329     for (h = s_h; h < NETDEV_HASHENTRIES; h++, s_idx = 0) {
1330         idx = 0;
1331         head = &net->dev_index_head[h];
1332         rcu_read_lock();
1333         cb->seq = net->dev_base_seq;
1334         hlist_for_each_entry_rcu(dev, head, index_hlist) {
1335             if (idx < s_idx)
1336                 goto cont;
1337             mdev = mpls_dev_get(dev);
1338             if (!mdev)
1339                 goto cont;
1340             if (mpls_netconf_fill_devconf(skb, mdev,
1341                               NETLINK_CB(cb->skb).portid,
1342                               nlh->nlmsg_seq,
1343                               RTM_NEWNETCONF,
1344                               NLM_F_MULTI,
1345                               NETCONFA_ALL) < 0) {
1346                 rcu_read_unlock();
1347                 goto done;
1348             }
1349             nl_dump_check_consistent(cb, nlmsg_hdr(skb));
1350 cont:
1351             idx++;
1352         }
1353         rcu_read_unlock();
1354     }
1355 done:
1356     cb->args[0] = h;
1357     cb->args[1] = idx;
1358 
1359     return skb->len;
1360 }
1361 
1362 #define MPLS_PERDEV_SYSCTL_OFFSET(field)    \
1363     (&((struct mpls_dev *)0)->field)
1364 
1365 static int mpls_conf_proc(struct ctl_table *ctl, int write,
1366               void *buffer, size_t *lenp, loff_t *ppos)
1367 {
1368     int oval = *(int *)ctl->data;
1369     int ret = proc_dointvec(ctl, write, buffer, lenp, ppos);
1370 
1371     if (write) {
1372         struct mpls_dev *mdev = ctl->extra1;
1373         int i = (int *)ctl->data - (int *)mdev;
1374         struct net *net = ctl->extra2;
1375         int val = *(int *)ctl->data;
1376 
1377         if (i == offsetof(struct mpls_dev, input_enabled) &&
1378             val != oval) {
1379             mpls_netconf_notify_devconf(net, RTM_NEWNETCONF,
1380                             NETCONFA_INPUT, mdev);
1381         }
1382     }
1383 
1384     return ret;
1385 }
1386 
1387 static const struct ctl_table mpls_dev_table[] = {
1388     {
1389         .procname   = "input",
1390         .maxlen     = sizeof(int),
1391         .mode       = 0644,
1392         .proc_handler   = mpls_conf_proc,
1393         .data       = MPLS_PERDEV_SYSCTL_OFFSET(input_enabled),
1394     },
1395     { }
1396 };
1397 
1398 static int mpls_dev_sysctl_register(struct net_device *dev,
1399                     struct mpls_dev *mdev)
1400 {
1401     char path[sizeof("net/mpls/conf/") + IFNAMSIZ];
1402     struct net *net = dev_net(dev);
1403     struct ctl_table *table;
1404     int i;
1405 
1406     table = kmemdup(&mpls_dev_table, sizeof(mpls_dev_table), GFP_KERNEL);
1407     if (!table)
1408         goto out;
1409 
1410     /* Table data contains only offsets relative to the base of
1411      * the mdev at this point, so make them absolute.
1412      */
1413     for (i = 0; i < ARRAY_SIZE(mpls_dev_table); i++) {
1414         table[i].data = (char *)mdev + (uintptr_t)table[i].data;
1415         table[i].extra1 = mdev;
1416         table[i].extra2 = net;
1417     }
1418 
1419     snprintf(path, sizeof(path), "net/mpls/conf/%s", dev->name);
1420 
1421     mdev->sysctl = register_net_sysctl(net, path, table);
1422     if (!mdev->sysctl)
1423         goto free;
1424 
1425     mpls_netconf_notify_devconf(net, RTM_NEWNETCONF, NETCONFA_ALL, mdev);
1426     return 0;
1427 
1428 free:
1429     kfree(table);
1430 out:
1431     return -ENOBUFS;
1432 }
1433 
1434 static void mpls_dev_sysctl_unregister(struct net_device *dev,
1435                        struct mpls_dev *mdev)
1436 {
1437     struct net *net = dev_net(dev);
1438     struct ctl_table *table;
1439 
1440     table = mdev->sysctl->ctl_table_arg;
1441     unregister_net_sysctl_table(mdev->sysctl);
1442     kfree(table);
1443 
1444     mpls_netconf_notify_devconf(net, RTM_DELNETCONF, 0, mdev);
1445 }
1446 
1447 static struct mpls_dev *mpls_add_dev(struct net_device *dev)
1448 {
1449     struct mpls_dev *mdev;
1450     int err = -ENOMEM;
1451     int i;
1452 
1453     ASSERT_RTNL();
1454 
1455     mdev = kzalloc(sizeof(*mdev), GFP_KERNEL);
1456     if (!mdev)
1457         return ERR_PTR(err);
1458 
1459     mdev->stats = alloc_percpu(struct mpls_pcpu_stats);
1460     if (!mdev->stats)
1461         goto free;
1462 
1463     for_each_possible_cpu(i) {
1464         struct mpls_pcpu_stats *mpls_stats;
1465 
1466         mpls_stats = per_cpu_ptr(mdev->stats, i);
1467         u64_stats_init(&mpls_stats->syncp);
1468     }
1469 
1470     mdev->dev = dev;
1471 
1472     err = mpls_dev_sysctl_register(dev, mdev);
1473     if (err)
1474         goto free;
1475 
1476     rcu_assign_pointer(dev->mpls_ptr, mdev);
1477 
1478     return mdev;
1479 
1480 free:
1481     free_percpu(mdev->stats);
1482     kfree(mdev);
1483     return ERR_PTR(err);
1484 }
1485 
1486 static void mpls_dev_destroy_rcu(struct rcu_head *head)
1487 {
1488     struct mpls_dev *mdev = container_of(head, struct mpls_dev, rcu);
1489 
1490     free_percpu(mdev->stats);
1491     kfree(mdev);
1492 }
1493 
1494 static int mpls_ifdown(struct net_device *dev, int event)
1495 {
1496     struct mpls_route __rcu **platform_label;
1497     struct net *net = dev_net(dev);
1498     unsigned index;
1499 
1500     platform_label = rtnl_dereference(net->mpls.platform_label);
1501     for (index = 0; index < net->mpls.platform_labels; index++) {
1502         struct mpls_route *rt = rtnl_dereference(platform_label[index]);
1503         bool nh_del = false;
1504         u8 alive = 0;
1505 
1506         if (!rt)
1507             continue;
1508 
1509         if (event == NETDEV_UNREGISTER) {
1510             u8 deleted = 0;
1511 
1512             for_nexthops(rt) {
1513                 if (!nh->nh_dev || nh->nh_dev == dev)
1514                     deleted++;
1515                 if (nh->nh_dev == dev)
1516                     nh_del = true;
1517             } endfor_nexthops(rt);
1518 
1519             /* if there are no more nexthops, delete the route */
1520             if (deleted == rt->rt_nhn) {
1521                 mpls_route_update(net, index, NULL, NULL);
1522                 continue;
1523             }
1524 
1525             if (nh_del) {
1526                 size_t size = sizeof(*rt) + rt->rt_nhn *
1527                     rt->rt_nh_size;
1528                 struct mpls_route *orig = rt;
1529 
1530                 rt = kmemdup(orig, size, GFP_KERNEL);
1531                 if (!rt)
1532                     return -ENOMEM;
1533             }
1534         }
1535 
1536         change_nexthops(rt) {
1537             unsigned int nh_flags = nh->nh_flags;
1538 
1539             if (nh->nh_dev != dev)
1540                 goto next;
1541 
1542             switch (event) {
1543             case NETDEV_DOWN:
1544             case NETDEV_UNREGISTER:
1545                 nh_flags |= RTNH_F_DEAD;
1546                 fallthrough;
1547             case NETDEV_CHANGE:
1548                 nh_flags |= RTNH_F_LINKDOWN;
1549                 break;
1550             }
1551             if (event == NETDEV_UNREGISTER)
1552                 nh->nh_dev = NULL;
1553 
1554             if (nh->nh_flags != nh_flags)
1555                 WRITE_ONCE(nh->nh_flags, nh_flags);
1556 next:
1557             if (!(nh_flags & (RTNH_F_DEAD | RTNH_F_LINKDOWN)))
1558                 alive++;
1559         } endfor_nexthops(rt);
1560 
1561         WRITE_ONCE(rt->rt_nhn_alive, alive);
1562 
1563         if (nh_del)
1564             mpls_route_update(net, index, rt, NULL);
1565     }
1566 
1567     return 0;
1568 }
1569 
1570 static void mpls_ifup(struct net_device *dev, unsigned int flags)
1571 {
1572     struct mpls_route __rcu **platform_label;
1573     struct net *net = dev_net(dev);
1574     unsigned index;
1575     u8 alive;
1576 
1577     platform_label = rtnl_dereference(net->mpls.platform_label);
1578     for (index = 0; index < net->mpls.platform_labels; index++) {
1579         struct mpls_route *rt = rtnl_dereference(platform_label[index]);
1580 
1581         if (!rt)
1582             continue;
1583 
1584         alive = 0;
1585         change_nexthops(rt) {
1586             unsigned int nh_flags = nh->nh_flags;
1587 
1588             if (!(nh_flags & flags)) {
1589                 alive++;
1590                 continue;
1591             }
1592             if (nh->nh_dev != dev)
1593                 continue;
1594             alive++;
1595             nh_flags &= ~flags;
1596             WRITE_ONCE(nh->nh_flags, nh_flags);
1597         } endfor_nexthops(rt);
1598 
1599         WRITE_ONCE(rt->rt_nhn_alive, alive);
1600     }
1601 }
1602 
1603 static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
1604                void *ptr)
1605 {
1606     struct net_device *dev = netdev_notifier_info_to_dev(ptr);
1607     struct mpls_dev *mdev;
1608     unsigned int flags;
1609     int err;
1610 
1611     if (event == NETDEV_REGISTER) {
1612         mdev = mpls_add_dev(dev);
1613         if (IS_ERR(mdev))
1614             return notifier_from_errno(PTR_ERR(mdev));
1615 
1616         return NOTIFY_OK;
1617     }
1618 
1619     mdev = mpls_dev_get(dev);
1620     if (!mdev)
1621         return NOTIFY_OK;
1622 
1623     switch (event) {
1624 
1625     case NETDEV_DOWN:
1626         err = mpls_ifdown(dev, event);
1627         if (err)
1628             return notifier_from_errno(err);
1629         break;
1630     case NETDEV_UP:
1631         flags = dev_get_flags(dev);
1632         if (flags & (IFF_RUNNING | IFF_LOWER_UP))
1633             mpls_ifup(dev, RTNH_F_DEAD | RTNH_F_LINKDOWN);
1634         else
1635             mpls_ifup(dev, RTNH_F_DEAD);
1636         break;
1637     case NETDEV_CHANGE:
1638         flags = dev_get_flags(dev);
1639         if (flags & (IFF_RUNNING | IFF_LOWER_UP)) {
1640             mpls_ifup(dev, RTNH_F_DEAD | RTNH_F_LINKDOWN);
1641         } else {
1642             err = mpls_ifdown(dev, event);
1643             if (err)
1644                 return notifier_from_errno(err);
1645         }
1646         break;
1647     case NETDEV_UNREGISTER:
1648         err = mpls_ifdown(dev, event);
1649         if (err)
1650             return notifier_from_errno(err);
1651         mdev = mpls_dev_get(dev);
1652         if (mdev) {
1653             mpls_dev_sysctl_unregister(dev, mdev);
1654             RCU_INIT_POINTER(dev->mpls_ptr, NULL);
1655             call_rcu(&mdev->rcu, mpls_dev_destroy_rcu);
1656         }
1657         break;
1658     case NETDEV_CHANGENAME:
1659         mdev = mpls_dev_get(dev);
1660         if (mdev) {
1661             mpls_dev_sysctl_unregister(dev, mdev);
1662             err = mpls_dev_sysctl_register(dev, mdev);
1663             if (err)
1664                 return notifier_from_errno(err);
1665         }
1666         break;
1667     }
1668     return NOTIFY_OK;
1669 }
1670 
1671 static struct notifier_block mpls_dev_notifier = {
1672     .notifier_call = mpls_dev_notify,
1673 };
1674 
1675 static int nla_put_via(struct sk_buff *skb,
1676                u8 table, const void *addr, int alen)
1677 {
1678     static const int table_to_family[NEIGH_NR_TABLES + 1] = {
1679         AF_INET, AF_INET6, AF_DECnet, AF_PACKET,
1680     };
1681     struct nlattr *nla;
1682     struct rtvia *via;
1683     int family = AF_UNSPEC;
1684 
1685     nla = nla_reserve(skb, RTA_VIA, alen + 2);
1686     if (!nla)
1687         return -EMSGSIZE;
1688 
1689     if (table <= NEIGH_NR_TABLES)
1690         family = table_to_family[table];
1691 
1692     via = nla_data(nla);
1693     via->rtvia_family = family;
1694     memcpy(via->rtvia_addr, addr, alen);
1695     return 0;
1696 }
1697 
1698 int nla_put_labels(struct sk_buff *skb, int attrtype,
1699            u8 labels, const u32 label[])
1700 {
1701     struct nlattr *nla;
1702     struct mpls_shim_hdr *nla_label;
1703     bool bos;
1704     int i;
1705     nla = nla_reserve(skb, attrtype, labels*4);
1706     if (!nla)
1707         return -EMSGSIZE;
1708 
1709     nla_label = nla_data(nla);
1710     bos = true;
1711     for (i = labels - 1; i >= 0; i--) {
1712         nla_label[i] = mpls_entry_encode(label[i], 0, 0, bos);
1713         bos = false;
1714     }
1715 
1716     return 0;
1717 }
1718 EXPORT_SYMBOL_GPL(nla_put_labels);
1719 
1720 int nla_get_labels(const struct nlattr *nla, u8 max_labels, u8 *labels,
1721            u32 label[], struct netlink_ext_ack *extack)
1722 {
1723     unsigned len = nla_len(nla);
1724     struct mpls_shim_hdr *nla_label;
1725     u8 nla_labels;
1726     bool bos;
1727     int i;
1728 
1729     /* len needs to be an even multiple of 4 (the label size). Number
1730      * of labels is a u8 so check for overflow.
1731      */
1732     if (len & 3 || len / 4 > 255) {
1733         NL_SET_ERR_MSG_ATTR(extack, nla,
1734                     "Invalid length for labels attribute");
1735         return -EINVAL;
1736     }
1737 
1738     /* Limit the number of new labels allowed */
1739     nla_labels = len/4;
1740     if (nla_labels > max_labels) {
1741         NL_SET_ERR_MSG(extack, "Too many labels");
1742         return -EINVAL;
1743     }
1744 
1745     /* when label == NULL, caller wants number of labels */
1746     if (!label)
1747         goto out;
1748 
1749     nla_label = nla_data(nla);
1750     bos = true;
1751     for (i = nla_labels - 1; i >= 0; i--, bos = false) {
1752         struct mpls_entry_decoded dec;
1753         dec = mpls_entry_decode(nla_label + i);
1754 
1755         /* Ensure the bottom of stack flag is properly set
1756          * and ttl and tc are both clear.
1757          */
1758         if (dec.ttl) {
1759             NL_SET_ERR_MSG_ATTR(extack, nla,
1760                         "TTL in label must be 0");
1761             return -EINVAL;
1762         }
1763 
1764         if (dec.tc) {
1765             NL_SET_ERR_MSG_ATTR(extack, nla,
1766                         "Traffic class in label must be 0");
1767             return -EINVAL;
1768         }
1769 
1770         if (dec.bos != bos) {
1771             NL_SET_BAD_ATTR(extack, nla);
1772             if (bos) {
1773                 NL_SET_ERR_MSG(extack,
1774                            "BOS bit must be set in first label");
1775             } else {
1776                 NL_SET_ERR_MSG(extack,
1777                            "BOS bit can only be set in first label");
1778             }
1779             return -EINVAL;
1780         }
1781 
1782         switch (dec.label) {
1783         case MPLS_LABEL_IMPLNULL:
1784             /* RFC3032: This is a label that an LSR may
1785              * assign and distribute, but which never
1786              * actually appears in the encapsulation.
1787              */
1788             NL_SET_ERR_MSG_ATTR(extack, nla,
1789                         "Implicit NULL Label (3) can not be used in encapsulation");
1790             return -EINVAL;
1791         }
1792 
1793         label[i] = dec.label;
1794     }
1795 out:
1796     *labels = nla_labels;
1797     return 0;
1798 }
1799 EXPORT_SYMBOL_GPL(nla_get_labels);
1800 
1801 static int rtm_to_route_config(struct sk_buff *skb,
1802                    struct nlmsghdr *nlh,
1803                    struct mpls_route_config *cfg,
1804                    struct netlink_ext_ack *extack)
1805 {
1806     struct rtmsg *rtm;
1807     struct nlattr *tb[RTA_MAX+1];
1808     int index;
1809     int err;
1810 
1811     err = nlmsg_parse_deprecated(nlh, sizeof(*rtm), tb, RTA_MAX,
1812                      rtm_mpls_policy, extack);
1813     if (err < 0)
1814         goto errout;
1815 
1816     err = -EINVAL;
1817     rtm = nlmsg_data(nlh);
1818 
1819     if (rtm->rtm_family != AF_MPLS) {
1820         NL_SET_ERR_MSG(extack, "Invalid address family in rtmsg");
1821         goto errout;
1822     }
1823     if (rtm->rtm_dst_len != 20) {
1824         NL_SET_ERR_MSG(extack, "rtm_dst_len must be 20 for MPLS");
1825         goto errout;
1826     }
1827     if (rtm->rtm_src_len != 0) {
1828         NL_SET_ERR_MSG(extack, "rtm_src_len must be 0 for MPLS");
1829         goto errout;
1830     }
1831     if (rtm->rtm_tos != 0) {
1832         NL_SET_ERR_MSG(extack, "rtm_tos must be 0 for MPLS");
1833         goto errout;
1834     }
1835     if (rtm->rtm_table != RT_TABLE_MAIN) {
1836         NL_SET_ERR_MSG(extack,
1837                    "MPLS only supports the main route table");
1838         goto errout;
1839     }
1840     /* Any value is acceptable for rtm_protocol */
1841 
1842     /* As mpls uses destination specific addresses
1843      * (or source specific address in the case of multicast)
1844      * all addresses have universal scope.
1845      */
1846     if (rtm->rtm_scope != RT_SCOPE_UNIVERSE) {
1847         NL_SET_ERR_MSG(extack,
1848                    "Invalid route scope  - MPLS only supports UNIVERSE");
1849         goto errout;
1850     }
1851     if (rtm->rtm_type != RTN_UNICAST) {
1852         NL_SET_ERR_MSG(extack,
1853                    "Invalid route type - MPLS only supports UNICAST");
1854         goto errout;
1855     }
1856     if (rtm->rtm_flags != 0) {
1857         NL_SET_ERR_MSG(extack, "rtm_flags must be 0 for MPLS");
1858         goto errout;
1859     }
1860 
1861     cfg->rc_label       = LABEL_NOT_SPECIFIED;
1862     cfg->rc_protocol    = rtm->rtm_protocol;
1863     cfg->rc_via_table   = MPLS_NEIGH_TABLE_UNSPEC;
1864     cfg->rc_ttl_propagate   = MPLS_TTL_PROP_DEFAULT;
1865     cfg->rc_nlflags     = nlh->nlmsg_flags;
1866     cfg->rc_nlinfo.portid   = NETLINK_CB(skb).portid;
1867     cfg->rc_nlinfo.nlh  = nlh;
1868     cfg->rc_nlinfo.nl_net   = sock_net(skb->sk);
1869 
1870     for (index = 0; index <= RTA_MAX; index++) {
1871         struct nlattr *nla = tb[index];
1872         if (!nla)
1873             continue;
1874 
1875         switch (index) {
1876         case RTA_OIF:
1877             cfg->rc_ifindex = nla_get_u32(nla);
1878             break;
1879         case RTA_NEWDST:
1880             if (nla_get_labels(nla, MAX_NEW_LABELS,
1881                        &cfg->rc_output_labels,
1882                        cfg->rc_output_label, extack))
1883                 goto errout;
1884             break;
1885         case RTA_DST:
1886         {
1887             u8 label_count;
1888             if (nla_get_labels(nla, 1, &label_count,
1889                        &cfg->rc_label, extack))
1890                 goto errout;
1891 
1892             if (!mpls_label_ok(cfg->rc_nlinfo.nl_net,
1893                        &cfg->rc_label, extack))
1894                 goto errout;
1895             break;
1896         }
1897         case RTA_GATEWAY:
1898             NL_SET_ERR_MSG(extack, "MPLS does not support RTA_GATEWAY attribute");
1899             goto errout;
1900         case RTA_VIA:
1901         {
1902             if (nla_get_via(nla, &cfg->rc_via_alen,
1903                     &cfg->rc_via_table, cfg->rc_via,
1904                     extack))
1905                 goto errout;
1906             break;
1907         }
1908         case RTA_MULTIPATH:
1909         {
1910             cfg->rc_mp = nla_data(nla);
1911             cfg->rc_mp_len = nla_len(nla);
1912             break;
1913         }
1914         case RTA_TTL_PROPAGATE:
1915         {
1916             u8 ttl_propagate = nla_get_u8(nla);
1917 
1918             if (ttl_propagate > 1) {
1919                 NL_SET_ERR_MSG_ATTR(extack, nla,
1920                             "RTA_TTL_PROPAGATE can only be 0 or 1");
1921                 goto errout;
1922             }
1923             cfg->rc_ttl_propagate = ttl_propagate ?
1924                 MPLS_TTL_PROP_ENABLED :
1925                 MPLS_TTL_PROP_DISABLED;
1926             break;
1927         }
1928         default:
1929             NL_SET_ERR_MSG_ATTR(extack, nla, "Unknown attribute");
1930             /* Unsupported attribute */
1931             goto errout;
1932         }
1933     }
1934 
1935     err = 0;
1936 errout:
1937     return err;
1938 }
1939 
1940 static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh,
1941                  struct netlink_ext_ack *extack)
1942 {
1943     struct mpls_route_config *cfg;
1944     int err;
1945 
1946     cfg = kzalloc(sizeof(*cfg), GFP_KERNEL);
1947     if (!cfg)
1948         return -ENOMEM;
1949 
1950     err = rtm_to_route_config(skb, nlh, cfg, extack);
1951     if (err < 0)
1952         goto out;
1953 
1954     err = mpls_route_del(cfg, extack);
1955 out:
1956     kfree(cfg);
1957 
1958     return err;
1959 }
1960 
1961 
1962 static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh,
1963                  struct netlink_ext_ack *extack)
1964 {
1965     struct mpls_route_config *cfg;
1966     int err;
1967 
1968     cfg = kzalloc(sizeof(*cfg), GFP_KERNEL);
1969     if (!cfg)
1970         return -ENOMEM;
1971 
1972     err = rtm_to_route_config(skb, nlh, cfg, extack);
1973     if (err < 0)
1974         goto out;
1975 
1976     err = mpls_route_add(cfg, extack);
1977 out:
1978     kfree(cfg);
1979 
1980     return err;
1981 }
1982 
1983 static int mpls_dump_route(struct sk_buff *skb, u32 portid, u32 seq, int event,
1984                u32 label, struct mpls_route *rt, int flags)
1985 {
1986     struct net_device *dev;
1987     struct nlmsghdr *nlh;
1988     struct rtmsg *rtm;
1989 
1990     nlh = nlmsg_put(skb, portid, seq, event, sizeof(*rtm), flags);
1991     if (nlh == NULL)
1992         return -EMSGSIZE;
1993 
1994     rtm = nlmsg_data(nlh);
1995     rtm->rtm_family = AF_MPLS;
1996     rtm->rtm_dst_len = 20;
1997     rtm->rtm_src_len = 0;
1998     rtm->rtm_tos = 0;
1999     rtm->rtm_table = RT_TABLE_MAIN;
2000     rtm->rtm_protocol = rt->rt_protocol;
2001     rtm->rtm_scope = RT_SCOPE_UNIVERSE;
2002     rtm->rtm_type = RTN_UNICAST;
2003     rtm->rtm_flags = 0;
2004 
2005     if (nla_put_labels(skb, RTA_DST, 1, &label))
2006         goto nla_put_failure;
2007 
2008     if (rt->rt_ttl_propagate != MPLS_TTL_PROP_DEFAULT) {
2009         bool ttl_propagate =
2010             rt->rt_ttl_propagate == MPLS_TTL_PROP_ENABLED;
2011 
2012         if (nla_put_u8(skb, RTA_TTL_PROPAGATE,
2013                    ttl_propagate))
2014             goto nla_put_failure;
2015     }
2016     if (rt->rt_nhn == 1) {
2017         const struct mpls_nh *nh = rt->rt_nh;
2018 
2019         if (nh->nh_labels &&
2020             nla_put_labels(skb, RTA_NEWDST, nh->nh_labels,
2021                    nh->nh_label))
2022             goto nla_put_failure;
2023         if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC &&
2024             nla_put_via(skb, nh->nh_via_table, mpls_nh_via(rt, nh),
2025                 nh->nh_via_alen))
2026             goto nla_put_failure;
2027         dev = nh->nh_dev;
2028         if (dev && nla_put_u32(skb, RTA_OIF, dev->ifindex))
2029             goto nla_put_failure;
2030         if (nh->nh_flags & RTNH_F_LINKDOWN)
2031             rtm->rtm_flags |= RTNH_F_LINKDOWN;
2032         if (nh->nh_flags & RTNH_F_DEAD)
2033             rtm->rtm_flags |= RTNH_F_DEAD;
2034     } else {
2035         struct rtnexthop *rtnh;
2036         struct nlattr *mp;
2037         u8 linkdown = 0;
2038         u8 dead = 0;
2039 
2040         mp = nla_nest_start_noflag(skb, RTA_MULTIPATH);
2041         if (!mp)
2042             goto nla_put_failure;
2043 
2044         for_nexthops(rt) {
2045             dev = nh->nh_dev;
2046             if (!dev)
2047                 continue;
2048 
2049             rtnh = nla_reserve_nohdr(skb, sizeof(*rtnh));
2050             if (!rtnh)
2051                 goto nla_put_failure;
2052 
2053             rtnh->rtnh_ifindex = dev->ifindex;
2054             if (nh->nh_flags & RTNH_F_LINKDOWN) {
2055                 rtnh->rtnh_flags |= RTNH_F_LINKDOWN;
2056                 linkdown++;
2057             }
2058             if (nh->nh_flags & RTNH_F_DEAD) {
2059                 rtnh->rtnh_flags |= RTNH_F_DEAD;
2060                 dead++;
2061             }
2062 
2063             if (nh->nh_labels && nla_put_labels(skb, RTA_NEWDST,
2064                                 nh->nh_labels,
2065                                 nh->nh_label))
2066                 goto nla_put_failure;
2067             if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC &&
2068                 nla_put_via(skb, nh->nh_via_table,
2069                     mpls_nh_via(rt, nh),
2070                     nh->nh_via_alen))
2071                 goto nla_put_failure;
2072 
2073             /* length of rtnetlink header + attributes */
2074             rtnh->rtnh_len = nlmsg_get_pos(skb) - (void *)rtnh;
2075         } endfor_nexthops(rt);
2076 
2077         if (linkdown == rt->rt_nhn)
2078             rtm->rtm_flags |= RTNH_F_LINKDOWN;
2079         if (dead == rt->rt_nhn)
2080             rtm->rtm_flags |= RTNH_F_DEAD;
2081 
2082         nla_nest_end(skb, mp);
2083     }
2084 
2085     nlmsg_end(skb, nlh);
2086     return 0;
2087 
2088 nla_put_failure:
2089     nlmsg_cancel(skb, nlh);
2090     return -EMSGSIZE;
2091 }
2092 
2093 #if IS_ENABLED(CONFIG_INET)
2094 static int mpls_valid_fib_dump_req(struct net *net, const struct nlmsghdr *nlh,
2095                    struct fib_dump_filter *filter,
2096                    struct netlink_callback *cb)
2097 {
2098     return ip_valid_fib_dump_req(net, nlh, filter, cb);
2099 }
2100 #else
2101 static int mpls_valid_fib_dump_req(struct net *net, const struct nlmsghdr *nlh,
2102                    struct fib_dump_filter *filter,
2103                    struct netlink_callback *cb)
2104 {
2105     struct netlink_ext_ack *extack = cb->extack;
2106     struct nlattr *tb[RTA_MAX + 1];
2107     struct rtmsg *rtm;
2108     int err, i;
2109 
2110     if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) {
2111         NL_SET_ERR_MSG_MOD(extack, "Invalid header for FIB dump request");
2112         return -EINVAL;
2113     }
2114 
2115     rtm = nlmsg_data(nlh);
2116     if (rtm->rtm_dst_len || rtm->rtm_src_len  || rtm->rtm_tos   ||
2117         rtm->rtm_table   || rtm->rtm_scope    || rtm->rtm_type  ||
2118         rtm->rtm_flags) {
2119         NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for FIB dump request");
2120         return -EINVAL;
2121     }
2122 
2123     if (rtm->rtm_protocol) {
2124         filter->protocol = rtm->rtm_protocol;
2125         filter->filter_set = 1;
2126         cb->answer_flags = NLM_F_DUMP_FILTERED;
2127     }
2128 
2129     err = nlmsg_parse_deprecated_strict(nlh, sizeof(*rtm), tb, RTA_MAX,
2130                         rtm_mpls_policy, extack);
2131     if (err < 0)
2132         return err;
2133 
2134     for (i = 0; i <= RTA_MAX; ++i) {
2135         int ifindex;
2136 
2137         if (i == RTA_OIF) {
2138             ifindex = nla_get_u32(tb[i]);
2139             filter->dev = __dev_get_by_index(net, ifindex);
2140             if (!filter->dev)
2141                 return -ENODEV;
2142             filter->filter_set = 1;
2143         } else if (tb[i]) {
2144             NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in dump request");
2145             return -EINVAL;
2146         }
2147     }
2148 
2149     return 0;
2150 }
2151 #endif
2152 
2153 static bool mpls_rt_uses_dev(struct mpls_route *rt,
2154                  const struct net_device *dev)
2155 {
2156     if (rt->rt_nhn == 1) {
2157         struct mpls_nh *nh = rt->rt_nh;
2158 
2159         if (nh->nh_dev == dev)
2160             return true;
2161     } else {
2162         for_nexthops(rt) {
2163             if (nh->nh_dev == dev)
2164                 return true;
2165         } endfor_nexthops(rt);
2166     }
2167 
2168     return false;
2169 }
2170 
2171 static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
2172 {
2173     const struct nlmsghdr *nlh = cb->nlh;
2174     struct net *net = sock_net(skb->sk);
2175     struct mpls_route __rcu **platform_label;
2176     struct fib_dump_filter filter = {};
2177     unsigned int flags = NLM_F_MULTI;
2178     size_t platform_labels;
2179     unsigned int index;
2180 
2181     ASSERT_RTNL();
2182 
2183     if (cb->strict_check) {
2184         int err;
2185 
2186         err = mpls_valid_fib_dump_req(net, nlh, &filter, cb);
2187         if (err < 0)
2188             return err;
2189 
2190         /* for MPLS, there is only 1 table with fixed type and flags.
2191          * If either are set in the filter then return nothing.
2192          */
2193         if ((filter.table_id && filter.table_id != RT_TABLE_MAIN) ||
2194             (filter.rt_type && filter.rt_type != RTN_UNICAST) ||
2195              filter.flags)
2196             return skb->len;
2197     }
2198 
2199     index = cb->args[0];
2200     if (index < MPLS_LABEL_FIRST_UNRESERVED)
2201         index = MPLS_LABEL_FIRST_UNRESERVED;
2202 
2203     platform_label = rtnl_dereference(net->mpls.platform_label);
2204     platform_labels = net->mpls.platform_labels;
2205 
2206     if (filter.filter_set)
2207         flags |= NLM_F_DUMP_FILTERED;
2208 
2209     for (; index < platform_labels; index++) {
2210         struct mpls_route *rt;
2211 
2212         rt = rtnl_dereference(platform_label[index]);
2213         if (!rt)
2214             continue;
2215 
2216         if ((filter.dev && !mpls_rt_uses_dev(rt, filter.dev)) ||
2217             (filter.protocol && rt->rt_protocol != filter.protocol))
2218             continue;
2219 
2220         if (mpls_dump_route(skb, NETLINK_CB(cb->skb).portid,
2221                     cb->nlh->nlmsg_seq, RTM_NEWROUTE,
2222                     index, rt, flags) < 0)
2223             break;
2224     }
2225     cb->args[0] = index;
2226 
2227     return skb->len;
2228 }
2229 
2230 static inline size_t lfib_nlmsg_size(struct mpls_route *rt)
2231 {
2232     size_t payload =
2233         NLMSG_ALIGN(sizeof(struct rtmsg))
2234         + nla_total_size(4)         /* RTA_DST */
2235         + nla_total_size(1);            /* RTA_TTL_PROPAGATE */
2236 
2237     if (rt->rt_nhn == 1) {
2238         struct mpls_nh *nh = rt->rt_nh;
2239 
2240         if (nh->nh_dev)
2241             payload += nla_total_size(4); /* RTA_OIF */
2242         if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC) /* RTA_VIA */
2243             payload += nla_total_size(2 + nh->nh_via_alen);
2244         if (nh->nh_labels) /* RTA_NEWDST */
2245             payload += nla_total_size(nh->nh_labels * 4);
2246     } else {
2247         /* each nexthop is packed in an attribute */
2248         size_t nhsize = 0;
2249 
2250         for_nexthops(rt) {
2251             if (!nh->nh_dev)
2252                 continue;
2253             nhsize += nla_total_size(sizeof(struct rtnexthop));
2254             /* RTA_VIA */
2255             if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC)
2256                 nhsize += nla_total_size(2 + nh->nh_via_alen);
2257             if (nh->nh_labels)
2258                 nhsize += nla_total_size(nh->nh_labels * 4);
2259         } endfor_nexthops(rt);
2260         /* nested attribute */
2261         payload += nla_total_size(nhsize);
2262     }
2263 
2264     return payload;
2265 }
2266 
2267 static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt,
2268                struct nlmsghdr *nlh, struct net *net, u32 portid,
2269                unsigned int nlm_flags)
2270 {
2271     struct sk_buff *skb;
2272     u32 seq = nlh ? nlh->nlmsg_seq : 0;
2273     int err = -ENOBUFS;
2274 
2275     skb = nlmsg_new(lfib_nlmsg_size(rt), GFP_KERNEL);
2276     if (skb == NULL)
2277         goto errout;
2278 
2279     err = mpls_dump_route(skb, portid, seq, event, label, rt, nlm_flags);
2280     if (err < 0) {
2281         /* -EMSGSIZE implies BUG in lfib_nlmsg_size */
2282         WARN_ON(err == -EMSGSIZE);
2283         kfree_skb(skb);
2284         goto errout;
2285     }
2286     rtnl_notify(skb, net, portid, RTNLGRP_MPLS_ROUTE, nlh, GFP_KERNEL);
2287 
2288     return;
2289 errout:
2290     if (err < 0)
2291         rtnl_set_sk_err(net, RTNLGRP_MPLS_ROUTE, err);
2292 }
2293 
2294 static int mpls_valid_getroute_req(struct sk_buff *skb,
2295                    const struct nlmsghdr *nlh,
2296                    struct nlattr **tb,
2297                    struct netlink_ext_ack *extack)
2298 {
2299     struct rtmsg *rtm;
2300     int i, err;
2301 
2302     if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) {
2303         NL_SET_ERR_MSG_MOD(extack,
2304                    "Invalid header for get route request");
2305         return -EINVAL;
2306     }
2307 
2308     if (!netlink_strict_get_check(skb))
2309         return nlmsg_parse_deprecated(nlh, sizeof(*rtm), tb, RTA_MAX,
2310                           rtm_mpls_policy, extack);
2311 
2312     rtm = nlmsg_data(nlh);
2313     if ((rtm->rtm_dst_len && rtm->rtm_dst_len != 20) ||
2314         rtm->rtm_src_len || rtm->rtm_tos || rtm->rtm_table ||
2315         rtm->rtm_protocol || rtm->rtm_scope || rtm->rtm_type) {
2316         NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for get route request");
2317         return -EINVAL;
2318     }
2319     if (rtm->rtm_flags & ~RTM_F_FIB_MATCH) {
2320         NL_SET_ERR_MSG_MOD(extack,
2321                    "Invalid flags for get route request");
2322         return -EINVAL;
2323     }
2324 
2325     err = nlmsg_parse_deprecated_strict(nlh, sizeof(*rtm), tb, RTA_MAX,
2326                         rtm_mpls_policy, extack);
2327     if (err)
2328         return err;
2329 
2330     if ((tb[RTA_DST] || tb[RTA_NEWDST]) && !rtm->rtm_dst_len) {
2331         NL_SET_ERR_MSG_MOD(extack, "rtm_dst_len must be 20 for MPLS");
2332         return -EINVAL;
2333     }
2334 
2335     for (i = 0; i <= RTA_MAX; i++) {
2336         if (!tb[i])
2337             continue;
2338 
2339         switch (i) {
2340         case RTA_DST:
2341         case RTA_NEWDST:
2342             break;
2343         default:
2344             NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in get route request");
2345             return -EINVAL;
2346         }
2347     }
2348 
2349     return 0;
2350 }
2351 
2352 static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
2353              struct netlink_ext_ack *extack)
2354 {
2355     struct net *net = sock_net(in_skb->sk);
2356     u32 portid = NETLINK_CB(in_skb).portid;
2357     u32 in_label = LABEL_NOT_SPECIFIED;
2358     struct nlattr *tb[RTA_MAX + 1];
2359     u32 labels[MAX_NEW_LABELS];
2360     struct mpls_shim_hdr *hdr;
2361     unsigned int hdr_size = 0;
2362     const struct mpls_nh *nh;
2363     struct net_device *dev;
2364     struct mpls_route *rt;
2365     struct rtmsg *rtm, *r;
2366     struct nlmsghdr *nlh;
2367     struct sk_buff *skb;
2368     u8 n_labels;
2369     int err;
2370 
2371     err = mpls_valid_getroute_req(in_skb, in_nlh, tb, extack);
2372     if (err < 0)
2373         goto errout;
2374 
2375     rtm = nlmsg_data(in_nlh);
2376 
2377     if (tb[RTA_DST]) {
2378         u8 label_count;
2379 
2380         if (nla_get_labels(tb[RTA_DST], 1, &label_count,
2381                    &in_label, extack)) {
2382             err = -EINVAL;
2383             goto errout;
2384         }
2385 
2386         if (!mpls_label_ok(net, &in_label, extack)) {
2387             err = -EINVAL;
2388             goto errout;
2389         }
2390     }
2391 
2392     rt = mpls_route_input_rcu(net, in_label);
2393     if (!rt) {
2394         err = -ENETUNREACH;
2395         goto errout;
2396     }
2397 
2398     if (rtm->rtm_flags & RTM_F_FIB_MATCH) {
2399         skb = nlmsg_new(lfib_nlmsg_size(rt), GFP_KERNEL);
2400         if (!skb) {
2401             err = -ENOBUFS;
2402             goto errout;
2403         }
2404 
2405         err = mpls_dump_route(skb, portid, in_nlh->nlmsg_seq,
2406                       RTM_NEWROUTE, in_label, rt, 0);
2407         if (err < 0) {
2408             /* -EMSGSIZE implies BUG in lfib_nlmsg_size */
2409             WARN_ON(err == -EMSGSIZE);
2410             goto errout_free;
2411         }
2412 
2413         return rtnl_unicast(skb, net, portid);
2414     }
2415 
2416     if (tb[RTA_NEWDST]) {
2417         if (nla_get_labels(tb[RTA_NEWDST], MAX_NEW_LABELS, &n_labels,
2418                    labels, extack) != 0) {
2419             err = -EINVAL;
2420             goto errout;
2421         }
2422 
2423         hdr_size = n_labels * sizeof(struct mpls_shim_hdr);
2424     }
2425 
2426     skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL);
2427     if (!skb) {
2428         err = -ENOBUFS;
2429         goto errout;
2430     }
2431 
2432     skb->protocol = htons(ETH_P_MPLS_UC);
2433 
2434     if (hdr_size) {
2435         bool bos;
2436         int i;
2437 
2438         if (skb_cow(skb, hdr_size)) {
2439             err = -ENOBUFS;
2440             goto errout_free;
2441         }
2442 
2443         skb_reserve(skb, hdr_size);
2444         skb_push(skb, hdr_size);
2445         skb_reset_network_header(skb);
2446 
2447         /* Push new labels */
2448         hdr = mpls_hdr(skb);
2449         bos = true;
2450         for (i = n_labels - 1; i >= 0; i--) {
2451             hdr[i] = mpls_entry_encode(labels[i],
2452                            1, 0, bos);
2453             bos = false;
2454         }
2455     }
2456 
2457     nh = mpls_select_multipath(rt, skb);
2458     if (!nh) {
2459         err = -ENETUNREACH;
2460         goto errout_free;
2461     }
2462 
2463     if (hdr_size) {
2464         skb_pull(skb, hdr_size);
2465         skb_reset_network_header(skb);
2466     }
2467 
2468     nlh = nlmsg_put(skb, portid, in_nlh->nlmsg_seq,
2469             RTM_NEWROUTE, sizeof(*r), 0);
2470     if (!nlh) {
2471         err = -EMSGSIZE;
2472         goto errout_free;
2473     }
2474 
2475     r = nlmsg_data(nlh);
2476     r->rtm_family    = AF_MPLS;
2477     r->rtm_dst_len  = 20;
2478     r->rtm_src_len  = 0;
2479     r->rtm_table    = RT_TABLE_MAIN;
2480     r->rtm_type = RTN_UNICAST;
2481     r->rtm_scope    = RT_SCOPE_UNIVERSE;
2482     r->rtm_protocol = rt->rt_protocol;
2483     r->rtm_flags    = 0;
2484 
2485     if (nla_put_labels(skb, RTA_DST, 1, &in_label))
2486         goto nla_put_failure;
2487 
2488     if (nh->nh_labels &&
2489         nla_put_labels(skb, RTA_NEWDST, nh->nh_labels,
2490                nh->nh_label))
2491         goto nla_put_failure;
2492 
2493     if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC &&
2494         nla_put_via(skb, nh->nh_via_table, mpls_nh_via(rt, nh),
2495             nh->nh_via_alen))
2496         goto nla_put_failure;
2497     dev = nh->nh_dev;
2498     if (dev && nla_put_u32(skb, RTA_OIF, dev->ifindex))
2499         goto nla_put_failure;
2500 
2501     nlmsg_end(skb, nlh);
2502 
2503     err = rtnl_unicast(skb, net, portid);
2504 errout:
2505     return err;
2506 
2507 nla_put_failure:
2508     nlmsg_cancel(skb, nlh);
2509     err = -EMSGSIZE;
2510 errout_free:
2511     kfree_skb(skb);
2512     return err;
2513 }
2514 
2515 static int resize_platform_label_table(struct net *net, size_t limit)
2516 {
2517     size_t size = sizeof(struct mpls_route *) * limit;
2518     size_t old_limit;
2519     size_t cp_size;
2520     struct mpls_route __rcu **labels = NULL, **old;
2521     struct mpls_route *rt0 = NULL, *rt2 = NULL;
2522     unsigned index;
2523 
2524     if (size) {
2525         labels = kvzalloc(size, GFP_KERNEL);
2526         if (!labels)
2527             goto nolabels;
2528     }
2529 
2530     /* In case the predefined labels need to be populated */
2531     if (limit > MPLS_LABEL_IPV4NULL) {
2532         struct net_device *lo = net->loopback_dev;
2533         rt0 = mpls_rt_alloc(1, lo->addr_len, 0);
2534         if (IS_ERR(rt0))
2535             goto nort0;
2536         rt0->rt_nh->nh_dev = lo;
2537         rt0->rt_protocol = RTPROT_KERNEL;
2538         rt0->rt_payload_type = MPT_IPV4;
2539         rt0->rt_ttl_propagate = MPLS_TTL_PROP_DEFAULT;
2540         rt0->rt_nh->nh_via_table = NEIGH_LINK_TABLE;
2541         rt0->rt_nh->nh_via_alen = lo->addr_len;
2542         memcpy(__mpls_nh_via(rt0, rt0->rt_nh), lo->dev_addr,
2543                lo->addr_len);
2544     }
2545     if (limit > MPLS_LABEL_IPV6NULL) {
2546         struct net_device *lo = net->loopback_dev;
2547         rt2 = mpls_rt_alloc(1, lo->addr_len, 0);
2548         if (IS_ERR(rt2))
2549             goto nort2;
2550         rt2->rt_nh->nh_dev = lo;
2551         rt2->rt_protocol = RTPROT_KERNEL;
2552         rt2->rt_payload_type = MPT_IPV6;
2553         rt2->rt_ttl_propagate = MPLS_TTL_PROP_DEFAULT;
2554         rt2->rt_nh->nh_via_table = NEIGH_LINK_TABLE;
2555         rt2->rt_nh->nh_via_alen = lo->addr_len;
2556         memcpy(__mpls_nh_via(rt2, rt2->rt_nh), lo->dev_addr,
2557                lo->addr_len);
2558     }
2559 
2560     rtnl_lock();
2561     /* Remember the original table */
2562     old = rtnl_dereference(net->mpls.platform_label);
2563     old_limit = net->mpls.platform_labels;
2564 
2565     /* Free any labels beyond the new table */
2566     for (index = limit; index < old_limit; index++)
2567         mpls_route_update(net, index, NULL, NULL);
2568 
2569     /* Copy over the old labels */
2570     cp_size = size;
2571     if (old_limit < limit)
2572         cp_size = old_limit * sizeof(struct mpls_route *);
2573 
2574     memcpy(labels, old, cp_size);
2575 
2576     /* If needed set the predefined labels */
2577     if ((old_limit <= MPLS_LABEL_IPV6NULL) &&
2578         (limit > MPLS_LABEL_IPV6NULL)) {
2579         RCU_INIT_POINTER(labels[MPLS_LABEL_IPV6NULL], rt2);
2580         rt2 = NULL;
2581     }
2582 
2583     if ((old_limit <= MPLS_LABEL_IPV4NULL) &&
2584         (limit > MPLS_LABEL_IPV4NULL)) {
2585         RCU_INIT_POINTER(labels[MPLS_LABEL_IPV4NULL], rt0);
2586         rt0 = NULL;
2587     }
2588 
2589     /* Update the global pointers */
2590     net->mpls.platform_labels = limit;
2591     rcu_assign_pointer(net->mpls.platform_label, labels);
2592 
2593     rtnl_unlock();
2594 
2595     mpls_rt_free(rt2);
2596     mpls_rt_free(rt0);
2597 
2598     if (old) {
2599         synchronize_rcu();
2600         kvfree(old);
2601     }
2602     return 0;
2603 
2604 nort2:
2605     mpls_rt_free(rt0);
2606 nort0:
2607     kvfree(labels);
2608 nolabels:
2609     return -ENOMEM;
2610 }
2611 
2612 static int mpls_platform_labels(struct ctl_table *table, int write,
2613                 void *buffer, size_t *lenp, loff_t *ppos)
2614 {
2615     struct net *net = table->data;
2616     int platform_labels = net->mpls.platform_labels;
2617     int ret;
2618     struct ctl_table tmp = {
2619         .procname   = table->procname,
2620         .data       = &platform_labels,
2621         .maxlen     = sizeof(int),
2622         .mode       = table->mode,
2623         .extra1     = SYSCTL_ZERO,
2624         .extra2     = &label_limit,
2625     };
2626 
2627     ret = proc_dointvec_minmax(&tmp, write, buffer, lenp, ppos);
2628 
2629     if (write && ret == 0)
2630         ret = resize_platform_label_table(net, platform_labels);
2631 
2632     return ret;
2633 }
2634 
2635 #define MPLS_NS_SYSCTL_OFFSET(field)        \
2636     (&((struct net *)0)->field)
2637 
2638 static const struct ctl_table mpls_table[] = {
2639     {
2640         .procname   = "platform_labels",
2641         .data       = NULL,
2642         .maxlen     = sizeof(int),
2643         .mode       = 0644,
2644         .proc_handler   = mpls_platform_labels,
2645     },
2646     {
2647         .procname   = "ip_ttl_propagate",
2648         .data       = MPLS_NS_SYSCTL_OFFSET(mpls.ip_ttl_propagate),
2649         .maxlen     = sizeof(int),
2650         .mode       = 0644,
2651         .proc_handler   = proc_dointvec_minmax,
2652         .extra1     = SYSCTL_ZERO,
2653         .extra2     = SYSCTL_ONE,
2654     },
2655     {
2656         .procname   = "default_ttl",
2657         .data       = MPLS_NS_SYSCTL_OFFSET(mpls.default_ttl),
2658         .maxlen     = sizeof(int),
2659         .mode       = 0644,
2660         .proc_handler   = proc_dointvec_minmax,
2661         .extra1     = SYSCTL_ONE,
2662         .extra2     = &ttl_max,
2663     },
2664     { }
2665 };
2666 
2667 static int mpls_net_init(struct net *net)
2668 {
2669     struct ctl_table *table;
2670     int i;
2671 
2672     net->mpls.platform_labels = 0;
2673     net->mpls.platform_label = NULL;
2674     net->mpls.ip_ttl_propagate = 1;
2675     net->mpls.default_ttl = 255;
2676 
2677     table = kmemdup(mpls_table, sizeof(mpls_table), GFP_KERNEL);
2678     if (table == NULL)
2679         return -ENOMEM;
2680 
2681     /* Table data contains only offsets relative to the base of
2682      * the mdev at this point, so make them absolute.
2683      */
2684     for (i = 0; i < ARRAY_SIZE(mpls_table) - 1; i++)
2685         table[i].data = (char *)net + (uintptr_t)table[i].data;
2686 
2687     net->mpls.ctl = register_net_sysctl(net, "net/mpls", table);
2688     if (net->mpls.ctl == NULL) {
2689         kfree(table);
2690         return -ENOMEM;
2691     }
2692 
2693     return 0;
2694 }
2695 
2696 static void mpls_net_exit(struct net *net)
2697 {
2698     struct mpls_route __rcu **platform_label;
2699     size_t platform_labels;
2700     struct ctl_table *table;
2701     unsigned int index;
2702 
2703     table = net->mpls.ctl->ctl_table_arg;
2704     unregister_net_sysctl_table(net->mpls.ctl);
2705     kfree(table);
2706 
2707     /* An rcu grace period has passed since there was a device in
2708      * the network namespace (and thus the last in flight packet)
2709      * left this network namespace.  This is because
2710      * unregister_netdevice_many and netdev_run_todo has completed
2711      * for each network device that was in this network namespace.
2712      *
2713      * As such no additional rcu synchronization is necessary when
2714      * freeing the platform_label table.
2715      */
2716     rtnl_lock();
2717     platform_label = rtnl_dereference(net->mpls.platform_label);
2718     platform_labels = net->mpls.platform_labels;
2719     for (index = 0; index < platform_labels; index++) {
2720         struct mpls_route *rt = rtnl_dereference(platform_label[index]);
2721         RCU_INIT_POINTER(platform_label[index], NULL);
2722         mpls_notify_route(net, index, rt, NULL, NULL);
2723         mpls_rt_free(rt);
2724     }
2725     rtnl_unlock();
2726 
2727     kvfree(platform_label);
2728 }
2729 
2730 static struct pernet_operations mpls_net_ops = {
2731     .init = mpls_net_init,
2732     .exit = mpls_net_exit,
2733 };
2734 
2735 static struct rtnl_af_ops mpls_af_ops __read_mostly = {
2736     .family        = AF_MPLS,
2737     .fill_stats_af     = mpls_fill_stats_af,
2738     .get_stats_af_size = mpls_get_stats_af_size,
2739 };
2740 
2741 static int __init mpls_init(void)
2742 {
2743     int err;
2744 
2745     BUILD_BUG_ON(sizeof(struct mpls_shim_hdr) != 4);
2746 
2747     err = register_pernet_subsys(&mpls_net_ops);
2748     if (err)
2749         goto out;
2750 
2751     err = register_netdevice_notifier(&mpls_dev_notifier);
2752     if (err)
2753         goto out_unregister_pernet;
2754 
2755     dev_add_pack(&mpls_packet_type);
2756 
2757     rtnl_af_register(&mpls_af_ops);
2758 
2759     rtnl_register_module(THIS_MODULE, PF_MPLS, RTM_NEWROUTE,
2760                  mpls_rtm_newroute, NULL, 0);
2761     rtnl_register_module(THIS_MODULE, PF_MPLS, RTM_DELROUTE,
2762                  mpls_rtm_delroute, NULL, 0);
2763     rtnl_register_module(THIS_MODULE, PF_MPLS, RTM_GETROUTE,
2764                  mpls_getroute, mpls_dump_routes, 0);
2765     rtnl_register_module(THIS_MODULE, PF_MPLS, RTM_GETNETCONF,
2766                  mpls_netconf_get_devconf,
2767                  mpls_netconf_dump_devconf, 0);
2768     err = ipgre_tunnel_encap_add_mpls_ops();
2769     if (err)
2770         pr_err("Can't add mpls over gre tunnel ops\n");
2771 
2772     err = 0;
2773 out:
2774     return err;
2775 
2776 out_unregister_pernet:
2777     unregister_pernet_subsys(&mpls_net_ops);
2778     goto out;
2779 }
2780 module_init(mpls_init);
2781 
2782 static void __exit mpls_exit(void)
2783 {
2784     rtnl_unregister_all(PF_MPLS);
2785     rtnl_af_unregister(&mpls_af_ops);
2786     dev_remove_pack(&mpls_packet_type);
2787     unregister_netdevice_notifier(&mpls_dev_notifier);
2788     unregister_pernet_subsys(&mpls_net_ops);
2789     ipgre_tunnel_encap_del_mpls_ops();
2790 }
2791 module_exit(mpls_exit);
2792 
2793 MODULE_DESCRIPTION("MultiProtocol Label Switching");
2794 MODULE_LICENSE("GPL v2");
2795 MODULE_ALIAS_NETPROTO(PF_MPLS);