Skip to content

Commit

Permalink
[XLS] Allow conditional specialization to recognize basic implications
Browse files Browse the repository at this point in the history
By recognizing certain special cases (equality, NOT, etc.), we can lift some conditions to apply to earlier nodes, letting conditional specialization infer more about the context in which each node is used.

PiperOrigin-RevId: 710811273
  • Loading branch information
ericastor authored and copybara-github committed Dec 30, 2024
1 parent 47bdc59 commit d41fd86
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 31 deletions.
1 change: 1 addition & 0 deletions xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1764,6 +1764,7 @@ cc_library(
"//xls/common:module_initializer",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/data_structures:leaf_type_tree",
"//xls/ir",
"//xls/ir:bits",
"//xls/ir:bits_ops",
Expand Down
179 changes: 148 additions & 31 deletions xls/passes/conditional_specialization_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "xls/common/module_initializer.h"
#include "xls/common/status/ret_check.h"
#include "xls/common/status/status_macros.h"
#include "xls/data_structures/leaf_type_tree.h"
#include "xls/ir/bits.h"
#include "xls/ir/bits_ops.h"
#include "xls/ir/function_base.h"
Expand Down Expand Up @@ -187,6 +188,110 @@ class ConditionSet {
CHECK_LE(conditions_.size(), kMaxConditions);
}

void AddImpliedConditions(const Condition& condition,
QueryEngine& query_engine) {
AddCondition(condition);

if (condition.node->op() == Op::kNot &&
!ternary_ops::AllUnknown(condition.value) &&
!condition.node->operand(0)->Is<Literal>()) {
Node* operand = condition.node->operand(0);

VLOG(4) << "Lifting a known negated value: not(" << operand->GetName()
<< ") == " << xls::ToString(condition.value);

TernaryVector negated = condition.value;
for (int64_t i = 0; i < negated.size(); ++i) {
if (negated[i] == TernaryValue::kKnownOne) {
negated[i] = TernaryValue::kKnownZero;
} else if (negated[i] == TernaryValue::kKnownZero) {
negated[i] = TernaryValue::kKnownOne;
}
}
AddCondition(Condition{.node = operand, .value = negated});
}

if (condition.node->op() == Op::kOr &&
absl::c_any_of(condition.value, [](TernaryValue v) {
return v == TernaryValue::kKnownZero;
})) {
VLOG(4) << "Lifting known bits through an OR; or("
<< absl::StrJoin(condition.node->operands(), ", ",
[](std::string* out, Node* node) {
absl::StrAppend(out, node->GetName());
})
<< ") == " << xls::ToString(condition.value);
TernaryVector lifted = condition.value;
for (int64_t i = 0; i < lifted.size(); ++i) {
if (lifted[i] == TernaryValue::kKnownOne) {
lifted[i] = TernaryValue::kUnknown;
}
}
for (Node* operand : condition.node->operands()) {
if (operand->Is<Literal>()) {
continue;
}
AddImpliedConditions(Condition{.node = operand, .value = lifted},
query_engine);
}
}

if (condition.node->op() == Op::kAnd &&
absl::c_any_of(condition.value, [](TernaryValue v) {
return v == TernaryValue::kKnownOne;
})) {
VLOG(4) << "Lifting known bits through an AND; and("
<< absl::StrJoin(condition.node->operands(), ", ",
[](std::string* out, Node* node) {
absl::StrAppend(out, node->GetName());
})
<< ") == " << xls::ToString(condition.value);
TernaryVector lifted = condition.value;
for (int64_t i = 0; i < lifted.size(); ++i) {
if (lifted[i] == TernaryValue::kKnownZero) {
lifted[i] = TernaryValue::kUnknown;
}
}
for (Node* operand : condition.node->operands()) {
if (operand->Is<Literal>()) {
continue;
}
AddImpliedConditions(Condition{.node = operand, .value = lifted},
query_engine);
}
}

if ((condition.node->op() == Op::kEq &&
ternary_ops::IsKnownOne(condition.value)) ||
(condition.node->op() == Op::kNe &&
ternary_ops::IsKnownZero(condition.value))) {
Node* lhs = condition.node->operand(0);
Node* rhs = condition.node->operand(1);

VLOG(4) << "Converting a known equality to direct conditions: "
<< lhs->GetName() << " == " << rhs->GetName();

if (std::optional<SharedLeafTypeTree<TernaryVector>> lhs_ternary =
query_engine.GetTernary(lhs);
!rhs->Is<Literal>() && rhs->GetType()->IsBits() &&
lhs_ternary.has_value() &&
!ternary_ops::AllUnknown(lhs_ternary->Get({}))) {
AddImpliedConditions(
Condition{.node = rhs, .value = lhs_ternary->Get({})},
query_engine);
}
if (std::optional<SharedLeafTypeTree<TernaryVector>> rhs_ternary =
query_engine.GetTernary(rhs);
!lhs->Is<Literal>() && lhs->GetType()->IsBits() &&
rhs_ternary.has_value() &&
!ternary_ops::AllUnknown(rhs_ternary->Get({}))) {
AddImpliedConditions(
Condition{.node = lhs, .value = rhs_ternary->Get({})},
query_engine);
}
}
}

absl::Span<const Condition> conditions() const { return conditions_; }

std::string ToString() const {
Expand Down Expand Up @@ -571,11 +676,13 @@ absl::StatusOr<bool> ConditionalSpecializationPass::RunOnFunctionBaseInternal(
ConditionSet edge_set = set;
// If this case is selected, we know the selector is exactly
// `case_no`.
edge_set.AddCondition(Condition{
.node = select->selector(),
.value = ternary_ops::BitsToTernary(
UBits(case_no, select->selector()->BitCountOrDie())),
});
edge_set.AddImpliedConditions(
Condition{
.node = select->selector(),
.value = ternary_ops::BitsToTernary(
UBits(case_no, select->selector()->BitCountOrDie())),
},
query_engine);
condition_map.SetEdgeConditionSet(node, case_no + 1,
std::move(edge_set));
}
Expand All @@ -591,10 +698,12 @@ absl::StatusOr<bool> ConditionalSpecializationPass::RunOnFunctionBaseInternal(
TernaryVector selector_value(select->selector()->BitCountOrDie(),
TernaryValue::kUnknown);
selector_value[case_no] = TernaryValue::kKnownOne;
edge_set.AddCondition(Condition{
.node = select->selector(),
.value = selector_value,
});
edge_set.AddImpliedConditions(
Condition{
.node = select->selector(),
.value = selector_value,
},
query_engine);
condition_map.SetEdgeConditionSet(node, case_no + 1,
std::move(edge_set));
}
Expand All @@ -612,22 +721,26 @@ absl::StatusOr<bool> ConditionalSpecializationPass::RunOnFunctionBaseInternal(
known_bits.SetRange(0, case_no + 1);
Bits known_bits_values =
Bits::PowerOfTwo(case_no, select->selector()->BitCountOrDie());
edge_set.AddCondition(Condition{
.node = select->selector(),
.value =
ternary_ops::FromKnownBits(known_bits, known_bits_values),
});
edge_set.AddImpliedConditions(
Condition{
.node = select->selector(),
.value =
ternary_ops::FromKnownBits(known_bits, known_bits_values),
},
query_engine);
condition_map.SetEdgeConditionSet(node, case_no + 1,
std::move(edge_set));
}
ConditionSet edge_set = set;
// If the default value is selected, we know all the bits of the
// selector are zero.
edge_set.AddCondition(Condition{
.node = select->selector(),
.value = TernaryVector(select->selector()->BitCountOrDie(),
TernaryValue::kKnownZero),
});
edge_set.AddImpliedConditions(
Condition{
.node = select->selector(),
.value = TernaryVector(select->selector()->BitCountOrDie(),
TernaryValue::kKnownZero),
},
query_engine);
condition_map.SetEdgeConditionSet(node, select->cases().size() + 1,
std::move(edge_set));
}
Expand All @@ -649,14 +762,16 @@ absl::StatusOr<bool> ConditionalSpecializationPass::RunOnFunctionBaseInternal(

// ArrayUpdate is a no-op if any index is out of range; as such, it only
// cares about the update value if all indices are in range.
edge_set.AddCondition(Condition{
.node = index,
.value =
TernaryVector(index->BitCountOrDie(), TernaryValue::kUnknown),
.range = IntervalSet::Of(
{Interval::RightOpen(UBits(0, index->BitCountOrDie()),
UBits(array_type->AsArrayOrDie()->size(),
index->BitCountOrDie()))})});
edge_set.AddImpliedConditions(
Condition{
.node = index,
.value = TernaryVector(index->BitCountOrDie(),
TernaryValue::kUnknown),
.range = IntervalSet::Of({Interval::RightOpen(
UBits(0, index->BitCountOrDie()),
UBits(array_size, index->BitCountOrDie()))}),
},
query_engine);

array_type = array_type->AsArrayOrDie()->element_type();
}
Expand All @@ -680,8 +795,9 @@ absl::StatusOr<bool> ConditionalSpecializationPass::RunOnFunctionBaseInternal(
node->GetName(), predicate->GetName(), send->data()->GetName());

ConditionSet edge_set = set;
edge_set.AddCondition(
Condition{.node = predicate, .value = {TernaryValue::kKnownOne}});
edge_set.AddImpliedConditions(
Condition{.node = predicate, .value = {TernaryValue::kKnownOne}},
query_engine);
condition_map.SetEdgeConditionSet(node, Send::kDataOperand,
std::move(edge_set));
}
Expand All @@ -703,8 +819,9 @@ absl::StatusOr<bool> ConditionalSpecializationPass::RunOnFunctionBaseInternal(
node->GetName(), predicate->GetName(), next->value()->GetName());

