package org.neo4j.gds;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.assertj.core.api.Condition;
import org.assertj.core.api.HamcrestCondition;
import org.assertj.core.api.ObjectAssert;
import org.assertj.core.api.SoftAssertions;
import org.hamcrest.Matcher;
import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.function.Executable;
import org.junit.jupiter.params.provider.Arguments;
import org.neo4j.gds.api.CSRGraph;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.NodeMapping;
import org.neo4j.gds.canonization.CanonicalAdjacencyMatrix;
import org.neo4j.gds.compat.GraphDatabaseApiProxy;
import org.neo4j.gds.core.Aggregation;
import org.neo4j.gds.core.GraphDimensions;
import org.neo4j.gds.core.loading.construction.GraphFactory;
import org.neo4j.gds.core.loading.construction.NodesBuilder;
import org.neo4j.gds.core.utils.mem.AllocationTracker;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.extension.GdlSupportExtension;
import org.neo4j.gds.extension.IdFunction;
import org.neo4j.gds.extension.TestGraph;
import org.neo4j.gds.gdl.GdlFactory;
import org.neo4j.gds.gdl.ImmutableGraphCreateFromGdlConfig;
import org.neo4j.gds.transaction.TransactionContext;
import org.neo4j.gds.utils.StringFormatting;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.TransactionTerminatedException;
import org.neo4j.internal.kernel.api.security.SecurityContext;
import org.neo4j.kernel.api.exceptions.Status;
import org.neo4j.kernel.internal.GraphDatabaseAPI;

/* loaded from: input_file:org/neo4j/gds/TestSupport.class */
public final class TestSupport {

