Skip to content

Commit

Permalink
Add log1p elementwise op
Browse files Browse the repository at this point in the history
Summary:
`log1p(x)` is more precise than `log(1+x)` when `x` is close to 0. We utilize cuda `log1pf` implementation for fp32. For other precision types, input is first converted to float, then `log1pf` is computed, finally output is converted back to original precision.

CUDA log1pf function for float and double: https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__SINGLE.html

Differential Revision: D54176180
  • Loading branch information
22quinn authored and facebook-github-bot committed Mar 1, 2024
1 parent a99c753 commit 63ad837
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 17 deletions.
24 changes: 7 additions & 17 deletions fx2ait/fx2ait/acc_tracer/acc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,23 +1251,6 @@ def softsign(*, input):
return nn.functional.softsign(input=input)


@register_custom_acc_mapper_fn(
op_and_target=("call_function", torch.log1p),
arg_replacement_tuples=[
("input", "input"),
],
)
def torch_log1p_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Node:
with node.graph.inserting_before(node):
add_kwargs = {"input": node.kwargs["input"], "other": 1.0}
add_node = node.graph.call_function(add, kwargs=add_kwargs)
add_node.meta = node.meta.copy()
log_kwargs = {"input": add_node}
log_node = node.graph.call_function(log, kwargs=log_kwargs)
log_node.meta = node.meta.copy()
return log_node


def reduce_op_mapper(
node: torch.fx.Node, mod: torch.fx.GraphModule, func
) -> torch.fx.Node:
Expand Down Expand Up @@ -1782,6 +1765,13 @@ def log(*, input):
return torch.log(input=input)


@register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary)
@register_acc_op_mapping(op_and_target=("call_function", torch.log1p))
@register_acc_op
def log1p(*, input):
return torch.log1p(input=input)


@register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary)
@register_acc_op_mapping(op_and_target=("call_function", torch.sqrt))
@register_acc_op_mapping(op_and_target=("call_method", "sqrt"))
Expand Down
14 changes: 14 additions & 0 deletions fx2ait/fx2ait/converters/ait_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,20 @@ def acc_ops_log(
return elementwise(FuncEnum.LOGE)(input_val)


@ait_converter(acc_ops.log1p)
def acc_ops_log1p(
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> ConverterOutput:
input_val = kwargs["input"]
if not isinstance(input_val, AITTensor):
raise RuntimeError(f"Unexpected input for {name}: {input_val}")

return elementwise(FuncEnum.LOG1P)(input_val)


@ait_converter(acc_ops.var)
def acc_ops_var(
target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str
Expand Down
7 changes: 7 additions & 0 deletions python/aitemplate/backend/backend_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,13 @@ class GPUBackendSpec(BackendSpec):
"bfloat16": "hlog",
"float": "logf",
},
FuncEnum.LOG1P: {
"half2": "h2log1p",
"bfloat16_2": "h2log1p",
"half": "hlog1p",
"bfloat16": "hlog1p",
"float": "log1pf",
},
FuncEnum.EXP: {
"half2": "h2exp",
"bfloat16_2": "h2exp",
Expand Down
24 changes: 24 additions & 0 deletions python/aitemplate/backend/cuda/elementwise/custom_math.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1048,4 +1048,28 @@ __device__ bfloat16_2 h2celu(const bfloat16_2 a, const bfloat16_2 alpha) {
#endif
}

__device__ half hlog1p(const half a) {
return half(log1pf(float(a)));
}

__device__ bfloat16 hlog1p(const bfloat16 a) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
return bfloat16(log1pf(float(a)));
#else
NOT_IMPLEMENTED();
#endif
}

__device__ half2 h2log1p(const half2 a) {
return half2(log1pf(float(a.x)), log1pf(float(a.y)));
}

__device__ bfloat16_2 h2log1p(const bfloat16_2 a) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
return bfloat16_2(log1pf(float(a.x)), log1pf(float(a.y)));
#else
NOT_IMPLEMENTED();
#endif
}

#endif
1 change: 1 addition & 0 deletions python/aitemplate/compiler/ops/common/epilogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ class FuncEnum(Enum):
SOFTSIGN = 27
FLOOR_DIV = 28
CELU = 29
LOG1P = 30
4 changes: 4 additions & 0 deletions python/aitemplate/compiler/ops/common/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def log(tensor: Any) -> Tensor:
return OP_REGISTRY.get("LOGE")(tensor)


def log1p(tensor: Any) -> Tensor:
return OP_REGISTRY.get("LOG1P")(tensor)


def exp(tensor: Any) -> Tensor:
return OP_REGISTRY.get("EXP")(tensor)

Expand Down

0 comments on commit 63ad837

Please sign in to comment.