Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: BSD-3-Clause OR GPL-2.0
0002 /* Copyright (c) 2018 Mellanox Technologies. All rights reserved */
0003 
0004 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
0005 
0006 #include <linux/kernel.h>
0007 #include <linux/module.h>
0008 #include <linux/slab.h>
0009 #include <linux/random.h>
0010 #include <linux/objagg.h>
0011 
0012 struct tokey {
0013     unsigned int id;
0014 };
0015 
0016 #define NUM_KEYS 32
0017 
0018 static int key_id_index(unsigned int key_id)
0019 {
0020     if (key_id >= NUM_KEYS) {
0021         WARN_ON(1);
0022         return 0;
0023     }
0024     return key_id;
0025 }
0026 
0027 #define BUF_LEN 128
0028 
0029 struct world {
0030     unsigned int root_count;
0031     unsigned int delta_count;
0032     char next_root_buf[BUF_LEN];
0033     struct objagg_obj *objagg_objs[NUM_KEYS];
0034     unsigned int key_refs[NUM_KEYS];
0035 };
0036 
0037 struct root {
0038     struct tokey key;
0039     char buf[BUF_LEN];
0040 };
0041 
0042 struct delta {
0043     unsigned int key_id_diff;
0044 };
0045 
0046 static struct objagg_obj *world_obj_get(struct world *world,
0047                     struct objagg *objagg,
0048                     unsigned int key_id)
0049 {
0050     struct objagg_obj *objagg_obj;
0051     struct tokey key;
0052     int err;
0053 
0054     key.id = key_id;
0055     objagg_obj = objagg_obj_get(objagg, &key);
0056     if (IS_ERR(objagg_obj)) {
0057         pr_err("Key %u: Failed to get object.\n", key_id);
0058         return objagg_obj;
0059     }
0060     if (!world->key_refs[key_id_index(key_id)]) {
0061         world->objagg_objs[key_id_index(key_id)] = objagg_obj;
0062     } else if (world->objagg_objs[key_id_index(key_id)] != objagg_obj) {
0063         pr_err("Key %u: God another object for the same key.\n",
0064                key_id);
0065         err = -EINVAL;
0066         goto err_key_id_check;
0067     }
0068     world->key_refs[key_id_index(key_id)]++;
0069     return objagg_obj;
0070 
0071 err_key_id_check:
0072     objagg_obj_put(objagg, objagg_obj);
0073     return ERR_PTR(err);
0074 }
0075 
0076 static void world_obj_put(struct world *world, struct objagg *objagg,
0077               unsigned int key_id)
0078 {
0079     struct objagg_obj *objagg_obj;
0080 
0081     if (!world->key_refs[key_id_index(key_id)])
0082         return;
0083     objagg_obj = world->objagg_objs[key_id_index(key_id)];
0084     objagg_obj_put(objagg, objagg_obj);
0085     world->key_refs[key_id_index(key_id)]--;
0086 }
0087 
0088 #define MAX_KEY_ID_DIFF 5
0089 
0090 static bool delta_check(void *priv, const void *parent_obj, const void *obj)
0091 {
0092     const struct tokey *parent_key = parent_obj;
0093     const struct tokey *key = obj;
0094     int diff = key->id - parent_key->id;
0095 
0096     return diff >= 0 && diff <= MAX_KEY_ID_DIFF;
0097 }
0098 
0099 static void *delta_create(void *priv, void *parent_obj, void *obj)
0100 {
0101     struct tokey *parent_key = parent_obj;
0102     struct world *world = priv;
0103     struct tokey *key = obj;
0104     int diff = key->id - parent_key->id;
0105     struct delta *delta;
0106 
0107     if (!delta_check(priv, parent_obj, obj))
0108         return ERR_PTR(-EINVAL);
0109 
0110     delta = kzalloc(sizeof(*delta), GFP_KERNEL);
0111     if (!delta)
0112         return ERR_PTR(-ENOMEM);
0113     delta->key_id_diff = diff;
0114     world->delta_count++;
0115     return delta;
0116 }
0117 
0118 static void delta_destroy(void *priv, void *delta_priv)
0119 {
0120     struct delta *delta = delta_priv;
0121     struct world *world = priv;
0122 
0123     world->delta_count--;
0124     kfree(delta);
0125 }
0126 
0127 static void *root_create(void *priv, void *obj, unsigned int id)
0128 {
0129     struct world *world = priv;
0130     struct tokey *key = obj;
0131     struct root *root;
0132 
0133     root = kzalloc(sizeof(*root), GFP_KERNEL);
0134     if (!root)
0135         return ERR_PTR(-ENOMEM);
0136     memcpy(&root->key, key, sizeof(root->key));
0137     memcpy(root->buf, world->next_root_buf, sizeof(root->buf));
0138     world->root_count++;
0139     return root;
0140 }
0141 
0142 static void root_destroy(void *priv, void *root_priv)
0143 {
0144     struct root *root = root_priv;
0145     struct world *world = priv;
0146 
0147     world->root_count--;
0148     kfree(root);
0149 }
0150 
0151 static int test_nodelta_obj_get(struct world *world, struct objagg *objagg,
0152                 unsigned int key_id, bool should_create_root)
0153 {
0154     unsigned int orig_root_count = world->root_count;
0155     struct objagg_obj *objagg_obj;
0156     const struct root *root;
0157     int err;
0158 
0159     if (should_create_root)
0160         prandom_bytes(world->next_root_buf,
0161                   sizeof(world->next_root_buf));
0162 
0163     objagg_obj = world_obj_get(world, objagg, key_id);
0164     if (IS_ERR(objagg_obj)) {
0165         pr_err("Key %u: Failed to get object.\n", key_id);
0166         return PTR_ERR(objagg_obj);
0167     }
0168     if (should_create_root) {
0169         if (world->root_count != orig_root_count + 1) {
0170             pr_err("Key %u: Root was not created\n", key_id);
0171             err = -EINVAL;
0172             goto err_check_root_count;
0173         }
0174     } else {
0175         if (world->root_count != orig_root_count) {
0176             pr_err("Key %u: Root was incorrectly created\n",
0177                    key_id);
0178             err = -EINVAL;
0179             goto err_check_root_count;
0180         }
0181     }
0182     root = objagg_obj_root_priv(objagg_obj);
0183     if (root->key.id != key_id) {
0184         pr_err("Key %u: Root has unexpected key id\n", key_id);
0185         err = -EINVAL;
0186         goto err_check_key_id;
0187     }
0188     if (should_create_root &&
0189         memcmp(world->next_root_buf, root->buf, sizeof(root->buf))) {
0190         pr_err("Key %u: Buffer does not match the expected content\n",
0191                key_id);
0192         err = -EINVAL;
0193         goto err_check_buf;
0194     }
0195     return 0;
0196 
0197 err_check_buf:
0198 err_check_key_id:
0199 err_check_root_count:
0200     objagg_obj_put(objagg, objagg_obj);
0201     return err;
0202 }
0203 
0204 static int test_nodelta_obj_put(struct world *world, struct objagg *objagg,
0205                 unsigned int key_id, bool should_destroy_root)
0206 {
0207     unsigned int orig_root_count = world->root_count;
0208 
0209     world_obj_put(world, objagg, key_id);
0210 
0211     if (should_destroy_root) {
0212         if (world->root_count != orig_root_count - 1) {
0213             pr_err("Key %u: Root was not destroyed\n", key_id);
0214             return -EINVAL;
0215         }
0216     } else {
0217         if (world->root_count != orig_root_count) {
0218             pr_err("Key %u: Root was incorrectly destroyed\n",
0219                    key_id);
0220             return -EINVAL;
0221         }
0222     }
0223     return 0;
0224 }
0225 
0226 static int check_stats_zero(struct objagg *objagg)
0227 {
0228     const struct objagg_stats *stats;
0229     int err = 0;
0230 
0231     stats = objagg_stats_get(objagg);
0232     if (IS_ERR(stats))
0233         return PTR_ERR(stats);
0234 
0235     if (stats->stats_info_count != 0) {
0236         pr_err("Stats: Object count is not zero while it should be\n");
0237         err = -EINVAL;
0238     }
0239 
0240     objagg_stats_put(stats);
0241     return err;
0242 }
0243 
0244 static int check_stats_nodelta(struct objagg *objagg)
0245 {
0246     const struct objagg_stats *stats;
0247     int i;
0248     int err;
0249 
0250     stats = objagg_stats_get(objagg);
0251     if (IS_ERR(stats))
0252         return PTR_ERR(stats);
0253 
0254     if (stats->stats_info_count != NUM_KEYS) {
0255         pr_err("Stats: Unexpected object count (%u expected, %u returned)\n",
0256                NUM_KEYS, stats->stats_info_count);
0257         err = -EINVAL;
0258         goto stats_put;
0259     }
0260 
0261     for (i = 0; i < stats->stats_info_count; i++) {
0262         if (stats->stats_info[i].stats.user_count != 2) {
0263             pr_err("Stats: incorrect user count\n");
0264             err = -EINVAL;
0265             goto stats_put;
0266         }
0267         if (stats->stats_info[i].stats.delta_user_count != 2) {
0268             pr_err("Stats: incorrect delta user count\n");
0269             err = -EINVAL;
0270             goto stats_put;
0271         }
0272     }
0273     err = 0;
0274 
0275 stats_put:
0276     objagg_stats_put(stats);
0277     return err;
0278 }
0279 
0280 static bool delta_check_dummy(void *priv, const void *parent_obj,
0281                   const void *obj)
0282 {
0283     return false;
0284 }
0285 
0286 static void *delta_create_dummy(void *priv, void *parent_obj, void *obj)
0287 {
0288     return ERR_PTR(-EOPNOTSUPP);
0289 }
0290 
0291 static void delta_destroy_dummy(void *priv, void *delta_priv)
0292 {
0293 }
0294 
0295 static const struct objagg_ops nodelta_ops = {
0296     .obj_size = sizeof(struct tokey),
0297     .delta_check = delta_check_dummy,
0298     .delta_create = delta_create_dummy,
0299     .delta_destroy = delta_destroy_dummy,
0300     .root_create = root_create,
0301     .root_destroy = root_destroy,
0302 };
0303 
0304 static int test_nodelta(void)
0305 {
0306     struct world world = {};
0307     struct objagg *objagg;
0308     int i;
0309     int err;
0310 
0311     objagg = objagg_create(&nodelta_ops, NULL, &world);
0312     if (IS_ERR(objagg))
0313         return PTR_ERR(objagg);
0314 
0315     err = check_stats_zero(objagg);
0316     if (err)
0317         goto err_stats_first_zero;
0318 
0319     /* First round of gets, the root objects should be created */
0320     for (i = 0; i < NUM_KEYS; i++) {
0321         err = test_nodelta_obj_get(&world, objagg, i, true);
0322         if (err)
0323             goto err_obj_first_get;
0324     }
0325 
0326     /* Do the second round of gets, all roots are already created,
0327      * make sure that no new root is created
0328      */
0329     for (i = 0; i < NUM_KEYS; i++) {
0330         err = test_nodelta_obj_get(&world, objagg, i, false);
0331         if (err)
0332             goto err_obj_second_get;
0333     }
0334 
0335     err = check_stats_nodelta(objagg);
0336     if (err)
0337         goto err_stats_nodelta;
0338 
0339     for (i = NUM_KEYS - 1; i >= 0; i--) {
0340         err = test_nodelta_obj_put(&world, objagg, i, false);
0341         if (err)
0342             goto err_obj_first_put;
0343     }
0344     for (i = NUM_KEYS - 1; i >= 0; i--) {
0345         err = test_nodelta_obj_put(&world, objagg, i, true);
0346         if (err)
0347             goto err_obj_second_put;
0348     }
0349 
0350     err = check_stats_zero(objagg);
0351     if (err)
0352         goto err_stats_second_zero;
0353 
0354     objagg_destroy(objagg);
0355     return 0;
0356 
0357 err_stats_nodelta:
0358 err_obj_first_put:
0359 err_obj_second_get:
0360     for (i--; i >= 0; i--)
0361         world_obj_put(&world, objagg, i);
0362 
0363     i = NUM_KEYS;
0364 err_obj_first_get:
0365 err_obj_second_put:
0366     for (i--; i >= 0; i--)
0367         world_obj_put(&world, objagg, i);
0368 err_stats_first_zero:
0369 err_stats_second_zero:
0370     objagg_destroy(objagg);
0371     return err;
0372 }
0373 
0374 static const struct objagg_ops delta_ops = {
0375     .obj_size = sizeof(struct tokey),
0376     .delta_check = delta_check,
0377     .delta_create = delta_create,
0378     .delta_destroy = delta_destroy,
0379     .root_create = root_create,
0380     .root_destroy = root_destroy,
0381 };
0382 
0383 enum action {
0384     ACTION_GET,
0385     ACTION_PUT,
0386 };
0387 
0388 enum expect_delta {
0389     EXPECT_DELTA_SAME,
0390     EXPECT_DELTA_INC,
0391     EXPECT_DELTA_DEC,
0392 };
0393 
0394 enum expect_root {
0395     EXPECT_ROOT_SAME,
0396     EXPECT_ROOT_INC,
0397     EXPECT_ROOT_DEC,
0398 };
0399 
0400 struct expect_stats_info {
0401     struct objagg_obj_stats stats;
0402     bool is_root;
0403     unsigned int key_id;
0404 };
0405 
0406 struct expect_stats {
0407     unsigned int info_count;
0408     struct expect_stats_info info[NUM_KEYS];
0409 };
0410 
0411 struct action_item {
0412     unsigned int key_id;
0413     enum action action;
0414     enum expect_delta expect_delta;
0415     enum expect_root expect_root;
0416     struct expect_stats expect_stats;
0417 };
0418 
0419 #define EXPECT_STATS(count, ...)        \
0420 {                       \
0421     .info_count = count,            \
0422     .info = { __VA_ARGS__ }         \
0423 }
0424 
0425 #define ROOT(key_id, user_count, delta_user_count)  \
0426     {{user_count, delta_user_count}, true, key_id}
0427 
0428 #define DELTA(key_id, user_count)           \
0429     {{user_count, user_count}, false, key_id}
0430 
0431 static const struct action_item action_items[] = {
0432     {
0433         1, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
0434         EXPECT_STATS(1, ROOT(1, 1, 1)),
0435     },  /* r: 1         d: */
0436     {
0437         7, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
0438         EXPECT_STATS(2, ROOT(1, 1, 1), ROOT(7, 1, 1)),
0439     },  /* r: 1, 7      d: */
0440     {
0441         3, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
0442         EXPECT_STATS(3, ROOT(1, 1, 2), ROOT(7, 1, 1),
0443                 DELTA(3, 1)),
0444     },  /* r: 1, 7      d: 3^1 */
0445     {
0446         5, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
0447         EXPECT_STATS(4, ROOT(1, 1, 3), ROOT(7, 1, 1),
0448                 DELTA(3, 1), DELTA(5, 1)),
0449     },  /* r: 1, 7      d: 3^1, 5^1 */
0450     {
0451         3, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
0452         EXPECT_STATS(4, ROOT(1, 1, 4), ROOT(7, 1, 1),
0453                 DELTA(3, 2), DELTA(5, 1)),
0454     },  /* r: 1, 7      d: 3^1, 3^1, 5^1 */
0455     {
0456         1, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
0457         EXPECT_STATS(4, ROOT(1, 2, 5), ROOT(7, 1, 1),
0458                 DELTA(3, 2), DELTA(5, 1)),
0459     },  /* r: 1, 1, 7       d: 3^1, 3^1, 5^1 */
0460     {
0461         30, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
0462         EXPECT_STATS(5, ROOT(1, 2, 5), ROOT(7, 1, 1), ROOT(30, 1, 1),
0463                 DELTA(3, 2), DELTA(5, 1)),
0464     },  /* r: 1, 1, 7, 30   d: 3^1, 3^1, 5^1 */
0465     {
0466         8, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
0467         EXPECT_STATS(6, ROOT(1, 2, 5), ROOT(7, 1, 2), ROOT(30, 1, 1),
0468                 DELTA(3, 2), DELTA(5, 1), DELTA(8, 1)),
0469     },  /* r: 1, 1, 7, 30   d: 3^1, 3^1, 5^1, 8^7 */
0470     {
0471         8, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
0472         EXPECT_STATS(6, ROOT(1, 2, 5), ROOT(7, 1, 3), ROOT(30, 1, 1),
0473                 DELTA(3, 2), DELTA(8, 2), DELTA(5, 1)),
0474     },  /* r: 1, 1, 7, 30   d: 3^1, 3^1, 5^1, 8^7, 8^7 */
0475     {
0476         3, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
0477         EXPECT_STATS(6, ROOT(1, 2, 4), ROOT(7, 1, 3), ROOT(30, 1, 1),
0478                 DELTA(8, 2), DELTA(3, 1), DELTA(5, 1)),
0479     },  /* r: 1, 1, 7, 30   d: 3^1, 5^1, 8^7, 8^7 */
0480     {
0481         3, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
0482         EXPECT_STATS(5, ROOT(1, 2, 3), ROOT(7, 1, 3), ROOT(30, 1, 1),
0483                 DELTA(8, 2), DELTA(5, 1)),
0484     },  /* r: 1, 1, 7, 30   d: 5^1, 8^7, 8^7 */
0485     {
0486         1, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
0487         EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(1, 1, 2), ROOT(30, 1, 1),
0488                 DELTA(8, 2), DELTA(5, 1)),
0489     },  /* r: 1, 7, 30      d: 5^1, 8^7, 8^7 */
0490     {
0491         1, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
0492         EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(30, 1, 1), ROOT(1, 0, 1),
0493                 DELTA(8, 2), DELTA(5, 1)),
0494     },  /* r: 7, 30     d: 5^1, 8^7, 8^7 */
0495     {
0496         5, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_DEC,
0497         EXPECT_STATS(3, ROOT(7, 1, 3), ROOT(30, 1, 1),
0498                 DELTA(8, 2)),
0499     },  /* r: 7, 30     d: 8^7, 8^7 */
0500     {
0501         5, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_INC,
0502         EXPECT_STATS(4, ROOT(7, 1, 3), ROOT(30, 1, 1), ROOT(5, 1, 1),
0503                 DELTA(8, 2)),
0504     },  /* r: 7, 30, 5      d: 8^7, 8^7 */
0505     {
0506         6, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
0507         EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(5, 1, 2), ROOT(30, 1, 1),
0508                 DELTA(8, 2), DELTA(6, 1)),
0509     },  /* r: 7, 30, 5      d: 8^7, 8^7, 6^5 */
0510     {
0511         8, ACTION_GET, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
0512         EXPECT_STATS(5, ROOT(7, 1, 4), ROOT(5, 1, 2), ROOT(30, 1, 1),
0513                 DELTA(8, 3), DELTA(6, 1)),
0514     },  /* r: 7, 30, 5      d: 8^7, 8^7, 8^7, 6^5 */
0515     {
0516         8, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
0517         EXPECT_STATS(5, ROOT(7, 1, 3), ROOT(5, 1, 2), ROOT(30, 1, 1),
0518                 DELTA(8, 2), DELTA(6, 1)),
0519     },  /* r: 7, 30, 5      d: 8^7, 8^7, 6^5 */
0520     {
0521         8, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
0522         EXPECT_STATS(5, ROOT(7, 1, 2), ROOT(5, 1, 2), ROOT(30, 1, 1),
0523                 DELTA(8, 1), DELTA(6, 1)),
0524     },  /* r: 7, 30, 5      d: 8^7, 6^5 */
0525     {
0526         8, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
0527         EXPECT_STATS(4, ROOT(5, 1, 2), ROOT(7, 1, 1), ROOT(30, 1, 1),
0528                 DELTA(6, 1)),
0529     },  /* r: 7, 30, 5      d: 6^5 */
0530     {
0531         8, ACTION_GET, EXPECT_DELTA_INC, EXPECT_ROOT_SAME,
0532         EXPECT_STATS(5, ROOT(5, 1, 3), ROOT(7, 1, 1), ROOT(30, 1, 1),
0533                 DELTA(6, 1), DELTA(8, 1)),
0534     },  /* r: 7, 30, 5      d: 6^5, 8^5 */
0535     {
0536         7, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_DEC,
0537         EXPECT_STATS(4, ROOT(5, 1, 3), ROOT(30, 1, 1),
0538                 DELTA(6, 1), DELTA(8, 1)),
0539     },  /* r: 30, 5     d: 6^5, 8^5 */
0540     {
0541         30, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_DEC,
0542         EXPECT_STATS(3, ROOT(5, 1, 3),
0543                 DELTA(6, 1), DELTA(8, 1)),
0544     },  /* r: 5         d: 6^5, 8^5 */
0545     {
0546         5, ACTION_PUT, EXPECT_DELTA_SAME, EXPECT_ROOT_SAME,
0547         EXPECT_STATS(3, ROOT(5, 0, 2),
0548                 DELTA(6, 1), DELTA(8, 1)),
0549     },  /* r:           d: 6^5, 8^5 */
0550     {
0551         6, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_SAME,
0552         EXPECT_STATS(2, ROOT(5, 0, 1),
0553                 DELTA(8, 1)),
0554     },  /* r:           d: 6^5 */
0555     {
0556         8, ACTION_PUT, EXPECT_DELTA_DEC, EXPECT_ROOT_DEC,
0557         EXPECT_STATS(0, ),
0558     },  /* r:           d: */
0559 };
0560 
0561 static int check_expect(struct world *world,
0562             const struct action_item *action_item,
0563             unsigned int orig_delta_count,
0564             unsigned int orig_root_count)
0565 {
0566     unsigned int key_id = action_item->key_id;
0567 
0568     switch (action_item->expect_delta) {
0569     case EXPECT_DELTA_SAME:
0570         if (orig_delta_count != world->delta_count) {
0571             pr_err("Key %u: Delta count changed while expected to remain the same.\n",
0572                    key_id);
0573             return -EINVAL;
0574         }
0575         break;
0576     case EXPECT_DELTA_INC:
0577         if (WARN_ON(action_item->action == ACTION_PUT))
0578             return -EINVAL;
0579         if (orig_delta_count + 1 != world->delta_count) {
0580             pr_err("Key %u: Delta count was not incremented.\n",
0581                    key_id);
0582             return -EINVAL;
0583         }
0584         break;
0585     case EXPECT_DELTA_DEC:
0586         if (WARN_ON(action_item->action == ACTION_GET))
0587             return -EINVAL;
0588         if (orig_delta_count - 1 != world->delta_count) {
0589             pr_err("Key %u: Delta count was not decremented.\n",
0590                    key_id);
0591             return -EINVAL;
0592         }
0593         break;
0594     }
0595 
0596     switch (action_item->expect_root) {
0597     case EXPECT_ROOT_SAME:
0598         if (orig_root_count != world->root_count) {
0599             pr_err("Key %u: Root count changed while expected to remain the same.\n",
0600                    key_id);
0601             return -EINVAL;
0602         }
0603         break;
0604     case EXPECT_ROOT_INC:
0605         if (WARN_ON(action_item->action == ACTION_PUT))
0606             return -EINVAL;
0607         if (orig_root_count + 1 != world->root_count) {
0608             pr_err("Key %u: Root count was not incremented.\n",
0609                    key_id);
0610             return -EINVAL;
0611         }
0612         break;
0613     case EXPECT_ROOT_DEC:
0614         if (WARN_ON(action_item->action == ACTION_GET))
0615             return -EINVAL;
0616         if (orig_root_count - 1 != world->root_count) {
0617             pr_err("Key %u: Root count was not decremented.\n",
0618                    key_id);
0619             return -EINVAL;
0620         }
0621     }
0622 
0623     return 0;
0624 }
0625 
0626 static unsigned int obj_to_key_id(struct objagg_obj *objagg_obj)
0627 {
0628     const struct tokey *root_key;
0629     const struct delta *delta;
0630     unsigned int key_id;
0631 
0632     root_key = objagg_obj_root_priv(objagg_obj);
0633     key_id = root_key->id;
0634     delta = objagg_obj_delta_priv(objagg_obj);
0635     if (delta)
0636         key_id += delta->key_id_diff;
0637     return key_id;
0638 }
0639 
0640 static int
0641 check_expect_stats_nums(const struct objagg_obj_stats_info *stats_info,
0642             const struct expect_stats_info *expect_stats_info,
0643             const char **errmsg)
0644 {
0645     if (stats_info->is_root != expect_stats_info->is_root) {
0646         if (errmsg)
0647             *errmsg = "Incorrect root/delta indication";
0648         return -EINVAL;
0649     }
0650     if (stats_info->stats.user_count !=
0651         expect_stats_info->stats.user_count) {
0652         if (errmsg)
0653             *errmsg = "Incorrect user count";
0654         return -EINVAL;
0655     }
0656     if (stats_info->stats.delta_user_count !=
0657         expect_stats_info->stats.delta_user_count) {
0658         if (errmsg)
0659             *errmsg = "Incorrect delta user count";
0660         return -EINVAL;
0661     }
0662     return 0;
0663 }
0664 
0665 static int
0666 check_expect_stats_key_id(const struct objagg_obj_stats_info *stats_info,
0667               const struct expect_stats_info *expect_stats_info,
0668               const char **errmsg)
0669 {
0670     if (obj_to_key_id(stats_info->objagg_obj) !=
0671         expect_stats_info->key_id) {
0672         if (errmsg)
0673             *errmsg = "incorrect key id";
0674         return -EINVAL;
0675     }
0676     return 0;
0677 }
0678 
0679 static int check_expect_stats_neigh(const struct objagg_stats *stats,
0680                     const struct expect_stats *expect_stats,
0681                     int pos)
0682 {
0683     int i;
0684     int err;
0685 
0686     for (i = pos - 1; i >= 0; i--) {
0687         err = check_expect_stats_nums(&stats->stats_info[i],
0688                           &expect_stats->info[pos], NULL);
0689         if (err)
0690             break;
0691         err = check_expect_stats_key_id(&stats->stats_info[i],
0692                         &expect_stats->info[pos], NULL);
0693         if (!err)
0694             return 0;
0695     }
0696     for (i = pos + 1; i < stats->stats_info_count; i++) {
0697         err = check_expect_stats_nums(&stats->stats_info[i],
0698                           &expect_stats->info[pos], NULL);
0699         if (err)
0700             break;
0701         err = check_expect_stats_key_id(&stats->stats_info[i],
0702                         &expect_stats->info[pos], NULL);
0703         if (!err)
0704             return 0;
0705     }
0706     return -EINVAL;
0707 }
0708 
0709 static int __check_expect_stats(const struct objagg_stats *stats,
0710                 const struct expect_stats *expect_stats,
0711                 const char **errmsg)
0712 {
0713     int i;
0714     int err;
0715 
0716     if (stats->stats_info_count != expect_stats->info_count) {
0717         *errmsg = "Unexpected object count";
0718         return -EINVAL;
0719     }
0720 
0721     for (i = 0; i < stats->stats_info_count; i++) {
0722         err = check_expect_stats_nums(&stats->stats_info[i],
0723                           &expect_stats->info[i], errmsg);
0724         if (err)
0725             return err;
0726         err = check_expect_stats_key_id(&stats->stats_info[i],
0727                         &expect_stats->info[i], errmsg);
0728         if (err) {
0729             /* It is possible that one of the neighbor stats with
0730              * same numbers have the correct key id, so check it
0731              */
0732             err = check_expect_stats_neigh(stats, expect_stats, i);
0733             if (err)
0734                 return err;
0735         }
0736     }
0737     return 0;
0738 }
0739 
0740 static int check_expect_stats(struct objagg *objagg,
0741                   const struct expect_stats *expect_stats,
0742                   const char **errmsg)
0743 {
0744     const struct objagg_stats *stats;
0745     int err;
0746 
0747     stats = objagg_stats_get(objagg);
0748     if (IS_ERR(stats)) {
0749         *errmsg = "objagg_stats_get() failed.";
0750         return PTR_ERR(stats);
0751     }
0752     err = __check_expect_stats(stats, expect_stats, errmsg);
0753     objagg_stats_put(stats);
0754     return err;
0755 }
0756 
0757 static int test_delta_action_item(struct world *world,
0758                   struct objagg *objagg,
0759                   const struct action_item *action_item,
0760                   bool inverse)
0761 {
0762     unsigned int orig_delta_count = world->delta_count;
0763     unsigned int orig_root_count = world->root_count;
0764     unsigned int key_id = action_item->key_id;
0765     enum action action = action_item->action;
0766     struct objagg_obj *objagg_obj;
0767     const char *errmsg;
0768     int err;
0769 
0770     if (inverse)
0771         action = action == ACTION_GET ? ACTION_PUT : ACTION_GET;
0772 
0773     switch (action) {
0774     case ACTION_GET:
0775         objagg_obj = world_obj_get(world, objagg, key_id);
0776         if (IS_ERR(objagg_obj))
0777             return PTR_ERR(objagg_obj);
0778         break;
0779     case ACTION_PUT:
0780         world_obj_put(world, objagg, key_id);
0781         break;
0782     }
0783 
0784     if (inverse)
0785         return 0;
0786     err = check_expect(world, action_item,
0787                orig_delta_count, orig_root_count);
0788     if (err)
0789         goto errout;
0790 
0791     err = check_expect_stats(objagg, &action_item->expect_stats, &errmsg);
0792     if (err) {
0793         pr_err("Key %u: Stats: %s\n", action_item->key_id, errmsg);
0794         goto errout;
0795     }
0796 
0797     return 0;
0798 
0799 errout:
0800     /* This can only happen when action is not inversed.
0801      * So in case of an error, cleanup by doing inverse action.
0802      */
0803     test_delta_action_item(world, objagg, action_item, true);
0804     return err;
0805 }
0806 
0807 static int test_delta(void)
0808 {
0809     struct world world = {};
0810     struct objagg *objagg;
0811     int i;
0812     int err;
0813 
0814     objagg = objagg_create(&delta_ops, NULL, &world);
0815     if (IS_ERR(objagg))
0816         return PTR_ERR(objagg);
0817 
0818     for (i = 0; i < ARRAY_SIZE(action_items); i++) {
0819         err = test_delta_action_item(&world, objagg,
0820                          &action_items[i], false);
0821         if (err)
0822             goto err_do_action_item;
0823     }
0824 
0825     objagg_destroy(objagg);
0826     return 0;
0827 
0828 err_do_action_item:
0829     for (i--; i >= 0; i--)
0830         test_delta_action_item(&world, objagg, &action_items[i], true);
0831 
0832     objagg_destroy(objagg);
0833     return err;
0834 }
0835 
0836 struct hints_case {
0837     const unsigned int *key_ids;
0838     size_t key_ids_count;
0839     struct expect_stats expect_stats;
0840     struct expect_stats expect_stats_hints;
0841 };
0842 
0843 static const unsigned int hints_case_key_ids[] = {
0844     1, 7, 3, 5, 3, 1, 30, 8, 8, 5, 6, 8,
0845 };
0846 
0847 static const struct hints_case hints_case = {
0848     .key_ids = hints_case_key_ids,
0849     .key_ids_count = ARRAY_SIZE(hints_case_key_ids),
0850     .expect_stats =
0851         EXPECT_STATS(7, ROOT(1, 2, 7), ROOT(7, 1, 4), ROOT(30, 1, 1),
0852                 DELTA(8, 3), DELTA(3, 2),
0853                 DELTA(5, 2), DELTA(6, 1)),
0854     .expect_stats_hints =
0855         EXPECT_STATS(7, ROOT(3, 2, 9), ROOT(1, 2, 2), ROOT(30, 1, 1),
0856                 DELTA(8, 3), DELTA(5, 2),
0857                 DELTA(6, 1), DELTA(7, 1)),
0858 };
0859 
0860 static void __pr_debug_stats(const struct objagg_stats *stats)
0861 {
0862     int i;
0863 
0864     for (i = 0; i < stats->stats_info_count; i++)
0865         pr_debug("Stat index %d key %u: u %d, d %d, %s\n", i,
0866              obj_to_key_id(stats->stats_info[i].objagg_obj),
0867              stats->stats_info[i].stats.user_count,
0868              stats->stats_info[i].stats.delta_user_count,
0869              stats->stats_info[i].is_root ? "root" : "noroot");
0870 }
0871 
0872 static void pr_debug_stats(struct objagg *objagg)
0873 {
0874     const struct objagg_stats *stats;
0875 
0876     stats = objagg_stats_get(objagg);
0877     if (IS_ERR(stats))
0878         return;
0879     __pr_debug_stats(stats);
0880     objagg_stats_put(stats);
0881 }
0882 
0883 static void pr_debug_hints_stats(struct objagg_hints *objagg_hints)
0884 {
0885     const struct objagg_stats *stats;
0886 
0887     stats = objagg_hints_stats_get(objagg_hints);
0888     if (IS_ERR(stats))
0889         return;
0890     __pr_debug_stats(stats);
0891     objagg_stats_put(stats);
0892 }
0893 
0894 static int check_expect_hints_stats(struct objagg_hints *objagg_hints,
0895                     const struct expect_stats *expect_stats,
0896                     const char **errmsg)
0897 {
0898     const struct objagg_stats *stats;
0899     int err;
0900 
0901     stats = objagg_hints_stats_get(objagg_hints);
0902     if (IS_ERR(stats))
0903         return PTR_ERR(stats);
0904     err = __check_expect_stats(stats, expect_stats, errmsg);
0905     objagg_stats_put(stats);
0906     return err;
0907 }
0908 
0909 static int test_hints_case(const struct hints_case *hints_case)
0910 {
0911     struct objagg_obj *objagg_obj;
0912     struct objagg_hints *hints;
0913     struct world world2 = {};
0914     struct world world = {};
0915     struct objagg *objagg2;
0916     struct objagg *objagg;
0917     const char *errmsg;
0918     int i;
0919     int err;
0920 
0921     objagg = objagg_create(&delta_ops, NULL, &world);
0922     if (IS_ERR(objagg))
0923         return PTR_ERR(objagg);
0924 
0925     for (i = 0; i < hints_case->key_ids_count; i++) {
0926         objagg_obj = world_obj_get(&world, objagg,
0927                        hints_case->key_ids[i]);
0928         if (IS_ERR(objagg_obj)) {
0929             err = PTR_ERR(objagg_obj);
0930             goto err_world_obj_get;
0931         }
0932     }
0933 
0934     pr_debug_stats(objagg);
0935     err = check_expect_stats(objagg, &hints_case->expect_stats, &errmsg);
0936     if (err) {
0937         pr_err("Stats: %s\n", errmsg);
0938         goto err_check_expect_stats;
0939     }
0940 
0941     hints = objagg_hints_get(objagg, OBJAGG_OPT_ALGO_SIMPLE_GREEDY);
0942     if (IS_ERR(hints)) {
0943         err = PTR_ERR(hints);
0944         goto err_hints_get;
0945     }
0946 
0947     pr_debug_hints_stats(hints);
0948     err = check_expect_hints_stats(hints, &hints_case->expect_stats_hints,
0949                        &errmsg);
0950     if (err) {
0951         pr_err("Hints stats: %s\n", errmsg);
0952         goto err_check_expect_hints_stats;
0953     }
0954 
0955     objagg2 = objagg_create(&delta_ops, hints, &world2);
0956     if (IS_ERR(objagg2))
0957         return PTR_ERR(objagg2);
0958 
0959     for (i = 0; i < hints_case->key_ids_count; i++) {
0960         objagg_obj = world_obj_get(&world2, objagg2,
0961                        hints_case->key_ids[i]);
0962         if (IS_ERR(objagg_obj)) {
0963             err = PTR_ERR(objagg_obj);
0964             goto err_world2_obj_get;
0965         }
0966     }
0967 
0968     pr_debug_stats(objagg2);
0969     err = check_expect_stats(objagg2, &hints_case->expect_stats_hints,
0970                  &errmsg);
0971     if (err) {
0972         pr_err("Stats2: %s\n", errmsg);
0973         goto err_check_expect_stats2;
0974     }
0975 
0976     err = 0;
0977 
0978 err_check_expect_stats2:
0979 err_world2_obj_get:
0980     for (i--; i >= 0; i--)
0981         world_obj_put(&world2, objagg, hints_case->key_ids[i]);
0982     i = hints_case->key_ids_count;
0983     objagg_destroy(objagg2);
0984 err_check_expect_hints_stats:
0985     objagg_hints_put(hints);
0986 err_hints_get:
0987 err_check_expect_stats:
0988 err_world_obj_get:
0989     for (i--; i >= 0; i--)
0990         world_obj_put(&world, objagg, hints_case->key_ids[i]);
0991 
0992     objagg_destroy(objagg);
0993     return err;
0994 }
0995 static int test_hints(void)
0996 {
0997     return test_hints_case(&hints_case);
0998 }
0999 
1000 static int __init test_objagg_init(void)
1001 {
1002     int err;
1003 
1004     err = test_nodelta();
1005     if (err)
1006         return err;
1007     err = test_delta();
1008     if (err)
1009         return err;
1010     return test_hints();
1011 }
1012 
1013 static void __exit test_objagg_exit(void)
1014 {
1015 }
1016 
1017 module_init(test_objagg_init);
1018 module_exit(test_objagg_exit);
1019 MODULE_LICENSE("Dual BSD/GPL");
1020 MODULE_AUTHOR("Jiri Pirko <jiri@mellanox.com>");
1021 MODULE_DESCRIPTION("Test module for objagg");