From b77d01edca43500f267c4b43333f645b84a4fcf0 Mon Sep 17 00:00:00 2001 From: Liu Shuang Date: Tue, 10 Oct 2023 10:29:41 -0500 Subject: [PATCH] Support get http header and x-ratelimit-* headers (#507) * feat: add headers to http response * feat: support rate limit headers * fix: go lint * fix: test coverage * refactor streamReader * refactor streamReader * refactor: NewRateLimitHeaders to newRateLimitHeaders * refactor: RateLimitHeaders Resets filed * refactor: move RateLimitHeaders struct --- chat_stream_test.go | 89 +++++++++++++++++++++++++++++++++++++++++++-- chat_test.go | 53 +++++++++++++++++++++++++++ client.go | 9 ++++- ratelimit.go | 43 ++++++++++++++++++++++ stream_reader.go | 2 + 5 files changed, 191 insertions(+), 5 deletions(-) create mode 100644 ratelimit.go diff --git a/chat_stream_test.go b/chat_stream_test.go index 5fc70b032..2c109d454 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1,15 +1,17 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "errors" + "fmt" "io" "net/http" + "strconv" "testing" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestChatCompletionsStreamWrongModel(t *testing.T) { @@ -178,6 +180,87 @@ func TestCreateChatCompletionStreamError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set(xCustomHeader, xCustomHeaderValue) + + // Send test responses + //nolint:lll + dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + value := stream.Header().Get(xCustomHeader) + if value != xCustomHeaderValue { + t.Errorf("expected %s to be %s", xCustomHeaderValue, value) + } +} + +func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } + + // Send test responses + //nolint:lll + dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + headers := stream.GetRateLimitHeaders() + bs1, _ := json.Marshal(headers) + bs2, _ := json.Marshal(rateLimitHeaders) + if string(bs1) != string(bs2) { + t.Errorf("expected rate limit header %s to be %s", bs2, bs1) + } +} + func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/chat_test.go b/chat_test.go index 52cd0bdef..329b2b9cb 100644 --- a/chat_test.go +++ b/chat_test.go @@ -21,6 +21,17 @@ const ( xCustomHeaderValue = "test" ) +var ( + rateLimitHeaders = map[string]any{ + "x-ratelimit-limit-requests": 60, + "x-ratelimit-limit-tokens": 150000, + "x-ratelimit-remaining-requests": 59, + "x-ratelimit-remaining-tokens": 149984, + "x-ratelimit-reset-requests": "1s", + "x-ratelimit-reset-tokens": "6m0s", + } +) + func TestChatCompletionsWrongModel(t *testing.T) { config := DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" @@ -97,6 +108,40 @@ func TestChatCompletionsWithHeaders(t *testing.T) { } } +// TestChatCompletionsWithRateLimitHeaders Tests the completions endpoint of the API using the mocked server. +func TestChatCompletionsWithRateLimitHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") + + headers := resp.GetRateLimitHeaders() + resetRequests := headers.ResetRequests.String() + if resetRequests != rateLimitHeaders["x-ratelimit-reset-requests"] { + t.Errorf("expected resetRequests %s to be %s", resetRequests, rateLimitHeaders["x-ratelimit-reset-requests"]) + } + resetRequestsTime := headers.ResetRequests.Time() + if resetRequestsTime.Before(time.Now()) { + t.Errorf("unexpected reset requetsts: %v", resetRequestsTime) + } + + bs1, _ := json.Marshal(headers) + bs2, _ := json.Marshal(rateLimitHeaders) + if string(bs1) != string(bs2) { + t.Errorf("expected rate limit header %s to be %s", bs2, bs1) + } +} + // TestChatCompletionsFunctions tests including a function call. func TestChatCompletionsFunctions(t *testing.T) { client, server, teardown := setupOpenAITestServer() @@ -311,6 +356,14 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { } resBytes, _ = json.Marshal(res) w.Header().Set(xCustomHeader, xCustomHeaderValue) + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } fmt.Fprintln(w, string(resBytes)) } diff --git a/client.go b/client.go index 19902285b..65ece812f 100644 --- a/client.go +++ b/client.go @@ -30,8 +30,12 @@ func (h *httpHeader) SetHeader(header http.Header) { *h = httpHeader(header) } -func (h httpHeader) Header() http.Header { - return http.Header(h) +func (h *httpHeader) Header() http.Header { + return http.Header(*h) +} + +func (h *httpHeader) GetRateLimitHeaders() RateLimitHeaders { + return newRateLimitHeaders(h.Header()) } // NewClient creates new OpenAI API client. @@ -156,6 +160,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream response: resp, errAccumulator: utils.NewErrorAccumulator(), unmarshaler: &utils.JSONUnmarshaler{}, + httpHeader: httpHeader(resp.Header), }, nil } diff --git a/ratelimit.go b/ratelimit.go new file mode 100644 index 000000000..e8953f716 --- /dev/null +++ b/ratelimit.go @@ -0,0 +1,43 @@ +package openai + +import ( + "net/http" + "strconv" + "time" +) + +// RateLimitHeaders struct represents Openai rate limits headers. +type RateLimitHeaders struct { + LimitRequests int `json:"x-ratelimit-limit-requests"` + LimitTokens int `json:"x-ratelimit-limit-tokens"` + RemainingRequests int `json:"x-ratelimit-remaining-requests"` + RemainingTokens int `json:"x-ratelimit-remaining-tokens"` + ResetRequests ResetTime `json:"x-ratelimit-reset-requests"` + ResetTokens ResetTime `json:"x-ratelimit-reset-tokens"` +} + +type ResetTime string + +func (r ResetTime) String() string { + return string(r) +} + +func (r ResetTime) Time() time.Time { + d, _ := time.ParseDuration(string(r)) + return time.Now().Add(d) +} + +func newRateLimitHeaders(h http.Header) RateLimitHeaders { + limitReq, _ := strconv.Atoi(h.Get("x-ratelimit-limit-requests")) + limitTokens, _ := strconv.Atoi(h.Get("x-ratelimit-limit-tokens")) + remainingReq, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-requests")) + remainingTokens, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-tokens")) + return RateLimitHeaders{ + LimitRequests: limitReq, + LimitTokens: limitTokens, + RemainingRequests: remainingReq, + RemainingTokens: remainingTokens, + ResetRequests: ResetTime(h.Get("x-ratelimit-reset-requests")), + ResetTokens: ResetTime(h.Get("x-ratelimit-reset-tokens")), + } +} diff --git a/stream_reader.go b/stream_reader.go index 87e59e0ca..d17412591 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -27,6 +27,8 @@ type streamReader[T streamable] struct { response *http.Response errAccumulator utils.ErrorAccumulator unmarshaler utils.Unmarshaler + + httpHeader } func (stream *streamReader[T]) Recv() (response T, err error) {