From fe94f8fc14df11ef467d8fe6344e2a238987782a Mon Sep 17 00:00:00 2001 From: Alex Light Date: Thu, 9 Jan 2025 15:43:13 -0800 Subject: [PATCH] Change JIT to store 1-bit values as 8-bits in memory We were creating and storing to 1-bit allocas in the LLVM jit. This could in some situations interact badly with LLVM optimizations causing llvm to create bad code. Avoid this by simply following what clang does and giving bools 8 bits. PiperOrigin-RevId: 713819834 --- xls/interpreter/ir_evaluator_test_base.cc | 65 +++++++++++++++++++++++ xls/jit/ir_builder_visitor.cc | 51 +++++++++++++----- xls/jit/llvm_type_converter.cc | 18 ++++--- 3 files changed, 114 insertions(+), 20 deletions(-) diff --git a/xls/interpreter/ir_evaluator_test_base.cc b/xls/interpreter/ir_evaluator_test_base.cc index be11226c8c..cc8914f682 100644 --- a/xls/interpreter/ir_evaluator_test_base.cc +++ b/xls/interpreter/ir_evaluator_test_base.cc @@ -1794,6 +1794,71 @@ TEST_P(IrEvaluatorTestBase, InterpretOneHotSelect2DArray) { "[[bits[4]:12, bits[4]:14], [bits[4]:9, bits[4]:6]]"))); } +TEST_P(IrEvaluatorTestBase, OneBitSignExtend) { + Package package("my_package"); + + // 1-bit values are handled somewhat specially. + XLS_ASSERT_OK_AND_ASSIGN(Function * function, + ParseAndGetFunction(&package, R"( +fn __sample__main(x1: bits[4] id=1) -> (bits[24], bits[20][1]) { + or_reduce.328: bits[1] = or_reduce(x1, id=328) + not.335: bits[1] = not(or_reduce.328, id=335) + x16: bits[24] = sign_ext(not.335, new_bit_count=24, id=333) + x17: bits[20][1] = literal(value=[0], id=326, pos=[(0,28,36)]) + ret tuple.308: (bits[24], bits[20][1]) = tuple(x16, x17, id=308, pos=[(0,38,8)]) +} + )")); + XLS_ASSERT_OK_AND_ASSIGN( + Value expected0, + ValueBuilder::Tuple({ValueBuilder::Bits(UBits(0xffffff, 24)), + ValueBuilder::SBitsArray({0}, 20)}) + .Build()); + EXPECT_THAT(RunWithNoEvents(function, {Value(UBits(0, 4))}), + IsOkAndHolds(Value(expected0))); + XLS_ASSERT_OK_AND_ASSIGN( + Value expected_other, + ValueBuilder::Tuple({ValueBuilder::Bits(UBits(0x0, 24)), + ValueBuilder::SBitsArray({0}, 20)}) + .Build()); + EXPECT_THAT(RunWithNoEvents(function, {Value(UBits(1, 4))}), + IsOkAndHolds(Value(expected_other))); +} + +TEST_P(IrEvaluatorTestBase, InterpretPrioritySelectBus) { + Package package("my_package"); + // This was found to cause a bus error on jit. + XLS_ASSERT_OK_AND_ASSIGN(Function * function, + ParseAndGetFunction(&package, R"( +fn __sample__main(x1: bits[4] id=1) -> (bits[24], bits[20][1]) { + literal.299: (bits[1], bits[24]) = literal(value=(0, 0), id=299, pos=[(0,24,45)]) + literal.170: (bits[1], bits[24]) = literal(value=(1, 16777215), id=170, pos=[(0,26,17)]) + literal.141: bits[20][7] = literal(value=[0, 0, 0, 0, 0, 0, 0], id=141) + + priority_sel.71: (bits[1], bits[24]) = priority_sel(x1, cases=[literal.299, literal.299, literal.299, literal.299], default=literal.170, id=71) + + x15: bits[1] = tuple_index(priority_sel.71, index=0, id=72, pos=[(0,22,13)]) + x16: bits[24] = tuple_index(priority_sel.71, index=1, id=73, pos=[(0,22,18)]) + x17: bits[20][1] = array_slice(literal.141, x15, width=1, id=77, pos=[(0,28,36)]) + + ret tuple.308: (bits[24], bits[20][1]) = tuple(x16, x17, id=308, pos=[(0,38,8)]) +} + )")); + XLS_ASSERT_OK_AND_ASSIGN( + Value expected0, + ValueBuilder::Tuple({ValueBuilder::Bits(UBits(0xffffff, 24)), + ValueBuilder::SBitsArray({0}, 20)}) + .Build()); + EXPECT_THAT(RunWithNoEvents(function, {Value(UBits(0, 4))}), + IsOkAndHolds(Value(expected0))); + XLS_ASSERT_OK_AND_ASSIGN( + Value expected_other, + ValueBuilder::Tuple({ValueBuilder::Bits(UBits(0x0, 24)), + ValueBuilder::SBitsArray({0}, 20)}) + .Build()); + EXPECT_THAT(RunWithNoEvents(function, {Value(UBits(1, 4))}), + IsOkAndHolds(Value(expected_other))); +} + TEST_P(IrEvaluatorTestBase, InterpretPrioritySelect) { Package package("my_package"); XLS_ASSERT_OK_AND_ASSIGN(Function * function, diff --git a/xls/jit/ir_builder_visitor.cc b/xls/jit/ir_builder_visitor.cc index 28503b5315..7cc1415ad5 100644 --- a/xls/jit/ir_builder_visitor.cc +++ b/xls/jit/ir_builder_visitor.cc @@ -684,6 +684,16 @@ absl::StatusOr CreateLlvmFunction( .getCallee()); } +// Get the truthiness of the given value (0 is falsy, all other values are +// truty). +llvm::Value* Truthiness(llvm::Value* value, llvm::IRBuilder<>& builder) { + CHECK(value->getType()->isIntegerTy()); + return builder.CreateICmpNE( + value, llvm::ConstantInt::get(value->getType(), 0), + value->hasName() ? absl::StrFormat("%s_truthiness", value->getName()) + : ""); +} + // Abstraction gathering together the necessary context for emitting the LLVM IR // for a given node. This data structure decouples IR generation for the // top-level function from the IR generation of each node. This enables, for @@ -1081,8 +1091,19 @@ void NodeIrContext::FinalizeWithValue( std::optional return_type) { llvm::IRBuilder<>* b = exit_builder.has_value() ? exit_builder.value() : &entry_builder(); - result = type_converter().ClearPaddingBits( - result, return_type.value_or(node()->GetType()), *b); + // Special case 0 & 1 bit values to automatically extend them since using them + // as a boolean within ssa registers is a common pattern. + if (result->getType()->isIntegerTy(1)) { + CHECK(node()->GetType()->IsBits() && + node()->GetType()->GetFlatBitCount() == 1) + << node(); + CHECK_EQ(type_converter().GetLlvmBitCount(node()->GetType()->AsBitsOrDie()), + 8); + result = b->CreateZExt(result, b->getInt8Ty()); + } else { + result = type_converter().ClearPaddingBits( + result, return_type.value_or(node()->GetType()), *b); + } if (GetOutputPtrs().empty()) { b->CreateRet(b->getFalse()); return; @@ -1396,7 +1417,9 @@ absl::Status IrBuilderVisitor::HandleRegisterWrite(RegisterWrite* write) { llvm::IRBuilder<> current_step_builder(current_step); llvm::IRBuilder<> reset_selected_builder(*reset_selected); XLS_ASSIGN_OR_RETURN(int64_t op_idx, write->reset_operand_number()); - auto reset_state = node_context.LoadOperand(op_idx, ¤t_step_builder); + auto reset_state = + Truthiness(node_context.LoadOperand(op_idx, ¤t_step_builder), + current_step_builder); if (write->GetRegister()->reset()->active_low) { // current_step_builder.CreateCondBr(reset_state, *reset_selected, // no_reset_selected); @@ -1424,8 +1447,10 @@ absl::Status IrBuilderVisitor::HandleRegisterWrite(RegisterWrite* write) { llvm::BasicBlock::Create(ctx(), "load_enabled", function); llvm::IRBuilder<> no_load_enable_builder(*no_load_enable_selected); llvm::IRBuilder<> current_step_builder(current_step); - auto load_enable_state = node_context.LoadOperand( - write->load_enable_operand_number().value(), ¤t_step_builder); + auto load_enable_state = Truthiness( + node_context.LoadOperand(write->load_enable_operand_number().value(), + ¤t_step_builder), + current_step_builder); // the original value is at operand_count+1 XLS_ASSIGN_OR_RETURN( no_load_enable_value, @@ -1548,8 +1573,9 @@ absl::Status IrBuilderVisitor::HandleAssert(Assert* assert_op) { fail_builder.CreateBr(after_block); - b.CreateCondBr(node_context.LoadOperand(Assert::kConditionOperand), ok_block, - fail_block); + b.CreateCondBr( + Truthiness(node_context.LoadOperand(Assert::kConditionOperand), b), + ok_block, fail_block); auto after_builder = std::make_unique>(after_block); llvm::Value* token = type_converter()->GetToken(); @@ -1594,7 +1620,7 @@ absl::Status IrBuilderVisitor::HandleTrace(Trace* trace_op) { /*include_wrapper_args=*/true)); llvm::IRBuilder<>& b = node_context.entry_builder(); - llvm::Value* condition = node_context.LoadOperand(1); + llvm::Value* condition = Truthiness(node_context.LoadOperand(1), b); llvm::Value* events_ptr = node_context.GetInterpreterEventsArg(); llvm::Value* jit_runtime_ptr = node_context.GetJitRuntimeArg(); @@ -2341,7 +2367,7 @@ absl::Status IrBuilderVisitor::HandleGate(Gate* gate) { XLS_ASSIGN_OR_RETURN(NodeIrContext node_context, NewNodeIrContext(gate, {"condition", "data"})); llvm::IRBuilder<>& b = node_context.entry_builder(); - llvm::Value* condition = node_context.LoadOperand(0); + llvm::Value* condition = Truthiness(node_context.LoadOperand(0), b); llvm::Value* data = node_context.LoadOperand(1); // TODO(meheff): 2022/09/09 Replace with a if/then/else block which does a @@ -2660,7 +2686,7 @@ absl::Status IrBuilderVisitor::HandleNext(Next* next) { } // If the predicate is true, emulate the `next_value` node's effects. - llvm::Value* predicate = node_context.LoadOperand(2); + llvm::Value* predicate = Truthiness(node_context.LoadOperand(2), b); LlvmIfThen if_then = CreateIfThen(predicate, b, next->GetName()); LlvmMemcpy(node_context.GetOutputPtr(0), value_ptr, @@ -3353,7 +3379,8 @@ absl::Status IrBuilderVisitor::HandleReceive(Receive* recv) { LlvmTypeConverter::ZeroOfType(data_type), data_buffer); if (recv->predicate().has_value()) { - llvm::Value* predicate = node_context.LoadOperand(1); + llvm::Value* predicate = + Truthiness(node_context.LoadOperand(1), node_context.entry_builder()); // First, declare the join block (so the case blocks can refer to it). llvm::BasicBlock* join_block = @@ -3459,7 +3486,7 @@ absl::Status IrBuilderVisitor::HandleSend(Send* send) { jit_context_.GetOrAllocateQueueIndex(send->channel_name()); if (send->predicate().has_value()) { - llvm::Value* predicate = node_context.LoadOperand(2); + llvm::Value* predicate = Truthiness(node_context.LoadOperand(2), b); // First, declare the join block (so the case blocks can refer to it). llvm::BasicBlock* join_block = diff --git a/xls/jit/llvm_type_converter.cc b/xls/jit/llvm_type_converter.cc index 94d3e7c2e9..32cc2e56a9 100644 --- a/xls/jit/llvm_type_converter.cc +++ b/xls/jit/llvm_type_converter.cc @@ -48,12 +48,9 @@ LlvmTypeConverter::LlvmTypeConverter(llvm::LLVMContext* context, : context_(*context), data_layout_(data_layout) {} int64_t LlvmTypeConverter::GetLlvmBitCount(int64_t xls_bit_count) const { - // LLVM does not accept 0-bit types, and we want to be able to JIT-compile - // unoptimized IR, so for the time being we make a dummy 1-bit value. - // See https://github.com/google/xls/issues/76 - if (xls_bit_count <= 1) { - return 1; - } + // LLVM does not accept 0-bit types and < 8 bit types often have issues, and + // we want to be able to JIT-compile unoptimized IR, so for the time being we + // make a dummy 8-bit value. See https://github.com/google/xls/issues/76 if (xls_bit_count <= 8) { return 8; } @@ -213,9 +210,14 @@ llvm::Value* LlvmTypeConverter::AsSignedValue( std::optional dest_type) const { CHECK(xls_type->IsBits()); int64_t xls_bit_count = xls_type->AsBitsOrDie()->bit_count(); + int64_t llvm_bit_count = GetLlvmBitCount(xls_bit_count); llvm::Value* signed_value; - if (xls_bit_count <= 1) { + if (llvm_bit_count == xls_bit_count || xls_bit_count == 0) { signed_value = value; + } else if (xls_bit_count == 1) { + // Just for this one case we don't need to do a shift. + signed_value = builder.CreateICmpNE( + value, llvm::ConstantInt::get(value->getType(), 0)); } else { llvm::Value* sign_bit = builder.CreateTrunc( builder.CreateLShr( @@ -225,7 +227,7 @@ llvm::Value* LlvmTypeConverter::AsSignedValue( sign_bit, builder.CreateOr(InvertedPaddingMask(xls_type, builder), value), value); } - return dest_type.has_value() + return dest_type.has_value() && dest_type.value() != signed_value->getType() ? builder.CreateSExt(signed_value, dest_type.value()) : signed_value; }