diff --git a/test/TritonGPU/amd/amd-block-pingpong.mlir b/test/TritonGPU/amd/amd-block-pingpong.mlir index 031bf9ff102b..a761cac37666 100644 --- a/test/TritonGPU/amd/amd-block-pingpong.mlir +++ b/test/TritonGPU/amd/amd-block-pingpong.mlir @@ -86,43 +86,43 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: tt.load // CHECK: %[[SLICEA0:.+]] = ttg.local_load // CHECK: %[[SLICEB0:.+]] = ttg.local_load -// CHECK: rocdl.sched.barrier 0 // CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 // CHECK: rocdl.s.setprio 1 // CHECK: %[[DOT0:.+]] = tt.dot %[[SLICEA0]], %[[SLICEB0]] // CHECK: rocdl.s.setprio 0 -// CHECK: rocdl.sched.barrier 0 // CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 // CHECK: tt.load // CHECK: %[[SLICEA1:.+]] = ttg.local_load // CHECK: %[[SLICEB1:.+]] = ttg.local_load -// CHECK: rocdl.sched.barrier 0 // CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 // CHECK: rocdl.s.setprio 1 // CHECK: %[[DOT1:.+]] = tt.dot %[[SLICEA1]], %[[SLICEB1]], %[[DOT0]] // CHECK: rocdl.s.setprio 0 -// CHECK: rocdl.sched.barrier 0 // CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 // CHECK: %[[SLICEA2:.+]] = ttg.local_load // CHECK: %[[SLICEB2:.+]] = ttg.local_load // CHECK: %[[SLICEA3:.+]] = ttg.local_load // CHECK: %[[SLICEB3:.+]] = ttg.local_load -// CHECK: rocdl.sched.barrier 0 // CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 // CHECK: rocdl.s.setprio 1 // CHECK: %[[DOT2:.+]] = tt.dot %[[SLICEA2]], %[[SLICEB2]], %[[DOT1]] // CHECK: rocdl.s.setprio 0 -// CHECK: rocdl.sched.barrier 0 // CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 // CHECK: ttg.local_store // CHECK: ttg.local_store -// CHECK: rocdl.sched.barrier 0 // CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 // CHECK: rocdl.s.setprio 1 // CHECK: tt.dot %[[SLICEA3]], %[[SLICEB3]], %[[DOT2]] // CHECK: rocdl.s.setprio 0 -// CHECK: rocdl.sched.barrier 0 // CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 // CHECK: scf.yield // CHECK: amdgpu.cond_barrier %[[WARPLOW]] @@ -169,9 +169,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %27 = tt.load %26 : tensor<256x64x!tt.ptr, #blocked1> %28 = tt.addptr %arg13, %cst_0 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> %29 = tt.load %28 : tensor<64x256x!tt.ptr, #blocked> - %30 = ttg.local_load %arg15 : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> - %31 = ttg.local_load %arg16 : !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> - %32 = tt.dot %30, %31, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<256x256xf32, #mma> + %30 = ttg.local_load %arg15 : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %31 = ttg.local_load %arg16 : !ttg.memdesc<64x256xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %32 = tt.dot %30, %31, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x256xf32, #mma> %33 = arith.addi %arg14, %c1_i32 : i32 %34 = arith.cmpi slt, %33, %c1_i32 : i32 %35 = arith.select %34, %33, %c0_i32 : i32 @@ -189,6 +189,105 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // ----- +// CHECK: gpu.barrier +// CHECK: %[[IDX:.+]] = rocdl.workitem.id.x +// CHECK: %[[XDIV:.+]] = arith.divsi %[[IDX]] +// CHECK: %[[WARPLOW:.+]] = arith.cmpi eq, %[[XDIV]] +// CHECK: %[[WARPHIGH:.+]] = arith.cmpi ne, %[[XDIV]] +// CHECK: amdgpu.cond_barrier %[[WARPHIGH]] +// CHECK: scf.for + +// CHECK: %[[SLICEA0:.+]] = ttg.local_load +// CHECK: %[[SLICEB0:.+]] = ttg.local_load +// CHECK: rocdl.sched.barrier 0 +// CHECK: tt.load +// CHECK: rocdl.sched.barrier 0 +// CHECK: %[[SLICEA1:.+]] = ttg.local_load +// CHECK: %[[SLICEB1:.+]] = ttg.local_load +// CHECK: rocdl.sched.barrier 0 +// CHECK: tt.load +// CHECK: rocdl.s.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: rocdl.s.setprio 1 +// CHECK: %[[DOT0:.+]] = tt.dot %[[SLICEA0]], %[[SLICEB0]] +// CHECK: rocdl.s.setprio 0 +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: ttg.local_store +// CHECK: ttg.local_store +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: rocdl.s.setprio 1 +// CHECK: %[[DOT1:.+]] = tt.dot %[[SLICEA1]], %[[SLICEB1]], %[[DOT0]] +// CHECK: rocdl.s.setprio 0 +// CHECK: gpu.barrier +// CHECK: rocdl.sched.barrier 0 +// CHECK: scf.yield +// CHECK: amdgpu.cond_barrier %[[WARPLOW]] + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#loc = loc("/home/jung/rocm/triton/python/perf-kernels/tools/tune_gemm/matmul_kernel.py":6:0) +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @pingpong_medium(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<64> : tensor<64x128xi32, #blocked> + %cst_1 = arith.constant dense<64> : tensor<256x64xi32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x1x!tt.ptr, #blocked1> + %1 = tt.get_program_id x : i32 + %2 = tt.splat %1 : i32 -> tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %4 = arith.addi %2, %3 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1> + %6 = tt.splat %arg6 : i32 -> tensor<256x1xi32, #blocked1> + %7 = arith.muli %5, %6 : tensor<256x1xi32, #blocked1> + %8 = tt.addptr %0, %7 : tensor<256x1x!tt.ptr, #blocked1>, tensor<256x1xi32, #blocked1> + %9 = tt.broadcast %8 : tensor<256x1x!tt.ptr, #blocked1> -> tensor<256x64x!tt.ptr, #blocked1> + %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %12 = tt.broadcast %11 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1> + %13 = tt.addptr %9, %12 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> + %14 = tt.splat %arg1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %17 = tt.addptr %14, %16 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %18 = tt.broadcast %17 : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> + %19 = tt.splat %arg7 : i32 -> tensor<64x128xi32, #blocked> + %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> + %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + %25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg11 = %cst, %arg12 = %13, %arg13 = %20, %arg14 = %c0_i32, %arg15 = %23, %arg16 = %24) -> (tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable>) : i32 { + %26 = tt.addptr %arg12, %cst_1 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> + %27 = tt.load %26 : tensor<256x64x!tt.ptr, #blocked1> + %28 = tt.addptr %arg13, %cst_0 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %29 = tt.load %28 : tensor<64x128x!tt.ptr, #blocked> + %30 = ttg.local_load %arg15 : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %31 = ttg.local_load %arg16 : !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %32 = tt.dot %30, %31, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x128xf32, #mma> + %33 = arith.addi %arg14, %c1_i32 : i32 + %34 = arith.cmpi slt, %33, %c1_i32 : i32 + %35 = arith.select %34, %33, %c0_i32 : i32 + %36 = ttg.memdesc_subview %21[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %27, %36 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %37 = ttg.memdesc_subview %22[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %29, %37 : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + scf.yield %32, %26, %28, %35, %36, %37 : tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xf16, #shared1, #ttg.shared_memory, mutable> + } + ttg.local_dealloc %21 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %22 : !ttg.memdesc<1x64x128xf16, #shared1, #ttg.shared_memory, mutable> + tt.return + } +} + +// ----- + // CHECK-LABEL: pingpong_reject // CHECK-COUNT-2: local_load // CHECK-NOT: local_load diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index c25ebd54e981..b5aebe007cd7 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -62,8 +62,11 @@ class Pingponger { LogicalResult genLocalSlice(OpBuilder &builder, Value v, Attribute dotEncoding, unsigned opIdx, unsigned numSlices, int64_t sliceWidth); + LogicalResult sliceDot(OpBuilder &builder, Location loc, tt::DotOp op, + unsigned numSlices); void transformOnePPClusters(OpBuilder &builder, Location loc); LogicalResult transformFourPPClusters(OpBuilder &builder, Location loc); + LogicalResult transformTwoPPClusters(OpBuilder &builder, Location loc); void addAsymmetricSyncToLoop(OpBuilder &builder, Location loc); void updateOpInsertion(Operation *Op); void appendOp(Operation *Op); @@ -92,10 +95,10 @@ void Pingponger::appendSlicedLoadAB(int slice) { // Also, SchedBarrier with `0` is set here to tell compiler backend not to // reorder any instruction across this point. void Pingponger::appendClusterBarrier(OpBuilder &builder, Location loc) { - appendOp(builder.create(loc, 0)); - // MembarAnalysis can recognize gpu::BarrierOp and skip inserting additional - // barrier + // MembarAnalysis can recognize gpu::BarrierOp and skip inserting additional + // barrier appendOp(builder.create(loc)); + appendOp(builder.create(loc, 0)); } void Pingponger::appendOpWithPrio(OpBuilder &builder, Operation *op, Location loc) { @@ -150,8 +153,6 @@ LogicalResult Pingponger::genLocalSlice(OpBuilder &builder, Value v, SmallVector subviews; auto memDesc = v.getDefiningOp()->getOperand(0); auto type = cast(memDesc.getType()); - auto encoding = cast(v.getType()).getEncoding(); - auto srcEncoding = cast(encoding); SmallVector shape = llvm::to_vector(type.getShape()); Type elementType = type.getElementType(); int64_t kIdx = opIdx == 0 ? 1 : 0; @@ -160,7 +161,7 @@ LogicalResult Pingponger::genLocalSlice(OpBuilder &builder, Value v, if (sliceWidth < 16) return failure(); auto dotOperandEnc = ttg::DotOperandEncodingAttr::get( - builder.getContext(), opIdx, dotEncoding, srcEncoding.getKWidth()); + builder.getContext(), opIdx, dotEncoding, kWidth); auto subviewDescType = ttg::MemDescType::get( shape, elementType, type.getEncoding(), type.getMemorySpace()); for (int i = 0; i < numSlices; i++) { @@ -183,32 +184,11 @@ LogicalResult Pingponger::genLocalSlice(OpBuilder &builder, Value v, return success(); } -// Transform a loop into four Dot - Memory (ping - pong) clusters -// This transfrom is useful when the original dot tile is too large that there's -// no enough register to hold data for a Dot cluster. This path slices the dot -// into four pieces and pair with four clusters of reordered memory operations. -// There are multiple guards at the boundary of each cluster. -// (1) sched.barrier : with mask0 to prevent compiler backed from reroder -// instructions across the boundary -// (2) gpu.barrier : ensures asymmetric synchronization at each point -// (3) setprio (1->0) : in order to avoid incomming warp overtaking resource -// while the other warp is actively using it. -// -// Here's overview of the instruction clusters -// mem0: global load A, local load A(1/4), local load B(1/4) -// dot0: dot A(1/4) * B(1/4) -// mem1: global load B, local load A(2/4), local load B(2/4) -// dot1: dot A(2/4) * B(2/4) -// mem2: local load A(3/4, 4/4), local load B(3/4, 4/4) -// dot2: dot A(3/4) * B(3/4) -// mem3: local store A and B -// dot3: dot A(4/4) * B(4/4) - -LogicalResult Pingponger::transformFourPPClusters(OpBuilder &builder, - Location loc) { - // First, slice local_loads and dot into 4 parts - unsigned numSlices = 4; - auto op = cast(dotOps[0]); +// Split dot into 'numSlices' pieces. This is required by pingpong scheduling +// when it needs to schedule multiple dot clusters. Calls genLocalSlice to +// create corresponding local_load slices. +LogicalResult Pingponger::sliceDot(OpBuilder &builder, Location loc, + tt::DotOp op, unsigned numSlices) { builder.setInsertionPointToStart(forOp.getBody()); auto typeB = op.getB().getType(); auto shapeB = typeB.getShape(); @@ -224,7 +204,7 @@ LogicalResult Pingponger::transformFourPPClusters(OpBuilder &builder, .failed()) return failure(); - // Clone dots four times to consume the slices + // Clone dots to consume all the slices Operation *prevDot = op; for (int i = 0; i < numSlices; i++) { IRMapping mapping; @@ -240,11 +220,41 @@ LogicalResult Pingponger::transformFourPPClusters(OpBuilder &builder, op->erase(); for (auto loads : lLoadOps) loads->erase(); + return success(); +} +// Transform a loop into four Dot - Memory (ping - pong) clusters +// This transfrom is useful when the original dot tile is too large that there's +// no enough register to hold data for a Dot cluster. This path slices the dot +// into four pieces and pair with four clusters of reordered memory operations. +// There are multiple guards at the boundary of each cluster. +// (1) sched.barrier : with mask0 to prevent compiler backed from reroder +// instructions across the boundary +// (2) gpu.barrier : ensures asymmetric synchronization at each point +// (3) setprio (1->0) : in order to avoid incomming warp overtaking resource +// while the other warp is actively using it. +// +// Here's overview of the instruction clusters +// mem0: global load A, local load A(1/4), local load B(1/4) +// dot0: dot A(1/4) * B(1/4) +// mem1: global load B, local load A(2/4), local load B(2/4) +// dot1: dot A(2/4) * B(2/4) +// mem2: local load A(3/4, 4/4), local load B(3/4, 4/4) +// dot2: dot A(3/4) * B(3/4) +// mem3: local store A and B +// dot3: dot A(4/4) * B(4/4) + +LogicalResult Pingponger::transformFourPPClusters(OpBuilder &builder, + Location loc) { + // First, slice local_loads and dot into 4 parts + if (sliceDot(builder, loc, dotOps[0], 4).failed()) + return failure(); builder.setInsertionPointAfter(gLoadOps[1]); // Reorder operations into four mem/dot clusters // mem0: global load A, local load A(1/4), local load B(1/4) + // set insertion point at the last global_load where all the addresses are + // ready to be used. updateOpInsertion(gLoadOps[1]); appendSlicedLoadAB(/*slice=*/0); appendClusterBarrier(builder, loc); @@ -282,6 +292,51 @@ LogicalResult Pingponger::transformFourPPClusters(OpBuilder &builder, return success(); } +// Transform a loop into two Dot - Memory (ping - pong) clusters +// This is useful for the medium sized tile which doesn't fit to either one/four +// cluster scheduling. +LogicalResult Pingponger::transformTwoPPClusters(OpBuilder &builder, + Location loc) { + // First, slice local_loads and dot into 2 parts + if (sliceDot(builder, loc, dotOps[0], 2).failed()) + return failure(); + // Reorder operations into two mem/dot clusters + + // Memory cluster #0 + // interleave local_loads and global_loads to minimize the stalling + // cycles, sched.barrier prevents backend from canceling the interleaved order + updateOpInsertion(gLoadOps[1]); + appendSlicedLoadAB(/*slice=*/0); + appendOp(builder.create(loc, 0)); + appendOp(gLoadOps[0]); + appendOp(builder.create(loc, 0)); + appendSlicedLoadAB(/*slice=*/1); + appendOp(builder.create(loc, 0)); + appendOp(gLoadOps[1]); + // The first cluster just fits into the two cluster pingpong and cannot + // include wait of the local_load inserted by the gpu.barrier, using s.barrier + // instead. backend will schedule the local memory fences later in the dot0 + // cluster. + appendOp(builder.create(loc)); + appendOp(builder.create(loc, 0)); + + // dot0 (1/2) + appendOpWithPrio(builder, dotSliceOps[0], loc); + appendClusterBarrier(builder, loc); + + // mem1: local store A and B + // Don't need to reorder local_stores, just add cluster barrier after the the + // last store + updateOpInsertion(lStoreOps[1]); + appendClusterBarrier(builder, loc); + + // dot1 (2/2) + appendOpWithPrio(builder, dotSliceOps[1], loc); + appendClusterBarrier(builder, loc); + + return success(); +} + // This function wraps forOp with cond_barrier. First, hold half of the warps // (warpHigh) in a block before the loop so the barriers in the loop synchronize // warps at the different point per the warp groups. After the loop, hold @@ -317,7 +372,6 @@ void Pingponger::getDotPingponged() { OpBuilder builder(forOp); MLIRContext *ctx = forOp.getContext(); Location loc = forOp.getLoc(); - auto f16ty = builder.getF16Type(); forOp->walk([&](Operation *op) { if (auto gLoad = dyn_cast(op)) @@ -364,17 +418,57 @@ void Pingponger::getDotPingponged() { // GPU to hold all the data for the calculation. Such large tile size // exceeds the amount of register GPU has so, we need to split the dot // into several pieces. + // + // (3) Twp Dot-Memory (ping-pongx2) clusters + // :Covers medium sized tile e.g., 256x128x64_FP16. Different tile size may + // require different scheduling pattern because the loop consists of + // different amount of memory transfer and dot operation. This scheduling + // support the tile sizes not supported by above two methods. + // + // N.B., Tile size smaller than 128x128x64_FP16 is likely not compute-bound + // that pingpong scheduling doesn't help much. + + auto dotType = dotOps[0].getType(); + auto dotShape = dotType.getShape(); + auto aType = dotOps[0].getA().getType(); + auto aShape = aType.getShape(); + auto elemWidth = aType.getElementTypeBitWidth(); + int64_t tileSize = dotShape[0] * dotShape[1] * aShape[1] * elemWidth; + + const int64_t smallTile = 16777216; // e.g. 128x128x64x16bit + const int64_t mediumTile = 33554432; // smallTile x 2 + const int64_t largeTile = 67108864; // e.g. 256x256x64x16bit - // TODO: - // - Add transformTwoPPClusters for the medium size tiles - // - Add definition of small/medium/large tile size considering data-type - // so we can choose the transfrom per given tile size. - if (numWarps == 4) { // pingpong between warps from different blocks - // transfor a loop with small tile size - transformOnePPClusters(builder, loc); - } else if (numWarps == 8) { // pingpong between warps from the same block - // transfor a loop with large tile size which requires dots to be sliced - if (transformFourPPClusters(builder, dotOps[0]->getLoc()).failed()) + auto encoding = cast(aType).getEncoding(); + auto srcEncoding = cast(encoding); + kWidth = srcEncoding.getKWidth(); + auto mfmaEncoding = cast(srcEncoding.getParent()); + SmallVector intShape; + intShape.push_back(mfmaEncoding.getMDim()); + intShape.push_back(mfmaEncoding.getNDim()); + + if (numWarps == 4) { // Pingpong between warps from different blocks + // Transform a loop with small tile size. + // We've observed that this small tile size spent almost equivalent cycle + // times for issuing the memory operations and issuing dot operations, + // smaller tile sizes are not likely to get any advantage from current dot + // centric pingpong scheduling. + if (tileSize == smallTile) + transformOnePPClusters(builder, loc); + // numWarps=4 doesn't need asymmetric sync, return. + return; + } else if (numWarps == 8) { // Pingpong between warps from the same block + // Transform a loop where the tile size requires dots to be sliced + if (tileSize == mediumTile) { + if (transformTwoPPClusters(builder, dotOps[0]->getLoc()).failed()) + return; + } else if (tileSize >= largeTile) { + // Avoid known register spilling. i.e., mfma16x16x16 & largetile & kpack>1 + if (intShape[0] == 16 && intShape[1] == 16 && kWidth == 8) + return; + if (transformFourPPClusters(builder, dotOps[0]->getLoc()).failed()) + return; + } else return; // Let half of the warps start the loop first and the others follow later