/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.quantization.quantizer;

import java.io.IOException;
import java.util.Arrays;
import lombok.Generated;
import lombok.NonNull;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.requests.TrainingRequest;
import org.opensearch.knn.quantization.quantizer.RandomGaussianRotation;
import oshi.util.tuples.Pair;

final class QuantizerHelper {
    private static final int ONE_BIT_NUMBER_OF_BITS_PER_COORDINATE = 1;

    static OneBitScalarQuantizationState calculateQuantizationState(TrainingRequest<float[]> trainingRequest, int[] sampledIndices, ScalarQuantizationParams quantizationParams) throws IOException {
        QuantizerHelperResult quantizerHelperResult = QuantizerHelper.calculateQuantizationStateHelper(trainingRequest, sampledIndices, 1);
        return OneBitScalarQuantizationState.builder().quantizationParams(quantizationParams).meanThresholds(quantizerHelperResult.getThresholds()[0]).rotationMatrix(quantizerHelperResult.getRotationMatrix()).belowThresholdMeans(quantizerHelperResult.getBelow()).aboveThresholdMeans(quantizerHelperResult.getAbove()).build();
    }

    static MultiBitScalarQuantizationState calculateQuantizationState(TrainingRequest<float[]> trainingRequest, int[] sampledIndices, ScalarQuantizationParams quantizationParams, int bitsPerCoordinate) throws IOException {
        QuantizerHelperResult quantizerHelperResult = QuantizerHelper.calculateQuantizationStateHelper(trainingRequest, sampledIndices, bitsPerCoordinate);
        return MultiBitScalarQuantizationState.builder().quantizationParams(quantizationParams).thresholds(quantizerHelperResult.getThresholds()).rotationMatrix(quantizerHelperResult.getRotationMatrix()).build();
    }

    private static void validateSampledIndices(int[] sampledIndices) {
        if (sampledIndices == null || sampledIndices.length == 0) {
            throw new IllegalArgumentException("Sampled indices cannot be null or empty.");
        }
    }

    protected static float[][] calculateThresholds(float[] mean, float[] stdDev, int bitsPerCoordinate) {
        int dim = mean.length;
        float[][] thresholds = new float[bitsPerCoordinate][dim];
        float coef = bitsPerCoordinate + 1;
        for (int b = 0; b < bitsPerCoordinate; ++b) {
            float iCoef = -1.0f + (float)(2 * (b + 1)) / coef;
            for (int d = 0; d < dim; ++d) {
                thresholds[b][d] = mean[d] + iCoef * stdDev[d];
            }
        }
        return thresholds;
    }

    private static QuantizerHelperResult calculateQuantizationStateHelper(TrainingRequest<float[]> trainingRequest, int[] sampledIndices, Integer bitsPerCoordinate) throws IOException {
        QuantizerHelper.validateSampledIndices(sampledIndices);
        int dim = trainingRequest.getVectorAtThePosition(sampledIndices[0]).length;
        float[][] rotationMatrix = null;
        if (trainingRequest.isEnableRandomRotation()) {
            rotationMatrix = RandomGaussianRotation.generateRotationMatrix(dim);
        }
        Pair<float[], float[]> meanStd = QuantizerHelper.calculateMeanAndStdDev(trainingRequest, sampledIndices, rotationMatrix);
        float[][] thresholds = QuantizerHelper.calculateThresholds((float[])meanStd.getA(), (float[])meanStd.getB(), bitsPerCoordinate);
        if (bitsPerCoordinate == 1) {
            assert (thresholds.length == 1);
            Pair<float[], float[]> belowAbove = QuantizerHelper.calculateBelowAboveThresholdMeans(trainingRequest, thresholds[0], sampledIndices, rotationMatrix);
            return QuantizerHelperResult.builder().thresholds(thresholds).rotationMatrix(rotationMatrix).below((float[])belowAbove.getA()).above((float[])belowAbove.getB()).build();
        }
        return QuantizerHelperResult.builder().thresholds(thresholds).rotationMatrix(rotationMatrix).build();
    }

    public static Pair<float[], float[]> calculateMeanAndStdDev(TrainingRequest<float[]> request, int[] sampledIndices) throws IOException {
        return QuantizerHelper.calculateMeanAndStdDev(request, sampledIndices, null);
    }

    public static Pair<float[], float[]> calculateMeanAndStdDev(TrainingRequest<float[]> request, int[] sampledIndices, float[][] rotationMatrix) throws IOException {
        float[] mean = null;
        float[] m2 = null;
        int count = 0;
        request.resetVectorValues();
        for (int docId : sampledIndices) {
            float[] vector = request.getVectorAtThePosition(docId);
            if (vector == null) {
                throw new IllegalArgumentException("Vector at sampled index " + docId + " is null.");
            }
            if (rotationMatrix != null) {
                vector = RandomGaussianRotation.applyRotation(vector, rotationMatrix);
            }
            if (mean == null) {
                mean = new float[vector.length];
                m2 = new float[vector.length];
            }
            ++count;
            int i = 0;
            while (i < vector.length) {
                float delta = vector[i] - mean[i];
                int n = i;
                mean[n] = mean[n] + delta / (float)count;
                float delta2 = vector[i] - mean[i];
                int n2 = i++;
                m2[n2] = m2[n2] + delta * delta2;
            }
        }
        if (mean == null) {
            throw new IllegalStateException("Mean array should not be null after processing vectors.");
        }
        float[] stdDev = new float[mean.length];
        for (int i = 0; i < stdDev.length; ++i) {
            stdDev[i] = (float)Math.sqrt((double)(m2[i] / (float)count));
        }
        return new Pair(mean, (Object)stdDev);
    }

