/*
 * Copyright 2017 Global Crop Diversity Trust
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.genesys.blocks.oauth.service;

import java.io.UnsupportedEncodingException;
import java.math.BigInteger;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

import com.querydsl.core.types.Predicate;

import org.apache.commons.lang3.RandomStringUtils;
import org.apache.commons.lang3.StringUtils;
import org.genesys.blocks.oauth.model.AccessToken;
import org.genesys.blocks.oauth.model.OAuthClient;
import org.genesys.blocks.oauth.model.OAuthRole;
import org.genesys.blocks.oauth.model.QOAuthClient;
import org.genesys.blocks.oauth.model.RefreshToken;
import org.genesys.blocks.oauth.persistence.AccessTokenRepository;
import org.genesys.blocks.oauth.persistence.OAuthClientRepository;
import org.genesys.blocks.oauth.persistence.RefreshTokenRepository;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.cache.annotation.CacheEvict;
import org.springframework.cache.annotation.Cacheable;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Sort;
import org.springframework.security.oauth2.common.DefaultExpiringOAuth2RefreshToken;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.OAuth2RefreshToken;
import org.springframework.security.oauth2.common.util.SerializationUtils;
import org.springframework.security.oauth2.provider.ClientDetails;
import org.springframework.security.oauth2.provider.ClientRegistrationException;
import org.springframework.security.oauth2.provider.NoSuchClientException;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.token.AuthenticationKeyGenerator;
import org.springframework.security.oauth2.provider.token.DefaultAuthenticationKeyGenerator;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

/**
 * The Class OAuthServiceImpl.
 */
@Service
@Transactional(readOnly = true)
public class OAuthServiceImpl implements OAuthClientDetailsService, OAuthTokenStoreService {

	/** The Constant LOG. */
	private static final Logger LOG = LoggerFactory.getLogger(OAuthServiceImpl.class);

	/** The hostname. */
	@Value("${host.name}")
	private String hostname;

	/** The oauth client repository. */
	@Autowired
	private OAuthClientRepository oauthClientRepository;

	/** The refresh token repository. */
	@Autowired
	private RefreshTokenRepository refreshTokenRepository;

	/** The access token repository. */
	@Autowired
	private AccessTokenRepository accessTokenRepository;

	/** The authentication key generator. */
	private AuthenticationKeyGenerator authenticationKeyGenerator = new DefaultAuthenticationKeyGenerator();

	/**
	 * Sets the authentication key generator.
	 *
	 * @param authenticationKeyGenerator the new authentication key generator
	 */
	public void setAuthenticationKeyGenerator(final AuthenticationKeyGenerator authenticationKeyGenerator) {
		this.authenticationKeyGenerator = authenticationKeyGenerator;
	}

	/*
	 * (non-Javadoc)
	 * @see org.springframework.security.oauth2.provider.ClientDetailsService#
	 * loadClientByClientId(java.lang.String)
	 */
	@Override
	@Cacheable(cacheNames = { "oauthclient" }, key = "#clientId", unless = "#result == null")
	public ClientDetails loadClientByClientId(final String clientId) throws ClientRegistrationException {
		final OAuthClient client = oauthClientRepository.findByClientId(clientId);
		if (client == null) {
			throw new NoSuchClientException(clientId);
		}
		return lazyLoad(client);
	}

	private OAuthClient lazyLoad(OAuthClient client) {
		if (client != null) {
			client.getRoles().size();
		}
		return client;
	}

