0001
0002
0003
0004
0005
0006 #define _GNU_SOURCE
0007 #include <assert.h>
0008 #include <errno.h>
0009 #include <limits.h>
0010 #include <stddef.h>
0011 #include <stdio.h>
0012 #include <stdlib.h>
0013 #include <string.h>
0014 #include <getopt.h>
0015 #include <unistd.h>
0016 #include <sys/auxv.h>
0017 #include <sys/prctl.h>
0018 #include <asm/hwcap.h>
0019 #include <asm/sigcontext.h>
0020
0021 static int inherit = 0;
0022 static int no_inherit = 0;
0023 static int force = 0;
0024 static unsigned long vl;
0025 static int set_ctl = PR_SVE_SET_VL;
0026 static int get_ctl = PR_SVE_GET_VL;
0027
0028 static const struct option options[] = {
0029 { "force", no_argument, NULL, 'f' },
0030 { "inherit", no_argument, NULL, 'i' },
0031 { "max", no_argument, NULL, 'M' },
0032 { "no-inherit", no_argument, &no_inherit, 1 },
0033 { "sme", no_argument, NULL, 's' },
0034 { "help", no_argument, NULL, '?' },
0035 {}
0036 };
0037
0038 static char const *program_name;
0039
0040 static int parse_options(int argc, char **argv)
0041 {
0042 int c;
0043 char *rest;
0044
0045 program_name = strrchr(argv[0], '/');
0046 if (program_name)
0047 ++program_name;
0048 else
0049 program_name = argv[0];
0050
0051 while ((c = getopt_long(argc, argv, "Mfhi", options, NULL)) != -1)
0052 switch (c) {
0053 case 'M': vl = SVE_VL_MAX; break;
0054 case 'f': force = 1; break;
0055 case 'i': inherit = 1; break;
0056 case 's': set_ctl = PR_SME_SET_VL;
0057 get_ctl = PR_SME_GET_VL;
0058 break;
0059 case 0: break;
0060 default: goto error;
0061 }
0062
0063 if (inherit && no_inherit)
0064 goto error;
0065
0066 if (!vl) {
0067
0068 if (optind >= argc)
0069 goto error;
0070
0071 errno = 0;
0072 vl = strtoul(argv[optind], &rest, 0);
0073 if (*rest) {
0074 vl = ULONG_MAX;
0075 errno = EINVAL;
0076 }
0077 if (vl == ULONG_MAX && errno) {
0078 fprintf(stderr, "%s: %s: %s\n",
0079 program_name, argv[optind], strerror(errno));
0080 goto error;
0081 }
0082
0083 ++optind;
0084 }
0085
0086
0087 if (optind >= argc)
0088 goto error;
0089
0090 return 0;
0091
0092 error:
0093 fprintf(stderr,
0094 "Usage: %s [-f | --force] "
0095 "[-i | --inherit | --no-inherit] "
0096 "{-M | --max | <vector length>} "
0097 "<command> [<arguments> ...]\n",
0098 program_name);
0099 return -1;
0100 }
0101
0102 int main(int argc, char **argv)
0103 {
0104 int ret = 126;
0105 long flags;
0106 char *path;
0107 int t, e;
0108
0109 if (parse_options(argc, argv))
0110 return 2;
0111
0112 if (vl & ~(vl & PR_SVE_VL_LEN_MASK)) {
0113 fprintf(stderr, "%s: Invalid vector length %lu\n",
0114 program_name, vl);
0115 return 2;
0116 }
0117
0118 if (!(getauxval(AT_HWCAP) & HWCAP_SVE)) {
0119 fprintf(stderr, "%s: Scalable Vector Extension not present\n",
0120 program_name);
0121
0122 if (!force)
0123 goto error;
0124
0125 fputs("Going ahead anyway (--force): "
0126 "This is a debug option. Don't rely on it.\n",
0127 stderr);
0128 }
0129
0130 flags = PR_SVE_SET_VL_ONEXEC;
0131 if (inherit)
0132 flags |= PR_SVE_VL_INHERIT;
0133
0134 t = prctl(set_ctl, vl | flags);
0135 if (t < 0) {
0136 fprintf(stderr, "%s: PR_SVE_SET_VL: %s\n",
0137 program_name, strerror(errno));
0138 goto error;
0139 }
0140
0141 t = prctl(get_ctl);
0142 if (t == -1) {
0143 fprintf(stderr, "%s: PR_SVE_GET_VL: %s\n",
0144 program_name, strerror(errno));
0145 goto error;
0146 }
0147 flags = PR_SVE_VL_LEN_MASK;
0148 flags = t & ~flags;
0149
0150 assert(optind < argc);
0151 path = argv[optind];
0152
0153 execvp(path, &argv[optind]);
0154 e = errno;
0155 if (errno == ENOENT)
0156 ret = 127;
0157 fprintf(stderr, "%s: %s: %s\n", program_name, path, strerror(e));
0158
0159 error:
0160 return ret;
0161 }