Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0-only
0002 #include <linux/module.h>
0003 #include <linux/moduleparam.h>
0004 #include <linux/rbtree_augmented.h>
0005 #include <linux/random.h>
0006 #include <linux/slab.h>
0007 #include <asm/timex.h>
0008 
0009 #define __param(type, name, init, msg)      \
0010     static type name = init;        \
0011     module_param(name, type, 0444);     \
0012     MODULE_PARM_DESC(name, msg);
0013 
0014 __param(int, nnodes, 100, "Number of nodes in the rb-tree");
0015 __param(int, perf_loops, 1000, "Number of iterations modifying the rb-tree");
0016 __param(int, check_loops, 100, "Number of iterations modifying and verifying the rb-tree");
0017 
0018 struct test_node {
0019     u32 key;
0020     struct rb_node rb;
0021 
0022     /* following fields used for testing augmented rbtree functionality */
0023     u32 val;
0024     u32 augmented;
0025 };
0026 
0027 static struct rb_root_cached root = RB_ROOT_CACHED;
0028 static struct test_node *nodes = NULL;
0029 
0030 static struct rnd_state rnd;
0031 
0032 static void insert(struct test_node *node, struct rb_root_cached *root)
0033 {
0034     struct rb_node **new = &root->rb_root.rb_node, *parent = NULL;
0035     u32 key = node->key;
0036 
0037     while (*new) {
0038         parent = *new;
0039         if (key < rb_entry(parent, struct test_node, rb)->key)
0040             new = &parent->rb_left;
0041         else
0042             new = &parent->rb_right;
0043     }
0044 
0045     rb_link_node(&node->rb, parent, new);
0046     rb_insert_color(&node->rb, &root->rb_root);
0047 }
0048 
0049 static void insert_cached(struct test_node *node, struct rb_root_cached *root)
0050 {
0051     struct rb_node **new = &root->rb_root.rb_node, *parent = NULL;
0052     u32 key = node->key;
0053     bool leftmost = true;
0054 
0055     while (*new) {
0056         parent = *new;
0057         if (key < rb_entry(parent, struct test_node, rb)->key)
0058             new = &parent->rb_left;
0059         else {
0060             new = &parent->rb_right;
0061             leftmost = false;
0062         }
0063     }
0064 
0065     rb_link_node(&node->rb, parent, new);
0066     rb_insert_color_cached(&node->rb, root, leftmost);
0067 }
0068 
0069 static inline void erase(struct test_node *node, struct rb_root_cached *root)
0070 {
0071     rb_erase(&node->rb, &root->rb_root);
0072 }
0073 
0074 static inline void erase_cached(struct test_node *node, struct rb_root_cached *root)
0075 {
0076     rb_erase_cached(&node->rb, root);
0077 }
0078 
0079 
0080 #define NODE_VAL(node) ((node)->val)
0081 
0082 RB_DECLARE_CALLBACKS_MAX(static, augment_callbacks,
0083              struct test_node, rb, u32, augmented, NODE_VAL)
0084 
0085 static void insert_augmented(struct test_node *node,
0086                  struct rb_root_cached *root)
0087 {
0088     struct rb_node **new = &root->rb_root.rb_node, *rb_parent = NULL;
0089     u32 key = node->key;
0090     u32 val = node->val;
0091     struct test_node *parent;
0092 
0093     while (*new) {
0094         rb_parent = *new;
0095         parent = rb_entry(rb_parent, struct test_node, rb);
0096         if (parent->augmented < val)
0097             parent->augmented = val;
0098         if (key < parent->key)
0099             new = &parent->rb.rb_left;
0100         else
0101             new = &parent->rb.rb_right;
0102     }
0103 
0104     node->augmented = val;
0105     rb_link_node(&node->rb, rb_parent, new);
0106     rb_insert_augmented(&node->rb, &root->rb_root, &augment_callbacks);
0107 }
0108 
0109 static void insert_augmented_cached(struct test_node *node,
0110                     struct rb_root_cached *root)
0111 {
0112     struct rb_node **new = &root->rb_root.rb_node, *rb_parent = NULL;
0113     u32 key = node->key;
0114     u32 val = node->val;
0115     struct test_node *parent;
0116     bool leftmost = true;
0117 
0118     while (*new) {
0119         rb_parent = *new;
0120         parent = rb_entry(rb_parent, struct test_node, rb);
0121         if (parent->augmented < val)
0122             parent->augmented = val;
0123         if (key < parent->key)
0124             new = &parent->rb.rb_left;
0125         else {
0126             new = &parent->rb.rb_right;
0127             leftmost = false;
0128         }
0129     }
0130 
0131     node->augmented = val;
0132     rb_link_node(&node->rb, rb_parent, new);
0133     rb_insert_augmented_cached(&node->rb, root,
0134                    leftmost, &augment_callbacks);
0135 }
0136 
0137 
0138 static void erase_augmented(struct test_node *node, struct rb_root_cached *root)
0139 {
0140     rb_erase_augmented(&node->rb, &root->rb_root, &augment_callbacks);
0141 }
0142 
0143 static void erase_augmented_cached(struct test_node *node,
0144                    struct rb_root_cached *root)
0145 {
0146     rb_erase_augmented_cached(&node->rb, root, &augment_callbacks);
0147 }
0148 
0149 static void init(void)
0150 {
0151     int i;
0152     for (i = 0; i < nnodes; i++) {
0153         nodes[i].key = prandom_u32_state(&rnd);
0154         nodes[i].val = prandom_u32_state(&rnd);
0155     }
0156 }
0157 
0158 static bool is_red(struct rb_node *rb)
0159 {
0160     return !(rb->__rb_parent_color & 1);
0161 }
0162 
0163 static int black_path_count(struct rb_node *rb)
0164 {
0165     int count;
0166     for (count = 0; rb; rb = rb_parent(rb))
0167         count += !is_red(rb);
0168     return count;
0169 }
0170 
0171 static void check_postorder_foreach(int nr_nodes)
0172 {
0173     struct test_node *cur, *n;
0174     int count = 0;
0175     rbtree_postorder_for_each_entry_safe(cur, n, &root.rb_root, rb)
0176         count++;
0177 
0178     WARN_ON_ONCE(count != nr_nodes);
0179 }
0180 
0181 static void check_postorder(int nr_nodes)
0182 {
0183     struct rb_node *rb;
0184     int count = 0;
0185     for (rb = rb_first_postorder(&root.rb_root); rb; rb = rb_next_postorder(rb))
0186         count++;
0187 
0188     WARN_ON_ONCE(count != nr_nodes);
0189 }
0190 
0191 static void check(int nr_nodes)
0192 {
0193     struct rb_node *rb;
0194     int count = 0, blacks = 0;
0195     u32 prev_key = 0;
0196 
0197     for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) {
0198         struct test_node *node = rb_entry(rb, struct test_node, rb);
0199         WARN_ON_ONCE(node->key < prev_key);
0200         WARN_ON_ONCE(is_red(rb) &&
0201                  (!rb_parent(rb) || is_red(rb_parent(rb))));
0202         if (!count)
0203             blacks = black_path_count(rb);
0204         else
0205             WARN_ON_ONCE((!rb->rb_left || !rb->rb_right) &&
0206                      blacks != black_path_count(rb));
0207         prev_key = node->key;
0208         count++;
0209     }
0210 
0211     WARN_ON_ONCE(count != nr_nodes);
0212     WARN_ON_ONCE(count < (1 << black_path_count(rb_last(&root.rb_root))) - 1);
0213 
0214     check_postorder(nr_nodes);
0215     check_postorder_foreach(nr_nodes);
0216 }
0217 
0218 static void check_augmented(int nr_nodes)
0219 {
0220     struct rb_node *rb;
0221 
0222     check(nr_nodes);
0223     for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) {
0224         struct test_node *node = rb_entry(rb, struct test_node, rb);
0225         u32 subtree, max = node->val;
0226         if (node->rb.rb_left) {
0227             subtree = rb_entry(node->rb.rb_left, struct test_node,
0228                        rb)->augmented;
0229             if (max < subtree)
0230                 max = subtree;
0231         }
0232         if (node->rb.rb_right) {
0233             subtree = rb_entry(node->rb.rb_right, struct test_node,
0234                        rb)->augmented;
0235             if (max < subtree)
0236                 max = subtree;
0237         }
0238         WARN_ON_ONCE(node->augmented != max);
0239     }
0240 }
0241 
0242 static int __init rbtree_test_init(void)
0243 {
0244     int i, j;
0245     cycles_t time1, time2, time;
0246     struct rb_node *node;
0247 
0248     nodes = kmalloc_array(nnodes, sizeof(*nodes), GFP_KERNEL);
0249     if (!nodes)
0250         return -ENOMEM;
0251 
0252     printk(KERN_ALERT "rbtree testing");
0253 
0254     prandom_seed_state(&rnd, 3141592653589793238ULL);
0255     init();
0256 
0257     time1 = get_cycles();
0258 
0259     for (i = 0; i < perf_loops; i++) {
0260         for (j = 0; j < nnodes; j++)
0261             insert(nodes + j, &root);
0262         for (j = 0; j < nnodes; j++)
0263             erase(nodes + j, &root);
0264     }
0265 
0266     time2 = get_cycles();
0267     time = time2 - time1;
0268 
0269     time = div_u64(time, perf_loops);
0270     printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n",
0271            (unsigned long long)time);
0272 
0273     time1 = get_cycles();
0274 
0275     for (i = 0; i < perf_loops; i++) {
0276         for (j = 0; j < nnodes; j++)
0277             insert_cached(nodes + j, &root);
0278         for (j = 0; j < nnodes; j++)
0279             erase_cached(nodes + j, &root);
0280     }
0281 
0282     time2 = get_cycles();
0283     time = time2 - time1;
0284 
0285     time = div_u64(time, perf_loops);
0286     printk(" -> test 2 (latency of nnodes cached insert+delete): %llu cycles\n",
0287            (unsigned long long)time);
0288 
0289     for (i = 0; i < nnodes; i++)
0290         insert(nodes + i, &root);
0291 
0292     time1 = get_cycles();
0293 
0294     for (i = 0; i < perf_loops; i++) {
0295         for (node = rb_first(&root.rb_root); node; node = rb_next(node))
0296             ;
0297     }
0298 
0299     time2 = get_cycles();
0300     time = time2 - time1;
0301 
0302     time = div_u64(time, perf_loops);
0303     printk(" -> test 3 (latency of inorder traversal): %llu cycles\n",
0304            (unsigned long long)time);
0305 
0306     time1 = get_cycles();
0307 
0308     for (i = 0; i < perf_loops; i++)
0309         node = rb_first(&root.rb_root);
0310 
0311     time2 = get_cycles();
0312     time = time2 - time1;
0313 
0314     time = div_u64(time, perf_loops);
0315     printk(" -> test 4 (latency to fetch first node)\n");
0316     printk("        non-cached: %llu cycles\n", (unsigned long long)time);
0317 
0318     time1 = get_cycles();
0319 
0320     for (i = 0; i < perf_loops; i++)
0321         node = rb_first_cached(&root);
0322 
0323     time2 = get_cycles();
0324     time = time2 - time1;
0325 
0326     time = div_u64(time, perf_loops);
0327     printk("        cached: %llu cycles\n", (unsigned long long)time);
0328 
0329     for (i = 0; i < nnodes; i++)
0330         erase(nodes + i, &root);
0331 
0332     /* run checks */
0333     for (i = 0; i < check_loops; i++) {
0334         init();
0335         for (j = 0; j < nnodes; j++) {
0336             check(j);
0337             insert(nodes + j, &root);
0338         }
0339         for (j = 0; j < nnodes; j++) {
0340             check(nnodes - j);
0341             erase(nodes + j, &root);
0342         }
0343         check(0);
0344     }
0345 
0346     printk(KERN_ALERT "augmented rbtree testing");
0347 
0348     init();
0349 
0350     time1 = get_cycles();
0351 
0352     for (i = 0; i < perf_loops; i++) {
0353         for (j = 0; j < nnodes; j++)
0354             insert_augmented(nodes + j, &root);
0355         for (j = 0; j < nnodes; j++)
0356             erase_augmented(nodes + j, &root);
0357     }
0358 
0359     time2 = get_cycles();
0360     time = time2 - time1;
0361 
0362     time = div_u64(time, perf_loops);
0363     printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n", (unsigned long long)time);
0364 
0365     time1 = get_cycles();
0366 
0367     for (i = 0; i < perf_loops; i++) {
0368         for (j = 0; j < nnodes; j++)
0369             insert_augmented_cached(nodes + j, &root);
0370         for (j = 0; j < nnodes; j++)
0371             erase_augmented_cached(nodes + j, &root);
0372     }
0373 
0374     time2 = get_cycles();
0375     time = time2 - time1;
0376 
0377     time = div_u64(time, perf_loops);
0378     printk(" -> test 2 (latency of nnodes cached insert+delete): %llu cycles\n", (unsigned long long)time);
0379 
0380     for (i = 0; i < check_loops; i++) {
0381         init();
0382         for (j = 0; j < nnodes; j++) {
0383             check_augmented(j);
0384             insert_augmented(nodes + j, &root);
0385         }
0386         for (j = 0; j < nnodes; j++) {
0387             check_augmented(nnodes - j);
0388             erase_augmented(nodes + j, &root);
0389         }
0390         check_augmented(0);
0391     }
0392 
0393     kfree(nodes);
0394 
0395     return -EAGAIN; /* Fail will directly unload the module */
0396 }
0397 
0398 static void __exit rbtree_test_exit(void)
0399 {
0400     printk(KERN_ALERT "test exit\n");
0401 }
0402 
0403 module_init(rbtree_test_init)
0404 module_exit(rbtree_test_exit)
0405 
0406 MODULE_LICENSE("GPL");
0407 MODULE_AUTHOR("Michel Lespinasse");
0408 MODULE_DESCRIPTION("Red Black Tree test");