/*
 * Decompiled with CFR 0.152.
 */
package chat.octet.model;

import chat.octet.model.LlamaService;
import chat.octet.model.beans.Token;
import com.google.common.base.Preconditions;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.lang3.StringUtils;

public class TokenDecoder {
    private TokenDecoder() {
    }

    public static String decodeToken(boolean special, int ... tokens) {
        byte[] buffer = new byte[tokens.length * 64];
        int length = 0;
        for (int token : tokens) {
            byte[] bytes = new byte[64];
            int size = LlamaService.tokenToPiece(token, bytes, bytes.length, special);
            System.arraycopy(bytes, 0, buffer, length, size);
            length += size;
        }
        return new String(buffer, 0, length, StandardCharsets.UTF_8);
    }

    public static String decodeToken(int ... tokens) {
        return TokenDecoder.decodeToken(false, tokens);
    }

    public static int getByteLength(byte[] buffer, int length) {
        int len = 0;
        for (int i = 0; i < length; ++i) {
            byte code = buffer[i];
            if (!Character.isValidCodePoint(code)) {
                try {
                    len += TokenDecoder.getUtf8ByteLength(code);
                }
                catch (Exception exception) {}
                continue;
            }
            ++len;
        }
        return len;
    }

    public static int getUtf8ByteLength(byte bytes) {
        int topBits = bytes & 0xFF;
        if (topBits <= 127) {
            return 1;
        }
        if (topBits >= 194 && topBits <= 223) {
            return 2;
        }
        if (topBits >= 224 && topBits <= 239) {
            return 3;
        }
        if (topBits >= 240 && topBits <= 247) {
            return 4;
        }
        throw new IllegalArgumentException("Illegal byte, byte code is " + bytes);
    }

    private static int findTokenIndex(List<Token> sources, int[] target, boolean toEnd) {
        for (int i = 0; i < sources.size(); ++i) {
            int toIndex = Math.min(sources.size() - 1, i + target.length);
            int[] temp = sources.subList(i, toIndex).stream().mapToInt(Token::getId).toArray();
            if (!Arrays.equals(temp, target)) continue;
            if (toEnd) {
                return toIndex;
            }
            return i;
        }
        return -1;
    }

    public static List<Token> subTokensBetween(List<Token> tokens, String startWord) {
        return TokenDecoder.subTokensBetween(tokens, startWord, null);
    }

    public static List<Token> subTokensBetween(List<Token> tokens, String startWord, String endWord) {
        Preconditions.checkNotNull(tokens, (Object)"Tokens cannot be null");
        int startIndex = 0;
        if (StringUtils.isNotBlank((CharSequence)startWord)) {
            int[] startIds = LlamaService.tokenize(startWord, false, true);
            int index = TokenDecoder.findTokenIndex(tokens, startIds, true);
            startIndex = index == -1 ? 0 : index;
        }
        int endIndex = tokens.size();
        if (StringUtils.isNotBlank((CharSequence)endWord)) {
            int[] endIds = LlamaService.tokenize(endWord, false, true);
            int index = TokenDecoder.findTokenIndex(tokens, endIds, false);
            endIndex = index == -1 ? tokens.size() : index;
        }
        return tokens.subList(startIndex, endIndex);
    }
}

