Back to home page

LXR

 
 

    


0001 #include <linux/module.h>
0002 #include <linux/rbtree_augmented.h>
0003 #include <linux/random.h>
0004 #include <asm/timex.h>
0005 
0006 #define NODES       100
0007 #define PERF_LOOPS  100000
0008 #define CHECK_LOOPS 100
0009 
0010 struct test_node {
0011     u32 key;
0012     struct rb_node rb;
0013 
0014     /* following fields used for testing augmented rbtree functionality */
0015     u32 val;
0016     u32 augmented;
0017 };
0018 
0019 static struct rb_root root = RB_ROOT;
0020 static struct test_node nodes[NODES];
0021 
0022 static struct rnd_state rnd;
0023 
0024 static void insert(struct test_node *node, struct rb_root *root)
0025 {
0026     struct rb_node **new = &root->rb_node, *parent = NULL;
0027     u32 key = node->key;
0028 
0029     while (*new) {
0030         parent = *new;
0031         if (key < rb_entry(parent, struct test_node, rb)->key)
0032             new = &parent->rb_left;
0033         else
0034             new = &parent->rb_right;
0035     }
0036 
0037     rb_link_node(&node->rb, parent, new);
0038     rb_insert_color(&node->rb, root);
0039 }
0040 
0041 static inline void erase(struct test_node *node, struct rb_root *root)
0042 {
0043     rb_erase(&node->rb, root);
0044 }
0045 
0046 static inline u32 augment_recompute(struct test_node *node)
0047 {
0048     u32 max = node->val, child_augmented;
0049     if (node->rb.rb_left) {
0050         child_augmented = rb_entry(node->rb.rb_left, struct test_node,
0051                        rb)->augmented;
0052         if (max < child_augmented)
0053             max = child_augmented;
0054     }
0055     if (node->rb.rb_right) {
0056         child_augmented = rb_entry(node->rb.rb_right, struct test_node,
0057                        rb)->augmented;
0058         if (max < child_augmented)
0059             max = child_augmented;
0060     }
0061     return max;
0062 }
0063 
0064 RB_DECLARE_CALLBACKS(static, augment_callbacks, struct test_node, rb,
0065              u32, augmented, augment_recompute)
0066 
0067 static void insert_augmented(struct test_node *node, struct rb_root *root)
0068 {
0069     struct rb_node **new = &root->rb_node, *rb_parent = NULL;
0070     u32 key = node->key;
0071     u32 val = node->val;
0072     struct test_node *parent;
0073 
0074     while (*new) {
0075         rb_parent = *new;
0076         parent = rb_entry(rb_parent, struct test_node, rb);
0077         if (parent->augmented < val)
0078             parent->augmented = val;
0079         if (key < parent->key)
0080             new = &parent->rb.rb_left;
0081         else
0082             new = &parent->rb.rb_right;
0083     }
0084 
0085     node->augmented = val;
0086     rb_link_node(&node->rb, rb_parent, new);
0087     rb_insert_augmented(&node->rb, root, &augment_callbacks);
0088 }
0089 
0090 static void erase_augmented(struct test_node *node, struct rb_root *root)
0091 {
0092     rb_erase_augmented(&node->rb, root, &augment_callbacks);
0093 }
0094 
0095 static void init(void)
0096 {
0097     int i;
0098     for (i = 0; i < NODES; i++) {
0099         nodes[i].key = prandom_u32_state(&rnd);
0100         nodes[i].val = prandom_u32_state(&rnd);
0101     }
0102 }
0103 
0104 static bool is_red(struct rb_node *rb)
0105 {
0106     return !(rb->__rb_parent_color & 1);
0107 }
0108 
0109 static int black_path_count(struct rb_node *rb)
0110 {
0111     int count;
0112     for (count = 0; rb; rb = rb_parent(rb))
0113         count += !is_red(rb);
0114     return count;
0115 }
0116 
0117 static void check_postorder_foreach(int nr_nodes)
0118 {
0119     struct test_node *cur, *n;
0120     int count = 0;
0121     rbtree_postorder_for_each_entry_safe(cur, n, &root, rb)
0122         count++;
0123 
0124     WARN_ON_ONCE(count != nr_nodes);
0125 }
0126 
0127 static void check_postorder(int nr_nodes)
0128 {
0129     struct rb_node *rb;
0130     int count = 0;
0131     for (rb = rb_first_postorder(&root); rb; rb = rb_next_postorder(rb))
0132         count++;
0133 
0134     WARN_ON_ONCE(count != nr_nodes);
0135 }
0136 
0137 static void check(int nr_nodes)
0138 {
0139     struct rb_node *rb;
0140     int count = 0, blacks = 0;
0141     u32 prev_key = 0;
0142 
0143     for (rb = rb_first(&root); rb; rb = rb_next(rb)) {
0144         struct test_node *node = rb_entry(rb, struct test_node, rb);
0145         WARN_ON_ONCE(node->key < prev_key);
0146         WARN_ON_ONCE(is_red(rb) &&
0147                  (!rb_parent(rb) || is_red(rb_parent(rb))));
0148         if (!count)
0149             blacks = black_path_count(rb);
0150         else
0151             WARN_ON_ONCE((!rb->rb_left || !rb->rb_right) &&
0152                      blacks != black_path_count(rb));
0153         prev_key = node->key;
0154         count++;
0155     }
0156 
0157     WARN_ON_ONCE(count != nr_nodes);
0158     WARN_ON_ONCE(count < (1 << black_path_count(rb_last(&root))) - 1);
0159 
0160     check_postorder(nr_nodes);
0161     check_postorder_foreach(nr_nodes);
0162 }
0163 
0164 static void check_augmented(int nr_nodes)
0165 {
0166     struct rb_node *rb;
0167 
0168     check(nr_nodes);
0169     for (rb = rb_first(&root); rb; rb = rb_next(rb)) {
0170         struct test_node *node = rb_entry(rb, struct test_node, rb);
0171         WARN_ON_ONCE(node->augmented != augment_recompute(node));
0172     }
0173 }
0174 
0175 static int __init rbtree_test_init(void)
0176 {
0177     int i, j;
0178     cycles_t time1, time2, time;
0179 
0180     printk(KERN_ALERT "rbtree testing");
0181 
0182     prandom_seed_state(&rnd, 3141592653589793238ULL);
0183     init();
0184 
0185     time1 = get_cycles();
0186 
0187     for (i = 0; i < PERF_LOOPS; i++) {
0188         for (j = 0; j < NODES; j++)
0189             insert(nodes + j, &root);
0190         for (j = 0; j < NODES; j++)
0191             erase(nodes + j, &root);
0192     }
0193 
0194     time2 = get_cycles();
0195     time = time2 - time1;
0196 
0197     time = div_u64(time, PERF_LOOPS);
0198     printk(" -> %llu cycles\n", (unsigned long long)time);
0199 
0200     for (i = 0; i < CHECK_LOOPS; i++) {
0201         init();
0202         for (j = 0; j < NODES; j++) {
0203             check(j);
0204             insert(nodes + j, &root);
0205         }
0206         for (j = 0; j < NODES; j++) {
0207             check(NODES - j);
0208             erase(nodes + j, &root);
0209         }
0210         check(0);
0211     }
0212 
0213     printk(KERN_ALERT "augmented rbtree testing");
0214 
0215     init();
0216 
0217     time1 = get_cycles();
0218 
0219     for (i = 0; i < PERF_LOOPS; i++) {
0220         for (j = 0; j < NODES; j++)
0221             insert_augmented(nodes + j, &root);
0222         for (j = 0; j < NODES; j++)
0223             erase_augmented(nodes + j, &root);
0224     }
0225 
0226     time2 = get_cycles();
0227     time = time2 - time1;
0228 
0229     time = div_u64(time, PERF_LOOPS);
0230     printk(" -> %llu cycles\n", (unsigned long long)time);
0231 
0232     for (i = 0; i < CHECK_LOOPS; i++) {
0233         init();
0234         for (j = 0; j < NODES; j++) {
0235             check_augmented(j);
0236             insert_augmented(nodes + j, &root);
0237         }
0238         for (j = 0; j < NODES; j++) {
0239             check_augmented(NODES - j);
0240             erase_augmented(nodes + j, &root);
0241         }
0242         check_augmented(0);
0243     }
0244 
0245     return -EAGAIN; /* Fail will directly unload the module */
0246 }
0247 
0248 static void __exit rbtree_test_exit(void)
0249 {
0250     printk(KERN_ALERT "test exit\n");
0251 }
0252 
0253 module_init(rbtree_test_init)
0254 module_exit(rbtree_test_exit)
0255 
0256 MODULE_LICENSE("GPL");
0257 MODULE_AUTHOR("Michel Lespinasse");
0258 MODULE_DESCRIPTION("Red Black Tree test");