From e6cdae6a4c3e90fa6c6cdaa7a2580dea6143afa2 Mon Sep 17 00:00:00 2001 From: liushuang Date: Mon, 9 Oct 2023 11:47:51 +0800 Subject: [PATCH 1/9] feat: add headers to http response --- audio.go | 19 ++++++++++++++++++- chat.go | 2 ++ chat_test.go | 30 ++++++++++++++++++++++++++++++ client.go | 20 +++++++++++++++++++- completion.go | 2 ++ edits.go | 2 ++ embeddings.go | 4 ++++ engines.go | 4 ++++ files.go | 4 ++++ fine_tunes.go | 8 ++++++++ fine_tuning_job.go | 4 ++++ image.go | 2 ++ models.go | 6 ++++++ moderation.go | 2 ++ stream_reader.go | 7 +++++++ 15 files changed, 114 insertions(+), 2 deletions(-) diff --git a/audio.go b/audio.go index 9f469159d..4cbe4fe64 100644 --- a/audio.go +++ b/audio.go @@ -63,6 +63,21 @@ type AudioResponse struct { Transient bool `json:"transient"` } `json:"segments"` Text string `json:"text"` + + httpHeader +} + +type audioTextResponse struct { + Text string `json:"text"` + + httpHeader +} + +func (r *audioTextResponse) ToAudioResponse() AudioResponse { + return AudioResponse{ + Text: r.Text, + httpHeader: r.httpHeader, + } } // CreateTranscription — API call to create a transcription. Returns transcribed text. @@ -104,7 +119,9 @@ func (c *Client) callAudioAPI( if request.HasJSONResponse() { err = c.sendRequest(req, &response) } else { - err = c.sendRequest(req, &response.Text) + var textResponse audioTextResponse + err = c.sendRequest(req, &textResponse) + response = textResponse.ToAudioResponse() } if err != nil { return AudioResponse{}, err diff --git a/chat.go b/chat.go index 8d29b3237..df0e5f970 100644 --- a/chat.go +++ b/chat.go @@ -142,6 +142,8 @@ type ChatCompletionResponse struct { Model string `json:"model"` Choices []ChatCompletionChoice `json:"choices"` Usage Usage `json:"usage"` + + httpHeader } // CreateChatCompletion — API call to Create a completion for the chat message. diff --git a/chat_test.go b/chat_test.go index 38d66fa64..52cd0bdef 100644 --- a/chat_test.go +++ b/chat_test.go @@ -16,6 +16,11 @@ import ( "github.com/sashabaranov/go-openai/jsonschema" ) +const ( + xCustomHeader = "X-CUSTOM-HEADER" + xCustomHeaderValue = "test" +) + func TestChatCompletionsWrongModel(t *testing.T) { config := DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" @@ -68,6 +73,30 @@ func TestChatCompletions(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") } +// TestCompletions Tests the completions endpoint of the API using the mocked server. +func TestChatCompletionsWithHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") + + a := resp.Header().Get(xCustomHeader) + _ = a + if resp.Header().Get(xCustomHeader) != xCustomHeaderValue { + t.Errorf("expected header %s to be %s", xCustomHeader, xCustomHeaderValue) + } +} + // TestChatCompletionsFunctions tests including a function call. func TestChatCompletionsFunctions(t *testing.T) { client, server, teardown := setupOpenAITestServer() @@ -281,6 +310,7 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { TotalTokens: inputTokens + completionTokens, } resBytes, _ = json.Marshal(res) + w.Header().Set(xCustomHeader, xCustomHeaderValue) fmt.Fprintln(w, string(resBytes)) } diff --git a/client.go b/client.go index 5779a8e1c..244a7c7cb 100644 --- a/client.go +++ b/client.go @@ -20,6 +20,20 @@ type Client struct { createFormBuilder func(io.Writer) utils.FormBuilder } +type Response interface { + SetHeader(http.Header) +} + +type httpHeader http.Header + +func (h *httpHeader) SetHeader(header http.Header) { + *h = httpHeader(header) +} + +func (h *httpHeader) Header() http.Header { + return http.Header(*h) +} + // NewClient creates new OpenAI API client. func NewClient(authToken string) *Client { config := DefaultConfig(authToken) @@ -82,7 +96,7 @@ func (c *Client) newRequest(ctx context.Context, method, url string, setters ... return req, nil } -func (c *Client) sendRequest(req *http.Request, v any) error { +func (c *Client) sendRequest(req *http.Request, v Response) error { req.Header.Set("Accept", "application/json; charset=utf-8") // Check whether Content-Type is already set, Upload Files API requires @@ -103,6 +117,10 @@ func (c *Client) sendRequest(req *http.Request, v any) error { return c.handleErrorResp(res) } + if v != nil { + v.SetHeader(res.Header) + } + return decodeResponse(res.Body, v) } diff --git a/completion.go b/completion.go index 7b9ae89e7..c7ff94afc 100644 --- a/completion.go +++ b/completion.go @@ -154,6 +154,8 @@ type CompletionResponse struct { Model string `json:"model"` Choices []CompletionChoice `json:"choices"` Usage Usage `json:"usage"` + + httpHeader } // CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well diff --git a/edits.go b/edits.go index 831aade2f..97d026029 100644 --- a/edits.go +++ b/edits.go @@ -28,6 +28,8 @@ type EditsResponse struct { Created int64 `json:"created"` Usage Usage `json:"usage"` Choices []EditsChoice `json:"choices"` + + httpHeader } // Edits Perform an API call to the Edits endpoint. diff --git a/embeddings.go b/embeddings.go index 660bc24c3..7e2aa7eb0 100644 --- a/embeddings.go +++ b/embeddings.go @@ -150,6 +150,8 @@ type EmbeddingResponse struct { Data []Embedding `json:"data"` Model EmbeddingModel `json:"model"` Usage Usage `json:"usage"` + + httpHeader } type base64String string @@ -182,6 +184,8 @@ type EmbeddingResponseBase64 struct { Data []Base64Embedding `json:"data"` Model EmbeddingModel `json:"model"` Usage Usage `json:"usage"` + + httpHeader } // ToEmbeddingResponse converts an embeddingResponseBase64 to an EmbeddingResponse. diff --git a/engines.go b/engines.go index adf6025c2..5a0dba858 100644 --- a/engines.go +++ b/engines.go @@ -12,11 +12,15 @@ type Engine struct { Object string `json:"object"` Owner string `json:"owner"` Ready bool `json:"ready"` + + httpHeader } // EnginesList is a list of engines. type EnginesList struct { Engines []Engine `json:"data"` + + httpHeader } // ListEngines Lists the currently available engines, and provides basic diff --git a/files.go b/files.go index 8b933c362..9e521fbbe 100644 --- a/files.go +++ b/files.go @@ -25,11 +25,15 @@ type File struct { Status string `json:"status"` Purpose string `json:"purpose"` StatusDetails string `json:"status_details"` + + httpHeader } // FilesList is a list of files that belong to the user or organization. type FilesList struct { Files []File `json:"data"` + + httpHeader } // CreateFile uploads a jsonl file to GPT3 diff --git a/fine_tunes.go b/fine_tunes.go index 7d3b59dbd..ca840781c 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -41,6 +41,8 @@ type FineTune struct { ValidationFiles []File `json:"validation_files"` TrainingFiles []File `json:"training_files"` UpdatedAt int64 `json:"updated_at"` + + httpHeader } // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. @@ -69,6 +71,8 @@ type FineTuneHyperParams struct { type FineTuneList struct { Object string `json:"object"` Data []FineTune `json:"data"` + + httpHeader } // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. @@ -77,6 +81,8 @@ type FineTuneList struct { type FineTuneEventList struct { Object string `json:"object"` Data []FineTuneEvent `json:"data"` + + httpHeader } // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. @@ -86,6 +92,8 @@ type FineTuneDeleteResponse struct { ID string `json:"id"` Object string `json:"object"` Deleted bool `json:"deleted"` + + httpHeader } // Deprecated: On August 22nd, 2023, OpenAI announced the deprecation of the /v1/fine-tunes API. diff --git a/fine_tuning_job.go b/fine_tuning_job.go index 07b0c337c..9dcb49de1 100644 --- a/fine_tuning_job.go +++ b/fine_tuning_job.go @@ -21,6 +21,8 @@ type FineTuningJob struct { ValidationFile string `json:"validation_file,omitempty"` ResultFiles []string `json:"result_files"` TrainedTokens int `json:"trained_tokens"` + + httpHeader } type Hyperparameters struct { @@ -39,6 +41,8 @@ type FineTuningJobEventList struct { Object string `json:"object"` Data []FineTuneEvent `json:"data"` HasMore bool `json:"has_more"` + + httpHeader } type FineTuningJobEvent struct { diff --git a/image.go b/image.go index cb96f4f5e..4addcdb1e 100644 --- a/image.go +++ b/image.go @@ -33,6 +33,8 @@ type ImageRequest struct { type ImageResponse struct { Created int64 `json:"created,omitempty"` Data []ImageResponseDataInner `json:"data,omitempty"` + + httpHeader } // ImageResponseDataInner represents a response data structure for image API. diff --git a/models.go b/models.go index c207f0a86..d94f98836 100644 --- a/models.go +++ b/models.go @@ -15,6 +15,8 @@ type Model struct { Permission []Permission `json:"permission"` Root string `json:"root"` Parent string `json:"parent"` + + httpHeader } // Permission struct represents an OpenAPI permission. @@ -38,11 +40,15 @@ type FineTuneModelDeleteResponse struct { ID string `json:"id"` Object string `json:"object"` Deleted bool `json:"deleted"` + + httpHeader } // ModelsList is a list of models, including those that belong to the user or organization. type ModelsList struct { Models []Model `json:"data"` + + httpHeader } // ListModels Lists the currently available models, diff --git a/moderation.go b/moderation.go index a32f123f3..f8d20ee51 100644 --- a/moderation.go +++ b/moderation.go @@ -69,6 +69,8 @@ type ModerationResponse struct { ID string `json:"id"` Model string `json:"model"` Results []Result `json:"results"` + + httpHeader } // Moderations — perform a moderation api call over a string. diff --git a/stream_reader.go b/stream_reader.go index 87e59e0ca..c4cd9ac5d 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -109,3 +109,10 @@ func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) { func (stream *streamReader[T]) Close() { stream.response.Body.Close() } + +func (stream *streamReader[T]) Header() http.Header { + if stream.response != nil { + return stream.response.Header + } + return map[string][]string{} +} From 321a2ba7ed22d6faf93f764dd9a15967a469ca93 Mon Sep 17 00:00:00 2001 From: liushuang Date: Mon, 9 Oct 2023 14:45:30 +0800 Subject: [PATCH 2/9] feat: support rate limit headers --- chat_stream_test.go | 89 +++++++++++++++++++++++++++++++++++++++++++-- chat_test.go | 49 +++++++++++++++++++++++++ client.go | 4 ++ models.go | 45 +++++++++++++++++++++++ stream_reader.go | 4 ++ 5 files changed, 188 insertions(+), 3 deletions(-) diff --git a/chat_stream_test.go b/chat_stream_test.go index 5fc70b032..2c109d454 100644 --- a/chat_stream_test.go +++ b/chat_stream_test.go @@ -1,15 +1,17 @@ package openai_test import ( - . "github.com/sashabaranov/go-openai" - "github.com/sashabaranov/go-openai/internal/test/checks" - "context" "encoding/json" "errors" + "fmt" "io" "net/http" + "strconv" "testing" + + . "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestChatCompletionsStreamWrongModel(t *testing.T) { @@ -178,6 +180,87 @@ func TestCreateChatCompletionStreamError(t *testing.T) { t.Logf("%+v\n", apiErr) } +func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set(xCustomHeader, xCustomHeaderValue) + + // Send test responses + //nolint:lll + dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + value := stream.Header().Get(xCustomHeader) + if value != xCustomHeaderValue { + t.Errorf("expected %s to be %s", xCustomHeaderValue, value) + } +} + +func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } + + // Send test responses + //nolint:lll + dataBytes := []byte(`data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + checks.NoError(t, err, "Write error") + }) + + stream, err := client.CreateChatCompletionStream(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }) + checks.NoError(t, err, "CreateCompletionStream returned error") + defer stream.Close() + + headers := stream.GetRateLimitHeaders() + bs1, _ := json.Marshal(headers) + bs2, _ := json.Marshal(rateLimitHeaders) + if string(bs1) != string(bs2) { + t.Errorf("expected rate limit header %s to be %s", bs2, bs1) + } +} + func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown() diff --git a/chat_test.go b/chat_test.go index 52cd0bdef..675284af4 100644 --- a/chat_test.go +++ b/chat_test.go @@ -21,6 +21,17 @@ const ( xCustomHeaderValue = "test" ) +var ( + rateLimitHeaders = map[string]any{ + "x-ratelimit-limit-requests": 60, + "x-ratelimit-limit-tokens": 150000, + "x-ratelimit-remaining-requests": 59, + "x-ratelimit-remaining-tokens": 149984, + "x-ratelimit-reset-requests": "1s", + "x-ratelimit-reset-tokens": "6m0s", + } +) + func TestChatCompletionsWrongModel(t *testing.T) { config := DefaultConfig("whatever") config.BaseURL = "http://localhost/v1" @@ -97,6 +108,36 @@ func TestChatCompletionsWithHeaders(t *testing.T) { } } +// TestChatCompletionsWithRateLimitHeaders Tests the completions endpoint of the API using the mocked server. +func TestChatCompletionsWithRateLimitHeaders(t *testing.T) { + client, server, teardown := setupOpenAITestServer() + defer teardown() + server.RegisterHandler("/v1/chat/completions", handleChatCompletionEndpoint) + resp, err := client.CreateChatCompletion(context.Background(), ChatCompletionRequest{ + MaxTokens: 5, + Model: GPT3Dot5Turbo, + Messages: []ChatCompletionMessage{ + { + Role: ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }) + checks.NoError(t, err, "CreateChatCompletion error") + + headers := resp.GetRateLimitHeaders() + resetReqTime, err := headers.ParseResetRequestsTime() + checks.NoError(t, err, "ParseResetRequestsTime error") + resetTokensTime, err := headers.ParseResetTokensTime() + checks.NoError(t, err, "ParseResetTokensTime error") + t.Logf("reset requests time: %s, reset tokens time: %s", resetReqTime, resetTokensTime) + bs1, _ := json.Marshal(headers) + bs2, _ := json.Marshal(rateLimitHeaders) + if string(bs1) != string(bs2) { + t.Errorf("expected rate limit header %s to be %s", bs2, bs1) + } +} + // TestChatCompletionsFunctions tests including a function call. func TestChatCompletionsFunctions(t *testing.T) { client, server, teardown := setupOpenAITestServer() @@ -311,6 +352,14 @@ func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { } resBytes, _ = json.Marshal(res) w.Header().Set(xCustomHeader, xCustomHeaderValue) + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } fmt.Fprintln(w, string(resBytes)) } diff --git a/client.go b/client.go index 244a7c7cb..75d391cc8 100644 --- a/client.go +++ b/client.go @@ -34,6 +34,10 @@ func (h *httpHeader) Header() http.Header { return http.Header(*h) } +func (h *httpHeader) GetRateLimitHeaders() RateLimitHeaders { + return NewRateLimitHeaders(h.Header()) +} + // NewClient creates new OpenAI API client. func NewClient(authToken string) *Client { config := DefaultConfig(authToken) diff --git a/models.go b/models.go index d94f98836..a1d8492e1 100644 --- a/models.go +++ b/models.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "net/http" + "strconv" + "time" ) // Model struct represents an OpenAPI model. @@ -88,3 +90,46 @@ func (c *Client) DeleteFineTuneModel(ctx context.Context, modelID string) ( err = c.sendRequest(req, &response) return } + +// RateLimitHeaders struct represents Openai rate limits headers +type RateLimitHeaders struct { + createTime time.Time + LimitRequests int `json:"x-ratelimit-limit-requests"` + LimitTokens int `json:"x-ratelimit-limit-tokens"` + RemainingRequests int `json:"x-ratelimit-remaining-requests"` + RemainingTokens int `json:"x-ratelimit-remaining-tokens"` + ResetRequests string `json:"x-ratelimit-reset-requests"` + ResetTokens string `json:"x-ratelimit-reset-tokens"` +} + +func NewRateLimitHeaders(h http.Header) RateLimitHeaders { + limitReq, _ := strconv.Atoi(h.Get("x-ratelimit-limit-requests")) + limitTokens, _ := strconv.Atoi(h.Get("x-ratelimit-limit-tokens")) + remainingReq, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-requests")) + remainingTokens, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-tokens")) + return RateLimitHeaders{ + createTime: time.Now(), + LimitRequests: limitReq, + LimitTokens: limitTokens, + RemainingRequests: remainingReq, + RemainingTokens: remainingTokens, + ResetRequests: h.Get("x-ratelimit-reset-requests"), + ResetTokens: h.Get("x-ratelimit-reset-tokens"), + } +} + +func (r RateLimitHeaders) ParseResetRequestsTime() (time.Time, error) { + d, err := time.ParseDuration(r.ResetRequests) + if err != nil { + return time.Time{}, err + } + return r.createTime.Add(d), nil +} + +func (r RateLimitHeaders) ParseResetTokensTime() (time.Time, error) { + d, err := time.ParseDuration(r.ResetTokens) + if err != nil { + return time.Time{}, err + } + return r.createTime.Add(d), nil +} diff --git a/stream_reader.go b/stream_reader.go index c4cd9ac5d..a5f4376c6 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -116,3 +116,7 @@ func (stream *streamReader[T]) Header() http.Header { } return map[string][]string{} } + +func (stream *streamReader[T]) GetRateLimitHeaders() RateLimitHeaders { + return NewRateLimitHeaders(stream.Header()) +} From f6e93e77d0e3f72a69fc69030205e298e6c56356 Mon Sep 17 00:00:00 2001 From: liushuang Date: Mon, 9 Oct 2023 14:49:47 +0800 Subject: [PATCH 3/9] fix: go lint --- models.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models.go b/models.go index a1d8492e1..0afceeb7b 100644 --- a/models.go +++ b/models.go @@ -91,7 +91,7 @@ func (c *Client) DeleteFineTuneModel(ctx context.Context, modelID string) ( return } -// RateLimitHeaders struct represents Openai rate limits headers +// RateLimitHeaders struct represents Openai rate limits headers. type RateLimitHeaders struct { createTime time.Time LimitRequests int `json:"x-ratelimit-limit-requests"` From c99a010887d01c5da698922038d9a96279ae37db Mon Sep 17 00:00:00 2001 From: liushuang Date: Mon, 9 Oct 2023 15:04:42 +0800 Subject: [PATCH 4/9] fix: test coverage --- chat_test.go | 6 ++++++ stream_reader.go | 5 +---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/chat_test.go b/chat_test.go index 675284af4..2693a5fa9 100644 --- a/chat_test.go +++ b/chat_test.go @@ -136,6 +136,12 @@ func TestChatCompletionsWithRateLimitHeaders(t *testing.T) { if string(bs1) != string(bs2) { t.Errorf("expected rate limit header %s to be %s", bs2, bs1) } + headers.ResetRequests = "xxx" + headers.ResetTokens = "xxx" + _, err = headers.ParseResetRequestsTime() + checks.HasError(t, err, "ParseResetRequestsTime not error") + _, err = headers.ParseResetTokensTime() + checks.HasError(t, err, "ParseResetTokensTime not error") } // TestChatCompletionsFunctions tests including a function call. diff --git a/stream_reader.go b/stream_reader.go index a5f4376c6..63bec6a19 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -111,10 +111,7 @@ func (stream *streamReader[T]) Close() { } func (stream *streamReader[T]) Header() http.Header { - if stream.response != nil { - return stream.response.Header - } - return map[string][]string{} + return stream.response.Header } func (stream *streamReader[T]) GetRateLimitHeaders() RateLimitHeaders { From 54c904d889645584dce08c1b58082fc829af261c Mon Sep 17 00:00:00 2001 From: liushuang Date: Tue, 10 Oct 2023 15:28:06 +0800 Subject: [PATCH 5/9] refactor streamReader --- client.go | 6 ++++-- stream_reader.go | 10 ++-------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/client.go b/client.go index 75d391cc8..f3212ac78 100644 --- a/client.go +++ b/client.go @@ -154,13 +154,15 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream if isFailureStatusCode(resp) { return new(streamReader[T]), client.handleErrorResp(resp) } - return &streamReader[T]{ + reader := &streamReader[T]{ emptyMessagesLimit: client.config.EmptyMessagesLimit, reader: bufio.NewReader(resp.Body), response: resp, errAccumulator: utils.NewErrorAccumulator(), unmarshaler: &utils.JSONUnmarshaler{}, - }, nil + } + reader.SetHeader(resp.Header) + return reader, nil } func (c *Client) setCommonHeaders(req *http.Request) { diff --git a/stream_reader.go b/stream_reader.go index 63bec6a19..d17412591 100644 --- a/stream_reader.go +++ b/stream_reader.go @@ -27,6 +27,8 @@ type streamReader[T streamable] struct { response *http.Response errAccumulator utils.ErrorAccumulator unmarshaler utils.Unmarshaler + + httpHeader } func (stream *streamReader[T]) Recv() (response T, err error) { @@ -109,11 +111,3 @@ func (stream *streamReader[T]) unmarshalError() (errResp *ErrorResponse) { func (stream *streamReader[T]) Close() { stream.response.Body.Close() } - -func (stream *streamReader[T]) Header() http.Header { - return stream.response.Header -} - -func (stream *streamReader[T]) GetRateLimitHeaders() RateLimitHeaders { - return NewRateLimitHeaders(stream.Header()) -} From 1b9af61e166c44df7025b07fca0a1e64467b6dde Mon Sep 17 00:00:00 2001 From: liushuang Date: Tue, 10 Oct 2023 16:18:54 +0800 Subject: [PATCH 6/9] refactor streamReader --- client.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index f3212ac78..ac7f6a4b0 100644 --- a/client.go +++ b/client.go @@ -154,15 +154,14 @@ func sendRequestStream[T streamable](client *Client, req *http.Request) (*stream if isFailureStatusCode(resp) { return new(streamReader[T]), client.handleErrorResp(resp) } - reader := &streamReader[T]{ + return &streamReader[T]{ emptyMessagesLimit: client.config.EmptyMessagesLimit, reader: bufio.NewReader(resp.Body), response: resp, errAccumulator: utils.NewErrorAccumulator(), unmarshaler: &utils.JSONUnmarshaler{}, - } - reader.SetHeader(resp.Header) - return reader, nil + httpHeader: httpHeader(resp.Header), + }, nil } func (c *Client) setCommonHeaders(req *http.Request) { From 703eec0886d155c340a4d0aababfe91290e663a9 Mon Sep 17 00:00:00 2001 From: liushuang Date: Tue, 10 Oct 2023 16:29:24 +0800 Subject: [PATCH 7/9] refactor: NewRateLimitHeaders to newRateLimitHeaders --- client.go | 2 +- models.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index ac7f6a4b0..65ece812f 100644 --- a/client.go +++ b/client.go @@ -35,7 +35,7 @@ func (h *httpHeader) Header() http.Header { } func (h *httpHeader) GetRateLimitHeaders() RateLimitHeaders { - return NewRateLimitHeaders(h.Header()) + return newRateLimitHeaders(h.Header()) } // NewClient creates new OpenAI API client. diff --git a/models.go b/models.go index 0afceeb7b..cb551d045 100644 --- a/models.go +++ b/models.go @@ -102,7 +102,7 @@ type RateLimitHeaders struct { ResetTokens string `json:"x-ratelimit-reset-tokens"` } -func NewRateLimitHeaders(h http.Header) RateLimitHeaders { +func newRateLimitHeaders(h http.Header) RateLimitHeaders { limitReq, _ := strconv.Atoi(h.Get("x-ratelimit-limit-requests")) limitTokens, _ := strconv.Atoi(h.Get("x-ratelimit-limit-tokens")) remainingReq, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-requests")) From e83a111b89a039f86ebbbe16c58969e13a6d61c8 Mon Sep 17 00:00:00 2001 From: liushuang Date: Tue, 10 Oct 2023 19:23:23 +0800 Subject: [PATCH 8/9] refactor: RateLimitHeaders Resets filed --- chat_test.go | 20 +++++++++----------- models.go | 43 +++++++++++++++++++------------------------ 2 files changed, 28 insertions(+), 35 deletions(-) diff --git a/chat_test.go b/chat_test.go index 2693a5fa9..329b2b9cb 100644 --- a/chat_test.go +++ b/chat_test.go @@ -126,22 +126,20 @@ func TestChatCompletionsWithRateLimitHeaders(t *testing.T) { checks.NoError(t, err, "CreateChatCompletion error") headers := resp.GetRateLimitHeaders() - resetReqTime, err := headers.ParseResetRequestsTime() - checks.NoError(t, err, "ParseResetRequestsTime error") - resetTokensTime, err := headers.ParseResetTokensTime() - checks.NoError(t, err, "ParseResetTokensTime error") - t.Logf("reset requests time: %s, reset tokens time: %s", resetReqTime, resetTokensTime) + resetRequests := headers.ResetRequests.String() + if resetRequests != rateLimitHeaders["x-ratelimit-reset-requests"] { + t.Errorf("expected resetRequests %s to be %s", resetRequests, rateLimitHeaders["x-ratelimit-reset-requests"]) + } + resetRequestsTime := headers.ResetRequests.Time() + if resetRequestsTime.Before(time.Now()) { + t.Errorf("unexpected reset requetsts: %v", resetRequestsTime) + } + bs1, _ := json.Marshal(headers) bs2, _ := json.Marshal(rateLimitHeaders) if string(bs1) != string(bs2) { t.Errorf("expected rate limit header %s to be %s", bs2, bs1) } - headers.ResetRequests = "xxx" - headers.ResetTokens = "xxx" - _, err = headers.ParseResetRequestsTime() - checks.HasError(t, err, "ParseResetRequestsTime not error") - _, err = headers.ParseResetTokensTime() - checks.HasError(t, err, "ParseResetTokensTime not error") } // TestChatCompletionsFunctions tests including a function call. diff --git a/models.go b/models.go index cb551d045..6e3419524 100644 --- a/models.go +++ b/models.go @@ -94,12 +94,23 @@ func (c *Client) DeleteFineTuneModel(ctx context.Context, modelID string) ( // RateLimitHeaders struct represents Openai rate limits headers. type RateLimitHeaders struct { createTime time.Time - LimitRequests int `json:"x-ratelimit-limit-requests"` - LimitTokens int `json:"x-ratelimit-limit-tokens"` - RemainingRequests int `json:"x-ratelimit-remaining-requests"` - RemainingTokens int `json:"x-ratelimit-remaining-tokens"` - ResetRequests string `json:"x-ratelimit-reset-requests"` - ResetTokens string `json:"x-ratelimit-reset-tokens"` + LimitRequests int `json:"x-ratelimit-limit-requests"` + LimitTokens int `json:"x-ratelimit-limit-tokens"` + RemainingRequests int `json:"x-ratelimit-remaining-requests"` + RemainingTokens int `json:"x-ratelimit-remaining-tokens"` + ResetRequests ResetTime `json:"x-ratelimit-reset-requests"` + ResetTokens ResetTime `json:"x-ratelimit-reset-tokens"` +} + +type ResetTime string + +func (r ResetTime) String() string { + return string(r) +} + +func (r ResetTime) Time() time.Time { + d, _ := time.ParseDuration(string(r)) + return time.Now().Add(d) } func newRateLimitHeaders(h http.Header) RateLimitHeaders { @@ -113,23 +124,7 @@ func newRateLimitHeaders(h http.Header) RateLimitHeaders { LimitTokens: limitTokens, RemainingRequests: remainingReq, RemainingTokens: remainingTokens, - ResetRequests: h.Get("x-ratelimit-reset-requests"), - ResetTokens: h.Get("x-ratelimit-reset-tokens"), - } -} - -func (r RateLimitHeaders) ParseResetRequestsTime() (time.Time, error) { - d, err := time.ParseDuration(r.ResetRequests) - if err != nil { - return time.Time{}, err - } - return r.createTime.Add(d), nil -} - -func (r RateLimitHeaders) ParseResetTokensTime() (time.Time, error) { - d, err := time.ParseDuration(r.ResetTokens) - if err != nil { - return time.Time{}, err + ResetRequests: ResetTime(h.Get("x-ratelimit-reset-requests")), + ResetTokens: ResetTime(h.Get("x-ratelimit-reset-tokens")), } - return r.createTime.Add(d), nil } From e61d41f0739a883f21ddc0b2778a588700bf960e Mon Sep 17 00:00:00 2001 From: liushuang Date: Tue, 10 Oct 2023 22:42:47 +0800 Subject: [PATCH 9/9] refactor: move RateLimitHeaders struct --- models.go | 40 ---------------------------------------- ratelimit.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 40 deletions(-) create mode 100644 ratelimit.go diff --git a/models.go b/models.go index 6e3419524..d94f98836 100644 --- a/models.go +++ b/models.go @@ -4,8 +4,6 @@ import ( "context" "fmt" "net/http" - "strconv" - "time" ) // Model struct represents an OpenAPI model. @@ -90,41 +88,3 @@ func (c *Client) DeleteFineTuneModel(ctx context.Context, modelID string) ( err = c.sendRequest(req, &response) return } - -// RateLimitHeaders struct represents Openai rate limits headers. -type RateLimitHeaders struct { - createTime time.Time - LimitRequests int `json:"x-ratelimit-limit-requests"` - LimitTokens int `json:"x-ratelimit-limit-tokens"` - RemainingRequests int `json:"x-ratelimit-remaining-requests"` - RemainingTokens int `json:"x-ratelimit-remaining-tokens"` - ResetRequests ResetTime `json:"x-ratelimit-reset-requests"` - ResetTokens ResetTime `json:"x-ratelimit-reset-tokens"` -} - -type ResetTime string - -func (r ResetTime) String() string { - return string(r) -} - -func (r ResetTime) Time() time.Time { - d, _ := time.ParseDuration(string(r)) - return time.Now().Add(d) -} - -func newRateLimitHeaders(h http.Header) RateLimitHeaders { - limitReq, _ := strconv.Atoi(h.Get("x-ratelimit-limit-requests")) - limitTokens, _ := strconv.Atoi(h.Get("x-ratelimit-limit-tokens")) - remainingReq, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-requests")) - remainingTokens, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-tokens")) - return RateLimitHeaders{ - createTime: time.Now(), - LimitRequests: limitReq, - LimitTokens: limitTokens, - RemainingRequests: remainingReq, - RemainingTokens: remainingTokens, - ResetRequests: ResetTime(h.Get("x-ratelimit-reset-requests")), - ResetTokens: ResetTime(h.Get("x-ratelimit-reset-tokens")), - } -} diff --git a/ratelimit.go b/ratelimit.go new file mode 100644 index 000000000..e8953f716 --- /dev/null +++ b/ratelimit.go @@ -0,0 +1,43 @@ +package openai + +import ( + "net/http" + "strconv" + "time" +) + +// RateLimitHeaders struct represents Openai rate limits headers. +type RateLimitHeaders struct { + LimitRequests int `json:"x-ratelimit-limit-requests"` + LimitTokens int `json:"x-ratelimit-limit-tokens"` + RemainingRequests int `json:"x-ratelimit-remaining-requests"` + RemainingTokens int `json:"x-ratelimit-remaining-tokens"` + ResetRequests ResetTime `json:"x-ratelimit-reset-requests"` + ResetTokens ResetTime `json:"x-ratelimit-reset-tokens"` +} + +type ResetTime string + +func (r ResetTime) String() string { + return string(r) +} + +func (r ResetTime) Time() time.Time { + d, _ := time.ParseDuration(string(r)) + return time.Now().Add(d) +} + +func newRateLimitHeaders(h http.Header) RateLimitHeaders { + limitReq, _ := strconv.Atoi(h.Get("x-ratelimit-limit-requests")) + limitTokens, _ := strconv.Atoi(h.Get("x-ratelimit-limit-tokens")) + remainingReq, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-requests")) + remainingTokens, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-tokens")) + return RateLimitHeaders{ + LimitRequests: limitReq, + LimitTokens: limitTokens, + RemainingRequests: remainingReq, + RemainingTokens: remainingTokens, + ResetRequests: ResetTime(h.Get("x-ratelimit-reset-requests")), + ResetTokens: ResetTime(h.Get("x-ratelimit-reset-tokens")), + } +}