Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support get http header and x-ratelimit-* headers #507

Merged
merged 10 commits into from
Oct 10, 2023
89 changes: 86 additions & 3 deletions chat_stream_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package openai_test

import (
. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"

"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"testing"

. "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/internal/test/checks"
)

func TestChatCompletionsStreamWrongModel(t *testing.T) {
Expand Down Expand Up @@ -178,6 +180,87 @@ func TestCreateChatCompletionStreamError(t *testing.T) {
t.Logf("%+v\n", apiErr)
}

func TestCreateChatCompletionStreamWithHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set(xCustomHeader, xCustomHeaderValue)

// Send test responses
//nolint:lll
dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`)
dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...)

_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
})

stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()

value := stream.Header().Get(xCustomHeader)
if value != xCustomHeaderValue {
t.Errorf("expected %s to be %s", xCustomHeaderValue, value)
}
}

func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
for k, v := range rateLimitHeaders {
switch val := v.(type) {
case int:
w.Header().Set(k, strconv.Itoa(val))
default:
w.Header().Set(k, fmt.Sprintf("%s", v))
}
}

// Send test responses
//nolint:lll
dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`)
dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...)

_, err := w.Write(dataBytes)
checks.NoError(t, err, "Write error")
})

stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
Stream: true,
})
checks.NoError(t, err, "CreateCompletionStream returned error")
defer stream.Close()

headers := stream.GetRateLimitHeaders()
bs1, _ := json.Marshal(headers)
bs2, _ := json.Marshal(rateLimitHeaders)
if string(bs1) != string(bs2) {
t.Errorf("expected rate limit header %s to be %s", bs2, bs1)
}
}

func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
Expand Down
55 changes: 55 additions & 0 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ const (
xCustomHeaderValue = "test"
)

var (
rateLimitHeaders = map[string]any{
"x-ratelimit-limit-requests": 60,
"x-ratelimit-limit-tokens": 150000,
"x-ratelimit-remaining-requests": 59,
"x-ratelimit-remaining-tokens": 149984,
"x-ratelimit-reset-requests": "1s",
"x-ratelimit-reset-tokens": "6m0s",
}
)

func TestChatCompletionsWrongModel(t *testing.T) {
config := DefaultConfig("whatever")
config.BaseURL = "http://localhost/v1"
Expand Down Expand Up @@ -97,6 +108,42 @@ func TestChatCompletionsWithHeaders(t *testing.T) {
}
}

// TestChatCompletionsWithRateLimitHeaders Tests the completions endpoint of the API using the mocked server.
func TestChatCompletionsWithRateLimitHeaders(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint)
resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{
MaxTokens: 5,
Model: GPT3Dot5Turbo,
Messages: []ChatCompletionMessage{
{
Role: ChatMessageRoleUser,
Content: "Hello!",
},
},
})
checks.NoError(t, err, "CreateChatCompletion error")

headers := resp.GetRateLimitHeaders()
resetReqTime, err := headers.ParseResetRequestsTime()
checks.NoError(t, err, "ParseResetRequestsTime error")
resetTokensTime, err := headers.ParseResetTokensTime()
checks.NoError(t, err, "ParseResetTokensTime error")
t.Logf("reset requests time: %s, reset tokens time: %s", resetReqTime, resetTokensTime)
bs1, _ := json.Marshal(headers)
bs2, _ := json.Marshal(rateLimitHeaders)
if string(bs1) != string(bs2) {
t.Errorf("expected rate limit header %s to be %s", bs2, bs1)
}
headers.ResetRequests = "xxx"
headers.ResetTokens = "xxx"
_, err = headers.ParseResetRequestsTime()
checks.HasError(t, err, "ParseResetRequestsTime not error")
_, err = headers.ParseResetTokensTime()
checks.HasError(t, err, "ParseResetTokensTime not error")
}

// TestChatCompletionsFunctions tests including a function call.
func TestChatCompletionsFunctions(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
Expand Down Expand Up @@ -311,6 +358,14 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
}
resBytes, _ = json.Marshal(res)
w.Header().Set(xCustomHeader, xCustomHeaderValue)
for k, v := range rateLimitHeaders {
switch val := v.(type) {
case int:
w.Header().Set(k, strconv.Itoa(val))
default:
w.Header().Set(k, fmt.Sprintf("%s", v))
}
}
fmt.Fprintln(w, string(resBytes))
}

Expand Down
9 changes: 7 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ func (h *httpHeader) SetHeader(header http.Header) {
*h = httpHeader(header)
}

func (h httpHeader) Header() http.Header {
return http.Header(h)
func (h *httpHeader) Header() http.Header {
sashabaranov marked this conversation as resolved.
Show resolved Hide resolved
return http.Header(*h)
}

func (h *httpHeader) GetRateLimitHeaders() RateLimitHeaders {
return newRateLimitHeaders(h.Header())
}

// NewClient creates new OpenAI API client.
Expand Down Expand Up @@ -156,6 +160,7 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream
response: resp,
errAccumulator: utils.NewErrorAccumulator(),
unmarshaler: &utils.JSONUnmarshaler{},
httpHeader: httpHeader(resp.Header),
}, nil
}

Expand Down
45 changes: 45 additions & 0 deletions models.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"fmt"
"net/http"
"strconv"
"time"
)

// Model struct represents an OpenAPI model.
Expand Down Expand Up @@ -88,3 +90,46 @@ func (c *Client) DeleteFineTuneModel(ctx context.Context, modelID string) (
err = c.sendRequest(req, &response)
return
}

// RateLimitHeaders struct represents Openai rate limits headers.
sashabaranov marked this conversation as resolved.
Show resolved Hide resolved
type RateLimitHeaders struct {
createTime time.Time
LimitRequests int `json:"x-ratelimit-limit-requests"`
LimitTokens int `json:"x-ratelimit-limit-tokens"`
RemainingRequests int `json:"x-ratelimit-remaining-requests"`
RemainingTokens int `json:"x-ratelimit-remaining-tokens"`
ResetRequests string `json:"x-ratelimit-reset-requests"`
ResetTokens string `json:"x-ratelimit-reset-tokens"`
}

func newRateLimitHeaders(h http.Header) RateLimitHeaders {
limitReq, _ := strconv.Atoi(h.Get("x-ratelimit-limit-requests"))
limitTokens, _ := strconv.Atoi(h.Get("x-ratelimit-limit-tokens"))
remainingReq, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-requests"))
remainingTokens, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-tokens"))
return RateLimitHeaders{
createTime: time.Now(),
LimitRequests: limitReq,
LimitTokens: limitTokens,
RemainingRequests: remainingReq,
RemainingTokens: remainingTokens,
ResetRequests: h.Get("x-ratelimit-reset-requests"),
ResetTokens: h.Get("x-ratelimit-reset-tokens"),
}
}

func (r RateLimitHeaders) ParseResetRequestsTime() (time.Time, error) {
d, err := time.ParseDuration(r.ResetRequests)
if err != nil {
return time.Time{}, err
}
return r.createTime.Add(d), nil
}

func (r RateLimitHeaders) ParseResetTokensTime() (time.Time, error) {
d, err := time.ParseDuration(r.ResetTokens)
if err != nil {
return time.Time{}, err
}
return r.createTime.Add(d), nil
}
2 changes: 2 additions & 0 deletions stream_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ type streamReader[T streamable] struct {
response *http.Response
errAccumulator utils.ErrorAccumulator
unmarshaler utils.Unmarshaler

httpHeader
}

func (stream *streamReader[T]) Recv() (response T, err error) {
Expand Down
Loading