package com.yahoo.vespa.indexinglanguage.expressions;

import com.yahoo.document.ArrayDataType;
import com.yahoo.document.DataType;
import com.yahoo.document.DocumentType;
import com.yahoo.document.Field;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.TensorType;

/* loaded from: input_file:com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.class */
public class EmbedExpression extends Expression {
    private final Embedder embedder;
    private String destination;
    private TensorType targetType;

    public EmbedExpression(Embedder embedder) {
        super(DataType.STRING);
        this.embedder = embedder;
    }

    @Override // com.yahoo.vespa.indexinglanguage.expressions.Expression
    public void setStatementOutput(DocumentType documentType, Field field) {
        this.targetType = toTargetTensor(field.getDataType());
        this.destination = documentType.getName() + "." + field.getName();
    }

    @Override // com.yahoo.vespa.indexinglanguage.expressions.Expression
    protected void doExecute(ExecutionContext executionContext) {
        executionContext.setValue(new TensorFieldValue(this.embedder.embed(executionContext.getValue().getString(), new Embedder.Context(this.destination).setLanguage(executionContext.getLanguage()), this.targetType)));
    }

    @Override // com.yahoo.vespa.indexinglanguage.expressions.Expression
    protected void doVerify(VerificationContext verificationContext) {
        String outputField = verificationContext.getOutputField();
        if (outputField == null) {
            throw new VerificationException(this, "No output field in this statement: Don't know what tensor type to embed into.");
        }
        this.targetType = toTargetTensor(verificationContext.getInputType(this, outputField));
        verificationContext.setValueType(createdOutputType());
    }

    @Override // com.yahoo.vespa.indexinglanguage.expressions.Expression
    public DataType createdOutputType() {
        return new TensorDataType(this.targetType);
    }

    private static TensorType toTargetTensor(DataType dataType) {
        if (dataType instanceof ArrayDataType) {
            return toTargetTensor(((ArrayDataType) dataType).getNestedType());
        }
        if (dataType instanceof TensorDataType) {
            return ((TensorDataType) dataType).getTensorType();
        }
        throw new IllegalArgumentException("Expected a tensor data type but got " + dataType);
    }

    public String toString() {
        return "embed";
    }

    public int hashCode() {
        return 1;
    }

    public boolean equals(Object obj) {
        return obj instanceof EmbedExpression;
    }
}
