diff --git a/xls/interpreter/ir_evaluator_test_base.cc b/xls/interpreter/ir_evaluator_test_base.cc index be11226c8c..cc8914f682 100644 --- a/xls/interpreter/ir_evaluator_test_base.cc +++ b/xls/interpreter/ir_evaluator_test_base.cc @@ -1794,6 +1794,71 @@ TEST_P(IrEvaluatorTestBase, InterpretOneHotSelect2DArray) { "[[bits[4]:12, bits[4]:14], [bits[4]:9, bits[4]:6]]"))); } +TEST_P(IrEvaluatorTestBase, OneBitSignExtend) { + Package package("my_package"); + + // 1-bit values are handled somewhat specially. + XLS_ASSERT_OK_AND_ASSIGN(Function * function, + ParseAndGetFunction(&package, R"( +fn __sample__main(x1: bits[4] id=1) -> (bits[24], bits[20][1]) { + or_reduce.328: bits[1] = or_reduce(x1, id=328) + not.335: bits[1] = not(or_reduce.328, id=335) + x16: bits[24] = sign_ext(not.335, new_bit_count=24, id=333) + x17: bits[20][1] = literal(value=[0], id=326, pos=[(0,28,36)]) + ret tuple.308: (bits[24], bits[20][1]) = tuple(x16, x17, id=308, pos=[(0,38,8)]) +} + )")); + XLS_ASSERT_OK_AND_ASSIGN( + Value expected0, + ValueBuilder::Tuple({ValueBuilder::Bits(UBits(0xffffff, 24)), + ValueBuilder::SBitsArray({0}, 20)}) + .Build()); + EXPECT_THAT(RunWithNoEvents(function, {Value(UBits(0, 4))}), + IsOkAndHolds(Value(expected0))); + XLS_ASSERT_OK_AND_ASSIGN( + Value expected_other, + ValueBuilder::Tuple({ValueBuilder::Bits(UBits(0x0, 24)), + ValueBuilder::SBitsArray({0}, 20)}) + .Build()); + EXPECT_THAT(RunWithNoEvents(function, {Value(UBits(1, 4))}), + IsOkAndHolds(Value(expected_other))); +} + +TEST_P(IrEvaluatorTestBase, InterpretPrioritySelectBus) { + Package package("my_package"); + // This was found to cause a bus error on jit. + XLS_ASSERT_OK_AND_ASSIGN(Function * function, + ParseAndGetFunction(&package, R"( +fn __sample__main(x1: bits[4] id=1) -> (bits[24], bits[20][1]) { + literal.299: (bits[1], bits[24]) = literal(value=(0, 0), id=299, pos=[(0,24,45)]) + literal.170: (bits[1], bits[24]) = literal(value=(1, 16777215), id=170, pos=[(0,26,17)]) + literal.141: bits[20][7] = literal(value=[0, 0, 0, 0, 0, 0, 0], id=141) + + priority_sel.71: (bits[1], bits[24]) = priority_sel(x1, cases=[literal.299, literal.299, literal.299, literal.299], default=literal.170, id=71) + + x15: bits[1] = tuple_index(priority_sel.71, index=0, id=72, pos=[(0,22,13)]) + x16: bits[24] = tuple_index(priority_sel.71, index=1, id=73, pos=[(0,22,18)]) + x17: bits[20][1] = array_slice(literal.141, x15, width=1, id=77, pos=[(0,28,36)]) + + ret tuple.308: (bits[24], bits[20][1]) = tuple(x16, x17, id=308, pos=[(0,38,8)]) +} + )")); + XLS_ASSERT_OK_AND_ASSIGN( + Value expected0, + ValueBuilder::Tuple({ValueBuilder::Bits(UBits(0xffffff, 24)), + ValueBuilder::SBitsArray({0}, 20)}) + .Build()); + EXPECT_THAT(RunWithNoEvents(function, {Value(UBits(0, 4))}), + IsOkAndHolds(Value(expected0))); + XLS_ASSERT_OK_AND_ASSIGN( + Value expected_other, + ValueBuilder::Tuple({ValueBuilder::Bits(UBits(0x0, 24)), + ValueBuilder::SBitsArray({0}, 20)}) + .Build()); + EXPECT_THAT(RunWithNoEvents(function, {Value(UBits(1, 4))}), + IsOkAndHolds(Value(expected_other))); +} + TEST_P(IrEvaluatorTestBase, InterpretPrioritySelect) { Package package("my_package"); XLS_ASSERT_OK_AND_ASSIGN(Function * function, diff --git a/xls/jit/ir_builder_visitor.cc b/xls/jit/ir_builder_visitor.cc index 28503b5315..7cc1415ad5 100644 --- a/xls/jit/ir_builder_visitor.cc +++ b/xls/jit/ir_builder_visitor.cc @@ -684,6 +684,16 @@ absl::StatusOr CreateLlvmFunction( .getCallee()); } +// Get the truthiness of the given value (0 is falsy, all other values are +// truty). +llvm::Value* Truthiness(llvm::Value* value, llvm::IRBuilder<>& builder) { + CHECK(value->getType()->isIntegerTy()); + return builder.CreateICmpNE( + value, llvm::ConstantInt::get(value->getType(), 0), + value->hasName() ? absl::StrFormat("%s_truthiness", value->getName()) + : ""); +} + // Abstraction gathering together the necessary context for emitting the LLVM IR // for a given node. This data structure decouples IR generation for the // top-level function from the IR generation of each node. This enables, for @@ -1081,8 +1091,19 @@ void NodeIrContext::FinalizeWithValue( std::optional return_type) { llvm::IRBuilder<>* b = exit_builder.has_value() ? exit_builder.value() : &entry_builder(); - result = type_converter().ClearPaddingBits( - result, return_type.value_or(node()->GetType()), *b); + // Special case 0 & 1 bit values to automatically extend them since using them + // as a boolean within ssa registers is a common pattern. + if (result->getType()->isIntegerTy(1)) { + CHECK(node()->GetType()->IsBits() && + node()->GetType()->GetFlatBitCount() == 1) + << node(); + CHECK_EQ(type_converter().GetLlvmBitCount(node()->GetType()->AsBitsOrDie()), + 8); + result = b->CreateZExt(result, b->getInt8Ty()); + } else { + result = type_converter().ClearPaddingBits( + result, return_type.value_or(node()->GetType()), *b); + } if (GetOutputPtrs().empty()) { b->CreateRet(b->getFalse()); return; @@ -1396,7 +1417,9 @@ absl::Status IrBuilderVisitor::HandleRegisterWrite(RegisterWrite* write) { llvm::IRBuilder<> current_step_builder(current_step); llvm::IRBuilder<> reset_selected_builder(*reset_selected); XLS_ASSIGN_OR_RETURN(int64_t op_idx, write->reset_operand_number()); - auto reset_state = node_context.LoadOperand(op_idx, ¤t_step_builder); + auto reset_state = + Truthiness(node_context.LoadOperand(op_idx, ¤t_step_builder), + current_step_builder); if (write->GetRegister()->reset()->active_low) { // current_step_builder.CreateCondBr(reset_state, *reset_selected, // no_reset_selected); @@ -1424,8 +1447,10 @@ absl::Status IrBuilderVisitor::HandleRegisterWrite(RegisterWrite* write) { llvm::BasicBlock::Create(ctx(), "load_enabled", function); llvm::IRBuilder<> no_load_enable_builder(*no_load_enable_selected); llvm::IRBuilder<> current_step_builder(current_step); - auto load_enable_state = node_context.LoadOperand( - write->load_enable_operand_number().value(), ¤t_step_builder); + auto load_enable_state = Truthiness( + node_context.LoadOperand(write->load_enable_operand_number().value(), + ¤t_step_builder), + current_step_builder); // the original value is at operand_count+1 XLS_ASSIGN_OR_RETURN( no_load_enable_value, @@ -1548,8 +1573,9 @@ absl::Status IrBuilderVisitor::HandleAssert(Assert* assert_op) { fail_builder.CreateBr(after_block); - b.CreateCondBr(node_context.LoadOperand(Assert::kConditionOperand), ok_block, - fail_block); + b.CreateCondBr( + Truthiness(node_context.LoadOperand(Assert::kConditionOperand), b), + ok_block, fail_block); auto after_builder = std::make_unique>(after_block); llvm::Value* token = type_converter()->GetToken(); @@ -1594,7 +1620,7 @@ absl::Status IrBuilderVisitor::HandleTrace(Trace* trace_op) { /*include_wrapper_args=*/true)); llvm::IRBuilder<>& b = node_context.entry_builder(); - llvm::Value* condition = node_context.LoadOperand(1); + llvm::Value* condition = Truthiness(node_context.LoadOperand(1), b); llvm::Value* events_ptr = node_context.GetInterpreterEventsArg(); llvm::Value* jit_runtime_ptr = node_context.GetJitRuntimeArg(); @@ -2341,7 +2367,7 @@ absl::Status IrBuilderVisitor::HandleGate(Gate* gate) { XLS_ASSIGN_OR_RETURN(NodeIrContext node_context, NewNodeIrContext(gate, {"condition", "data"})); llvm::IRBuilder<>& b = node_context.entry_builder(); - llvm::Value* condition = node_context.LoadOperand(0); + llvm::Value* condition = Truthiness(node_context.LoadOperand(0), b); llvm::Value* data = node_context.LoadOperand(1); // TODO(meheff): 2022/09/09 Replace with a if/then/else block which does a @@ -2660,7 +2686,7 @@ absl::Status IrBuilderVisitor::HandleNext(Next* next) { } // If the predicate is true, emulate the `next_value` node's effects. - llvm::Value* predicate = node_context.LoadOperand(2); + llvm::Value* predicate = Truthiness(node_context.LoadOperand(2), b); LlvmIfThen if_then = CreateIfThen(predicate, b, next->GetName()); LlvmMemcpy(node_context.GetOutputPtr(0), value_ptr, @@ -3353,7 +3379,8 @@ absl::Status IrBuilderVisitor::HandleReceive(Receive* recv) { LlvmTypeConverter::ZeroOfType(data_type), data_buffer); if (recv->predicate().has_value()) { - llvm::Value* predicate = node_context.LoadOperand(1); + llvm::Value* predicate = + Truthiness(node_context.LoadOperand(1), node_context.entry_builder()); // First, declare the join block (so the case blocks can refer to it). llvm::BasicBlock* join_block = @@ -3459,7 +3486,7 @@ absl::Status IrBuilderVisitor::HandleSend(Send* send) { jit_context_.GetOrAllocateQueueIndex(send->channel_name()); if (send->predicate().has_value()) { - llvm::Value* predicate = node_context.LoadOperand(2); + llvm::Value* predicate = Truthiness(node_context.LoadOperand(2), b); // First, declare the join block (so the case blocks can refer to it). llvm::BasicBlock* join_block = diff --git a/xls/jit/llvm_type_converter.cc b/xls/jit/llvm_type_converter.cc index 94d3e7c2e9..32cc2e56a9 100644 --- a/xls/jit/llvm_type_converter.cc +++ b/xls/jit/llvm_type_converter.cc @@ -48,12 +48,9 @@ LlvmTypeConverter::LlvmTypeConverter(llvm::LLVMContext* context, : context_(*context), data_layout_(data_layout) {} int64_t LlvmTypeConverter::GetLlvmBitCount(int64_t xls_bit_count) const { - // LLVM does not accept 0-bit types, and we want to be able to JIT-compile - // unoptimized IR, so for the time being we make a dummy 1-bit value. - // See https://github.com/google/xls/issues/76 - if (xls_bit_count <= 1) { - return 1; - } + // LLVM does not accept 0-bit types and < 8 bit types often have issues, and + // we want to be able to JIT-compile unoptimized IR, so for the time being we + // make a dummy 8-bit value. See https://github.com/google/xls/issues/76 if (xls_bit_count <= 8) { return 8; } @@ -213,9 +210,14 @@ llvm::Value* LlvmTypeConverter::AsSignedValue( std::optional dest_type) const { CHECK(xls_type->IsBits()); int64_t xls_bit_count = xls_type->AsBitsOrDie()->bit_count(); + int64_t llvm_bit_count = GetLlvmBitCount(xls_bit_count); llvm::Value* signed_value; - if (xls_bit_count <= 1) { + if (llvm_bit_count == xls_bit_count || xls_bit_count == 0) { signed_value = value; + } else if (xls_bit_count == 1) { + // Just for this one case we don't need to do a shift. + signed_value = builder.CreateICmpNE( + value, llvm::ConstantInt::get(value->getType(), 0)); } else { llvm::Value* sign_bit = builder.CreateTrunc( builder.CreateLShr( @@ -225,7 +227,7 @@ llvm::Value* LlvmTypeConverter::AsSignedValue( sign_bit, builder.CreateOr(InvertedPaddingMask(xls_type, builder), value), value); } - return dest_type.has_value() + return dest_type.has_value() && dest_type.value() != signed_value->getType() ? builder.CreateSExt(signed_value, dest_type.value()) : signed_value; }