	/*
	 * (non-Javadoc)
	 * @see org.springframework.security.oauth2.provider.token.TokenStore#
	 * storeAccessToken(org.springframework.security.oauth2.common.
	 * OAuth2AccessToken,
	 * org.springframework.security.oauth2.provider.OAuth2Authentication)
	 */
	@Override
	@Transactional
	@CacheEvict(cacheNames = { "oauthaccesstoken", "oauthaccesstokenauth" }, allEntries = true)
	public void storeAccessToken(final OAuth2AccessToken token, final OAuth2Authentication authentication) {
		String refreshToken = null;
		if (token.getRefreshToken() != null) {
			refreshToken = token.getRefreshToken().getValue();
		}

		if (readAccessToken(token.getValue()) != null) {
			removeAccessToken(token.getValue());
		}

		accessTokenRepository.deleteByAuthenticationId(authenticationKeyGenerator.extractKey(authentication));

		// "insert into oauth_access_token (token_id, token, authentication_id,
		// user_name, client_id, authentication,
		// refresh_token)
		// values (?, ?, ?, ?, ?, ?, ?)";
		final AccessToken storedToken = new AccessToken();
		storedToken.setTokenId(extractTokenKey(token.getValue()));
		storedToken.setToken(serializeAccessToken(token));
		storedToken.setAuthenticationId(authenticationKeyGenerator.extractKey(authentication));
		storedToken.setUsername(authentication.isClientOnly() ? null : authentication.getName());
		storedToken.setClientId(authentication.getOAuth2Request().getClientId());
		storedToken.setAuthentication(serializeAuthentication(authentication));
		storedToken.setRefreshToken(extractTokenKey(refreshToken));
		storedToken.setExpiration(token.getExpiration());

		accessTokenRepository.save(storedToken);
	}

	/*
	 * (non-Javadoc)
	 * @see
	 * org.springframework.security.oauth2.provider.token.TokenStore#readAccessToken
	 * (java.lang.String)
	 */
	@Override
	@Cacheable(cacheNames = { "oauthaccesstoken" }, key = "#tokenValue", unless = "#result == null")
	public OAuth2AccessToken readAccessToken(final String tokenValue) {
		OAuth2AccessToken accessToken = null;

		LOG.trace("Reading access token value={} key={}", tokenValue, extractTokenKey(tokenValue));

		final AccessToken storedToken = accessTokenRepository.findOne(extractTokenKey(tokenValue));
		if (storedToken != null) {
			accessToken = deserializeAccessToken(storedToken.getToken());
		} else {
			if (LOG.isInfoEnabled()) {
				LOG.info("Failed to find access token for token " + tokenValue);
			}
		}

		return accessToken;
	}

	/*
	 * (non-Javadoc)
	 * @see org.springframework.security.oauth2.provider.token.TokenStore#
	 * removeAccessToken(org.springframework.security.oauth2.common.
	 * OAuth2AccessToken)
	 */
	@Override
	@Transactional
	@CacheEvict(cacheNames = { "oauthaccesstoken", "oauthaccesstokenauth" }, allEntries = true)
	public void removeAccessToken(final OAuth2AccessToken token) {
		removeAccessToken(token.getValue());
	}

	/*
	 * (non-Javadoc)
	 * @see
	 * org.genesys.blocks.oauth.service.OAuthTokenStoreService#removeAccessToken(
	 * java.lang.String)
	 */
	@Override
	@Transactional
	@CacheEvict(cacheNames = { "oauthaccesstoken", "oauthaccesstokenauth" }, allEntries = true)
	public void removeAccessToken(final String tokenValue) {
		accessTokenRepository.delete(extractTokenKey(tokenValue));
	}

	/*
	 * (non-Javadoc)
	 * @see org.springframework.security.oauth2.provider.token.TokenStore#
	 * readAuthentication(org.springframework.security.oauth2.common.
	 * OAuth2AccessToken)
	 */
	@Override
	@Cacheable(cacheNames = { "oauthaccesstokenauth" }, key = "#token.value", unless = "#result == null")
	public OAuth2Authentication readAuthentication(final OAuth2AccessToken token) {
		return readAuthentication(token.getValue());
	}

	/*
	 * (non-Javadoc)
	 * @see org.springframework.security.oauth2.provider.token.TokenStore#
	 * readAuthentication(java.lang.String)
	 */
	@Override
	@Cacheable(cacheNames = { "oauthaccesstokenauth" }, key = "#tokenValue", unless = "#result == null")
	public OAuth2Authentication readAuthentication(final String tokenValue) {
		OAuth2Authentication authentication = null;
		LOG.trace("TokenValue={} key={}", tokenValue, extractTokenKey(tokenValue));
		final AccessToken storedToken = accessTokenRepository.findOne(extractTokenKey(tokenValue));
		if (storedToken != null) {
			authentication = deserializeAuthentication(storedToken.getAuthentication());
		} else {
			if (LOG.isInfoEnabled()) {
				LOG.info("Failed to find access token for token " + tokenValue);
			}
		}

		return authentication;
	}

