Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Triton] Verify DesciptorLoad/StoreOp block types #5566

Merged
merged 1 commit into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 77 additions & 75 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1240,58 +1240,60 @@ def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable


def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [MemoryEffects<[MemRead<GlobalMemory>]>]> {
let summary = "Load from descriptor";
let description = [{
This operation will be lowered to Nvidia TMA load operation on targets supporting it.
`desc` is a tensor descriptor object.
The destination tensor type and shape must match the descriptor otherwise the result is undefined.
let summary = "Load from descriptor";
let description = [{
This operation will be lowered to Nvidia TMA load operation on targets supporting it.
`desc` is a tensor descriptor object.
The destination tensor type and shape must match the descriptor otherwise the result is undefined.

This is an escape hatch and is only there for testing/experimenting.
This op will be removed in the future.
}];
let arguments = (
ins
TT_TensorDescType:$desc,
Variadic<I32>:$indices,
DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
DefaultValuedAttr<TT_EvictionPolicyAttr, "::mlir::triton::EvictionPolicy::NORMAL">:$evict
);
This is an escape hatch and is only there for testing/experimenting.
This op will be removed in the future.
}];
let arguments = (ins
TT_TensorDescType:$desc,
Variadic<I32>:$indices,
DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
DefaultValuedAttr<TT_EvictionPolicyAttr, "::mlir::triton::EvictionPolicy::NORMAL">:$evict
);

let results = (outs TT_Tensor:$result);
let results = (outs TT_Tensor:$result);

let assemblyFormat = [{
$desc `[` $indices `]`
oilist(
`cacheModifier` `=` $cache |
`evictionPolicy` `=` $evict
)
attr-dict `:` qualified(type($desc)) `->` type($result)
}];
let assemblyFormat = [{
$desc `[` $indices `]`
oilist(
`cacheModifier` `=` $cache |
`evictionPolicy` `=` $evict
)
attr-dict `:` qualified(type($desc)) `->` type($result)
}];

let hasVerifier = 1;
}

def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [
MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>,
]> {
let summary = "store value based on descriptor";
let description = [{
This operation will be lowered to Nvidia TMA store operation on targets supporting it.
`desc` is a tensor descriptor object.
The shape and types of `src` must match the descriptor otherwise the result is undefined.
let summary = "store value based on descriptor";
let description = [{
This operation will be lowered to Nvidia TMA store operation on targets supporting it.
`desc` is a tensor descriptor object.
The shape and types of `src` must match the descriptor otherwise the result is undefined.

This is an escape hatch and is only there for testing/experimenting.
This op will be removed in the future.
}];
let arguments = (
ins
TT_TensorDescType:$desc,
TT_Tensor:$src,
Variadic<I32>:$indices
);
This is an escape hatch and is only there for testing/experimenting.
This op will be removed in the future.
}];
let arguments = (ins
TT_TensorDescType:$desc,
TT_Tensor:$src,
Variadic<I32>:$indices
);

let assemblyFormat = [{
$desc `[` $indices `]` `,` $src
attr-dict `:` qualified(type($desc)) `,` type($src)
}];
let assemblyFormat = [{
$desc `[` $indices `]` `,` $src
attr-dict `:` qualified(type($desc)) `,` type($src)
}];

let hasVerifier = 1;
}

def TT_ExperimentalTensormapCreateOp: TT_Op<
Expand All @@ -1301,46 +1303,46 @@ def TT_ExperimentalTensormapCreateOp: TT_Op<
AttrSizedOperandSegments,
]
> {
let summary = "Create a new TMA descriptor on device";
let arguments = (
ins
TT_PtrType:$desc_ptr,
TT_PtrType:$global_address,
Variadic<I32>:$box_dim,
Variadic<I32>:$global_dim,
Variadic<I64>:$global_stride,
Variadic<I32>:$element_stride,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<12>]>:$elem_type,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<2>]>:$interleave_layout,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$swizzle_mode,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$fill_mode
);
let extraClassDeclaration = [{
int32_t getRank() {
return getBoxDim().size();
}
}];
let assemblyFormat = [{
$desc_ptr `,` $global_address `,`
`[` $box_dim `]` `,`
`[` $global_dim `]` `,`
`[` $global_stride `]` `,`
`[` $element_stride `]`
attr-dict `:` functional-type(operands, results)
}];
let summary = "Create a new TMA descriptor on device";
let arguments = (
ins
TT_PtrType:$desc_ptr,
TT_PtrType:$global_address,
Variadic<I32>:$box_dim,
Variadic<I32>:$global_dim,
Variadic<I64>:$global_stride,
Variadic<I32>:$element_stride,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<12>]>:$elem_type,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<2>]>:$interleave_layout,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$swizzle_mode,
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$fill_mode
);
let extraClassDeclaration = [{
int32_t getRank() {
return getBoxDim().size();
}
}];
let assemblyFormat = [{
$desc_ptr `,` $global_address `,`
`[` $box_dim `]` `,`
`[` $global_dim `]` `,`
`[` $global_stride `]` `,`
`[` $element_stride `]`
attr-dict `:` functional-type(operands, results)
}];

