0001
0002
0003
0004
0005
0006
0007
0008
0009 #include <linux/bpf.h>
0010 #include <linux/btf.h>
0011 #include <linux/types.h>
0012 #include <linux/btf_ids.h>
0013 #include <linux/net_namespace.h>
0014 #include <net/netfilter/nf_conntrack.h>
0015 #include <net/netfilter/nf_conntrack_bpf.h>
0016 #include <net/netfilter/nf_conntrack_core.h>
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046 struct bpf_ct_opts {
0047 s32 netns_id;
0048 s32 error;
0049 u8 l4proto;
0050 u8 dir;
0051 u8 reserved[2];
0052 };
0053
0054 enum {
0055 NF_BPF_CT_OPTS_SZ = 12,
0056 };
0057
0058 static int bpf_nf_ct_tuple_parse(struct bpf_sock_tuple *bpf_tuple,
0059 u32 tuple_len, u8 protonum, u8 dir,
0060 struct nf_conntrack_tuple *tuple)
0061 {
0062 union nf_inet_addr *src = dir ? &tuple->dst.u3 : &tuple->src.u3;
0063 union nf_inet_addr *dst = dir ? &tuple->src.u3 : &tuple->dst.u3;
0064 union nf_conntrack_man_proto *sport = dir ? (void *)&tuple->dst.u
0065 : &tuple->src.u;
0066 union nf_conntrack_man_proto *dport = dir ? &tuple->src.u
0067 : (void *)&tuple->dst.u;
0068
0069 if (unlikely(protonum != IPPROTO_TCP && protonum != IPPROTO_UDP))
0070 return -EPROTO;
0071
0072 memset(tuple, 0, sizeof(*tuple));
0073
0074 switch (tuple_len) {
0075 case sizeof(bpf_tuple->ipv4):
0076 tuple->src.l3num = AF_INET;
0077 src->ip = bpf_tuple->ipv4.saddr;
0078 sport->tcp.port = bpf_tuple->ipv4.sport;
0079 dst->ip = bpf_tuple->ipv4.daddr;
0080 dport->tcp.port = bpf_tuple->ipv4.dport;
0081 break;
0082 case sizeof(bpf_tuple->ipv6):
0083 tuple->src.l3num = AF_INET6;
0084 memcpy(src->ip6, bpf_tuple->ipv6.saddr, sizeof(bpf_tuple->ipv6.saddr));
0085 sport->tcp.port = bpf_tuple->ipv6.sport;
0086 memcpy(dst->ip6, bpf_tuple->ipv6.daddr, sizeof(bpf_tuple->ipv6.daddr));
0087 dport->tcp.port = bpf_tuple->ipv6.dport;
0088 break;
0089 default:
0090 return -EAFNOSUPPORT;
0091 }
0092 tuple->dst.protonum = protonum;
0093 tuple->dst.dir = dir;
0094
0095 return 0;
0096 }
0097
0098 static struct nf_conn *
0099 __bpf_nf_ct_alloc_entry(struct net *net, struct bpf_sock_tuple *bpf_tuple,
0100 u32 tuple_len, struct bpf_ct_opts *opts, u32 opts_len,
0101 u32 timeout)
0102 {
0103 struct nf_conntrack_tuple otuple, rtuple;
0104 struct nf_conn *ct;
0105 int err;
0106
0107 if (!opts || !bpf_tuple || opts->reserved[0] || opts->reserved[1] ||
0108 opts_len != NF_BPF_CT_OPTS_SZ)
0109 return ERR_PTR(-EINVAL);
0110
0111 if (unlikely(opts->netns_id < BPF_F_CURRENT_NETNS))
0112 return ERR_PTR(-EINVAL);
0113
0114 err = bpf_nf_ct_tuple_parse(bpf_tuple, tuple_len, opts->l4proto,
0115 IP_CT_DIR_ORIGINAL, &otuple);
0116 if (err < 0)
0117 return ERR_PTR(err);
0118
0119 err = bpf_nf_ct_tuple_parse(bpf_tuple, tuple_len, opts->l4proto,
0120 IP_CT_DIR_REPLY, &rtuple);
0121 if (err < 0)
0122 return ERR_PTR(err);
0123
0124 if (opts->netns_id >= 0) {
0125 net = get_net_ns_by_id(net, opts->netns_id);
0126 if (unlikely(!net))
0127 return ERR_PTR(-ENONET);
0128 }
0129
0130 ct = nf_conntrack_alloc(net, &nf_ct_zone_dflt, &otuple, &rtuple,
0131 GFP_ATOMIC);
0132 if (IS_ERR(ct))
0133 goto out;
0134
0135 memset(&ct->proto, 0, sizeof(ct->proto));
0136 __nf_ct_set_timeout(ct, timeout * HZ);
0137 ct->status |= IPS_CONFIRMED;
0138
0139 out:
0140 if (opts->netns_id >= 0)
0141 put_net(net);
0142
0143 return ct;
0144 }
0145
0146 static struct nf_conn *__bpf_nf_ct_lookup(struct net *net,
0147 struct bpf_sock_tuple *bpf_tuple,
0148 u32 tuple_len, struct bpf_ct_opts *opts,
0149 u32 opts_len)
0150 {
0151 struct nf_conntrack_tuple_hash *hash;
0152 struct nf_conntrack_tuple tuple;
0153 struct nf_conn *ct;
0154 int err;
0155
0156 if (!opts || !bpf_tuple || opts->reserved[0] || opts->reserved[1] ||
0157 opts_len != NF_BPF_CT_OPTS_SZ)
0158 return ERR_PTR(-EINVAL);
0159 if (unlikely(opts->l4proto != IPPROTO_TCP && opts->l4proto != IPPROTO_UDP))
0160 return ERR_PTR(-EPROTO);
0161 if (unlikely(opts->netns_id < BPF_F_CURRENT_NETNS))
0162 return ERR_PTR(-EINVAL);
0163
0164 err = bpf_nf_ct_tuple_parse(bpf_tuple, tuple_len, opts->l4proto,
0165 IP_CT_DIR_ORIGINAL, &tuple);
0166 if (err < 0)
0167 return ERR_PTR(err);
0168
0169 if (opts->netns_id >= 0) {
0170 net = get_net_ns_by_id(net, opts->netns_id);
0171 if (unlikely(!net))
0172 return ERR_PTR(-ENONET);
0173 }
0174
0175 hash = nf_conntrack_find_get(net, &nf_ct_zone_dflt, &tuple);
0176 if (opts->netns_id >= 0)
0177 put_net(net);
0178 if (!hash)
0179 return ERR_PTR(-ENOENT);
0180
0181 ct = nf_ct_tuplehash_to_ctrack(hash);
0182 opts->dir = NF_CT_DIRECTION(hash);
0183
0184 return ct;
0185 }
0186
0187 __diag_push();
0188 __diag_ignore_all("-Wmissing-prototypes",
0189 "Global functions as their definitions will be in nf_conntrack BTF");
0190
0191 struct nf_conn___init {
0192 struct nf_conn ct;
0193 };
0194
0195
0196
0197
0198
0199
0200
0201
0202
0203
0204
0205
0206
0207
0208
0209
0210 struct nf_conn___init *
0211 bpf_xdp_ct_alloc(struct xdp_md *xdp_ctx, struct bpf_sock_tuple *bpf_tuple,
0212 u32 tuple__sz, struct bpf_ct_opts *opts, u32 opts__sz)
0213 {
0214 struct xdp_buff *ctx = (struct xdp_buff *)xdp_ctx;
0215 struct nf_conn *nfct;
0216
0217 nfct = __bpf_nf_ct_alloc_entry(dev_net(ctx->rxq->dev), bpf_tuple, tuple__sz,
0218 opts, opts__sz, 10);
0219 if (IS_ERR(nfct)) {
0220 if (opts)
0221 opts->error = PTR_ERR(nfct);
0222 return NULL;
0223 }
0224
0225 return (struct nf_conn___init *)nfct;
0226 }
0227
0228
0229
0230
0231
0232
0233
0234
0235
0236
0237
0238
0239
0240
0241
0242
0243
0244 struct nf_conn *
0245 bpf_xdp_ct_lookup(struct xdp_md *xdp_ctx, struct bpf_sock_tuple *bpf_tuple,
0246 u32 tuple__sz, struct bpf_ct_opts *opts, u32 opts__sz)
0247 {
0248 struct xdp_buff *ctx = (struct xdp_buff *)xdp_ctx;
0249 struct net *caller_net;
0250 struct nf_conn *nfct;
0251
0252 caller_net = dev_net(ctx->rxq->dev);
0253 nfct = __bpf_nf_ct_lookup(caller_net, bpf_tuple, tuple__sz, opts, opts__sz);
0254 if (IS_ERR(nfct)) {
0255 if (opts)
0256 opts->error = PTR_ERR(nfct);
0257 return NULL;
0258 }
0259 return nfct;
0260 }
0261
0262
0263
0264
0265
0266
0267
0268
0269
0270
0271
0272
0273
0274
0275
0276
0277 struct nf_conn___init *
0278 bpf_skb_ct_alloc(struct __sk_buff *skb_ctx, struct bpf_sock_tuple *bpf_tuple,
0279 u32 tuple__sz, struct bpf_ct_opts *opts, u32 opts__sz)
0280 {
0281 struct sk_buff *skb = (struct sk_buff *)skb_ctx;
0282 struct nf_conn *nfct;
0283 struct net *net;
0284
0285 net = skb->dev ? dev_net(skb->dev) : sock_net(skb->sk);
0286 nfct = __bpf_nf_ct_alloc_entry(net, bpf_tuple, tuple__sz, opts, opts__sz, 10);
0287 if (IS_ERR(nfct)) {
0288 if (opts)
0289 opts->error = PTR_ERR(nfct);
0290 return NULL;
0291 }
0292
0293 return (struct nf_conn___init *)nfct;
0294 }
0295
0296
0297
0298
0299
0300
0301
0302
0303
0304
0305
0306
0307
0308
0309
0310
0311
0312 struct nf_conn *
0313 bpf_skb_ct_lookup(struct __sk_buff *skb_ctx, struct bpf_sock_tuple *bpf_tuple,
0314 u32 tuple__sz, struct bpf_ct_opts *opts, u32 opts__sz)
0315 {
0316 struct sk_buff *skb = (struct sk_buff *)skb_ctx;
0317 struct net *caller_net;
0318 struct nf_conn *nfct;
0319
0320 caller_net = skb->dev ? dev_net(skb->dev) : sock_net(skb->sk);
0321 nfct = __bpf_nf_ct_lookup(caller_net, bpf_tuple, tuple__sz, opts, opts__sz);
0322 if (IS_ERR(nfct)) {
0323 if (opts)
0324 opts->error = PTR_ERR(nfct);
0325 return NULL;
0326 }
0327 return nfct;
0328 }
0329
0330
0331
0332
0333
0334
0335
0336
0337 struct nf_conn *bpf_ct_insert_entry(struct nf_conn___init *nfct_i)
0338 {
0339 struct nf_conn *nfct = (struct nf_conn *)nfct_i;
0340 int err;
0341
0342 err = nf_conntrack_hash_check_insert(nfct);
0343 if (err < 0) {
0344 nf_conntrack_free(nfct);
0345 return NULL;
0346 }
0347 return nfct;
0348 }
0349
0350
0351
0352
0353
0354
0355
0356
0357
0358
0359
0360 void bpf_ct_release(struct nf_conn *nfct)
0361 {
0362 if (!nfct)
0363 return;
0364 nf_ct_put(nfct);
0365 }
0366
0367
0368
0369
0370
0371
0372
0373
0374
0375
0376
0377 void bpf_ct_set_timeout(struct nf_conn___init *nfct, u32 timeout)
0378 {
0379 __nf_ct_set_timeout((struct nf_conn *)nfct, msecs_to_jiffies(timeout));
0380 }
0381
0382
0383
0384
0385
0386
0387
0388
0389
0390
0391
0392 int bpf_ct_change_timeout(struct nf_conn *nfct, u32 timeout)
0393 {
0394 return __nf_ct_change_timeout(nfct, msecs_to_jiffies(timeout));
0395 }
0396
0397
0398
0399
0400
0401
0402
0403
0404
0405
0406
0407 int bpf_ct_set_status(const struct nf_conn___init *nfct, u32 status)
0408 {
0409 return nf_ct_change_status_common((struct nf_conn *)nfct, status);
0410 }
0411
0412
0413
0414
0415
0416
0417
0418
0419
0420
0421
0422 int bpf_ct_change_status(struct nf_conn *nfct, u32 status)
0423 {
0424 return nf_ct_change_status_common(nfct, status);
0425 }
0426
0427 __diag_pop()
0428
0429 BTF_SET8_START(nf_ct_kfunc_set)
0430 BTF_ID_FLAGS(func, bpf_xdp_ct_alloc, KF_ACQUIRE | KF_RET_NULL)
0431 BTF_ID_FLAGS(func, bpf_xdp_ct_lookup, KF_ACQUIRE | KF_RET_NULL)
0432 BTF_ID_FLAGS(func, bpf_skb_ct_alloc, KF_ACQUIRE | KF_RET_NULL)
0433 BTF_ID_FLAGS(func, bpf_skb_ct_lookup, KF_ACQUIRE | KF_RET_NULL)
0434 BTF_ID_FLAGS(func, bpf_ct_insert_entry, KF_ACQUIRE | KF_RET_NULL | KF_RELEASE)
0435 BTF_ID_FLAGS(func, bpf_ct_release, KF_RELEASE)
0436 BTF_ID_FLAGS(func, bpf_ct_set_timeout, KF_TRUSTED_ARGS)
0437 BTF_ID_FLAGS(func, bpf_ct_change_timeout, KF_TRUSTED_ARGS)
0438 BTF_ID_FLAGS(func, bpf_ct_set_status, KF_TRUSTED_ARGS)
0439 BTF_ID_FLAGS(func, bpf_ct_change_status, KF_TRUSTED_ARGS)
0440 BTF_SET8_END(nf_ct_kfunc_set)
0441
0442 static const struct btf_kfunc_id_set nf_conntrack_kfunc_set = {
0443 .owner = THIS_MODULE,
0444 .set = &nf_ct_kfunc_set,
0445 };
0446
0447 int register_nf_conntrack_bpf(void)
0448 {
0449 int ret;
0450
0451 ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_XDP, &nf_conntrack_kfunc_set);
0452 return ret ?: register_btf_kfunc_id_set(BPF_PROG_TYPE_SCHED_CLS, &nf_conntrack_kfunc_set);
0453 }