	/*
	 * (non-Javadoc)
	 * @see org.springframework.security.oauth2.provider.token.TokenStore#
	 * storeRefreshToken(org.springframework.security.oauth2.common.
	 * OAuth2RefreshToken,
	 * org.springframework.security.oauth2.provider.OAuth2Authentication)
	 */
	@Override
	@Transactional
	@CacheEvict(cacheNames = { "oauthaccesstoken", "oauthaccesstokenauth" }, allEntries = true)
	public void storeRefreshToken(final OAuth2RefreshToken refreshToken, final OAuth2Authentication authentication) {
		// insert into oauth_refresh_token (token_id, token, authentication) values (?,
		// ?, ?)
		final RefreshToken storedToken = new RefreshToken();
		storedToken.setClientId(authentication.getOAuth2Request().getClientId());
		storedToken.setTokenId(extractTokenKey(refreshToken.getValue()));
		storedToken.setToken(serializeRefreshToken(refreshToken));
		storedToken.setAuthentication(serializeAuthentication(authentication));
		storedToken.setUsername(authentication.isClientOnly() ? null : authentication.getUserAuthentication().getName());
		if (refreshToken instanceof DefaultExpiringOAuth2RefreshToken) {
			final DefaultExpiringOAuth2RefreshToken expRefreshToken = (DefaultExpiringOAuth2RefreshToken) refreshToken;
			storedToken.setExpiration(expRefreshToken.getExpiration());
		}

		refreshTokenRepository.save(storedToken);
	}

	/*
	 * (non-Javadoc)
	 * @see org.springframework.security.oauth2.provider.token.TokenStore#
	 * readRefreshToken(java.lang.String)
	 */
	@Override
	public OAuth2RefreshToken readRefreshToken(final String token) {
		OAuth2RefreshToken refreshToken = null;

		final RefreshToken storedToken = refreshTokenRepository.findOne(extractTokenKey(token));
		if (storedToken != null) {
			refreshToken = deserializeRefreshToken(storedToken.getToken());
		} else {
			if (LOG.isInfoEnabled()) {
				LOG.info("Failed to find refresh token for token " + token);
			}
		}

		return refreshToken;
	}

	/*
	 * (non-Javadoc)
	 * @see org.springframework.security.oauth2.provider.token.TokenStore#
	 * removeRefreshToken(org.springframework.security.oauth2.common.
	 * OAuth2RefreshToken)
	 */
	@Override
	@Transactional
	@CacheEvict(cacheNames = { "oauthaccesstoken", "oauthaccesstokenauth" }, allEntries = true)
	public void removeRefreshToken(final OAuth2RefreshToken token) {
		removeRefreshToken(token.getValue());
	}

	/*
	 * (non-Javadoc)
	 * @see
	 * org.genesys.blocks.oauth.service.OAuthTokenStoreService#removeRefreshToken(
	 * java.lang.String)
	 */
	@Override
	@Transactional
	@CacheEvict(cacheNames = { "oauthaccesstoken", "oauthaccesstokenauth" }, allEntries = true)
	public void removeRefreshToken(final String token) {
		refreshTokenRepository.delete(extractTokenKey(token));
	}

	/*
	 * (non-Javadoc)
	 * @see org.springframework.security.oauth2.provider.token.TokenStore#
	 * readAuthenticationForRefreshToken(org.springframework.security.oauth2.common.
	 * OAuth2RefreshToken)
	 */
	@Override
	public OAuth2Authentication readAuthenticationForRefreshToken(final OAuth2RefreshToken token) {
		return readAuthenticationForRefreshToken(token.getValue());
	}

