diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 56e078463d20..3363b5c9a1fa 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -255,6 +255,22 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- // ----- +#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#smem = #ttg.shared_memory +// CHECK-LABEL: distribute_to_shared_st_matrix_local_store +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + tt.func @distribute_to_shared_st_matrix_local_store(%a: tensor<128x128xf16, #mma>) { + // CHECK-COUNT-16: nvgpu.stmatrix + // CHECK: llvm.return + %b = ttg.local_alloc {allocation.offset = 0 : i32} : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> + ttg.local_store %a, %b : tensor<128x128xf16, #mma> -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> + tt.return + } +} + +// ----- + #blocked = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func @fp8_const(%arg0: tensor<1024xi1, #blocked>, %arg1: tensor<1024xf8E4M3FNUZ, #blocked>) attributes {noinline = false} { diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp index 380f549cc677..749031d5538e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp @@ -123,6 +123,78 @@ struct LocalLoadOpConversion } }; +LogicalResult lowerDistributedToSharedStmatrix( + Location loc, TypedValue src, MemDescType memDescType, + Value adaptorSrc, Value smemBase, const TypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo, + std::pair *const llvmOpCount = nullptr) { + auto mmaEncoding = + dyn_cast(src.getType().getEncoding()); + if (!mmaEncoding) + return failure(); + auto sharedLayout = + cast(memDescType.getEncoding()); + if (!sharedLayout.getHasLeadingOffset()) + return failure(); + int swizzleByteSize = 0; + if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2) + swizzleByteSize = 32; + else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4) + swizzleByteSize = 64; + else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8) + swizzleByteSize = 128; + else + return failure(); + + RankedTensorType srcTy = src.getType(); + SmallVector shape = + convertType(srcTy.getShape()); + auto order = sharedLayout.getOrder(); + if (!targetInfo.canUseStMatrix(srcTy, shape, shape, order, swizzleByteSize)) { + return failure(); + } + + auto *ctx = rewriter.getContext(); + + auto layout = chooseStMatrixLayout(rewriter.getContext(), srcTy, shape, shape, + order, swizzleByteSize); + auto llvmElemTy = typeConverter->convertType(memDescType.getElementType()); + auto smemPtrTy = ptr_ty(ctx, 3); + + auto kRegister = str_attr("register"); + auto kLane = str_attr("lane"); + auto kWarp = str_attr("warp"); + auto kBlock = str_attr("block"); + + Value threadId = getThreadId(rewriter, loc); + Value threadsPerWarp = i32_val(layout.getInDimSize(kLane)); + Value laneId = urem(threadId, threadsPerWarp); + Value warpId = udiv(threadId, threadsPerWarp); + + auto regBase = applyLinearLayout(loc, rewriter, layout, + {{kRegister, i32_val(0)}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, i32_val(0)}})[0] + .second; + auto srcVals = unpackLLElements(loc, adaptorSrc, rewriter); + auto srcVec = layout.getNumConsecutiveInOut(); + for (int i = 0; i < srcVals.size(); i += srcVec) { + auto regIdx = + layout.apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}})[0] + .second; + Value offset = xor_(regBase, i32_val(regIdx)); + auto vecAddr = gep(smemPtrTy, llvmElemTy, smemBase, offset); + vecAddr.setInbounds(true); + SmallVector inValsVec; + for (int j = 0; j < srcVec; j++) + inValsVec.push_back(srcVals[i + j]); + Value valsVec = packLLVector(loc, inValsVec, rewriter); + targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec); + } + return success(); +} + struct LocalAllocOpConversion : public ConvertOpToLLVMPattern { LocalAllocOpConversion(const LLVMTypeConverter &converter, @@ -136,82 +208,61 @@ struct LocalAllocOpConversion ConversionPatternRewriter &rewriter) const override { if (!op.getSrc()) return failure(); - auto mmaEncoding = dyn_cast( - op.getSrc().getType().getEncoding()); - if (!mmaEncoding) - return failure(); + MemDescType memDescType = op.getType(); auto sharedLayout = - cast(op.getType().getEncoding()); - if (!sharedLayout.getHasLeadingOffset()) - return failure(); - int swizzleByteSize = 0; - if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2) - swizzleByteSize = 32; - else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4) - swizzleByteSize = 64; - else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8) - swizzleByteSize = 128; - else - return failure(); - - auto *ctx = rewriter.getContext(); - Location loc = op->getLoc(); - + cast(memDescType.getEncoding()); RankedTensorType srcTy = op.getSrc().getType(); - SmallVector shape = - convertType(srcTy.getShape()); - auto order = sharedLayout.getOrder(); - if (!targetInfo.canUseStMatrix(srcTy, shape, shape, order, - swizzleByteSize)) { - return failure(); - } - auto layout = chooseStMatrixLayout(rewriter.getContext(), srcTy, shape, - shape, order, swizzleByteSize); - Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op); - auto smemPtrTy = ptr_ty(ctx, 3); - - auto kRegister = str_attr("register"); - auto kLane = str_attr("lane"); - auto kWarp = str_attr("warp"); - auto kBlock = str_attr("block"); - - Value threadId = getThreadId(rewriter, loc); - Value threadsPerWarp = i32_val(layout.getInDimSize(kLane)); - Value laneId = urem(threadId, threadsPerWarp); - Value warpId = udiv(threadId, threadsPerWarp); - - auto regBase = applyLinearLayout(loc, rewriter, layout, - {{kRegister, i32_val(0)}, - {kLane, laneId}, - {kWarp, warpId}, - {kBlock, i32_val(0)}})[0] - .second; - auto srcVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); - auto srcVec = layout.getNumConsecutiveInOut(); Type llvmElemTy = typeConverter->convertType(srcTy.getElementType()); - for (int i = 0; i < srcVals.size(); i += srcVec) { - auto regIdx = - layout.apply({{kRegister, i}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}})[0] - .second; - Value offset = xor_(regBase, i32_val(regIdx)); - auto vecAddr = gep(smemPtrTy, llvmElemTy, smemBase, offset); - vecAddr.setInbounds(true); - SmallVector inValsVec; - for (int j = 0; j < srcVec; j++) - inValsVec.push_back(srcVals[i + j]); - Value valsVec = packLLVector(loc, inValsVec, rewriter); - targetInfo.storeMatrixShared(rewriter, loc, vecAddr, valsVec); + Value smemBase = + LLVM::getSharedMemoryBase(op.getLoc(), rewriter, targetInfo, op); + + if (lowerDistributedToSharedStmatrix(op.getLoc(), op.getSrc(), memDescType, + adaptor.getSrc(), smemBase, + typeConverter, rewriter, targetInfo) + .failed()) { + return failure(); } auto resultTy = cast(op.getType()); auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape()); auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, shapePerCTA, - sharedLayout, loc, rewriter); - auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); + sharedLayout, op.getLoc(), rewriter); + auto retVal = + getStructFromSharedMemoryObject(op.getLoc(), smemObj, rewriter); rewriter.replaceOp(op, retVal); return success(); } +private: + const NVIDIA::TargetInfo &targetInfo; +}; + +struct LocalStoreOpConversion + : public ConvertOpToLLVMPattern { + LocalStoreOpConversion(const LLVMTypeConverter &converter, + const NVIDIA::TargetInfo &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(converter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type llvmElemTy = + getTypeConverter()->convertType(op.getDst().getType().getElementType()); + SharedMemoryObject smemObj = LLVM::getSharedMemoryObjectFromStruct( + op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter); + MemDescType memDescType = op.getDst().getType(); + if (lowerDistributedToSharedStmatrix( + op.getLoc(), op.getSrc(), memDescType, adaptor.getSrc(), + smemObj.getBase(), getTypeConverter(), rewriter, targetInfo) + .failed()) { + return failure(); + } + rewriter.eraseOp(op); + return success(); + } + private: const NVIDIA::TargetInfo &targetInfo; }; @@ -223,6 +274,8 @@ void mlir::triton::NVIDIA::populateMemoryOpToLLVMPatterns( // Backend optimized memory ops get higher benefit patterns.add(typeConverter, targetInfo, benefit.getBenefit() + 1); + patterns.add(typeConverter, targetInfo, + benefit.getBenefit() + 1); patterns.add(typeConverter, benefit.getBenefit() + 1); mlir::triton::populateMemoryOpToLLVMPatterns(typeConverter, targetInfo, patterns, benefit);