Skip to content

Commit

Permalink
Support comparison operators in type_system_v2.
Browse files Browse the repository at this point in the history
This change also fixes the handling of expressions used as explicit parametric invocation arguments.

PiperOrigin-RevId: 713898572
  • Loading branch information
richmckeever authored and copybara-github committed Jan 10, 2025
1 parent fe94f8f commit 376406c
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 39 deletions.
9 changes: 1 addition & 8 deletions xls/dslx/type_system_v2/inference_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,7 @@ class InferenceTableImpl : public InferenceTable {
callee.parametric_bindings();
const std::vector<ExprOrType>& explicit_parametrics =
node.explicit_parametrics();
if (explicit_parametrics.size() > bindings.size()) {
return ArgCountMismatchErrorStatus(
node.span(),
absl::Substitute(
"Too many parametric values supplied; limit: $0 given: $1",
callee.parametric_bindings().size(), explicit_parametrics.size()),
file_table_);
}
CHECK(explicit_parametrics.size() <= bindings.size());
absl::flat_hash_map<const InferenceVariable*, InvocationScopedExpr> values;
for (int i = 0; i < bindings.size(); i++) {
const ParametricBinding* binding = bindings[i];
Expand Down
29 changes: 0 additions & 29 deletions xls/dslx/type_system_v2/inference_table_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -364,34 +364,5 @@ TEST_F(InferenceTableTest, ParametricVariableWithUnsupportedAnnotation) {
HasSubstr("Inference variables of type T are not supported")));
}

