Skip to content

Commit

Permalink
[FRONTEND][BACKEND] tl.atomic_load and tl.atomic_store with codegen f…
Browse files Browse the repository at this point in the history
…or PTX

PTX ld and st support scopes { .cta, .cluster, .gpu, .sys } and
sub-opcodes relaxed and acquire.

It includes TT::AtomicLoadOp/StoreOp to support tl.atomic_load/store and
implements codegen for MemSemantic flags and Scope flags for atomic load
and store ops.

The purpose is to generate:

ld{ .relaxed, .acquire }{ .cta, .cluster, .gpu, .sys }
st{ .relaxed, .release }{ .cta, .cluster, .gpu, .sys }

in order to syncronize groups of threads during a cooperative thread
launch.
  • Loading branch information
plotfi committed Dec 18, 2024
1 parent 80e2abd commit 7585650
Show file tree
Hide file tree
Showing 9 changed files with 551 additions and 0 deletions.
65 changes: 65 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,71 @@ def TT_StoreOp : TT_Op<"store", [
let hasCanonicalizer = 1;
}

//
// Atomic Load/Store Ops
//
def TT_AtomicLoadOp : TT_Op<"atomic_load", [
SameLoadStoreOperandsAndResultShape,
SameLoadStoreOperandsAndResultEncoding,
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
TypesMatchWith<"result matches ptr type", "ptr", "result", "getPointeeType($_self)">,
TypesMatchWith<"mask type matches ptr type", "ptr", "mask", "getI1SameShape(getPointeeType($_self))",
"($_op.getOperands().size() <= 1) || std::equal_to<>()">,
TypesMatchWith<"other matches ptr type", "ptr", "other", "getPointeeType($_self)",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">
]> {
let summary = "Atomically Load a scalar using scoped (cta, gpu, system) total ordering semantics (acquire, relaxed)";

let arguments = (
ins
AnyTypeOf<[TT_PtrLike]>:$ptr,
TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope,
Optional<TT_BoolLike>:$mask,
Optional<TT_Type>:$other
);

let results = (outs TT_Type:$result);

let assemblyFormat = [{
$ptr
`,` `memSemantic` `=` $sem
`,` `memSyncScope` `=` $scope
(`,` $mask^)? (`,` $other^)?
attr-dict `:` type($ptr)
}];
}

def TT_AtomicStoreOp : TT_Op<"atomic_store", [
SameLoadStoreOperandsShape,
SameLoadStoreOperandsEncoding,
MemoryEffects<[MemWrite<GlobalMemory>]>,
TypesMatchWith<"value type matches ptr type", "ptr", "value",
"getPointeeType($_self)">,
TypesMatchWith<"mask type matches ptr type", "ptr", "mask",
"getI1SameShape(getPointeeType($_self))",
"($_op.getOperands().size() <= 2) || std::equal_to<>()">
]> {
let summary = "Atomically Store a scalar using scoped (cta, gpu, system) total ordering semantics (release, relaxed)";

let arguments = (
ins
AnyTypeOf<[TT_PtrLike]>:$ptr,
TT_Type:$value,
TT_MemSemanticAttr:$sem, TT_MemSyncScopeAttr:$scope,
Optional<TT_BoolLike>:$mask
);

let assemblyFormat = [{
$ptr `,` $value
`,` `memSemantic` `=` $sem
`,` `memSyncScope` `=` $scope
(`,` $mask^)?
attr-dict `:` type($ptr)
}];
}

//
// Atomic Ops
//
Expand Down
10 changes: 10 additions & 0 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ void LoadOp::getEffects(
SideEffects::DefaultResource::get());
}

void AtomicLoadOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable(),
triton::GlobalMemory::get());
// Always assume volatile for now:
effects.emplace_back(MemoryEffects::Write::get(),
SideEffects::DefaultResource::get());
}

} // namespace triton
} // namespace mlir

