package org.flyte.jflyte;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.ForkJoinPool;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.flyte.api.v1.Binding;
import org.flyte.api.v1.BindingData;
import org.flyte.api.v1.ContainerError;
import org.flyte.api.v1.ContainerTaskRegistrar;
import org.flyte.api.v1.DynamicJobSpec;
import org.flyte.api.v1.DynamicWorkflowTask;
import org.flyte.api.v1.DynamicWorkflowTaskRegistrar;
import org.flyte.api.v1.Literal;
import org.flyte.api.v1.NamedEntityIdentifier;
import org.flyte.api.v1.Node;
import org.flyte.api.v1.PluginTaskRegistrar;
import org.flyte.api.v1.RunnableTaskRegistrar;
import org.flyte.api.v1.Struct;
import org.flyte.api.v1.TaskIdentifier;
import org.flyte.api.v1.TaskTemplate;
import org.flyte.api.v1.WorkflowIdentifier;
import org.flyte.api.v1.WorkflowTemplate;
import org.flyte.api.v1.WorkflowTemplateRegistrar;
import org.flyte.jflyte.api.TokenSource;
import org.flyte.jflyte.utils.ClassLoaders;
import org.flyte.jflyte.utils.Config;
import org.flyte.jflyte.utils.ExecutionConfig;
import org.flyte.jflyte.utils.FileSystemLoader;
import org.flyte.jflyte.utils.FlyteAdminClient;
import org.flyte.jflyte.utils.IdentifierRewrite;
import org.flyte.jflyte.utils.JFlyteCustom;
import org.flyte.jflyte.utils.MoreCollectors;
import org.flyte.jflyte.utils.PackageLoader;
import org.flyte.jflyte.utils.ProjectClosure;
import org.flyte.jflyte.utils.ProtoReader;
import org.flyte.jflyte.utils.ProtoUtil;
import org.flyte.jflyte.utils.ProtoWriter;
import org.flyte.jflyte.utils.Registrars;
import org.flyte.jflyte.utils.WorkflowNodeVisitor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import picocli.CommandLine;

@CommandLine.Command(name = "execute-dynamic-workflow")
/* loaded from: input_file:org/flyte/jflyte/ExecuteDynamicWorkflow.class */
public class ExecuteDynamicWorkflow implements Callable<Integer> {
    private static final Logger LOG = LoggerFactory.getLogger(ExecuteDynamicWorkflow.class);
    private static final int LOAD_PARALLELISM = 32;

    @CommandLine.Option(names = {"--task"}, required = true)
    private String task;

    @CommandLine.Option(names = {"--inputs"}, required = true)
    private String inputs;

    @CommandLine.Option(names = {"--outputPrefix"}, required = true)
    private String outputPrefix;

