Back to home page

LXR

 
 

    


0001 /*
0002  * Resizable, Scalable, Concurrent Hash Table
0003  *
0004  * Copyright (c) 2014-2015 Thomas Graf <tgraf@suug.ch>
0005  * Copyright (c) 2008-2014 Patrick McHardy <kaber@trash.net>
0006  *
0007  * This program is free software; you can redistribute it and/or modify
0008  * it under the terms of the GNU General Public License version 2 as
0009  * published by the Free Software Foundation.
0010  */
0011 
0012 /**************************************************************************
0013  * Self Test
0014  **************************************************************************/
0015 
0016 #include <linux/init.h>
0017 #include <linux/jhash.h>
0018 #include <linux/kernel.h>
0019 #include <linux/kthread.h>
0020 #include <linux/module.h>
0021 #include <linux/rcupdate.h>
0022 #include <linux/rhashtable.h>
0023 #include <linux/semaphore.h>
0024 #include <linux/slab.h>
0025 #include <linux/sched.h>
0026 #include <linux/vmalloc.h>
0027 
0028 #define MAX_ENTRIES 1000000
0029 #define TEST_INSERT_FAIL INT_MAX
0030 
0031 static int entries = 50000;
0032 module_param(entries, int, 0);
0033 MODULE_PARM_DESC(entries, "Number of entries to add (default: 50000)");
0034 
0035 static int runs = 4;
0036 module_param(runs, int, 0);
0037 MODULE_PARM_DESC(runs, "Number of test runs per variant (default: 4)");
0038 
0039 static int max_size = 0;
0040 module_param(max_size, int, 0);
0041 MODULE_PARM_DESC(max_size, "Maximum table size (default: calculated)");
0042 
0043 static bool shrinking = false;
0044 module_param(shrinking, bool, 0);
0045 MODULE_PARM_DESC(shrinking, "Enable automatic shrinking (default: off)");
0046 
0047 static int size = 8;
0048 module_param(size, int, 0);
0049 MODULE_PARM_DESC(size, "Initial size hint of table (default: 8)");
0050 
0051 static int tcount = 10;
0052 module_param(tcount, int, 0);
0053 MODULE_PARM_DESC(tcount, "Number of threads to spawn (default: 10)");
0054 
0055 static bool enomem_retry = false;
0056 module_param(enomem_retry, bool, 0);
0057 MODULE_PARM_DESC(enomem_retry, "Retry insert even if -ENOMEM was returned (default: off)");
0058 
0059 struct test_obj {
0060     int         value;
0061     struct rhash_head   node;
0062 };
0063 
0064 struct thread_data {
0065     int id;
0066     struct task_struct *task;
0067     struct test_obj *objs;
0068 };
0069 
0070 static struct test_obj array[MAX_ENTRIES];
0071 
0072 static struct rhashtable_params test_rht_params = {
0073     .head_offset = offsetof(struct test_obj, node),
0074     .key_offset = offsetof(struct test_obj, value),
0075     .key_len = sizeof(int),
0076     .hashfn = jhash,
0077     .nulls_base = (3U << RHT_BASE_SHIFT),
0078 };
0079 
0080 static struct semaphore prestart_sem;
0081 static struct semaphore startup_sem = __SEMAPHORE_INITIALIZER(startup_sem, 0);
0082 
0083 static int insert_retry(struct rhashtable *ht, struct rhash_head *obj,
0084                         const struct rhashtable_params params)
0085 {
0086     int err, retries = -1, enomem_retries = 0;
0087 
0088     do {
0089         retries++;
0090         cond_resched();
0091         err = rhashtable_insert_fast(ht, obj, params);
0092         if (err == -ENOMEM && enomem_retry) {
0093             enomem_retries++;
0094             err = -EBUSY;
0095         }
0096     } while (err == -EBUSY);
0097 
0098     if (enomem_retries)
0099         pr_info(" %u insertions retried after -ENOMEM\n",
0100             enomem_retries);
0101 
0102     return err ? : retries;
0103 }
0104 
0105 static int __init test_rht_lookup(struct rhashtable *ht)
0106 {
0107     unsigned int i;
0108 
0109     for (i = 0; i < entries * 2; i++) {
0110         struct test_obj *obj;
0111         bool expected = !(i % 2);
0112         u32 key = i;
0113 
0114         if (array[i / 2].value == TEST_INSERT_FAIL)
0115             expected = false;
0116 
0117         obj = rhashtable_lookup_fast(ht, &key, test_rht_params);
0118 
0119         if (expected && !obj) {
0120             pr_warn("Test failed: Could not find key %u\n", key);
0121             return -ENOENT;
0122         } else if (!expected && obj) {
0123             pr_warn("Test failed: Unexpected entry found for key %u\n",
0124                 key);
0125             return -EEXIST;
0126         } else if (expected && obj) {
0127             if (obj->value != i) {
0128                 pr_warn("Test failed: Lookup value mismatch %u!=%u\n",
0129                     obj->value, i);
0130                 return -EINVAL;
0131             }
0132         }
0133 
0134         cond_resched_rcu();
0135     }
0136 
0137     return 0;
0138 }
0139 
0140 static void test_bucket_stats(struct rhashtable *ht)
0141 {
0142     unsigned int err, total = 0, chain_len = 0;
0143     struct rhashtable_iter hti;
0144     struct rhash_head *pos;
0145 
0146     err = rhashtable_walk_init(ht, &hti, GFP_KERNEL);
0147     if (err) {
0148         pr_warn("Test failed: allocation error");
0149         return;
0150     }
0151 
0152     err = rhashtable_walk_start(&hti);
0153     if (err && err != -EAGAIN) {
0154         pr_warn("Test failed: iterator failed: %d\n", err);
0155         return;
0156     }
0157 
0158     while ((pos = rhashtable_walk_next(&hti))) {
0159         if (PTR_ERR(pos) == -EAGAIN) {
0160             pr_info("Info: encountered resize\n");
0161             chain_len++;
0162             continue;
0163         } else if (IS_ERR(pos)) {
0164             pr_warn("Test failed: rhashtable_walk_next() error: %ld\n",
0165                 PTR_ERR(pos));
0166             break;
0167         }
0168 
0169         total++;
0170     }
0171 
0172     rhashtable_walk_stop(&hti);
0173     rhashtable_walk_exit(&hti);
0174 
0175     pr_info("  Traversal complete: counted=%u, nelems=%u, entries=%d, table-jumps=%u\n",
0176         total, atomic_read(&ht->nelems), entries, chain_len);
0177 
0178     if (total != atomic_read(&ht->nelems) || total != entries)
0179         pr_warn("Test failed: Total count mismatch ^^^");
0180 }
0181 
0182 static s64 __init test_rhashtable(struct rhashtable *ht)
0183 {
0184     struct test_obj *obj;
0185     int err;
0186     unsigned int i, insert_retries = 0;
0187     s64 start, end;
0188 
0189     /*
0190      * Insertion Test:
0191      * Insert entries into table with all keys even numbers
0192      */
0193     pr_info("  Adding %d keys\n", entries);
0194     start = ktime_get_ns();
0195     for (i = 0; i < entries; i++) {
0196         struct test_obj *obj = &array[i];
0197 
0198         obj->value = i * 2;
0199         err = insert_retry(ht, &obj->node, test_rht_params);
0200         if (err > 0)
0201             insert_retries += err;
0202         else if (err)
0203             return err;
0204     }
0205 
0206     if (insert_retries)
0207         pr_info("  %u insertions retried due to memory pressure\n",
0208             insert_retries);
0209 
0210     test_bucket_stats(ht);
0211     rcu_read_lock();
0212     test_rht_lookup(ht);
0213     rcu_read_unlock();
0214 
0215     test_bucket_stats(ht);
0216 
0217     pr_info("  Deleting %d keys\n", entries);
0218     for (i = 0; i < entries; i++) {
0219         u32 key = i * 2;
0220 
0221         if (array[i].value != TEST_INSERT_FAIL) {
0222             obj = rhashtable_lookup_fast(ht, &key, test_rht_params);
0223             BUG_ON(!obj);
0224 
0225             rhashtable_remove_fast(ht, &obj->node, test_rht_params);
0226         }
0227 
0228         cond_resched();
0229     }
0230 
0231     end = ktime_get_ns();
0232     pr_info("  Duration of test: %lld ns\n", end - start);
0233 
0234     return end - start;
0235 }
0236 
0237 static struct rhashtable ht;
0238 
0239 static int thread_lookup_test(struct thread_data *tdata)
0240 {
0241     int i, err = 0;
0242 
0243     for (i = 0; i < entries; i++) {
0244         struct test_obj *obj;
0245         int key = (tdata->id << 16) | i;
0246 
0247         obj = rhashtable_lookup_fast(&ht, &key, test_rht_params);
0248         if (obj && (tdata->objs[i].value == TEST_INSERT_FAIL)) {
0249             pr_err("  found unexpected object %d\n", key);
0250             err++;
0251         } else if (!obj && (tdata->objs[i].value != TEST_INSERT_FAIL)) {
0252             pr_err("  object %d not found!\n", key);
0253             err++;
0254         } else if (obj && (obj->value != key)) {
0255             pr_err("  wrong object returned (got %d, expected %d)\n",
0256                    obj->value, key);
0257             err++;
0258         }
0259 
0260         cond_resched();
0261     }
0262     return err;
0263 }
0264 
0265 static int threadfunc(void *data)
0266 {
0267     int i, step, err = 0, insert_retries = 0;
0268     struct thread_data *tdata = data;
0269 
0270     up(&prestart_sem);
0271     if (down_interruptible(&startup_sem))
0272         pr_err("  thread[%d]: down_interruptible failed\n", tdata->id);
0273 
0274     for (i = 0; i < entries; i++) {
0275         tdata->objs[i].value = (tdata->id << 16) | i;
0276         err = insert_retry(&ht, &tdata->objs[i].node, test_rht_params);
0277         if (err > 0) {
0278             insert_retries += err;
0279         } else if (err) {
0280             pr_err("  thread[%d]: rhashtable_insert_fast failed\n",
0281                    tdata->id);
0282             goto out;
0283         }
0284     }
0285     if (insert_retries)
0286         pr_info("  thread[%d]: %u insertions retried due to memory pressure\n",
0287             tdata->id, insert_retries);
0288 
0289     err = thread_lookup_test(tdata);
0290     if (err) {
0291         pr_err("  thread[%d]: rhashtable_lookup_test failed\n",
0292                tdata->id);
0293         goto out;
0294     }
0295 
0296     for (step = 10; step > 0; step--) {
0297         for (i = 0; i < entries; i += step) {
0298             if (tdata->objs[i].value == TEST_INSERT_FAIL)
0299                 continue;
0300             err = rhashtable_remove_fast(&ht, &tdata->objs[i].node,
0301                                          test_rht_params);
0302             if (err) {
0303                 pr_err("  thread[%d]: rhashtable_remove_fast failed\n",
0304                        tdata->id);
0305                 goto out;
0306             }
0307             tdata->objs[i].value = TEST_INSERT_FAIL;
0308 
0309             cond_resched();
0310         }
0311         err = thread_lookup_test(tdata);
0312         if (err) {
0313             pr_err("  thread[%d]: rhashtable_lookup_test (2) failed\n",
0314                    tdata->id);
0315             goto out;
0316         }
0317     }
0318 out:
0319     while (!kthread_should_stop()) {
0320         set_current_state(TASK_INTERRUPTIBLE);
0321         schedule();
0322     }
0323     return err;
0324 }
0325 
0326 static int __init test_rht_init(void)
0327 {
0328     int i, err, started_threads = 0, failed_threads = 0;
0329     u64 total_time = 0;
0330     struct thread_data *tdata;
0331     struct test_obj *objs;
0332 
0333     entries = min(entries, MAX_ENTRIES);
0334 
0335     test_rht_params.automatic_shrinking = shrinking;
0336     test_rht_params.max_size = max_size ? : roundup_pow_of_two(entries);
0337     test_rht_params.nelem_hint = size;
0338 
0339     pr_info("Running rhashtable test nelem=%d, max_size=%d, shrinking=%d\n",
0340         size, max_size, shrinking);
0341 
0342     for (i = 0; i < runs; i++) {
0343         s64 time;
0344 
0345         pr_info("Test %02d:\n", i);
0346         memset(&array, 0, sizeof(array));
0347         err = rhashtable_init(&ht, &test_rht_params);
0348         if (err < 0) {
0349             pr_warn("Test failed: Unable to initialize hashtable: %d\n",
0350                 err);
0351             continue;
0352         }
0353 
0354         time = test_rhashtable(&ht);
0355         rhashtable_destroy(&ht);
0356         if (time < 0) {
0357             pr_warn("Test failed: return code %lld\n", time);
0358             return -EINVAL;
0359         }
0360 
0361         total_time += time;
0362     }
0363 
0364     do_div(total_time, runs);
0365     pr_info("Average test time: %llu\n", total_time);
0366 
0367     if (!tcount)
0368         return 0;
0369 
0370     pr_info("Testing concurrent rhashtable access from %d threads\n",
0371             tcount);
0372     sema_init(&prestart_sem, 1 - tcount);
0373     tdata = vzalloc(tcount * sizeof(struct thread_data));
0374     if (!tdata)
0375         return -ENOMEM;
0376     objs  = vzalloc(tcount * entries * sizeof(struct test_obj));
0377     if (!objs) {
0378         vfree(tdata);
0379         return -ENOMEM;
0380     }
0381 
0382     test_rht_params.max_size = max_size ? :
0383                                roundup_pow_of_two(tcount * entries);
0384     err = rhashtable_init(&ht, &test_rht_params);
0385     if (err < 0) {
0386         pr_warn("Test failed: Unable to initialize hashtable: %d\n",
0387             err);
0388         vfree(tdata);
0389         vfree(objs);
0390         return -EINVAL;
0391     }
0392     for (i = 0; i < tcount; i++) {
0393         tdata[i].id = i;
0394         tdata[i].objs = objs + i * entries;
0395         tdata[i].task = kthread_run(threadfunc, &tdata[i],
0396                                     "rhashtable_thrad[%d]", i);
0397         if (IS_ERR(tdata[i].task))
0398             pr_err(" kthread_run failed for thread %d\n", i);
0399         else
0400             started_threads++;
0401     }
0402     if (down_interruptible(&prestart_sem))
0403         pr_err("  down interruptible failed\n");
0404     for (i = 0; i < tcount; i++)
0405         up(&startup_sem);
0406     for (i = 0; i < tcount; i++) {
0407         if (IS_ERR(tdata[i].task))
0408             continue;
0409         if ((err = kthread_stop(tdata[i].task))) {
0410             pr_warn("Test failed: thread %d returned: %d\n",
0411                     i, err);
0412             failed_threads++;
0413         }
0414     }
0415     pr_info("Started %d threads, %d failed\n",
0416             started_threads, failed_threads);
0417     rhashtable_destroy(&ht);
0418     vfree(tdata);
0419     vfree(objs);
0420     return 0;
0421 }
0422 
0423 static void __exit test_rht_exit(void)
0424 {
0425 }
0426 
0427 module_init(test_rht_init);
0428 module_exit(test_rht_exit);
0429 
0430 MODULE_LICENSE("GPL v2");