Expand Down
25 changes: 25 additions & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1459,6 +1459,31 @@ void init_triton_ir(py::module &&m) {
return self.create<AtomicRMWOp>(dstType, rmwOp, ptr, val, mask,
sem, scope);
})
.def("create_atomic_load",
[](TritonOpBuilder &self, Value &ptrs, MemSemantic sem,
MemSyncScope scope) -> Value {
Value mask = {};
Value other = {};
return self.create<AtomicLoadOp>(ptrs, sem, scope, mask, other);
})
.def("create_atomic_store",
[](TritonOpBuilder &self, Value &ptrs, Value &value, MemSemantic sem,
MemSyncScope scope) -> void {
Value mask = {};
self.create<AtomicStoreOp>(ptrs, value, sem, scope, mask);
})
.def("create_masked_atomic_load",
[](TritonOpBuilder &self, Value &ptrs, Value &mask,
std::optional<Value> &other, MemSemantic sem,
MemSyncScope scope) -> Value {
return self.create<AtomicLoadOp>(ptrs, sem, scope, mask,
other.value_or(Value()));
})
.def("create_masked_atomic_store",
[](TritonOpBuilder &self, Value &ptrs, Value &value, Value &mask,
MemSemantic sem, MemSyncScope scope) -> void {
self.create<AtomicStoreOp>(ptrs, value, sem, scope, mask);
})
// External
.def("create_extern_elementwise",
[](TritonOpBuilder &self, const std::string &libName,
Expand Down
86 changes: 86 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,92 @@ def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr):
assert (torch.equal(X, Y))


@pytest.mark.interpreter
def test_load_scope_sem(device):

@triton.jit
def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr):
numel = 512
offset = tl.program_id(0) * BLOCK_SIZE
index = offset
mask = index < numel

a = tl.atomic_load(ptrs, sem='acquire', scope='gpu', mask=mask)
tl.atomic_store(ptrs, a, sem='release', scope='gpu')

a = tl.atomic_load(ptrs, sem='acquire', scope='cta')
tl.atomic_store(ptrs, a + 1, sem='release', scope='cta', mask=mask)

# cluster not handled in Triton yet
# a = a + tl.load(ptrs, sem='acquire', scope='cluster')
# tl.store(ptrs, a + 1, sem='release', scope='cluster')

a = a + tl.atomic_load(ptrs, sem='acquire', scope='sys', mask=mask)
tl.atomic_store(ptrs, a + 1, sem='release', scope='sys')

########### relaxed:

a = tl.atomic_load(ptrs, sem='relaxed', scope='gpu', mask=mask)
tl.atomic_store(ptrs, a, sem='relaxed', scope='gpu')

a = tl.atomic_load(ptrs, sem='relaxed', scope='cta')
tl.atomic_store(ptrs, a + 1, sem='relaxed', scope='cta')

# cluster not handled in Triton yet
# a = a + tl.load(ptrs, sem='relaxed', scope='cluster')
# tl.store(ptrs, a + 1, sem='relaxed', scope='cluster')

a = a + tl.atomic_load(ptrs, sem='relaxed', scope='sys', mask=mask)
tl.atomic_store(ptrs, a + 1, sem='relaxed', scope='sys')

block_size = 128
data = torch.zeros((128, ), device=device, dtype=torch.float32)

out = kernel_r[(2, )](data, BLOCK_SIZE=block_size)

asm = out.asm['ttir']
assert len(re.findall("atomic_load .*, memSemantic = acquire, memSyncScope = gpu", asm)) == 1
assert len(re.findall("atomic_store .*, memSemantic = release, memSyncScope = gpu", asm)) == 1

assert len(re.findall("atomic_load .*, memSemantic = acquire, memSyncScope = cta", asm)) == 1
assert len(re.findall("atomic_store .*, memSemantic = release, memSyncScope = cta", asm)) == 1

assert len(re.findall("atomic_load .*, memSemantic = acquire, memSyncScope = sys", asm)) == 1
assert len(re.findall("atomic_store .*, memSemantic = release, memSyncScope = sys", asm)) == 1

########### relaxed:

assert len(re.findall("atomic_load .*, memSemantic = relaxed, memSyncScope = gpu", asm)) == 1
assert len(re.findall("atomic_store .*, memSemantic = relaxed, memSyncScope = gpu", asm)) == 1

