Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JWT support for gRPC server #75

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vertx-grpc-client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
<groupId>io.vertx</groupId>
<artifactId>vertx-grpc-common</artifactId>
</dependency>
<dependency>
<groupId>io.vertx</groupId>
<artifactId>vertx-auth-common</artifactId>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-stub</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
*/
package io.vertx.grpc.client;

import java.util.function.Function;

import io.grpc.MethodDescriptor;
import io.vertx.codegen.annotations.GenIgnore;
import io.vertx.codegen.annotations.VertxGen;
Expand All @@ -19,11 +21,9 @@
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpClientOptions;
import io.vertx.core.net.SocketAddress;
import io.vertx.core.streams.ReadStream;
import io.vertx.ext.auth.authentication.Credentials;
import io.vertx.grpc.client.impl.GrpcClientImpl;

import java.util.function.Function;

/**
* A gRPC client for Vert.x
*
Expand Down Expand Up @@ -112,6 +112,15 @@ default <Req, Resp, T> Future<T> call(SocketAddress server, MethodDescriptor<Req
});
}

/**
* Mark that request should be dispatched with authentication obtained from passed {@code JWTAuth} provider
*
* @param credentials
* @return a reference to this, so the API can be used fluently
*/
@GenIgnore(GenIgnore.PERMITTED_TYPE)
GrpcClient credentials(Credentials credentials);

/**
* Close this client.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,21 @@
*/
package io.vertx.grpc.client.impl;

import com.google.common.net.HttpHeaders;

