Skip to content

Commit

Permalink
Adding support for local_store lowering with stmatrix (#5556)
Browse files Browse the repository at this point in the history
Unifying lowering path of `local_alloc` with `local_store` for the case
shared mem layout has `leadingOffset`.
  • Loading branch information
pawelszczerbuk authored Jan 8, 2025
1 parent 70359fa commit f436c9e
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 64 deletions.
16 changes: 16 additions & 0 deletions test/Conversion/tritongpu_to_llvm_hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,22 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-

// -----

#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>
#smem = #ttg.shared_memory
// CHECK-LABEL: distribute_to_shared_st_matrix_local_store
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func @distribute_to_shared_st_matrix_local_store(%a: tensor<128x128xf16, #mma>) {
// CHECK-COUNT-16: nvgpu.stmatrix
// CHECK: llvm.return
%b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable>
ttg.local_store %a, %b : tensor<128x128xf16, #mma> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable>
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func @fp8_const(%arg0: tensor<1024xi1, #blocked>, %arg1: tensor<1024xf8E4M3FNUZ, #blocked>) attributes {noinline = false} {
Expand Down
181 changes: 117 additions & 64 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,78 @@ struct LocalLoadOpConversion
}
};

LogicalResult lowerDistributedToSharedStmatrix(
Location loc, TypedValue<RankedTensorType> src, MemDescType memDescType,
Value adaptorSrc, Value smemBase, const TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo,
std::pair<size_t, Type> *const llvmOpCount = nullptr) {
auto mmaEncoding =
dyn_cast<triton::gpu::NvidiaMmaEncodingAttr>(src.getType().getEncoding());
if (!mmaEncoding)
return failure();
auto sharedLayout =
cast<triton::gpu::SharedEncodingAttr>(memDescType.getEncoding());
if (!sharedLayout.getHasLeadingOffset())
return failure();
int swizzleByteSize = 0;
if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2)
swizzleByteSize = 32;
else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4)
swizzleByteSize = 64;
else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8)
swizzleByteSize = 128;
else
return failure();

RankedTensorType srcTy = src.getType();
SmallVector<unsigned> shape =
convertType<unsigned, int64_t>(srcTy.getShape());
auto order = sharedLayout.getOrder();
if (!targetInfo.canUseStMatrix(srcTy, shape, shape, order, swizzleByteSize)) {
return failure();
}

auto *ctx = rewriter.getContext();

auto layout = chooseStMatrixLayout(rewriter.getContext(), srcTy, shape, shape,
order, swizzleByteSize);
auto llvmElemTy = typeConverter->convertType(memDescType.getElementType());
auto smemPtrTy = ptr_ty(ctx, 3);

auto kRegister = str_attr("register");
auto kLane = str_attr("lane");
auto kWarp = str_attr("warp");
auto kBlock = str_attr("block");

Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = i32_val(layout.getInDimSize(kLane));
Value laneId = urem(threadId, threadsPerWarp);
Value warpId = udiv(threadId, threadsPerWarp);

auto regBase = applyLinearLayout(loc, rewriter, layout,
{{kRegister, i32_val(0)},
{kLane, laneId},
{kWarp, warpId},
{kBlock, i32_val(0)}})[0]
.second;
auto srcVals = unpackLLElements(loc, adaptorSrc, rewriter);
auto srcVec = layout.getNumConsecutiveInOut();
for (int i = 0; i < srcVals.size(); i += srcVec) {
auto regIdx =
layout.apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}})[0]
.second;
Value offset = xor_(regBase, i32_val(regIdx));
auto vecAddr = gep(smemPtrTy, llvmElemTy, smemBase, offset);
vecAddr.setInbounds(true);
SmallVector<Value> inValsVec;
for (int j = 0; j < srcVec; j++)
inValsVec.push_back(srcVals[i + j]);
Value valsVec = packLLVector(loc, inValsVec, rewriter);
targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec);
}
return success();
}

struct LocalAllocOpConversion
: public ConvertOpToLLVMPattern<triton::gpu::LocalAllocOp> {
LocalAllocOpConversion(const LLVMTypeConverter &converter,
Expand All @@ -136,82 +208,61 @@ struct LocalAllocOpConversion
ConversionPatternRewriter &rewriter) const override {
if (!op.getSrc())
return failure();
auto mmaEncoding = dyn_cast<triton::gpu::NvidiaMmaEncodingAttr>(
op.getSrc().getType().getEncoding());
if (!mmaEncoding)
return failure();
MemDescType memDescType = op.getType();
auto sharedLayout =
cast<triton::gpu::SharedEncodingAttr>(op.getType().getEncoding());
if (!sharedLayout.getHasLeadingOffset())
return failure();
int swizzleByteSize = 0;
if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2)
swizzleByteSize = 32;
else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4)
swizzleByteSize = 64;
else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8)
swizzleByteSize = 128;
else
return failure();

