package LinkFuture.Core.Auth.LDAPHelper;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.naming.*;
import javax.naming.directory.*;
import javax.naming.ldap.InitialLdapContext;
import javax.naming.ldap.LdapContext;
import java.util.*;
/*
      1) run command
     "%JAVA_HOME%/bin/keytool" -import -alias ldapCert -file ldap.crt -keystore "%JAVA_HOME%\jre\lib\security\cacerts"
     2) default password is changeit
     3) If you use default java keystore(%JAVA_HOME%\jre\lib\security\cacerts) , then you don't need set value on ldapCertificatePath
*/

/**
 * Created by Cyokin
 * on 5/12/2015.
 */
public class LDAPHelper {
    //region Properties
    final static Logger logger = LoggerFactory.getLogger(LDAPHelper.class);
    static String ldapPrincipal;
    static String ldapPassword;
    static String ldapUrl;
    static String ldapCertificatePath;
    static boolean ldapIsSSL;
    static String ldapFactory = "com.sun.jndi.ldap.LdapCtxFactory";
    static String ldapAuthenticationMethod = "Simple";
    static String ldapSearchBase = "ldap.search.base";

    public static String getLdapPrincipal() {
        return ldapPrincipal;
    }

    public void setLdapPrincipal(String ldapPrincipal) {
        this.ldapPrincipal = ldapPrincipal;
    }

    public static String getLdapPassword() {
        return ldapPassword;
    }

    public void setLdapPassword(String ldapPassword) {
        this.ldapPassword = ldapPassword;
    }

    public static String getLdapUrl() {
        return ldapUrl;
    }

    public void setLdapUrl(String ldapUrl) {
        this.ldapUrl = ldapUrl;
    }

    public static String getLdapCertificatePath() {
        return ldapCertificatePath;
    }

    public void setLdapCertificatePath(String ldapCertificatePath) {
        this.ldapCertificatePath = ldapCertificatePath;
        if (ldapCertificatePath != null && ldapCertificatePath.length() > 0) {
            System.setProperty("javax.net.ssl.trustStore", ldapCertificatePath);
        }
    }

    public static boolean getLdapIsSSL() {
        return ldapIsSSL;
    }

    public void setLdapIsSSL(boolean ldapIsSSL) {
        this.ldapIsSSL = ldapIsSSL;
    }

    public static String getLdapFactory() {
        return ldapFactory;
    }

    public void setLdapFactory(String ldapFactory) {
        this.ldapFactory = ldapFactory;
    }

    public static String getLdapAuthenticationMethod() {
        return ldapAuthenticationMethod;
    }

    public void setLdapAuthenticationMethod(String ldapAuthenticationMethod) {
        this.ldapAuthenticationMethod = ldapAuthenticationMethod;
    }

    public static String getLdapSearchBase() {
        return ldapSearchBase;
    }

    public void setLdapSearchBase(String ldapSearchBase) {
        this.ldapSearchBase = ldapSearchBase;
    }

    //endregion
    public static List<String> findWindowsDomainController(String ldapDomain) throws NamingException {
        List<String> servers = new ArrayList<>();
        Hashtable<String, String> env = new Hashtable<>();
        env.put(Context.INITIAL_CONTEXT_FACTORY, "com.sun.jndi.dns.DnsContextFactory");
        env.put("java.naming.provider.url", "dns:");
        DirContext ctx = new InitialDirContext(env);
        Attributes attributes = ctx.getAttributes("_ldap._tcp." + ldapDomain, new String[]{"SRV"});
        // that's how Windows domain controllers are registered in DNS
        Attribute a = attributes.get("SRV");
        for (int i = 0; i < a.size(); i++) {
            String srvRecord = a.get(i).toString();
            // each SRV record is in the format "0 100 389 dc1.company.com."
            // priority weight port server (space separated)
            String wrksrv = srvRecord.split(" ")[3];
            servers.add(wrksrv.substring(0, wrksrv.length() - 1));
        }
        ctx.close();
        return servers;
    }
    private static String findLdapAuthServer(String ldapDomain) throws NamingException {
        List<String> ldapAuthServers = findWindowsDomainController(ldapDomain);
        if (ldapAuthServers == null || ldapAuthServers.size()==0)
        {
            throw new NamingException("No domain controllers found");
        }
        return "ldap://" + ldapAuthServers.get(0);
    }

    private static String domainParse(String dcName) {
        StringBuffer sb = new StringBuffer(20);
        StringTokenizer st = new StringTokenizer(dcName, "DC=");
        while (st.hasMoreElements()) {
            String element = st.nextToken();
            if (element.endsWith(",")) {
                sb.append(element.substring(0, element.length() - 1) + ".");
            } else {
                sb.append(element);
            }
        }

        return sb.toString();
    }