import io.grpc.MethodDescriptor;
import io.vertx.core.Future;
import io.vertx.core.Vertx;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpClient;
import io.vertx.core.http.HttpClientOptions;
import io.vertx.core.http.HttpClientRequest;
import io.vertx.core.http.HttpMethod;
import io.vertx.core.http.HttpVersion;
import io.vertx.core.http.RequestOptions;
import io.vertx.core.net.SocketAddress;
import io.vertx.ext.auth.authentication.Credentials;
import io.vertx.ext.auth.authentication.TokenCredentials;
import io.vertx.grpc.client.GrpcClient;
import io.vertx.grpc.client.GrpcClientRequest;
import io.vertx.grpc.common.GrpcMessageDecoder;
Expand All @@ -32,6 +37,7 @@ public class GrpcClientImpl implements GrpcClient {

private final Vertx vertx;
private HttpClient client;
private Credentials credentials;

public GrpcClientImpl(HttpClientOptions options, Vertx vertx) {
this.vertx = vertx;
Expand All @@ -48,7 +54,10 @@ public GrpcClientImpl(Vertx vertx) {
.setMethod(HttpMethod.POST)
.setServer(server);
return client.request(options)
.map(request -> new GrpcClientRequestImpl<>(request, GrpcMessageEncoder.IDENTITY, GrpcMessageDecoder.IDENTITY));
.map(request -> {
addCredentials(request);
return new GrpcClientRequestImpl<>(request, GrpcMessageEncoder.IDENTITY, GrpcMessageDecoder.IDENTITY);
});
}

@Override public <Req, Resp> Future<GrpcClientRequest<Req, Resp>> request(SocketAddress server, MethodDescriptor<Req, Resp> service) {
Expand All @@ -59,14 +68,32 @@ public GrpcClientImpl(Vertx vertx) {
GrpcMessageEncoder<Req> messageEncoder = GrpcMessageEncoder.marshaller(service.getRequestMarshaller());
return client.request(options)
.map(request -> {
addCredentials(request);
GrpcClientRequestImpl<Req, Resp> call = new GrpcClientRequestImpl<>(request, messageEncoder, messageDecoder);
call.fullMethodName(service.getFullMethodName());
return call;
});
}

@Override
public GrpcClient credentials(Credentials credentials) {
if (credentials == null) {
Copy link
Member

@vietj vietj Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think null could be accepted to null out credentials

throw new NullPointerException("Credentials passed to GrpcClient can not be null");
}

this.credentials = credentials;
return this;
}

@Override
public Future<Void> close() {
return client.close();
}

private void addCredentials(HttpClientRequest request) {
if (credentials == null) {
return;
}
request.headers().add(HttpHeaders.AUTHORIZATION, credentials.toHttpAuthorization());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
package io.vertx.grpc.client;

import io.grpc.*;
import io.grpc.ServerCall.Listener;
import io.grpc.examples.helloworld.GreeterGrpc;
import io.grpc.examples.helloworld.HelloReply;
import io.grpc.examples.helloworld.HelloRequest;
Expand All @@ -20,9 +21,11 @@
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;
import io.vertx.core.http.HttpClientOptions;
import io.vertx.core.http.HttpHeaders;
import io.vertx.core.http.StreamResetException;
import io.vertx.core.net.SelfSignedCertificate;
import io.vertx.core.net.SocketAddress;
import io.vertx.ext.auth.authentication.TokenCredentials;
import io.vertx.ext.unit.Async;
import io.vertx.ext.unit.TestContext;
import io.vertx.grpc.common.GrpcReadStream;
Expand All @@ -31,9 +34,7 @@

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -74,6 +75,52 @@ protected void testUnary(TestContext should, String requestEncoding, String resp
}));
}

@Test
public void testJWT(TestContext should) throws IOException {
final Async test = should.async();
final String clientToken = "ABC";

GreeterGrpc.GreeterImplBase called = new GreeterGrpc.GreeterImplBase() {
@Override
public void sayHello(HelloRequest request, StreamObserver<HelloReply> responseObserver) {
responseObserver.onNext(HelloReply.newBuilder().setMessage("Hello " + request.getName()).build());
responseObserver.onCompleted();
}
};

ServerInterceptor interceptor = new ServerInterceptor() {
@Override
public <ReqT, RespT> Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call, Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
String serverToken = headers.get(Metadata.Key.of(HttpHeaders.AUTHORIZATION.toString(), Metadata.ASCII_STRING_MARSHALLER));
serverToken = serverToken.substring("Bearer ".length());
should.assertEquals(clientToken, serverToken);
return next.startCall(call, headers);
}
};

Server server = ServerBuilder.forPort(0).addService(called).intercept(interceptor).build().start();

client = GrpcClient.client(vertx).credentials(new TokenCredentials(clientToken));
client.request(SocketAddress.inetSocketAddress(server.getPort(), "localhost"), GreeterGrpc.getSayHelloMethod())
.onComplete(should.asyncAssertSuccess(callRequest -> {
callRequest.response().onComplete(should.asyncAssertSuccess(callResponse -> {
AtomicInteger count = new AtomicInteger();
callResponse.handler(reply -> {
should.assertEquals(1, count.incrementAndGet());
should.assertEquals("Hello Julien", reply.getMessage());
});
callResponse.endHandler(v2 -> {
should.assertEquals(GrpcStatus.OK, callResponse.status());
should.assertEquals(1, count.get());
test.complete();
});
}));
callRequest.end(HelloRequest.newBuilder().setName("Julien").build());
}));
}

