Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions sdk/src/main/java/io/opentdf/platform/sdk/AttributesClient.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package io.opentdf.platform.sdk;

import java.net.MalformedURLException;
import java.net.URL;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.function.Function;

import io.grpc.ManagedChannel;
import io.opentdf.platform.policy.attributes.GetAttributeValuesByFqnsRequest;
import io.opentdf.platform.policy.attributes.AttributesServiceGrpc;
import io.opentdf.platform.policy.attributes.GetAttributeValuesByFqnsResponse;

import static java.lang.String.format;


public class AttributesClient implements SDK.AttributesService {

private final ManagedChannel channel;

/***
* A client that communicates with KAS
* @param channelFactory A function that produces channels that can be used to communicate
* @param dpopKey
*/
public AttributesClient(ManagedChannel channel) {
this.channel = channel;
}


@Override
public synchronized void close() {
var entries = new ArrayList<>(stubs.values());
stubs.clear();
for (var entry: entries) {
entry.channel.shutdownNow();
}
this.channel.shutdownNow();
}

private String normalizeAddress(String urlString) {
URL url;
try {
url = new URL(urlString);
} catch (MalformedURLException e) {
// if there is no protocol then they either gave us
// a correct address or one we don't know how to fix
return urlString;
}

// otherwise we take the specified port or default
// based on whether the URL uses a scheme that
// implies TLS
int port;
if (url.getPort() == -1) {
if ("http".equals(url.getProtocol())) {
port = 80;
} else {
port = 443;
}
} else {
port = url.getPort();
}

return format("%s:%d", url.getHost(), port);
}


private final HashMap<String, CacheEntry> stubs = new HashMap<>();
private static class CacheEntry {
final ManagedChannel channel;
final AttributesServiceGrpc.AttributesServiceBlockingStub stub;
private CacheEntry(ManagedChannel channel, AttributesServiceGrpc.AttributesServiceBlockingStub stub) {
this.channel = channel;
this.stub = stub;
}
}

// make this protected so we can test the address normalization logic
synchronized AttributesServiceGrpc.AttributesServiceBlockingStub getStub() {
return AttributesServiceGrpc.newBlockingStub(channel);
}


@Override
public GetAttributeValuesByFqnsResponse getAttributeValuesByFqn(GetAttributeValuesByFqnsRequest request) {
return getStub().getAttributeValuesByFqns(request);
}

}
16 changes: 12 additions & 4 deletions sdk/src/main/java/io/opentdf/platform/sdk/SDK.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io.opentdf.platform.authorization.AuthorizationServiceGrpc;
import io.opentdf.platform.authorization.AuthorizationServiceGrpc.AuthorizationServiceFutureStub;
import io.opentdf.platform.policy.attributes.AttributesServiceGrpc;
import io.opentdf.platform.policy.attributes.GetAttributeValuesByFqnsRequest;
import io.opentdf.platform.policy.attributes.AttributesServiceGrpc.AttributesServiceFutureStub;
import io.opentdf.platform.policy.namespaces.NamespaceServiceGrpc;
import io.opentdf.platform.policy.namespaces.NamespaceServiceGrpc.NamespaceServiceFutureStub;
Expand All @@ -13,6 +14,9 @@
import io.opentdf.platform.policy.subjectmapping.SubjectMappingServiceGrpc;
import io.opentdf.platform.policy.subjectmapping.SubjectMappingServiceGrpc.SubjectMappingServiceFutureStub;
import io.opentdf.platform.sdk.nanotdf.NanoTDFType;
import io.opentdf.platform.policy.attributes.GetAttributeValuesByFqnsResponse;

import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -45,17 +49,20 @@ public interface KAS extends AutoCloseable {
byte[] unwrapNanoTDF(NanoTDFType.ECCurve curve, String header, String kasURL);
}

public interface AttributesService extends AutoCloseable {
GetAttributeValuesByFqnsResponse getAttributeValuesByFqn(GetAttributeValuesByFqnsRequest request);
}

