Back to home page

OSCL-LXR

 
 

    


0001 // SPDX-License-Identifier: GPL-2.0-only
0002 // Copyright (C) 2019-2020 Arm Ltd.
0003 
0004 #include <linux/compiler.h>
0005 #include <linux/kasan-checks.h>
0006 #include <linux/kernel.h>
0007 
0008 #include <net/checksum.h>
0009 
0010 /* Looks dumb, but generates nice-ish code */
0011 static u64 accumulate(u64 sum, u64 data)
0012 {
0013     __uint128_t tmp = (__uint128_t)sum + data;
0014     return tmp + (tmp >> 64);
0015 }
0016 
0017 /*
0018  * We over-read the buffer and this makes KASAN unhappy. Instead, disable
0019  * instrumentation and call kasan explicitly.
0020  */
0021 unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len)
0022 {
0023     unsigned int offset, shift, sum;
0024     const u64 *ptr;
0025     u64 data, sum64 = 0;
0026 
0027     if (unlikely(len == 0))
0028         return 0;
0029 
0030     offset = (unsigned long)buff & 7;
0031     /*
0032      * This is to all intents and purposes safe, since rounding down cannot
0033      * result in a different page or cache line being accessed, and @buff
0034      * should absolutely not be pointing to anything read-sensitive. We do,
0035      * however, have to be careful not to piss off KASAN, which means using
0036      * unchecked reads to accommodate the head and tail, for which we'll
0037      * compensate with an explicit check up-front.
0038      */
0039     kasan_check_read(buff, len);
0040     ptr = (u64 *)(buff - offset);
0041     len = len + offset - 8;
0042 
0043     /*
0044      * Head: zero out any excess leading bytes. Shifting back by the same
0045      * amount should be at least as fast as any other way of handling the
0046      * odd/even alignment, and means we can ignore it until the very end.
0047      */
0048     shift = offset * 8;
0049     data = *ptr++;
0050 #ifdef __LITTLE_ENDIAN
0051     data = (data >> shift) << shift;
0052 #else
0053     data = (data << shift) >> shift;
0054 #endif
0055 
0056     /*
0057      * Body: straightforward aligned loads from here on (the paired loads
0058      * underlying the quadword type still only need dword alignment). The
0059      * main loop strictly excludes the tail, so the second loop will always
0060      * run at least once.
0061      */
0062     while (unlikely(len > 64)) {
0063         __uint128_t tmp1, tmp2, tmp3, tmp4;
0064 
0065         tmp1 = *(__uint128_t *)ptr;
0066         tmp2 = *(__uint128_t *)(ptr + 2);
0067         tmp3 = *(__uint128_t *)(ptr + 4);
0068         tmp4 = *(__uint128_t *)(ptr + 6);
0069 
0070         len -= 64;
0071         ptr += 8;
0072 
0073         /* This is the "don't dump the carry flag into a GPR" idiom */
0074         tmp1 += (tmp1 >> 64) | (tmp1 << 64);
0075         tmp2 += (tmp2 >> 64) | (tmp2 << 64);
0076         tmp3 += (tmp3 >> 64) | (tmp3 << 64);
0077         tmp4 += (tmp4 >> 64) | (tmp4 << 64);
0078         tmp1 = ((tmp1 >> 64) << 64) | (tmp2 >> 64);
0079         tmp1 += (tmp1 >> 64) | (tmp1 << 64);
0080         tmp3 = ((tmp3 >> 64) << 64) | (tmp4 >> 64);
0081         tmp3 += (tmp3 >> 64) | (tmp3 << 64);
0082         tmp1 = ((tmp1 >> 64) << 64) | (tmp3 >> 64);
0083         tmp1 += (tmp1 >> 64) | (tmp1 << 64);
0084         tmp1 = ((tmp1 >> 64) << 64) | sum64;
0085         tmp1 += (tmp1 >> 64) | (tmp1 << 64);
0086         sum64 = tmp1 >> 64;
0087     }
0088     while (len > 8) {
0089         __uint128_t tmp;
0090 
0091         sum64 = accumulate(sum64, data);
0092         tmp = *(__uint128_t *)ptr;
0093 
0094         len -= 16;
0095         ptr += 2;
0096 
0097 #ifdef __LITTLE_ENDIAN
0098         data = tmp >> 64;
0099         sum64 = accumulate(sum64, tmp);
0100 #else
0101         data = tmp;
0102         sum64 = accumulate(sum64, tmp >> 64);
0103 #endif
0104     }
0105     if (len > 0) {
0106         sum64 = accumulate(sum64, data);
0107         data = *ptr;
0108         len -= 8;
0109     }
0110     /*
0111      * Tail: zero any over-read bytes similarly to the head, again
0112      * preserving odd/even alignment.
0113      */
0114     shift = len * -8;
0115 #ifdef __LITTLE_ENDIAN
0116     data = (data << shift) >> shift;
0117 #else
0118     data = (data >> shift) << shift;
0119 #endif
0120     sum64 = accumulate(sum64, data);
0121 
0122     /* Finally, folding */
0123     sum64 += (sum64 >> 32) | (sum64 << 32);
0124     sum = sum64 >> 32;
0125     sum += (sum >> 16) | (sum << 16);
0126     if (offset & 1)
0127         return (u16)swab32(sum);
0128 
0129     return sum >> 16;
0130 }
0131 
0132 __sum16 csum_ipv6_magic(const struct in6_addr *saddr,
0133             const struct in6_addr *daddr,
0134             __u32 len, __u8 proto, __wsum csum)
0135 {
0136     __uint128_t src, dst;
0137     u64 sum = (__force u64)csum;
0138 
0139     src = *(const __uint128_t *)saddr->s6_addr;
0140     dst = *(const __uint128_t *)daddr->s6_addr;
0141 
0142     sum += (__force u32)htonl(len);
0143 #ifdef __LITTLE_ENDIAN
0144     sum += (u32)proto << 24;
0145 #else
0146     sum += proto;
0147 #endif
0148     src += (src >> 64) | (src << 64);
0149     dst += (dst >> 64) | (dst << 64);
0150 
0151     sum = accumulate(sum, src >> 64);
0152     sum = accumulate(sum, dst >> 64);
0153 
0154     sum += ((sum >> 32) | (sum << 32));
0155     return csum_fold((__force __wsum)(sum >> 32));
0156 }
0157 EXPORT_SYMBOL(csum_ipv6_magic);