    public static LDAPUser auth(String name, String password) throws NamingException {
        DirContext context = null;
        try {
            LDAPUser user = getUserInfo(name);
            String ldapDomain = domainParse(user.getDCName());
            if (user != null) {
                Properties props = new Properties();
                props.put(Context.INITIAL_CONTEXT_FACTORY, ldapFactory);
                props.put(Context.PROVIDER_URL, findLdapAuthServer(ldapDomain));
                props.put(Context.SECURITY_PRINCIPAL, user.getUserDN());
                props.put(Context.SECURITY_CREDENTIALS, password);
                if (ldapIsSSL) {
                    props.put(Context.SECURITY_PROTOCOL, "ssl");
                }
                context = new InitialDirContext(props);
                return user;
            }
        } catch (AuthenticationException e) {
            logger.error("Wrong password", e);
            return null;
        } catch (CommunicationException e) {
            logger.error("can't reach server", e);
        } catch (NamingException e) {
            logger.error("auth failed", e);
            return null;
        } finally {
            if (context != null) {
                context.close();
            }
        }
        return null;
    }

    public static LDAPUser getUserInfo(String alias) throws NamingException {
        LDAPUser user = null;
        String domain = null;
        String[] userAccount = alias.split("\\\\");
        if (userAccount.length == 2) {
            domain = userAccount[0];
            alias = userAccount[1];
        }
        Properties props = new Properties();
        props.put(Context.INITIAL_CONTEXT_FACTORY, ldapFactory);
        props.put(Context.PROVIDER_URL, ldapUrl);
        props.put(Context.SECURITY_AUTHENTICATION, ldapAuthenticationMethod);
        props.put(Context.SECURITY_PRINCIPAL, ldapPrincipal);
        props.put(Context.SECURITY_CREDENTIALS, ldapPassword);
        props.put("com.sun.jndi.ldap.connect.timeout", "2000");
        if (ldapIsSSL) {
            props.put(Context.SECURITY_PROTOCOL, "ssl");
            props.put("java.naming.ldap.factory.socket", "com.hibu.smml.security.factory.CustomSocketFactory");
        }
        LdapContext context = new InitialLdapContext(props, null);
        //DirContext context = new InitialDirContext(props);
        SearchControls ctrls = new SearchControls();
        ctrls.setReturningAttributes(new String[]{"distinguishedName", "displayName", "mail", "memberof", "msDS-SourceObjectDN"});
        ctrls.setSearchScope(SearchControls.SUBTREE_SCOPE);
        ctrls.setTimeLimit(3000);

        String filter = "(&(objectClass=user)(sAMAccountName=" + alias + "))";
        NamingEnumeration<SearchResult> answers = context.search(ldapSearchBase, filter, ctrls);
        if (answers != null && answers.hasMoreElements()) {
            javax.naming.directory.SearchResult result = answers.nextElement();
            Attributes attrs = result.getAttributes();
            user = new LDAPUser();
            user.setAlias(alias);
            user.setEmail(readLDAPAttr(attrs, "mail"));
            user.setNameInNameSpace(readLDAPAttr(attrs, "distinguishedName"));
            user.setUserDN(readLDAPAttr(attrs, "msDS-SourceObjectDN"));
            user.setDisplayName(readLDAPAttr(attrs, "displayName"));
            user.setDomain(domain);
            user.setGroups(readLDAPAttrList(attrs, "memberof"));
        }
        if (answers != null) {
            answers.close();
        }
        context.close();
        return user;
    }

    //region Utility
    private static String readLDAPAttr(Attributes attrs, String attrName) throws NamingException {
        Attribute attribute = attrs.get(attrName);
        if (attribute != null) {
            return attribute.get().toString();
        }
        return null;
    }

    private static List<String> readLDAPAttrList(Attributes attrs, String attrName) throws NamingException {
        Attribute attribute = attrs.get(attrName);
        if (attribute != null) {
            List<String> output = new ArrayList<>();
            for (NamingEnumeration<?> e = attribute.getAll(); e.hasMore(); ) {
                String mname = (String) e.next();
                output.add(getCN(mname));
            }
            return output;
        }
        return null;
    }

    private static String getCN(String cnName) {
        if (cnName != null && cnName.toUpperCase().startsWith("CN=")) {
            cnName = cnName.substring(3);
        }
        int position = cnName.indexOf(',');
        if (position == -1) {
            return cnName;
        } else {
            return cnName.substring(0, position);
        }
    }
    //endregion
}