	/**
	 * Read authentication for refresh token.
	 *
	 * @param token the token
	 * @return the o auth 2 authentication
	 */
	public OAuth2Authentication readAuthenticationForRefreshToken(final String token) {
		OAuth2Authentication authentication = null;

		final RefreshToken storedToken = refreshTokenRepository.findOne(extractTokenKey(token));
		if (storedToken != null) {
			authentication = deserializeAuthentication(storedToken.getAuthentication());
		} else {
			if (LOG.isInfoEnabled()) {
				LOG.info("Failed to find access token for token " + token);
			}
		}

		return authentication;
	}

	/*
	 * (non-Javadoc)
	 * @see org.springframework.security.oauth2.provider.token.TokenStore#
	 * removeAccessTokenUsingRefreshToken(org.springframework.security.oauth2.common
	 * .OAuth2RefreshToken)
	 */
	@Override
	@Transactional
	@CacheEvict(cacheNames = { "oauthaccesstoken", "oauthaccesstokenauth" }, allEntries = true)
	public void removeAccessTokenUsingRefreshToken(final OAuth2RefreshToken refreshToken) {
		removeAccessTokenUsingRefreshToken(refreshToken.getValue());
	}

	/**
	 * Removes the access token using refresh token.
	 *
	 * @param refreshToken the refresh token
	 */
	@Transactional
	@CacheEvict(cacheNames = { "oauthaccesstoken", "oauthaccesstokenauth" }, allEntries = true)
	public void removeAccessTokenUsingRefreshToken(final String refreshToken) {
		accessTokenRepository.deleteByRefreshToken(extractTokenKey(refreshToken));
	}

	/*
	 * (non-Javadoc)
	 * @see
	 * org.springframework.security.oauth2.provider.token.TokenStore#getAccessToken(
	 * org.springframework.security.oauth2.provider.OAuth2Authentication)
	 */
	@Override
	public OAuth2AccessToken getAccessToken(final OAuth2Authentication authentication) {
		OAuth2AccessToken accessToken = null;
		final String key = authenticationKeyGenerator.extractKey(authentication);

		LOG.trace("auth={} key={}", authentication, key);
		final AccessToken storedToken = accessTokenRepository.findByAuthenticationId(key);
		if (storedToken != null) {
			accessToken = deserializeAccessToken(storedToken.getToken());
			final OAuth2Authentication auth = readAuthentication(accessToken.getValue());

			if ((accessToken != null) && (auth != null) && !key.equals(authenticationKeyGenerator.extractKey(auth))) {
				removeAccessToken(accessToken.getValue());
				// Keep the store consistent (maybe the same user is represented by this
				// authentication but the details have
				// changed)
				storeAccessToken(accessToken, authentication);
			}
		} else {
			if (LOG.isDebugEnabled()) {
				LOG.debug("Failed to find access token for authentication={}", authentication);
			}
		}

		return accessToken;
	}

	/*
	 * (non-Javadoc)
	 * @see org.springframework.security.oauth2.provider.token.TokenStore#
	 * findTokensByClientIdAndUserName(java.lang.String, java.lang.String)
	 */
	@Override
	public Collection<OAuth2AccessToken> findTokensByClientIdAndUserName(final String clientId, final String username) {
		return accessTokenRepository.findByClientIdAndUsername(clientId, username).stream().filter(at -> at != null).map(at -> deserializeAccessToken(at.getToken())).collect(Collectors
			.toList());
	}

	/*
	 * (non-Javadoc)
	 * @see org.springframework.security.oauth2.provider.token.TokenStore#
	 * findTokensByClientId(java.lang.String)
	 */
	@Override
	public Collection<OAuth2AccessToken> findTokensByClientId(final String clientId) {
		return accessTokenRepository.findByClientId(clientId).stream().filter(at -> at != null).map(at -> deserializeAccessToken(at.getToken())).collect(Collectors.toList());
	}