    protected static Pair<float[], float[]> calculateBelowAboveThresholdMeans(TrainingRequest<float[]> request, float[] thresholds, int[] sampledIndices, float[][] rotationMatrix) throws IOException {
        int dim = thresholds.length;
        float[] below = new float[dim];
        float[] above = new float[dim];
        int[] belowCount = new int[dim];
        int[] aboveCount = new int[dim];
        request.resetVectorValues();
        for (int docId : sampledIndices) {
            float[] vector = request.getVectorAtThePosition(docId);
            if (vector == null) {
                throw new IllegalArgumentException("Vector at sampled index " + docId + " is null.");
            }
            if (rotationMatrix != null) {
                vector = RandomGaussianRotation.applyRotation(vector, rotationMatrix);
            }
            for (int d = 0; d < dim; ++d) {
                if (vector[d] <= thresholds[d]) {
                    int n = d;
                    below[n] = below[n] + vector[d];
                    int n2 = d;
                    belowCount[n2] = belowCount[n2] + 1;
                    continue;
                }
                int n = d;
                above[n] = above[n] + vector[d];
                int n3 = d;
                aboveCount[n3] = aboveCount[n3] + 1;
            }
        }
        for (int d = 0; d < dim; ++d) {
            if (belowCount[d] > 0) {
                int n = d;
                below[n] = below[n] / (float)belowCount[d];
            }
            if (aboveCount[d] <= 0) continue;
            int n = d;
            above[n] = above[n] / (float)aboveCount[d];
        }
        return new Pair((Object)below, (Object)above);
    }

    @Generated
    private QuantizerHelper() {
        throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
    }

    public static final class QuantizerHelperResult {
        @NonNull
        private final float[][] thresholds;
        private final float[][] rotationMatrix;
        private final float[] below;
        private final float[] above;

        @Generated
        QuantizerHelperResult(@NonNull float[][] thresholds, float[][] rotationMatrix, float[] below, float[] above) {
            if (thresholds == null) {
                throw new NullPointerException("thresholds is marked non-null but is null");
            }
            this.thresholds = thresholds;
            this.rotationMatrix = rotationMatrix;
            this.below = below;
            this.above = above;
        }

        @Generated
        public static QuantizerHelperResultBuilder builder() {
            return new QuantizerHelperResultBuilder();
        }

        @Generated
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof QuantizerHelperResult)) {
                return false;
            }
            QuantizerHelperResult other = (QuantizerHelperResult)o;
            if (!Arrays.deepEquals((Object[])this.getThresholds(), (Object[])other.getThresholds())) {
                return false;
            }
            if (!Arrays.deepEquals((Object[])this.getRotationMatrix(), (Object[])other.getRotationMatrix())) {
                return false;
            }
            if (!Arrays.equals(this.getBelow(), other.getBelow())) {
                return false;
            }
            return Arrays.equals(this.getAbove(), other.getAbove());
        }

        @Generated
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + Arrays.deepHashCode((Object[])this.getThresholds());
            result = result * 59 + Arrays.deepHashCode((Object[])this.getRotationMatrix());
            result = result * 59 + Arrays.hashCode(this.getBelow());
            result = result * 59 + Arrays.hashCode(this.getAbove());
            return result;
        }

        @Generated
        public String toString() {
            return "QuantizerHelper.QuantizerHelperResult(thresholds=" + Arrays.deepToString((Object[])this.getThresholds()) + ", rotationMatrix=" + Arrays.deepToString((Object[])this.getRotationMatrix()) + ", below=" + Arrays.toString(this.getBelow()) + ", above=" + Arrays.toString(this.getAbove()) + ")";
        }

        @NonNull
        @Generated
        public float[][] getThresholds() {
            return this.thresholds;
        }

        @Generated
        public float[][] getRotationMatrix() {
            return this.rotationMatrix;
        }

        @Generated
        public float[] getBelow() {
            return this.below;
        }

        @Generated
        public float[] getAbove() {
            return this.above;
        }

        @Generated
        public static class QuantizerHelperResultBuilder {
            @Generated
            private float[][] thresholds;
            @Generated
            private float[][] rotationMatrix;
            @Generated
            private float[] below;
            @Generated
            private float[] above;

            @Generated
            QuantizerHelperResultBuilder() {
            }

            @Generated
            public QuantizerHelperResultBuilder thresholds(@NonNull float[][] thresholds) {
                if (thresholds == null) {
                    throw new NullPointerException("thresholds is marked non-null but is null");
                }
                this.thresholds = thresholds;
                return this;
            }

            @Generated
            public QuantizerHelperResultBuilder rotationMatrix(float[][] rotationMatrix) {
                this.rotationMatrix = rotationMatrix;
                return this;
            }

            @Generated
            public QuantizerHelperResultBuilder below(float[] below) {
                this.below = below;
                return this;
            }

            @Generated
            public QuantizerHelperResultBuilder above(float[] above) {
                this.above = above;
                return this;
            }

            @Generated
            public QuantizerHelperResult build() {
                return new QuantizerHelperResult(this.thresholds, this.rotationMatrix, this.below, this.above);
            }

            @Generated
            public String toString() {
                return "QuantizerHelper.QuantizerHelperResult.QuantizerHelperResultBuilder(thresholds=" + Arrays.deepToString((Object[])this.thresholds) + ", rotationMatrix=" + Arrays.deepToString((Object[])this.rotationMatrix) + ", below=" + Arrays.toString(this.below) + ", above=" + Arrays.toString(this.above) + ")";
            }
        }
    }
}

