Skip to content

Commit 889ca01

Browse files
committed
JDBC implementation of RegisteredClientRepository
1 parent 8e9563a commit 889ca01

File tree

3 files changed

+640
-0
lines changed

3 files changed

+640
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
/*
2+
* Copyright 2020-2021 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.security.oauth2.server.authorization.client;
17+
18+
import com.fasterxml.jackson.core.JsonProcessingException;
19+
import com.fasterxml.jackson.databind.ObjectMapper;
20+
import org.springframework.jdbc.core.*;
21+
import org.springframework.jdbc.support.lob.DefaultLobHandler;
22+
import org.springframework.jdbc.support.lob.LobCreator;
23+
import org.springframework.jdbc.support.lob.LobHandler;
24+
import org.springframework.security.oauth2.core.AuthorizationGrantType;
25+
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
26+
import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
27+
import org.springframework.security.oauth2.server.authorization.config.TokenSettings;
28+
import org.springframework.util.Assert;
29+
30+
import java.nio.charset.StandardCharsets;
31+
import java.sql.*;
32+
import java.time.Duration;
33+
import java.time.Instant;
34+
import java.util.*;
35+
import java.util.function.Function;
36+
import java.util.stream.Collectors;
37+
38+
/**
39+
* JDBC-backed registered client repository
40+
*
41+
* @author Rafal Lewczuk
42+
* @since 0.1.2
43+
*/
44+
public class JdbcRegisteredClientRepository implements RegisteredClientRepository {
45+
46+
private static final Map<String, AuthorizationGrantType> AUTHORIZATION_GRANT_TYPE_MAP;
47+
private static final Map<String, ClientAuthenticationMethod> CLIENT_AUTHENTICATION_METHOD_MAP;
48+
49+
private static final String COLUMN_NAMES = "id, "
50+
+ "client_id, "
51+
+ "client_id_issued_at, "
52+
+ "client_secret, "
53+
+ "client_secret_expires_at, "
54+
+ "client_name, "
55+
+ "client_authentication_methods, "
56+
+ "authorization_grant_types, "
57+
+ "redirect_uris, "
58+
+ "scopes, "
59+
+ "client_settings,"
60+
+ "token_settings";
61+
62+
private static final String TABLE_NAME = "oauth2_registered_client";
63+
64+
private static final String LOAD_REGISTERED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + " FROM " + TABLE_NAME + " WHERE ";
65+
66+
private static final String INSERT_REGISTERED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME
67+
+ "(" + COLUMN_NAMES + ") values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
68+
69+
private RowMapper<RegisteredClient> registeredClientRowMapper;
70+
71+
private Function<RegisteredClient, List<SqlParameterValue>> registeredClientParametersMapper;
72+
73+
private final JdbcOperations jdbcOperations;
74+
75+
private final LobHandler lobHandler = new DefaultLobHandler();
76+
77+
private final ObjectMapper objectMapper;
78+
79+
public JdbcRegisteredClientRepository(JdbcOperations jdbcOperations, ObjectMapper objectMapper) {
80+
Assert.notNull(jdbcOperations, "jdbcOperations cannot be null");
81+
Assert.notNull(objectMapper, "objectMapper cannot be null");
82+
this.jdbcOperations = jdbcOperations;
83+
this.objectMapper = objectMapper;
84+
this.registeredClientRowMapper = new DefaultRegisteredClientRowMapper();
85+
this.registeredClientParametersMapper = new DefaultRegisteredClientParametersMapper();
86+
}
87+
88+
/**
89+
* Allows changing of {@link RegisteredClient} row mapper implementation
90+
*
91+
* @param registeredClientRowMapper mapper implementation
92+
*/
93+
public void setRegisteredClientRowMapper(RowMapper<RegisteredClient> registeredClientRowMapper) {
94+
Assert.notNull(registeredClientRowMapper, "registeredClientRowMapper cannot be null");
95+
this.registeredClientRowMapper = registeredClientRowMapper;
96+
}
97+
98+
/**
99+
* Allows changing of SQL parameter mapper for {@link RegisteredClient}
100+
*
101+
* @param registeredClientParametersMapper mapper implementation
102+
*/
103+
public void setRegisteredClientParametersMapper(Function<RegisteredClient, List<SqlParameterValue>> registeredClientParametersMapper) {
104+
Assert.notNull(registeredClientParametersMapper, "registeredClientParameterMapper cannot be null");
105+
this.registeredClientParametersMapper = registeredClientParametersMapper;
106+
}
107+
108+
@Override
109+
public void save(RegisteredClient registeredClient) {
110+
Assert.notNull(registeredClient, "registeredClient cannot be null");
111+
RegisteredClient foundClient = this.findBy("id = ? OR client_id = ? OR client_secret = ?",
112+
registeredClient.getId(), registeredClient.getClientId(),
113+
registeredClient.getClientSecret().getBytes(StandardCharsets.UTF_8));
114+
115+
if (null != foundClient) {
116+
Assert.isTrue(!foundClient.getId().equals(registeredClient.getId()),
117+
"Registered client must be unique. Found duplicate identifier: " + registeredClient.getId());
118+
Assert.isTrue(!foundClient.getClientId().equals(registeredClient.getClientId()),
119+
"Registered client must be unique. Found duplicate client identifier: " + registeredClient.getClientId());
120+
Assert.isTrue(!foundClient.getClientSecret().equals(registeredClient.getClientSecret()),
121+
"Registered client must be unique. Found duplicate client secret for identifier: " + registeredClient.getId());
122+
}
123+
124+
List<SqlParameterValue> parameters = this.registeredClientParametersMapper.apply(registeredClient);
125+
126+
try (LobCreator lobCreator = this.lobHandler.getLobCreator()) {
127+
PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator, parameters.toArray());
128+
jdbcOperations.update(INSERT_REGISTERED_CLIENT_SQL, pss);
129+
}
130+
}
131+
132+
@Override
133+
public RegisteredClient findById(String id) {
134+
Assert.hasText(id, "id cannot be empty");
135+
return findBy("id = ?", id);
136+
}
137+
138+
@Override
139+
public RegisteredClient findByClientId(String clientId) {
140+
Assert.hasText(clientId, "clientId cannot be empty");
141+
return findBy("client_id = ?", clientId);
142+
}
143+
144+
private RegisteredClient findBy(String condStr, Object...args) {
145+
List<RegisteredClient> lst = jdbcOperations.query(
146+
LOAD_REGISTERED_CLIENT_SQL + condStr,
147+
registeredClientRowMapper, args);
148+
return !lst.isEmpty() ? lst.get(0) : null;
149+
}
150+
151+
private class DefaultRegisteredClientRowMapper implements RowMapper<RegisteredClient> {
152+
153+
private final LobHandler lobHandler = new DefaultLobHandler();
154+
155+
private Collection<String> parseList(String s) {
156+
return s != null ? Arrays.asList(s.split("\\|")) : Collections.emptyList();
157+
}
158+
159+
@Override
160+
@SuppressWarnings("unchecked")
161+
public RegisteredClient mapRow(ResultSet rs, int rowNum) throws SQLException {
162+
Collection<String> scopes = parseList(rs.getString("scopes"));
163+
List<AuthorizationGrantType> authGrantTypes = parseList(rs.getString("authorization_grant_types"))
164+
.stream().map(AUTHORIZATION_GRANT_TYPE_MAP::get).collect(Collectors.toList());
165+
List<ClientAuthenticationMethod> clientAuthMethods = parseList(rs.getString("client_authentication_methods"))
166+
.stream().map(CLIENT_AUTHENTICATION_METHOD_MAP::get).collect(Collectors.toList());
167+
Collection<String> redirectUris = parseList(rs.getString("redirect_uris"));
168+
Timestamp clientIssuedAt = rs.getTimestamp("client_id_issued_at");
169+
Timestamp clientSecretExpiresAt = rs.getTimestamp("client_secret_expires_at");
170+
byte[] clientSecretBytes = this.lobHandler.getBlobAsBytes(rs, "client_secret");
171+
String clientSecret = clientSecretBytes != null ? new String(clientSecretBytes, StandardCharsets.UTF_8) : null;
172+
RegisteredClient.Builder builder = RegisteredClient
173+
.withId(rs.getString("id"))
174+
.clientId(rs.getString("client_id"))
175+
.clientIdIssuedAt(clientIssuedAt != null ? clientIssuedAt.toInstant() : null)
176+
.clientSecret(clientSecret)
177+
.clientSecretExpiresAt(clientSecretExpiresAt != null ? clientSecretExpiresAt.toInstant() : null)
178+
.clientName(rs.getString("client_name"))
179+
.clientAuthenticationMethods(coll -> coll.addAll(clientAuthMethods))
180+
.authorizationGrantTypes(coll -> coll.addAll(authGrantTypes))
181+
.redirectUris(coll -> coll.addAll(redirectUris))
182+
.scopes(coll -> coll.addAll(scopes));
183+
184+
RegisteredClient rc = builder.build();
185+
186+
TokenSettings ts = rc.getTokenSettings();
187+
ClientSettings cs = rc.getClientSettings();
188+
189+
try {
190+
String tokenSettingsJson = rs.getString("token_settings");
191+
if (tokenSettingsJson != null) {
192+
193+
Map<String, Object> m = JdbcRegisteredClientRepository.this.objectMapper.readValue(tokenSettingsJson, Map.class);
194+
195+
Number accessTokenTTL = (Number) m.get("access_token_ttl");
196+
if (accessTokenTTL != null) {
197+
ts.accessTokenTimeToLive(Duration.ofMillis(accessTokenTTL.longValue()));
198+
}
199+
200+
Number refreshTokenTTL = (Number) m.get("refresh_token_ttl");
201+
if (refreshTokenTTL != null) {
202+
ts.refreshTokenTimeToLive(Duration.ofMillis(refreshTokenTTL.longValue()));
203+
}
204+
205+
Boolean reuseRefreshTokens = (Boolean) m.get("reuse_refresh_tokens");
206+
if (reuseRefreshTokens != null) {
207+
ts.reuseRefreshTokens(reuseRefreshTokens);
208+
}
209+
}
210+
211+
String clientSettingsJson = rs.getString("client_settings");
212+
if (clientSettingsJson != null) {
213+
214+
Map<String, Object> m = JdbcRegisteredClientRepository.this.objectMapper.readValue(clientSettingsJson, Map.class);
215+
216+
Boolean requireProofKey = (Boolean) m.get("require_proof_key");
217+
if (requireProofKey != null) {
218+
cs.requireProofKey(requireProofKey);
219+
}
220+
221+
Boolean requireUserConsent = (Boolean) m.get("require_user_consent");
222+
if (requireUserConsent != null) {
223+
cs.requireUserConsent(requireUserConsent);
224+
}
225+
}
226+
227+
228+
} catch (JsonProcessingException e) {
229+
throw new IllegalArgumentException(e.getMessage(), e);
230+
}
231+
232+
return rc;
233+
}
234+
}
235+
236+
private class DefaultRegisteredClientParametersMapper implements Function<RegisteredClient, List<SqlParameterValue>> {
237+
@Override
238+
public List<SqlParameterValue> apply(RegisteredClient registeredClient) {
239+
try {
240+
List<String> clientAuthenticationMethodNames = new ArrayList<>(registeredClient.getClientAuthenticationMethods().size());
241+
for (ClientAuthenticationMethod clientAuthenticationMethod : registeredClient.getClientAuthenticationMethods()) {
242+
clientAuthenticationMethodNames.add(clientAuthenticationMethod.getValue());
243+
}
244+
245+
List<String> authorizationGrantTypeNames = new ArrayList<>(registeredClient.getAuthorizationGrantTypes().size());
246+
for (AuthorizationGrantType authorizationGrantType : registeredClient.getAuthorizationGrantTypes()) {
247+
authorizationGrantTypeNames.add(authorizationGrantType.getValue());
248+
}
249+
250+
Instant issuedAt = registeredClient.getClientIdIssuedAt() != null ?
251+
registeredClient.getClientIdIssuedAt() : Instant.now();
252+
253+
Timestamp clientSecretExpiresAt = registeredClient.getClientSecretExpiresAt() != null ?
254+
Timestamp.from(registeredClient.getClientSecretExpiresAt()) : null;
255+
256+
Map<String, Object> clientSettings = new HashMap<>();
257+
clientSettings.put("require_proof_key", registeredClient.getClientSettings().requireProofKey());
258+
clientSettings.put("require_user_consent", registeredClient.getClientSettings().requireUserConsent());
259+
String clientSettingsJson = JdbcRegisteredClientRepository.this.objectMapper.writeValueAsString(clientSettings);
260+
261+
Map<String, Object> tokenSettings = new HashMap<>();
262+
tokenSettings.put("access_token_ttl", registeredClient.getTokenSettings().accessTokenTimeToLive().toMillis());
263+
tokenSettings.put("reuse_refresh_tokens", registeredClient.getTokenSettings().reuseRefreshTokens());
264+
tokenSettings.put("refresh_token_ttl", registeredClient.getTokenSettings().refreshTokenTimeToLive().toMillis());
265+
String tokenSettingsJson = JdbcRegisteredClientRepository.this.objectMapper.writeValueAsString(tokenSettings);
266+
267+
return Arrays.asList(
268+
new SqlParameterValue(Types.VARCHAR, registeredClient.getId()),
269+
new SqlParameterValue(Types.VARCHAR, registeredClient.getClientId()),
270+
new SqlParameterValue(Types.TIMESTAMP, Timestamp.from(issuedAt)),
271+
new SqlParameterValue(Types.BLOB, registeredClient.getClientSecret().getBytes(StandardCharsets.UTF_8)),
272+
new SqlParameterValue(Types.TIMESTAMP, clientSecretExpiresAt),
273+
new SqlParameterValue(Types.VARCHAR, registeredClient.getClientName()),
274+
new SqlParameterValue(Types.VARCHAR, String.join("|", clientAuthenticationMethodNames)),
275+
new SqlParameterValue(Types.VARCHAR, String.join("|", authorizationGrantTypeNames)),
276+
new SqlParameterValue(Types.VARCHAR, String.join("|", registeredClient.getRedirectUris())),
277+
new SqlParameterValue(Types.VARCHAR, String.join("|", registeredClient.getScopes())),
278+
new SqlParameterValue(Types.VARCHAR, clientSettingsJson),
279+
new SqlParameterValue(Types.VARCHAR, tokenSettingsJson));
280+
} catch (JsonProcessingException e) {
281+
throw new IllegalArgumentException(e.getMessage(), e);
282+
}
283+
}
284+
}
285+
286+
private static final class LobCreatorArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter {
287+
288+
protected final LobCreator lobCreator;
289+
290+
private LobCreatorArgumentPreparedStatementSetter(LobCreator lobCreator, Object[] args) {
291+
super(args);
292+
this.lobCreator = lobCreator;
293+
}
294+
295+
@Override
296+
protected void doSetValue(PreparedStatement ps, int parameterPosition, Object argValue) throws SQLException {
297+
if (argValue instanceof SqlParameterValue) {
298+
SqlParameterValue paramValue = (SqlParameterValue) argValue;
299+
if (paramValue.getSqlType() == Types.BLOB) {
300+
if (paramValue.getValue() != null) {
301+
Assert.isInstanceOf(byte[].class, paramValue.getValue(),
302+
"Value of blob parameter must be byte[]");
303+
}
304+
byte[] valueBytes = (byte[]) paramValue.getValue();
305+
this.lobCreator.setBlobAsBytes(ps, parameterPosition, valueBytes);
306+
return;
307+
}
308+
}
309+
super.doSetValue(ps, parameterPosition, argValue);
310+
}
311+
312+
}
313+
314+
static {
315+
Map<String, AuthorizationGrantType> am = new HashMap<>();
316+
for (AuthorizationGrantType a : Arrays.asList(
317+
AuthorizationGrantType.AUTHORIZATION_CODE,
318+
AuthorizationGrantType.REFRESH_TOKEN,
319+
AuthorizationGrantType.CLIENT_CREDENTIALS,
320+
AuthorizationGrantType.PASSWORD,
321+
AuthorizationGrantType.IMPLICIT)) {
322+
am.put(a.getValue(), a);
323+
}
324+
AUTHORIZATION_GRANT_TYPE_MAP = Collections.unmodifiableMap(am);
325+
326+
Map<String, ClientAuthenticationMethod> cm = new HashMap<>();
327+
for (ClientAuthenticationMethod c : Arrays.asList(
328+
ClientAuthenticationMethod.NONE,
329+
ClientAuthenticationMethod.BASIC,
330+
ClientAuthenticationMethod.POST)) {
331+
cm.put(c.getValue(), c);
332+
}
333+
CLIENT_AUTHENTICATION_METHOD_MAP = Collections.unmodifiableMap(cm);
334+
}
335+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
CREATE TABLE oauth2_registered_client (
2+
id varchar(100) NOT NULL,
3+
client_id varchar(100) NOT NULL,
4+
client_id_issued_at timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
5+
client_secret blob NOT NULL,
6+
client_secret_expires_at timestamp DEFAULT NULL,
7+
client_name varchar(200),
8+
client_authentication_methods varchar(1000) NOT NULL,
9+
authorization_grant_types varchar(1000) NOT NULL,
10+
redirect_uris varchar(1000) NOT NULL,
11+
scopes varchar(1000) NOT NULL,
12+
client_settings varchar(1000) DEFAULT NULL,
13+
token_settings varchar(1000) DEFAULT NULL,
14+
PRIMARY KEY (id));

0 commit comments

Comments
 (0)