/*
 *
 *  * Copyright 2018-2024 huiche.org.
 *  *
 *  * Licensed under the Apache License, Version 2.0 (the "License");
 *  * you may not use this file except in compliance with the License.
 *  * You may obtain a copy of the License at
 *  *
 *  *      https://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS,
 *  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  * See the License for the specific language governing permissions and
 *  * limitations under the License.
 *
 */

package org.huiche.util;

import org.huiche.exception.HcException;
import org.huiche.support.LambdaMethod;
import org.huiche.support.SerializableFunction;
import org.huiche.support.SyntheticParameterizedType;

import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.*;
import java.net.URL;
import java.util.*;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.logging.Logger;

/**
 * @author Maning
 */
public interface ReflectUtil {
    Logger LOGGER = Logger.getLogger(ReflectUtil.class.getName());

    static Set<Class<?>> scanClassInJar(ClassLoader classLoader, String scanPackage, Predicate<Class<?>> filter) {
        Set<Class<?>> classes = new HashSet<>();
        try {
            Enumeration<URL> resources = classLoader.getResources("META-INF/MANIFEST.MF");
            while (resources.hasMoreElements()) {
                URL resource = resources.nextElement();
                String path = resource.getPath();
                if (!path.startsWith("jar:file")) {
                    continue;
                }
                path = path.substring(6, path.length() - 22);
                classes.addAll(scanClassInJar(new File(path), scanPackage, filter));
            }
        } catch (IOException e) {
            logException(e);
        }
        return classes;
    }

    private static Set<Class<?>> scanClassInJar(File file, String scanPackage, Predicate<Class<?>> filter) throws IOException {
        Set<Class<?>> classes = new HashSet<>();
        if (file.exists() && file.isFile() && file.getName().endsWith(".jar")) {
            try (JarFile jarFile = new JarFile(file)) {
                Enumeration<JarEntry> entries = jarFile.entries();
                while (entries.hasMoreElements()) {
                    JarEntry entry = entries.nextElement();
                    String name = entry.getName();
                    if (entry.isDirectory() || name.startsWith("META-INF") || !name.endsWith(".class")) {
                        continue;
                    }
                    String className = name.replaceAll("/", ".").substring(0, name.length() - 6);
                    if (isPackageMatch(className, scanPackage)) {
                        classes.addAll(loadClass(className, filter));
                    }
                }
            }
        }
        return classes;
    }

    static Set<Class<?>> scanClass(String scanPackage, Predicate<Class<?>> filter) {
        return scanClass(Thread.currentThread().getContextClassLoader(), scanPackage, filter);
    }

    static Set<Class<?>> scanClass(ClassLoader classLoader, String scanPackage, Predicate<Class<?>> filter) {
        Set<Class<?>> classes = new HashSet<>();
        classes.addAll(scanClassInSrc(classLoader, scanPackage, filter));
        classes.addAll(scanClassInJar(classLoader, scanPackage, filter));
        return classes;
    }

    static Set<Class<?>> scanClassInSrc(ClassLoader classLoader, String scanPackage, Predicate<Class<?>> filter) {
        Set<Class<?>> classes = new HashSet<>();
        try {
            Enumeration<URL> resources = classLoader.getResources(".");
            while (resources.hasMoreElements()) {
                URL resource = resources.nextElement();
                File file = new File(resource.getFile());
                // classes root, eg: xxx/target/classes/
                if (file.isDirectory()) {
                    File[] files = file.listFiles();
                    if (files != null) {
                        for (File f : files) {
                            if (f.isDirectory()) {
                                String packageName = f.getName();
                                if (isPackageMatch(packageName, scanPackage)) {
                                    classes.addAll(scanClassInDir(f, packageName, scanPackage, filter));
                                }
                            } else {
                                if (isPackageMatch("", scanPackage)) {
                                    classes.addAll(loadClass(f, "", filter));
                                }
                            }
                        }
                    }
                } else {
                    if (isPackageMatch("", scanPackage)) {
                        classes.addAll(loadClass(file, "", filter));
                    }
                }
            }
        } catch (IOException e) {
            logException(e);
        }
        return classes;
    }

