Skip to content

Commit f5406a5

Browse files
committed
Improve autovectorization of to_lowercase / to_uppercase functions
Refactor the code in the `convert_while_ascii` helper function to make it more suitable for auto-vectorization and also process the full ascii prefix of the string. The generic case conversion logic will only be invoked starting from the first non-ascii character. The runtime on microbenchmarks with ascii-only inputs improves between 1.5x for short and 4x for long inputs on x86_64 and aarch64. The new implementation also encapsulates all unsafe inside the `convert_while_ascii` function. Fixes rust-lang#123712
1 parent ba6158c commit f5406a5

File tree

3 files changed

+82
-52
lines changed

3 files changed

+82
-52
lines changed

alloc/benches/str.rs

+2
Original file line numberDiff line numberDiff line change
@@ -347,3 +347,5 @@ make_test!(rsplitn_space_char, s, s.rsplitn(10, ' ').count());
347347

348348
make_test!(split_space_str, s, s.split(" ").count());
349349
make_test!(split_ad_str, s, s.split("ad").count());
350+
351+
make_test!(to_lowercase, s, s.to_lowercase());

alloc/src/str.rs

+77-52
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
use core::borrow::{Borrow, BorrowMut};
1111
use core::iter::FusedIterator;
12+
use core::mem::MaybeUninit;
1213
#[stable(feature = "encode_utf16", since = "1.8.0")]
1314
pub use core::str::EncodeUtf16;
1415
#[stable(feature = "split_ascii_whitespace", since = "1.34.0")]
@@ -365,14 +366,9 @@ impl str {
365366
without modifying the original"]
366367
#[stable(feature = "unicode_case_mapping", since = "1.2.0")]
367368
pub fn to_lowercase(&self) -> String {
368-
let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_lowercase);
369+
let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_lowercase);
369370

370-
// Safety: we know this is a valid char boundary since
371-
// out.len() is only progressed if ascii bytes are found
372-
let rest = unsafe { self.get_unchecked(out.len()..) };
373-
374-
// Safety: We have written only valid ASCII to our vec
375-
let mut s = unsafe { String::from_utf8_unchecked(out) };
371+
let prefix_len = s.len();
376372

