/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.evaluator;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.training.evaluator.AbstractAccuracy;
import ai.djl.util.Pair;

public class Accuracy
extends AbstractAccuracy {
    public Accuracy() {
        this("Accuracy", 1);
    }

    public Accuracy(String name) {
        this(name, 1);
    }

    public Accuracy(String name, int axis) {
        super(name, axis);
    }

    @Override
    protected Pair<Long, NDArray> accuracyHelper(NDList labels, NDList predictions) {
        NDArray predictionReduced;
        NDArray label = labels.head();
        NDArray prediction = predictions.head();
        this.checkLabelShapes(label, prediction);
        if (!label.getShape().equals(prediction.getShape())) {
            predictionReduced = prediction.argMax(this.axis);
            predictionReduced = predictionReduced.reshape(label.getShape());
        } else {
            predictionReduced = prediction.argMax(this.axis);
            label = label.argMax(this.axis);
        }
        long total = label.size();
        try (NDArray nd = label.toType(DataType.INT64, true);){
            NDArray correct = predictionReduced.toType(DataType.INT64, false).eq(nd).countNonzero();
            Pair<Long, NDArray> pair = new Pair<Long, NDArray>(total, correct);
            return pair;
        }
    }
}

