forked from OSchip/llvm-project
[libc][math] Implement full multiplication and quick_mul_hi for UInt class.
Implement full multiplication `UInt<A> * UInt<B> -> UInt<A + B>` and `quick_mul_hi` that returns the higher half of the product `UInt<A> * UInt<A>`. These 2 functions will be used for dyadic floating point class. Reviewed By: sivachandra Differential Revision: https://reviews.llvm.org/D138541
This commit is contained in:
parent
0c2b7fa869
commit
b80f535879
|
@ -191,6 +191,78 @@ template <size_t Bits> struct UInt {
|
|||
}
|
||||
}
|
||||
|
||||
// Return the full product.
|
||||
template <size_t OtherBits>
|
||||
constexpr UInt<Bits + OtherBits> ful_mul(const UInt<OtherBits> &other) const {
|
||||
UInt<Bits + OtherBits> result(0);
|
||||
UInt<128> partial_sum(0);
|
||||
uint64_t carry = 0;
|
||||
constexpr size_t OtherWordCount = UInt<OtherBits>::WordCount;
|
||||
for (size_t i = 0; i <= WordCount + OtherWordCount - 2; ++i) {
|
||||
const size_t lower_idx = i < OtherWordCount ? 0 : i - OtherWordCount + 1;
|
||||
const size_t upper_idx = i < WordCount ? i : WordCount - 1;
|
||||
for (size_t j = lower_idx; j <= upper_idx; ++j) {
|
||||
NumberPair<uint64_t> prod = full_mul(val[j], other.val[i - j]);
|
||||
UInt<128> tmp({prod.lo, prod.hi});
|
||||
carry += partial_sum.add(tmp);
|
||||
}
|
||||
result.val[i] = partial_sum.val[0];
|
||||
partial_sum.val[0] = partial_sum.val[1];
|
||||
partial_sum.val[1] = carry;
|
||||
carry = 0;
|
||||
}
|
||||
result.val[WordCount + OtherWordCount - 1] = partial_sum.val[0];
|
||||
return result;
|
||||
}
|
||||
|
||||
// Fast hi part of the full product. The normal product `operator*` returns
|
||||
// `Bits` least significant bits of the full product, while this function will
|
||||
// approximate `Bits` most significant bits of the full product with errors
|
||||
// bounded by:
|
||||
// 0 <= (a.full_mul(b) >> Bits) - a.quick_mul_hi(b)) <= WordCount - 1.
|
||||
//
|
||||
// An example usage of this is to quickly (but less accurately) compute the
|
||||
// product of (normalized) mantissas of floating point numbers:
|
||||
// (mant_1, mant_2) -> quick_mul_hi -> normalize leading bit
|
||||
// is much more efficient than:
|
||||
// (mant_1, mant_2) -> ful_mul -> normalize leading bit
|
||||
// -> convert back to same Bits width by shifting/rounding,
|
||||
// especially for higher precisions.
|
||||
//
|
||||
// Performance summary:
|
||||
// Number of 64-bit x 64-bit -> 128-bit multiplications performed.
|
||||
// Bits WordCount ful_mul quick_mul_hi Error bound
|
||||
// 128 2 4 3 1
|
||||
// 196 3 9 6 2
|
||||
// 256 4 16 10 3
|
||||
// 512 8 64 36 7
|
||||
constexpr UInt<Bits> quick_mul_hi(const UInt<Bits> &other) const {
|
||||
UInt<Bits> result(0);
|
||||
UInt<128> partial_sum(0);
|
||||
uint64_t carry = 0;
|
||||
// First round of accumulation for those at WordCount - 1 in the full
|
||||
// product.
|
||||
for (size_t i = 0; i < WordCount; ++i) {
|
||||
NumberPair<uint64_t> prod =
|
||||
full_mul(val[i], other.val[WordCount - 1 - i]);
|
||||
UInt<128> tmp({prod.lo, prod.hi});
|
||||
carry += partial_sum.add(tmp);
|
||||
}
|
||||
for (size_t i = WordCount; i < 2 * WordCount - 1; ++i) {
|
||||
partial_sum.val[0] = partial_sum.val[1];
|
||||
partial_sum.val[1] = carry;
|
||||
carry = 0;
|
||||
for (size_t j = i - WordCount + 1; j < WordCount; ++j) {
|
||||
NumberPair<uint64_t> prod = full_mul(val[j], other.val[i - j]);
|
||||
UInt<128> tmp({prod.lo, prod.hi});
|
||||
carry += partial_sum.add(tmp);
|
||||
}
|
||||
result.val[i - WordCount] = partial_sum.val[0];
|
||||
}
|
||||
result.val[WordCount - 1] = partial_sum.val[1];
|
||||
return result;
|
||||
}
|
||||
|
||||
// pow takes a power and sets this to its starting value to that power. Zero
|
||||
// to the zeroth power returns 1.
|
||||
constexpr void pow_n(uint64_t power) {
|
||||
|
|
|
@ -64,11 +64,11 @@ add_libc_unittest(
|
|||
)
|
||||
|
||||
add_libc_unittest(
|
||||
uint128_test
|
||||
uint_test
|
||||
SUITE
|
||||
libc_support_unittests
|
||||
SRCS
|
||||
uint128_test.cpp
|
||||
uint_test.cpp
|
||||
DEPENDS
|
||||
libc.src.__support.uint
|
||||
libc.src.__support.CPP.optional
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===-- Unittests for the 128 bit integer class ---------------------------===//
|
||||
//===-- Unittests for the UInt integer class ------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -17,15 +17,18 @@
|
|||
using LL_UInt128 = __llvm_libc::cpp::UInt<128>;
|
||||
using LL_UInt192 = __llvm_libc::cpp::UInt<192>;
|
||||
using LL_UInt256 = __llvm_libc::cpp::UInt<256>;
|
||||
using LL_UInt320 = __llvm_libc::cpp::UInt<320>;
|
||||
using LL_UInt512 = __llvm_libc::cpp::UInt<512>;
|
||||
using LL_UInt1024 = __llvm_libc::cpp::UInt<1024>;
|
||||
|
||||
TEST(LlvmLibcUInt128ClassTest, BasicInit) {
|
||||
TEST(LlvmLibcUIntClassTest, BasicInit) {
|
||||
LL_UInt128 empty;
|
||||
LL_UInt128 half_val(12345);
|
||||
LL_UInt128 full_val({12345, 67890});
|
||||
ASSERT_TRUE(half_val != full_val);
|
||||
}
|
||||
|
||||
TEST(LlvmLibcUInt128ClassTest, AdditionTests) {
|
||||
TEST(LlvmLibcUIntClassTest, AdditionTests) {
|
||||
LL_UInt128 val1(12345);
|
||||
LL_UInt128 val2(54321);
|
||||
LL_UInt128 result1(66666);
|
||||
|
@ -65,7 +68,7 @@ TEST(LlvmLibcUInt128ClassTest, AdditionTests) {
|
|||
EXPECT_EQ(val9 + val10, val10 + val9);
|
||||
}
|
||||
|
||||
TEST(LlvmLibcUInt128ClassTest, SubtractionTests) {
|
||||
TEST(LlvmLibcUIntClassTest, SubtractionTests) {
|
||||
LL_UInt128 val1(12345);
|
||||
LL_UInt128 val2(54321);
|
||||
LL_UInt128 result1({0xffffffffffff5c08, 0xffffffffffffffff});
|
||||
|
@ -94,7 +97,7 @@ TEST(LlvmLibcUInt128ClassTest, SubtractionTests) {
|
|||
EXPECT_EQ(val6, val5 + result6);
|
||||
}
|
||||
|
||||
TEST(LlvmLibcUInt128ClassTest, MultiplicationTests) {
|
||||
TEST(LlvmLibcUIntClassTest, MultiplicationTests) {
|
||||
LL_UInt128 val1({5, 0});
|
||||
LL_UInt128 val2({10, 0});
|
||||
LL_UInt128 result1({50, 0});
|
||||
|
@ -154,7 +157,7 @@ TEST(LlvmLibcUInt128ClassTest, MultiplicationTests) {
|
|||
EXPECT_EQ((val13 * val14), (val14 * val13));
|
||||
}
|
||||
|
||||
TEST(LlvmLibcUInt128ClassTest, DivisionTests) {
|
||||
TEST(LlvmLibcUIntClassTest, DivisionTests) {
|
||||
LL_UInt128 val1({10, 0});
|
||||
LL_UInt128 val2({5, 0});
|
||||
LL_UInt128 result1({2, 0});
|
||||
|
@ -201,7 +204,7 @@ TEST(LlvmLibcUInt128ClassTest, DivisionTests) {
|
|||
EXPECT_FALSE(val13.div(val14).has_value());
|
||||
}
|
||||
|
||||
TEST(LlvmLibcUInt128ClassTest, ModuloTests) {
|
||||
TEST(LlvmLibcUIntClassTest, ModuloTests) {
|
||||
LL_UInt128 val1({10, 0});
|
||||
LL_UInt128 val2({5, 0});
|
||||
LL_UInt128 result1({0, 0});
|
||||
|
@ -248,7 +251,7 @@ TEST(LlvmLibcUInt128ClassTest, ModuloTests) {
|
|||
EXPECT_EQ((val17 % val18), result9);
|
||||
}
|
||||
|
||||
TEST(LlvmLibcUInt128ClassTest, PowerTests) {
|
||||
TEST(LlvmLibcUIntClassTest, PowerTests) {
|
||||
LL_UInt128 val1({10, 0});
|
||||
val1.pow_n(30);
|
||||
LL_UInt128 result1({5076944270305263616, 54210108624}); // (10 ^ 30)
|
||||
|
@ -299,7 +302,7 @@ TEST(LlvmLibcUInt128ClassTest, PowerTests) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(LlvmLibcUInt128ClassTest, ShiftLeftTests) {
|
||||
TEST(LlvmLibcUIntClassTest, ShiftLeftTests) {
|
||||
LL_UInt128 val1(0x0123456789abcdef);
|
||||
LL_UInt128 result1(0x123456789abcdef0);
|
||||
EXPECT_EQ((val1 << 4), result1);
|
||||
|
@ -325,7 +328,7 @@ TEST(LlvmLibcUInt128ClassTest, ShiftLeftTests) {
|
|||
EXPECT_EQ((val2 << 256), result6);
|
||||
}
|
||||
|
||||
TEST(LlvmLibcUInt128ClassTest, ShiftRightTests) {
|
||||
TEST(LlvmLibcUIntClassTest, ShiftRightTests) {
|
||||
LL_UInt128 val1(0x0123456789abcdef);
|
||||
LL_UInt128 result1(0x00123456789abcde);
|
||||
EXPECT_EQ((val1 >> 4), result1);
|
||||
|
@ -351,7 +354,7 @@ TEST(LlvmLibcUInt128ClassTest, ShiftRightTests) {
|
|||
EXPECT_EQ((val2 >> 256), result6);
|
||||
}
|
||||
|
||||
TEST(LlvmLibcUInt128ClassTest, AndTests) {
|
||||
TEST(LlvmLibcUIntClassTest, AndTests) {
|
||||
LL_UInt128 base({0xffff00000000ffff, 0xffffffff00000000});
|
||||
LL_UInt128 val128({0xf0f0f0f00f0f0f0f, 0xff00ff0000ff00ff});
|
||||
uint64_t val64 = 0xf0f0f0f00f0f0f0f;
|
||||
|
@ -364,7 +367,7 @@ TEST(LlvmLibcUInt128ClassTest, AndTests) {
|
|||
EXPECT_EQ((base & val32), result32);
|
||||
}
|
||||
|
||||
TEST(LlvmLibcUInt128ClassTest, OrTests) {
|
||||
TEST(LlvmLibcUIntClassTest, OrTests) {
|
||||
LL_UInt128 base({0xffff00000000ffff, 0xffffffff00000000});
|
||||
LL_UInt128 val128({0xf0f0f0f00f0f0f0f, 0xff00ff0000ff00ff});
|
||||
uint64_t val64 = 0xf0f0f0f00f0f0f0f;
|
||||
|
@ -377,7 +380,7 @@ TEST(LlvmLibcUInt128ClassTest, OrTests) {
|
|||
EXPECT_EQ((base | val32), result32);
|
||||
}
|
||||
|
||||
TEST(LlvmLibcUInt128ClassTest, CompoundAssignments) {
|
||||
TEST(LlvmLibcUIntClassTest, CompoundAssignments) {
|
||||
LL_UInt128 x({0xffff00000000ffff, 0xffffffff00000000});
|
||||
LL_UInt128 b({0xf0f0f0f00f0f0f0f, 0xff00ff0000ff00ff});
|
||||
|
||||
|
@ -419,7 +422,7 @@ TEST(LlvmLibcUInt128ClassTest, CompoundAssignments) {
|
|||
EXPECT_EQ(a, mul_result);
|
||||
}
|
||||
|
||||
TEST(LlvmLibcUInt128ClassTest, UnaryPredecrement) {
|
||||
TEST(LlvmLibcUIntClassTest, UnaryPredecrement) {
|
||||
LL_UInt128 a = LL_UInt128({0x1111111111111111, 0x1111111111111111});
|
||||
++a;
|
||||
EXPECT_EQ(a, LL_UInt128({0x1111111111111112, 0x1111111111111111}));
|
||||
|
@ -433,7 +436,7 @@ TEST(LlvmLibcUInt128ClassTest, UnaryPredecrement) {
|
|||
EXPECT_EQ(a, LL_UInt128({0x0, 0x0}));
|
||||
}
|
||||
|
||||
TEST(LlvmLibcUInt128ClassTest, EqualsTests) {
|
||||
TEST(LlvmLibcUIntClassTest, EqualsTests) {
|
||||
LL_UInt128 a1({0xffffffff00000000, 0xffff00000000ffff});
|
||||
LL_UInt128 a2({0xffffffff00000000, 0xffff00000000ffff});
|
||||
LL_UInt128 b({0xff00ff0000ff00ff, 0xf0f0f0f00f0f0f0f});
|
||||
|
@ -449,7 +452,7 @@ TEST(LlvmLibcUInt128ClassTest, EqualsTests) {
|
|||
ASSERT_TRUE(a_lower != a_upper);
|
||||
}
|
||||
|
||||
TEST(LlvmLibcUInt128ClassTest, ComparisonTests) {
|
||||
TEST(LlvmLibcUIntClassTest, ComparisonTests) {
|
||||
LL_UInt128 a({0xffffffff00000000, 0xffff00000000ffff});
|
||||
LL_UInt128 b({0xff00ff0000ff00ff, 0xf0f0f0f00f0f0f0f});
|
||||
EXPECT_GT(a, b);
|
||||
|
@ -467,3 +470,40 @@ TEST(LlvmLibcUInt128ClassTest, ComparisonTests) {
|
|||
EXPECT_LE(a, a);
|
||||
EXPECT_GE(a, a);
|
||||
}
|
||||
|
||||
TEST(LlvmLibcUIntClassTest, FullMulTests) {
|
||||
LL_UInt128 a({0xffffffffffffffffULL, 0xffffffffffffffffULL});
|
||||
LL_UInt128 b({0xfedcba9876543210ULL, 0xfefdfcfbfaf9f8f7ULL});
|
||||
LL_UInt256 r({0x0123456789abcdf0ULL, 0x0102030405060708ULL,
|
||||
0xfedcba987654320fULL, 0xfefdfcfbfaf9f8f7ULL});
|
||||
LL_UInt128 r_hi({0xfedcba987654320eULL, 0xfefdfcfbfaf9f8f7ULL});
|
||||
|
||||
EXPECT_EQ(a.ful_mul(b), r);
|
||||
EXPECT_EQ(a.quick_mul_hi(b), r_hi);
|
||||
|
||||
LL_UInt192 c(
|
||||
{0x7766554433221101ULL, 0xffeeddccbbaa9988ULL, 0x1f2f3f4f5f6f7f8fULL});
|
||||
LL_UInt320 rr({0x8899aabbccddeeffULL, 0x0011223344556677ULL,
|
||||
0x583715f4d3b29171ULL, 0xffeeddccbbaa9988ULL,
|
||||
0x1f2f3f4f5f6f7f8fULL});
|
||||
|
||||
EXPECT_EQ(a.ful_mul(c), rr);
|
||||
EXPECT_EQ(a.ful_mul(c), c.ful_mul(a));
|
||||
}
|
||||
|
||||
#define TEST_QUICK_MUL_HI(Bits, Error) \
|
||||
do { \
|
||||
LL_UInt##Bits a = ~LL_UInt##Bits(0); \
|
||||
LL_UInt##Bits hi = a.quick_mul_hi(a); \
|
||||
LL_UInt##Bits trunc = static_cast<LL_UInt##Bits>(a.ful_mul(a) >> Bits); \
|
||||
uint64_t overflow = trunc.sub(hi); \
|
||||
EXPECT_EQ(overflow, uint64_t(0)); \
|
||||
EXPECT_LE(uint64_t(trunc), uint64_t(Error)); \
|
||||
} while (0)
|
||||
|
||||
TEST(LlvmLibcUIntClassTest, QuickMulHiTests) {
|
||||
TEST_QUICK_MUL_HI(128, 1);
|
||||
TEST_QUICK_MUL_HI(192, 2);
|
||||
TEST_QUICK_MUL_HI(256, 3);
|
||||
TEST_QUICK_MUL_HI(512, 7);
|
||||
}
|
|
@ -288,6 +288,11 @@ template bool test<__llvm_libc::cpp::UInt<256>>(
|
|||
__llvm_libc::cpp::UInt<256> RHS, const char *LHSStr, const char *RHSStr,
|
||||
const char *File, unsigned long Line);
|
||||
|
||||
template bool test<__llvm_libc::cpp::UInt<320>>(
|
||||
RunContext *Ctx, TestCondition Cond, __llvm_libc::cpp::UInt<320> LHS,
|
||||
__llvm_libc::cpp::UInt<320> RHS, const char *LHSStr, const char *RHSStr,
|
||||
const char *File, unsigned long Line);
|
||||
|
||||
template bool test<__llvm_libc::cpp::string_view>(
|
||||
RunContext *Ctx, TestCondition Cond, __llvm_libc::cpp::string_view LHS,
|
||||
__llvm_libc::cpp::string_view RHS, const char *LHSStr, const char *RHSStr,
|
||||
|
|
Loading…
Reference in New Issue