Skip to content

Commit

Permalink
Change JIT to store 1-bit values as 8-bits in memory
Browse files Browse the repository at this point in the history
We were creating and storing to 1-bit allocas in the LLVM jit. This could in some situations interact badly with LLVM optimizations causing llvm to create bad code. Avoid this by simply following what clang does and giving bools 8 bits.

PiperOrigin-RevId: 713819834
  • Loading branch information
allight authored and copybara-github committed Jan 9, 2025
1 parent 2cdb89e commit fe94f8f
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 20 deletions.
65 changes: 65 additions & 0 deletions xls/interpreter/ir_evaluator_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1794,6 +1794,71 @@ TEST_P(IrEvaluatorTestBase, InterpretOneHotSelect2DArray) {
"[[bits[4]:12, bits[4]:14], [bits[4]:9, bits[4]:6]]")));
}

TEST_P(IrEvaluatorTestBase, OneBitSignExtend) {
Package package("my_package");

// 1-bit values are handled somewhat specially.
XLS_ASSERT_OK_AND_ASSIGN(Function * function,
ParseAndGetFunction(&package, R"(
fn __sample__main(x1: bits[4] id=1) -> (bits[24], bits[20][1]) {
or_reduce.328: bits[1] = or_reduce(x1, id=328)
not.335: bits[1] = not(or_reduce.328, id=335)
x16: bits[24] = sign_ext(not.335, new_bit_count=24, id=333)
x17: bits[20][1] = literal(value=[0], id=326, pos=[(0,28,36)])
ret tuple.308: (bits[24], bits[20][1]) = tuple(x16, x17, id=308, pos=[(0,38,8)])
}
)"));
XLS_ASSERT_OK_AND_ASSIGN(
Value expected0,
ValueBuilder::Tuple({ValueBuilder::Bits(UBits(0xffffff, 24)),
ValueBuilder::SBitsArray({0}, 20)})
.Build());
EXPECT_THAT(RunWithNoEvents(function, {Value(UBits(0, 4))}),
IsOkAndHolds(Value(expected0)));
XLS_ASSERT_OK_AND_ASSIGN(
Value expected_other,
ValueBuilder::Tuple({ValueBuilder::Bits(UBits(0x0, 24)),
ValueBuilder::SBitsArray({0}, 20)})
.Build());
EXPECT_THAT(RunWithNoEvents(function, {Value(UBits(1, 4))}),
IsOkAndHolds(Value(expected_other)));
}

TEST_P(IrEvaluatorTestBase, InterpretPrioritySelectBus) {
Package package("my_package");
// This was found to cause a bus error on jit.
XLS_ASSERT_OK_AND_ASSIGN(Function * function,
ParseAndGetFunction(&package, R"(
fn __sample__main(x1: bits[4] id=1) -> (bits[24], bits[20][1]) {
literal.299: (bits[1], bits[24]) = literal(value=(0, 0), id=299, pos=[(0,24,45)])
literal.170: (bits[1], bits[24]) = literal(value=(1, 16777215), id=170, pos=[(0,26,17)])
literal.141: bits[20][7] = literal(value=[0, 0, 0, 0, 0, 0, 0], id=141)
priority_sel.71: (bits[1], bits[24]) = priority_sel(x1, cases=[literal.299, literal.299, literal.299, literal.299], default=literal.170, id=71)
x15: bits[1] = tuple_index(priority_sel.71, index=0, id=72, pos=[(0,22,13)])
x16: bits[24] = tuple_index(priority_sel.71, index=1, id=73, pos=[(0,22,18)])
x17: bits[20][1] = array_slice(literal.141, x15, width=1, id=77, pos=[(0,28,36)])
ret tuple.308: (bits[24], bits[20][1]) = tuple(x16, x17, id=308, pos=[(0,38,8)])
}
)"));
XLS_ASSERT_OK_AND_ASSIGN(
Value expected0,
ValueBuilder::Tuple({ValueBuilder::Bits(UBits(0xffffff, 24)),
ValueBuilder::SBitsArray({0}, 20)})
.Build());
EXPECT_THAT(RunWithNoEvents(function, {Value(UBits(0, 4))}),
IsOkAndHolds(Value(expected0)));
XLS_ASSERT_OK_AND_ASSIGN(
Value expected_other,
ValueBuilder::Tuple({ValueBuilder::Bits(UBits(0x0, 24)),
ValueBuilder::SBitsArray({0}, 20)})
.Build());
EXPECT_THAT(RunWithNoEvents(function, {Value(UBits(1, 4))}),
IsOkAndHolds(Value(expected_other)));
}

TEST_P(IrEvaluatorTestBase, InterpretPrioritySelect) {
Package package("my_package");
XLS_ASSERT_OK_AND_ASSIGN(Function * function,
Expand Down
51 changes: 39 additions & 12 deletions xls/jit/ir_builder_visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,16 @@ absl::StatusOr<llvm::Function*> CreateLlvmFunction(
.getCallee());
}

// Get the truthiness of the given value (0 is falsy, all other values are
// truty).
llvm::Value* Truthiness(llvm::Value* value, llvm::IRBuilder<>& builder) {
CHECK(value->getType()->isIntegerTy());
return builder.CreateICmpNE(
value, llvm::ConstantInt::get(value->getType(), 0),
value->hasName() ? absl::StrFormat("%s_truthiness", value->getName())
: "");
}

// Abstraction gathering together the necessary context for emitting the LLVM IR
// for a given node. This data structure decouples IR generation for the
// top-level function from the IR generation of each node. This enables, for
Expand Down Expand Up @@ -1081,8 +1091,19 @@ void NodeIrContext::FinalizeWithValue(
std::optional<Type*> return_type) {
llvm::IRBuilder<>* b =
exit_builder.has_value() ? exit_builder.value() : &entry_builder();
result = type_converter().ClearPaddingBits(
result, return_type.value_or(node()->GetType()), *b);
// Special case 0 & 1 bit values to automatically extend them since using them
// as a boolean within ssa registers is a common pattern.
if (result->getType()->isIntegerTy(1)) {
CHECK(node()->GetType()->IsBits() &&
node()->GetType()->GetFlatBitCount() == 1)
<< node();
CHECK_EQ(type_converter().GetLlvmBitCount(node()->GetType()->AsBitsOrDie()),
8);
result = b->CreateZExt(result, b->getInt8Ty());
} else {
result = type_converter().ClearPaddingBits(
result, return_type.value_or(node()->GetType()), *b);
}
if (GetOutputPtrs().empty()) {
b->CreateRet(b->getFalse());
return;
Expand Down Expand Up @@ -1396,7 +1417,9 @@ absl::Status IrBuilderVisitor::HandleRegisterWrite(RegisterWrite* write) {
llvm::IRBuilder<> current_step_builder(current_step);
llvm::IRBuilder<> reset_selected_builder(*reset_selected);
XLS_ASSIGN_OR_RETURN(int64_t op_idx, write->reset_operand_number());
auto reset_state = node_context.LoadOperand(op_idx, &current_step_builder);
auto reset_state =
Truthiness(node_context.LoadOperand(op_idx, &current_step_builder),
current_step_builder);
if (write->GetRegister()->reset()->active_low) {
// current_step_builder.CreateCondBr(reset_state, *reset_selected,
// no_reset_selected);
Expand Down Expand Up @@ -1424,8 +1447,10 @@ absl::Status IrBuilderVisitor::HandleRegisterWrite(RegisterWrite* write) {
llvm::BasicBlock::Create(ctx(), "load_enabled", function);
llvm::IRBuilder<> no_load_enable_builder(*no_load_enable_selected);
llvm::IRBuilder<> current_step_builder(current_step);
auto load_enable_state = node_context.LoadOperand(
write->load_enable_operand_number().value(), &current_step_builder);
auto load_enable_state = Truthiness(
node_context.LoadOperand(write->load_enable_operand_number().value(),
&current_step_builder),
current_step_builder);
// the original value is at operand_count+1
XLS_ASSIGN_OR_RETURN(
no_load_enable_value,
Expand Down Expand Up @@ -1548,8 +1573,9 @@ absl::Status IrBuilderVisitor::HandleAssert(Assert* assert_op) {

fail_builder.CreateBr(after_block);

b.CreateCondBr(node_context.LoadOperand(Assert::kConditionOperand), ok_block,
fail_block);
b.CreateCondBr(
Truthiness(node_context.LoadOperand(Assert::kConditionOperand), b),
ok_block, fail_block);

auto after_builder = std::make_unique<llvm::IRBuilder<>>(after_block);
llvm::Value* token = type_converter()->GetToken();
Expand Down Expand Up @@ -1594,7 +1620,7 @@ absl::Status IrBuilderVisitor::HandleTrace(Trace* trace_op) {
/*include_wrapper_args=*/true));

llvm::IRBuilder<>& b = node_context.entry_builder();
llvm::Value* condition = node_context.LoadOperand(1);
llvm::Value* condition = Truthiness(node_context.LoadOperand(1), b);
llvm::Value* events_ptr = node_context.GetInterpreterEventsArg();
llvm::Value* jit_runtime_ptr = node_context.GetJitRuntimeArg();

Expand Down Expand Up @@ -2341,7 +2367,7 @@ absl::Status IrBuilderVisitor::HandleGate(Gate* gate) {
XLS_ASSIGN_OR_RETURN(NodeIrContext node_context,
NewNodeIrContext(gate, {"condition", "data"}));
llvm::IRBuilder<>& b = node_context.entry_builder();
llvm::Value* condition = node_context.LoadOperand(0);
llvm::Value* condition = Truthiness(node_context.LoadOperand(0), b);
llvm::Value* data = node_context.LoadOperand(1);

// TODO(meheff): 2022/09/09 Replace with a if/then/else block which does a
Expand Down Expand Up @@ -2660,7 +2686,7 @@ absl::Status IrBuilderVisitor::HandleNext(Next* next) {
}

// If the predicate is true, emulate the `next_value` node's effects.
llvm::Value* predicate = node_context.LoadOperand(2);
llvm::Value* predicate = Truthiness(node_context.LoadOperand(2), b);
LlvmIfThen if_then = CreateIfThen(predicate, b, next->GetName());

LlvmMemcpy(node_context.GetOutputPtr(0), value_ptr,
Expand Down Expand Up @@ -3353,7 +3379,8 @@ absl::Status IrBuilderVisitor::HandleReceive(Receive* recv) {
LlvmTypeConverter::ZeroOfType(data_type), data_buffer);

if (recv->predicate().has_value()) {
llvm::Value* predicate = node_context.LoadOperand(1);
llvm::Value* predicate =
Truthiness(node_context.LoadOperand(1), node_context.entry_builder());

// First, declare the join block (so the case blocks can refer to it).
llvm::BasicBlock* join_block =
Expand Down Expand Up @@ -3459,7 +3486,7 @@ absl::Status IrBuilderVisitor::HandleSend(Send* send) {
jit_context_.GetOrAllocateQueueIndex(send->channel_name());

if (send->predicate().has_value()) {
llvm::Value* predicate = node_context.LoadOperand(2);
llvm::Value* predicate = Truthiness(node_context.LoadOperand(2), b);

// First, declare the join block (so the case blocks can refer to it).
llvm::BasicBlock* join_block =
Expand Down
18 changes: 10 additions & 8 deletions xls/jit/llvm_type_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,9 @@ LlvmTypeConverter::LlvmTypeConverter(llvm::LLVMContext* context,
: context_(*context), data_layout_(data_layout) {}

int64_t LlvmTypeConverter::GetLlvmBitCount(int64_t xls_bit_count) const {
// LLVM does not accept 0-bit types, and we want to be able to JIT-compile
// unoptimized IR, so for the time being we make a dummy 1-bit value.
// See https://github.com/google/xls/issues/76
if (xls_bit_count <= 1) {
return 1;
}
// LLVM does not accept 0-bit types and < 8 bit types often have issues, and
// we want to be able to JIT-compile unoptimized IR, so for the time being we
// make a dummy 8-bit value. See https://github.com/google/xls/issues/76
if (xls_bit_count <= 8) {
return 8;
}
Expand Down Expand Up @@ -213,9 +210,14 @@ llvm::Value* LlvmTypeConverter::AsSignedValue(
std::optional<llvm::Type*> dest_type) const {
CHECK(xls_type->IsBits());
int64_t xls_bit_count = xls_type->AsBitsOrDie()->bit_count();
int64_t llvm_bit_count = GetLlvmBitCount(xls_bit_count);
llvm::Value* signed_value;
if (xls_bit_count <= 1) {
if (llvm_bit_count == xls_bit_count || xls_bit_count == 0) {
signed_value = value;
} else if (xls_bit_count == 1) {
// Just for this one case we don't need to do a shift.
signed_value = builder.CreateICmpNE(
value, llvm::ConstantInt::get(value->getType(), 0));
} else {
llvm::Value* sign_bit = builder.CreateTrunc(
builder.CreateLShr(
Expand All @@ -225,7 +227,7 @@ llvm::Value* LlvmTypeConverter::AsSignedValue(
sign_bit,
builder.CreateOr(InvertedPaddingMask(xls_type, builder), value), value);
}
return dest_type.has_value()
return dest_type.has_value() && dest_type.value() != signed_value->getType()
? builder.CreateSExt(signed_value, dest_type.value())
: signed_value;
}
Expand Down

0 comments on commit fe94f8f

Please sign in to comment.