// TODO: add KAS
public interface Services extends AutoCloseable {
AuthorizationServiceFutureStub authorization();
AttributesServiceFutureStub attributes();
AttributesService attributes();
NamespaceServiceFutureStub namespaces();
SubjectMappingServiceFutureStub subjectMappings();
ResourceMappingServiceFutureStub resourceMappings();
KAS kas();

static Services newServices(ManagedChannel channel, KAS kas) {
var attributeService = AttributesServiceGrpc.newFutureStub(channel);
static Services newServices(ManagedChannel channel, KAS kas, AttributesService attributeService) {
var namespaceService = NamespaceServiceGrpc.newFutureStub(channel);
var subjectMappingService = SubjectMappingServiceGrpc.newFutureStub(channel);
var resourceMappingService = ResourceMappingServiceGrpc.newFutureStub(channel);
Expand All @@ -65,11 +72,12 @@ static Services newServices(ManagedChannel channel, KAS kas) {
@Override
public void close() throws Exception {
channel.shutdownNow();
attributeService.close();
kas.close();
}

@Override
public AttributesServiceFutureStub attributes() {
public AttributesService attributes() {
return attributeService;
}

Expand Down
8 changes: 6 additions & 2 deletions sdk/src/main/java/io/opentdf/platform/sdk/SDKBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -193,20 +193,24 @@ ServicesAndInternals buildServices() {

var authInterceptor = getGrpcAuthInterceptor(dpopKey);
ManagedChannel channel;
ManagedChannel attributesChannel;
Function<String, ManagedChannel> managedChannelFactory;
if (authInterceptor == null) {
channel = getManagedChannelBuilder(platformEndpoint).build();
attributesChannel = getManagedChannelBuilder(platformEndpoint).build();
managedChannelFactory = (String endpoint) -> getManagedChannelBuilder(endpoint).build();

} else {
channel = getManagedChannelBuilder(platformEndpoint).intercept(authInterceptor).build();
attributesChannel = getManagedChannelBuilder(platformEndpoint).intercept(authInterceptor).build();
managedChannelFactory = (String endpoint) -> getManagedChannelBuilder(endpoint).intercept(authInterceptor).build();
}
var client = new KASClient(managedChannelFactory, dpopKey);
var kasclient = new KASClient(managedChannelFactory, dpopKey);
var attrclient = new AttributesClient(attributesChannel);
return new ServicesAndInternals(
authInterceptor,
sslFactory == null ? null : sslFactory.getTrustManager().orElse(null),
SDK.Services.newServices(channel, client)
SDK.Services.newServices(channel, kasclient, attrclient)
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package io.opentdf.platform.sdk;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;

import org.junit.jupiter.api.Test;

import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.opentdf.platform.policy.attributes.AttributesServiceGrpc;
import io.opentdf.platform.policy.attributes.GetAttributeValuesByFqnsRequest;
import io.opentdf.platform.policy.attributes.GetAttributeValuesByFqnsResponse;
import io.opentdf.platform.policy.attributes.GetAttributeValuesByFqnsResponse.AttributeAndValue;
import io.opentdf.platform.policy.Attribute;
import io.opentdf.platform.policy.Namespace;
import io.opentdf.platform.policy.Value;
import io.opentdf.platform.policy.AttributeRuleTypeEnum;

import static io.opentdf.platform.sdk.SDKBuilderTest.getRandomPort;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;


public class AttributeClientTest {
@Test
void testGettingAttributeByFqn() throws IOException {
AttributesServiceGrpc.AttributesServiceImplBase attributesService = new AttributesServiceGrpc.AttributesServiceImplBase() {
@Override
public void getAttributeValuesByFqns(GetAttributeValuesByFqnsRequest request,
io.grpc.stub.StreamObserver<GetAttributeValuesByFqnsResponse> responseObserver) {
Attribute attribute1 = Attribute.newBuilder().setId("CLS").setNamespace(
Namespace.newBuilder().setId("v").setName("virtru.com").setFqn("https://virtru.com").build())
.setName("Classification").setRule(AttributeRuleTypeEnum.ATTRIBUTE_RULE_TYPE_ENUM_HIERARCHY).setFqn("https://virtru.com/attr/classification").build();

Value attributeValue1 = Value.newBuilder()
.setValue("value1")
.build();

// Create a sample AttributeValues object
AttributeAndValue attributeAndValues = AttributeAndValue.newBuilder().setAttribute(attribute1)
.setValue(attributeValue1)
.build();
GetAttributeValuesByFqnsResponse response = GetAttributeValuesByFqnsResponse.newBuilder()
.putFqnAttributeValues("https://virtru.com/attr/classification/value/value1",attributeAndValues)
.build();
responseObserver.onNext(response);
responseObserver.onCompleted();

}
};

Server attrServer = null;
try {
attrServer = startServer(attributesService);
String attrServerUrl = "localhost:" + attrServer.getPort();
ManagedChannel channel = ManagedChannelBuilder
.forTarget(attrServerUrl)
.usePlaintext()
.build();
try (var attr = new AttributesClient(channel)) {
GetAttributeValuesByFqnsResponse resp = attr.getAttributeValuesByFqn(GetAttributeValuesByFqnsRequest.newBuilder().build());
Set<String> fqnSet = new HashSet<>(Arrays.asList("https://virtru.com/attr/classification/value/value1"));
assertThat(resp.getFqnAttributeValuesMap().keySet()).isEqualTo(fqnSet);
assertThat(resp.getFqnAttributeValuesCount()).isEqualTo(1);
}
} finally {
if (attrServer != null) {
attrServer.shutdownNow();
}
}
}
private static Server startServer(AttributesServiceGrpc.AttributesServiceImplBase attrService) throws IOException {
return ServerBuilder
.forPort(getRandomPort())
.directExecutor()
.addService(attrService)
.build()
.start();
}

}