diff --git a/cmd/indexer.go b/cmd/indexer.go index 3cdc828..f9a9adf 100644 --- a/cmd/indexer.go +++ b/cmd/indexer.go @@ -3,7 +3,7 @@ package cmd import ( "time" - "github.com/0glabs/0g-storage-client/common/util" + "github.com/0glabs/0g-storage-client/common/rpc" "github.com/0glabs/0g-storage-client/indexer" "github.com/0glabs/0g-storage-client/indexer/gateway" "github.com/sirupsen/logrus" @@ -79,7 +79,7 @@ func startIndexer(*cobra.Command, []string) { Endpoint: indexerArgs.endpoint, Nodes: indexerArgs.nodes.TrustedNodes, MaxDownloadFileSize: indexerArgs.maxDownloadFileSize, - RPCHandler: util.MustNewRPCHandler(map[string]interface{}{ + RPCHandler: rpc.MustNewHandler(map[string]interface{}{ api.Namespace: api, }), }) diff --git a/common/rpc/batch.go b/common/rpc/batch.go new file mode 100644 index 0000000..a4e063e --- /dev/null +++ b/common/rpc/batch.go @@ -0,0 +1,49 @@ +package rpc + +import ( + "context" + + "github.com/openweb3/go-rpc-provider" + "github.com/openweb3/go-rpc-provider/interfaces" +) + +type Request struct { + Method string + Args []any +} + +type Response[T any] struct { + Data T + Error error +} + +// BatchCall is a generic method to call RPC in batch. +func BatchCall[T any](provider interfaces.Provider, requests ...Request) ([]Response[T], error) { + return BatchCallContext[T](provider, context.Background(), requests...) +} + +// BatchCallContext is a generic method to call RPC with context in batch. +func BatchCallContext[T any](provider interfaces.Provider, ctx context.Context, requests ...Request) ([]Response[T], error) { + batch := make([]rpc.BatchElem, 0, len(requests)) + responses := make([]Response[T], len(requests)) + + for i, v := range requests { + batch = append(batch, rpc.BatchElem{ + Method: v.Method, + Args: v.Args, + Result: &responses[i].Data, + }) + } + + if err := provider.BatchCallContext(ctx, batch); err != nil { + return nil, err + } + + for i, v := range batch { + if v.Error != nil { + responses[i].Error = v.Error + } + } + + return responses, nil +} diff --git a/common/rpc/client.go b/common/rpc/client.go new file mode 100644 index 0000000..5496c1c --- /dev/null +++ b/common/rpc/client.go @@ -0,0 +1,45 @@ +package rpc + +import ( + "context" + + "github.com/openweb3/go-rpc-provider/interfaces" + providers "github.com/openweb3/go-rpc-provider/provider_wrapper" +) + +// Client is a base class of any RPC client. +type Client struct { + *providers.MiddlewarableProvider + url string +} + +// NewClient creates a new client instance. +func NewClient(url string, option ...providers.Option) (*Client, error) { + var opt providers.Option + if len(option) > 0 { + opt = option[0] + } + + provider, err := providers.NewProviderWithOption(url, opt) + if err != nil { + return nil, err + } + + return &Client{providers.NewMiddlewarableProvider(provider), url}, nil +} + +// URL Get the RPC server URL the client connected to. +func (c *Client) URL() string { + return c.url +} + +// Call is a generic method to call RPC. +func Call[T any](provider interfaces.Provider, method string, args ...any) (result T, err error) { + return CallContext[T](provider, context.Background(), method, args...) +} + +// CallContext is a generic method to call RPC with context. +func CallContext[T any](provider interfaces.Provider, ctx context.Context, method string, args ...any) (result T, err error) { + err = provider.CallContext(ctx, &result, method, args...) + return +} diff --git a/common/util/rpc.go b/common/rpc/server.go similarity index 52% rename from common/util/rpc.go rename to common/rpc/server.go index 0ba2874..eafb26c 100644 --- a/common/util/rpc.go +++ b/common/rpc/server.go @@ -1,15 +1,15 @@ -package util +package rpc import ( "net/http" + "net/rpc" "github.com/ethereum/go-ethereum/node" - "github.com/openweb3/go-rpc-provider" "github.com/sirupsen/logrus" ) -// MustNewRPCHandler creates a http.Handler for the specified RPC apis. -func MustNewRPCHandler(apis map[string]interface{}) http.Handler { +// MustNewHandler creates a http.Handler for the specified RPC apis. +func MustNewHandler(apis map[string]interface{}) http.Handler { handler := rpc.NewServer() for namespace, impl := range apis { @@ -22,8 +22,8 @@ func MustNewRPCHandler(apis map[string]interface{}) http.Handler { return node.NewHTTPHandlerStack(handler, []string{"*"}, []string{"*"}, []byte{}) } -// MustServe starts a HTTP service util shutdown. -func MustServe(endpoint string, handler http.Handler) { +// Start starts a HTTP service util shutdown. +func Start(endpoint string, handler http.Handler) { server := http.Server{ Addr: endpoint, Handler: handler, @@ -32,8 +32,8 @@ func MustServe(endpoint string, handler http.Handler) { server.ListenAndServe() } -// MustServeRPC starts RPC service until shutdown. -func MustServeRPC(endpoint string, apis map[string]interface{}) { - rpcHandler := MustNewRPCHandler(apis) - MustServe(endpoint, rpcHandler) +// MustServe starts RPC service until shutdown. +func MustServe(endpoint string, apis map[string]interface{}) { + rpcHandler := MustNewHandler(apis) + Start(endpoint, rpcHandler) } diff --git a/indexer/client.go b/indexer/client.go index 04afa65..78aa504 100644 --- a/indexer/client.go +++ b/indexer/client.go @@ -7,12 +7,12 @@ import ( "time" "github.com/0glabs/0g-storage-client/common" + "github.com/0glabs/0g-storage-client/common/rpc" "github.com/0glabs/0g-storage-client/common/shard" "github.com/0glabs/0g-storage-client/core" "github.com/0glabs/0g-storage-client/node" "github.com/0glabs/0g-storage-client/transfer" eth_common "github.com/ethereum/go-ethereum/common" - "github.com/openweb3/go-rpc-provider/interfaces" providers "github.com/openweb3/go-rpc-provider/provider_wrapper" "github.com/openweb3/web3go" "github.com/pkg/errors" @@ -28,7 +28,7 @@ var ( // Client indexer client type Client struct { - interfaces.Provider + *rpc.Client option IndexerClientOption logger *logrus.Logger } @@ -46,34 +46,31 @@ func NewClient(url string, option ...IndexerClientOption) (*Client, error) { opt = option[0] } - provider, err := providers.NewProviderWithOption(url, opt.ProviderOption) + client, err := rpc.NewClient(url, opt.ProviderOption) if err != nil { return nil, err } return &Client{ - Provider: provider, - option: opt, - logger: common.NewLogger(opt.LogOption), + Client: client, + option: opt, + logger: common.NewLogger(opt.LogOption), }, nil } -// GetNodes get node list from indexer service -func (c *Client) GetShardedNodes(ctx context.Context) (nodes ShardedNodes, err error) { - err = c.Provider.CallContext(ctx, &nodes, "indexer_getShardedNodes") - return +// GetShardedNodes get node list from indexer service +func (c *Client) GetShardedNodes(ctx context.Context) (ShardedNodes, error) { + return rpc.CallContext[ShardedNodes](c, ctx, "indexer_getShardedNodes") } -// GetNodes return storage nodes with IP location information. -func (c *Client) GetNodeLocations(ctx context.Context) (locations map[string]*IPLocation, err error) { - err = c.Provider.CallContext(ctx, &locations, "indexer_getNodeLocations") - return +// GetNodeLocations return storage nodes with IP location information. +func (c *Client) GetNodeLocations(ctx context.Context) (map[string]*IPLocation, error) { + return rpc.CallContext[map[string]*IPLocation](c, ctx, "indexer_getNodeLocations") } // GetFileLocations return locations info of given file. -func (c *Client) GetFileLocations(ctx context.Context, root string) (locations []*shard.ShardedNode, err error) { - err = c.Provider.CallContext(ctx, &locations, "indexer_getFileLocations", root) - return +func (c *Client) GetFileLocations(ctx context.Context, root string) ([]*shard.ShardedNode, error) { + return rpc.CallContext[[]*shard.ShardedNode](c, ctx, "indexer_getFileLocations", root) } // SelectNodes get node list from indexer service and select a subset of it, which is sufficient to store expected number of replications. diff --git a/indexer/gateway/server.go b/indexer/gateway/server.go index 89195e6..dd5cffb 100644 --- a/indexer/gateway/server.go +++ b/indexer/gateway/server.go @@ -3,7 +3,7 @@ package gateway import ( "net/http" - "github.com/0glabs/0g-storage-client/common/util" + "github.com/0glabs/0g-storage-client/common/rpc" "github.com/0glabs/0g-storage-client/node" "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" @@ -34,7 +34,7 @@ func MustServeWithRPC(config Config) { router.POST("/", gin.WrapH(config.RPCHandler)) } - util.MustServe(config.Endpoint, router) + rpc.Start(config.Endpoint, router) } func newRouter() *gin.Engine { diff --git a/node/client_admin.go b/node/client_admin.go index 152ee26..3c04a06 100644 --- a/node/client_admin.go +++ b/node/client_admin.go @@ -3,6 +3,7 @@ package node import ( "context" + "github.com/0glabs/0g-storage-client/common/rpc" providers "github.com/openweb3/go-rpc-provider/provider_wrapper" "github.com/sirupsen/logrus" ) @@ -33,66 +34,55 @@ func NewAdminClient(url string, option ...providers.Option) (*AdminClient, error } // FindFile Call find_file to update file location cache -func (c *AdminClient) FindFile(ctx context.Context, txSeq uint64) (ret int, err error) { - err = c.wrapError(c.CallContext(ctx, &ret, "admin_findFile", txSeq), "admin_findFile") - return +func (c *AdminClient) FindFile(ctx context.Context, txSeq uint64) (int, error) { + return rpc.CallContext[int](c, ctx, "admin_findFile", txSeq) } // Shutdown Call admin_shutdown to shutdown the node. -func (c *AdminClient) Shutdown(ctx context.Context) (ret int, err error) { - err = c.wrapError(c.CallContext(ctx, &ret, "admin_shutdown"), "admin_shutdown") - return +func (c *AdminClient) Shutdown(ctx context.Context) (int, error) { + return rpc.CallContext[int](c, ctx, "admin_shutdown") } // StartSyncFile Call admin_startSyncFile to request synchronization of a file. -func (c *AdminClient) StartSyncFile(ctx context.Context, txSeq uint64) (ret int, err error) { - err = c.wrapError(c.CallContext(ctx, &ret, "admin_startSyncFile", txSeq), "admin_startSyncFile") - return +func (c *AdminClient) StartSyncFile(ctx context.Context, txSeq uint64) (int, error) { + return rpc.CallContext[int](c, ctx, "admin_startSyncFile", txSeq) } // StartSyncChunks Call admin_startSyncChunks to request synchronization of specified chunks. -func (c *AdminClient) StartSyncChunks(ctx context.Context, txSeq, startIndex, endIndex uint64) (ret int, err error) { - err = c.wrapError(c.CallContext(ctx, &ret, "admin_startSyncChunks", txSeq, startIndex, endIndex), "admin_startSyncChunks") - return +func (c *AdminClient) StartSyncChunks(ctx context.Context, txSeq, startIndex, endIndex uint64) (int, error) { + return rpc.CallContext[int](c, ctx, "admin_startSyncChunks", txSeq, startIndex, endIndex) } // TerminateSync Call admin_terminateSync to terminate a file sync. -func (c *AdminClient) TerminateSync(ctx context.Context, txSeq uint64) (terminated bool, err error) { - err = c.wrapError(c.CallContext(ctx, &terminated, "admin_terminateSync", txSeq), "admin_terminateSync") - return +func (c *AdminClient) TerminateSync(ctx context.Context, txSeq uint64) (bool, error) { + return rpc.CallContext[bool](c, ctx, "admin_terminateSync", txSeq) } // GetSyncStatus Call admin_getSyncStatus to retrieve the sync status of specified file. -func (c *AdminClient) GetSyncStatus(ctx context.Context, txSeq uint64) (status string, err error) { - err = c.wrapError(c.CallContext(ctx, &status, "admin_getSyncStatus", txSeq), "admin_getSyncStatus") - return +func (c *AdminClient) GetSyncStatus(ctx context.Context, txSeq uint64) (string, error) { + return rpc.CallContext[string](c, ctx, "admin_getSyncStatus", txSeq) } // GetSyncInfo Call admin_getSyncInfo to retrieve the sync status of specified file or all files. -func (c *AdminClient) GetSyncInfo(ctx context.Context, tx_seq ...uint64) (files map[uint64]FileSyncInfo, err error) { +func (c *AdminClient) GetSyncInfo(ctx context.Context, tx_seq ...uint64) (map[uint64]FileSyncInfo, error) { if len(tx_seq) > 0 { - err = c.wrapError(c.CallContext(ctx, &files, "admin_getSyncInfo", tx_seq[0]), "admin_getSyncInfo") - } else { - err = c.wrapError(c.CallContext(ctx, &files, "admin_getSyncInfo"), "admin_getSyncInfo") + return rpc.CallContext[map[uint64]FileSyncInfo](c, ctx, "admin_getSyncInfo", tx_seq[0]) } - return + return rpc.CallContext[map[uint64]FileSyncInfo](c, ctx, "admin_getSyncInfo") } // GetNetworkInfo Call admin_getNetworkInfo to retrieve the network information. -func (c *AdminClient) GetNetworkInfo(ctx context.Context) (info NetworkInfo, err error) { - err = c.wrapError(c.CallContext(ctx, &info, "admin_getNetworkInfo"), "admin_getNetworkInfo") - return +func (c *AdminClient) GetNetworkInfo(ctx context.Context) (NetworkInfo, error) { + return rpc.CallContext[NetworkInfo](c, ctx, "admin_getNetworkInfo") } // GetPeers Call admin_getPeers to retrieve all discovered network peers. -func (c *AdminClient) GetPeers(ctx context.Context) (peers map[string]*PeerInfo, err error) { - err = c.wrapError(c.CallContext(ctx, &peers, "admin_getPeers"), "admin_getPeers") - return +func (c *AdminClient) GetPeers(ctx context.Context) (map[string]*PeerInfo, error) { + return rpc.CallContext[map[string]*PeerInfo](c, ctx, "admin_getPeers") } -// getFileLocation Get file location -func (c *AdminClient) GetFileLocation(ctx context.Context, txSeq uint64, allShards bool) (locations []LocationInfo, err error) { - err = c.wrapError(c.CallContext(ctx, &locations, "admin_getFileLocation", txSeq, allShards), "admin_getFileLocation") - return +// GetFileLocation Get file location +func (c *AdminClient) GetFileLocation(ctx context.Context, txSeq uint64, allShards bool) ([]LocationInfo, error) { + return rpc.CallContext[[]LocationInfo](c, ctx, "admin_getFileLocation", txSeq, allShards) } diff --git a/node/client_kv.go b/node/client_kv.go index bc7b8e2..8aceaaf 100644 --- a/node/client_kv.go +++ b/node/client_kv.go @@ -3,6 +3,7 @@ package node import ( "context" + "github.com/0glabs/0g-storage-client/common/rpc" "github.com/ethereum/go-ethereum/common" providers "github.com/openweb3/go-rpc-provider/provider_wrapper" "github.com/sirupsen/logrus" @@ -34,113 +35,101 @@ func NewKvClient(url string, option ...providers.Option) (*KvClient, error) { } // GetValue Call kv_getValue RPC to query the value of a key. -func (c *KvClient) GetValue(ctx context.Context, streamId common.Hash, key []byte, startIndex, length uint64, version ...uint64) (val *Value, err error) { +func (c *KvClient) GetValue(ctx context.Context, streamId common.Hash, key []byte, startIndex, length uint64, version ...uint64) (*Value, error) { args := []interface{}{streamId, key, startIndex, length} if len(version) > 0 { args = append(args, version[0]) } - err = c.wrapError(c.CallContext(ctx, &val, "kv_getValue", args...), "kv_getValue") - return + return rpc.CallContext[*Value](c, ctx, "kv_getValue", args...) } // GetNext Call kv_getNext RPC to query the next key of a given key. -func (c *KvClient) GetNext(ctx context.Context, streamId common.Hash, key []byte, startIndex, length uint64, inclusive bool, version ...uint64) (val *KeyValue, err error) { +func (c *KvClient) GetNext(ctx context.Context, streamId common.Hash, key []byte, startIndex, length uint64, inclusive bool, version ...uint64) (*KeyValue, error) { args := []interface{}{streamId, key, startIndex, length, inclusive} if len(version) > 0 { args = append(args, version[0]) } - err = c.wrapError(c.CallContext(ctx, &val, "kv_getNext", args...), "kv_getNext") - return + return rpc.CallContext[*KeyValue](c, ctx, "kv_getNext", args...) } // GetPrev Call kv_getNext RPC to query the prev key of a given key. -func (c *KvClient) GetPrev(ctx context.Context, streamId common.Hash, key []byte, startIndex, length uint64, inclusive bool, version ...uint64) (val *KeyValue, err error) { +func (c *KvClient) GetPrev(ctx context.Context, streamId common.Hash, key []byte, startIndex, length uint64, inclusive bool, version ...uint64) (*KeyValue, error) { args := []interface{}{streamId, key, startIndex, length, inclusive} if len(version) > 0 { args = append(args, version[0]) } - err = c.wrapError(c.CallContext(ctx, &val, "kv_getPrev", args...), "kv_getPrev") - return + return rpc.CallContext[*KeyValue](c, ctx, "kv_getPrev", args...) } // GetFirst Call kv_getFirst RPC to query the first key. -func (c *KvClient) GetFirst(ctx context.Context, streamId common.Hash, startIndex, length uint64, version ...uint64) (val *KeyValue, err error) { +func (c *KvClient) GetFirst(ctx context.Context, streamId common.Hash, startIndex, length uint64, version ...uint64) (*KeyValue, error) { args := []interface{}{streamId, startIndex, length} if len(version) > 0 { args = append(args, version[0]) } - err = c.wrapError(c.CallContext(ctx, &val, "kv_getFirst", args...), "kv_getFirst") - return + return rpc.CallContext[*KeyValue](c, ctx, "kv_getFirst", args...) } // GetLast Call kv_getLast RPC to query the last key. -func (c *KvClient) GetLast(ctx context.Context, streamId common.Hash, startIndex, length uint64, version ...uint64) (val *KeyValue, err error) { +func (c *KvClient) GetLast(ctx context.Context, streamId common.Hash, startIndex, length uint64, version ...uint64) (*KeyValue, error) { args := []interface{}{streamId, startIndex, length} if len(version) > 0 { args = append(args, version[0]) } - err = c.wrapError(c.CallContext(ctx, &val, "kv_getLast", args...), "kv_getLast") - return + return rpc.CallContext[*KeyValue](c, ctx, "kv_getLast", args...) } // GetTransactionResult Call kv_getTransactionResult RPC to query the kv replay status of a given file. -func (c *KvClient) GetTransactionResult(ctx context.Context, txSeq uint64) (result string, err error) { - err = c.wrapError(c.CallContext(ctx, &result, "kv_getTransactionResult", txSeq), "kv_getTransactionResult") - return +func (c *KvClient) GetTransactionResult(ctx context.Context, txSeq uint64) (string, error) { + return rpc.CallContext[string](c, ctx, "kv_getTransactionResult", txSeq) } // GetHoldingStreamIds Call kv_getHoldingStreamIds RPC to query the stream ids monitered by the kv node. -func (c *KvClient) GetHoldingStreamIds(ctx context.Context) (streamIds []common.Hash, err error) { - err = c.wrapError(c.CallContext(ctx, &streamIds, "kv_getHoldingStreamIds"), "kv_getHoldingStreamIds") - return +func (c *KvClient) GetHoldingStreamIds(ctx context.Context) ([]common.Hash, error) { + return rpc.CallContext[[]common.Hash](c, ctx, "kv_getHoldingStreamIds") } // HasWritePermission Call kv_hasWritePermission RPC to check if the account is able to write the stream. -func (c *KvClient) HasWritePermission(ctx context.Context, account common.Address, streamId common.Hash, key []byte, version ...uint64) (hasPermission bool, err error) { +func (c *KvClient) HasWritePermission(ctx context.Context, account common.Address, streamId common.Hash, key []byte, version ...uint64) (bool, error) { args := []interface{}{account, streamId, key} if len(version) > 0 { args = append(args, version[0]) } - err = c.wrapError(c.CallContext(ctx, &hasPermission, "kv_hasWritePermission", args...), "kv_hasWritePermission") - return + return rpc.CallContext[bool](c, ctx, "kv_hasWritePermission", args...) } // IsAdmin Call kv_isAdmin RPC to check if the account is the admin of the stream. -func (c *KvClient) IsAdmin(ctx context.Context, account common.Address, streamId common.Hash, version ...uint64) (isAdmin bool, err error) { +func (c *KvClient) IsAdmin(ctx context.Context, account common.Address, streamId common.Hash, version ...uint64) (bool, error) { args := []interface{}{account, streamId} if len(version) > 0 { args = append(args, version[0]) } - err = c.wrapError(c.CallContext(ctx, &isAdmin, "kv_isAdmin", args...), "kv_isAdmin") - return + return rpc.CallContext[bool](c, ctx, "kv_isAdmin", args...) } // IsSpecialKey Call kv_isSpecialKey RPC to check if the key has unique access control. -func (c *KvClient) IsSpecialKey(ctx context.Context, streamId common.Hash, key []byte, version ...uint64) (isSpecialKey bool, err error) { +func (c *KvClient) IsSpecialKey(ctx context.Context, streamId common.Hash, key []byte, version ...uint64) (bool, error) { args := []interface{}{streamId, key} if len(version) > 0 { args = append(args, version[0]) } - err = c.wrapError(c.CallContext(ctx, &isSpecialKey, "kv_isSpecialKey", args...), "kv_isSpecialKey") - return + return rpc.CallContext[bool](c, ctx, "kv_isSpecialKey", args...) } // IsWriterOfKey Call kv_isWriterOfKey RPC to check if the account can write the special key. -func (c *KvClient) IsWriterOfKey(ctx context.Context, account common.Address, streamId common.Hash, key []byte, version ...uint64) (isWriter bool, err error) { +func (c *KvClient) IsWriterOfKey(ctx context.Context, account common.Address, streamId common.Hash, key []byte, version ...uint64) (bool, error) { args := []interface{}{account, streamId, key} if len(version) > 0 { args = append(args, version[0]) } - err = c.wrapError(c.CallContext(ctx, &isWriter, "kv_isWriterOfKey", args...), "kv_isWriterOfKey") - return + return rpc.CallContext[bool](c, ctx, "kv_isWriterOfKey", args...) } // IsWriterOfStream Call kv_isWriterOfStream RPC to check if the account is the writer of the stream. -func (c *KvClient) IsWriterOfStream(ctx context.Context, account common.Address, streamId common.Hash, version ...uint64) (isWriter bool, err error) { +func (c *KvClient) IsWriterOfStream(ctx context.Context, account common.Address, streamId common.Hash, version ...uint64) (bool, error) { args := []interface{}{account, streamId} if len(version) > 0 { args = append(args, version[0]) } - err = c.wrapError(c.CallContext(ctx, &isWriter, "kv_isWriterOfStream", args...), "kv_isWriterOfStream") - return + return rpc.CallContext[bool](c, ctx, "kv_isWriterOfStream", args...) } diff --git a/node/client_zgs.go b/node/client_zgs.go index 9c9bd3a..c29ee21 100644 --- a/node/client_zgs.go +++ b/node/client_zgs.go @@ -3,6 +3,7 @@ package node import ( "context" + "github.com/0glabs/0g-storage-client/common/rpc" "github.com/0glabs/0g-storage-client/common/shard" "github.com/ethereum/go-ethereum/common" providers "github.com/openweb3/go-rpc-provider/provider_wrapper" @@ -47,59 +48,52 @@ func MustNewZgsClients(urls []string, option ...providers.Option) []*ZgsClient { } // GetStatus Call zgs_getStatus RPC to get sync status of the node. -func (c *ZgsClient) GetStatus(ctx context.Context) (status Status, err error) { - err = c.wrapError(c.CallContext(ctx, &status, "zgs_getStatus"), "zgs_getStatus") - return +func (c *ZgsClient) GetStatus(ctx context.Context) (Status, error) { + return rpc.CallContext[Status](c, ctx, "zgs_getStatus") } // CheckFileFinalized Call zgs_checkFileFinalized to check if specified file is finalized. // Returns nil if file not available on storage node. -func (c *ZgsClient) CheckFileFinalized(ctx context.Context, txSeqOrRoot TxSeqOrRoot) (finalized *bool, err error) { - err = c.wrapError(c.CallContext(ctx, &finalized, "zgs_checkFileFinalized", txSeqOrRoot), "zgs_checkFileFinalized") - return +func (c *ZgsClient) CheckFileFinalized(ctx context.Context, txSeqOrRoot TxSeqOrRoot) (*bool, error) { + return rpc.CallContext[*bool](c, ctx, "zgs_checkFileFinalized", txSeqOrRoot) } // GetFileInfo Call zgs_getFileInfo RPC to get the information of a file by file data root from the node. -func (c *ZgsClient) GetFileInfo(ctx context.Context, root common.Hash) (file *FileInfo, err error) { - err = c.wrapError(c.CallContext(ctx, &file, "zgs_getFileInfo", root), "zgs_getFileInfo") - return +func (c *ZgsClient) GetFileInfo(ctx context.Context, root common.Hash) (*FileInfo, error) { + return rpc.CallContext[*FileInfo](c, ctx, "zgs_getFileInfo", root) } // GetFileInfoByTxSeq Call zgs_getFileInfoByTxSeq RPC to get the information of a file by file sequence id from the node. -func (c *ZgsClient) GetFileInfoByTxSeq(ctx context.Context, txSeq uint64) (file *FileInfo, err error) { - err = c.wrapError(c.CallContext(ctx, &file, "zgs_getFileInfoByTxSeq", txSeq), "zgs_getFileInfoByTxSeq") - return +func (c *ZgsClient) GetFileInfoByTxSeq(ctx context.Context, txSeq uint64) (*FileInfo, error) { + return rpc.CallContext[*FileInfo](c, ctx, "zgs_getFileInfoByTxSeq", txSeq) } // UploadSegment Call zgs_uploadSegment RPC to upload a segment to the node. -func (c *ZgsClient) UploadSegment(ctx context.Context, segment SegmentWithProof) (ret int, err error) { - err = c.wrapError(c.CallContext(ctx, &ret, "zgs_uploadSegment", segment), "zgs_uploadSegment") - return +func (c *ZgsClient) UploadSegment(ctx context.Context, segment SegmentWithProof) (int, error) { + return rpc.CallContext[int](c, ctx, "zgs_uploadSegment", segment) } // UploadSegments Call zgs_uploadSegments RPC to upload a slice of segments to the node. -func (c *ZgsClient) UploadSegments(ctx context.Context, segments []SegmentWithProof) (ret int, err error) { - err = c.wrapError(c.CallContext(ctx, &ret, "zgs_uploadSegments", segments), "zgs_uploadSegments") - return +func (c *ZgsClient) UploadSegments(ctx context.Context, segments []SegmentWithProof) (int, error) { + return rpc.CallContext[int](c, ctx, "zgs_uploadSegments", segments) } // DownloadSegment Call zgs_downloadSegment RPC to download a segment from the node. -func (c *ZgsClient) DownloadSegment(ctx context.Context, root common.Hash, startIndex, endIndex uint64) (data []byte, err error) { - err = c.wrapError(c.CallContext(ctx, &data, "zgs_downloadSegment", root, startIndex, endIndex), "zgs_downloadSegment") +func (c *ZgsClient) DownloadSegment(ctx context.Context, root common.Hash, startIndex, endIndex uint64) ([]byte, error) { + data, err := rpc.CallContext[[]byte](c, ctx, "zgs_downloadSegment", root, startIndex, endIndex) if len(data) == 0 { return nil, err } - return + + return data, err } // DownloadSegmentWithProof Call zgs_downloadSegmentWithProof RPC to download a segment along with its merkle proof from the node. -func (c *ZgsClient) DownloadSegmentWithProof(ctx context.Context, root common.Hash, index uint64) (segment *SegmentWithProof, err error) { - err = c.wrapError(c.CallContext(ctx, &segment, "zgs_downloadSegmentWithProof", root, index), "zgs_downloadSegmentWithProof") - return +func (c *ZgsClient) DownloadSegmentWithProof(ctx context.Context, root common.Hash, index uint64) (*SegmentWithProof, error) { + return rpc.CallContext[*SegmentWithProof](c, ctx, "zgs_downloadSegmentWithProof", root, index) } // GetShardConfig Call zgs_getShardConfig RPC to get the current shard configuration of the node. -func (c *ZgsClient) GetShardConfig(ctx context.Context) (shardConfig shard.ShardConfig, err error) { - err = c.wrapError(c.CallContext(ctx, &shardConfig, "zgs_getShardConfig"), "zgs_getShardConfig") - return +func (c *ZgsClient) GetShardConfig(ctx context.Context) (shard.ShardConfig, error) { + return rpc.CallContext[shard.ShardConfig](c, ctx, "zgs_getShardConfig") } diff --git a/node/rpc_client.go b/node/rpc_client.go index 26fc94e..b106f09 100644 --- a/node/rpc_client.go +++ b/node/rpc_client.go @@ -1,32 +1,26 @@ package node import ( - "github.com/openweb3/go-rpc-provider/interfaces" + "context" + + "github.com/0glabs/0g-storage-client/common/rpc" providers "github.com/openweb3/go-rpc-provider/provider_wrapper" ) type rpcClient struct { - interfaces.Provider - url string + *rpc.Client } func newRpcClient(url string, option ...providers.Option) (*rpcClient, error) { - var opt providers.Option - if len(option) > 0 { - opt = option[0] - } - - provider, err := providers.NewProviderWithOption(url, opt) + inner, err := rpc.NewClient(url, option...) if err != nil { return nil, err } - return &rpcClient{provider, url}, nil -} + client := rpcClient{inner} + client.HookCallContext(client.rpcErrorMiddleware) -// URL Get the RPC server URL the client connected to. -func (c *rpcClient) URL() string { - return c.url + return &client, nil } func (c *rpcClient) wrapError(e error, method string) error { @@ -39,3 +33,10 @@ func (c *rpcClient) wrapError(e error, method string) error { URL: c.URL(), } } + +func (c *rpcClient) rpcErrorMiddleware(handler providers.CallContextFunc) providers.CallContextFunc { + return func(ctx context.Context, result interface{}, method string, args ...interface{}) error { + err := handler(ctx, result, method, args...) + return c.wrapError(err, method) + } +}