Skip to content

Commit 9dccd5d

Browse files
committed
Further Implement Power of Two Optimization
1 parent 971e37f commit 9dccd5d

File tree

3 files changed

+350
-185
lines changed

3 files changed

+350
-185
lines changed

library/core/src/num/int_macros.rs

+186-73
Original file line numberDiff line numberDiff line change
@@ -901,26 +901,59 @@ macro_rules! int_impl {
901901
#[rustc_const_stable(feature = "const_int_pow", since = "1.50.0")]
902902
#[must_use = "this returns the result of the operation, \
903903
without modifying the original"]
904+
#[rustc_allow_const_fn_unstable(is_val_statically_known, const_int_unchecked_arith)]
904905
#[inline]
905906
pub const fn checked_pow(self, mut exp: u32) -> Option<Self> {
906-
if exp == 0 {
907-
return Some(1);
908-
}
909-
let mut base = self;
910-
let mut acc: Self = 1;
907+
// SAFETY: This path has the same behavior as the other.
908+
if unsafe { intrinsics::is_val_statically_known(self) }
909+
&& self.unsigned_abs().is_power_of_two()
910+
{
911+
if self == 1 { // Avoid divide by zero
912+
return Some(1);
913+
}
914+
if self == -1 { // Avoid divide by zero
915+
return Some(if exp & 1 != 0 { -1 } else { 1 });
916+
}
917+
// SAFETY: We just checked this is a power of two. and above zero.
918+
let power_used = unsafe { intrinsics::cttz_nonzero(self.wrapping_abs()) as u32 };
919+
if exp > Self::BITS / power_used { return None; } // Division of constants is free
920+
921+
// SAFETY: exp <= Self::BITS / power_used
922+
let res = unsafe { intrinsics::unchecked_shl(
923+
1 as Self,
924+
intrinsics::unchecked_mul(power_used, exp) as Self
925+
)};
926+
// LLVM doesn't always optimize out the checks
927+
// at the ir level.
928+
929+
let sign = self.is_negative() && exp & 1 != 0;
930+
if !sign && res == Self::MIN {
931+
None
932+
} else if sign {
933+
Some(res.wrapping_neg())
934+
} else {
935+
Some(res)
936+
}
937+
} else {
938+
if exp == 0 {
939+
return Some(1);
940+
}
941+
let mut base = self;
942+
let mut acc: Self = 1;
911943

912-
while exp > 1 {
913-
if (exp & 1) == 1 {
914-
acc = try_opt!(acc.checked_mul(base));
944+
while exp > 1 {
945+
if (exp & 1) == 1 {
946+
acc = try_opt!(acc.checked_mul(base));
947+
}
948+
exp /= 2;
949+
base = try_opt!(base.checked_mul(base));
915950
}
916-
exp /= 2;
917-
base = try_opt!(base.checked_mul(base));
951+
// since exp!=0, finally the exp must be 1.
952+
// Deal with the final bit of the exponent separately, since
953+
// squaring the base afterwards is not necessary and may cause a
954+
// needless overflow.
955+
acc.checked_mul(base)
918956
}
919-
// since exp!=0, finally the exp must be 1.
920-
// Deal with the final bit of the exponent separately, since
921-
// squaring the base afterwards is not necessary and may cause a
922-
// needless overflow.
923-
acc.checked_mul(base)
924957
}
925958

926959
/// Returns the square root of the number, rounded down.
@@ -1537,27 +1570,58 @@ macro_rules! int_impl {
15371570
#[rustc_const_stable(feature = "const_int_pow", since = "1.50.0")]
15381571
#[must_use = "this returns the result of the operation, \
15391572
without modifying the original"]
1573+
#[rustc_allow_const_fn_unstable(is_val_statically_known, const_int_unchecked_arith)]
15401574
#[inline]
15411575
pub const fn wrapping_pow(self, mut exp: u32) -> Self {
1542-
if exp == 0 {
1543-
return 1;
1544-
}
1545-
let mut base = self;
1546-
let mut acc: Self = 1;
1576+
// SAFETY: This path has the same behavior as the other.
1577+
if unsafe { intrinsics::is_val_statically_known(self) }
1578+
&& self.unsigned_abs().is_power_of_two()
1579+
{
1580+
if self == 1 { // Avoid divide by zero
1581+
return 1;
1582+
}
1583+
if self == -1 { // Avoid divide by zero
1584+
return if exp & 1 != 0 { -1 } else { 1 };
1585+
}
1586+
// SAFETY: We just checked this is a power of two. and above zero.
1587+
let power_used = unsafe { intrinsics::cttz_nonzero(self.wrapping_abs()) as u32 };
1588+
if exp > Self::BITS / power_used { return 0; } // Division of constants is free
1589+
1590+
// SAFETY: exp <= Self::BITS / power_used
1591+
let res = unsafe { intrinsics::unchecked_shl(
1592+
1 as Self,
1593+
intrinsics::unchecked_mul(power_used, exp) as Self
1594+
)};
1595+
// LLVM doesn't always optimize out the checks
1596+
// at the ir level.
1597+
1598+
let sign = self.is_negative() && exp & 1 != 0;
1599+
if sign {
1600+
res.wrapping_neg()
1601+
} else {
1602+
res
1603+
}
1604+
} else {
1605+
if exp == 0 {
1606+
return 1;
1607+
}
1608+
let mut base = self;
1609+
let mut acc: Self = 1;
15471610

1548-
while exp > 1 {
1549-
if (exp & 1) == 1 {
1550-
acc = acc.wrapping_mul(base);
1611+
while exp > 1 {
1612+
if (exp & 1) == 1 {
1613+
acc = acc.wrapping_mul(base);
1614+
}
1615+
exp /= 2;
1616+
base = base.wrapping_mul(base);
15511617
}
1552-
exp /= 2;
1553-
base = base.wrapping_mul(base);
1554-
}
15551618

1556-
// since exp!=0, finally the exp must be 1.
1557-
// Deal with the final bit of the exponent separately, since
1558-
// squaring the base afterwards is not necessary and may cause a
1559-
// needless overflow.
1560-
acc.wrapping_mul(base)
1619+
// since exp!=0, finally the exp must be 1.
1620+
// Deal with the final bit of the exponent separately, since
1621+
// squaring the base afterwards is not necessary and may cause a
1622+
// needless overflow.
1623+
acc.wrapping_mul(base)
1624+
}
15611625
}
15621626

15631627
/// Calculates `self` + `rhs`
@@ -2039,36 +2103,68 @@ macro_rules! int_impl {
20392103
#[rustc_const_stable(feature = "const_int_pow", since = "1.50.0")]
20402104
#[must_use = "this returns the result of the operation, \
20412105
without modifying the original"]
2106+
#[rustc_allow_const_fn_unstable(is_val_statically_known, const_int_unchecked_arith)]
20422107
#[inline]
20432108
pub const fn overflowing_pow(self, mut exp: u32) -> (Self, bool) {
2044-
if exp == 0 {
2045-
return (1,false);
2046-
}
2047-
let mut base = self;
2048-
let mut acc: Self = 1;
2049-
let mut overflown = false;
2050-
// Scratch space for storing results of overflowing_mul.
2051-
let mut r;
2052-
2053-
while exp > 1 {
2054-
if (exp & 1) == 1 {
2055-
r = acc.overflowing_mul(base);
2056-
acc = r.0;
2109+
// SAFETY: This path has the same behavior as the other.
2110+
if unsafe { intrinsics::is_val_statically_known(self) }
2111+
&& self.unsigned_abs().is_power_of_two()
2112+
{
2113+
if self == 1 { // Avoid divide by zero
2114+
return (1, false);
2115+
}
2116+
if self == -1 { // Avoid divide by zero
2117+
return (if exp & 1 != 0 { -1 } else { 1 }, false);
2118+
}
2119+
// SAFETY: We just checked this is a power of two. and above zero.
2120+
let power_used = unsafe { intrinsics::cttz_nonzero(self.wrapping_abs()) as u32 };
2121+
if exp > Self::BITS / power_used { return (0, true); } // Division of constants is free
2122+
2123+
// SAFETY: exp <= Self::BITS / power_used
2124+
let res = unsafe { intrinsics::unchecked_shl(
2125+
1 as Self,
2126+
intrinsics::unchecked_mul(power_used, exp) as Self
2127+
)};
2128+
// LLVM doesn't always optimize out the checks
2129+
// at the ir level.
2130+
2131+
let sign = self.is_negative() && exp & 1 != 0;
2132+
let overflow = res == Self::MIN;
2133+
if sign {
2134+
(res.wrapping_neg(), overflow)
2135+
} else {
2136+
(res, overflow)
2137+
}
2138+
} else {
2139+
if exp == 0 {
2140+
return (1,false);
2141+
}
2142+
let mut base = self;
2143+
let mut acc: Self = 1;
2144+
let mut overflown = false;
2145+
// Scratch space for storing results of overflowing_mul.
2146+
let mut r;
2147+
2148+
while exp > 1 {
2149+
if (exp & 1) == 1 {
2150+
r = acc.overflowing_mul(base);
2151+
acc = r.0;
2152+
overflown |= r.1;
2153+
}
2154+
exp /= 2;
2155+
r = base.overflowing_mul(base);
2156+
base = r.0;
20572157
overflown |= r.1;
20582158
}
2059-
exp /= 2;
2060-
r = base.overflowing_mul(base);
2061-
base = r.0;
2062-
overflown |= r.1;
2063-
}
20642159

2065-
// since exp!=0, finally the exp must be 1.
2066-
// Deal with the final bit of the exponent separately, since
2067-
// squaring the base afterwards is not necessary and may cause a
2068-
// needless overflow.
2069-
r = acc.overflowing_mul(base);
2070-
r.1 |= overflown;
2071-
r
2160+
// since exp!=0, finally the exp must be 1.
2161+
// Deal with the final bit of the exponent separately, since
2162+
// squaring the base afterwards is not necessary and may cause a
2163+
// needless overflow.
2164+
r = acc.overflowing_mul(base);
2165+
r.1 |= overflown;
2166+
r
2167+
}
20722168
}
20732169

20742170
/// Raises self to the power of `exp`, using exponentiation by squaring.
@@ -2086,30 +2182,47 @@ macro_rules! int_impl {
20862182
#[rustc_const_stable(feature = "const_int_pow", since = "1.50.0")]
20872183
#[must_use = "this returns the result of the operation, \
20882184
without modifying the original"]
2185+
#[rustc_allow_const_fn_unstable(is_val_statically_known, const_int_unchecked_arith)]
20892186
#[inline]
20902187
#[rustc_inherit_overflow_checks]
2091-
#[rustc_allow_const_fn_unstable(is_val_statically_known)]
2188+
#[track_caller] // Hides the hackish overflow check for powers of two.
20922189
pub const fn pow(self, mut exp: u32) -> Self {
20932190
// SAFETY: This path has the same behavior as the other.
20942191
if unsafe { intrinsics::is_val_statically_known(self) }
2095-
&& self > 0
2096-
&& (self & (self - 1) == 0)
2192+
&& self.unsigned_abs().is_power_of_two()
20972193
{
2098-
let power_used = match self.checked_ilog2() {
2099-
Some(v) => v,
2100-
// SAFETY: We just checked this is a power of two. and above zero.
2101-
None => unsafe { core::hint::unreachable_unchecked() },
2102-
};
2103-
// So it panics. Have to use `overflowing_mul` to efficiently set the
2104-
// result to 0 if not.
2105-
#[cfg(debug_assertions)]
2106-
{
2107-
_ = power_used * exp;
2194+
if self == 1 { // Avoid divide by zero
2195+
return 1;
2196+
}
2197+
if self == -1 { // Avoid divide by zero
2198+
return if exp & 1 != 0 { -1 } else { 1 };
2199+
}
2200+
// SAFETY: We just checked this is a power of two. and above zero.
2201+
let power_used = unsafe { intrinsics::cttz_nonzero(self.wrapping_abs()) as u32 };
2202+
if exp > Self::BITS / power_used { // Division of constants is free
2203+
#[allow(arithmetic_overflow)]
2204+
return Self::MAX * Self::MAX * 0;
2205+
}
2206+
2207+
// SAFETY: exp <= Self::BITS / power_used
2208+
let res = unsafe { intrinsics::unchecked_shl(
2209+
1 as Self,
2210+
intrinsics::unchecked_mul(power_used, exp) as Self
2211+
)};
2212+
// LLVM doesn't always optimize out the checks
2213+
// at the ir level.
2214+
2215+
let sign = self.is_negative() && exp & 1 != 0;
2216+
#[allow(arithmetic_overflow)]
2217+
if !sign && res == Self::MIN {
2218+
// So it panics.
2219+
_ = Self::MAX * Self::MAX;
2220+
}
2221+
if sign {
2222+
res.wrapping_neg()
2223+
} else {
2224+
res
21082225
}
2109-
let (num_shl, overflowed) = power_used.overflowing_mul(exp);
2110-
let fine = !overflowed
2111-
& (num_shl < (mem::size_of::<Self>() * 8) as u32);
2112-
(1 << num_shl) * fine as Self
21132226
} else {
21142227
if exp == 0 {
21152228
return 1;

0 commit comments

Comments
 (0)