    @CommandLine.Option(names = {"--taskTemplatePath"}, required = true)
    private String taskTemplatePath;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.flyte.jflyte.ExecuteDynamicWorkflow$1, reason: invalid class name */
    /* loaded from: input_file:org/flyte/jflyte/ExecuteDynamicWorkflow$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$flyte$api$v1$BindingData$Kind = new int[BindingData.Kind.values().length];

        static {
            try {
                $SwitchMap$org$flyte$api$v1$BindingData$Kind[BindingData.Kind.SCALAR.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$flyte$api$v1$BindingData$Kind[BindingData.Kind.COLLECTION.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$flyte$api$v1$BindingData$Kind[BindingData.Kind.PROMISE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$flyte$api$v1$BindingData$Kind[BindingData.Kind.MAP.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.concurrent.Callable
    public Integer call() {
        execute();
        return 0;
    }

    private void execute() {
        Config load = Config.load();
        ExecutionConfig load2 = ExecutionConfig.load();
        Map loadFileSystems = FileSystemLoader.loadFileSystems(ClassLoaders.forModuleDir(load.moduleDir()).values());
        ProtoWriter protoWriter = new ProtoWriter(this.outputPrefix, FileSystemLoader.getFileSystem(loadFileSystems, this.outputPrefix));
        try {
            ProtoReader protoReader = new ProtoReader(FileSystemLoader.getFileSystem(loadFileSystems, this.inputs));
            TaskTemplate taskTemplate = protoReader.getTaskTemplate(this.taskTemplatePath);
            ForkJoinPool forkJoinPool = new ForkJoinPool(LOAD_PARALLELISM);
            try {
                ClassLoader load3 = PackageLoader.load(loadFileSystems, taskTemplate, forkJoinPool);
                forkJoinPool.shutdownNow();
                Map<String, String> env = getEnv();
                Map map = (Map) ClassLoaders.withClassLoader(load3, () -> {
                    return Registrars.loadAll(WorkflowTemplateRegistrar.class, env);
                });
                Map map2 = (Map) ClassLoaders.withClassLoader(load3, () -> {
                    return Registrars.loadAll(RunnableTaskRegistrar.class, env);
                });
                Map map3 = (Map) ClassLoaders.withClassLoader(load3, () -> {
                    return Registrars.loadAll(DynamicWorkflowTaskRegistrar.class, env);
                });
                Map map4 = (Map) ClassLoaders.withClassLoader(load3, () -> {
                    return Registrars.loadAll(ContainerTaskRegistrar.class, env);
                });
                Map map5 = (Map) ClassLoaders.withClassLoader(load3, () -> {
                    return Registrars.loadAll(PluginTaskRegistrar.class, env);
                });
                Struct serializeToStruct = JFlyteCustom.deserializeFromStruct(taskTemplate.custom()).serializeToStruct();
                DynamicJobSpec rewrite = rewrite(load, load2, (DynamicJobSpec) ClassLoaders.withClassLoader(load3, () -> {
                    return getDynamicWorkflowTask(this.task).run(protoReader.getInput(this.inputs));
                }), MoreCollectors.mapValues(ProjectClosure.createTaskTemplates(load2, map2, map3, map4, map5), taskTemplate2 -> {
                    return taskTemplate2.toBuilder().custom(ProjectClosure.merge(taskTemplate2.custom(), serializeToStruct)).build();
                }), map);
                if (rewrite.nodes().isEmpty()) {
                    protoWriter.writeOutputs(getLiteralMap(rewrite.outputs()));
                } else {
                    protoWriter.writeFutures(rewrite);
                }
            } catch (Throwable th) {
                forkJoinPool.shutdownNow();
                throw th;
            }
        } catch (ContainerError e) {
            LOG.error("failed to run dynamic workflow", e);
            protoWriter.writeError(ProtoUtil.serializeContainerError(e));
        } catch (Throwable th2) {
            LOG.error("failed to run dynamic workflow", th2);
            protoWriter.writeError(ProtoUtil.serializeThrowable(th2));
        }
    }

    private static DynamicJobSpec rewrite(Config config, ExecutionConfig executionConfig, DynamicJobSpec dynamicJobSpec, Map<TaskIdentifier, TaskTemplate> map, Map<WorkflowIdentifier, WorkflowTemplate> map2) {
        FlyteAdminClient create = FlyteAdminClient.create(config.platformUrl(), config.platformInsecure(), (TokenSource) null);
        try {
            IdentifierRewrite.Visitor visitor = IdentifierRewrite.builder().domain(executionConfig.domain()).project(executionConfig.project()).version(executionConfig.version()).adminClient(create).build().visitor();
            Function function = list -> {
                Stream stream = list.stream();
                Objects.requireNonNull(visitor);
                return (List) stream.map(visitor::visitNode).collect(MoreCollectors.toUnmodifiableList());
            };
            Map<WorkflowIdentifier, WorkflowTemplate> collectAllUsedSubWorkflows = collectAllUsedSubWorkflows(dynamicJobSpec.nodes(), map2, visitor, function);
            HashMap hashMap = new HashMap();
            DynamicJobSpec build = dynamicJobSpec.toBuilder().nodes(collectAllUsedTaskTemplates(dynamicJobSpec, map, function, hashMap, create, collectAllUsedSubWorkflows)).subWorkflows(ImmutableMap.builder().putAll(dynamicJobSpec.subWorkflows()).putAll(collectAllUsedSubWorkflows).build()).tasks(ImmutableMap.builder().putAll(dynamicJobSpec.tasks()).putAll(hashMap).build()).build();
            if (create != null) {
                create.close();
            }
            return build;
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private static List<Node> collectAllUsedTaskTemplates(DynamicJobSpec dynamicJobSpec, Map<TaskIdentifier, TaskTemplate> map, Function<List<Node>, List<Node>> function, Map<TaskIdentifier, TaskTemplate> map2, FlyteAdminClient flyteAdminClient, Map<WorkflowIdentifier, WorkflowTemplate> map3) {
        HashMap hashMap = new HashMap();
        List<Node> collectTaskTemplates = collectTaskTemplates(dynamicJobSpec.nodes(), function, map2, map, flyteAdminClient, hashMap);
        map3.values().forEach(workflowTemplate -> {
            collectTaskTemplates(workflowTemplate.nodes(), function, map2, map, flyteAdminClient, hashMap);
        });
        return collectTaskTemplates;
    }

    private static Map<WorkflowIdentifier, WorkflowTemplate> collectAllUsedSubWorkflows(List<Node> list, Map<WorkflowIdentifier, WorkflowTemplate> map, WorkflowNodeVisitor workflowNodeVisitor, Function<List<Node>, List<Node>> function) {
        Map collectSubWorkflows = ProjectClosure.collectSubWorkflows(list, map, function);
        Objects.requireNonNull(workflowNodeVisitor);
        return MoreCollectors.mapValues(collectSubWorkflows, workflowNodeVisitor::visitWorkflowTemplate);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static List<Node> collectTaskTemplates(List<Node> list, Function<List<Node>, List<Node>> function, Map<TaskIdentifier, TaskTemplate> map, Map<TaskIdentifier, TaskTemplate> map2, FlyteAdminClient flyteAdminClient, Map<TaskIdentifier, TaskTemplate> map3) {
        List<Node> apply = function.apply(list);
        map.putAll(ProjectClosure.collectDynamicWorkflowTasks(apply, map2, taskIdentifier -> {
            return fetchTaskTemplate(flyteAdminClient, taskIdentifier, map3);
        }));
        return apply;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static TaskTemplate fetchTaskTemplate(FlyteAdminClient flyteAdminClient, TaskIdentifier taskIdentifier, Map<TaskIdentifier, TaskTemplate> map) {
        return map.computeIfAbsent(taskIdentifier, taskIdentifier2 -> {
            LOG.info("fetching task template remotely for {}", taskIdentifier);
            return flyteAdminClient.fetchLatestTaskTemplate(NamedEntityIdentifier.builder().domain(taskIdentifier.domain()).project(taskIdentifier.project()).name(taskIdentifier.name()).build());
        });
    }

    private static DynamicWorkflowTask getDynamicWorkflowTask(String str) {
        for (Map.Entry entry : Registrars.loadAll(DynamicWorkflowTaskRegistrar.class, getEnv()).entrySet()) {
            if (((TaskIdentifier) entry.getKey()).name().equals(str)) {
                return (DynamicWorkflowTask) entry.getValue();
            }
        }
        throw new IllegalArgumentException("Dynamic workflow task not found: " + str);
    }

    private static Map<String, String> getEnv() {
        return (Map) System.getenv().entrySet().stream().filter(entry -> {
            return ((String) entry.getKey()).startsWith("JFLYTE_") || ((String) entry.getKey()).startsWith("FLYTE_");
        }).collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, (v0) -> {
            return v0.getValue();
        }));
    }

    static Map<String, Literal> getLiteralMap(List<Binding> list) {
        return (Map) list.stream().collect(Collectors.toMap((v0) -> {
            return v0.var_();
        }, binding -> {
            return getLiteral(binding.binding());
        }));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Literal getLiteral(BindingData bindingData) {
        switch (AnonymousClass1.$SwitchMap$org$flyte$api$v1$BindingData$Kind[bindingData.kind().ordinal()]) {
            case 1:
                return Literal.ofScalar(bindingData.scalar());
            case 2:
                return Literal.ofCollection((List) bindingData.collection().stream().map(ExecuteDynamicWorkflow::getLiteral).collect(Collectors.toList()));
            case 3:
                throw new IllegalArgumentException("invariant failed, workflows without nodes can't have promises");
            case 4:
                return Literal.ofMap((Map) bindingData.map().entrySet().stream().map(entry -> {
                    return Maps.immutableEntry((String) entry.getKey(), getLiteral((BindingData) entry.getValue()));
                }).collect(Collectors.toMap((v0) -> {
                    return v0.getKey();
                }, (v0) -> {
                    return v0.getValue();
                })));
            default:
                throw new AssertionError("Unexpected BindingData.Kind: " + bindingData.kind());
        }
    }
}
