56
56
import com .google .common .io .Files ;
57
57
import io .grpc .CallCredentials ;
58
58
import io .grpc .ChannelCredentials ;
59
+ import io .grpc .CompositeChannelCredentials ;
59
60
import io .grpc .Grpc ;
60
61
import io .grpc .InsecureChannelCredentials ;
61
62
import io .grpc .ManagedChannel ;
69
70
import java .nio .charset .StandardCharsets ;
70
71
import java .security .GeneralSecurityException ;
71
72
import java .security .KeyStore ;
73
+ import java .util .ArrayList ;
72
74
import java .util .HashMap ;
73
75
import java .util .List ;
74
76
import java .util .Map ;
@@ -139,14 +141,15 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
139
141
@ Nullable private final Boolean keepAliveWithoutCalls ;
140
142
private final ChannelPoolSettings channelPoolSettings ;
141
143
@ Nullable private final Credentials credentials ;
144
+ @ Nullable private final CallCredentials mtlsS2ACallCredentials ;
142
145
@ Nullable private final ChannelPrimer channelPrimer ;
143
146
@ Nullable private final Boolean attemptDirectPath ;
144
147
@ Nullable private final Boolean attemptDirectPathXds ;
145
148
@ Nullable private final Boolean allowNonDefaultServiceAccount ;
146
149
@ VisibleForTesting final ImmutableMap <String , ?> directPathServiceConfig ;
147
150
@ Nullable private final MtlsProvider mtlsProvider ;
148
151
@ Nullable private final SecureSessionAgent s2aConfigProvider ;
149
- @ Nullable private final List <HardBoundTokenTypes > allowedHardBoundTokenTypes ;
152
+ private final List <HardBoundTokenTypes > allowedHardBoundTokenTypes ;
150
153
@ VisibleForTesting final Map <String , String > headersWithDuplicatesRemoved = new HashMap <>();
151
154
152
155
@ Nullable
@@ -188,6 +191,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) {
188
191
this .channelPoolSettings = builder .channelPoolSettings ;
189
192
this .channelConfigurator = builder .channelConfigurator ;
190
193
this .credentials = builder .credentials ;
194
+ this .mtlsS2ACallCredentials = builder .mtlsS2ACallCredentials ;
191
195
this .channelPrimer = builder .channelPrimer ;
192
196
this .attemptDirectPath = builder .attemptDirectPath ;
193
197
this .attemptDirectPathXds = builder .attemptDirectPathXds ;
@@ -648,6 +652,12 @@ private ManagedChannel createSingleChannel() throws IOException {
648
652
}
649
653
if (channelCredentials != null ) {
650
654
// Create the channel using S2A-secured channel credentials.
655
+ if (mtlsS2ACallCredentials != null ) {
656
+ // Set {@code mtlsS2ACallCredentials} to be per-RPC call credentials,
657
+ // which will be used to fetch MTLS_S2A hard bound tokens from the metdata server.
658
+ channelCredentials =
659
+ CompositeChannelCredentials .create (channelCredentials , mtlsS2ACallCredentials );
660
+ }
651
661
builder = Grpc .newChannelBuilder (endpoint , channelCredentials );
652
662
} else {
653
663
// Use default if we cannot initialize channel credentials via DCA or S2A.
@@ -812,18 +822,20 @@ public static final class Builder {
812
822
@ Nullable private Boolean keepAliveWithoutCalls ;
813
823
@ Nullable private ApiFunction <ManagedChannelBuilder , ManagedChannelBuilder > channelConfigurator ;
814
824
@ Nullable private Credentials credentials ;
825
+ @ Nullable private CallCredentials mtlsS2ACallCredentials ;
815
826
@ Nullable private ChannelPrimer channelPrimer ;
816
827
private ChannelPoolSettings channelPoolSettings ;
817
828
@ Nullable private Boolean attemptDirectPath ;
818
829
@ Nullable private Boolean attemptDirectPathXds ;
819
830
@ Nullable private Boolean allowNonDefaultServiceAccount ;
820
831
@ Nullable private ImmutableMap <String , ?> directPathServiceConfig ;
821
- @ Nullable private List <HardBoundTokenTypes > allowedHardBoundTokenTypes ;
832
+ private List <HardBoundTokenTypes > allowedHardBoundTokenTypes ;
822
833
823
834
private Builder () {
824
835
processorCount = Runtime .getRuntime ().availableProcessors ();
825
836
envProvider = System ::getenv ;
826
837
channelPoolSettings = ChannelPoolSettings .staticallySized (1 );
838
+ allowedHardBoundTokenTypes = new ArrayList <>();
827
839
}
828
840
829
841
private Builder (InstantiatingGrpcChannelProvider provider ) {
@@ -841,11 +853,13 @@ private Builder(InstantiatingGrpcChannelProvider provider) {
841
853
this .keepAliveWithoutCalls = provider .keepAliveWithoutCalls ;
842
854
this .channelConfigurator = provider .channelConfigurator ;
843
855
this .credentials = provider .credentials ;
856
+ this .mtlsS2ACallCredentials = provider .mtlsS2ACallCredentials ;
844
857
this .channelPrimer = provider .channelPrimer ;
845
858
this .channelPoolSettings = provider .channelPoolSettings ;
846
859
this .attemptDirectPath = provider .attemptDirectPath ;
847
860
this .attemptDirectPathXds = provider .attemptDirectPathXds ;
848
861
this .allowNonDefaultServiceAccount = provider .allowNonDefaultServiceAccount ;
862
+ this .allowedHardBoundTokenTypes = provider .allowedHardBoundTokenTypes ;
849
863
this .directPathServiceConfig = provider .directPathServiceConfig ;
850
864
this .mtlsProvider = provider .mtlsProvider ;
851
865
this .s2aConfigProvider = provider .s2aConfigProvider ;
@@ -914,7 +928,10 @@ Builder setUseS2A(boolean useS2A) {
914
928
*/
915
929
@ InternalApi
916
930
public Builder setAllowHardBoundTokenTypes (List <HardBoundTokenTypes > allowedValues ) {
917
- this .allowedHardBoundTokenTypes = allowedValues ;
931
+ this .allowedHardBoundTokenTypes =
932
+ Preconditions .checkNotNull (
933
+ allowedValues , "List of allowed HardBoundTokenTypes cannot be null" );
934
+ ;
918
935
return this ;
919
936
}
920
937
@@ -1133,7 +1150,50 @@ public Builder setDirectPathServiceConfig(Map<String, ?> serviceConfig) {
1133
1150
return this ;
1134
1151
}
1135
1152
1153
+ boolean isMtlsS2AHardBoundTokensEnabled () {
1154
+ // If S2A cannot be used, the list of allowed hard bound token types is empty or doesn't
1155
+ // contain
1156
+ // {@code HardBoundTokenTypes.MTLS_S2A}, the {@code credentials} are null or not of type
1157
+ // {@code
1158
+ // ComputeEngineCredentials} then {@code HardBoundTokenTypes.MTLS_S2A} hard bound tokens
1159
+ // should
1160
+ // not
1161
+ // be used. {@code HardBoundTokenTypes.MTLS_S2A} hard bound tokens can only be used on MTLS
1162
+ // channels established using S2A and when tokens from MDS (i.e {@code
1163
+ // ComputeEngineCredentials}
1164
+ // are being used.
1165
+ if (!this .useS2A
1166
+ || this .allowedHardBoundTokenTypes .isEmpty ()
1167
+ || this .credentials == null
1168
+ || !(this .credentials instanceof ComputeEngineCredentials )) {
1169
+ return false ;
1170
+ }
1171
+ return allowedHardBoundTokenTypes .stream ()
1172
+ .anyMatch (val -> val .equals (HardBoundTokenTypes .MTLS_S2A ));
1173
+ }
1174
+
1175
+ CallCredentials createHardBoundTokensCallCredentials (
1176
+ ComputeEngineCredentials .GoogleAuthTransport googleAuthTransport ,
1177
+ ComputeEngineCredentials .BindingEnforcement bindingEnforcement ) {
1178
+ // We only set scopes and HTTP transport factory from the original credentials because
1179
+ // only those are used in gRPC CallCredentials to fetch request metadata.
1180
+ return MoreCallCredentials .from (
1181
+ ((ComputeEngineCredentials ) this .credentials )
1182
+ .toBuilder ()
1183
+ .setGoogleAuthTransport (googleAuthTransport )
1184
+ .setBindingEnforcement (bindingEnforcement )
1185
+ .build ());
1186
+ }
1187
+
1136
1188
public InstantiatingGrpcChannelProvider build () {
1189
+ if (isMtlsS2AHardBoundTokensEnabled ()) {
1190
+ // Set a {@code ComputeEngineCredentials} instance to be per-RPC call credentials,
1191
+ // which will be used to fetch MTLS_S2A hard bound tokens from the metdata server.
1192
+ this .mtlsS2ACallCredentials =
1193
+ createHardBoundTokensCallCredentials (
1194
+ ComputeEngineCredentials .GoogleAuthTransport .MTLS ,
1195
+ ComputeEngineCredentials .BindingEnforcement .ON );
1196
+ }
1137
1197
InstantiatingGrpcChannelProvider instantiatingGrpcChannelProvider =
1138
1198
new InstantiatingGrpcChannelProvider (this );
1139
1199
instantiatingGrpcChannelProvider .removeApiKeyCredentialDuplicateHeaders ();
0 commit comments