TEST_F(InferenceTableTest, TooManyParametricsInInvocation) {
ParseAndInitModuleAndTable(R"(
fn foo<N: u32>(a: uN[N]) -> uN[N] { a }
fn bar() {
foo<u32:4, u32:5>(u4:1);
}
)");

XLS_ASSERT_OK_AND_ASSIGN(const Function* foo,
module_->GetMemberOrError<Function>("foo"));
ASSERT_EQ(foo->parametric_bindings().size(), 1);
ASSERT_EQ(foo->params().size(), 1);
XLS_ASSERT_OK(
table_->DefineParametricVariable(*foo->parametric_bindings()[0]));
for (const Param* param : foo->params()) {
XLS_ASSERT_OK(table_->SetTypeAnnotation(param, param->type_annotation()));
}
XLS_ASSERT_OK_AND_ASSIGN(const Function* bar,
module_->GetMemberOrError<Function>("bar"));
ASSERT_EQ(bar->body()->statements().size(), 1);
const Invocation* invocation = down_cast<const Invocation*>(
ToAstNode(bar->body()->statements().at(0)->wrapped()));
EXPECT_THAT(
table_->AddParametricInvocation(*invocation, *foo, bar,
/*caller_invocation=*/std::nullopt),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("Too many parametric values supplied")));
}

} // namespace
} // namespace xls::dslx
6 changes: 6 additions & 0 deletions xls/dslx/type_system_v2/inference_table_to_type_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ class ConversionOrderVisitor : public AstNodeVisitorWithDefault {

absl::Status HandleParametricBindingExprsInternal(
const ParametricInvocation* parametric_invocation) {
for (ExprOrType explicit_parametric :
parametric_invocation->node().explicit_parametrics()) {
if (std::holds_alternative<Expr*>(explicit_parametric)) {
XLS_RETURN_IF_ERROR(std::get<Expr*>(explicit_parametric)->Accept(this));
}
}
parametric_invocation_stack_.push(parametric_invocation);
for (const ParametricBinding* binding :
parametric_invocation->callee().parametric_bindings()) {
Expand Down
64 changes: 62 additions & 2 deletions xls/dslx/type_system_v2/typecheck_module_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,27 @@ class PopulateInferenceTableVisitor : public AstNodeVisitorWithDefault {
// should have a type that was set when its parent was visited.
const NameRef* type_variable = *table_.GetTypeVariable(node);
if (GetBinopSameTypeKinds().contains(node->binop_kind())) {
// In the example `const C = a + b;`, the `ConstantDef` establishes a type
// variable that is just propagated down to `a` and `b` here, meaning that
// `a`, `b`, and the result must ultimately be the same type.
XLS_RETURN_IF_ERROR(table_.SetTypeVariable(node->lhs(), type_variable));
XLS_RETURN_IF_ERROR(table_.SetTypeVariable(node->rhs(), type_variable));
} else if (GetBinopComparisonKinds().contains(node->binop_kind())) {
// In a comparison example, like `const C = a > b;`, the `>` establishes a
// new type variable for `a` and `b` (meaning the two of them must be the
// same type), and attaches a bool annotation to the overall expression,
// which will then be assumed by the type variable for the `ConstantDef`.
XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation(
node, CreateBoolAnnotation(module_, node->span())));
XLS_ASSIGN_OR_RETURN(
const NameRef* operand_variable,
table_.DefineInternalVariable(
InferenceVariableKind::kType, const_cast<Binop*>(node),
GenerateInternalTypeVariableName(node)));
XLS_RETURN_IF_ERROR(
table_.SetTypeVariable(node->lhs(), operand_variable));
XLS_RETURN_IF_ERROR(
table_.SetTypeVariable(node->rhs(), operand_variable));
} else {
return absl::UnimplementedError(
absl::StrCat("Type inference version 2 is a work in progress and "
Expand Down Expand Up @@ -527,6 +546,9 @@ class PopulateInferenceTableVisitor : public AstNodeVisitorWithDefault {
VLOG(5) << "HandleParametricInvocation: " << node->ToString()
<< ", fn: " << fn.identifier();
CHECK(fn.IsParametric());
const std::vector<ParametricBinding*>& bindings = fn.parametric_bindings();
const std::vector<ExprOrType>& explicit_parametrics =
node->explicit_parametrics();
const std::optional<const Function*> caller = GetCurrentFunction();
current_function_stack_.push(&fn);
const bool function_processed_before =
Expand All @@ -537,22 +559,54 @@ class PopulateInferenceTableVisitor : public AstNodeVisitorWithDefault {
// The bindings need to be defined in the table up front, because the rest
// of the header may depend on them, and we can't even create a
// `ParametricInvocation` without them being registered.
for (const ParametricBinding* binding : fn.parametric_bindings()) {
for (const ParametricBinding* binding : bindings) {
XLS_RETURN_IF_ERROR(binding->Accept(this));
}
}

if (explicit_parametrics.size() > bindings.size()) {
return ArgCountMismatchErrorStatus(
node->span(),
absl::Substitute(
"Too many parametric values supplied; limit: $0 given: $1",
bindings.size(), explicit_parametrics.size()),
file_table_);
}

// Type-check the subtrees for any explicit parametric values. Note that the
// addition of the invocation above will have verified that a valid number
// of explicit parametrics was passed in.
for (int i = 0; i < explicit_parametrics.size(); i++) {
ExprOrType explicit_parametric = explicit_parametrics[i];
const ParametricBinding* formal_parametric = bindings[i];
if (std::holds_alternative<Expr*>(explicit_parametric)) {
const Expr* parametric_value_expr =
std::get<Expr*>(explicit_parametric);
XLS_ASSIGN_OR_RETURN(
const NameRef* type_variable,
table_.DefineInternalVariable(
InferenceVariableKind::kType,
const_cast<Expr*>(parametric_value_expr),
GenerateInternalTypeVariableName(parametric_value_expr)));
XLS_RETURN_IF_ERROR(
table_.SetTypeVariable(parametric_value_expr, type_variable));
XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation(
parametric_value_expr, formal_parametric->type_annotation()));
XLS_RETURN_IF_ERROR(parametric_value_expr->Accept(this));
}
}

// Register the parametric invocation in the table, regardless of whether
// we have seen the function before.
XLS_ASSIGN_OR_RETURN(
const ParametricInvocation* parametric_invocation,
table_.AddParametricInvocation(*node, fn, caller,
GetCurrentParametricInvocation()));
parametric_invocation_stack_.push(parametric_invocation);

// We don't need to process the entire function multiple times, if it's
// used in multiple contexts. Only the invocation nodes in it need to be
// dealt with multiple times.
parametric_invocation_stack_.push(parametric_invocation);
if (function_processed_before) {
VLOG(5) << "Reprocessing outbound invocations in this context from: "
<< fn.identifier();
Expand Down Expand Up @@ -618,6 +672,12 @@ class PopulateInferenceTableVisitor : public AstNodeVisitorWithDefault {
return absl::StrCat("internal_type_actual_member_", formal_member->name(),
"_at_", actual_member->span().ToString(file_table_));
}
// Variant for operands of a binary operator.
std::string GenerateInternalTypeVariableName(const Binop* binop) {
return absl::StrCat("internal_type_operand_",
BinopKindToString(binop->binop_kind()), "_at_",
binop->span().ToString(file_table_));
}

// Propagates the type from the def for `ref`, to `ref` itself in the
// inference table. This may result in a `TypeAnnotation` being added to the
Expand Down
125 changes: 125 additions & 0 deletions xls/dslx/type_system_v2/typecheck_module_v2_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,123 @@ const Z = X % 1 % Y % 2;
HasSubstr("node: `const Z = X % 1 % Y % 2;`, type: uN[32]")));
}

TEST(TypecheckV2Test, GlobalBoolConstantEqualsComparisonOfUntypedLiterals) {
EXPECT_THAT("const Z = 4 > 1;", TopNodeHasType("uN[1]"));
}

TEST(TypecheckV2Test, GlobalBoolConstantEqualsComparisonOfTypedLiterals) {
EXPECT_THAT("const Z = u32:4 < u32:1;", TopNodeHasType("uN[1]"));
}

TEST(TypecheckV2Test, GlobalBoolConstantEqualsComparisonOfLiteralsWithOneType) {
EXPECT_THAT("const Z = 4 < s32:1;", TopNodeHasType("uN[1]"));
}

TEST(TypecheckV2Test, GlobalBoolConstantEqualsComparisonOfVariables) {
XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"(
const X = u32:3;
const Y = u32:4;
const Z = Y >= X;
)"));
XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string,
TypeInfoToString(result.tm));
EXPECT_THAT(type_info_string, HasSubstr("node: `Z`, type: uN[1]"));
}

TEST(TypecheckV2Test, GlobalBoolConstantEqualsComparisonOfExprs) {
XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"(
const X = s24:3;
const Y = s24:4;
const Z = (Y + X * 2) == (1 - Y);
)"));
XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string,
TypeInfoToString(result.tm));
EXPECT_THAT(type_info_string,
AllOf(HasSubstr("node: `(Y + X * 2)`, type: sN[24]"),
HasSubstr("node: `(1 - Y)`, type: sN[24]"),
HasSubstr("node: `Z`, type: uN[1]")));
}

TEST(TypecheckV2Test, ComparisonAsFunctionArgument) {
XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"(
fn foo(a: bool) -> bool { a }
const Y = foo(1 != 2);
)"));
XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string,
TypeInfoToString(result.tm));
EXPECT_THAT(type_info_string, AllOf(HasSubstr("node: `1 != 2`, type: uN[1]"),
HasSubstr("node: `1`, type: uN[2]"),
HasSubstr("node: `2`, type: uN[2]")));
}

