Skip to content

Commit

Permalink
apply audit suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
Dreamer committed Oct 17, 2023
1 parent efde3fb commit 1d33cd9
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 34 deletions.
2 changes: 1 addition & 1 deletion keeper/grpc_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (k Keeper) ClassTrace(c context.Context,

hash, err := types.ParseHexHash(strings.TrimPrefix(req.Hash, "ibc/"))
if err != nil {
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("invalid class trace hash: %s, error: %s", hash.String(), err))
return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("invalid class trace hash: %s, error: %s", req.Hash, err))
}

ctx := sdk.UnwrapSDKContext(c)
Expand Down
31 changes: 9 additions & 22 deletions keeper/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,10 @@ func (k Keeper) SendTransfer(
packet, err := k.createOutgoingPacket(ctx,
sourcePort,
sourceChannel,
destinationPort,
destinationChannel,
classID,
tokenIDs,
sender,
receiver,
timeoutHeight,
timeoutTimestamp,
memo,
)
if err != nil {
Expand Down Expand Up @@ -170,8 +166,8 @@ func (k Keeper) refundPacketToken(ctx sdk.Context, packet channeltypes.Packet, d
classTrace := types.ParseClassTrace(data.ClassId)
voucherClassID := classTrace.IBCClassID()
if types.IsAwayFromOrigin(packet.GetSourcePort(), packet.GetSourceChannel(), data.ClassId) {
for _, tokenID := range data.TokenIds {
if err := k.nftKeeper.Transfer(ctx, voucherClassID, tokenID, "", sender); err != nil {
for i, tokenID := range data.TokenIds {
if err := k.nftKeeper.Transfer(ctx, voucherClassID, tokenID, types.GetIfExist(i, data.TokenData), sender); err != nil {
return err
}
}
Expand Down Expand Up @@ -199,14 +195,10 @@ func (k Keeper) refundPacketToken(ctx sdk.Context, packet channeltypes.Packet, d
func (k Keeper) createOutgoingPacket(ctx sdk.Context,
sourcePort,
sourceChannel,
destinationPort,
destinationChannel,
classID string,
tokenIDs []string,
sender sdk.AccAddress,
receiver string,
timeoutHeight clienttypes.Height,
timeoutTimestamp uint64,
memo string,
) (types.NonFungibleTokenPacketData, error) {
class, exist := k.nftKeeper.GetClass(ctx, classID)
Expand All @@ -218,8 +210,8 @@ func (k Keeper) createOutgoingPacket(ctx sdk.Context,
// NOTE: class and hex hash correctness checked during msg.ValidateBasic
fullClassPath = classID
err error
tokenURIs []string
tokenData []string
tokenURIs = make([]string, len(tokenIDs))
tokenData = make([]string, len(tokenIDs))
)

// deconstruct the token denomination into the denomination trace info
Expand All @@ -233,7 +225,7 @@ func (k Keeper) createOutgoingPacket(ctx sdk.Context,

isAwayFromOrigin := types.IsAwayFromOrigin(sourcePort,
sourceChannel, fullClassPath)
for _, tokenID := range tokenIDs {
for i, tokenID := range tokenIDs {
nft, exist := k.nftKeeper.GetNFT(ctx, classID, tokenID)
if !exist {
return types.NonFungibleTokenPacketData{}, errorsmod.Wrap(types.ErrInvalidTokenID, "tokenId not exist")
Expand All @@ -244,13 +236,13 @@ func (k Keeper) createOutgoingPacket(ctx sdk.Context,
return types.NonFungibleTokenPacketData{}, errorsmod.Wrap(sdkerrors.ErrUnauthorized, "not token owner")
}

tokenURIs = append(tokenURIs, nft.GetURI())
tokenData = append(tokenData, nft.GetData())
tokenURIs[i] = nft.GetURI()
tokenData[i] = nft.GetData()

if isAwayFromOrigin {
// create the escrow address for the tokens
escrowAddress := types.GetEscrowAddress(sourcePort, sourceChannel)
if err := k.nftKeeper.Transfer(ctx, classID, tokenID, "", escrowAddress); err != nil {
if err := k.nftKeeper.Transfer(ctx, classID, tokenID, nft.GetData(), escrowAddress); err != nil {
return types.NonFungibleTokenPacketData{}, err
}
} else {
Expand All @@ -271,12 +263,7 @@ func (k Keeper) createOutgoingPacket(ctx sdk.Context,
tokenData,
memo,
)

// check packet
if err := packetData.ValidateBasic(); err != nil {
return types.NonFungibleTokenPacketData{}, err
}
return packetData, nil
return packetData, packetData.ValidateBasic()
}

// processReceivedPacket will mint the tokens to receiver account
Expand Down
6 changes: 5 additions & 1 deletion types/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@ import (
// on the provided LegacyAmino codec. These types are used for Amino JSON serialization.
func RegisterLegacyAminoCodec(cdc *codec.LegacyAmino) {
cdc.RegisterConcrete(&MsgTransfer{}, "cosmos-sdk/MsgTransferNFT", nil)
cdc.RegisterConcrete(&MsgUpdateParams{}, "cosmos-sdk/MsgUpdateParams", nil)
}

// RegisterInterfaces register the ibc nft-transfer module interfaces to protobuf
// Any.
func RegisterInterfaces(registry codectypes.InterfaceRegistry) {
registry.RegisterImplementations((*sdk.Msg)(nil), &MsgTransfer{})
registry.RegisterImplementations((*sdk.Msg)(nil),
&MsgTransfer{},
&MsgUpdateParams{},
)
msgservice.RegisterMsgServiceDesc(registry, &_Msg_serviceDesc)
}

Expand Down
14 changes: 11 additions & 3 deletions types/msgs.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,15 @@ func (msg MsgTransfer) ValidateBasic() error {
return errorsmod.Wrap(ErrInvalidTokenID, "tokenId cannot be blank")
}

for _, tokenID := range msg.TokenIds {
if strings.TrimSpace(tokenID) == "" {
seen := make(map[string]int64)
for i, id := range msg.TokenIds {
if strings.TrimSpace(id) == "" {
return errorsmod.Wrap(ErrInvalidTokenID, "tokenId cannot be blank")
}
if j, exist := seen[id]; exist {
return errorsmod.Wrapf(ErrInvalidTokenID, "the tokenId at positions %d and %d in the array are repeated", i, j)
}
seen[id] = int64(i)
}

// NOTE: sender format must be validated as it is required by the GetSigners function.
Expand Down Expand Up @@ -116,6 +121,9 @@ func (msg MsgUpdateParams) GetSignBytes() []byte {

// GetSigners returns the expected signers for a MsgUpdateParams.
func (msg MsgUpdateParams) GetSigners() []sdk.AccAddress {
authority, _ := sdk.AccAddressFromBech32(msg.Authority)
authority, err := sdk.AccAddressFromBech32(msg.Authority)
if err != nil {
panic(err)
}
return []sdk.AccAddress{authority}
}
7 changes: 6 additions & 1 deletion types/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,15 @@ func (nftpd NonFungibleTokenPacketData) ValidateBasic() error {
return errorsmod.Wrap(ErrInvalidTokenID, "tokenId cannot be empty")
}

for _, id := range nftpd.TokenIds {
seen := make(map[string]int64)
for i, id := range nftpd.TokenIds {
if strings.TrimSpace(id) == "" {
return errorsmod.Wrap(ErrInvalidTokenID, "tokenId cannot be blank")
}
if j, exist := seen[id]; exist {
return errorsmod.Wrapf(ErrInvalidTokenID, "the tokenId at positions %d and %d in the array are repeated", i, j)
}
seen[id] = int64(i)
}

if (len(nftpd.TokenUris) != 0) && len(nftpd.TokenIds) != len(nftpd.TokenUris) {
Expand Down
10 changes: 4 additions & 6 deletions types/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@ import (

// ParseHexHash parses a hex hash in string format to bytes and validates its correctness.
func ParseHexHash(hexHash string) (tmbytes.HexBytes, error) {
if strings.TrimSpace(hexHash) == "" {
return nil, fmt.Errorf("empty hex hash")
}
hash, err := hex.DecodeString(hexHash)
if err != nil {
return nil, err
}

if err := tmtypes.ValidateHash(hash); err != nil {
return nil, err
}

return hash, nil
return hash, tmtypes.ValidateHash(hash)
}

// GetClassPrefix returns the receiving class prefix
Expand Down

0 comments on commit 1d33cd9

Please sign in to comment.