/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.shaded.io.airlift.compress.zstd;

import org.apache.flink.shaded.io.airlift.compress.zstd.BitOutputStream;
import org.apache.flink.shaded.io.airlift.compress.zstd.Util;

class FseCompressionTable {
    private final short[] nextState;
    private final int[] deltaNumberOfBits;
    private final int[] deltaFindState;
    private int log2Size;

    public FseCompressionTable(int maxTableLog, int maxSymbol) {
        this.nextState = new short[1 << maxTableLog];
        this.deltaNumberOfBits = new int[maxSymbol + 1];
        this.deltaFindState = new int[maxSymbol + 1];
    }

    public static FseCompressionTable newInstance(short[] normalizedCounts, int maxSymbol, int tableLog) {
        FseCompressionTable result = new FseCompressionTable(tableLog, maxSymbol);
        result.initialize(normalizedCounts, maxSymbol, tableLog);
        return result;
    }

    public void initializeRleTable(int symbol) {
        this.log2Size = 0;
        this.nextState[0] = 0;
        this.nextState[1] = 0;
        this.deltaFindState[symbol] = 0;
        this.deltaNumberOfBits[symbol] = 0;
    }

    public void initialize(short[] normalizedCounts, int maxSymbol, int tableLog) {
        int symbol;
        int tableSize = 1 << tableLog;
        byte[] table = new byte[tableSize];
        int highThreshold = tableSize - 1;
        this.log2Size = tableLog;
        int[] cumulative = new int[257];
        cumulative[0] = 0;
        for (int i = 1; i <= maxSymbol + 1; ++i) {
            if (normalizedCounts[i - 1] == -1) {
                cumulative[i] = cumulative[i - 1] + 1;
                table[highThreshold--] = (byte)(i - 1);
                continue;
            }
            cumulative[i] = cumulative[i - 1] + normalizedCounts[i - 1];
        }
        cumulative[maxSymbol + 1] = tableSize + 1;
        int position = FseCompressionTable.spreadSymbols(normalizedCounts, maxSymbol, tableSize, highThreshold, table);
        if (position != 0) {
            throw new AssertionError((Object)"Spread symbols failed");
        }
        for (int i = 0; i < tableSize; ++i) {
            int n = symbol = table[i];
            int n2 = cumulative[n];
            cumulative[n] = n2 + 1;
            this.nextState[n2] = (short)(tableSize + i);
        }
        int total = 0;
        block6: for (symbol = 0; symbol <= maxSymbol; ++symbol) {
            switch (normalizedCounts[symbol]) {
                case 0: {
                    this.deltaNumberOfBits[symbol] = (tableLog + 1 << 16) - tableSize;
                    continue block6;
                }
                case -1: 
                case 1: {
                    this.deltaNumberOfBits[symbol] = (tableLog << 16) - tableSize;
                    this.deltaFindState[symbol] = total - 1;
                    ++total;
                    continue block6;
                }
                default: {
                    int maxBitsOut = tableLog - Util.highestBit(normalizedCounts[symbol] - 1);
                    int minStatePlus = normalizedCounts[symbol] << maxBitsOut;
                    this.deltaNumberOfBits[symbol] = (maxBitsOut << 16) - minStatePlus;
                    this.deltaFindState[symbol] = total - normalizedCounts[symbol];
                    total += normalizedCounts[symbol];
                }
            }
        }
    }

    public int begin(byte symbol) {
        int outputBits = this.deltaNumberOfBits[symbol] + 32768 >>> 16;
        int base = (outputBits << 16) - this.deltaNumberOfBits[symbol] >>> outputBits;
        return this.nextState[base + this.deltaFindState[symbol]];
    }

    public int encode(BitOutputStream stream, int state, int symbol) {
        int outputBits = state + this.deltaNumberOfBits[symbol] >>> 16;
        stream.addBits(state, outputBits);
        return this.nextState[(state >>> outputBits) + this.deltaFindState[symbol]];
    }

    public void finish(BitOutputStream stream, int state) {
        stream.addBits(state, this.log2Size);
        stream.flush();
    }

    private static int calculateStep(int tableSize) {
        return (tableSize >>> 1) + (tableSize >>> 3) + 3;
    }

    public static int spreadSymbols(short[] normalizedCounters, int maxSymbolValue, int tableSize, int highThreshold, byte[] symbols) {
        int mask = tableSize - 1;
        int step = FseCompressionTable.calculateStep(tableSize);
        int position = 0;
        for (int symbol = 0; symbol <= maxSymbolValue; symbol = (int)((byte)(symbol + 1))) {
            for (int i = 0; i < normalizedCounters[symbol]; ++i) {
                symbols[position] = symbol;
                while ((position = position + step & mask) > highThreshold) {
                }
            }
        }
        return position;
    }
}