	/**
	 * Extract token key.
	 *
	 * @param value the value
	 * @return the string
	 */
	protected String extractTokenKey(final String value) {
		if (value == null) {
			return null;
		}
		MessageDigest digest;
		try {
			digest = MessageDigest.getInstance("MD5");
		} catch (final NoSuchAlgorithmException e) {
			throw new IllegalStateException("MD5 algorithm not available.  Fatal (should be in the JDK).");
		}

		try {
			final byte[] bytes = digest.digest(value.getBytes("UTF-8"));
			return String.format("%032x", new BigInteger(1, bytes));
		} catch (final UnsupportedEncodingException e) {
			throw new IllegalStateException("UTF-8 encoding not available.  Fatal (should be in the JDK).");
		}
	}

	/**
	 * Serialize access token.
	 *
	 * @param token the token
	 * @return the byte[]
	 */
	protected byte[] serializeAccessToken(final OAuth2AccessToken token) {
		try {
			return SerializationUtils.serialize(token);
		} catch (final Throwable e) {
			LOG.warn(e.getMessage() + ". Returning null.");
			return null;
		}
	}

	/**
	 * Serialize refresh token.
	 *
	 * @param token the token
	 * @return the byte[]
	 */
	protected byte[] serializeRefreshToken(final OAuth2RefreshToken token) {
		try {
			return SerializationUtils.serialize(token);
		} catch (final Throwable e) {
			LOG.warn(e.getMessage() + ". Returning null.");
			return null;
		}
	}

	/**
	 * Serialize authentication.
	 *
	 * @param authentication the authentication
	 * @return the byte[]
	 */
	protected byte[] serializeAuthentication(final OAuth2Authentication authentication) {
		try {
			return SerializationUtils.serialize(authentication);
		} catch (final Throwable e) {
			LOG.warn(e.getMessage() + ". Returning null.");
			return null;
		}
	}

	/**
	 * Deserialize access token.
	 *
	 * @param token the token
	 * @return the o auth 2 access token
	 */
	protected OAuth2AccessToken deserializeAccessToken(final byte[] token) {
		try {
			return SerializationUtils.deserialize(token);
		} catch (final Throwable e) {
			LOG.warn(e.getMessage() + ". Returning null.");
			return null;
		}
	}

	/**
	 * Deserialize refresh token.
	 *
	 * @param token the token
	 * @return the o auth 2 refresh token
	 */
	protected OAuth2RefreshToken deserializeRefreshToken(final byte[] token) {
		try {
			return SerializationUtils.deserialize(token);
		} catch (final Throwable e) {
			LOG.warn(e.getMessage() + ". Returning null.");
			return null;
		}
	}

	/**
	 * Deserialize authentication.
	 *
	 * @param authentication the authentication
	 * @return the o auth 2 authentication
	 */
	protected OAuth2Authentication deserializeAuthentication(final byte[] authentication) {
		try {
			return SerializationUtils.deserialize(authentication);
		} catch (final Throwable e) {
			LOG.warn(e.getMessage() + ". Returning null.");
			return null;
		}
	}

	/*
	 * (non-Javadoc)
	 * @see
	 * org.genesys.blocks.oauth.service.OAuthClientDetailsService#listClientDetails(
	 * )
	 */
	@Override
	public List<OAuthClient> listClientDetails() {
		return oauthClientRepository.findAll(new Sort("clientId"));
	}

	/*
	 * (non-Javadoc)
	 * @see org.genesys.blocks.oauth.service.OAuthTokenStoreService#
	 * findAccessTokensByClientId(java.lang.String)
	 */
	@Override
	public List<AccessToken> findAccessTokensByClientId(final String clientId) {
		return accessTokenRepository.findByClientId(clientId);
	}

	/*
	 * (non-Javadoc)
	 * @see org.genesys.blocks.oauth.service.OAuthTokenStoreService#
	 * findRefreshTokensByClientId(java.lang.String)
	 */
	@Override
	public List<RefreshToken> findRefreshTokensByClientId(final String clientId) {
		return refreshTokenRepository.findByClientId(clientId);
	}

	/*
	 * (non-Javadoc)
	 * @see
	 * org.genesys.blocks.oauth.service.OAuthTokenStoreService#findTokensByUserUuid(
	 * java.lang.String)
	 */
	@Override
	public List<AccessToken> findTokensByUserUuid(final String uuid) {
		return accessTokenRepository.findByUsername(uuid);
	}