let hasVerifier = 1;
let hasVerifier = 1;
}

def TT_ExperimentalTensormapFenceproxyAcquireOp: TT_Op<
"experimental_tensormap_fenceproxy_acquire",
[MemoryEffects<[MemWrite<GlobalMemory>]>]
> {
let summary = "Acquire fence on a tensormap object";
let arguments = (ins TT_PtrType:$desc_ptr);
let assemblyFormat = [{
$desc_ptr attr-dict `:` qualified(type($desc_ptr))
}];
let summary = "Acquire fence on a tensormap object";
let arguments = (ins TT_PtrType:$desc_ptr);
let assemblyFormat = [{
$desc_ptr attr-dict `:` qualified(type($desc_ptr))
}];
}


Expand Down
21 changes: 21 additions & 0 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,27 @@ LogicalResult GatherOp::inferReturnTypes(
return success();
}

// -- ExperimentalDesciptorLoadOp --
static LogicalResult verifyDesciptorLoadStoreType(Operation *op,
TensorDescType desc,
RankedTensorType tensor) {
RankedTensorType block = desc.getBlockType();
if (block.getShape() == tensor.getShape() &&
block.getElementType() == tensor.getElementType())
return success();
return op->emitOpError("tensor desciptor block and tensor types must match");
}

LogicalResult ExperimentalDescriptorLoadOp::verify() {
return verifyDesciptorLoadStoreType(*this, getDesc().getType(), getType());
}

// -- ExperimentalDesciptorStoreOp --
LogicalResult ExperimentalDescriptorStoreOp::verify() {
return verifyDesciptorLoadStoreType(*this, getDesc().getType(),
getSrc().getType());
}

// -- ExperimentalTensormapCreateOp --
LogicalResult ExperimentalTensormapCreateOp::verify() {
auto rank = getBoxDim().size();
Expand Down
27 changes: 27 additions & 0 deletions test/Triton/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,30 @@ tt.func @gather_op(%arg0: tensor<128x16xf32>, %arg1: tensor<512x4xi32>) {
%0 = tt.gather %arg0[%arg1] {axis = 3 : i32} : (tensor<128x16xf32>, tensor<512x4xi32>) -> tensor<512x4xf32>
tt.return
}

// -----

tt.func @invalid_desc_load(%arg0: !tt.tensordesc<tensor<16x16xf32>>) {
%c = arith.constant 0 : i32
// expected-error @below {{tensor desciptor block and tensor types must match}}
tt.experimental_descriptor_load %arg0[%c, %c] : !tt.tensordesc<tensor<16x16xf32>> -> tensor<16xf32>
tt.return
}

// -----

tt.func @invalid_desc_load(%arg0: !tt.tensordesc<tensor<16x16xf32>>) {
%c = arith.constant 0 : i32
// expected-error @below {{tensor desciptor block and tensor types must match}}
tt.experimental_descriptor_load %arg0[%c, %c] : !tt.tensordesc<tensor<16x16xf32>> -> tensor<16x16xf16>
tt.return
}

// -----

tt.func @invalid_desc_store(%arg0: !tt.tensordesc<tensor<16x16xf32>>, %arg1: tensor<32x16xf32>) {
%c = arith.constant 0 : i32
// expected-error @below {{tensor desciptor block and tensor types must match}}
tt.experimental_descriptor_store %arg0[%c, %c], %arg1 : !tt.tensordesc<tensor<16x16xf32>>, tensor<32x16xf32>
tt.return
}
Loading