    private static boolean isPackageMatch(String packageName, String scanPackage) {
        if (scanPackage == null || scanPackage.isEmpty()) {
            return true;
        } else {
            return packageName.startsWith(scanPackage) || scanPackage.startsWith(packageName);
        }
    }

    private static Set<Class<?>> loadClass(File file, String packageName, Predicate<Class<?>> filter) {
        String fileName = file.getName();
        if (fileName.endsWith(".class")) {
            String className = packageName + '.' + fileName.substring(0, fileName.length() - 6);
            return loadClass(className, filter);
        }
        return Collections.emptySet();
    }

    private static Set<Class<?>> loadClass(String className, Predicate<Class<?>> filter) {
        try {
            Class<?> clazz = Class.forName(className);
            if (filter.test(clazz)) {
                return Collections.singleton(clazz);
            }
        } catch (Error ignored) {
        } catch (Exception e) {
            logException(e);
        }
        return Collections.emptySet();
    }

    private static void logException(Throwable e) {
        LOGGER.warning(() -> {
            StringWriter sw = new StringWriter();
            e.printStackTrace(new PrintWriter(sw));
            return sw.toString();
        });
    }

    private static Set<Class<?>> scanClassInDir(File dir, String packageName, String scanPackage, Predicate<Class<?>> filter) {
        Set<Class<?>> classes = new HashSet<>();
        File[] files = dir.listFiles();
        if (files != null) {
            for (File file : files) {
                if (file.isDirectory()) {
                    if (isPackageMatch(packageName + "." + file.getName(), scanPackage)) {
                        classes.addAll(scanClassInDir(file, packageName + "." + file.getName(), scanPackage, filter));
                    }
                } else {
                    if (isPackageMatch(packageName, scanPackage)) {
                        classes.addAll(loadClass(file, packageName, filter));
                    }
                }
            }
        }
        return classes;
    }

    static Class<?> getGenericsClass(Class<?> clazz) {
        for (Type t : clazz.getGenericInterfaces()) {
            if (t instanceof ParameterizedType pt) {
                for (Type at : pt.getActualTypeArguments()) {
                    if (at instanceof Class<?> ct) {
                        return ct;
                    }
                }
            }
        }
        Class<?> superClazz = clazz.getSuperclass();
        if (superClazz != null && superClazz != Objects.class) {
            return getGenericsClass(superClazz);
        }
        return null;
    }

    static ParameterizedType parameterizedType(Class<?> raw, Type... args) {
        return new SyntheticParameterizedType(raw, args);
    }

    @SuppressWarnings("unchecked")
    static <T> Constructor<T> getRecordConstructor(Class<T> clazz) {
        Field[] fields = clazz.getDeclaredFields();
        int length = fields.length;
        Constructor<?>[] constructors = clazz.getDeclaredConstructors();
        if (constructors.length == 0) {
            return null;
        }
        List<Constructor<?>> list = Arrays.stream(constructors).filter(i -> i.getParameterCount() == length).toList();
        if (list.isEmpty()) {
            return null;
        }
        if (list.size() == 1) {
            Constructor<?> constructor = list.get(0);
            constructor.setAccessible(true);
            return (Constructor<T>) constructor;
        }
        for (Constructor<?> c : list) {
            boolean ok = true;
            for (Parameter p : c.getParameters()) {
                if (Arrays.stream(fields).noneMatch(f -> f.getName().equals(p.getName()) && f.getType().equals(p.getType()))) {
                    ok = false;
                    break;
                }
            }
            if (ok) {
                c.setAccessible(true);
                return (Constructor<T>) c;
            }
        }
        return null;
    }

    @SuppressWarnings("unchecked")
    static <T> Constructor<T> getDefaultConstructor(Class<T> clazz) {
        Constructor<?>[] constructors = clazz.getDeclaredConstructors();
        int length = constructors.length;
        if (length == 0) {
            return null;
        } else {
            if (length > 1) {
                Arrays.sort(constructors, Comparator.comparing(Constructor::getParameterCount));
            }
            Constructor<?> constructor = constructors[0];
            constructor.setAccessible(true);
            return (Constructor<T>) constructor;
        }
    }

