diff --git a/bench_test.go b/bench_test.go index 9ecc3914..718899c8 100644 --- a/bench_test.go +++ b/bench_test.go @@ -48,25 +48,143 @@ func BenchmarkConnect(b *testing.B) { assert.True(b, ok) httpTransport.DisableCompression = true - client := pingv1connect.NewPingServiceClient( - httpClient, - server.URL, - connect.WithGRPC(), - connect.WithSendGzip(), - ) + clients := []struct { + name string + opts []connect.ClientOption + }{{ + name: "connect", + opts: []connect.ClientOption{}, + }, { + name: "grpc", + opts: []connect.ClientOption{ + connect.WithGRPC(), + }, + }, { + name: "grpcweb", + opts: []connect.ClientOption{ + connect.WithGRPCWeb(), + }, + }} + twoMiB := strings.Repeat("a", 2*1024*1024) - b.ResetTimer() + for _, client := range clients { + b.Run(client.name, func(b *testing.B) { + client := pingv1connect.NewPingServiceClient( + httpClient, + server.URL, + connect.WithSendGzip(), + connect.WithClientOptions(client.opts...), + ) - b.Run("unary", func(b *testing.B) { - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - _, _ = client.Ping( - context.Background(), - connect.NewRequest(&pingv1.PingRequest{Text: twoMiB}), - ) - } + ctx := context.Background() + b.Run("unary_big", func(b *testing.B) { + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if _, err := client.Ping( + ctx, connect.NewRequest(&pingv1.PingRequest{Text: twoMiB}), + ); err != nil { + b.Error(err) + } + } + }) + }) + b.Run("unary_small", func(b *testing.B) { + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + response, err := client.Ping( + ctx, connect.NewRequest(&pingv1.PingRequest{Number: 42}), + ) + if err != nil { + b.Error(err) + } else if response.Msg.Number != 42 { + b.Errorf("expected 42, got %d", response.Msg.Number) + } + } + }) + }) + b.Run("client_stream", func(b *testing.B) { + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + const ( + upTo = 1 + expect = 1 + ) + stream := client.Sum(ctx) + for number := int64(1); number <= upTo; number++ { + if err := stream.Send(&pingv1.SumRequest{Number: number}); err != nil { + b.Error(err) + } + } + response, err := stream.CloseAndReceive() + if err != nil { + b.Error(err) + } else if response.Msg.Sum != expect { + b.Errorf("expected %d, got %d", expect, response.Msg.Sum) + } + } + }) + }) + b.Run("server_stream", func(b *testing.B) { + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + const ( + upTo = 1 + ) + request := connect.NewRequest(&pingv1.CountUpRequest{Number: upTo}) + stream, err := client.CountUp(ctx, request) + if err != nil { + b.Error(err) + return + } + number := int64(1) + for ; stream.Receive(); number++ { + if stream.Msg().Number != number { + b.Errorf("expected %d, got %d", number, stream.Msg().Number) + } + } + if number != upTo+1 { + b.Errorf("expected %d, got %d", upTo+1, number) + } + } + }) + }) + b.Run("bidi_stream", func(b *testing.B) { + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + const ( + upTo = 1 + ) + stream := client.CumSum(ctx) + number := int64(1) + for ; number <= upTo; number++ { + if err := stream.Send(&pingv1.CumSumRequest{Number: number}); err != nil { + b.Error(err) + } + + msg, err := stream.Receive() + if err != nil { + b.Error(err) + } + if msg.Sum != number*(number+1)/2 { + b.Errorf("expected %d, got %d", number*(number+1)/2, msg.Sum) + } + } + if err := stream.CloseRequest(); err != nil { + b.Error(err) + } + if err := stream.CloseResponse(); err != nil { + b.Error(err) + } + } + }) + }) }) - }) + } } type ping struct { diff --git a/handler_example_test.go b/handler_example_test.go index 68b57084..504b31e3 100644 --- a/handler_example_test.go +++ b/handler_example_test.go @@ -16,6 +16,8 @@ package connect_test import ( "context" + "errors" + "io" "net/http" connect "connectrpc.com/connect" @@ -42,6 +44,45 @@ func (*ExamplePingServer) Ping( ), nil } +// Sum implements pingv1connect.PingServiceHandler. +func (p *ExamplePingServer) Sum(ctx context.Context, stream *connect.ClientStream[pingv1.SumRequest]) (*connect.Response[pingv1.SumResponse], error) { + var sum int64 + for stream.Receive() { + sum += stream.Msg().Number + } + if stream.Err() != nil { + return nil, stream.Err() + } + return connect.NewResponse(&pingv1.SumResponse{Sum: sum}), nil +} + +// CountUp implements pingv1connect.PingServiceHandler. +func (p *ExamplePingServer) CountUp(ctx context.Context, request *connect.Request[pingv1.CountUpRequest], stream *connect.ServerStream[pingv1.CountUpResponse]) error { + for number := int64(1); number <= request.Msg.Number; number++ { + if err := stream.Send(&pingv1.CountUpResponse{Number: number}); err != nil { + return err + } + } + return nil +} + +// CumSum implements pingv1connect.PingServiceHandler. +func (p *ExamplePingServer) CumSum(ctx context.Context, stream *connect.BidiStream[pingv1.CumSumRequest, pingv1.CumSumResponse]) error { + var sum int64 + for { + msg, err := stream.Receive() + if errors.Is(err, io.EOF) { + return nil + } else if err != nil { + return err + } + sum += msg.Number + if err := stream.Send(&pingv1.CumSumResponse{Sum: sum}); err != nil { + return err + } + } +} + func Example_handler() { // protoc-gen-connect-go generates constructors that return plain net/http // Handlers, so they're compatible with most Go HTTP routers and middleware