ConditionSet edge_set = set;
edge_set.AddCondition(
Condition{.node = predicate, .value = {TernaryValue::kKnownOne}});
edge_set.AddImpliedConditions(
Condition{.node = predicate, .value = {TernaryValue::kKnownOne}},
query_engine);
condition_map.SetEdgeConditionSet(node, Next::kValueOperand,
std::move(edge_set));
}
Expand Down
80 changes: 80 additions & 0 deletions xls/passes/conditional_specialization_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1072,5 +1072,85 @@ TEST_F(ConditionalSpecializationPassTest, NextValueChange) {
m::StateRead("value2"), m::Eq())));
}

TEST_F(ConditionalSpecializationPassTest, ImpliedConditionThroughNot) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u1 = p->GetBitsType(1);
BValue a = fb.Param("a", u1);
BValue b = fb.Param("b", u1);
BValue s = fb.Not(a);
BValue result = fb.Select(s, {a, b});
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(result));

solvers::z3::ScopedVerifyEquivalence sve{f};
EXPECT_THAT(Run(f, /*use_bdd=*/false), IsOkAndHolds(true));
EXPECT_THAT(f->return_value(),
m::Select(m::Not(m::Param("a")), {m::Literal(1), m::Param("b")}));
}

TEST_F(ConditionalSpecializationPassTest, ImpliedConditionThroughEq) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u1 = p->GetBitsType(1);
BValue a = fb.Param("a", u1);
BValue b = fb.Param("b", u1);
BValue s = fb.Eq(b, fb.Literal(UBits(1, 1)));
BValue result = fb.Select(s, {a, b});
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(result));