assert len(re.findall("atomic_load .*, memSemantic = relaxed, memSyncScope = cta", asm)) == 1
assert len(re.findall("atomic_store .*, memSemantic = relaxed, memSyncScope = cta", asm)) == 1

assert len(re.findall("atomic_load .*, memSemantic = relaxed, memSyncScope = sys", asm)) == 1
assert len(re.findall("atomic_store .*, memSemantic = relaxed, memSyncScope = sys", asm)) == 1

asm = out.asm['ptx']
assert len(re.findall("ld.global.gpu.acquire", asm)) == 1
assert len(re.findall("st.global.gpu.release", asm)) == 1

assert len(re.findall("ld.global.cta.acquire", asm)) == 1
assert len(re.findall("st.global.cta.release", asm)) == 1

assert len(re.findall("ld.global.sys.acquire", asm)) == 1
assert len(re.findall("st.global.sys.release", asm)) == 1

########### relaxed:

assert len(re.findall("ld.global.gpu.relaxed", asm)) == 1
assert len(re.findall("st.global.gpu.relaxed", asm)) == 1

assert len(re.findall("ld.global.cta.relaxed", asm)) == 1
assert len(re.findall("st.global.cta.relaxed", asm)) == 1

assert len(re.findall("ld.global.sys.relaxed", asm)) == 1
assert len(re.findall("st.global.sys.relaxed", asm)) == 1


# ---------------
# test cast
# ---------------
Expand Down
4 changes: 4 additions & 0 deletions python/triton/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
int8,
join,
load,
atomic_load,
make_block_ptr,
max_constancy,
max_contiguous,
Expand All @@ -101,6 +102,7 @@
static_print,
static_range,
store,
atomic_store,
tensor,
trans,
tuple,
Expand Down Expand Up @@ -203,6 +205,7 @@
"ir",
"join",
"load",
"atomic_load",
"log",
"log2",
"make_block_ptr",
Expand Down Expand Up @@ -246,6 +249,7 @@
"static_print",
"static_range",
"store",
"atomic_store",
"sum",
"swizzle2d",
"tensor",
Expand Down
46 changes: 46 additions & 0 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1821,6 +1821,30 @@ def load(pointer, mask=None, other=None, boundary_check=(), padding_option="", c
volatile, _builder)


@builtin
def atomic_load(pointer, sem, scope, mask=None, other=None, _builder=None):
"""
Return a scalar atomically loaded from memory at location defined by `pointer`:
`pointer` must be a single element pointer (a scalar is loaded)
- `mask` and `other` must also be scalars,
- `other` is implicitly typecast to `pointer.dtype.element_ty`, and
:param pointer: Pointer to the data to be loaded
:type pointer: `triton.PointerType`
:param mask: if `mask[idx]` is false, do not load the data at address `pointer[idx]`
:param other: if `mask[idx]` is false, return `other[idx]`
"""
# `mask` and `other` can be constexpr
mask = _constexpr_to_value(mask)
other = _constexpr_to_value(other)
if mask is not None:
mask = semantic.to_tensor(mask, _builder)
if other is not None:
other = semantic.to_tensor(other, _builder)
return semantic.atomic_load(pointer, sem, scope, mask, other, _builder)


@builtin
def _experimental_reinterpret_tensor_descriptor(desc_ptr, block_shape, dtype,
_builder=None) -> _experimental_tensor_descriptor_base:
Expand Down Expand Up @@ -1906,6 +1930,28 @@ def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", evict
return semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder)


@_tensor_member_fn
@builtin
def atomic_store(pointer, value, sem, scope, mask=None, _builder=None):
"""
Atomically store scalar data into memory location defined by `pointer`.
`pointer` must be a single element pointer (a scalar is stored)
- `mask` must also be scalar, and
:param pointer: The memory location where the elements of `value` are stored
:type pointer: `triton.PointerType`
:param value: The scalar element to be stored
:param mask: If `mask` is false, do not store `value` at `pointer`
"""
# `value` can be constexpr
value = semantic.to_tensor(value, _builder)
mask = _constexpr_to_value(mask)
if mask is not None:
mask = semantic.to_tensor(mask, _builder)
return semantic.atomic_store(pointer, value, sem, scope, mask, _builder)


