Commit ecc793c8 authored by Matija Obreza's avatar Matija Obreza
Browse files

Merge branch '49-upgrade-spring-oauth' into release-4.1

* 49-upgrade-spring-oauth:
  Spring OAuth: Upgrade to Spring OAuth Authorization Server
parents 25c9277e 1f5391f4
......@@ -111,9 +111,21 @@
</dependency>
<dependency>
<groupId>org.springframework.security.oauth</groupId>
<artifactId>spring-security-oauth2</artifactId>
<version>${spring.security.oauth2.version}</version>
<groupId>org.springframework.security</groupId>
<artifactId>spring-security-oauth2-resource-server</artifactId>
<version>${spring.security.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.springframework.security</groupId>
<artifactId>spring-security-oauth2-jose</artifactId>
<version>${spring.security.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.springframework.security</groupId>
<artifactId>spring-security-oauth2-client</artifactId>
<version>${spring.security.version}</version>
<scope>provided</scope>
</dependency>
......@@ -138,6 +150,13 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.security</groupId>
<artifactId>spring-security-oauth2-authorization-server</artifactId>
<version>0.3.1</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>javax.servlet</groupId>
<artifactId>javax.servlet-api</artifactId>
......
......@@ -48,7 +48,6 @@ import org.genesys.blocks.model.JsonViews;
import org.genesys.blocks.security.model.AclSid;
import org.hibernate.annotations.Type;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.oauth2.provider.ClientDetails;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
......@@ -66,7 +65,7 @@ import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
@Entity
@Table(name = "oauthclient")
@DiscriminatorValue(value = "2")
public class OAuthClient extends AclSid implements ClientDetails, Copyable<OAuthClient> {
public class OAuthClient extends AclSid implements Copyable<OAuthClient> {
/** The Constant serialVersionUID. */
private static final long serialVersionUID = -4204753722663196007L;
......@@ -239,11 +238,6 @@ public class OAuthClient extends AclSid implements ClientDetails, Copyable<OAuth
}
}
/*
* (non-Javadoc)
* @see org.springframework.security.oauth2.provider.ClientDetails#getClientId()
*/
@Override
public String getClientId() {
return clientId;
}
......@@ -257,12 +251,6 @@ public class OAuthClient extends AclSid implements ClientDetails, Copyable<OAuth
this.clientId = clientId;
}
/*
* (non-Javadoc)
* @see
* org.springframework.security.oauth2.provider.ClientDetails#getClientSecret()
*/
@Override
public String getClientSecret() {
return clientSecret;
}
......@@ -412,12 +400,6 @@ public class OAuthClient extends AclSid implements ClientDetails, Copyable<OAuth
this.autoApproveScopes = autoApproveScopes;
}
/*
* (non-Javadoc)
* @see
* org.springframework.security.oauth2.provider.ClientDetails#getResourceIds()
*/
@Override
public Set<String> getResourceIds() {
return resourceIds;
}
......@@ -436,7 +418,6 @@ public class OAuthClient extends AclSid implements ClientDetails, Copyable<OAuth
*
* @return true, if is secret required
*/
@Override
public boolean isSecretRequired() {
return clientSecret != null;
}
......@@ -445,7 +426,6 @@ public class OAuthClient extends AclSid implements ClientDetails, Copyable<OAuth
* (non-Javadoc)
* @see org.springframework.security.oauth2.provider.ClientDetails#isScoped()
*/
@Override
public boolean isScoped() {
return !scopes.isEmpty();
}
......@@ -455,7 +435,6 @@ public class OAuthClient extends AclSid implements ClientDetails, Copyable<OAuth
* @see org.springframework.security.oauth2.provider.ClientDetails#getScope()
*/
@JsonProperty("clientScopes")
@Override
public Set<String> getScope() {
return scopes;
}
......@@ -475,7 +454,6 @@ public class OAuthClient extends AclSid implements ClientDetails, Copyable<OAuth
* @see org.springframework.security.oauth2.provider.ClientDetails#
* getAuthorizedGrantTypes()
*/
@Override
@JsonView(JsonViews.Protected.class)
public Set<String> getAuthorizedGrantTypes() {
return grantTypes;
......@@ -490,12 +468,6 @@ public class OAuthClient extends AclSid implements ClientDetails, Copyable<OAuth
this.grantTypes = grantTypes;
}
/*
* (non-Javadoc)
* @see org.springframework.security.oauth2.provider.ClientDetails#
* getRegisteredRedirectUri()
*/
@Override
@JsonView(JsonViews.Protected.class)
public Set<String> getRegisteredRedirectUri() {
return redirectUris;
......@@ -511,12 +483,6 @@ public class OAuthClient extends AclSid implements ClientDetails, Copyable<OAuth
}
/*
* (non-Javadoc)
* @see
* org.springframework.security.oauth2.provider.ClientDetails#getAuthorities()
*/
@Override
@JsonView(JsonViews.Protected.class)
@JsonDeserialize(contentUsing = GrantedAuthorityDeserializer.class)
public Collection<GrantedAuthority> getAuthorities() {
......@@ -535,12 +501,6 @@ public class OAuthClient extends AclSid implements ClientDetails, Copyable<OAuth
return authorities;
}
/*
* (non-Javadoc)
* @see org.springframework.security.oauth2.provider.ClientDetails#
* getAccessTokenValiditySeconds()
*/
@Override
public Integer getAccessTokenValiditySeconds() {
return accessTokenValidity;
}
......@@ -563,12 +523,6 @@ public class OAuthClient extends AclSid implements ClientDetails, Copyable<OAuth
this.accessTokenValidity = accessTokenValidity;
}
/*
* (non-Javadoc)
* @see org.springframework.security.oauth2.provider.ClientDetails#
* getRefreshTokenValiditySeconds()
*/
@Override
public Integer getRefreshTokenValiditySeconds() {
return refreshTokenValidity;
}
......@@ -591,23 +545,10 @@ public class OAuthClient extends AclSid implements ClientDetails, Copyable<OAuth
this.refreshTokenValidity = refreshTokenValidity;
}
/*
* (non-Javadoc)
* @see
* org.springframework.security.oauth2.provider.ClientDetails#isAutoApprove(java
* .lang.String)
*/
@Override
public boolean isAutoApprove(final String scope) {
return autoApprove || autoApproveScopes.contains(scope);
}
/*
* (non-Javadoc)
* @see org.springframework.security.oauth2.provider.ClientDetails#
* getAdditionalInformation()
*/
@Override
public Map<String, Object> getAdditionalInformation() {
return additionalInformation;
}
......
......@@ -20,12 +20,12 @@ import java.util.List;
import org.genesys.blocks.oauth.model.OAuthClient;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.security.oauth2.provider.ClientDetailsService;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
/**
* The Interface OAuthClientDetailsService.
*/
public interface OAuthClientDetailsService extends ClientDetailsService {
public interface OAuthClientService extends RegisteredClientRepository {
/**
* List client details.
......@@ -128,4 +128,6 @@ public interface OAuthClientDetailsService extends ClientDetailsService {
*/
boolean isOriginRegistered(String origin);
OAuthClient loadClientByClientId(String clientId);
}
......@@ -16,12 +16,18 @@
package org.genesys.blocks.oauth.service;
import java.net.URL;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.persistence.EntityNotFoundException;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.commons.lang3.StringUtils;
import org.genesys.blocks.oauth.model.OAuthClient;
......@@ -44,9 +50,14 @@ import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.oauth2.provider.ClientDetails;
import org.springframework.security.oauth2.provider.ClientRegistrationException;
import org.springframework.security.oauth2.provider.NoSuchClientException;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2TokenFormat;
import org.springframework.security.oauth2.core.oidc.OidcScopes;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient.Builder;
import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
import org.springframework.security.oauth2.server.authorization.config.TokenSettings;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
......@@ -57,7 +68,7 @@ import com.querydsl.core.types.Predicate;
*/
@Service
@Transactional(readOnly = true)
public class OAuthServiceImpl implements OAuthClientDetailsService, InitializingBean {
public class OAuthServiceImpl implements OAuthClientService, InitializingBean {
/** The Constant LOG. */
private static final Logger LOG = LoggerFactory.getLogger(OAuthServiceImpl.class);
......@@ -94,17 +105,11 @@ public class OAuthServiceImpl implements OAuthClientDetailsService, Initializing
}
}
/*
* (non-Javadoc)
* @see org.springframework.security.oauth2.provider.ClientDetailsService#
* loadClientByClientId(java.lang.String)
*/
@Override
@Cacheable(cacheNames = { "oauthclient" }, key = "#clientId", unless = "#result == null")
public ClientDetails loadClientByClientId(final String clientId) throws ClientRegistrationException {
public OAuthClient loadClientByClientId(final String clientId) {
final OAuthClient client = getClient(clientId);
if (client == null) {
throw new NoSuchClientException(clientId);
return client;
}
client.getRoles().remove(OAuthRole.EVERYONE);
client.setRuntimeAuthorities(OAuthRole.EVERYONE);
......@@ -156,7 +161,7 @@ public class OAuthServiceImpl implements OAuthClientDetailsService, Initializing
*/
@Override
@Transactional
@CacheEvict(cacheNames = { "oauthclient" }, key = "#client.clientId", condition = "#client != null")
@CacheEvict(cacheNames = { "oauthclient", "oauthclient.registered" }, key = "#client.clientId", condition = "#client != null")
public OAuthClient removeClient(final OAuthClient client) {
oauthClientRepository.delete(client);
return client;
......@@ -190,7 +195,7 @@ public class OAuthServiceImpl implements OAuthClientDetailsService, Initializing
*/
@Override
@Transactional
@CacheEvict(cacheNames = { "oauthclient" }, key = "#updates.clientId", condition = "#updates != null")
@CacheEvict(cacheNames = { "oauthclient", "oauthclient.registered" }, key = "#updates.clientId", condition = "#updates != null")
public OAuthClient updateClient(final long id, final int version, final OAuthClient updates) {
OAuthClient client = oauthClientRepository.findByIdAndVersion(id, version);
client.apply(updates);
......@@ -200,7 +205,7 @@ public class OAuthServiceImpl implements OAuthClientDetailsService, Initializing
@Override
@Transactional
@CacheEvict(cacheNames = { "oauthclient" }, key = "#sourceId", condition = "#sourceId != null && #targetId != null")
@CacheEvict(cacheNames = { "oauthclient", "oauthclient.registered" }, key = "#sourceId", condition = "#sourceId != null && #targetId != null")
public OAuthClient updateClientId(String sourceId, String targetId) {
OAuthClient client = getClient(sourceId);
client.setClientId(targetId);
......@@ -291,4 +296,108 @@ public class OAuthServiceImpl implements OAuthClientDetailsService, Initializing
return found.get();
}
@Override
public void save(RegisteredClient registeredClient) {
LOG.warn("Saving client: {}", registeredClient);
throw new RuntimeException("Not implemented");
}
@Override
public RegisteredClient findById(String registrationId) {
LOG.warn("Loading OAuth registered client by registrationId {}", registrationId);
var client = oauthClientRepository.findById(Long.parseLong(registrationId)).orElseThrow(() -> new EntityNotFoundException("No such client."));
if (client == null) {
return null;
}
return convertToRegisteredClient(client);
}
@Override
@Cacheable(cacheNames = "oauthclient.registered", key = "#clientId", unless = "#result == null")
public RegisteredClient findByClientId(String clientId) {
LOG.warn("Loading OAuth registered client by clientId {}", clientId);
var client = loadClientByClientId(clientId);
if (client == null) {
return null;
}
return convertToRegisteredClient(client);
}
private RegisteredClient convertToRegisteredClient(OAuthClient client) {
Builder registeredClient = RegisteredClient.withId(UUID.randomUUID().toString());
registeredClient
.clientId(client.getClientId())
.clientSecret(client.getClientSecret())
.clientIdIssuedAt(client.getCreatedDate())
.clientName(client.getTitle())
;
if (StringUtils.isBlank(client.getClientSecret())) {
registeredClient.clientAuthenticationMethod(ClientAuthenticationMethod.NONE);
} else {
registeredClient.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST);
registeredClient.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC);
registeredClient.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT);
}
// Grant types
registeredClient
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
;
if (CollectionUtils.isNotEmpty(client.getAuthorizedGrantTypes())) {
client.getAuthorizedGrantTypes().stream().map(AuthorizationGrantType::new).forEach(registeredClient::authorizationGrantType);
}
// Redirect URIs
registeredClient
.redirectUri("http://web.local:8080/login/oauth2/code/local")
;
if (CollectionUtils.isNotEmpty(client.getRegisteredRedirectUri())) {
client.getRegisteredRedirectUri().forEach(registeredClient::redirectUri);
}
// Scopes
registeredClient
.scope(OidcScopes.OPENID)
.scope("profile")
.scope("email")
;
// Apply scopes
if (CollectionUtils.isNotEmpty(client.getScope())) {
client.getScope().forEach(registeredClient::scope);
}
var token = TokenSettings.builder();
token
.accessTokenFormat(OAuth2TokenFormat.SELF_CONTAINED)
// .accessTokenFormat(OAuth2TokenFormat.REFERENCE) // Spring only supports Opaque or JWTs
.accessTokenTimeToLive(Duration.of(
// 3 days
Optional.ofNullable(client.getAccessTokenValidity()).orElse(60 * 60 * 24 * 3).longValue(), ChronoUnit.SECONDS))
.refreshTokenTimeToLive(
Duration.of(
// 30 days
Optional.ofNullable(client.getRefreshTokenValidity()).orElse(60 * 60 * 24 * 30).longValue(), ChronoUnit.SECONDS))
.reuseRefreshTokens(true);
registeredClient.tokenSettings(token.build());
// Settings
var settings = ClientSettings.builder();
settings.requireAuthorizationConsent(false);
settings.setting("randomSetting", "Is here"); // Random
// Copy additional settings
if (MapUtils.isNotEmpty(client.getAdditionalInformation())) {
client.getAdditionalInformation().entrySet().forEach(setting -> settings.setting(setting.getKey(), setting.getValue()));
}
registeredClient
.clientSettings(settings.build())
.build();
return registeredClient.build();
}
}
......@@ -31,7 +31,8 @@ import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.server.resource.authentication.AbstractOAuth2TokenAuthenticationToken;
import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken;
/**
......@@ -262,14 +263,15 @@ public class SecurityContextUtil {
*/
public static String getOAuthClientId() {
final Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (authentication instanceof OAuth2Authentication) {
OAuth2Authentication oauthAuth = (OAuth2Authentication) authentication;
if (authentication instanceof AbstractOAuth2TokenAuthenticationToken<?>) {
var oauthAuth = (AbstractOAuth2TokenAuthenticationToken<?>) authentication;
LOG.debug("OAuth authentication: {}", oauthAuth);
String clientId = oauthAuth.getOAuth2Request().getClientId();
LOG.debug("OAuth clientId: {}", clientId);
return clientId;
var token = (Jwt) oauthAuth.getToken();
var aud = token.getClaimAsString("aud");
return aud;
} else {
// No OAuth authentication
LOG.warn("TODO {} {}", authentication.getClass(), authentication);
return null;
}
}
......
......@@ -27,12 +27,12 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.genesys.blocks.oauth.model.OAuthClient;
import org.genesys.blocks.oauth.service.OAuthClientService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.provider.ClientDetailsService;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.server.resource.authentication.AbstractOAuth2TokenAuthenticationToken;
import org.springframework.web.filter.OncePerRequestFilter;
import com.google.common.cache.CacheBuilder;
......@@ -46,16 +46,18 @@ import com.google.common.cache.LoadingCache;
public class OAuthClientOriginCheckFilter extends OncePerRequestFilter {
@Autowired
@Qualifier("oauthService")
private ClientDetailsService clientDetailsService;
private OAuthClientService clientDetailsService;
private LoadingCache<String, Set<String>> clientOriginsCache = CacheBuilder.newBuilder().maximumSize(100).expireAfterWrite(10, TimeUnit.MINUTES).build(
new CacheLoader<String, Set<String>>() {
public Set<String> load(String clientId) {
public Set<String> load(String clientId) throws Exception {
if (logger.isInfoEnabled()) {
logger.info("Loading allowed origins for client: " + clientId);
}
OAuthClient clientDetails = (OAuthClient) clientDetailsService.loadClientByClientId(clientId);
if (clientDetails == null || clientDetails.getAllowedOrigins() == null) {
throw new Exception("No such client");
}
return clientDetails.getAllowedOrigins();
}
});
......@@ -63,8 +65,9 @@ public class OAuthClientOriginCheckFilter extends OncePerRequestFilter {
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
if (authentication != null && authentication instanceof OAuth2Authentication) {
if (!checkValidOrigin(request, (OAuth2Authentication) authentication)) {
if (authentication != null && authentication instanceof AbstractOAuth2TokenAuthenticationToken<?>) {
var oauthAuth = (AbstractOAuth2TokenAuthenticationToken<?>) authentication;
if (!checkValidOrigin(request, oauthAuth)) {
response.sendError(403, "Request origin not valid");
return;
}
......@@ -76,7 +79,9 @@ public class OAuthClientOriginCheckFilter extends OncePerRequestFilter {
filterChain.doFilter(request, response);
}
private boolean checkValidOrigin(HttpServletRequest request, OAuth2Authentication authAuth) {
private boolean checkValidOrigin(HttpServletRequest request, AbstractOAuth2TokenAuthenticationToken<?> authAuth) {
var token = (Jwt) authAuth.getToken();
if (logger.isTraceEnabled()) {
logger.trace(request.getRequestURI());
for (String headerName : Collections.list(request.getHeaderNames())) {
......@@ -86,13 +91,13 @@ public class OAuthClientOriginCheckFilter extends OncePerRequestFilter {
String reqOrigin = request.getHeader("Origin");
String reqReferrer = request.getHeader("Referer"); // GET requests don't carry Origin?
if (authAuth.getOAuth2Request() != null) {
if (token != null) {
boolean isGet = request.getMethod().equalsIgnoreCase("get");
String clientId = authAuth.getOAuth2Request().getClientId();
var clientId = token.getClaimAsString("aud");
try {
Set<String> allowedOrigins = clientOriginsCache.get(clientId);
if (!allowedOrigins.isEmpty()) {
if (reqOrigin == null && reqReferrer == null) {
if (logger.isInfoEnabled()) {
......
/*
* Copyright 2018 Global Crop Diversity Trust
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.genesys.blocks.oauth;
import static org.hamcrest.CoreMatchers.*;
import static org.hamcrest.MatcherAssert.*;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*;
import org.genesys.blocks.oauth.model.OAuthClient;
import org.genesys.blocks.oauth.model.OAuthRole;
import org.genesys.blocks.security.rest.AbstractRestTest;
import org.junit.Before;
import org.junit.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import org.springframework.util.Base64Utils;
import org.springframework.web.context.WebApplicationContext;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
/**
* The Class OAuth2GrantTypeTest.
*
* @author Maxym Borodenko
*/
public class OAuth2GrantTypeTest extends AbstractRestTest {
@Autowired
private WebApplicationContext context;
private MockMvc mockMvc;
private static final String DEFAULT_CLIENT_ID = "my-trusted-client";
private static final String DEFAULT_CLIENT_SECRET = "my-secret";
private static final ObjectMapper objectMapper;
/** The password encoder. */
@Autowired
public PasswordEncoder passwordEncoder;
static {
objectMapper = new ObjectMapper();
objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);