Skip to content

Commit 473430e

Browse files
rmehta19lqiu96
authored andcommitted
feat: Enable MTLS_S2A bound token by default for gRPC S2A enabled flows (#3591)
Similar to implementation for DirectPath in #3572. This is part of the experimental S2A feature (see #3400)
1 parent 5225529 commit 473430e

File tree

2 files changed

+136
-3
lines changed

2 files changed

+136
-3
lines changed

gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProvider.java

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import com.google.common.io.Files;
5757
import io.grpc.CallCredentials;
5858
import io.grpc.ChannelCredentials;
59+
import io.grpc.CompositeChannelCredentials;
5960
import io.grpc.Grpc;
6061
import io.grpc.InsecureChannelCredentials;
6162
import io.grpc.ManagedChannel;
@@ -69,6 +70,7 @@
6970
import java.nio.charset.StandardCharsets;
7071
import java.security.GeneralSecurityException;
7172
import java.security.KeyStore;
73+
import java.util.ArrayList;
7274
import java.util.HashMap;
7375
import java.util.List;
7476
import java.util.Map;
@@ -139,14 +141,15 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
139141
@Nullable private final Boolean keepAliveWithoutCalls;
140142
private final ChannelPoolSettings channelPoolSettings;
141143
@Nullable private final Credentials credentials;
144+
@Nullable private final CallCredentials mtlsS2ACallCredentials;
142145
@Nullable private final ChannelPrimer channelPrimer;
143146
@Nullable private final Boolean attemptDirectPath;
144147
@Nullable private final Boolean attemptDirectPathXds;
145148
@Nullable private final Boolean allowNonDefaultServiceAccount;
146149
@VisibleForTesting final ImmutableMap<String, ?> directPathServiceConfig;
147150
@Nullable private final MtlsProvider mtlsProvider;
148151
@Nullable private final SecureSessionAgent s2aConfigProvider;
149-
@Nullable private final List<HardBoundTokenTypes> allowedHardBoundTokenTypes;
152+
private final List<HardBoundTokenTypes> allowedHardBoundTokenTypes;
150153
@VisibleForTesting final Map<String, String> headersWithDuplicatesRemoved = new HashMap<>();
151154

152155
@Nullable
@@ -188,6 +191,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) {
188191
this.channelPoolSettings = builder.channelPoolSettings;
189192
this.channelConfigurator = builder.channelConfigurator;
190193
this.credentials = builder.credentials;
194+
this.mtlsS2ACallCredentials = builder.mtlsS2ACallCredentials;
191195
this.channelPrimer = builder.channelPrimer;
192196
this.attemptDirectPath = builder.attemptDirectPath;
193197
this.attemptDirectPathXds = builder.attemptDirectPathXds;
@@ -648,6 +652,12 @@ private ManagedChannel createSingleChannel() throws IOException {
648652
}
649653
if (channelCredentials != null) {
650654
// Create the channel using S2A-secured channel credentials.
655+
if (mtlsS2ACallCredentials != null) {
656+
// Set {@code mtlsS2ACallCredentials} to be per-RPC call credentials,
657+
// which will be used to fetch MTLS_S2A hard bound tokens from the metdata server.
658+
channelCredentials =
659+
CompositeChannelCredentials.create(channelCredentials, mtlsS2ACallCredentials);
660+
}
651661
builder = Grpc.newChannelBuilder(endpoint, channelCredentials);
652662
} else {
653663
// Use default if we cannot initialize channel credentials via DCA or S2A.
@@ -812,18 +822,20 @@ public static final class Builder {
812822
@Nullable private Boolean keepAliveWithoutCalls;
813823
@Nullable private ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator;
814824
@Nullable private Credentials credentials;
825+
@Nullable private CallCredentials mtlsS2ACallCredentials;
815826
@Nullable private ChannelPrimer channelPrimer;
816827
private ChannelPoolSettings channelPoolSettings;
817828
@Nullable private Boolean attemptDirectPath;
818829
@Nullable private Boolean attemptDirectPathXds;
819830
@Nullable private Boolean allowNonDefaultServiceAccount;
820831
@Nullable private ImmutableMap<String, ?> directPathServiceConfig;
821-
@Nullable private List<HardBoundTokenTypes> allowedHardBoundTokenTypes;
832+
private List<HardBoundTokenTypes> allowedHardBoundTokenTypes;
822833

823834
private Builder() {
824835
processorCount = Runtime.getRuntime().availableProcessors();
825836
envProvider = System::getenv;
826837
channelPoolSettings = ChannelPoolSettings.staticallySized(1);
838+
allowedHardBoundTokenTypes = new ArrayList<>();
827839
}
828840

829841
private Builder(InstantiatingGrpcChannelProvider provider) {
@@ -841,11 +853,13 @@ private Builder(InstantiatingGrpcChannelProvider provider) {
841853
this.keepAliveWithoutCalls = provider.keepAliveWithoutCalls;
842854
this.channelConfigurator = provider.channelConfigurator;
843855
this.credentials = provider.credentials;
856+
this.mtlsS2ACallCredentials = provider.mtlsS2ACallCredentials;
844857
this.channelPrimer = provider.channelPrimer;
845858
this.channelPoolSettings = provider.channelPoolSettings;
846859
this.attemptDirectPath = provider.attemptDirectPath;
847860
this.attemptDirectPathXds = provider.attemptDirectPathXds;
848861
this.allowNonDefaultServiceAccount = provider.allowNonDefaultServiceAccount;
862+
this.allowedHardBoundTokenTypes = provider.allowedHardBoundTokenTypes;
849863
this.directPathServiceConfig = provider.directPathServiceConfig;
850864
this.mtlsProvider = provider.mtlsProvider;
851865
this.s2aConfigProvider = provider.s2aConfigProvider;
@@ -914,7 +928,10 @@ Builder setUseS2A(boolean useS2A) {
914928
*/
915929
@InternalApi
916930
public Builder setAllowHardBoundTokenTypes(List<HardBoundTokenTypes> allowedValues) {
917-
this.allowedHardBoundTokenTypes = allowedValues;
931+
this.allowedHardBoundTokenTypes =
932+
Preconditions.checkNotNull(
933+
allowedValues, "List of allowed HardBoundTokenTypes cannot be null");
934+
;
918935
return this;
919936
}
920937

@@ -1133,7 +1150,50 @@ public Builder setDirectPathServiceConfig(Map<String, ?> serviceConfig) {
11331150
return this;
11341151
}
11351152

1153+
boolean isMtlsS2AHardBoundTokensEnabled() {
1154+
// If S2A cannot be used, the list of allowed hard bound token types is empty or doesn't
1155+
// contain
1156+
// {@code HardBoundTokenTypes.MTLS_S2A}, the {@code credentials} are null or not of type
1157+
// {@code
1158+
// ComputeEngineCredentials} then {@code HardBoundTokenTypes.MTLS_S2A} hard bound tokens
1159+
// should
1160+
// not
1161+
// be used. {@code HardBoundTokenTypes.MTLS_S2A} hard bound tokens can only be used on MTLS
1162+
// channels established using S2A and when tokens from MDS (i.e {@code
1163+
// ComputeEngineCredentials}
1164+
// are being used.
1165+
if (!this.useS2A
1166+
|| this.allowedHardBoundTokenTypes.isEmpty()
1167+
|| this.credentials == null
1168+
|| !(this.credentials instanceof ComputeEngineCredentials)) {
1169+
return false;
1170+
}
1171+
return allowedHardBoundTokenTypes.stream()
1172+
.anyMatch(val -> val.equals(HardBoundTokenTypes.MTLS_S2A));
1173+
}
1174+
1175+
CallCredentials createHardBoundTokensCallCredentials(
1176+
ComputeEngineCredentials.GoogleAuthTransport googleAuthTransport,
1177+
ComputeEngineCredentials.BindingEnforcement bindingEnforcement) {
1178+
// We only set scopes and HTTP transport factory from the original credentials because
1179+
// only those are used in gRPC CallCredentials to fetch request metadata.
1180+
return MoreCallCredentials.from(
1181+
((ComputeEngineCredentials) this.credentials)
1182+
.toBuilder()
1183+
.setGoogleAuthTransport(googleAuthTransport)
1184+
.setBindingEnforcement(bindingEnforcement)
1185+
.build());
1186+
}
1187+
11361188
public InstantiatingGrpcChannelProvider build() {
1189+
if (isMtlsS2AHardBoundTokensEnabled()) {
1190+
// Set a {@code ComputeEngineCredentials} instance to be per-RPC call credentials,
1191+
// which will be used to fetch MTLS_S2A hard bound tokens from the metdata server.
1192+
this.mtlsS2ACallCredentials =
1193+
createHardBoundTokensCallCredentials(
1194+
ComputeEngineCredentials.GoogleAuthTransport.MTLS,
1195+
ComputeEngineCredentials.BindingEnforcement.ON);
1196+
}
11371197
InstantiatingGrpcChannelProvider instantiatingGrpcChannelProvider =
11381198
new InstantiatingGrpcChannelProvider(this);
11391199
instantiatingGrpcChannelProvider.removeApiKeyCredentialDuplicateHeaders();

gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/InstantiatingGrpcChannelProviderTest.java

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,79 @@ void createS2ASecuredChannelCredentials_returnsPlaintextToS2AS2AChannelCredentia
11031103
InstantiatingGrpcChannelProvider.LOG.removeHandler(logHandler);
11041104
}
11051105

1106+
@Test
1107+
void isMtlsS2AHardBoundTokensEnabled_useS2AFalse() {
1108+
InstantiatingGrpcChannelProvider.Builder providerBuilder =
1109+
InstantiatingGrpcChannelProvider.newBuilder()
1110+
.setUseS2A(false)
1111+
.setAllowHardBoundTokenTypes(
1112+
Collections.singletonList(
1113+
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A))
1114+
.setCredentials(computeEngineCredentials);
1115+
Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isFalse();
1116+
}
1117+
1118+
@Test
1119+
void isMtlsS2AHardBoundTokensEnabled_hardBoundTokenTypesEmpty() {
1120+
InstantiatingGrpcChannelProvider.Builder providerBuilder =
1121+
InstantiatingGrpcChannelProvider.newBuilder()
1122+
.setUseS2A(true)
1123+
.setAllowHardBoundTokenTypes(new ArrayList<>())
1124+
.setCredentials(computeEngineCredentials);
1125+
Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isFalse();
1126+
}
1127+
1128+
@Test
1129+
void isMtlsS2AHardBoundTokensEnabled_nullCreds() {
1130+
InstantiatingGrpcChannelProvider.Builder providerBuilder =
1131+
InstantiatingGrpcChannelProvider.newBuilder()
1132+
.setUseS2A(true)
1133+
.setAllowHardBoundTokenTypes(
1134+
Collections.singletonList(
1135+
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A))
1136+
.setCredentials(null);
1137+
Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isFalse();
1138+
}
1139+
1140+
@Test
1141+
void isMtlsS2AHardBoundTokensEnabled_notComputeEngineCreds() {
1142+
InstantiatingGrpcChannelProvider.Builder providerBuilder =
1143+
InstantiatingGrpcChannelProvider.newBuilder()
1144+
.setUseS2A(true)
1145+
.setAllowHardBoundTokenTypes(
1146+
Collections.singletonList(
1147+
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A))
1148+
.setCredentials(CloudShellCredentials.create(3000));
1149+
Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isFalse();
1150+
}
1151+
1152+
@Test
1153+
void isMtlsS2AHardBoundTokensEnabled_mtlsS2ANotInList() {
1154+
InstantiatingGrpcChannelProvider.Builder providerBuilder =
1155+
InstantiatingGrpcChannelProvider.newBuilder()
1156+
.setUseS2A(true)
1157+
.setAllowHardBoundTokenTypes(
1158+
Collections.singletonList(
1159+
InstantiatingGrpcChannelProvider.HardBoundTokenTypes.ALTS))
1160+
.setCredentials(computeEngineCredentials);
1161+
Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isFalse();
1162+
}
1163+
1164+
@Test
1165+
void isMtlsS2AHardBoundTokensEnabled_mtlsS2ATokenAllowedInList() {
1166+
List<InstantiatingGrpcChannelProvider.HardBoundTokenTypes> allowHardBoundTokenTypes =
1167+
new ArrayList<>();
1168+
allowHardBoundTokenTypes.add(InstantiatingGrpcChannelProvider.HardBoundTokenTypes.MTLS_S2A);
1169+
allowHardBoundTokenTypes.add(InstantiatingGrpcChannelProvider.HardBoundTokenTypes.ALTS);
1170+
1171+
InstantiatingGrpcChannelProvider.Builder providerBuilder =
1172+
InstantiatingGrpcChannelProvider.newBuilder()
1173+
.setUseS2A(true)
1174+
.setAllowHardBoundTokenTypes(allowHardBoundTokenTypes)
1175+
.setCredentials(computeEngineCredentials);
1176+
Truth.assertThat(providerBuilder.isMtlsS2AHardBoundTokensEnabled()).isTrue();
1177+
}
1178+
11061179
private static class FakeLogHandler extends Handler {
11071180

11081181
List<LogRecord> records = new ArrayList<>();

0 commit comments

Comments
 (0)