Skip to content

Commit

Permalink
API key support via credentials abstraction (#1409)
Browse files Browse the repository at this point in the history
Fixes #1401
  • Loading branch information
cretz authored Mar 8, 2024
1 parent efabf46 commit 1052374
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 3 deletions.
42 changes: 42 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ package client

import (
"context"
"crypto/tls"
"io"

commonpb "go.temporal.io/api/common/v1"
Expand Down Expand Up @@ -90,6 +91,9 @@ type (
// ConnectionOptions are optional parameters that can be specified in ClientOptions
ConnectionOptions = internal.ConnectionOptions

// Credentials are optional credentials that can be specified in ClientOptions.
Credentials = internal.Credentials

// StartWorkflowOptions configuration parameters for starting a workflow execution.
StartWorkflowOptions = internal.StartWorkflowOptions

Expand Down Expand Up @@ -752,3 +756,41 @@ type HistoryJSONOptions struct {
func HistoryFromJSON(r io.Reader, options HistoryJSONOptions) (*historypb.History, error) {
return internal.HistoryFromJSON(r, options.LastEventID)
}

// NewAPIKeyStaticCredentials creates credentials that can be provided to
// ClientOptions to use a fixed API key.
//
// This is the equivalent of providing a headers provider that sets the
// "Authorization" header with "Bearer " + the given key. This will overwrite
// any "Authorization" header that may be on the context or from existing header
// provider.
//
// Note, this uses a fixed header value for authentication. Many users that want
// to rotate this value without reconnecting should use
// [NewAPIKeyDynamicCredentials].
func NewAPIKeyStaticCredentials(apiKey string) Credentials {
return internal.NewAPIKeyStaticCredentials(apiKey)
}

// NewAPIKeyDynamicCredentials creates credentials powered by a callback that
// is invoked on each request. The callback accepts the context that is given by
// the calling user and can return a key or an error. When error is non-nil, the
// client call is failed with that error. When string is non-empty, it is used
// as the API key. When string is empty, nothing is set/overridden.
//
// This is the equivalent of providing a headers provider that returns the
// "Authorization" header with "Bearer " + the given function result. If the
// resulting string is non-empty, it will overwrite any "Authorization" header
// that may be on the context or from existing header provider.
func NewAPIKeyDynamicCredentials(apiKeyCallback func(context.Context) (string, error)) Credentials {
return internal.NewAPIKeyDynamicCredentials(apiKeyCallback)
}

// NewMTLSCredentials creates credentials that use TLS with the client
// certificate as the given one. If the client options do not already enable
// TLS, this enables it. If the client options' TLS configuration is present and
// already has a client certificate, client creation will fail when applying
// these credentials.
func NewMTLSCredentials(certificate tls.Certificate) Credentials {
return internal.NewMTLSCredentials(certificate)
}
72 changes: 71 additions & 1 deletion internal/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"go.temporal.io/api/operatorservice/v1"
"go.temporal.io/api/workflowservice/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"

"go.temporal.io/sdk/converter"
"go.temporal.io/sdk/internal/common/metrics"
Expand Down Expand Up @@ -421,6 +422,9 @@ type (
// default: default
Namespace string

// Optional: Set the credentials for this client.
Credentials Credentials

// Optional: Logger framework can use to log.
// default: default logger provided.
Logger log.Logger
Expand Down Expand Up @@ -473,7 +477,6 @@ type (
HeadersProvider interface {
GetHeaders(ctx context.Context) (map[string]string, error)
}

// TrafficController is getting called in the interceptor chain with API invocation parameters.
// Result is either nil if API call is allowed or an error, in which case request would be interrupted and
// the error will be propagated back through the interceptor chain.
Expand Down Expand Up @@ -704,6 +707,13 @@ type (
}
)

// Credentials are optional credentials that can be specified in ClientOptions.
type Credentials interface {
applyToOptions(*ClientOptions) error
// Can return nil to have no interceptor
gRPCInterceptor() grpc.UnaryClientInterceptor
}

// DialClient creates a client and attempts to connect to the server.
func DialClient(options ClientOptions) (Client, error) {
options.ConnectionOptions.disableEagerConnection = false
Expand Down Expand Up @@ -753,6 +763,12 @@ func newClient(options ClientOptions, existing *WorkflowClient) (Client, error)
options.Logger.Info("No logger configured for temporal client. Created default one.")
}

if options.Credentials != nil {
if err := options.Credentials.applyToOptions(&options); err != nil {
return nil, err
}
}

// Dial or use existing connection
var connection *grpc.ClientConn
var err error
Expand Down Expand Up @@ -800,6 +816,7 @@ func newDialParameters(options *ClientOptions, excludeInternalFromRetry *atomic.
options.HeadersProvider,
options.TrafficController,
excludeInternalFromRetry,
options.Credentials,
),
DefaultServiceConfig: defaultServiceConfig,
}
Expand Down Expand Up @@ -923,3 +940,56 @@ func NewValue(data *commonpb.Payloads) converter.EncodedValue {
func NewValues(data *commonpb.Payloads) converter.EncodedValues {
return newEncodedValues(data, nil)
}

type apiKeyCredentials func(context.Context) (string, error)

func NewAPIKeyStaticCredentials(apiKey string) Credentials {
return NewAPIKeyDynamicCredentials(func(ctx context.Context) (string, error) { return apiKey, nil })
}

func NewAPIKeyDynamicCredentials(apiKeyCallback func(context.Context) (string, error)) Credentials {
return apiKeyCredentials(apiKeyCallback)
}

func (apiKeyCredentials) applyToOptions(*ClientOptions) error { return nil }

func (a apiKeyCredentials) gRPCInterceptor() grpc.UnaryClientInterceptor { return a.gRPCIntercept }

func (a apiKeyCredentials) gRPCIntercept(
ctx context.Context,
method string,
req any,
reply any,
cc *grpc.ClientConn,
invoker grpc.UnaryInvoker,
opts ...grpc.CallOption,
) error {
if apiKey, err := a(ctx); err != nil {
return err
} else if apiKey != "" {
// Do from-add-new instead of append to overwrite anything there
md, _ := metadata.FromOutgoingContext(ctx)
if md == nil {
md = metadata.MD{}
}
md["authorization"] = []string{"Bearer " + apiKey}
ctx = metadata.NewOutgoingContext(ctx, md)
}
return invoker(ctx, method, req, reply, cc, opts...)
}

type mTLSCredentials tls.Certificate

func NewMTLSCredentials(certificate tls.Certificate) Credentials { return mTLSCredentials(certificate) }

func (m mTLSCredentials) applyToOptions(opts *ClientOptions) error {
if opts.ConnectionOptions.TLS == nil {
opts.ConnectionOptions.TLS = &tls.Config{}
} else if len(opts.ConnectionOptions.TLS.Certificates) != 0 {
return fmt.Errorf("cannot apply mTLS credentials, certificates already exist on TLS options")
}
opts.ConnectionOptions.TLS.Certificates = append(opts.ConnectionOptions.TLS.Certificates, tls.Certificate(m))
return nil
}

func (mTLSCredentials) gRPCInterceptor() grpc.UnaryClientInterceptor { return nil }
8 changes: 8 additions & 0 deletions internal/grpc_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ func requiredInterceptors(
headersProvider HeadersProvider,
controller TrafficController,
excludeInternalFromRetry *atomic.Bool,
credentials Credentials,
) []grpc.UnaryClientInterceptor {
interceptors := []grpc.UnaryClientInterceptor{
errorInterceptor,
Expand All @@ -168,6 +169,13 @@ func requiredInterceptors(
if controller != nil {
interceptors = append(interceptors, trafficControllerInterceptor(controller))
}
// Add credentials interceptor. This is intentionally added after headers
// provider to overwrite anything set there.
if credentials != nil {
if interceptor := credentials.gRPCInterceptor(); interceptor != nil {
interceptors = append(interceptors, interceptor)
}
}
return interceptors
}

Expand Down
67 changes: 65 additions & 2 deletions internal/grpc_dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ package internal

import (
"context"
"crypto/tls"
"errors"
"fmt"
"log"
Expand Down Expand Up @@ -127,13 +128,13 @@ func TestHeadersProvider_Error(t *testing.T) {
}

func TestHeadersProvider_NotIncludedWhenNil(t *testing.T) {
interceptors := requiredInterceptors(nil, nil, nil, nil)
interceptors := requiredInterceptors(nil, nil, nil, nil, nil)
require.Equal(t, 5, len(interceptors))
}

func TestHeadersProvider_IncludedWithHeadersProvider(t *testing.T) {
interceptors := requiredInterceptors(nil,
authHeadersProvider{token: "test-auth-token"}, nil, nil)
authHeadersProvider{token: "test-auth-token"}, nil, nil, nil)
require.Equal(t, 6, len(interceptors))
}

Expand Down Expand Up @@ -438,12 +439,73 @@ func TestResourceExhaustedCause(t *testing.T) {
assert.True(t, foundWithoutCause)
}

func TestCredentialsAPIKey(t *testing.T) {
srv, err := startTestGRPCServer()
require.NoError(t, err)
defer srv.Stop()

// Fixed string
client, err := DialClient(ClientOptions{
HostPort: srv.addr,
Credentials: NewAPIKeyStaticCredentials("my-api-key"),
})
require.NoError(t, err)
defer client.Close()
require.Equal(
t,
[]string{"Bearer my-api-key"},
metadata.ValueFromIncomingContext(srv.getSystemInfoRequestContext, "Authorization"),
)

// Callback
client, err = DialClient(ClientOptions{
HostPort: srv.addr,
Credentials: NewAPIKeyDynamicCredentials(func(ctx context.Context) (string, error) {
return "my-callback-api-key", nil
}),
})
require.NoError(t, err)
defer client.Close()
require.Equal(
t,
[]string{"Bearer my-callback-api-key"},
metadata.ValueFromIncomingContext(srv.getSystemInfoRequestContext, "Authorization"),
)
}

func TestCredentialsMTLS(t *testing.T) {
// Just confirming option is set, not full end-to-end mTLS test

// No TLS set
var clientOptions ClientOptions
creds := NewMTLSCredentials(tls.Certificate{Certificate: [][]byte{[]byte("somedata1")}})
require.NoError(t, creds.applyToOptions(&clientOptions))
require.Equal(t, "somedata1", string(clientOptions.ConnectionOptions.TLS.Certificates[0].Certificate[0]))

// TLS already set
clientOptions = ClientOptions{}
clientOptions.ConnectionOptions.TLS = &tls.Config{ServerName: "my-server-name"}
creds = NewMTLSCredentials(tls.Certificate{Certificate: [][]byte{[]byte("somedata2")}})
require.NoError(t, creds.applyToOptions(&clientOptions))
require.Equal(t, "my-server-name", clientOptions.ConnectionOptions.TLS.ServerName)
require.Equal(t, "somedata2", string(clientOptions.ConnectionOptions.TLS.Certificates[0].Certificate[0]))

// Fail with existing cert
clientOptions = ClientOptions{}
clientOptions.ConnectionOptions.TLS = &tls.Config{
Certificates: []tls.Certificate{{Certificate: [][]byte{[]byte("somedata3")}}},
}
creds = NewMTLSCredentials(tls.Certificate{Certificate: [][]byte{[]byte("somedata4")}})
require.Error(t, creds.applyToOptions(&clientOptions))
}

type testGRPCServer struct {
workflowservice.UnimplementedWorkflowServiceServer
*grpc.Server
addr string
healthServer *health.Server
sigWfCount int32
getSystemInfoRequestContext context.Context
getSystemInfoResponse workflowservice.GetSystemInfoResponse
getSystemInfoResponseError error
signalWorkflowExecutionResponse workflowservice.SignalWorkflowExecutionResponse
Expand Down Expand Up @@ -500,6 +562,7 @@ func (t *testGRPCServer) GetSystemInfo(
ctx context.Context,
req *workflowservice.GetSystemInfoRequest,
) (*workflowservice.GetSystemInfoResponse, error) {
t.getSystemInfoRequestContext = ctx
return &t.getSystemInfoResponse, t.getSystemInfoResponseError
}

Expand Down

0 comments on commit 1052374

Please sign in to comment.