Skip to content

Commit

Permalink
feat: add complex_constant_
Browse files Browse the repository at this point in the history
  • Loading branch information
ouioui199 committed Dec 5, 2024
1 parent e54f2c4 commit 0729da6
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/torchcvnn/nn/modules/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,21 @@ def complex_xavier_normal_(
std = (gain * math.sqrt(2.0 / float(fan_in + fan_out))) / math.sqrt(2)

return nn.init._no_grad_normal_(tensor, 0.0, std)


def complex_constant_(
tensor: torch.Tensor,
val: float
) -> torch.Tensor:
r"""Fill the input Tensor with the value :math:`\text{val}`.
Args:
tensor: an n-dimensional `torch.Tensor`
val: the value to fill the tensor with
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.constant_(w, 0.3)
"""
val = (val + 1j * val) / math.sqrt(2)
return nn.init._no_grad_fill_(tensor.to(torch.complex64), val)

0 comments on commit 0729da6

Please sign in to comment.