Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0
0002 #include <linux/init.h>
0003 #include <linux/static_call.h>
0004 #include <linux/bug.h>
0005 #include <linux/smp.h>
0006 #include <linux/sort.h>
0007 #include <linux/slab.h>
0008 #include <linux/module.h>
0009 #include <linux/cpu.h>
0010 #include <linux/processor.h>
0011 #include <asm/sections.h>
0012 
0013 extern struct static_call_site __start_static_call_sites[],
0014                    __stop_static_call_sites[];
0015 extern struct static_call_tramp_key __start_static_call_tramp_key[],
0016                     __stop_static_call_tramp_key[];
0017 
0018 static bool static_call_initialized;
0019 
0020 /* mutex to protect key modules/sites */
0021 static DEFINE_MUTEX(static_call_mutex);
0022 
0023 static void static_call_lock(void)
0024 {
0025     mutex_lock(&static_call_mutex);
0026 }
0027 
0028 static void static_call_unlock(void)
0029 {
0030     mutex_unlock(&static_call_mutex);
0031 }
0032 
0033 static inline void *static_call_addr(struct static_call_site *site)
0034 {
0035     return (void *)((long)site->addr + (long)&site->addr);
0036 }
0037 
0038 static inline unsigned long __static_call_key(const struct static_call_site *site)
0039 {
0040     return (long)site->key + (long)&site->key;
0041 }
0042 
0043 static inline struct static_call_key *static_call_key(const struct static_call_site *site)
0044 {
0045     return (void *)(__static_call_key(site) & ~STATIC_CALL_SITE_FLAGS);
0046 }
0047 
0048 /* These assume the key is word-aligned. */
0049 static inline bool static_call_is_init(struct static_call_site *site)
0050 {
0051     return __static_call_key(site) & STATIC_CALL_SITE_INIT;
0052 }
0053 
0054 static inline bool static_call_is_tail(struct static_call_site *site)
0055 {
0056     return __static_call_key(site) & STATIC_CALL_SITE_TAIL;
0057 }
0058 
0059 static inline void static_call_set_init(struct static_call_site *site)
0060 {
0061     site->key = (__static_call_key(site) | STATIC_CALL_SITE_INIT) -
0062             (long)&site->key;
0063 }
0064 
0065 static int static_call_site_cmp(const void *_a, const void *_b)
0066 {
0067     const struct static_call_site *a = _a;
0068     const struct static_call_site *b = _b;
0069     const struct static_call_key *key_a = static_call_key(a);
0070     const struct static_call_key *key_b = static_call_key(b);
0071 
0072     if (key_a < key_b)
0073         return -1;
0074 
0075     if (key_a > key_b)
0076         return 1;
0077 
0078     return 0;
0079 }
0080 
0081 static void static_call_site_swap(void *_a, void *_b, int size)
0082 {
0083     long delta = (unsigned long)_a - (unsigned long)_b;
0084     struct static_call_site *a = _a;
0085     struct static_call_site *b = _b;
0086     struct static_call_site tmp = *a;
0087 
0088     a->addr = b->addr  - delta;
0089     a->key  = b->key   - delta;
0090 
0091     b->addr = tmp.addr + delta;
0092     b->key  = tmp.key  + delta;
0093 }
0094 
0095 static inline void static_call_sort_entries(struct static_call_site *start,
0096                         struct static_call_site *stop)
0097 {
0098     sort(start, stop - start, sizeof(struct static_call_site),
0099          static_call_site_cmp, static_call_site_swap);
0100 }
0101 
0102 static inline bool static_call_key_has_mods(struct static_call_key *key)
0103 {
0104     return !(key->type & 1);
0105 }
0106 
0107 static inline struct static_call_mod *static_call_key_next(struct static_call_key *key)
0108 {
0109     if (!static_call_key_has_mods(key))
0110         return NULL;
0111 
0112     return key->mods;
0113 }
0114 
0115 static inline struct static_call_site *static_call_key_sites(struct static_call_key *key)
0116 {
0117     if (static_call_key_has_mods(key))
0118         return NULL;
0119 
0120     return (struct static_call_site *)(key->type & ~1);
0121 }
0122 
0123 void __static_call_update(struct static_call_key *key, void *tramp, void *func)
0124 {
0125     struct static_call_site *site, *stop;
0126     struct static_call_mod *site_mod, first;
0127 
0128     cpus_read_lock();
0129     static_call_lock();
0130 
0131     if (key->func == func)
0132         goto done;
0133 
0134     key->func = func;
0135 
0136     arch_static_call_transform(NULL, tramp, func, false);
0137 
0138     /*
0139      * If uninitialized, we'll not update the callsites, but they still
0140      * point to the trampoline and we just patched that.
0141      */
0142     if (WARN_ON_ONCE(!static_call_initialized))
0143         goto done;
0144 
0145     first = (struct static_call_mod){
0146         .next = static_call_key_next(key),
0147         .mod = NULL,
0148         .sites = static_call_key_sites(key),
0149     };
0150 
0151     for (site_mod = &first; site_mod; site_mod = site_mod->next) {
0152         bool init = system_state < SYSTEM_RUNNING;
0153         struct module *mod = site_mod->mod;
0154 
0155         if (!site_mod->sites) {
0156             /*
0157              * This can happen if the static call key is defined in
0158              * a module which doesn't use it.
0159              *
0160              * It also happens in the has_mods case, where the
0161              * 'first' entry has no sites associated with it.
0162              */
0163             continue;
0164         }
0165 
0166         stop = __stop_static_call_sites;
0167 
0168         if (mod) {
0169 #ifdef CONFIG_MODULES
0170             stop = mod->static_call_sites +
0171                    mod->num_static_call_sites;
0172             init = mod->state == MODULE_STATE_COMING;
0173 #endif
0174         }
0175 
0176         for (site = site_mod->sites;
0177              site < stop && static_call_key(site) == key; site++) {
0178             void *site_addr = static_call_addr(site);
0179 
0180             if (!init && static_call_is_init(site))
0181                 continue;
0182 
0183             if (!kernel_text_address((unsigned long)site_addr)) {
0184                 /*
0185                  * This skips patching built-in __exit, which
0186                  * is part of init_section_contains() but is
0187                  * not part of kernel_text_address().
0188                  *
0189                  * Skipping built-in __exit is fine since it
0190                  * will never be executed.
0191                  */
0192                 WARN_ONCE(!static_call_is_init(site),
0193                       "can't patch static call site at %pS",
0194                       site_addr);
0195                 continue;
0196             }
0197 
0198             arch_static_call_transform(site_addr, NULL, func,
0199                            static_call_is_tail(site));
0200         }
0201     }
0202 
0203 done:
0204     static_call_unlock();
0205     cpus_read_unlock();
0206 }
0207 EXPORT_SYMBOL_GPL(__static_call_update);
0208 
0209 static int __static_call_init(struct module *mod,
0210                   struct static_call_site *start,
0211                   struct static_call_site *stop)
0212 {
0213     struct static_call_site *site;
0214     struct static_call_key *key, *prev_key = NULL;
0215     struct static_call_mod *site_mod;
0216 
0217     if (start == stop)
0218         return 0;
0219 
0220     static_call_sort_entries(start, stop);
0221 
0222     for (site = start; site < stop; site++) {
0223         void *site_addr = static_call_addr(site);
0224 
0225         if ((mod && within_module_init((unsigned long)site_addr, mod)) ||
0226             (!mod && init_section_contains(site_addr, 1)))
0227             static_call_set_init(site);
0228 
0229         key = static_call_key(site);
0230         if (key != prev_key) {
0231             prev_key = key;
0232 
0233             /*
0234              * For vmlinux (!mod) avoid the allocation by storing
0235              * the sites pointer in the key itself. Also see
0236              * __static_call_update()'s @first.
0237              *
0238              * This allows architectures (eg. x86) to call
0239              * static_call_init() before memory allocation works.
0240              */
0241             if (!mod) {
0242                 key->sites = site;
0243                 key->type |= 1;
0244                 goto do_transform;
0245             }
0246 
0247             site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL);
0248             if (!site_mod)
0249                 return -ENOMEM;
0250 
0251             /*
0252              * When the key has a direct sites pointer, extract
0253              * that into an explicit struct static_call_mod, so we
0254              * can have a list of modules.
0255              */
0256             if (static_call_key_sites(key)) {
0257                 site_mod->mod = NULL;
0258                 site_mod->next = NULL;
0259                 site_mod->sites = static_call_key_sites(key);
0260 
0261                 key->mods = site_mod;
0262 
0263                 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL);
0264                 if (!site_mod)
0265                     return -ENOMEM;
0266             }
0267 
0268             site_mod->mod = mod;
0269             site_mod->sites = site;
0270             site_mod->next = static_call_key_next(key);
0271             key->mods = site_mod;
0272         }
0273 
0274 do_transform:
0275         arch_static_call_transform(site_addr, NULL, key->func,
0276                 static_call_is_tail(site));
0277     }
0278 
0279     return 0;
0280 }
0281 
0282 static int addr_conflict(struct static_call_site *site, void *start, void *end)
0283 {
0284     unsigned long addr = (unsigned long)static_call_addr(site);
0285 
0286     if (addr <= (unsigned long)end &&
0287         addr + CALL_INSN_SIZE > (unsigned long)start)
0288         return 1;
0289 
0290     return 0;
0291 }
0292 
0293 static int __static_call_text_reserved(struct static_call_site *iter_start,
0294                        struct static_call_site *iter_stop,
0295                        void *start, void *end, bool init)
0296 {
0297     struct static_call_site *iter = iter_start;
0298 
0299     while (iter < iter_stop) {
0300         if (init || !static_call_is_init(iter)) {
0301             if (addr_conflict(iter, start, end))
0302                 return 1;
0303         }
0304         iter++;
0305     }
0306 
0307     return 0;
0308 }
0309 
0310 #ifdef CONFIG_MODULES
0311 
0312 static int __static_call_mod_text_reserved(void *start, void *end)
0313 {
0314     struct module *mod;
0315     int ret;
0316 
0317     preempt_disable();
0318     mod = __module_text_address((unsigned long)start);
0319     WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod);
0320     if (!try_module_get(mod))
0321         mod = NULL;
0322     preempt_enable();
0323 
0324     if (!mod)
0325         return 0;
0326 
0327     ret = __static_call_text_reserved(mod->static_call_sites,
0328             mod->static_call_sites + mod->num_static_call_sites,
0329             start, end, mod->state == MODULE_STATE_COMING);
0330 
0331     module_put(mod);
0332 
0333     return ret;
0334 }
0335 
0336 static unsigned long tramp_key_lookup(unsigned long addr)
0337 {
0338     struct static_call_tramp_key *start = __start_static_call_tramp_key;
0339     struct static_call_tramp_key *stop = __stop_static_call_tramp_key;
0340     struct static_call_tramp_key *tramp_key;
0341 
0342     for (tramp_key = start; tramp_key != stop; tramp_key++) {
0343         unsigned long tramp;
0344 
0345         tramp = (long)tramp_key->tramp + (long)&tramp_key->tramp;
0346         if (tramp == addr)
0347             return (long)tramp_key->key + (long)&tramp_key->key;
0348     }
0349 
0350     return 0;
0351 }
0352 
0353 static int static_call_add_module(struct module *mod)
0354 {
0355     struct static_call_site *start = mod->static_call_sites;
0356     struct static_call_site *stop = start + mod->num_static_call_sites;
0357     struct static_call_site *site;
0358 
0359     for (site = start; site != stop; site++) {
0360         unsigned long s_key = __static_call_key(site);
0361         unsigned long addr = s_key & ~STATIC_CALL_SITE_FLAGS;
0362         unsigned long key;
0363 
0364         /*
0365          * Is the key is exported, 'addr' points to the key, which
0366          * means modules are allowed to call static_call_update() on
0367          * it.
0368          *
0369          * Otherwise, the key isn't exported, and 'addr' points to the
0370          * trampoline so we need to lookup the key.
0371          *
0372          * We go through this dance to prevent crazy modules from
0373          * abusing sensitive static calls.
0374          */
0375         if (!kernel_text_address(addr))
0376             continue;
0377 
0378         key = tramp_key_lookup(addr);
0379         if (!key) {
0380             pr_warn("Failed to fixup __raw_static_call() usage at: %ps\n",
0381                 static_call_addr(site));
0382             return -EINVAL;
0383         }
0384 
0385         key |= s_key & STATIC_CALL_SITE_FLAGS;
0386         site->key = key - (long)&site->key;
0387     }
0388 
0389     return __static_call_init(mod, start, stop);
0390 }
0391 
0392 static void static_call_del_module(struct module *mod)
0393 {
0394     struct static_call_site *start = mod->static_call_sites;
0395     struct static_call_site *stop = mod->static_call_sites +
0396                     mod->num_static_call_sites;
0397     struct static_call_key *key, *prev_key = NULL;
0398     struct static_call_mod *site_mod, **prev;
0399     struct static_call_site *site;
0400 
0401     for (site = start; site < stop; site++) {
0402         key = static_call_key(site);
0403         if (key == prev_key)
0404             continue;
0405 
0406         prev_key = key;
0407 
0408         for (prev = &key->mods, site_mod = key->mods;
0409              site_mod && site_mod->mod != mod;
0410              prev = &site_mod->next, site_mod = site_mod->next)
0411             ;
0412 
0413         if (!site_mod)
0414             continue;
0415 
0416         *prev = site_mod->next;
0417         kfree(site_mod);
0418     }
0419 }
0420 
0421 static int static_call_module_notify(struct notifier_block *nb,
0422                      unsigned long val, void *data)
0423 {
0424     struct module *mod = data;
0425     int ret = 0;
0426 
0427     cpus_read_lock();
0428     static_call_lock();
0429 
0430     switch (val) {
0431     case MODULE_STATE_COMING:
0432         ret = static_call_add_module(mod);
0433         if (ret) {
0434             WARN(1, "Failed to allocate memory for static calls");
0435             static_call_del_module(mod);
0436         }
0437         break;
0438     case MODULE_STATE_GOING:
0439         static_call_del_module(mod);
0440         break;
0441     }
0442 
0443     static_call_unlock();
0444     cpus_read_unlock();
0445 
0446     return notifier_from_errno(ret);
0447 }
0448 
0449 static struct notifier_block static_call_module_nb = {
0450     .notifier_call = static_call_module_notify,
0451 };
0452 
0453 #else
0454 
0455 static inline int __static_call_mod_text_reserved(void *start, void *end)
0456 {
0457     return 0;
0458 }
0459 
0460 #endif /* CONFIG_MODULES */
0461 
0462 int static_call_text_reserved(void *start, void *end)
0463 {
0464     bool init = system_state < SYSTEM_RUNNING;
0465     int ret = __static_call_text_reserved(__start_static_call_sites,
0466             __stop_static_call_sites, start, end, init);
0467 
0468     if (ret)
0469         return ret;
0470 
0471     return __static_call_mod_text_reserved(start, end);
0472 }
0473 
0474 int __init static_call_init(void)
0475 {
0476     int ret;
0477 
0478     if (static_call_initialized)
0479         return 0;
0480 
0481     cpus_read_lock();
0482     static_call_lock();
0483     ret = __static_call_init(NULL, __start_static_call_sites,
0484                  __stop_static_call_sites);
0485     static_call_unlock();
0486     cpus_read_unlock();
0487 
0488     if (ret) {
0489         pr_err("Failed to allocate memory for static_call!\n");
0490         BUG();
0491     }
0492 
0493     static_call_initialized = true;
0494 
0495 #ifdef CONFIG_MODULES
0496     register_module_notifier(&static_call_module_nb);
0497 #endif
0498     return 0;
0499 }
0500 early_initcall(static_call_init);
0501 
0502 #ifdef CONFIG_STATIC_CALL_SELFTEST
0503 
0504 static int func_a(int x)
0505 {
0506     return x+1;
0507 }
0508 
0509 static int func_b(int x)
0510 {
0511     return x+2;
0512 }
0513 
0514 DEFINE_STATIC_CALL(sc_selftest, func_a);
0515 
0516 static struct static_call_data {
0517       int (*func)(int);
0518       int val;
0519       int expect;
0520 } static_call_data [] __initdata = {
0521       { NULL,   2, 3 },
0522       { func_b, 2, 4 },
0523       { func_a, 2, 3 }
0524 };
0525 
0526 static int __init test_static_call_init(void)
0527 {
0528       int i;
0529 
0530       for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) {
0531           struct static_call_data *scd = &static_call_data[i];
0532 
0533               if (scd->func)
0534                       static_call_update(sc_selftest, scd->func);
0535 
0536               WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect);
0537       }
0538 
0539       return 0;
0540 }
0541 early_initcall(test_static_call_init);
0542 
0543 #endif /* CONFIG_STATIC_CALL_SELFTEST */