0001
0002 #include <linux/module.h>
0003 #include <linux/moduleparam.h>
0004 #include <linux/rbtree_augmented.h>
0005 #include <linux/random.h>
0006 #include <linux/slab.h>
0007 #include <asm/timex.h>
0008
0009 #define __param(type, name, init, msg) \
0010 static type name = init; \
0011 module_param(name, type, 0444); \
0012 MODULE_PARM_DESC(name, msg);
0013
0014 __param(int, nnodes, 100, "Number of nodes in the rb-tree");
0015 __param(int, perf_loops, 1000, "Number of iterations modifying the rb-tree");
0016 __param(int, check_loops, 100, "Number of iterations modifying and verifying the rb-tree");
0017
0018 struct test_node {
0019 u32 key;
0020 struct rb_node rb;
0021
0022
0023 u32 val;
0024 u32 augmented;
0025 };
0026
0027 static struct rb_root_cached root = RB_ROOT_CACHED;
0028 static struct test_node *nodes = NULL;
0029
0030 static struct rnd_state rnd;
0031
0032 static void insert(struct test_node *node, struct rb_root_cached *root)
0033 {
0034 struct rb_node **new = &root->rb_root.rb_node, *parent = NULL;
0035 u32 key = node->key;
0036
0037 while (*new) {
0038 parent = *new;
0039 if (key < rb_entry(parent, struct test_node, rb)->key)
0040 new = &parent->rb_left;
0041 else
0042 new = &parent->rb_right;
0043 }
0044
0045 rb_link_node(&node->rb, parent, new);
0046 rb_insert_color(&node->rb, &root->rb_root);
0047 }
0048
0049 static void insert_cached(struct test_node *node, struct rb_root_cached *root)
0050 {
0051 struct rb_node **new = &root->rb_root.rb_node, *parent = NULL;
0052 u32 key = node->key;
0053 bool leftmost = true;
0054
0055 while (*new) {
0056 parent = *new;
0057 if (key < rb_entry(parent, struct test_node, rb)->key)
0058 new = &parent->rb_left;
0059 else {
0060 new = &parent->rb_right;
0061 leftmost = false;
0062 }
0063 }
0064
0065 rb_link_node(&node->rb, parent, new);
0066 rb_insert_color_cached(&node->rb, root, leftmost);
0067 }
0068
0069 static inline void erase(struct test_node *node, struct rb_root_cached *root)
0070 {
0071 rb_erase(&node->rb, &root->rb_root);
0072 }
0073
0074 static inline void erase_cached(struct test_node *node, struct rb_root_cached *root)
0075 {
0076 rb_erase_cached(&node->rb, root);
0077 }
0078
0079
0080 #define NODE_VAL(node) ((node)->val)
0081
0082 RB_DECLARE_CALLBACKS_MAX(static, augment_callbacks,
0083 struct test_node, rb, u32, augmented, NODE_VAL)
0084
0085 static void insert_augmented(struct test_node *node,
0086 struct rb_root_cached *root)
0087 {
0088 struct rb_node **new = &root->rb_root.rb_node, *rb_parent = NULL;
0089 u32 key = node->key;
0090 u32 val = node->val;
0091 struct test_node *parent;
0092
0093 while (*new) {
0094 rb_parent = *new;
0095 parent = rb_entry(rb_parent, struct test_node, rb);
0096 if (parent->augmented < val)
0097 parent->augmented = val;
0098 if (key < parent->key)
0099 new = &parent->rb.rb_left;
0100 else
0101 new = &parent->rb.rb_right;
0102 }
0103
0104 node->augmented = val;
0105 rb_link_node(&node->rb, rb_parent, new);
0106 rb_insert_augmented(&node->rb, &root->rb_root, &augment_callbacks);
0107 }
0108
0109 static void insert_augmented_cached(struct test_node *node,
0110 struct rb_root_cached *root)
0111 {
0112 struct rb_node **new = &root->rb_root.rb_node, *rb_parent = NULL;
0113 u32 key = node->key;
0114 u32 val = node->val;
0115 struct test_node *parent;
0116 bool leftmost = true;
0117
0118 while (*new) {
0119 rb_parent = *new;
0120 parent = rb_entry(rb_parent, struct test_node, rb);
0121 if (parent->augmented < val)
0122 parent->augmented = val;
0123 if (key < parent->key)
0124 new = &parent->rb.rb_left;
0125 else {
0126 new = &parent->rb.rb_right;
0127 leftmost = false;
0128 }
0129 }
0130
0131 node->augmented = val;
0132 rb_link_node(&node->rb, rb_parent, new);
0133 rb_insert_augmented_cached(&node->rb, root,
0134 leftmost, &augment_callbacks);
0135 }
0136
0137
0138 static void erase_augmented(struct test_node *node, struct rb_root_cached *root)
0139 {
0140 rb_erase_augmented(&node->rb, &root->rb_root, &augment_callbacks);
0141 }
0142
0143 static void erase_augmented_cached(struct test_node *node,
0144 struct rb_root_cached *root)
0145 {
0146 rb_erase_augmented_cached(&node->rb, root, &augment_callbacks);
0147 }
0148
0149 static void init(void)
0150 {
0151 int i;
0152 for (i = 0; i < nnodes; i++) {
0153 nodes[i].key = prandom_u32_state(&rnd);
0154 nodes[i].val = prandom_u32_state(&rnd);
0155 }
0156 }
0157
0158 static bool is_red(struct rb_node *rb)
0159 {
0160 return !(rb->__rb_parent_color & 1);
0161 }
0162
0163 static int black_path_count(struct rb_node *rb)
0164 {
0165 int count;
0166 for (count = 0; rb; rb = rb_parent(rb))
0167 count += !is_red(rb);
0168 return count;
0169 }
0170
0171 static void check_postorder_foreach(int nr_nodes)
0172 {
0173 struct test_node *cur, *n;
0174 int count = 0;
0175 rbtree_postorder_for_each_entry_safe(cur, n, &root.rb_root, rb)
0176 count++;
0177
0178 WARN_ON_ONCE(count != nr_nodes);
0179 }
0180
0181 static void check_postorder(int nr_nodes)
0182 {
0183 struct rb_node *rb;
0184 int count = 0;
0185 for (rb = rb_first_postorder(&root.rb_root); rb; rb = rb_next_postorder(rb))
0186 count++;
0187
0188 WARN_ON_ONCE(count != nr_nodes);
0189 }
0190
0191 static void check(int nr_nodes)
0192 {
0193 struct rb_node *rb;
0194 int count = 0, blacks = 0;
0195 u32 prev_key = 0;
0196
0197 for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) {
0198 struct test_node *node = rb_entry(rb, struct test_node, rb);
0199 WARN_ON_ONCE(node->key < prev_key);
0200 WARN_ON_ONCE(is_red(rb) &&
0201 (!rb_parent(rb) || is_red(rb_parent(rb))));
0202 if (!count)
0203 blacks = black_path_count(rb);
0204 else
0205 WARN_ON_ONCE((!rb->rb_left || !rb->rb_right) &&
0206 blacks != black_path_count(rb));
0207 prev_key = node->key;
0208 count++;
0209 }
0210
0211 WARN_ON_ONCE(count != nr_nodes);
0212 WARN_ON_ONCE(count < (1 << black_path_count(rb_last(&root.rb_root))) - 1);
0213
0214 check_postorder(nr_nodes);
0215 check_postorder_foreach(nr_nodes);
0216 }
0217
0218 static void check_augmented(int nr_nodes)
0219 {
0220 struct rb_node *rb;
0221
0222 check(nr_nodes);
0223 for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) {
0224 struct test_node *node = rb_entry(rb, struct test_node, rb);
0225 u32 subtree, max = node->val;
0226 if (node->rb.rb_left) {
0227 subtree = rb_entry(node->rb.rb_left, struct test_node,
0228 rb)->augmented;
0229 if (max < subtree)
0230 max = subtree;
0231 }
0232 if (node->rb.rb_right) {
0233 subtree = rb_entry(node->rb.rb_right, struct test_node,
0234 rb)->augmented;
0235 if (max < subtree)
0236 max = subtree;
0237 }
0238 WARN_ON_ONCE(node->augmented != max);
0239 }
0240 }
0241
0242 static int __init rbtree_test_init(void)
0243 {
0244 int i, j;
0245 cycles_t time1, time2, time;
0246 struct rb_node *node;
0247
0248 nodes = kmalloc_array(nnodes, sizeof(*nodes), GFP_KERNEL);
0249 if (!nodes)
0250 return -ENOMEM;
0251
0252 printk(KERN_ALERT "rbtree testing");
0253
0254 prandom_seed_state(&rnd, 3141592653589793238ULL);
0255 init();
0256
0257 time1 = get_cycles();
0258
0259 for (i = 0; i < perf_loops; i++) {
0260 for (j = 0; j < nnodes; j++)
0261 insert(nodes + j, &root);
0262 for (j = 0; j < nnodes; j++)
0263 erase(nodes + j, &root);
0264 }
0265
0266 time2 = get_cycles();
0267 time = time2 - time1;
0268
0269 time = div_u64(time, perf_loops);
0270 printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n",
0271 (unsigned long long)time);
0272
0273 time1 = get_cycles();
0274
0275 for (i = 0; i < perf_loops; i++) {
0276 for (j = 0; j < nnodes; j++)
0277 insert_cached(nodes + j, &root);
0278 for (j = 0; j < nnodes; j++)
0279 erase_cached(nodes + j, &root);
0280 }
0281
0282 time2 = get_cycles();
0283 time = time2 - time1;
0284
0285 time = div_u64(time, perf_loops);
0286 printk(" -> test 2 (latency of nnodes cached insert+delete): %llu cycles\n",
0287 (unsigned long long)time);
0288
0289 for (i = 0; i < nnodes; i++)
0290 insert(nodes + i, &root);
0291
0292 time1 = get_cycles();
0293
0294 for (i = 0; i < perf_loops; i++) {
0295 for (node = rb_first(&root.rb_root); node; node = rb_next(node))
0296 ;
0297 }
0298
0299 time2 = get_cycles();
0300 time = time2 - time1;
0301
0302 time = div_u64(time, perf_loops);
0303 printk(" -> test 3 (latency of inorder traversal): %llu cycles\n",
0304 (unsigned long long)time);
0305
0306 time1 = get_cycles();
0307
0308 for (i = 0; i < perf_loops; i++)
0309 node = rb_first(&root.rb_root);
0310
0311 time2 = get_cycles();
0312 time = time2 - time1;
0313
0314 time = div_u64(time, perf_loops);
0315 printk(" -> test 4 (latency to fetch first node)\n");
0316 printk(" non-cached: %llu cycles\n", (unsigned long long)time);
0317
0318 time1 = get_cycles();
0319
0320 for (i = 0; i < perf_loops; i++)
0321 node = rb_first_cached(&root);
0322
0323 time2 = get_cycles();
0324 time = time2 - time1;
0325
0326 time = div_u64(time, perf_loops);
0327 printk(" cached: %llu cycles\n", (unsigned long long)time);
0328
0329 for (i = 0; i < nnodes; i++)
0330 erase(nodes + i, &root);
0331
0332
0333 for (i = 0; i < check_loops; i++) {
0334 init();
0335 for (j = 0; j < nnodes; j++) {
0336 check(j);
0337 insert(nodes + j, &root);
0338 }
0339 for (j = 0; j < nnodes; j++) {
0340 check(nnodes - j);
0341 erase(nodes + j, &root);
0342 }
0343 check(0);
0344 }
0345
0346 printk(KERN_ALERT "augmented rbtree testing");
0347
0348 init();
0349
0350 time1 = get_cycles();
0351
0352 for (i = 0; i < perf_loops; i++) {
0353 for (j = 0; j < nnodes; j++)
0354 insert_augmented(nodes + j, &root);
0355 for (j = 0; j < nnodes; j++)
0356 erase_augmented(nodes + j, &root);
0357 }
0358
0359 time2 = get_cycles();
0360 time = time2 - time1;
0361
0362 time = div_u64(time, perf_loops);
0363 printk(" -> test 1 (latency of nnodes insert+delete): %llu cycles\n", (unsigned long long)time);
0364
0365 time1 = get_cycles();
0366
0367 for (i = 0; i < perf_loops; i++) {
0368 for (j = 0; j < nnodes; j++)
0369 insert_augmented_cached(nodes + j, &root);
0370 for (j = 0; j < nnodes; j++)
0371 erase_augmented_cached(nodes + j, &root);
0372 }
0373
0374 time2 = get_cycles();
0375 time = time2 - time1;
0376
0377 time = div_u64(time, perf_loops);
0378 printk(" -> test 2 (latency of nnodes cached insert+delete): %llu cycles\n", (unsigned long long)time);
0379
0380 for (i = 0; i < check_loops; i++) {
0381 init();
0382 for (j = 0; j < nnodes; j++) {
0383 check_augmented(j);
0384 insert_augmented(nodes + j, &root);
0385 }
0386 for (j = 0; j < nnodes; j++) {
0387 check_augmented(nnodes - j);
0388 erase_augmented(nodes + j, &root);
0389 }
0390 check_augmented(0);
0391 }
0392
0393 kfree(nodes);
0394
0395 return -EAGAIN;
0396 }
0397
0398 static void __exit rbtree_test_exit(void)
0399 {
0400 printk(KERN_ALERT "test exit\n");
0401 }
0402
0403 module_init(rbtree_test_init)
0404 module_exit(rbtree_test_exit)
0405
0406 MODULE_LICENSE("GPL");
0407 MODULE_AUTHOR("Michel Lespinasse");
0408 MODULE_DESCRIPTION("Red Black Tree test");