@Test
public void testSSL(TestContext should) throws IOException {

Expand All @@ -95,7 +142,7 @@ public void sayHello(HelloRequest request, StreamObserver<HelloReply> responseOb
Async test = should.async();
client = GrpcClient.client(vertx, new HttpClientOptions().setSsl(true)
.setUseAlpn(true)
.setPemTrustOptions(cert.trustOptions()));
.setTrustOptions(cert.trustOptions()));
client.request(SocketAddress.inetSocketAddress(8443, "localhost"), GreeterGrpc.getSayHelloMethod())
.onComplete(should.asyncAssertSuccess(callRequest -> {
callRequest.response().onComplete(should.asyncAssertSuccess(callResponse -> {
Expand Down
5 changes: 5 additions & 0 deletions vertx-grpc-server/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@
<groupId>io.vertx</groupId>
<artifactId>vertx-grpc-common</artifactId>
</dependency>
<dependency>
<groupId>io.vertx</groupId>
<artifactId>vertx-auth-jwt</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-stub</artifactId>
Expand Down
30 changes: 30 additions & 0 deletions vertx-grpc-server/src/main/asciidoc/server.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,36 @@ You can compress response messages by setting the response encoding *prior* befo

Decompression is done transparently by the server when the client send encoded requests.

=== JWT Authentication

Service method handlers can be secured by providing an {@link io.vertx.grpc.server.auth.GrpcAuthenticationHandler} in the {@link io.vertx.grpc.server.GrpcServer#callHandler} method.

For JWT a handler can be created via {@link io.vertx.grpc.server.auth.GrpcJWTAuthenticationHandler#create}.

Add the required dependency to provide JWT via {@link io.vertx.ext.auth.jwt.JWTAuth}.

[source,java]
----
<dependency>
<groupId>io.vertx</groupId>
<artifactId>vertx-auth-jwt</artifactId>
</dependency>
----

Handlers that were registered without the handler will not use authentication.

[source,java]
----
{@link examples.GrpcServerExamples#jwtServerAuthExample}
----

The JWT can be added for the client call via an interceptor in order to provide the token credentials.

[source,java]
----
{@link examples.GrpcServerExamples#jwtClientAuthExample}
----

=== Stub API

The Vert.x gRPC Server can bridge a gRPC service to use with a generated server stub in a more traditional fashion
Expand Down
74 changes: 74 additions & 0 deletions vertx-grpc-server/src/main/java/examples/GreeterGrpc.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,37 @@ examples.HelloReply> getSayHelloMethod() {
return getSayHelloMethod;
}

private static volatile io.grpc.MethodDescriptor<examples.HelloRequest,
examples.HelloReply> getSaySecuredHelloMethod;

@io.grpc.stub.annotations.RpcMethod(
fullMethodName = SERVICE_NAME + '/' + "SaySecuredHello",
requestType = examples.HelloRequest.class,
responseType = examples.HelloReply.class,
methodType = io.grpc.MethodDescriptor.MethodType.UNARY)
public static io.grpc.MethodDescriptor<examples.HelloRequest,
examples.HelloReply> getSaySecuredHelloMethod() {
io.grpc.MethodDescriptor<examples.HelloRequest, examples.HelloReply> getSaySecuredHelloMethod;
if ((getSaySecuredHelloMethod = GreeterGrpc.getSaySecuredHelloMethod) == null) {
synchronized (GreeterGrpc.class) {
if ((getSaySecuredHelloMethod = GreeterGrpc.getSaySecuredHelloMethod) == null) {
GreeterGrpc.getSaySecuredHelloMethod = getSaySecuredHelloMethod =
io.grpc.MethodDescriptor.<examples.HelloRequest, examples.HelloReply>newBuilder()
.setType(io.grpc.MethodDescriptor.MethodType.UNARY)
.setFullMethodName(generateFullMethodName(SERVICE_NAME, "SaySecuredHello"))
.setSampledToLocalTracing(true)
.setRequestMarshaller(io.grpc.protobuf.ProtoUtils.marshaller(
examples.HelloRequest.getDefaultInstance()))
.setResponseMarshaller(io.grpc.protobuf.ProtoUtils.marshaller(
examples.HelloReply.getDefaultInstance()))
.setSchemaDescriptor(new GreeterMethodDescriptorSupplier("SaySecuredHello"))
.build();
}
}
}
return getSaySecuredHelloMethod;
}

/**
* Creates a new async stub that supports all call types for the service
*/
Expand Down Expand Up @@ -110,6 +141,13 @@ public void sayHello(examples.HelloRequest request,
io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall(getSayHelloMethod(), responseObserver);
}

/**
*/
public void saySecuredHello(examples.HelloRequest request,
io.grpc.stub.StreamObserver<examples.HelloReply> responseObserver) {
io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall(getSaySecuredHelloMethod(), responseObserver);
}

@java.lang.Override public final io.grpc.ServerServiceDefinition bindService() {
return io.grpc.ServerServiceDefinition.builder(getServiceDescriptor())
.addMethod(
Expand All @@ -119,6 +157,13 @@ public void sayHello(examples.HelloRequest request,
examples.HelloRequest,
examples.HelloReply>(
this, METHODID_SAY_HELLO)))
.addMethod(
getSaySecuredHelloMethod(),
io.grpc.stub.ServerCalls.asyncUnaryCall(
new MethodHandlers<
examples.HelloRequest,
examples.HelloReply>(
this, METHODID_SAY_SECURED_HELLO)))
.build();
}
}
Expand Down Expand Up @@ -150,6 +195,14 @@ public void sayHello(examples.HelloRequest request,
io.grpc.stub.ClientCalls.asyncUnaryCall(
getChannel().newCall(getSayHelloMethod(), getCallOptions()), request, responseObserver);
}

/**
*/
public void saySecuredHello(examples.HelloRequest request,
io.grpc.stub.StreamObserver<examples.HelloReply> responseObserver) {
io.grpc.stub.ClientCalls.asyncUnaryCall(
getChannel().newCall(getSaySecuredHelloMethod(), getCallOptions()), request, responseObserver);
}
}

/**
Expand Down Expand Up @@ -178,6 +231,13 @@ public examples.HelloReply sayHello(examples.HelloRequest request) {
return io.grpc.stub.ClientCalls.blockingUnaryCall(
getChannel(), getSayHelloMethod(), getCallOptions(), request);
}

/**
*/
public examples.HelloReply saySecuredHello(examples.HelloRequest request) {
return io.grpc.stub.ClientCalls.blockingUnaryCall(
getChannel(), getSaySecuredHelloMethod(), getCallOptions(), request);
}
}

/**
Expand Down Expand Up @@ -207,9 +267,18 @@ public com.google.common.util.concurrent.ListenableFuture<examples.HelloReply> s
return io.grpc.stub.ClientCalls.futureUnaryCall(
getChannel().newCall(getSayHelloMethod(), getCallOptions()), request);
}

/**
*/
public com.google.common.util.concurrent.ListenableFuture<examples.HelloReply> saySecuredHello(
examples.HelloRequest request) {
return io.grpc.stub.ClientCalls.futureUnaryCall(
getChannel().newCall(getSaySecuredHelloMethod(), getCallOptions()), request);
}
}

private static final int METHODID_SAY_HELLO = 0;
private static final int METHODID_SAY_SECURED_HELLO = 1;

private static final class MethodHandlers<Req, Resp> implements
io.grpc.stub.ServerCalls.UnaryMethod<Req, Resp>,
Expand All @@ -232,6 +301,10 @@ public void invoke(Req request, io.grpc.stub.StreamObserver<Resp> responseObserv
serviceImpl.sayHello((examples.HelloRequest) request,
(io.grpc.stub.StreamObserver<examples.HelloReply>) responseObserver);
break;
case METHODID_SAY_SECURED_HELLO:
serviceImpl.saySecuredHello((examples.HelloRequest) request,
(io.grpc.stub.StreamObserver<examples.HelloReply>) responseObserver);
break;
default:
throw new AssertionError();
}
Expand Down Expand Up @@ -294,6 +367,7 @@ public static io.grpc.ServiceDescriptor getServiceDescriptor() {
serviceDescriptor = result = io.grpc.ServiceDescriptor.newBuilder(SERVICE_NAME)
.setSchemaDescriptor(new GreeterFileDescriptorSupplier())
.addMethod(getSayHelloMethod())
.addMethod(getSaySecuredHelloMethod())
.build();
}
}
Expand Down
Loading