    static <T> Constructor<T> getEntityConstructor(Class<T> clazz, boolean isRecord) {
        Constructor<T> constructor = isRecord ? getRecordConstructor(clazz) : getDefaultConstructor(clazz);
        if (constructor == null) {
            throw new HcException("Class: " + clazz.getCanonicalName() + " can not found public constructor");
        }
        return constructor;
    }

    static List<Field> getEntityFields(Class<?> clazz, boolean isRecord) {
        List<Field> fields;
        if (isRecord) {
            fields = Arrays.asList(clazz.getDeclaredFields());
        } else {
            fields = ReflectUtil.getAllFieldsNotStaticOrTransient(clazz);
        }
        fields.forEach(field -> field.setAccessible(true));
        return fields;
    }

    static Function<Map<String, Object>, Object> getInstanceFunction(Constructor<?> constructor, List<Parameter> parameters, List<Field> fields, boolean isRecord) {
        return map -> {
            int parameterCount = parameters.size();
            Object[] values = new Object[parameterCount];
            for (int i = 0; i < parameterCount; i++) {
                values[i] = map.get(parameters.get(i).getName());
            }
            try {
                Object t = constructor.newInstance(values);
                if (!isRecord) {
                    for (Field field : fields) {
                        Object v = map.get(field.getName());
                        if (v != null) {
                            field.set(t, v);
                        }
                    }
                }
                return t;
            } catch (Exception e) {
                throw new HcException("failed to create an instance, errorMessage:" + e.getLocalizedMessage());
            }
        };
    }

    static List<Field> getAllFields(Class<?> clazz, Predicate<Field> predicate) {
        List<Field> fields = new ArrayList<>();
        Class<?> superClass = clazz.getSuperclass();
        if (superClass != null && superClass != Object.class) {
            fields.addAll(getAllFields(superClass, predicate));
        }
        List<Field> list = Arrays.asList(clazz.getDeclaredFields());
        if (predicate != null) {
            fields.addAll(list.stream().filter(predicate).toList());
        } else {
            fields.addAll(list);
        }
        return fields;
    }

    static List<Field> getAllFieldsNotStaticOrTransient(Class<?> clazz) {
        return getAllFields(clazz, field -> {
            int modifiers = field.getModifiers();
            return !Modifier.isStatic(modifiers) && !Modifier.isTransient(modifiers);
        });
    }

    static <T> LambdaMethod getLambdaMethod(SerializableFunction<T, ?> lambda) {
        try {
            Method method = lambda.getClass().getDeclaredMethod("writeReplace");
            method.setAccessible(true);
            SerializedLambda serializedLambda = (SerializedLambda) method.invoke(lambda);
            String className = serializedLambda.getImplClass().replaceAll("/", ".");
            return new LambdaMethod(Class.forName(className), serializedLambda.getImplMethodName());
        } catch (Exception ignored) {
        }
        return null;
    }

    static Field getFieldByMethodName(Class<?> clazz, String methodName) {
        try {
            if (clazz.isRecord()) {
                return clazz.getDeclaredField(methodName);
            } else {
                List<Field> fields = getAllFieldsNotStaticOrTransient(clazz);
                String getName = null;
                String isName = null;
                if (methodName.length() > 3 && methodName.startsWith("get")) {
                    getName = methodName.substring(3, 4).toLowerCase() + methodName.substring(4);
                } else if (methodName.length() > 2 && methodName.startsWith("is")) {
                    isName = methodName.substring(2, 3).toLowerCase() + methodName.substring(3);
                }
                for (Field field : fields) {
                    if (field.getName().equals(getName) || field.getName().equals(isName) || field.getName().equals(methodName)) {
                        return field;
                    }
                }
            }
        } catch (NoSuchFieldException ignored) {
        }
        return null;
    }
}