auto *ctx = rewriter.getContext();
Location loc = op->getLoc();

cast<triton::gpu::SharedEncodingAttr>(memDescType.getEncoding());
RankedTensorType srcTy = op.getSrc().getType();
SmallVector<unsigned> shape =
convertType<unsigned, int64_t>(srcTy.getShape());
auto order = sharedLayout.getOrder();
if (!targetInfo.canUseStMatrix(srcTy, shape, shape, order,
swizzleByteSize)) {
return failure();
}
auto layout = chooseStMatrixLayout(rewriter.getContext(), srcTy, shape,
shape, order, swizzleByteSize);
Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op);
auto smemPtrTy = ptr_ty(ctx, 3);

auto kRegister = str_attr("register");
auto kLane = str_attr("lane");
auto kWarp = str_attr("warp");
auto kBlock = str_attr("block");

Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = i32_val(layout.getInDimSize(kLane));
Value laneId = urem(threadId, threadsPerWarp);
Value warpId = udiv(threadId, threadsPerWarp);

auto regBase = applyLinearLayout(loc, rewriter, layout,
{{kRegister, i32_val(0)},
{kLane, laneId},
{kWarp, warpId},
{kBlock, i32_val(0)}})[0]
.second;
auto srcVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
auto srcVec = layout.getNumConsecutiveInOut();
Type llvmElemTy = typeConverter->convertType(srcTy.getElementType());
for (int i = 0; i < srcVals.size(); i += srcVec) {
auto regIdx =
layout.apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}})[0]
.second;
Value offset = xor_(regBase, i32_val(regIdx));
auto vecAddr = gep(smemPtrTy, llvmElemTy, smemBase, offset);
vecAddr.setInbounds(true);
SmallVector<Value> inValsVec;
for (int j = 0; j < srcVec; j++)
inValsVec.push_back(srcVals[i + j]);
Value valsVec = packLLVector(loc, inValsVec, rewriter);
targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec);
Value smemBase =
LLVM::getSharedMemoryBase(op.getLoc(), rewriter, targetInfo, op);

if (lowerDistributedToSharedStmatrix(op.getLoc(), op.getSrc(), memDescType,
adaptor.getSrc(), smemBase,
typeConverter, rewriter, targetInfo)
.failed()) {
return failure();
}

auto resultTy = cast<MemDescType>(op.getType());
auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape());
auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, shapePerCTA,
sharedLayout, loc, rewriter);
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
sharedLayout, op.getLoc(), rewriter);
auto retVal =
getStructFromSharedMemoryObject(op.getLoc(), smemObj, rewriter);
rewriter.replaceOp(op, retVal);
return success();
}

private:
const NVIDIA::TargetInfo &targetInfo;
};

struct LocalStoreOpConversion
: public ConvertOpToLLVMPattern<triton::gpu::LocalStoreOp> {
LocalStoreOpConversion(const LLVMTypeConverter &converter,
const NVIDIA::TargetInfo &targetInfo,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<triton::gpu::LocalStoreOp>(converter, benefit),
targetInfo(targetInfo) {}

LogicalResult
matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type llvmElemTy =
getTypeConverter()->convertType(op.getDst().getType().getElementType());
SharedMemoryObject smemObj = LLVM::getSharedMemoryObjectFromStruct(
op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter);
MemDescType memDescType = op.getDst().getType();
if (lowerDistributedToSharedStmatrix(
op.getLoc(), op.getSrc(), memDescType, adaptor.getSrc(),
smemObj.getBase(), getTypeConverter(), rewriter, targetInfo)
.failed()) {
return failure();
}
rewriter.eraseOp(op);
return success();
}

private:
const NVIDIA::TargetInfo &targetInfo;
};
Expand All @@ -223,6 +274,8 @@ void mlir::triton::NVIDIA::populateMemoryOpToLLVMPatterns(
// Backend optimized memory ops get higher benefit
patterns.add<LocalAllocOpConversion>(typeConverter, targetInfo,
benefit.getBenefit() + 1);
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo,
benefit.getBenefit() + 1);
patterns.add<LocalLoadOpConversion>(typeConverter, benefit.getBenefit() + 1);
mlir::triton::populateMemoryOpToLLVMPatterns(typeConverter, targetInfo,
patterns, benefit);
Expand Down

0 comments on commit f436c9e

Please sign in to comment.