diff --git a/vertx-grpc-client/pom.xml b/vertx-grpc-client/pom.xml index ca1797c8..1dfc2363 100644 --- a/vertx-grpc-client/pom.xml +++ b/vertx-grpc-client/pom.xml @@ -39,6 +39,10 @@ io.vertx vertx-grpc-common + + io.vertx + vertx-auth-common + io.grpc grpc-stub diff --git a/vertx-grpc-client/src/main/java/io/vertx/grpc/client/GrpcClient.java b/vertx-grpc-client/src/main/java/io/vertx/grpc/client/GrpcClient.java index 9278c7ca..7b473373 100644 --- a/vertx-grpc-client/src/main/java/io/vertx/grpc/client/GrpcClient.java +++ b/vertx-grpc-client/src/main/java/io/vertx/grpc/client/GrpcClient.java @@ -10,6 +10,8 @@ */ package io.vertx.grpc.client; +import java.util.function.Function; + import io.grpc.MethodDescriptor; import io.vertx.codegen.annotations.GenIgnore; import io.vertx.codegen.annotations.VertxGen; @@ -19,11 +21,9 @@ import io.vertx.core.buffer.Buffer; import io.vertx.core.http.HttpClientOptions; import io.vertx.core.net.SocketAddress; -import io.vertx.core.streams.ReadStream; +import io.vertx.ext.auth.authentication.Credentials; import io.vertx.grpc.client.impl.GrpcClientImpl; -import java.util.function.Function; - /** * A gRPC client for Vert.x * @@ -112,6 +112,15 @@ default Future call(SocketAddress server, MethodDescriptor new GrpcClientRequestImpl<>(request, GrpcMessageEncoder.IDENTITY, GrpcMessageDecoder.IDENTITY)); + .map(request -> { + addCredentials(request); + return new GrpcClientRequestImpl<>(request, GrpcMessageEncoder.IDENTITY, GrpcMessageDecoder.IDENTITY); + }); } @Override public Future> request(SocketAddress server, MethodDescriptor service) { @@ -59,14 +68,32 @@ public GrpcClientImpl(Vertx vertx) { GrpcMessageEncoder messageEncoder = GrpcMessageEncoder.marshaller(service.getRequestMarshaller()); return client.request(options) .map(request -> { + addCredentials(request); GrpcClientRequestImpl call = new GrpcClientRequestImpl<>(request, messageEncoder, messageDecoder); call.fullMethodName(service.getFullMethodName()); return call; }); } + @Override + public GrpcClient credentials(Credentials credentials) { + if (credentials == null) { + throw new NullPointerException("Credentials passed to GrpcClient can not be null"); + } + + this.credentials = credentials; + return this; + } + @Override public Future close() { return client.close(); } + + private void addCredentials(HttpClientRequest request) { + if (credentials == null) { + return; + } + request.headers().add(HttpHeaders.AUTHORIZATION, credentials.toHttpAuthorization()); + } } diff --git a/vertx-grpc-client/src/test/java/io/vertx/grpc/client/ClientRequestTest.java b/vertx-grpc-client/src/test/java/io/vertx/grpc/client/ClientRequestTest.java index b20a16c2..f070dfb8 100644 --- a/vertx-grpc-client/src/test/java/io/vertx/grpc/client/ClientRequestTest.java +++ b/vertx-grpc-client/src/test/java/io/vertx/grpc/client/ClientRequestTest.java @@ -11,6 +11,7 @@ package io.vertx.grpc.client; import io.grpc.*; +import io.grpc.ServerCall.Listener; import io.grpc.examples.helloworld.GreeterGrpc; import io.grpc.examples.helloworld.HelloReply; import io.grpc.examples.helloworld.HelloRequest; @@ -20,9 +21,11 @@ import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; import io.vertx.core.http.HttpClientOptions; +import io.vertx.core.http.HttpHeaders; import io.vertx.core.http.StreamResetException; import io.vertx.core.net.SelfSignedCertificate; import io.vertx.core.net.SocketAddress; +import io.vertx.ext.auth.authentication.TokenCredentials; import io.vertx.ext.unit.Async; import io.vertx.ext.unit.TestContext; import io.vertx.grpc.common.GrpcReadStream; @@ -31,9 +34,7 @@ import java.io.File; import java.io.IOException; -import java.util.ArrayList; import java.util.Base64; -import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -74,6 +75,52 @@ protected void testUnary(TestContext should, String requestEncoding, String resp })); } + @Test + public void testJWT(TestContext should) throws IOException { + final Async test = should.async(); + final String clientToken = "ABC"; + + GreeterGrpc.GreeterImplBase called = new GreeterGrpc.GreeterImplBase() { + @Override + public void sayHello(HelloRequest request, StreamObserver responseObserver) { + responseObserver.onNext(HelloReply.newBuilder().setMessage("Hello " + request.getName()).build()); + responseObserver.onCompleted(); + } + }; + + ServerInterceptor interceptor = new ServerInterceptor() { + @Override + public Listener interceptCall( + ServerCall call, Metadata headers, + ServerCallHandler next) { + String serverToken = headers.get(Metadata.Key.of(HttpHeaders.AUTHORIZATION.toString(), Metadata.ASCII_STRING_MARSHALLER)); + serverToken = serverToken.substring("Bearer ".length()); + should.assertEquals(clientToken, serverToken); + return next.startCall(call, headers); + } + }; + + Server server = ServerBuilder.forPort(0).addService(called).intercept(interceptor).build().start(); + + client = GrpcClient.client(vertx).credentials(new TokenCredentials(clientToken)); + client.request(SocketAddress.inetSocketAddress(server.getPort(), "localhost"), GreeterGrpc.getSayHelloMethod()) + .onComplete(should.asyncAssertSuccess(callRequest -> { + callRequest.response().onComplete(should.asyncAssertSuccess(callResponse -> { + AtomicInteger count = new AtomicInteger(); + callResponse.handler(reply -> { + should.assertEquals(1, count.incrementAndGet()); + should.assertEquals("Hello Julien", reply.getMessage()); + }); + callResponse.endHandler(v2 -> { + should.assertEquals(GrpcStatus.OK, callResponse.status()); + should.assertEquals(1, count.get()); + test.complete(); + }); + })); + callRequest.end(HelloRequest.newBuilder().setName("Julien").build()); + })); + } + @Test public void testSSL(TestContext should) throws IOException { @@ -95,7 +142,7 @@ public void sayHello(HelloRequest request, StreamObserver responseOb Async test = should.async(); client = GrpcClient.client(vertx, new HttpClientOptions().setSsl(true) .setUseAlpn(true) - .setPemTrustOptions(cert.trustOptions())); + .setTrustOptions(cert.trustOptions())); client.request(SocketAddress.inetSocketAddress(8443, "localhost"), GreeterGrpc.getSayHelloMethod()) .onComplete(should.asyncAssertSuccess(callRequest -> { callRequest.response().onComplete(should.asyncAssertSuccess(callResponse -> { diff --git a/vertx-grpc-server/pom.xml b/vertx-grpc-server/pom.xml index 4af5769d..3f18edff 100644 --- a/vertx-grpc-server/pom.xml +++ b/vertx-grpc-server/pom.xml @@ -39,6 +39,11 @@ io.vertx vertx-grpc-common + + io.vertx + vertx-auth-jwt + true + io.grpc grpc-stub diff --git a/vertx-grpc-server/src/main/asciidoc/server.adoc b/vertx-grpc-server/src/main/asciidoc/server.adoc index 8f6b238a..15020901 100644 --- a/vertx-grpc-server/src/main/asciidoc/server.adoc +++ b/vertx-grpc-server/src/main/asciidoc/server.adoc @@ -109,6 +109,36 @@ You can compress response messages by setting the response encoding *prior* befo Decompression is done transparently by the server when the client send encoded requests. +=== JWT Authentication + +Service method handlers can be secured by providing an {@link io.vertx.grpc.server.auth.GrpcAuthenticationHandler} in the {@link io.vertx.grpc.server.GrpcServer#callHandler} method. + +For JWT a handler can be created via {@link io.vertx.grpc.server.auth.GrpcJWTAuthenticationHandler#create}. + +Add the required dependency to provide JWT via {@link io.vertx.ext.auth.jwt.JWTAuth}. + +[source,java] +---- + + io.vertx + vertx-auth-jwt + +---- + +Handlers that were registered without the handler will not use authentication. + +[source,java] +---- +{@link examples.GrpcServerExamples#jwtServerAuthExample} +---- + +The JWT can be added for the client call via an interceptor in order to provide the token credentials. + +[source,java] +---- +{@link examples.GrpcServerExamples#jwtClientAuthExample} +---- + === Stub API The Vert.x gRPC Server can bridge a gRPC service to use with a generated server stub in a more traditional fashion diff --git a/vertx-grpc-server/src/main/java/examples/GreeterGrpc.java b/vertx-grpc-server/src/main/java/examples/GreeterGrpc.java index 297ee427..f4436e2e 100644 --- a/vertx-grpc-server/src/main/java/examples/GreeterGrpc.java +++ b/vertx-grpc-server/src/main/java/examples/GreeterGrpc.java @@ -49,6 +49,37 @@ examples.HelloReply> getSayHelloMethod() { return getSayHelloMethod; } + private static volatile io.grpc.MethodDescriptor getSaySecuredHelloMethod; + + @io.grpc.stub.annotations.RpcMethod( + fullMethodName = SERVICE_NAME + '/' + "SaySecuredHello", + requestType = examples.HelloRequest.class, + responseType = examples.HelloReply.class, + methodType = io.grpc.MethodDescriptor.MethodType.UNARY) + public static io.grpc.MethodDescriptor getSaySecuredHelloMethod() { + io.grpc.MethodDescriptor getSaySecuredHelloMethod; + if ((getSaySecuredHelloMethod = GreeterGrpc.getSaySecuredHelloMethod) == null) { + synchronized (GreeterGrpc.class) { + if ((getSaySecuredHelloMethod = GreeterGrpc.getSaySecuredHelloMethod) == null) { + GreeterGrpc.getSaySecuredHelloMethod = getSaySecuredHelloMethod = + io.grpc.MethodDescriptor.newBuilder() + .setType(io.grpc.MethodDescriptor.MethodType.UNARY) + .setFullMethodName(generateFullMethodName(SERVICE_NAME, "SaySecuredHello")) + .setSampledToLocalTracing(true) + .setRequestMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + examples.HelloRequest.getDefaultInstance())) + .setResponseMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + examples.HelloReply.getDefaultInstance())) + .setSchemaDescriptor(new GreeterMethodDescriptorSupplier("SaySecuredHello")) + .build(); + } + } + } + return getSaySecuredHelloMethod; + } + /** * Creates a new async stub that supports all call types for the service */ @@ -110,6 +141,13 @@ public void sayHello(examples.HelloRequest request, io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall(getSayHelloMethod(), responseObserver); } + /** + */ + public void saySecuredHello(examples.HelloRequest request, + io.grpc.stub.StreamObserver responseObserver) { + io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall(getSaySecuredHelloMethod(), responseObserver); + } + @java.lang.Override public final io.grpc.ServerServiceDefinition bindService() { return io.grpc.ServerServiceDefinition.builder(getServiceDescriptor()) .addMethod( @@ -119,6 +157,13 @@ public void sayHello(examples.HelloRequest request, examples.HelloRequest, examples.HelloReply>( this, METHODID_SAY_HELLO))) + .addMethod( + getSaySecuredHelloMethod(), + io.grpc.stub.ServerCalls.asyncUnaryCall( + new MethodHandlers< + examples.HelloRequest, + examples.HelloReply>( + this, METHODID_SAY_SECURED_HELLO))) .build(); } } @@ -150,6 +195,14 @@ public void sayHello(examples.HelloRequest request, io.grpc.stub.ClientCalls.asyncUnaryCall( getChannel().newCall(getSayHelloMethod(), getCallOptions()), request, responseObserver); } + + /** + */ + public void saySecuredHello(examples.HelloRequest request, + io.grpc.stub.StreamObserver responseObserver) { + io.grpc.stub.ClientCalls.asyncUnaryCall( + getChannel().newCall(getSaySecuredHelloMethod(), getCallOptions()), request, responseObserver); + } } /** @@ -178,6 +231,13 @@ public examples.HelloReply sayHello(examples.HelloRequest request) { return io.grpc.stub.ClientCalls.blockingUnaryCall( getChannel(), getSayHelloMethod(), getCallOptions(), request); } + + /** + */ + public examples.HelloReply saySecuredHello(examples.HelloRequest request) { + return io.grpc.stub.ClientCalls.blockingUnaryCall( + getChannel(), getSaySecuredHelloMethod(), getCallOptions(), request); + } } /** @@ -207,9 +267,18 @@ public com.google.common.util.concurrent.ListenableFuture s return io.grpc.stub.ClientCalls.futureUnaryCall( getChannel().newCall(getSayHelloMethod(), getCallOptions()), request); } + + /** + */ + public com.google.common.util.concurrent.ListenableFuture saySecuredHello( + examples.HelloRequest request) { + return io.grpc.stub.ClientCalls.futureUnaryCall( + getChannel().newCall(getSaySecuredHelloMethod(), getCallOptions()), request); + } } private static final int METHODID_SAY_HELLO = 0; + private static final int METHODID_SAY_SECURED_HELLO = 1; private static final class MethodHandlers implements io.grpc.stub.ServerCalls.UnaryMethod, @@ -232,6 +301,10 @@ public void invoke(Req request, io.grpc.stub.StreamObserver responseObserv serviceImpl.sayHello((examples.HelloRequest) request, (io.grpc.stub.StreamObserver) responseObserver); break; + case METHODID_SAY_SECURED_HELLO: + serviceImpl.saySecuredHello((examples.HelloRequest) request, + (io.grpc.stub.StreamObserver) responseObserver); + break; default: throw new AssertionError(); } @@ -294,6 +367,7 @@ public static io.grpc.ServiceDescriptor getServiceDescriptor() { serviceDescriptor = result = io.grpc.ServiceDescriptor.newBuilder(SERVICE_NAME) .setSchemaDescriptor(new GreeterFileDescriptorSupplier()) .addMethod(getSayHelloMethod()) + .addMethod(getSaySecuredHelloMethod()) .build(); } } diff --git a/vertx-grpc-server/src/main/java/examples/GrpcServerExamples.java b/vertx-grpc-server/src/main/java/examples/GrpcServerExamples.java index 821bffed..e47bc0dd 100644 --- a/vertx-grpc-server/src/main/java/examples/GrpcServerExamples.java +++ b/vertx-grpc-server/src/main/java/examples/GrpcServerExamples.java @@ -1,6 +1,12 @@ package examples; -import io.grpc.stub.ServerCallStreamObserver; +import com.google.common.net.HttpHeaders; + +import examples.GreeterGrpc.GreeterBlockingStub; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Metadata; +import io.grpc.stub.MetadataUtils; import io.grpc.stub.StreamObserver; import io.vertx.core.Future; import io.vertx.core.Vertx; @@ -8,6 +14,10 @@ import io.vertx.core.http.HttpServer; import io.vertx.core.http.HttpServerOptions; import io.vertx.docgen.Source; +import io.vertx.ext.auth.KeyStoreOptions; +import io.vertx.ext.auth.User; +import io.vertx.ext.auth.jwt.JWTAuth; +import io.vertx.ext.auth.jwt.JWTAuthOptions; import io.vertx.grpc.common.GrpcMessage; import io.vertx.grpc.common.GrpcStatus; import io.vertx.grpc.common.ServiceName; @@ -15,6 +25,8 @@ import io.vertx.grpc.server.GrpcServerRequest; import io.vertx.grpc.server.GrpcServerResponse; import io.vertx.grpc.server.GrpcServiceBridge; +import io.vertx.grpc.server.auth.GrpcAuthenticationHandler; +import io.vertx.grpc.server.auth.GrpcJWTAuthenticationHandler; @Source public class GrpcServerExamples { @@ -125,6 +137,43 @@ public void responseCompression(GrpcServerResponse response) { response.write(Item.newBuilder().setValue("item-3").build()); } + + public void jwtServerAuthExample(Vertx vertx, HttpServerOptions options) { + JWTAuthOptions config = new JWTAuthOptions() + .setKeyStore(new KeyStoreOptions() + .setPath("keystore.jceks") + .setPassword("secret") + .setType("jceks")); + + JWTAuth jwtAuth = JWTAuth.create(vertx, config); + GrpcAuthenticationHandler authHandler = GrpcJWTAuthenticationHandler.create(jwtAuth); + GrpcServer server = GrpcServer.server(vertx); + + server.callHandler(authHandler, GreeterGrpc.getSayHelloMethod(), request -> { + + request.handler(hello -> { + User authenticatedUser = request.user(); + + GrpcServerResponse response = request.response(); + + HelloReply reply = HelloReply.newBuilder().setMessage("Hello " + hello.getName() + + " via " + authenticatedUser.subject()).build(); + + response.end(reply); + }); + }); + } + + public void jwtClientAuthExample(String token, int port) { + + ManagedChannel channel = ManagedChannelBuilder.forAddress("localhost", port).usePlaintext().build(); + + Metadata header = new Metadata(); + header.put(Metadata.Key.of(HttpHeaders.AUTHORIZATION, Metadata.ASCII_STRING_MARSHALLER), "Bearer " + token); + GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(channel) + .withInterceptors(MetadataUtils.newAttachHeadersInterceptor(header)); + } + public void stubExample(Vertx vertx, HttpServerOptions options) { GrpcServer grpcServer = GrpcServer.server(vertx); diff --git a/vertx-grpc-server/src/main/java/examples/HelloWorldProto.java b/vertx-grpc-server/src/main/java/examples/HelloWorldProto.java index 07970283..458d65fb 100644 --- a/vertx-grpc-server/src/main/java/examples/HelloWorldProto.java +++ b/vertx-grpc-server/src/main/java/examples/HelloWorldProto.java @@ -35,10 +35,12 @@ public static void registerAllExtensions( java.lang.String[] descriptorData = { "\n\020helloworld.proto\022\nhelloworld\"\034\n\014HelloR" + "equest\022\014\n\004name\030\001 \001(\t\"\035\n\nHelloReply\022\017\n\007me" + - "ssage\030\001 \001(\t2I\n\007Greeter\022>\n\010SayHello\022\030.hel" + - "loworld.HelloRequest\032\026.helloworld.HelloR" + - "eply\"\000B#\n\010examplesB\017HelloWorldProtoP\001\242\002\003" + - "HLWb\006proto3" + "ssage\030\001 \001(\t2\220\001\n\007Greeter\022>\n\010SayHello\022\030.he" + + "lloworld.HelloRequest\032\026.helloworld.Hello" + + "Reply\"\000\022E\n\017SaySecuredHello\022\030.helloworld." + + "HelloRequest\032\026.helloworld.HelloReply\"\000B#" + + "\n\010examplesB\017HelloWorldProtoP\001\242\002\003HLWb\006pro" + + "to3" }; descriptor = com.google.protobuf.Descriptors.FileDescriptor .internalBuildGeneratedFileFrom(descriptorData, diff --git a/vertx-grpc-server/src/main/java/io/vertx/grpc/server/GrpcException.java b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/GrpcException.java new file mode 100644 index 00000000..6b1d2319 --- /dev/null +++ b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/GrpcException.java @@ -0,0 +1,42 @@ +package io.vertx.grpc.server; + +import io.vertx.core.http.HttpClientResponse; +import io.vertx.grpc.common.GrpcStatus; + +public class GrpcException extends RuntimeException { + + private static final long serialVersionUID = -7838327176604697641L; + + private GrpcStatus status; + + private HttpClientResponse httpResponse; + + public GrpcException(String msg, GrpcStatus status, + HttpClientResponse httpResponse) { + super(msg); + this.status = status; + this.httpResponse = httpResponse; + } + + public GrpcException(GrpcStatus status) { + this.status = status; + } + + public GrpcException(GrpcStatus status, Throwable err) { + super(err); + this.status = status; + } + + public GrpcException(GrpcStatus status, String msg) { + super(msg); + this.status = status; + } + + public GrpcStatus status() { + return status; + } + + public HttpClientResponse response() { + return httpResponse; + } +} diff --git a/vertx-grpc-server/src/main/java/io/vertx/grpc/server/GrpcServer.java b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/GrpcServer.java index 1cdbfe4c..76d2b432 100644 --- a/vertx-grpc-server/src/main/java/io/vertx/grpc/server/GrpcServer.java +++ b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/GrpcServer.java @@ -18,6 +18,7 @@ import io.vertx.core.Vertx; import io.vertx.core.buffer.Buffer; import io.vertx.core.http.HttpServerRequest; +import io.vertx.grpc.server.auth.GrpcAuthenticationHandler; import io.vertx.grpc.server.impl.GrpcServerImpl; /** @@ -43,7 +44,8 @@ public interface GrpcServer extends Handler { /** * Create a blank gRPC server - * + * + * @param vertx Vert.x instance * @return the created server */ static GrpcServer server(Vertx vertx) { @@ -68,4 +70,16 @@ static GrpcServer server(Vertx vertx) { @GenIgnore(GenIgnore.PERMITTED_TYPE) GrpcServer callHandler(MethodDescriptor methodDesc, Handler> handler); + /** + * Set a service method call handler that handles any call call made to the server for the {@link MethodDescriptor} service method and that uses the provided authentication handler to authenticate the method call. + * + * @param authHandler Authentication handler used to authenticate the method call + * @param methodDesc gRPC service method that will be handled + * @param handler the service method call handler + * + * @return a reference to this, so the API can be used fluently + */ + @GenIgnore(GenIgnore.PERMITTED_TYPE) + GrpcServer callHandler(GrpcAuthenticationHandler authHandler, MethodDescriptor methodDesc, Handler> handler); + } diff --git a/vertx-grpc-server/src/main/java/io/vertx/grpc/server/GrpcServerRequest.java b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/GrpcServerRequest.java index 374a2eca..fb769791 100644 --- a/vertx-grpc-server/src/main/java/io/vertx/grpc/server/GrpcServerRequest.java +++ b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/GrpcServerRequest.java @@ -16,6 +16,7 @@ import io.vertx.codegen.annotations.VertxGen; import io.vertx.core.Handler; import io.vertx.core.http.HttpConnection; +import io.vertx.ext.auth.User; import io.vertx.grpc.common.GrpcError; import io.vertx.grpc.common.GrpcMessage; import io.vertx.grpc.common.GrpcReadStream; @@ -78,4 +79,7 @@ public interface GrpcServerRequest extends GrpcReadStream { * @return the underlying HTTP connection */ HttpConnection connection(); + + User user(); + } diff --git a/vertx-grpc-server/src/main/java/io/vertx/grpc/server/auth/GrpcAuthenticationHandler.java b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/auth/GrpcAuthenticationHandler.java new file mode 100644 index 00000000..673f62bd --- /dev/null +++ b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/auth/GrpcAuthenticationHandler.java @@ -0,0 +1,25 @@ +package io.vertx.grpc.server.auth; + +import io.vertx.codegen.annotations.VertxGen; +import io.vertx.core.Future; +import io.vertx.core.http.HttpServerRequest; +import io.vertx.ext.auth.User; +import io.vertx.grpc.server.GrpcServer; + +/** + * Authentication handler for {@link GrpcServer}. + */ +@VertxGen +@FunctionalInterface +public interface GrpcAuthenticationHandler { + + /** + * Authenticate the provided request and return the authenticated user. + * + * @param httpRequest Request to authenticate + * @param requireAuthentication Whether the handler should fail when no authentication is present + * @return + */ + Future authenticate(HttpServerRequest httpRequest, boolean requireAuthentication); + +} diff --git a/vertx-grpc-server/src/main/java/io/vertx/grpc/server/auth/GrpcJWTAuthenticationHandler.java b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/auth/GrpcJWTAuthenticationHandler.java new file mode 100644 index 00000000..a6bf324f --- /dev/null +++ b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/auth/GrpcJWTAuthenticationHandler.java @@ -0,0 +1,19 @@ +package io.vertx.grpc.server.auth; + +import io.vertx.ext.auth.jwt.JWTAuth; +import io.vertx.grpc.server.auth.impl.GrpcJWTAuthenticationHandlerImpl; + +/** + * Authentication handler which provides JWT authentication. + */ +public interface GrpcJWTAuthenticationHandler extends GrpcAuthenticationHandler { + + static GrpcJWTAuthenticationHandler create(JWTAuth authProvider) { + return create(authProvider, ""); + } + + static GrpcJWTAuthenticationHandler create(JWTAuth authProvider, String realm) { + return new GrpcJWTAuthenticationHandlerImpl(authProvider, realm); + } + +} diff --git a/vertx-grpc-server/src/main/java/io/vertx/grpc/server/auth/impl/GrpcJWTAuthenticationHandlerImpl.java b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/auth/impl/GrpcJWTAuthenticationHandlerImpl.java new file mode 100644 index 00000000..9da4c5c2 --- /dev/null +++ b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/auth/impl/GrpcJWTAuthenticationHandlerImpl.java @@ -0,0 +1,137 @@ +package io.vertx.grpc.server.auth.impl; + +import io.vertx.core.Future; +import io.vertx.core.http.HttpHeaders; +import io.vertx.core.http.HttpServerRequest; +import io.vertx.ext.auth.User; +import io.vertx.ext.auth.authentication.AuthenticationProvider; +import io.vertx.ext.auth.authentication.TokenCredentials; +import io.vertx.ext.auth.jwt.JWTAuth; +import io.vertx.grpc.common.GrpcStatus; +import io.vertx.grpc.server.GrpcException; +import io.vertx.grpc.server.auth.GrpcJWTAuthenticationHandler; + +public class GrpcJWTAuthenticationHandlerImpl implements GrpcJWTAuthenticationHandler { + + static final GrpcException UNAUTHENTICATED = new GrpcException( + GrpcStatus.UNAUTHENTICATED); + static final GrpcException UNKNOWN = new GrpcException(GrpcStatus.UNKNOWN); + static final GrpcException UNIMPLEMENTED = new GrpcException( + GrpcStatus.UNIMPLEMENTED); + + public enum Type { + BASIC("Basic"), DIGEST("Digest"), BEARER("Bearer"), + // these have no known implementation + HOBA("HOBA"), MUTUAL("Mutual"), NEGOTIATE("Negotiate"), OAUTH( + "OAuth"), SCRAM_SHA_1("SCRAM-SHA-1"), SCRAM_SHA_256("SCRAM-SHA-256"); + + private final String label; + + Type(String label) { + this.label = label; + } + + public boolean is(String other) { + return label.equalsIgnoreCase(other); + } + + @Override + public String toString() { + return label; + } + } + + private final AuthenticationProvider authProvider; + private final Type type; + private final String realm; + + public GrpcJWTAuthenticationHandlerImpl(JWTAuth authProvider, String realm) { + this(authProvider, Type.BEARER, realm); + } + + public GrpcJWTAuthenticationHandlerImpl(AuthenticationProvider authProvider, + Type type, String realm) { + this.authProvider = authProvider; + this.type = type; + this.realm = realm == null ? null + : realm + // escape quotes + .replaceAll("\"", "\\\""); + + if (this.realm != null && + (this.realm.indexOf('\r') != -1 || this.realm.indexOf('\n') != -1)) { + throw new IllegalArgumentException( + "Not allowed [\\r|\\n] characters detected on realm name"); + } + } + + private Future parseAuthorization(HttpServerRequest req, + boolean optional) { + + final String authorization = req.headers().get(HttpHeaders.AUTHORIZATION); + + if (authorization == null) { + if (optional) { + // this is allowed + return Future.succeededFuture(); + } else { + return Future.failedFuture(UNAUTHENTICATED); + } + } + + try { + int idx = authorization.indexOf(' '); + + if (idx <= 0) { + return Future.failedFuture(UNKNOWN); + } + + if (!type.is(authorization.substring(0, idx))) { + return Future.failedFuture(UNAUTHENTICATED); + } + + return Future.succeededFuture(authorization.substring(idx + 1)); + } catch (RuntimeException e) { + return Future.failedFuture(e); + } + } + + protected Future parseAuthorization(HttpServerRequest req) { + return parseAuthorization(req, false); + } + + public Future authenticate(HttpServerRequest req, boolean requireAuthentication) { + return parseAuthorization(req, !requireAuthentication) + .compose(token -> { + + if (token == null) { + return Future.failedFuture( + new GrpcException(GrpcStatus.UNKNOWN, "Missing token")); + } + int segments = 0; + for (int i = 0; i < token.length(); i++) { + char c = token.charAt(i); + if (c == '.') { + if (++segments == 3) { + return Future.failedFuture(new GrpcException(GrpcStatus.UNKNOWN, + "Too many segments in token")); + } + continue; + } + if (Character.isLetterOrDigit(c) || c == '-' || c == '_') { + continue; + } + // invalid character + return Future.failedFuture(new GrpcException(GrpcStatus.UNKNOWN, + "Invalid character in token: " + (int) c)); + } + + return authProvider.authenticate(new TokenCredentials(token)) + .recover(error -> { + return Future.failedFuture( + new GrpcException(GrpcStatus.UNAUTHENTICATED, error)); + }); + }); + } + +} diff --git a/vertx-grpc-server/src/main/java/io/vertx/grpc/server/impl/GrpcServerImpl.java b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/impl/GrpcServerImpl.java index f664b3ea..af648682 100644 --- a/vertx-grpc-server/src/main/java/io/vertx/grpc/server/impl/GrpcServerImpl.java +++ b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/impl/GrpcServerImpl.java @@ -11,15 +11,20 @@ package io.vertx.grpc.server.impl; import io.grpc.MethodDescriptor; +import io.vertx.core.Future; import io.vertx.core.Handler; import io.vertx.core.Vertx; import io.vertx.core.buffer.Buffer; import io.vertx.core.http.HttpServerRequest; +import io.vertx.ext.auth.User; import io.vertx.grpc.common.GrpcMessageDecoder; import io.vertx.grpc.common.GrpcMessageEncoder; import io.vertx.grpc.common.impl.GrpcMethodCall; +import io.vertx.grpc.server.GrpcException; import io.vertx.grpc.server.GrpcServer; import io.vertx.grpc.server.GrpcServerRequest; +import io.vertx.grpc.server.GrpcServerResponse; +import io.vertx.grpc.server.auth.GrpcAuthenticationHandler; import java.util.HashMap; import java.util.Map; @@ -42,6 +47,7 @@ public void handle(HttpServerRequest httpRequest) { GrpcMethodCall methodCall = new GrpcMethodCall(httpRequest.path()); String fmn = methodCall.fullMethodName(); MethodCallHandler method = methodCallHandlers.get(fmn); + if (method != null) { handle(method, httpRequest, methodCall); } else { @@ -59,6 +65,27 @@ public void handle(HttpServerRequest httpRequest) { private void handle(MethodCallHandler method, HttpServerRequest httpRequest, GrpcMethodCall methodCall) { GrpcServerRequestImpl grpcRequest = new GrpcServerRequestImpl<>(httpRequest, method.messageDecoder, method.messageEncoder, methodCall); grpcRequest.init(); + GrpcAuthenticationHandler authHandler = method.authHandler; + if (authHandler != null) { + Future fut = authHandler.authenticate(httpRequest, true); + + // Handle authentication failures + if (fut.failed()) { + if (fut.cause() instanceof GrpcException) { + GrpcException ex = (GrpcException) fut.cause(); + GrpcServerResponse response = grpcRequest.response(); + response.status(ex.status()).end(); + } else { + httpRequest.response().setStatusCode(500).end(); + } + return; + } else { + User user = fut.result(); + if (user != null) { + grpcRequest.setUser(user); + } + } + } method.handle(grpcRequest); } @@ -68,8 +95,17 @@ public GrpcServer callHandler(Handler> handler } public GrpcServer callHandler(MethodDescriptor methodDesc, Handler> handler) { + return callHandler(methodDesc, handler, null); + } + + @Override + public GrpcServer callHandler(GrpcAuthenticationHandler authHandler, MethodDescriptor methodDesc, Handler> handler) { + return callHandler(methodDesc, handler, authHandler); + } + + private GrpcServer callHandler(MethodDescriptor methodDesc, Handler> handler, GrpcAuthenticationHandler authHandler) { if (handler != null) { - methodCallHandlers.put(methodDesc.getFullMethodName(), new MethodCallHandler<>(methodDesc, GrpcMessageDecoder.unmarshaller(methodDesc.getRequestMarshaller()), GrpcMessageEncoder.marshaller(methodDesc.getResponseMarshaller()), handler)); + methodCallHandlers.put(methodDesc.getFullMethodName(), new MethodCallHandler<>(methodDesc, GrpcMessageDecoder.unmarshaller(methodDesc.getRequestMarshaller()), GrpcMessageEncoder.marshaller(methodDesc.getResponseMarshaller()), handler, authHandler)); } else { methodCallHandlers.remove(methodDesc.getFullMethodName()); } @@ -82,12 +118,14 @@ private static class MethodCallHandler implements Handler messageDecoder; final GrpcMessageEncoder messageEncoder; final Handler> handler; + final GrpcAuthenticationHandler authHandler; - MethodCallHandler(MethodDescriptor def, GrpcMessageDecoder messageDecoder, GrpcMessageEncoder messageEncoder, Handler> handler) { + MethodCallHandler(MethodDescriptor def, GrpcMessageDecoder messageDecoder, GrpcMessageEncoder messageEncoder, Handler> handler, GrpcAuthenticationHandler authHandler) { this.def = def; this.messageDecoder = messageDecoder; this.messageEncoder = messageEncoder; this.handler = handler; + this.authHandler = authHandler; } @Override diff --git a/vertx-grpc-server/src/main/java/io/vertx/grpc/server/impl/GrpcServerRequestImpl.java b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/impl/GrpcServerRequestImpl.java index 4c2e5720..faf9b2a3 100644 --- a/vertx-grpc-server/src/main/java/io/vertx/grpc/server/impl/GrpcServerRequestImpl.java +++ b/vertx-grpc-server/src/main/java/io/vertx/grpc/server/impl/GrpcServerRequestImpl.java @@ -10,18 +10,18 @@ */ package io.vertx.grpc.server.impl; -import io.vertx.core.Future; import io.vertx.core.Handler; import io.vertx.core.MultiMap; import io.vertx.core.http.HttpConnection; import io.vertx.core.http.HttpServerRequest; import io.vertx.core.http.impl.HttpServerRequestInternal; +import io.vertx.ext.auth.User; import io.vertx.grpc.common.CodecException; import io.vertx.grpc.common.GrpcMessageDecoder; import io.vertx.grpc.common.GrpcMessageEncoder; import io.vertx.grpc.common.ServiceName; -import io.vertx.grpc.common.impl.GrpcReadStreamBase; import io.vertx.grpc.common.impl.GrpcMethodCall; +import io.vertx.grpc.common.impl.GrpcReadStreamBase; import io.vertx.grpc.server.GrpcServerRequest; import io.vertx.grpc.server.GrpcServerResponse; @@ -33,6 +33,7 @@ public class GrpcServerRequestImpl extends GrpcReadStreamBase response; private GrpcMethodCall methodCall; + private User user; public GrpcServerRequestImpl(HttpServerRequest httpRequest, GrpcMessageDecoder messageDecoder, GrpcMessageEncoder messageEncoder, GrpcMethodCall methodCall) { super(((HttpServerRequestInternal) httpRequest).context(), httpRequest, httpRequest.headers().get("grpc-encoding"), messageDecoder); @@ -91,4 +92,15 @@ public GrpcServerResponse response() { public HttpConnection connection() { return httpRequest.connection(); } + + @Override + public User user() { + return user; + } + + public GrpcServerRequest setUser(User user) { + this.user = user; + return this; + } + } diff --git a/vertx-grpc-server/src/main/proto/helloworld.proto b/vertx-grpc-server/src/main/proto/helloworld.proto index 7d9bc184..bc4e80d6 100644 --- a/vertx-grpc-server/src/main/proto/helloworld.proto +++ b/vertx-grpc-server/src/main/proto/helloworld.proto @@ -40,6 +40,8 @@ package helloworld; service Greeter { // Sends a greeting rpc SayHello (HelloRequest) returns (HelloReply) {} + + rpc SaySecuredHello (HelloRequest) returns (HelloReply) {} } // The request message containing the user's name. diff --git a/vertx-grpc-server/src/test/java/io/vertx/grpc/server/ServerJWTAuthTest.java b/vertx-grpc-server/src/test/java/io/vertx/grpc/server/ServerJWTAuthTest.java new file mode 100644 index 00000000..6ad490ab --- /dev/null +++ b/vertx-grpc-server/src/test/java/io/vertx/grpc/server/ServerJWTAuthTest.java @@ -0,0 +1,180 @@ +package io.vertx.grpc.server; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import org.junit.Test; + +import com.google.common.net.HttpHeaders; + +import examples.GreeterGrpc; +import examples.GreeterGrpc.GreeterBlockingStub; +import examples.HelloReply; +import examples.HelloRequest; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.stub.MetadataUtils; +import io.vertx.core.json.JsonObject; +import io.vertx.ext.auth.JWTOptions; +import io.vertx.ext.auth.KeyStoreOptions; +import io.vertx.ext.auth.User; +import io.vertx.ext.auth.jwt.JWTAuth; +import io.vertx.ext.auth.jwt.JWTAuthOptions; +import io.vertx.ext.unit.TestContext; +import io.vertx.grpc.server.auth.GrpcAuthenticationHandler; +import io.vertx.grpc.server.auth.GrpcJWTAuthenticationHandler; + +public class ServerJWTAuthTest extends ServerTestBase { + + private String validToken; + private GrpcServer jwtServer; + private String expiredToken; + private static final String BROKEN_TOKEN = "this-token-value-is-bogus"; + private static final String INVALID_TOKEN = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJqb2hhbm5lcyIsImlhdCI6MTY3ODgwNDgwN30"; + private static final String NO_TOKEN = null; + + public void setupClientServer(TestContext should, boolean expectUser) { + + // Prepare JWT auth and generate token to be used for the client + JWTAuthOptions config = new JWTAuthOptions() + .setKeyStore(new KeyStoreOptions() + .setPath("keystore.jceks") + .setPassword("secret") + .setType("jceks")); + + JWTAuth authProvider = JWTAuth.create(vertx, config); + GrpcAuthenticationHandler authHandler = GrpcJWTAuthenticationHandler.create(authProvider, ""); + validToken = authProvider.generateToken(new JsonObject().put("sub", "johannes"), new JWTOptions().setIgnoreExpiration(true)); + expiredToken = authProvider.generateToken(new JsonObject().put("sub", "johannes"), new JWTOptions().setExpiresInSeconds(1)); + + jwtServer = GrpcServer.server(vertx); + + // Register the secured + jwtServer.callHandler(authHandler, GreeterGrpc.getSaySecuredHelloMethod(), request -> { + handleHelloRequest(should, request, expectUser); + }); + // And public handler + jwtServer.callHandler(GreeterGrpc.getSayHelloMethod(), request -> { + handleHelloRequest(should, request, expectUser); + }); + + startServer(jwtServer); + } + + private void handleHelloRequest(TestContext should, GrpcServerRequest request, boolean expectUser) { + request.handler(hello -> { + GrpcServerResponse response = request.response(); + HelloReply reply = HelloReply.newBuilder().setMessage("Hello " + hello.getName()).build(); + response.end(reply); + User user = request.user(); + if (expectUser) { + should.assertNotNull(user); + should.assertEquals("johannes", user.subject()); + } else { + should.assertNull(user); + } + }).errorHandler(error -> { + should.fail("Error should not happen " + error); + }); + } + + /** + * Invoke request with a broken JWT in the headers. Request should be rejected by auth handler. + * @param should + */ + @Test + public void testJWTBrokenTokenAuthentication(TestContext should) { + setupClientServer(should, false); + StatusRuntimeException error = invokeSecuredRequest(should, BROKEN_TOKEN); + assertEquals("The token should not have been accepted", Status.UNAUTHENTICATED, error.getStatus()); + } + + /** + * Invoke request with expired JWT. Request should be rejected. + * @param should + */ + @Test + public void testJWTExpiredTokenAuthentication(TestContext should) { + setupClientServer(should, false); + // Let the token expire + try { + Thread.sleep(2000); + } catch (InterruptedException e) { + e.printStackTrace(); + } + + StatusRuntimeException error = invokeSecuredRequest(should, expiredToken); + assertEquals("The request should have failed since the token is expired", Status.UNAUTHENTICATED, error.getStatus()); + + } + + /** + * Invoke request with valid JWT. Request should pass and user should be accessible. + * @param should + */ + @Test + public void testJWTValidAuthentication(TestContext should) { + setupClientServer(should, true); + assertNull("No error should occur", invokeSecuredRequest(should, validToken)); + } + + /** + * Invoke request to the public method with no JWT provided. + * Request should pass. + * @param should + */ + @Test + public void testNoJWTValidAuthentication(TestContext should) { + setupClientServer(should, false); + assertNull("No error should occur", invokePublicRequest(should, NO_TOKEN)); + } + + /** + * Invoke request with invalid token. Request should fail. + * @param should + */ + @Test + public void testJWTInvalidAuthentication(TestContext should) { + setupClientServer(should, true); + StatusRuntimeException error = invokeSecuredRequest(should, INVALID_TOKEN); + assertEquals("The invalid token should have been rejected.", Status.UNAUTHENTICATED, error.getStatus()); + } + + private StatusRuntimeException invokePublicRequest(TestContext should, String token) { + return invokeRequest(should, token, false); + } + + private StatusRuntimeException invokeSecuredRequest(TestContext should, String token) { + return invokeRequest(should, token, true); + } + + private StatusRuntimeException invokeRequest(TestContext should, String token, boolean secured) { + channel = ManagedChannelBuilder.forAddress("localhost", port).usePlaintext().build(); + + GreeterBlockingStub stub = GreeterGrpc.newBlockingStub(channel); + if (secured && token != null) { + Metadata header = new Metadata(); + header.put(Metadata.Key.of(HttpHeaders.AUTHORIZATION, Metadata.ASCII_STRING_MARSHALLER), "Bearer " + token); + stub = stub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(header)); + } + + try { + HelloRequest request = HelloRequest.newBuilder().setName("Johannes").build(); + HelloReply res; + if (secured) { + res = stub.saySecuredHello(request); + } else { + res = stub.sayHello(request); + } + + should.assertEquals("Hello Johannes", res.getMessage()); + return null; + } catch (StatusRuntimeException e) { + return e; + } + + } + +} diff --git a/vertx-grpc-server/src/test/java/io/vertx/grpc/server/ServerRequestTest.java b/vertx-grpc-server/src/test/java/io/vertx/grpc/server/ServerRequestTest.java index d6d20e83..77d38945 100644 --- a/vertx-grpc-server/src/test/java/io/vertx/grpc/server/ServerRequestTest.java +++ b/vertx-grpc-server/src/test/java/io/vertx/grpc/server/ServerRequestTest.java @@ -66,7 +66,7 @@ public void testSSL(TestContext should) throws IOException { .setUseAlpn(true) .setPort(8443) .setHost("localhost") - .setPemKeyCertOptions(cert.keyCertOptions()), GrpcServer.server(vertx).callHandler(GreeterGrpc.getSayHelloMethod(), call -> { + .setKeyCertOptions(cert.keyCertOptions()), GrpcServer.server(vertx).callHandler(GreeterGrpc.getSayHelloMethod(), call -> { call.handler(helloRequest -> { HelloReply helloReply = HelloReply.newBuilder().setMessage("Hello " + helloRequest.getName()).build(); GrpcServerResponse response = call.response(); diff --git a/vertx-grpc-server/src/test/resources/keystore.jceks b/vertx-grpc-server/src/test/resources/keystore.jceks new file mode 100644 index 00000000..0db37124 Binary files /dev/null and b/vertx-grpc-server/src/test/resources/keystore.jceks differ