    /* renamed from: org.neo4j.gds.TestSupport$1, reason: invalid class name */
    /* loaded from: input_file:org/neo4j/gds/TestSupport$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$neo4j$gds$core$Aggregation = new int[Aggregation.values().length];

        static {
            try {
                $SwitchMap$org$neo4j$gds$core$Aggregation[Aggregation.SINGLE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$neo4j$gds$core$Aggregation[Aggregation.SUM.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$neo4j$gds$core$Aggregation[Aggregation.MIN.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$neo4j$gds$core$Aggregation[Aggregation.MAX.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$neo4j$gds$core$Aggregation[Aggregation.COUNT.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
        }
    }

    private TestSupport() {
    }

    public static Stream<Orientation> allDirectedProjections() {
        return Stream.of((Object[]) new Orientation[]{Orientation.NATURAL, Orientation.REVERSE});
    }

    public static <T> Supplier<Stream<Arguments>> toArguments(Supplier<Stream<T>> supplier) {
        return () -> {
            return ((Stream) supplier.get()).map(obj -> {
                return Arguments.of(new Object[]{obj});
            });
        };
    }

    public static <T> Supplier<Stream<Arguments>> toArgumentsFlat(Supplier<Stream<List<T>>> supplier) {
        return () -> {
            return ((Stream) supplier.get()).map((v0) -> {
                return v0.toArray();
            }).map(Arguments::of);
        };
    }

    @SafeVarargs
    public static Stream<Arguments> crossArguments(Supplier<Stream<Arguments>> supplier, Supplier<Stream<Arguments>>... supplierArr) {
        return (Stream) ((Supplier) Arrays.stream(supplierArr).reduce(supplier, (supplier2, supplier3) -> {
            return () -> {
                return crossArguments((Supplier<Stream<Arguments>>) supplier2, (Supplier<Stream<Arguments>>) supplier3);
            };
        })).get();
    }

    public static Stream<Arguments> crossArguments(Supplier<Stream<Arguments>> supplier, Supplier<Stream<Arguments>> supplier2) {
        return supplier.get().flatMap(arguments -> {
            return ((Stream) supplier2.get()).map(arguments -> {
                ArrayList arrayList = new ArrayList(Arrays.asList(arguments.get()));
                arrayList.addAll(new ArrayList(Arrays.asList(arguments.get())));
                return Arguments.of(arrayList.toArray());
            });
        });
    }

    public static Stream<Arguments> trueFalseArguments() {
        return Stream.of((Object[]) new Boolean[]{true, false}).map(obj -> {
            return Arguments.of(new Object[]{obj});
        });
    }

    public static TestGraph fromGdl(String str) {
        return fromGdl(str, Orientation.NATURAL, "graph");
    }

    public static TestGraph fromGdl(String str, String str2) {
        return fromGdl(str, Orientation.NATURAL, str2);
    }

    public static TestGraph fromGdl(String str, Orientation orientation) {
        return fromGdl(str, orientation, "graph");
    }

    public static TestGraph fromGdl(String str, Orientation orientation, String str2) {
        Objects.requireNonNull(str);
        GdlFactory build = GdlFactory.builder().createConfig(ImmutableGraphCreateFromGdlConfig.builder().gdlGraph(str).graphName("graph").orientation(orientation).build()).namedDatabaseId(GdlSupportExtension.DATABASE_ID).build();
        CSRGraph union = build.build().graphStore().getUnion();
        Objects.requireNonNull(build);
        return new TestGraph(union, build::nodeId, str2);
    }

    public static GraphStore graphStoreFromGDL(String str) {
        Objects.requireNonNull(str);
        return GdlFactory.of(str).build().graphStore();
    }

    public static long[][] ids(IdFunction idFunction, String[][] strArr) {
        return (long[][]) Arrays.stream(strArr).map(strArr2 -> {
            return ids(idFunction, strArr2);
        }).toArray(i -> {
            return new long[i];
        });
    }

    public static long[] ids(IdFunction idFunction, String... strArr) {
        Stream stream = Arrays.stream(strArr);
        Objects.requireNonNull(idFunction);
        return stream.mapToLong(idFunction::of).toArray();
    }

    public static void assertLongValues(TestGraph testGraph, Function<Long, Long> function, Map<String, Long> map) {
        map.forEach((str, l) -> {
            Long l = (Long) function.apply(Long.valueOf(testGraph.toMappedNodeId(str)));
            Assertions.assertEquals(l, l, StringFormatting.formatWithLocale("Values do not match for variable %s. Expected %s, got %s.", new Object[]{str, l.toString(), l.toString()}));
        });
    }

    public static void assertDoubleValues(TestGraph testGraph, Function<Long, Double> function, Map<String, Double> map, double d) {
        map.forEach((str, d2) -> {
            Double d2 = (Double) function.apply(Long.valueOf(testGraph.toMappedNodeId(str)));
            Assertions.assertEquals(d2.doubleValue(), d2.doubleValue(), d, StringFormatting.formatWithLocale("Values do not match for variable %s. Expected %s, got %s.", new Object[]{str, d2.toString(), d2.toString()}));
        });
    }

    public static void assertGraphEquals(Graph graph, Graph graph2) {
        Assertions.assertEquals(graph.nodeCount(), graph2.nodeCount(), "Node counts do not match.");
        Assertions.assertEquals(CanonicalAdjacencyMatrix.canonicalize(graph), CanonicalAdjacencyMatrix.canonicalize(graph2));
    }

    public static void assertGraphEquals(Collection<Graph> collection, Graph graph) {
        List list = (List) collection.stream().map(CanonicalAdjacencyMatrix::canonicalize).collect(Collectors.toList());
        String canonicalize = CanonicalAdjacencyMatrix.canonicalize(graph);
        Assertions.assertTrue(((Boolean) list.stream().map(str -> {
            return Boolean.valueOf(str.equals(canonicalize));
        }).reduce((v0, v1) -> {
            return Boolean.logicalXor(v0, v1);
        }).orElse(false)).booleanValue(), StringFormatting.formatWithLocale("None of the given graphs matches the actual one.%nActual:%n%s%nExpected:%n%s", new Object[]{canonicalize, String.join("\n\n", list)}));
    }

    public static void assertMemoryEstimation(Supplier<MemoryEstimation> supplier, long j, int i, long j2, long j3) {
        assertMemoryEstimation(supplier, j, 0L, i, j2, j3);
    }

    public static void assertMemoryEstimation(Supplier<MemoryEstimation> supplier, long j, long j2, int i, long j3, long j4) {
        assertMemoryEstimation(supplier, GraphDimensions.of(j, j2), i, j3, j4);
    }

    public static void assertMemoryEstimation(Supplier<MemoryEstimation> supplier, GraphDimensions graphDimensions, int i, long j, long j2) {
        MemoryRange memoryUsage = supplier.get().estimate(graphDimensions, i).memoryUsage();
        Assertions.assertEquals(j, memoryUsage.min);
        Assertions.assertEquals(j2, memoryUsage.max);
    }

    public static void assertTransactionTermination(Executable executable) {
        Assertions.assertEquals(Status.Transaction.Terminated, Assertions.assertThrows(TransactionTerminatedException.class, executable).status());
    }

    public static void assertCypherResult(GraphDatabaseService graphDatabaseService, @Language("Cypher") String str, List<Map<String, Object>> list) {
        assertCypherResult(graphDatabaseService, str, Collections.emptyMap(), list);
    }

    public static void assertCypherResult(GraphDatabaseService graphDatabaseService, @Language("Cypher") String str, Map<String, Object> map, List<Map<String, Object>> list) {
        GraphDatabaseApiProxy.runInTransaction(graphDatabaseService, transaction -> {
            SoftAssertions softAssertions = new SoftAssertions();
            ArrayList arrayList = new ArrayList();
            QueryRunner.runQueryWithResultConsumer(graphDatabaseService, str, map, result -> {
                result.accept(resultRow -> {
                    HashMap hashMap = new HashMap();
                    for (String str2 : result.columns()) {
                        hashMap.put(str2, resultRow.get(str2));
                    }
                    arrayList.add(hashMap);
                    return true;
                });
            });
            softAssertions.assertThat(arrayList).withFailMessage("Different amount of rows returned for actual result (%d) than expected (%d)", new Object[]{Integer.valueOf(arrayList.size()), Integer.valueOf(list.size())}).hasSize(list.size());
            for (int i = 0; i < list.size(); i++) {
                Map map2 = (Map) list.get(i);
                Map map3 = (Map) arrayList.get(i);
                softAssertions.assertThat(map3.keySet()).containsExactlyInAnyOrderElementsOf(map2.keySet());
                int i2 = i;
                map2.forEach((str2, obj) -> {
                    Object obj = map3.get(str2);
                    ObjectAssert withFailMessage = softAssertions.assertThat(obj).withFailMessage("Different value for column '%s' of row %d (expected %s, but got %s)", new Object[]{str2, Integer.valueOf(i2), obj, obj});
                    if (obj instanceof Matcher) {
                        withFailMessage.is(new HamcrestCondition((Matcher) obj));
                    } else if (obj instanceof Condition) {
                        withFailMessage.is((Condition) obj);
                    } else {
                        withFailMessage.isEqualTo(obj);
                    }
                });
            }
            softAssertions.assertAll();
        });
    }

    public static String getCypherAggregation(String str, String str2) {
        String str3;
        switch (AnonymousClass1.$SwitchMap$org$neo4j$gds$core$Aggregation[Aggregation.parse(str).ordinal()]) {
            case 1:
                str3 = "head(collect(%s))";
                break;
            case 2:
                str3 = "sum(%s)";
                break;
            case 3:
                str3 = "min(%s)";
                break;
            case 4:
                str3 = "max(%s)";
                break;
            case 5:
                str3 = "count(%s)";
                break;
            default:
                str3 = "%s";
                break;
        }
        return StringFormatting.formatWithLocale(str3, new Object[]{str2});
    }

    public static TransactionContext fullAccessTransaction(GraphDatabaseAPI graphDatabaseAPI) {
        return TransactionContext.of(graphDatabaseAPI, SecurityContext.AUTH_DISABLED);
    }

    public static NodeMapping nodeMapping(long j) {
        NodesBuilder build = GraphFactory.initNodesBuilder().nodeCount(j).maxOriginalId(j - 1).allocationTracker(AllocationTracker.empty()).build();
        long j2 = 0;
        while (true) {
            long j3 = j2;
            if (j3 >= j) {
                return build.build().nodeMapping();
            }
            build.addNode(j3, new NodeLabel[0]);
            j2 = j3 + 1;
        }
    }

    public static NodeMapping nodeMapping(long[] jArr) {
        NodesBuilder build = GraphFactory.initNodesBuilder().nodeCount(jArr.length).maxOriginalId(Arrays.stream(jArr).max().orElse(0L)).allocationTracker(AllocationTracker.empty()).build();
        LongStream stream = Arrays.stream(jArr);
        Objects.requireNonNull(build);
        stream.forEach(j -> {
            build.addNode(j, new NodeLabel[0]);
        });
        return build.build().nodeMapping();
    }
}
