Skip to content

Commit 603edd6

Browse files
mkleenesujankota
authored andcommitted
fix(sdk): allow SDK to handle protocols in addresses (#70)
TDFs contain embedded URLs, some of which contain protocols. In order for them to work with GRPC we need to strip off the protocol. The logic for ports is to use one if it is specified, otherwise we use 80 if the protocol is `http`, otherwise use `443`.
1 parent 1d3aeee commit 603edd6

File tree

2 files changed

+62
-9
lines changed

2 files changed

+62
-9
lines changed

sdk/src/main/java/io/opentdf/platform/sdk/KASClient.java

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,19 @@
1717
import io.opentdf.platform.sdk.nanotdf.NanoTDFType;
1818
import org.bouncycastle.jce.interfaces.ECPublicKey;
1919

20-
import java.io.IOException;
21-
import java.security.InvalidAlgorithmParameterException;
2220
import java.security.MessageDigest;
2321
import java.security.NoSuchAlgorithmException;
24-
import java.security.NoSuchProviderException;
22+
import java.net.MalformedURLException;
23+
import java.net.URL;
2524
import java.time.Duration;
2625
import java.time.Instant;
2726
import java.util.ArrayList;
2827
import java.util.Date;
2928
import java.util.HashMap;
3029
import java.util.function.Function;
3130

31+
import static java.lang.String.format;
32+
3233
public class KASClient implements SDK.KAS, AutoCloseable {
3334

3435
private final Function<String, ManagedChannel> channelFactory;
@@ -67,6 +68,33 @@ public String getPublicKey(Config.KASInfo kasInfo) {
6768
.getPublicKey();
6869
}
6970

71+
private String normalizeAddress(String urlString) {
72+
URL url;
73+
try {
74+
url = new URL(urlString);
75+
} catch (MalformedURLException e) {
76+
// if there is no protocol then they either gave us
77+
// a correct address or one we don't know how to fix
78+
return urlString;
79+
}
80+
81+
// otherwise we take the specified port or default
82+
// based on whether the URL uses a scheme that
83+
// implies TLS
84+
int port;
85+
if (url.getPort() == -1) {
86+
if ("http".equals(url.getProtocol())) {
87+
port = 80;
88+
} else {
89+
port = 443;
90+
}
91+
} else {
92+
port = url.getPort();
93+
}
94+
95+
return format("%s:%d", url.getHost(), port);
96+
}
97+
7098
@Override
7199
public synchronized void close() {
72100
var entries = new ArrayList<>(stubs.values());
@@ -188,21 +216,22 @@ public byte[] unwrapNanoTDF(NanoTDFType.ECCurve curve, String header, String kas
188216
private static class CacheEntry {
189217
final ManagedChannel channel;
190218
final AccessServiceGrpc.AccessServiceBlockingStub stub;
191-
192219
private CacheEntry(ManagedChannel channel, AccessServiceGrpc.AccessServiceBlockingStub stub) {
193220
this.channel = channel;
194221
this.stub = stub;
195222
}
196223
}
197224

198-
private synchronized AccessServiceGrpc.AccessServiceBlockingStub getStub(String url) {
199-
if (!stubs.containsKey(url)) {
200-
var channel = channelFactory.apply(url);
225+
// make this protected so we can test the address normalization logic
226+
synchronized AccessServiceGrpc.AccessServiceBlockingStub getStub(String url) {
227+
var realAddress = normalizeAddress(url);
228+
if (!stubs.containsKey(realAddress)) {
229+
var channel = channelFactory.apply(realAddress);
201230
var stub = AccessServiceGrpc.newBlockingStub(channel);
202-
stubs.put(url, new CacheEntry(channel, stub));
231+
stubs.put(realAddress, new CacheEntry(channel, stub));
203232
}
204233

205-
return stubs.get(url).stub;
234+
return stubs.get(realAddress).stub;
206235
}
207236
}
208237

sdk/src/test/java/io/opentdf/platform/sdk/KASClientTest.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.text.ParseException;
2525
import java.util.Base64;
2626
import java.util.Random;
27+
import java.util.concurrent.atomic.AtomicReference;
2728
import java.util.function.Function;
2829

2930
import static io.opentdf.platform.sdk.SDKBuilderTest.getRandomPort;
@@ -136,6 +137,29 @@ public void rewrap(RewrapRequest request, StreamObserver<RewrapResponse> respons
136137
}
137138
}
138139

140+
@Test
141+
public void testAddressNormalization() {
142+
var lastAddress = new AtomicReference<String>();
143+
var dpopKeypair = CryptoUtils.generateRSAKeypair();
144+
var dpopKey = new RSAKey.Builder((RSAPublicKey)dpopKeypair.getPublic()).privateKey(dpopKeypair.getPrivate()).build();
145+
var kasClient = new KASClient(addr -> {
146+
lastAddress.set(addr);
147+
return ManagedChannelBuilder.forTarget(addr).build();
148+
}, dpopKey);
149+
150+
var stub = kasClient.getStub("http://localhost:8080");
151+
assertThat(lastAddress.get()).isEqualTo("localhost:8080");
152+
var otherStub = kasClient.getStub("https://localhost:8080");
153+
assertThat(lastAddress.get()).isEqualTo("localhost:8080");
154+
assertThat(stub).isSameAs(otherStub);
155+
156+
kasClient.getStub("https://example.org");
157+
assertThat(lastAddress.get()).isEqualTo("example.org:443");
158+
159+
kasClient.getStub("http://example.org");
160+
assertThat(lastAddress.get()).isEqualTo("example.org:80");
161+
}
162+
139163
private static Server startServer(AccessServiceGrpc.AccessServiceImplBase accessService) throws IOException {
140164
return ServerBuilder
141165
.forPort(getRandomPort())

0 commit comments

Comments
 (0)