Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0-only
0002 /*
0003  * Copyright (C) 2015-2019 ARM Limited.
0004  * Original author: Dave Martin <Dave.Martin@arm.com>
0005  */
0006 #define _GNU_SOURCE
0007 #include <assert.h>
0008 #include <errno.h>
0009 #include <limits.h>
0010 #include <stddef.h>
0011 #include <stdio.h>
0012 #include <stdlib.h>
0013 #include <string.h>
0014 #include <getopt.h>
0015 #include <unistd.h>
0016 #include <sys/auxv.h>
0017 #include <sys/prctl.h>
0018 #include <asm/hwcap.h>
0019 #include <asm/sigcontext.h>
0020 
0021 static int inherit = 0;
0022 static int no_inherit = 0;
0023 static int force = 0;
0024 static unsigned long vl;
0025 static int set_ctl = PR_SVE_SET_VL;
0026 static int get_ctl = PR_SVE_GET_VL;
0027 
0028 static const struct option options[] = {
0029     { "force",  no_argument, NULL, 'f' },
0030     { "inherit",    no_argument, NULL, 'i' },
0031     { "max",    no_argument, NULL, 'M' },
0032     { "no-inherit", no_argument, &no_inherit, 1 },
0033     { "sme",    no_argument, NULL, 's' },
0034     { "help",   no_argument, NULL, '?' },
0035     {}
0036 };
0037 
0038 static char const *program_name;
0039 
0040 static int parse_options(int argc, char **argv)
0041 {
0042     int c;
0043     char *rest;
0044 
0045     program_name = strrchr(argv[0], '/');
0046     if (program_name)
0047         ++program_name;
0048     else
0049         program_name = argv[0];
0050 
0051     while ((c = getopt_long(argc, argv, "Mfhi", options, NULL)) != -1)
0052         switch (c) {
0053         case 'M':   vl = SVE_VL_MAX; break;
0054         case 'f':   force = 1; break;
0055         case 'i':   inherit = 1; break;
0056         case 's':   set_ctl = PR_SME_SET_VL;
0057                 get_ctl = PR_SME_GET_VL;
0058                 break;
0059         case 0:     break;
0060         default:    goto error;
0061         }
0062 
0063     if (inherit && no_inherit)
0064         goto error;
0065 
0066     if (!vl) {
0067         /* vector length */
0068         if (optind >= argc)
0069             goto error;
0070 
0071         errno = 0;
0072         vl = strtoul(argv[optind], &rest, 0);
0073         if (*rest) {
0074             vl = ULONG_MAX;
0075             errno = EINVAL;
0076         }
0077         if (vl == ULONG_MAX && errno) {
0078             fprintf(stderr, "%s: %s: %s\n",
0079                 program_name, argv[optind], strerror(errno));
0080             goto error;
0081         }
0082 
0083         ++optind;
0084     }
0085 
0086     /* command */
0087     if (optind >= argc)
0088         goto error;
0089 
0090     return 0;
0091 
0092 error:
0093     fprintf(stderr,
0094         "Usage: %s [-f | --force] "
0095         "[-i | --inherit | --no-inherit] "
0096         "{-M | --max | <vector length>} "
0097         "<command> [<arguments> ...]\n",
0098         program_name);
0099     return -1;
0100 }
0101 
0102 int main(int argc, char **argv)
0103 {
0104     int ret = 126;  /* same as sh(1) command-not-executable error */
0105     long flags;
0106     char *path;
0107     int t, e;
0108 
0109     if (parse_options(argc, argv))
0110         return 2;   /* same as sh(1) builtin incorrect-usage */
0111 
0112     if (vl & ~(vl & PR_SVE_VL_LEN_MASK)) {
0113         fprintf(stderr, "%s: Invalid vector length %lu\n",
0114             program_name, vl);
0115         return 2;   /* same as sh(1) builtin incorrect-usage */
0116     }
0117 
0118     if (!(getauxval(AT_HWCAP) & HWCAP_SVE)) {
0119         fprintf(stderr, "%s: Scalable Vector Extension not present\n",
0120             program_name);
0121 
0122         if (!force)
0123             goto error;
0124 
0125         fputs("Going ahead anyway (--force):  "
0126               "This is a debug option.  Don't rely on it.\n",
0127               stderr);
0128     }
0129 
0130     flags = PR_SVE_SET_VL_ONEXEC;
0131     if (inherit)
0132         flags |= PR_SVE_VL_INHERIT;
0133 
0134     t = prctl(set_ctl, vl | flags);
0135     if (t < 0) {
0136         fprintf(stderr, "%s: PR_SVE_SET_VL: %s\n",
0137             program_name, strerror(errno));
0138         goto error;
0139     }
0140 
0141     t = prctl(get_ctl);
0142     if (t == -1) {
0143         fprintf(stderr, "%s: PR_SVE_GET_VL: %s\n",
0144             program_name, strerror(errno));
0145         goto error;
0146     }
0147     flags = PR_SVE_VL_LEN_MASK;
0148     flags = t & ~flags;
0149 
0150     assert(optind < argc);
0151     path = argv[optind];
0152 
0153     execvp(path, &argv[optind]);
0154     e = errno;
0155     if (errno == ENOENT)
0156         ret = 127;  /* same as sh(1) not-found error */
0157     fprintf(stderr, "%s: %s: %s\n", program_name, path, strerror(e));
0158 
0159 error:
0160     return ret;     /* same as sh(1) not-executable error */
0161 }