377373
for (i, c) in rest.char_indices() {
378374
if c == 'Σ' {
@@ -381,8 +377,7 @@ impl str {
381377
// in `SpecialCasing.txt`,
382378
// so hard-code it rather than have a generic "condition" mechanism.
383379
// See https://github.com/rust-lang/rust/issues/26035
384-
let out_len = self.len() - rest.len();
385-
let sigma_lowercase = map_uppercase_sigma(&self, i + out_len);
380+
let sigma_lowercase = map_uppercase_sigma(self, prefix_len + i);
386381
s.push(sigma_lowercase);
387382
} else {
388383
match conversions::to_lower(c) {
@@ -458,14 +453,7 @@ impl str {
458453
without modifying the original"]
459454
#[stable(feature = "unicode_case_mapping", since = "1.2.0")]
460455
pub fn to_uppercase(&self) -> String {
461-
let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_uppercase);
462-
463-
// Safety: we know this is a valid char boundary since
464-
// out.len() is only progressed if ascii bytes are found
465-
let rest = unsafe { self.get_unchecked(out.len()..) };
466-
467-
// Safety: We have written only valid ASCII to our vec
468-
let mut s = unsafe { String::from_utf8_unchecked(out) };
456+
let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_uppercase);
469457

470458
for c in rest.chars() {
471459
match conversions::to_upper(c) {
@@ -614,50 +602,87 @@ pub unsafe fn from_boxed_utf8_unchecked(v: Box<[u8]>) -> Box<str> {
614602
unsafe { Box::from_raw(Box::into_raw(v) as *mut str) }
615603
}
616604

617-
/// Converts the bytes while the bytes are still ascii.
605+
/// Converts leading ascii bytes in `s` by calling the `convert` function.
606+
///
618607
/// For better average performance, this happens in chunks of `2*size_of::<usize>()`.
619-
/// Returns a vec with the converted bytes.
608+
///
609+
/// Returns a tuple of the converted prefix and the remainder starting from
610+
/// the first non-ascii character.
611+
///
612+
/// This function is only public so that it can be verified in a codegen test,
613+
/// see `issue-123712-str-to-lower-autovectorization.rs`.
614+
#[unstable(feature = "str_internals", issue = "none")]
615+
#[doc(hidden)]
620616
#[inline]
621617
#[cfg(not(test))]
622618
#[cfg(not(no_global_oom_handling))]
623-
fn convert_while_ascii(b: &[u8], convert: fn(&u8) -> u8) -> Vec<u8> {
624-
let mut out = Vec::with_capacity(b.len());
619+
pub fn convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> (String, &str) {
620+
// Process the input in chunks of 16 bytes to enable auto-vectorization.
621+
// Previously the chunk size depended on the size of `usize`,
622+
// but on 32-bit platforms with sse or neon is also the better choice.
623+
// The only downside on other platforms would be a bit more loop-unrolling.
624+
const N: usize = 16;
625+
626+
let mut slice = s.as_bytes();
627+
let mut out = Vec::with_capacity(slice.len());
628+
let mut out_slice = out.spare_capacity_mut();
629+
630+
let mut ascii_prefix_len = 0_usize;
631+
let mut is_ascii = [false; N];
632+
633+
while slice.len() >= N {
634+
// SAFETY: checked in loop condition
635+
let chunk = unsafe { slice.get_unchecked(..N) };
636+
// SAFETY: out_slice has at least same length as input slice and gets sliced with the same offsets
637+
let out_chunk = unsafe { out_slice.get_unchecked_mut(..N) };
638+
639+
for j in 0..N {
640+
is_ascii[j] = chunk[j] <= 127;
641+
}
625642

626-
const USIZE_SIZE: usize = mem::size_of::<usize>();
627-
const MAGIC_UNROLL: usize = 2;
628-
const N: usize = USIZE_SIZE * MAGIC_UNROLL;
629-
const NONASCII_MASK: usize = usize::from_ne_bytes([0x80; USIZE_SIZE]);
643+
// Auto-vectorization for this check is a bit fragile, sum and comparing against the chunk
644+
// size gives the best result, specifically a pmovmsk instruction on x86.
645+
// See https://github.com/llvm/llvm-project/issues/96395 for why llvm currently does not
646+
// currently recognize other similar idioms.
647+
if is_ascii.iter().map(|x| *x as u8).sum::<u8>() as usize != N {
648+
break;
649+
}
630650

631-
let mut i = 0;
632-
unsafe {
633-
while i + N <= b.len() {
634-
// Safety: we have checks the sizes `b` and `out` to know that our
635-
let in_chunk = b.get_unchecked(i..i + N);
636-
let out_chunk = out.spare_capacity_mut().get_unchecked_mut(i..i + N);
637-
638-
let mut bits = 0;
639-
for j in 0..MAGIC_UNROLL {
640-
// read the bytes 1 usize at a time (unaligned since we haven't checked the alignment)
641-
// safety: in_chunk is valid bytes in the range
642-
bits |= in_chunk.as_ptr().cast::<usize>().add(j).read_unaligned();
643-
}
644-
// if our chunks aren't ascii, then return only the prior bytes as init
645-
if bits & NONASCII_MASK != 0 {
646-
break;
647-
}
651+
for j in 0..N {
652+
out_chunk[j] = MaybeUninit::new(convert(&chunk[j]));
653+
}
648654

649-
// perform the case conversions on N bytes (gets heavily autovec'd)
650-
for j in 0..N {
651-
// safety: in_chunk and out_chunk is valid bytes in the range
652-
let out = out_chunk.get_unchecked_mut(j);
653-
out.write(convert(in_chunk.get_unchecked(j)));
654-
}
655+
ascii_prefix_len += N;
656+
slice = unsafe { slice.get_unchecked(N..) };
657+
out_slice = unsafe { out_slice.get_unchecked_mut(N..) };
658+
}
655659

656-
// mark these bytes as initialised
657-
i += N;
660+
// handle the remainder as individual bytes
661+
while slice.len() > 0 {
662+
let byte = slice[0];
663+
if byte > 127 {
664+
break;
665+
}
666+
// SAFETY: out_slice has at least same length as input slice
667+
unsafe {
668+
*out_slice.get_unchecked_mut(0) = MaybeUninit::new(convert(&byte));
658669
}
659-
out.set_len(i);
670+
ascii_prefix_len += 1;
671+
slice = unsafe { slice.get_unchecked(1..) };
672+
out_slice = unsafe { out_slice.get_unchecked_mut(1..) };
660673
}
661674

662-
out
675+
unsafe {
676+
// SAFETY: ascii_prefix_len bytes have been initialized above
677+
out.set_len(ascii_prefix_len);
678+
679+
// SAFETY: We have written only valid ascii to the output vec
680+
let ascii_string = String::from_utf8_unchecked(out);
681+
682+
// SAFETY: we know this is a valid char boundary
683+
// since we only skipped over leading ascii bytes
684+
let rest = core::str::from_utf8_unchecked(slice);
685+
686+
(ascii_string, rest)
687+
}
663688
}

alloc/tests/str.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1854,7 +1854,10 @@ fn to_lowercase() {
18541854
assert_eq!("ΑΣ''Α".to_lowercase(), "ασ''α");
18551855

18561856
// https://github.com/rust-lang/rust/issues/124714
1857+
// input lengths around the boundary of the chunk size used by the ascii prefix optimization
1858+
assert_eq!("abcdefghijklmnoΣ".to_lowercase(), "abcdefghijklmnoς");
18571859
assert_eq!("abcdefghijklmnopΣ".to_lowercase(), "abcdefghijklmnopς");
1860+
assert_eq!("abcdefghijklmnopqΣ".to_lowercase(), "abcdefghijklmnopqς");
18581861

18591862
// a really long string that has it's lowercase form
18601863
// even longer. this tests that implementations don't assume

0 commit comments

Comments
 (0)