@builtin
def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _builder=None):
"""
Expand Down
80 changes: 80 additions & 0 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,55 @@ def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor],
return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder)


def is_ptr_to_scalar(val: tl.tensor):
if val is None:
return True
if val.type.is_block():
return False
if val.type.is_ptr() and val.type.element_ty.is_block():
return False
if not val.type.scalar.is_ptr():
return False
return True


def atomic_load(ptr: tl.tensor, sem: str, scope: str, mask: Optional[tl.tensor], other: Optional[tl.tensor],
builder: ir.builder) -> tl.tensor:
if not is_ptr_to_scalar(ptr):
raise ValueError(f"Only pointers to scalars are supported for `tl.atomic_load`: {ptr.type.__repr__()}")
if mask is None and other is not None:
raise ValueError("`other` cannot be provided without `mask`")
if (mask and mask.type.is_block()) or (other and other.type.is_block()):
raise ValueError(f"`tl.atomic_load`'s Mask and Other arguments cannot be block type: {mask}, {other}")

sem = _str_to_sem(sem)
scope = _str_to_scope(scope)
ptr_ty = ptr.type.scalar
elt_ty = ptr_ty.element_ty

# Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
is_bool = elt_ty == tl.int1
if is_bool:
elt_ty = tl.int8
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
ptr = cast(ptr, ptr_ty, builder)

# Cast `other` into `elt_ty` type
if other is not None:
other = cast(other, elt_ty, builder)

dst_ty = elt_ty
if mask is None:
ret = tl.tensor(builder.create_atomic_load(ptr.handle, sem, scope), dst_ty)
else:
ret = tl.tensor(
builder.create_masked_atomic_load(ptr.handle, mask.handle, other.handle if other else None, sem, scope),
dst_ty)
if is_bool:
ret = cast(ret, tl.int1, builder)
return ret


def reinterpret_tensor_descriptor(desc_ptr: tl.tensor, block_ty: tl.block_type, builder: ir.builder):
handle = builder.create_reinterpret_tensor_descriptor(desc_ptr.handle, block_ty.to_ir(builder))
return tl._experimental_tensor_descriptor_base(handle, block_ty)
Expand Down Expand Up @@ -1287,6 +1336,37 @@ def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], boundary_ch
return _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder)


def atomic_store(ptr: tl.tensor, val: tl.tensor, sem: str, scope: str, mask: Optional[tl.tensor],
builder: ir.builder) -> tl.tensor:
if ptr.type.is_const() or ptr.type.scalar.is_const():
raise ValueError("Cannot `tl.atomic_store` to a constant pointer")
if not is_ptr_to_scalar(ptr):
raise ValueError(f"Only scalars are supported for `tl.atomic_store`: {ptr.type.__repr__()}")
if val.type.is_block() or (mask and mask.type.is_block()):
raise ValueError(f"`tl.atomic_store`'s Value and Mask arguments cannot be block type: {val}, {mask}")

sem = _str_to_sem(sem)
scope = _str_to_scope(scope)
ptr_ty = ptr.type.scalar
elt_ty = ptr_ty.element_ty

# Treat `pointer_type<tl.int1>` as `pointer_type<tl.int8>`
if elt_ty == tl.int1:
elt_ty = tl.int8
ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space)
ptr = cast(ptr, ptr_ty, builder)

# Cast to target data type
val = cast(val, elt_ty, builder)

# Build IR
if mask is None:
return tl.tensor(builder.create_atomic_store(ptr.handle, val.handle, sem, scope), tl.void)
if not mask.type.scalar.is_bool():
raise ValueError("Mask must have boolean scalar type")
return tl.tensor(builder.create_masked_atomic_store(ptr.handle, val.handle, mask.handle, sem, scope), tl.void)


#########
# atomic
#########
Expand Down
Loading

0 comments on commit 7585650

Please sign in to comment.