diff --git a/xls/ir/BUILD b/xls/ir/BUILD index 35915903cd..498a815f9b 100644 --- a/xls/ir/BUILD +++ b/xls/ir/BUILD @@ -208,6 +208,7 @@ cc_test( deps = [ ":bits", ":bits_ops", + ":bits_test_utils", ":interval", ":interval_test_utils", "//xls/common:xls_gunit", @@ -249,6 +250,7 @@ cc_test( ":interval", ":interval_ops", ":interval_set", + ":interval_test_utils", ":ir", ":ir_test_base", ":ternary", diff --git a/xls/ir/interval.cc b/xls/ir/interval.cc index 7e8ca955fb..c9c78e7be7 100644 --- a/xls/ir/interval.cc +++ b/xls/ir/interval.cc @@ -292,12 +292,13 @@ bool Interval::IsMaximal() const { bool Interval::IsTrueWhenAndWith(const Bits& value) const { CHECK_EQ(value.bit_count(), BitCount()); - int64_t right_index = std::min(LowerBound().CountTrailingZeros(), - UpperBound().CountTrailingZeros()); - int64_t left_index = BitCount() - UpperBound().CountLeadingZeros(); - Bits interval_mask_value(BitCount()); - interval_mask_value.SetRange(right_index, left_index); - return !bits_ops::And(interval_mask_value, value).IsZero(); + BitsRope interval_mask_value(BitCount()); + Bits common_prefix = + bits_ops::LongestCommonPrefixMSB({LowerBound(), UpperBound()}); + interval_mask_value.push_back( + Bits::AllOnes(BitCount() - common_prefix.bit_count())); + interval_mask_value.push_back(common_prefix); + return !bits_ops::And(interval_mask_value.Build(), value).IsZero(); } bool Interval::Covers(const Bits& point) const { diff --git a/xls/ir/interval.h b/xls/ir/interval.h index 3c71619704..8e074bfb6f 100644 --- a/xls/ir/interval.h +++ b/xls/ir/interval.h @@ -222,6 +222,11 @@ class Interval { interval.upper_bound_); } + template + friend void AbslStringify(Sink& sink, const Interval& interval) { + absl::Format(&sink, "%s", interval.ToString()); + } + private: void EnsureValid() const { CHECK(is_valid_); } diff --git a/xls/ir/interval_ops.cc b/xls/ir/interval_ops.cc index afa7a6db94..846746ed3e 100644 --- a/xls/ir/interval_ops.cc +++ b/xls/ir/interval_ops.cc @@ -21,7 +21,6 @@ #include #include #include -#include #include "absl/algorithm/container.h" #include "absl/log/check.h" @@ -30,7 +29,6 @@ #include "xls/ir/bits_ops.h" #include "xls/ir/interval.h" #include "xls/ir/interval_set.h" -#include "xls/ir/lsb_or_msb.h" #include "xls/ir/node.h" #include "xls/ir/ternary.h" #include "xls/passes/ternary_evaluator.h" @@ -119,6 +117,130 @@ IntervalSet FromTernary(TernarySpan tern, int64_t max_interval_bits) { return is; } +bool CoversTernary(const Interval& interval, TernarySpan ternary) { + if (interval.BitCount() != ternary.size()) { + return false; + } + if (ternary_ops::IsFullyKnown(ternary)) { + return interval.Covers(ternary_ops::ToKnownBitsValues(ternary)); + } + if (interval.IsPrecise()) { + return ternary_ops::IsCompatible(ternary, interval.LowerBound()); + } + + Bits lcp = bits_ops::LongestCommonPrefixMSB( + {interval.LowerBound(), interval.UpperBound()}); + + // We know the next bit of the bounds of `interval` differs, and the interval + // is proper iff the upper bound has a 1 there. + const bool proper = interval.UpperBound().GetFromMsb(lcp.bit_count()); + + TernarySpan prefix = ternary.subspan(ternary.size() - lcp.bit_count()); + + // If the interval is proper, then the interval only contains things with + // this least-common prefix. + if (proper && !ternary_ops::IsCompatible(prefix, lcp)) { + return false; + } + + // If the interval is improper, then it contains everything that doesn't share + // this prefix. Therefore, unless `prefix` is fully-known and matches the + // least-common prefix, `ternary` can definitely represent something in the + // interval. + if (!proper && !(ternary_ops::IsFullyKnown(prefix) && + ternary_ops::ToKnownBitsValues(prefix) == lcp)) { + return true; + } + + // Take the leading value in `ternary`. + TernaryValue x = ternary[ternary.size() - lcp.bit_count() - 1]; + + // Drop all the bits we've already confirmed match, plus one more. + Bits L = interval.LowerBound().Slice(0, ternary.size() - lcp.bit_count() - 1); + Bits U = interval.UpperBound().Slice(0, ternary.size() - lcp.bit_count() - 1); + TernarySpan t = ternary.subspan(0, ternary.size() - lcp.bit_count() - 1); + + auto could_be_le = [](TernarySpan t, const Bits& L) { + for (int64_t i = t.size() - 1; i >= 0; --i) { + if (L.Get(i)) { + if (t[i] != TernaryValue::kKnownOne) { + // If this bit is zero, it will make t < L. + return true; + } + } else if (t[i] == TernaryValue::kKnownOne) { + // We know t > L. + return false; + } + } + return true; + }; + auto could_be_ge = [](TernarySpan t, const Bits& U) { + for (int64_t i = t.size() - 1; i >= 0; --i) { + if (U.Get(i)) { + if (t[i] == TernaryValue::kKnownZero) { + // We know t < U. + return false; + } + } else if (t[i] != TernaryValue::kKnownZero) { + // If this bit is one, it will make t > L. + return true; + } + } + return true; + }; + + // NOTE: At this point, we want to know: + // + // if improper, whether it's possible to have: + // xt <= 0U || 1L <= xt, which is true iff + // (x == 0 && t <= U) || (x == 1 && L <= t). + // + // if proper, whether it's possible to have: + // 0L <= xt && xt <= 1U, which is true iff + // (x == 1 || L <= t) && (x == 0 || t <= U). + // + // If x is known, then this is easy: + // if x == 0 && proper: check if it's possible to have L <= t. + // if x == 1 && improper: check if it's possible to have L <= t. + // if x == 0 && improper: check if it's possible to have t <= U. + // if x == 1 && proper: check if it's possible to have t <= U. + // In other words: + // if (x == 0) == proper, check if it's possible to have L <= t. + // Otherwise, check if it's possible to have t <= U. + if (ternary_ops::IsKnown(x)) { + if ((x == TernaryValue::kKnownZero) == proper) { + return could_be_ge(t, L); + } + return could_be_le(t, U); + } + + // If x is unknown, then we can choose whichever value we want. Therefore, we + // just need to know: + // if improper, whether it's possible to have... well. + // if we take x == 0, then we just need to check if we can have t <= U. + // if we take x == 1, then we just need to check if we can have L <= t. + // Therefore, we just need to check whether it's possible to have: + // t <= U || L <= t. + // if proper, whether it's possible to have... well. + // If we take x == 1, then we just need to check if we can have t <= U. + // If we take x == 0, then we just need to check if we can have L <= t. + // Therefore, we just need to check whether it's possible to have: + // t <= U || L <= t. + // The conclusion is the same whether the interval is proper or improper, so + // we check this and we're done. + return could_be_le(t, U) || could_be_ge(t, L); +} + +bool CoversTernary(const IntervalSet& intervals, TernarySpan ternary) { + if (intervals.BitCount() != ternary.size()) { + return false; + } + return absl::c_any_of(intervals.Intervals(), + [&ternary](const Interval& interval) { + return CoversTernary(interval, ternary); + }); +} + namespace { enum class Tonicity : bool { Monotone, Antitone }; // What sort of behavior the argument exhibits diff --git a/xls/ir/interval_ops.h b/xls/ir/interval_ops.h index 2b777b6dae..578797f95c 100644 --- a/xls/ir/interval_ops.h +++ b/xls/ir/interval_ops.h @@ -50,6 +50,11 @@ IntervalSet FromTernary(TernarySpan ternary, int64_t max_interval_bits = 4); TernaryVector ExtractTernaryVector(const IntervalSet& intervals, std::optional source = std::nullopt); +// Determine whether the given `intervals` include any element that matches the +// given `ternary` span. +bool CoversTernary(const Interval& interval, TernarySpan ternary); +bool CoversTernary(const IntervalSet& intervals, TernarySpan ternary); + struct KnownBits { Bits known_bits; Bits known_bit_values; @@ -82,6 +87,9 @@ IntervalSet UMul(const IntervalSet& a, const IntervalSet& b, int64_t output_bitwidth); IntervalSet UDiv(const IntervalSet& a, const IntervalSet& b); +// Encode/decode +IntervalSet Decode(const IntervalSet& a, int64_t width); + // Bit ops. IntervalSet Not(const IntervalSet& a); IntervalSet And(const IntervalSet& a, const IntervalSet& b); diff --git a/xls/ir/interval_ops_test.cc b/xls/ir/interval_ops_test.cc index 4e9da56e3d..79decb1070 100644 --- a/xls/ir/interval_ops_test.cc +++ b/xls/ir/interval_ops_test.cc @@ -35,6 +35,7 @@ #include "xls/ir/function_builder.h" #include "xls/ir/interval.h" #include "xls/ir/interval_set.h" +#include "xls/ir/interval_test_utils.h" #include "xls/ir/ir_test_base.h" #include "xls/ir/nodes.h" #include "xls/ir/package.h" @@ -868,5 +869,27 @@ FUZZ_TEST(MinimizeIntervalsTest, MinimizeIntervalsGeneratesSuperset) fuzztest::NonNegative())), fuzztest::InRange(1, 256)); +void CoversTernaryWorksForIntervals(const Interval& interval, + TernarySpan ternary) { + EXPECT_EQ(interval_ops::CoversTernary(interval, ternary), + interval.ForEachElement([&](const Bits& element) { + return ternary == + ternary_ops::Intersection( + ternary_ops::BitsToTernary(element), ternary); + })) + << "interval: " + << absl::StrFormat("[%s, %s]", interval.LowerBound().ToDebugString(), + interval.UpperBound().ToDebugString()) + << ", ternary: " << ToString(ternary); +} +FUZZ_TEST(IntervalOpsFuzzTest, CoversTernaryWorksForIntervals) + .WithDomains(ArbitraryInterval(8), + fuzztest::VectorOf(fuzztest::ElementOf({ + TernaryValue::kKnownZero, + TernaryValue::kKnownOne, + TernaryValue::kUnknown, + })) + .WithSize(8)); + } // namespace } // namespace xls::interval_ops diff --git a/xls/ir/interval_set_test.cc b/xls/ir/interval_set_test.cc index ad75996477..a69172be3a 100644 --- a/xls/ir/interval_set_test.cc +++ b/xls/ir/interval_set_test.cc @@ -272,16 +272,19 @@ TEST(IntervalTest, Size) { } TEST(IntervalTest, IsTrueWhenMaskWith) { - IntervalSet example(3); - example.AddInterval(MakeInterval(0, 0, 3)); - for (int64_t value = 0; value < 8; ++value) { - EXPECT_FALSE(example.IsTrueWhenMaskWith(UBits(value, 3))); + IntervalSet example(4); + example.AddInterval(MakeInterval(0, 0, 4)); + for (int64_t value = 0; value < 16; ++value) { + EXPECT_FALSE(example.IsTrueWhenMaskWith(UBits(value, 4))); + } + example.AddInterval(MakeInterval(2, 4, 4)); + EXPECT_FALSE(example.IsTrueWhenMaskWith(UBits(0, 4))); + for (int64_t value = 1; value < 8; ++value) { + EXPECT_TRUE(example.IsTrueWhenMaskWith(UBits(value, 4))); } - example.AddInterval(MakeInterval(2, 4, 3)); - EXPECT_FALSE(example.IsTrueWhenMaskWith(UBits(0, 3))); - EXPECT_FALSE(example.IsTrueWhenMaskWith(UBits(1, 3))); - for (int64_t value = 2; value < 8; ++value) { - EXPECT_TRUE(example.IsTrueWhenMaskWith(UBits(value, 3))); + EXPECT_FALSE(example.IsTrueWhenMaskWith(UBits(8, 4))); + for (int64_t value = 9; value < 16; ++value) { + EXPECT_TRUE(example.IsTrueWhenMaskWith(UBits(value, 4))); } } diff --git a/xls/ir/interval_test.cc b/xls/ir/interval_test.cc index 95843a5c88..7726d658df 100644 --- a/xls/ir/interval_test.cc +++ b/xls/ir/interval_test.cc @@ -29,6 +29,7 @@ #include "xls/common/status/matchers.h" #include "xls/ir/bits.h" #include "xls/ir/bits_ops.h" +#include "xls/ir/bits_test_utils.h" #include "xls/ir/interval_test_utils.h" using ::testing::ElementsAre; @@ -373,24 +374,15 @@ TEST(IntervalTest, NonZeroStartingValueIsTrueWhenMaskWith) { // Test the IsTrueWhenMaskWith with an interval starting at zero. TEST(IntervalTest, ZeroStartingValueIsTrueWhenMaskWith) { - Interval interval(UBits(0, 3), UBits(4, 3)); + Interval interval(UBits(0, 4), UBits(4, 4)); - EXPECT_FALSE(interval.IsTrueWhenAndWith(UBits(0, 3))); - EXPECT_FALSE(interval.IsTrueWhenAndWith(UBits(1, 3))); - EXPECT_FALSE(interval.IsTrueWhenAndWith(UBits(2, 3))); - EXPECT_FALSE(interval.IsTrueWhenAndWith(UBits(3, 3))); - EXPECT_TRUE(interval.IsTrueWhenAndWith(UBits(4, 3))); - EXPECT_TRUE(interval.IsTrueWhenAndWith(UBits(5, 3))); - EXPECT_TRUE(interval.IsTrueWhenAndWith(UBits(6, 3))); - EXPECT_TRUE(interval.IsTrueWhenAndWith(UBits(7, 3))); -} - -// Test the IsTrueWhenMaskWith with an interval that does not overlap. -TEST(IntervalTest, NoOverlappingIntervalIsTrueWhenMaskWith) { - Interval interval(UBits(4, 3), UBits(0, 3)); - - for (int64_t value = 0; value < 8; ++value) { - EXPECT_FALSE(interval.IsTrueWhenAndWith(UBits(value, 3))); + EXPECT_FALSE(interval.IsTrueWhenAndWith(UBits(0, 4))); + for (int64_t value = 1; value < 7; ++value) { + EXPECT_TRUE(interval.IsTrueWhenAndWith(UBits(value, 4))); + } + EXPECT_FALSE(interval.IsTrueWhenAndWith(UBits(8, 4))); + for (int64_t value = 9; value < 16; ++value) { + EXPECT_TRUE(interval.IsTrueWhenAndWith(UBits(value, 4))); } } @@ -405,6 +397,27 @@ TEST(IntervalTest, OverlappingBitsIsTrueWhenMaskWith) { } } +// Test the IsTrueWhenMaskWith with an interval containing overlapping bits, but +// not overlapping for every bit at the ends of the interval. +TEST(IntervalTest, OverlappingBitsIsTrueWhenMaskWith2) { + Interval interval(UBits(2, 3), UBits(6, 3)); + + EXPECT_FALSE(interval.IsTrueWhenAndWith(UBits(0, 3))); + for (int64_t value = 1; value < 8; ++value) { + EXPECT_TRUE(interval.IsTrueWhenAndWith(UBits(value, 3))) + << "didn't match with value: " << value; + } +} + +void IsTrueWhenAndWith(const Interval& interval, const Bits& value) { + EXPECT_EQ(interval.IsTrueWhenAndWith(value), + interval.ForEachElement([&](const Bits& bits) -> bool { + return bits_ops::OrReduce(bits_ops::And(bits, value)).IsAllOnes(); + })); +} +FUZZ_TEST(IntervalFuzzTest, IsTrueWhenAndWith) + .WithDomains(ProperInterval(12), ArbitraryBits(12)); + TEST(IntervalTest, Covers) { Bits thirty_two = Bits::PowerOfTwo(5, 12); Bits sixty_four = Bits::PowerOfTwo(6, 12); diff --git a/xls/ir/ternary.cc b/xls/ir/ternary.cc index 479021aa7f..8f266d6c91 100644 --- a/xls/ir/ternary.cc +++ b/xls/ir/ternary.cc @@ -221,6 +221,22 @@ TernaryVector Intersection(TernarySpan lhs, TernarySpan rhs) { return result; } +bool IsCompatible(TernarySpan pattern, const Bits& bits) { + if (pattern.size() != bits.bit_count()) { + return false; + } + + for (int64_t i = 0; i < pattern.size(); ++i) { + if (pattern[i] == TernaryValue::kUnknown) { + continue; + } + if (bits.Get(i) != (pattern[i] == TernaryValue::kKnownOne)) { + return false; + } + } + return true; +} + void UpdateWithIntersection(TernaryVector& lhs, TernarySpan rhs) { CHECK_EQ(lhs.size(), rhs.size()); diff --git a/xls/ir/ternary.h b/xls/ir/ternary.h index ea19e7fd7c..385a31119c 100644 --- a/xls/ir/ternary.h +++ b/xls/ir/ternary.h @@ -102,6 +102,9 @@ absl::Status UpdateWithUnion(TernaryVector& lhs, TernarySpan rhs); // lengths. TernaryVector Intersection(TernarySpan lhs, TernarySpan rhs); +// Returns true if `bits` is a possible value for `pattern`. +bool IsCompatible(TernarySpan pattern, const Bits& bits); + // Updates `lhs`, turning it into a vector of bits known to have the same value // in both `lhs` and `rhs`. CHECK fails if `lhs` and `rhs` have different // lengths. diff --git a/xls/passes/range_query_engine.cc b/xls/passes/range_query_engine.cc index 053596693b..ad39d05bf4 100644 --- a/xls/passes/range_query_engine.cc +++ b/xls/passes/range_query_engine.cc @@ -770,11 +770,15 @@ absl::Status RangeQueryVisitor::HandlePrioritySel(PrioritySelect* sel) { lhs = IntervalSet::Combine(lhs, rhs); }); } + TernaryVector case_pattern(sel->cases().size(), TernaryValue::kUnknown); for (int64_t i = 0; i < sel->cases().size(); ++i) { // TODO(vmirian): Make implementation more efficient by considering only the // ranges of interest. - if (selector_intervals.IsTrueWhenMaskWith( - bits_ops::ShiftLeftLogical(UBits(1, sel->cases().size()), i))) { + case_pattern[i] = TernaryValue::kKnownOne; + if (i > 0) { + case_pattern[i - 1] = TernaryValue::kKnownZero; + } + if (interval_ops::CoversTernary(selector_intervals, case_pattern)) { leaf_type_tree::SimpleUpdateFrom( result.AsMutableView(), GetIntervalSetTree(sel->cases()[i]).AsView(), [](IntervalSet& lhs, const IntervalSet& rhs) { diff --git a/xls/passes/range_query_engine_test.cc b/xls/passes/range_query_engine_test.cc index 6f25d467f7..96eabc9038 100644 --- a/xls/passes/range_query_engine_test.cc +++ b/xls/passes/range_query_engine_test.cc @@ -911,6 +911,17 @@ TEST_F(RangeQueryEngineTest, PrioritySel) { EXPECT_EQ(IntervalSet::Combine(x_ist.Get({}), y_ist.Get({})), engine.GetIntervalSetTree(expr.node()).Get({})); + engine = RangeQueryEngine(); + engine.SetIntervalSetTree( + selector.node(), + BitsLTT(selector.node(), {Interval::Precise(UBits(4, 3))})); + engine.SetIntervalSetTree(x.node(), x_ist); + engine.SetIntervalSetTree(y.node(), y_ist); + engine.SetIntervalSetTree(z.node(), z_ist); + XLS_ASSERT_OK(engine.Populate(f)); + + EXPECT_EQ(z_ist.Get({}), engine.GetIntervalSetTree(expr.node()).Get({})); + engine = RangeQueryEngine(); engine.SetIntervalSetTree( selector.node(), @@ -926,16 +937,17 @@ TEST_F(RangeQueryEngineTest, PrioritySel) { engine.GetIntervalSetTree(expr.node()).Get({})); // Test case with overlapping bits for selector. + // TODO(epastor): Fix test once this is better supported. engine = RangeQueryEngine(); engine.SetIntervalSetTree( selector.node(), - BitsLTT(selector.node(), {Interval(UBits(2, 3), UBits(6, 3))})); + BitsLTT(selector.node(), {Interval(UBits(5, 3), UBits(7, 3))})); engine.SetIntervalSetTree(x.node(), x_ist); engine.SetIntervalSetTree(y.node(), y_ist); engine.SetIntervalSetTree(z.node(), z_ist); XLS_ASSERT_OK(engine.Populate(f)); - EXPECT_EQ(IntervalSet::Combine(y_ist.Get({}), z_ist.Get({})), + EXPECT_EQ(IntervalSet::Combine(x_ist.Get({}), y_ist.Get({})), engine.GetIntervalSetTree(expr.node()).Get({})); // Test case where default is covered.