|
17 | 17 | import io.opentdf.platform.sdk.nanotdf.NanoTDFType;
|
18 | 18 | import org.bouncycastle.jce.interfaces.ECPublicKey;
|
19 | 19 |
|
20 |
| -import java.io.IOException; |
21 |
| -import java.security.InvalidAlgorithmParameterException; |
22 | 20 | import java.security.MessageDigest;
|
23 | 21 | import java.security.NoSuchAlgorithmException;
|
24 |
| -import java.security.NoSuchProviderException; |
| 22 | +import java.net.MalformedURLException; |
| 23 | +import java.net.URL; |
25 | 24 | import java.time.Duration;
|
26 | 25 | import java.time.Instant;
|
27 | 26 | import java.util.ArrayList;
|
28 | 27 | import java.util.Date;
|
29 | 28 | import java.util.HashMap;
|
30 | 29 | import java.util.function.Function;
|
31 | 30 |
|
| 31 | +import static java.lang.String.format; |
| 32 | + |
32 | 33 | public class KASClient implements SDK.KAS, AutoCloseable {
|
33 | 34 |
|
34 | 35 | private final Function<String, ManagedChannel> channelFactory;
|
@@ -67,6 +68,33 @@ public String getPublicKey(Config.KASInfo kasInfo) {
|
67 | 68 | .getPublicKey();
|
68 | 69 | }
|
69 | 70 |
|
| 71 | + private String normalizeAddress(String urlString) { |
| 72 | + URL url; |
| 73 | + try { |
| 74 | + url = new URL(urlString); |
| 75 | + } catch (MalformedURLException e) { |
| 76 | + // if there is no protocol then they either gave us |
| 77 | + // a correct address or one we don't know how to fix |
| 78 | + return urlString; |
| 79 | + } |
| 80 | + |
| 81 | + // otherwise we take the specified port or default |
| 82 | + // based on whether the URL uses a scheme that |
| 83 | + // implies TLS |
| 84 | + int port; |
| 85 | + if (url.getPort() == -1) { |
| 86 | + if ("http".equals(url.getProtocol())) { |
| 87 | + port = 80; |
| 88 | + } else { |
| 89 | + port = 443; |
| 90 | + } |
| 91 | + } else { |
| 92 | + port = url.getPort(); |
| 93 | + } |
| 94 | + |
| 95 | + return format("%s:%d", url.getHost(), port); |
| 96 | + } |
| 97 | + |
70 | 98 | @Override
|
71 | 99 | public synchronized void close() {
|
72 | 100 | var entries = new ArrayList<>(stubs.values());
|
@@ -188,21 +216,22 @@ public byte[] unwrapNanoTDF(NanoTDFType.ECCurve curve, String header, String kas
|
188 | 216 | private static class CacheEntry {
|
189 | 217 | final ManagedChannel channel;
|
190 | 218 | final AccessServiceGrpc.AccessServiceBlockingStub stub;
|
191 |
| - |
192 | 219 | private CacheEntry(ManagedChannel channel, AccessServiceGrpc.AccessServiceBlockingStub stub) {
|
193 | 220 | this.channel = channel;
|
194 | 221 | this.stub = stub;
|
195 | 222 | }
|
196 | 223 | }
|
197 | 224 |
|
198 |
| - private synchronized AccessServiceGrpc.AccessServiceBlockingStub getStub(String url) { |
199 |
| - if (!stubs.containsKey(url)) { |
200 |
| - var channel = channelFactory.apply(url); |
| 225 | + // make this protected so we can test the address normalization logic |
| 226 | + synchronized AccessServiceGrpc.AccessServiceBlockingStub getStub(String url) { |
| 227 | + var realAddress = normalizeAddress(url); |
| 228 | + if (!stubs.containsKey(realAddress)) { |
| 229 | + var channel = channelFactory.apply(realAddress); |
201 | 230 | var stub = AccessServiceGrpc.newBlockingStub(channel);
|
202 |
| - stubs.put(url, new CacheEntry(channel, stub)); |
| 231 | + stubs.put(realAddress, new CacheEntry(channel, stub)); |
203 | 232 | }
|
204 | 233 |
|
205 |
| - return stubs.get(url).stub; |
| 234 | + return stubs.get(realAddress).stub; |
206 | 235 | }
|
207 | 236 | }
|
208 | 237 |
|
0 commit comments