	/*
	 * (non-Javadoc)
	 * @see
	 * org.genesys.blocks.oauth.service.OAuthClientDetailsService#getClient(java.
	 * lang.String)
	 */
	@Override
	public OAuthClient getClient(final String clientId) {
		OAuthClient client = oauthClientRepository.findByClientId(clientId);
		if (client != null)
			client.getRoles().size();
		return client;
	}

	/*
	 * (non-Javadoc)
	 * @see
	 * org.genesys.blocks.oauth.service.OAuthClientDetailsService#removeClient(org.
	 * genesys.blocks.oauth.model.OAuthClient)
	 */
	@Override
	@Transactional
	public OAuthClient removeClient(final OAuthClient client) {
		oauthClientRepository.delete(client);
		return client;
	}

	/*
	 * (non-Javadoc)
	 * @see
	 * org.genesys.blocks.oauth.service.OAuthClientDetailsService#addClient(java.
	 * lang.String, java.lang.String, java.lang.String, java.lang.Integer,
	 * java.lang.Integer)
	 */
	@Override
	@Transactional
	public OAuthClient addClient(final String title, final String description, final String redirectUris, final Integer accessTokenValidity, final Integer refreshTokenValidity) {
		final String clientId = RandomStringUtils.randomAlphanumeric(5) + "." + RandomStringUtils.randomAlphanumeric(20) + "@" + hostname;
		final String clientSecret = RandomStringUtils.randomAlphanumeric(32);

		final OAuthClient client = new OAuthClient();
		client.setTitle(title);
		client.setDescription(description);
		client.setRedirect(StringUtils.defaultIfBlank(redirectUris, null));
		client.setAccessTokenValidity(accessTokenValidity);
		client.setRefreshTokenValidity(refreshTokenValidity);
		client.setClientId(clientId);
		client.setClientSecret(clientSecret);
		client.getScope().add("read");
		client.getScope().add("write");
		client.getAuthorizedGrantTypes().add("authorization_code");
		client.getAuthorizedGrantTypes().add("refresh_token");
		client.getRoles().add(OAuthRole.CLIENT);

		return oauthClientRepository.save(client);
	}
	
	/* (non-Javadoc)
	 * @see org.genesys.blocks.oauth.service.OAuthClientDetailsService#addClient(org.genesys.blocks.oauth.model.OAuthClient)
	 */
	@Override
	@Transactional
	public OAuthClient addClient(OAuthClient client) {
		final String clientId = RandomStringUtils.randomAlphanumeric(5) + "." + RandomStringUtils.randomAlphanumeric(20) + "@" + hostname;
		final String clientSecret = RandomStringUtils.randomAlphanumeric(32);

		final OAuthClient newClient = new OAuthClient();
		newClient.apply(client);
		newClient.setClientId(clientId);
		newClient.setClientSecret(clientSecret);

		return lazyLoad(oauthClientRepository.save(newClient));
	}

	/*
	 * (non-Javadoc)
	 * @see
	 * org.genesys.blocks.oauth.service.OAuthClientDetailsService#updateClient(long,
	 * int, org.genesys.blocks.oauth.model.OAuthClient)
	 */
	@Override
	@Transactional
	public OAuthClient updateClient(final long id, final int version, final OAuthClient updates) {
		OAuthClient client = oauthClientRepository.findByIdAndVersion(id, version);
		client.apply(updates);
		return lazyLoad(oauthClientRepository.save(client));
	}

	@Override
	public List<OAuthClient> autocompleteClients(final String title) {
		if (StringUtils.isBlank(title) || title.length() < 4)
			return Collections.emptyList();

		LOG.debug("Autocomplete for={}", title);
		Predicate predicate = QOAuthClient.oAuthClient.title.startsWithIgnoreCase(title);
		return oauthClientRepository.findAll(predicate, new PageRequest(0, 10, new Sort("title"))).getContent();
	}
}