TEST(TypecheckV2Test, ComparisonOfReturnValues) {
XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"(
fn foo(a: u32) -> u32 { a }
const Y = foo(1) > foo(2);
)"));
XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string,
TypeInfoToString(result.tm));
EXPECT_THAT(type_info_string,
AllOf(HasSubstr("node: `Y`, type: uN[1]"),
HasSubstr("node: `foo(1)`, type: uN[32]"),
HasSubstr("node: `foo(2)`, type: uN[32]")));
}

TEST(TypecheckV2Test, ComparisonAsParametricArgument) {
XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"(
fn foo<S: bool>(a: xN[S][32]) -> xN[S][32] { a }
const Y = foo<{2 > 1}>(s32:5);
)"));
XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string,
TypeInfoToString(result.tm));
EXPECT_THAT(type_info_string, AllOf(HasSubstr("node: `Y`, type: sN[32]"),
HasSubstr("node: `2`, type: uN[2]"),
HasSubstr("node: `1`, type: uN[2]")));
}

TEST(TypecheckV2Test, ComparisonAsParametricArgumentWithConflictFails) {
EXPECT_THAT(R"(
fn foo<S: bool>(a: xN[S][32]) -> xN[S][32] { a }
const Y = foo<{2 > 1}>(u32:5);
)",
TypecheckFails(HasSignednessMismatch("uN[32]", "sN[32]")));
}

TEST(TypecheckV2Test, ComparisonAndSumAsParametricArguments) {
XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"(
const X = u32:1;
fn foo<S: bool, N: u32>(a: xN[S][N]) -> xN[S][N] { a }
const Y = foo<{X == 1}, {X + 3}>(s4:3);
)"));
XLS_ASSERT_OK_AND_ASSIGN(std::string type_info_string,
TypeInfoToString(result.tm));
EXPECT_THAT(type_info_string, HasSubstr("node: `Y`, type: sN[4]"));
}

TEST(TypecheckV2Test, ComparisonAndSumParametricArgumentsWithConflictFails) {
EXPECT_THAT(R"(
const X = u32:1;
fn foo<S: bool, N: u32>(a: xN[S][N]) -> xN[S][N] { a }
const Y = foo<{X == 1}, {X + 4}>(s4:3);
)",
TypecheckFails(HasSizeMismatch("sN[4]", "sN[5]")));
}

TEST(TypecheckV2Test, ExplicitParametricExpressionMismatchingBindingTypeFails) {
EXPECT_THAT(R"(
const X = u32:1;
fn foo<N: u32>(a: uN[N]) -> uN[N] { a }
const Y = foo<{X == 1}>(s4:3);
)",
TypecheckFails(HasSizeMismatch("bool", "u32")));
}

TEST(TypecheckV2Test,
GlobalBoolConstantEqualsComparisonOfConflictingTypedLiteralsFails) {
EXPECT_THAT("const Z = u32:4 >= s32:1;",
TypecheckFails(HasSignednessMismatch("s32", "u32")));
}

TEST(TypecheckV2Test,
GlobalIntegerConstantEqualsAnotherConstantWithAnnotationOnName) {
XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"(
Expand Down Expand Up @@ -1067,6 +1184,14 @@ const Y = X(1);
TypecheckFails(HasSubstr("callee `X` is not a function")));
}

TEST(TypecheckV2Test, ParametricFunctionCallWithTooManyParametricsFails) {
EXPECT_THAT(R"(
fn foo<N: u32>() -> u32 { N }
const X = foo<3, 4>();
)",
TypecheckFails(HasSubstr("Too many parametric values supplied")));
}

TEST(TypecheckV2Test, ParametricFunctionReturningIntegerParameter) {
XLS_ASSERT_OK_AND_ASSIGN(TypecheckResult result, TypecheckV2(R"(
fn foo<N: u32>() -> u32 { N }
Expand Down

0 comments on commit 376406c

Please sign in to comment.