solvers::z3::ScopedVerifyEquivalence sve{f};
EXPECT_THAT(Run(f, /*use_bdd=*/false), IsOkAndHolds(true));
EXPECT_THAT(f->return_value(), m::Select(m::Eq(m::Param("b"), m::Literal(1)),
{m::Param("a"), m::Literal(1)}));
}

TEST_F(ConditionalSpecializationPassTest, ImpliedConditionThroughNe) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u1 = p->GetBitsType(1);
BValue a = fb.Param("a", u1);
BValue b = fb.Param("b", u1);
BValue s = fb.Ne(a, fb.Literal(UBits(1, 1)));
BValue result = fb.Select(s, {a, b});
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(result));

solvers::z3::ScopedVerifyEquivalence sve{f};
EXPECT_THAT(Run(f, /*use_bdd=*/false), IsOkAndHolds(true));
EXPECT_THAT(f->return_value(), m::Select(m::Ne(m::Param("a"), m::Literal(1)),
{m::Literal(1), m::Param("b")}));
}

TEST_F(ConditionalSpecializationPassTest, ImpliedConditionThroughOr) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u1 = p->GetBitsType(1);
BValue a = fb.Param("a", u1);
BValue b = fb.Param("b", u1);
BValue s = fb.Or(a, b);
BValue result = fb.Select(s, {a, b});
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(result));

solvers::z3::ScopedVerifyEquivalence sve{f};
EXPECT_THAT(Run(f, /*use_bdd=*/false), IsOkAndHolds(true));
EXPECT_THAT(f->return_value(), m::Select(m::Or(m::Param("a"), m::Param("b")),
{m::Literal(0), m::Param("b")}));
}

TEST_F(ConditionalSpecializationPassTest, ImpliedConditionThroughAnd) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u1 = p->GetBitsType(1);
BValue a = fb.Param("a", u1);
BValue b = fb.Param("b", u1);
BValue s = fb.And(a, b);
BValue result = fb.Select(s, {a, b});
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(result));

solvers::z3::ScopedVerifyEquivalence sve{f};
EXPECT_THAT(Run(f, /*use_bdd=*/false), IsOkAndHolds(true));
EXPECT_THAT(f->return_value(), m::Select(m::And(m::Param("a"), m::Param("b")),
{m::Param("a"), m::Literal(1)}));
}

} // namespace
} // namespace xls

0 comments on commit d41fd86

Please sign in to comment.