Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add utils for example #1128

Merged
merged 1 commit into from
Jan 31, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions examples/hfl/proto/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import argparse
import pickle
import socket
import struct

from google.protobuf.message import Message
from singa import tensor


def receive_all(conn: socket.socket, size: int) -> bytes:
"""Receive a given length of bytes from socket.

Args:
conn (socket.socket): Socket connection.
size (int): Length of bytes to receive.

Raises:
RuntimeError: If connection closed before chunk was read, it will raise an error.

Returns:
bytes: Received bytes.
"""
buffer = b""
while size > 0:
chunk = conn.recv(size)
if not chunk:
raise RuntimeError("connection closed before chunk was read")
buffer += chunk
size -= len(chunk)
return buffer


def send_int(conn: socket.socket, i: int, pack_format: str = "Q") -> None:
"""Send an integer from socket.

Args:
conn (socket.socket): Socket connection.
i (int): Integer to send.
pack_format (str, optional): Pack format. Defaults to "Q", which means unsigned long long.
"""
data = struct.pack(f"!{pack_format}", i)
conn.sendall(data)


def receive_int(conn: socket.socket, pack_format: str = "Q") -> int:
"""Receive an integer from socket.

Args:
conn (socket.socket): Socket connection.
pack_format (str, optional): Pack format. Defaults to "Q", which means unsigned long long.

Returns:
int: Received integer.
"""
buffer_size = struct.Struct(pack_format).size
data = receive_all(conn, buffer_size)
(data,) = struct.unpack(f"!{pack_format}", data)
return data


def send_message(conn: socket.socket, data: Message, pack_format: str = "Q") -> None:
"""Send protobuf message from socket. First the length of protobuf message will be sent. Then the message is sent.

Args:
conn (socket.socket): Socket connection.
data (Message): Protobuf message to send.
pack_format (str, optional): Length of protobuf message pack format. Defaults to "Q", which means unsigned long long.
"""
send_int(conn, data.ByteSize(), pack_format)
conn.sendall(data.SerializePartialToString())


def receive_message(conn: socket.socket, data: Message, pack_format: str = "Q") -> Message:
"""Receive protobuf message from socket

Args:
conn (socket.socket): Socket connection.
data (Message): Placehold for protobuf message.
pack_format (str, optional): Length of protobuf message pack format. Defaults to "Q", which means unsigned long long.

Returns:
Message: The protobuf message.
"""
data_len = receive_int(conn, pack_format)
data.ParseFromString(receive_all(conn, data_len))
return data


def serialize_tensor(t: tensor.Tensor) -> bytes:
"""Serialize a singa tensor to bytes.

Args:
t (tensor.Tensor): The singa tensor.

Returns:
bytes: The serialized tensor.
"""
return pickle.dumps(tensor.to_numpy(t), protocol=0)


def deserialize_tensor(t: bytes) -> tensor.Tensor:
"""Recover singa tensor from bytes.

Args:
t (bytes): The serialized tensor.

Returns:
tensor.Tensor: The singa tensor.
"""
return tensor.from_numpy(pickle.loads(t))


def parseargs(arg=None) -> argparse.Namespace:
"""Parse command line arguments

Returns:
argparse.Namespace: parsed arguments
"""

parser = argparse.ArgumentParser(description="Training using the autograd and graph.")
parser.add_argument(
"--model", choices=["cnn", "resnet", "xceptionnet", "mlp", "alexnet"], default="mlp"
)
parser.add_argument("--data", choices=["mnist", "cifar10", "cifar100", "bank"], default="mnist")
parser.add_argument("-p", choices=["float32", "float16"], default="float32", dest="precision")
parser.add_argument(
"-m", "--max-epoch", default=10, type=int, help="maximum epochs", dest="max_epoch"
)
parser.add_argument(
"-b", "--batch-size", default=64, type=int, help="batch size", dest="batch_size"
)
parser.add_argument(
"-l", "--learning-rate", default=0.005, type=float, help="initial learning rate", dest="lr"
)
# Determine which gpu to use
parser.add_argument(
"-i", "--device-id", default=0, type=int, help="which GPU to use", dest="device_id"
)
parser.add_argument(
"-g",
"--disable-graph",
default="True",
action="store_false",
help="disable graph",
dest="graph",
)
parser.add_argument(
"-v", "--log-verbosity", default=0, type=int, help="logging verbosity", dest="verbosity"
)
parser.add_argument(
"-d",
"--data-distribution",
choices=["iid", "non-iid"],
default="iid",
help="data distribution",
dest="data_dist",
)
parser.add_argument("--num_clients", default=10, type=int)
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", default=1234)

args = parser.parse_args(arg)
return args
Loading