Skip to content

feat(sdk): provide access tokens dynamically to KAS #51

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package io.opentdf.platform.sdk;

import io.grpc.Channel;
import io.opentdf.platform.kas.AccessServiceGrpc;
import io.opentdf.platform.kas.PublicKeyRequest;
import io.opentdf.platform.kas.RewrapRequest;

import java.util.HashMap;
import java.util.function.Function;

public class KASClient implements SDK.KAS {

private final Function<SDK.KASInfo, Channel> channelFactory;

public KASClient(Function <SDK.KASInfo, Channel> channelFactory) {
this.channelFactory = channelFactory;
}

@Override
public String getPublicKey(SDK.KASInfo kasInfo) {
return getStub(kasInfo).publicKey(PublicKeyRequest.getDefaultInstance()).getPublicKey();
}

@Override
public byte[] unwrap(SDK.KASInfo kasInfo, SDK.Policy policy) {
// this is obviously wrong. we still have to generate a correct request and decrypt the payload
return getStub(kasInfo).rewrap(RewrapRequest.getDefaultInstance()).getEntityWrappedKey().toByteArray();
}

private final HashMap<SDK.KASInfo, AccessServiceGrpc.AccessServiceBlockingStub> stubs = new HashMap<>();

private synchronized AccessServiceGrpc.AccessServiceBlockingStub getStub(SDK.KASInfo kasInfo) {
if (!stubs.containsKey(kasInfo)) {
var channel = channelFactory.apply(kasInfo);
var stub = AccessServiceGrpc.newBlockingStub(channel);
stubs.put(kasInfo, stub);
}

return stubs.get(kasInfo);
}
}
22 changes: 19 additions & 3 deletions sdk/src/main/java/io/opentdf/platform/sdk/SDK.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,25 @@
public class SDK {
private final Services services;

public interface KASInfo{
String getAddress();
}
public interface Policy{}

interface KAS {
String getPublicKey(KASInfo kasInfo);
byte[] unwrap(KASInfo kasInfo, Policy policy);
}

// TODO: add KAS
public interface Services {
interface Services {
AttributesServiceFutureStub attributes();
NamespaceServiceFutureStub namespaces();
SubjectMappingServiceFutureStub subjectMappings();
ResourceMappingServiceFutureStub resourceMappings();
KAS kas();

static Services newServices(Channel channel) {
static Services newServices(Channel channel, KAS kas) {
var attributeService = AttributesServiceGrpc.newFutureStub(channel);
var namespaceService = NamespaceServiceGrpc.newFutureStub(channel);
var subjectMappingService = SubjectMappingServiceGrpc.newFutureStub(channel);
Expand All @@ -50,11 +61,16 @@ public SubjectMappingServiceFutureStub subjectMappings() {
public ResourceMappingServiceFutureStub resourceMappings() {
return resourceMappingService;
}

@Override
public KAS kas() {
return kas;
}
};
}
}

public SDK(Services services) {
SDK(Services services) {
this.services = services;
}
}
47 changes: 38 additions & 9 deletions sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import com.nimbusds.oauth2.sdk.id.ClientID;
import com.nimbusds.oauth2.sdk.id.Issuer;
import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata;
import io.grpc.Channel;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Status;
Expand All @@ -23,6 +24,7 @@

import java.io.IOException;
import java.util.UUID;
import java.util.function.Function;

/**
* A builder class for creating instances of the SDK class.
Expand All @@ -33,9 +35,13 @@ public class SDKBuilder {
private ClientAuthentication clientAuth = null;
private Boolean usePlainText;

private static final Logger logger = LoggerFactory.getLogger(SDKBuilder.class);

public static SDKBuilder newBuilder() {
SDKBuilder builder = new SDKBuilder();
builder.usePlainText = false;
builder.clientAuth = null;
builder.platformEndpoint = null;

return builder;
}
Expand All @@ -57,8 +63,16 @@ public SDKBuilder useInsecurePlaintextConnection(Boolean usePlainText) {
return this;
}

// this is not exposed publicly so that it can be tested
ManagedChannel buildChannel() {
private GRPCAuthInterceptor getGrpcAuthInterceptor() {
if (platformEndpoint == null) {
throw new SDKException("cannot build an SDK without specifying the platform endpoint");
}

if (clientAuth == null) {
// this simplifies things for now, if we need to support this case we can revisit
throw new SDKException("cannot build an SDK without specifying OAuth credentials");
}

// we don't add the auth listener to this channel since it is only used to call the
// well known endpoint
ManagedChannel bootstrapChannel = null;
Expand Down Expand Up @@ -107,24 +121,39 @@ ManagedChannel buildChannel() {
throw new SDKException("Error generating DPoP key", e);
}

GRPCAuthInterceptor interceptor = new GRPCAuthInterceptor(clientAuth, rsaKey, providerMetadata.getTokenEndpointURI());
return new GRPCAuthInterceptor(clientAuth, rsaKey, providerMetadata.getTokenEndpointURI());
}

return getManagedChannelBuilder()
.intercept(interceptor)
.build();
SDK.Services buildServices() {
var authInterceptor = getGrpcAuthInterceptor();
var channel = getManagedChannelBuilder().intercept(authInterceptor).build();
var client = new KASClient(getChannelFactory(authInterceptor));
return SDK.Services.newServices(channel, client);
}

public SDK build() {
return new SDK(SDK.Services.newServices(buildChannel()));
return new SDK(buildServices());
}

private ManagedChannelBuilder<?> getManagedChannelBuilder() {
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder
.forTarget(platformEndpoint);
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder.forTarget(platformEndpoint);

if (usePlainText) {
channelBuilder = channelBuilder.usePlaintext();
}
return channelBuilder;
}

Function<SDK.KASInfo, Channel> getChannelFactory(GRPCAuthInterceptor authInterceptor) {
var pt = usePlainText; // no need to have the builder be able to influence things from beyond the grave
return (SDK.KASInfo kasInfo) -> {
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder
.forTarget(kasInfo.getAddress())
.intercept(authInterceptor);
if (pt) {
channelBuilder = channelBuilder.usePlaintext();
}
return channelBuilder.build();
};
}
}
4 changes: 4 additions & 0 deletions sdk/src/main/java/io/opentdf/platform/sdk/SDKException.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,8 @@ public class SDKException extends RuntimeException {
public SDKException(String message, Exception reason) {
super(message, reason);
}

public SDKException(String message) {
super(message);
}
}
101 changes: 76 additions & 25 deletions sdk/src/test/java/io/opentdf/platform/sdk/SDKBuilderTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,19 @@
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.stub.StreamObserver;
import io.opentdf.platform.kas.AccessServiceGrpc;
import io.opentdf.platform.kas.RewrapRequest;
import io.opentdf.platform.kas.RewrapResponse;
import io.opentdf.platform.policy.namespaces.GetNamespaceRequest;
import io.opentdf.platform.policy.namespaces.GetNamespaceResponse;
import io.opentdf.platform.policy.namespaces.NamespaceServiceGrpc;
import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationRequest;
import io.opentdf.platform.wellknownconfiguration.GetWellKnownConfigurationResponse;
import io.opentdf.platform.wellknownconfiguration.WellKnownServiceGrpc;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

import java.io.IOException;
Expand All @@ -30,8 +38,9 @@
public class SDKBuilderTest {

@Test
void testCreatingSDKChannel() throws IOException, InterruptedException {
Server wellknownServer = null;
void testCreatingSDKServices() throws IOException, InterruptedException {
Server platformServicesServer = null;
Server kasServer = null;
// we use the HTTP server for two things:
// * it returns the OIDC configuration we use at bootstrapping time
// * it fakes out being an IDP and returns an access token when need to retrieve an access token
Expand All @@ -51,6 +60,8 @@ void testCreatingSDKChannel() throws IOException, InterruptedException {
.setHeader("Content-type", "application/json")
);

// this service returns the platform_issuer url to the SDK during bootstrapping. This
// tells the SDK where to download the OIDC discovery document from (our test webserver!)
WellKnownServiceGrpc.WellKnownServiceImplBase wellKnownService = new WellKnownServiceGrpc.WellKnownServiceImplBase() {
@Override
public void getWellKnownConfiguration(GetWellKnownConfigurationRequest request, StreamObserver<GetWellKnownConfigurationResponse> responseObserver) {
Expand All @@ -65,55 +76,76 @@ public void getWellKnownConfiguration(GetWellKnownConfigurationRequest request,
}
};

AtomicReference<String> authHeaderFromRequest = new AtomicReference<>(null);
AtomicReference<String> dpopHeaderFromRequest = new AtomicReference<>(null);
// remember the auth headers that we received during GRPC calls to platform services
AtomicReference<String> servicesAuthHeader = new AtomicReference<>(null);
AtomicReference<String> servicesDPoPHeader = new AtomicReference<>(null);

// remember the auth headers that we received during GRPC calls to KAS
AtomicReference<String> kasAuthHeader = new AtomicReference<>(null);
AtomicReference<String> kasDPoPHeader = new AtomicReference<>(null);
// we use the server in two different ways. the first time we use it to actually return
// issuer for bootstrapping. the second time we use the interception functionality in order
// to make sure that we are including a DPoP proof and an auth header
int randomPort;
try (ServerSocket socket = new ServerSocket(0)) {
randomPort = socket.getLocalPort();
}
wellknownServer = ServerBuilder
.forPort(randomPort)
platformServicesServer = ServerBuilder
.forPort(getRandomPort())
.directExecutor()
.addService(wellKnownService)
.addService(new NamespaceServiceGrpc.NamespaceServiceImplBase() {})
.intercept(new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
servicesAuthHeader.set(headers.get(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER)));
servicesDPoPHeader.set(headers.get(Metadata.Key.of("DPoP", Metadata.ASCII_STRING_MARSHALLER)));
return next.startCall(call, headers);
}
})
.build()
.start();


kasServer = ServerBuilder
.forPort(getRandomPort())
.directExecutor()
.addService(new AccessServiceGrpc.AccessServiceImplBase() {
@Override
public void rewrap(RewrapRequest request, StreamObserver<RewrapResponse> responseObserver) {
responseObserver.onNext(RewrapResponse.getDefaultInstance());
responseObserver.onCompleted();
}
})
.intercept(new ServerInterceptor() {
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
authHeaderFromRequest.set(headers.get(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER)));
dpopHeaderFromRequest.set(headers.get(Metadata.Key.of("DPoP", Metadata.ASCII_STRING_MARSHALLER)));
kasAuthHeader.set(headers.get(Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER)));
kasDPoPHeader.set(headers.get(Metadata.Key.of("DPoP", Metadata.ASCII_STRING_MARSHALLER)));
return next.startCall(call, headers);
}
})
.build()
.start();

ManagedChannel channel = SDKBuilder
SDK.Services services = SDKBuilder
.newBuilder()
.clientSecret("client-id", "client-secret")
.platformEndpoint("localhost:" + wellknownServer.getPort())
.platformEndpoint("localhost:" + platformServicesServer.getPort())
.useInsecurePlaintextConnection(true)
.buildChannel();

assertThat(channel).isNotNull();
assertThat(channel.getState(false)).isEqualTo(ConnectivityState.IDLE);
.buildServices();

var wellKnownStub = WellKnownServiceGrpc.newBlockingStub(channel);
assertThat(services).isNotNull();

httpServer.enqueue(new MockResponse()
.setBody("{\"access_token\": \"hereisthetoken\", \"token_type\": \"Bearer\"}")
.setHeader("Content-Type", "application/json"));

var ignored = wellKnownStub.getWellKnownConfiguration(GetWellKnownConfigurationRequest.getDefaultInstance());
channel.shutdownNow();
var ignored = services.namespaces().getNamespace(GetNamespaceRequest.getDefaultInstance());

// we've now made two requests. one to get the bootstrapping info and one
// call that should activate the token fetching logic
assertThat(httpServer.getRequestCount()).isEqualTo(2);

httpServer.takeRequest();

// validate that we made a reasonable request to our fake IdP to get an access token
var accessTokenRequest = httpServer.takeRequest();
assertThat(accessTokenRequest).isNotNull();
var authHeader = accessTokenRequest.getHeader("Authorization");
Expand All @@ -124,16 +156,35 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re
var usernameAndPassword = new String(Base64.getDecoder().decode(authHeaderParts[1]), StandardCharsets.UTF_8);
assertThat(usernameAndPassword).isEqualTo("client-id:client-secret");

assertThat(dpopHeaderFromRequest.get()).isNotNull();
assertThat(authHeaderFromRequest.get()).isEqualTo("DPoP hereisthetoken");
// validate that during the request to the namespace service we supplied a valid token
assertThat(servicesDPoPHeader.get()).isNotNull();
assertThat(servicesAuthHeader.get()).isEqualTo("DPoP hereisthetoken");

var body = new String(accessTokenRequest.getBody().readByteArray(), StandardCharsets.UTF_8);
assertThat(body).contains("grant_type=client_credentials");

// now call KAS _on a different server_ and make sure that the interceptors provide us with auth tokens
int kasPort = kasServer.getPort();
SDK.KASInfo kasInfo = () -> "localhost:" + kasPort;
services.kas().unwrap(kasInfo, new SDK.Policy() {});

assertThat(kasDPoPHeader.get()).isNotNull();
assertThat(kasAuthHeader.get()).isEqualTo("DPoP hereisthetoken");
} finally {
if (wellknownServer != null) {
wellknownServer.shutdownNow();
if (platformServicesServer != null) {
platformServicesServer.shutdownNow();
}
if (kasServer != null) {
kasServer.shutdownNow();
}
}
}

private static int getRandomPort() throws IOException {
int randomPort;
try (ServerSocket socket = new ServerSocket(0)) {
randomPort = socket.getLocalPort();
}
return randomPort;
}
}
Loading